refactor: move remaining api packages to core (#1731)

* core 1

* api/openai/files fix

* core 2 - core/config

* move over core api.go and tests to the start of core/http

* move over localai specific endpoints to core/http, begin the service/endpoint split there

* refactor big chunk on the plane

* refactor chunk 2 on plane, next step: port and modify changes to request.go

* easy fixes for request.go, major changes not done yet

* lintfix

* json tag lintfix?

* gitignore and .keep files

* strange fix attempt: rename the config dir?
This commit is contained in:
Dave 2024-03-01 10:19:53 -05:00 committed by GitHub
parent 316de82f51
commit 1c312685aa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
50 changed files with 1440 additions and 1206 deletions

4
.gitignore vendored
View File

@ -21,6 +21,7 @@ local-ai
!charts/*
# prevent above rules from omitting the api/localai folder
!api/localai
!core/**/localai
# Ignore models
models/*
@ -34,6 +35,7 @@ release/
.idea
# Generated during build
backend-assets/
backend-assets/*
!backend-assets/.keep
prepare
/ggml-metal.metal

View File

@ -44,6 +44,8 @@ BUILD_ID?=git
TEST_DIR=/tmp/test
TEST_FLAKES?=5
RANDOM := $(shell bash -c 'echo $$RANDOM')
VERSION?=$(shell git describe --always --tags || echo "dev" )
@ -337,7 +339,7 @@ test: prepare test-models/testmodel grpcs
export GO_TAGS="tts stablediffusion"
$(MAKE) prepare-test
HUGGINGFACE_GRPC=$(abspath ./)/backend/python/sentencetransformers/run.sh TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="!gpt4all && !llama && !llama-gguf" --flake-attempts 5 --fail-fast -v -r $(TEST_PATHS)
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="!gpt4all && !llama && !llama-gguf" --flake-attempts $(TEST_FLAKES) --fail-fast -v -r $(TEST_PATHS)
$(MAKE) test-gpt4all
$(MAKE) test-llama
$(MAKE) test-llama-gguf

View File

@ -1,162 +0,0 @@
package localai
import (
"context"
"fmt"
"strings"
config "github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/go-skynet/LocalAI/core/options"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
gopsutil "github.com/shirou/gopsutil/v3/process"
)
type BackendMonitorRequest struct {
Model string `json:"model" yaml:"model"`
}
type BackendMonitorResponse struct {
MemoryInfo *gopsutil.MemoryInfoStat
MemoryPercent float32
CPUPercent float64
}
type BackendMonitor struct {
configLoader *config.ConfigLoader
options *options.Option // Taking options in case we need to inspect ExternalGRPCBackends, though that's out of scope for now, hence the name.
}
func NewBackendMonitor(configLoader *config.ConfigLoader, options *options.Option) BackendMonitor {
return BackendMonitor{
configLoader: configLoader,
options: options,
}
}
func (bm *BackendMonitor) SampleLocalBackendProcess(model string) (*BackendMonitorResponse, error) {
config, exists := bm.configLoader.GetConfig(model)
var backend string
if exists {
backend = config.Model
} else {
// Last ditch effort: use it raw, see if a backend happens to match.
backend = model
}
if !strings.HasSuffix(backend, ".bin") {
backend = fmt.Sprintf("%s.bin", backend)
}
pid, err := bm.options.Loader.GetGRPCPID(backend)
if err != nil {
log.Error().Msgf("model %s : failed to find pid %+v", model, err)
return nil, err
}
// Name is slightly frightening but this does _not_ create a new process, rather it looks up an existing process by PID.
backendProcess, err := gopsutil.NewProcess(int32(pid))
if err != nil {
log.Error().Msgf("model %s [PID %d] : error getting process info %+v", model, pid, err)
return nil, err
}
memInfo, err := backendProcess.MemoryInfo()
if err != nil {
log.Error().Msgf("model %s [PID %d] : error getting memory info %+v", model, pid, err)
return nil, err
}
memPercent, err := backendProcess.MemoryPercent()
if err != nil {
log.Error().Msgf("model %s [PID %d] : error getting memory percent %+v", model, pid, err)
return nil, err
}
cpuPercent, err := backendProcess.CPUPercent()
if err != nil {
log.Error().Msgf("model %s [PID %d] : error getting cpu percent %+v", model, pid, err)
return nil, err
}
return &BackendMonitorResponse{
MemoryInfo: memInfo,
MemoryPercent: memPercent,
CPUPercent: cpuPercent,
}, nil
}
func (bm BackendMonitor) getModelLoaderIDFromCtx(c *fiber.Ctx) (string, error) {
input := new(BackendMonitorRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return "", err
}
config, exists := bm.configLoader.GetConfig(input.Model)
var backendId string
if exists {
backendId = config.Model
} else {
// Last ditch effort: use it raw, see if a backend happens to match.
backendId = input.Model
}
if !strings.HasSuffix(backendId, ".bin") {
backendId = fmt.Sprintf("%s.bin", backendId)
}
return backendId, nil
}
func BackendMonitorEndpoint(bm BackendMonitor) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
backendId, err := bm.getModelLoaderIDFromCtx(c)
if err != nil {
return err
}
model := bm.options.Loader.CheckIsLoaded(backendId)
if model == "" {
return fmt.Errorf("backend %s is not currently loaded", backendId)
}
status, rpcErr := model.GRPC(false, nil).Status(context.TODO())
if rpcErr != nil {
log.Warn().Msgf("backend %s experienced an error retrieving status info: %s", backendId, rpcErr.Error())
val, slbErr := bm.SampleLocalBackendProcess(backendId)
if slbErr != nil {
return fmt.Errorf("backend %s experienced an error retrieving status info via rpc: %s, then failed local node process sample: %s", backendId, rpcErr.Error(), slbErr.Error())
}
return c.JSON(proto.StatusResponse{
State: proto.StatusResponse_ERROR,
Memory: &proto.MemoryUsageData{
Total: val.MemoryInfo.VMS,
Breakdown: map[string]uint64{
"gopsutil-RSS": val.MemoryInfo.RSS,
},
},
})
}
return c.JSON(status)
}
}
func BackendShutdownEndpoint(bm BackendMonitor) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
backendId, err := bm.getModelLoaderIDFromCtx(c)
if err != nil {
return err
}
return bm.options.Loader.ShutdownModel(backendId)
}
}

View File

@ -1,326 +0,0 @@
package localai
import (
"context"
"fmt"
"os"
"slices"
"strings"
"sync"
json "github.com/json-iterator/go"
"gopkg.in/yaml.v3"
config "github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/pkg/gallery"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
)
type galleryOp struct {
req gallery.GalleryModel
id string
galleries []gallery.Gallery
galleryName string
}
type galleryOpStatus struct {
FileName string `json:"file_name"`
Error error `json:"error"`
Processed bool `json:"processed"`
Message string `json:"message"`
Progress float64 `json:"progress"`
TotalFileSize string `json:"file_size"`
DownloadedFileSize string `json:"downloaded_size"`
}
type galleryApplier struct {
modelPath string
sync.Mutex
C chan galleryOp
statuses map[string]*galleryOpStatus
}
func NewGalleryService(modelPath string) *galleryApplier {
return &galleryApplier{
modelPath: modelPath,
C: make(chan galleryOp),
statuses: make(map[string]*galleryOpStatus),
}
}
func prepareModel(modelPath string, req gallery.GalleryModel, cm *config.ConfigLoader, downloadStatus func(string, string, string, float64)) error {
config, err := gallery.GetGalleryConfigFromURL(req.URL)
if err != nil {
return err
}
config.Files = append(config.Files, req.AdditionalFiles...)
return gallery.InstallModel(modelPath, req.Name, &config, req.Overrides, downloadStatus)
}
func (g *galleryApplier) updateStatus(s string, op *galleryOpStatus) {
g.Lock()
defer g.Unlock()
g.statuses[s] = op
}
func (g *galleryApplier) getStatus(s string) *galleryOpStatus {
g.Lock()
defer g.Unlock()
return g.statuses[s]
}
func (g *galleryApplier) getAllStatus() map[string]*galleryOpStatus {
g.Lock()
defer g.Unlock()
return g.statuses
}
func (g *galleryApplier) Start(c context.Context, cm *config.ConfigLoader) {
go func() {
for {
select {
case <-c.Done():
return
case op := <-g.C:
utils.ResetDownloadTimers()
g.updateStatus(op.id, &galleryOpStatus{Message: "processing", Progress: 0})
// updates the status with an error
updateError := func(e error) {
g.updateStatus(op.id, &galleryOpStatus{Error: e, Processed: true, Message: "error: " + e.Error()})
}
// displayDownload displays the download progress
progressCallback := func(fileName string, current string, total string, percentage float64) {
g.updateStatus(op.id, &galleryOpStatus{Message: "processing", FileName: fileName, Progress: percentage, TotalFileSize: total, DownloadedFileSize: current})
utils.DisplayDownloadFunction(fileName, current, total, percentage)
}
var err error
// if the request contains a gallery name, we apply the gallery from the gallery list
if op.galleryName != "" {
if strings.Contains(op.galleryName, "@") {
err = gallery.InstallModelFromGallery(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback)
} else {
err = gallery.InstallModelFromGalleryByName(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback)
}
} else {
err = prepareModel(g.modelPath, op.req, cm, progressCallback)
}
if err != nil {
updateError(err)
continue
}
// Reload models
err = cm.LoadConfigs(g.modelPath)
if err != nil {
updateError(err)
continue
}
err = cm.Preload(g.modelPath)
if err != nil {
updateError(err)
continue
}
g.updateStatus(op.id, &galleryOpStatus{Processed: true, Message: "completed", Progress: 100})
}
}
}()
}
type galleryModel struct {
gallery.GalleryModel `yaml:",inline"` // https://github.com/go-yaml/yaml/issues/63
ID string `json:"id"`
}
func processRequests(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery, requests []galleryModel) error {
var err error
for _, r := range requests {
utils.ResetDownloadTimers()
if r.ID == "" {
err = prepareModel(modelPath, r.GalleryModel, cm, utils.DisplayDownloadFunction)
} else {
if strings.Contains(r.ID, "@") {
err = gallery.InstallModelFromGallery(
galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction)
} else {
err = gallery.InstallModelFromGalleryByName(
galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction)
}
}
}
return err
}
func ApplyGalleryFromFile(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery) error {
dat, err := os.ReadFile(s)
if err != nil {
return err
}
var requests []galleryModel
if err := yaml.Unmarshal(dat, &requests); err != nil {
return err
}
return processRequests(modelPath, s, cm, galleries, requests)
}
func ApplyGalleryFromString(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery) error {
var requests []galleryModel
err := json.Unmarshal([]byte(s), &requests)
if err != nil {
return err
}
return processRequests(modelPath, s, cm, galleries, requests)
}
/// Endpoint Service
type ModelGalleryService struct {
galleries []gallery.Gallery
modelPath string
galleryApplier *galleryApplier
}
type GalleryModel struct {
ID string `json:"id"`
gallery.GalleryModel
}
func CreateModelGalleryService(galleries []gallery.Gallery, modelPath string, galleryApplier *galleryApplier) ModelGalleryService {
return ModelGalleryService{
galleries: galleries,
modelPath: modelPath,
galleryApplier: galleryApplier,
}
}
func (mgs *ModelGalleryService) GetOpStatusEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
status := mgs.galleryApplier.getStatus(c.Params("uuid"))
if status == nil {
return fmt.Errorf("could not find any status for ID")
}
return c.JSON(status)
}
}
func (mgs *ModelGalleryService) GetAllStatusEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
return c.JSON(mgs.galleryApplier.getAllStatus())
}
}
func (mgs *ModelGalleryService) ApplyModelGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(GalleryModel)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
uuid, err := uuid.NewUUID()
if err != nil {
return err
}
mgs.galleryApplier.C <- galleryOp{
req: input.GalleryModel,
id: uuid.String(),
galleryName: input.ID,
galleries: mgs.galleries,
}
return c.JSON(struct {
ID string `json:"uuid"`
StatusURL string `json:"status"`
}{ID: uuid.String(), StatusURL: c.BaseURL() + "/models/jobs/" + uuid.String()})
}
}
func (mgs *ModelGalleryService) ListModelFromGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
log.Debug().Msgf("Listing models from galleries: %+v", mgs.galleries)
models, err := gallery.AvailableGalleryModels(mgs.galleries, mgs.modelPath)
if err != nil {
return err
}
log.Debug().Msgf("Models found from galleries: %+v", models)
for _, m := range models {
log.Debug().Msgf("Model found from galleries: %+v", m)
}
dat, err := json.Marshal(models)
if err != nil {
return err
}
return c.Send(dat)
}
}
// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents!
func (mgs *ModelGalleryService) ListModelGalleriesEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
log.Debug().Msgf("Listing model galleries %+v", mgs.galleries)
dat, err := json.Marshal(mgs.galleries)
if err != nil {
return err
}
return c.Send(dat)
}
}
func (mgs *ModelGalleryService) AddModelGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(gallery.Gallery)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
if slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
return gallery.Name == input.Name
}) {
return fmt.Errorf("%s already exists", input.Name)
}
dat, err := json.Marshal(mgs.galleries)
if err != nil {
return err
}
log.Debug().Msgf("Adding %+v to gallery list", *input)
mgs.galleries = append(mgs.galleries, *input)
return c.Send(dat)
}
}
func (mgs *ModelGalleryService) RemoveModelGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(gallery.Gallery)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
if !slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
return gallery.Name == input.Name
}) {
return fmt.Errorf("%s is not currently registered", input.Name)
}
mgs.galleries = slices.DeleteFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
return gallery.Name == input.Name
})
return c.Send(nil)
}
}

0
configuration/.keep Normal file
View File

View File

@ -3,36 +3,36 @@ package backend
import (
"fmt"
config "github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/options"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/pkg/grpc"
model "github.com/go-skynet/LocalAI/pkg/model"
)
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.Config, o *options.Option) (func() ([]float32, error), error) {
if !c.Embeddings {
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) {
if !backendConfig.Embeddings {
return nil, fmt.Errorf("endpoint disabled for this model by API configuration")
}
modelFile := c.Model
modelFile := backendConfig.Model
grpcOpts := gRPCModelOpts(c)
grpcOpts := gRPCModelOpts(backendConfig)
var inferenceModel interface{}
var err error
opts := modelOpts(c, o, []model.Option{
opts := modelOpts(backendConfig, appConfig, []model.Option{
model.WithLoadGRPCLoadModelOpts(grpcOpts),
model.WithThreads(uint32(c.Threads)),
model.WithAssetDir(o.AssetsDestination),
model.WithThreads(uint32(backendConfig.Threads)),
model.WithAssetDir(appConfig.AssetsDestination),
model.WithModel(modelFile),
model.WithContext(o.Context),
model.WithContext(appConfig.Context),
})
if c.Backend == "" {
if backendConfig.Backend == "" {
inferenceModel, err = loader.GreedyLoader(opts...)
} else {
opts = append(opts, model.WithBackendString(c.Backend))
opts = append(opts, model.WithBackendString(backendConfig.Backend))
inferenceModel, err = loader.BackendLoader(opts...)
}
if err != nil {
@ -43,7 +43,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.
switch model := inferenceModel.(type) {
case grpc.Backend:
fn = func() ([]float32, error) {
predictOptions := gRPCPredictOpts(c, loader.ModelPath)
predictOptions := gRPCPredictOpts(backendConfig, loader.ModelPath)
if len(tokens) > 0 {
embeds := []int32{}
@ -52,7 +52,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.
}
predictOptions.EmbeddingTokens = embeds
res, err := model.Embeddings(o.Context, predictOptions)
res, err := model.Embeddings(appConfig.Context, predictOptions)
if err != nil {
return nil, err
}
@ -61,7 +61,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.
}
predictOptions.Embeddings = s
res, err := model.Embeddings(o.Context, predictOptions)
res, err := model.Embeddings(appConfig.Context, predictOptions)
if err != nil {
return nil, err
}

View File

@ -1,33 +1,33 @@
package backend
import (
config "github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/options"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
model "github.com/go-skynet/LocalAI/pkg/model"
)
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, c config.Config, o *options.Option) (func() error, error) {
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() error, error) {
opts := modelOpts(c, o, []model.Option{
model.WithBackendString(c.Backend),
model.WithAssetDir(o.AssetsDestination),
model.WithThreads(uint32(c.Threads)),
model.WithContext(o.Context),
model.WithModel(c.Model),
opts := modelOpts(backendConfig, appConfig, []model.Option{
model.WithBackendString(backendConfig.Backend),
model.WithAssetDir(appConfig.AssetsDestination),
model.WithThreads(uint32(backendConfig.Threads)),
model.WithContext(appConfig.Context),
model.WithModel(backendConfig.Model),
model.WithLoadGRPCLoadModelOpts(&proto.ModelOptions{
CUDA: c.CUDA || c.Diffusers.CUDA,
SchedulerType: c.Diffusers.SchedulerType,
PipelineType: c.Diffusers.PipelineType,
CFGScale: c.Diffusers.CFGScale,
LoraAdapter: c.LoraAdapter,
LoraScale: c.LoraScale,
LoraBase: c.LoraBase,
IMG2IMG: c.Diffusers.IMG2IMG,
CLIPModel: c.Diffusers.ClipModel,
CLIPSubfolder: c.Diffusers.ClipSubFolder,
CLIPSkip: int32(c.Diffusers.ClipSkip),
ControlNet: c.Diffusers.ControlNet,
CUDA: backendConfig.CUDA || backendConfig.Diffusers.CUDA,
SchedulerType: backendConfig.Diffusers.SchedulerType,
PipelineType: backendConfig.Diffusers.PipelineType,
CFGScale: backendConfig.Diffusers.CFGScale,
LoraAdapter: backendConfig.LoraAdapter,
LoraScale: backendConfig.LoraScale,
LoraBase: backendConfig.LoraBase,
IMG2IMG: backendConfig.Diffusers.IMG2IMG,
CLIPModel: backendConfig.Diffusers.ClipModel,
CLIPSubfolder: backendConfig.Diffusers.ClipSubFolder,
CLIPSkip: int32(backendConfig.Diffusers.ClipSkip),
ControlNet: backendConfig.Diffusers.ControlNet,
}),
})
@ -40,19 +40,19 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat
fn := func() error {
_, err := inferenceModel.GenerateImage(
o.Context,
appConfig.Context,
&proto.GenerateImageRequest{
Height: int32(height),
Width: int32(width),
Mode: int32(mode),
Step: int32(step),
Seed: int32(seed),
CLIPSkip: int32(c.Diffusers.ClipSkip),
CLIPSkip: int32(backendConfig.Diffusers.ClipSkip),
PositivePrompt: positive_prompt,
NegativePrompt: negative_prompt,
Dst: dst,
Src: src,
EnableParameters: c.Diffusers.EnableParameters,
EnableParameters: backendConfig.Diffusers.EnableParameters,
})
return err
}

View File

@ -8,8 +8,8 @@ import (
"sync"
"unicode/utf8"
config "github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/options"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/pkg/gallery"
"github.com/go-skynet/LocalAI/pkg/grpc"
model "github.com/go-skynet/LocalAI/pkg/model"
@ -26,7 +26,7 @@ type TokenUsage struct {
Completion int
}
func ModelInference(ctx context.Context, s string, images []string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
func ModelInference(ctx context.Context, s string, images []string, loader *model.ModelLoader, c config.BackendConfig, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
modelFile := c.Model
grpcOpts := gRPCModelOpts(c)
@ -140,7 +140,7 @@ func ModelInference(ctx context.Context, s string, images []string, loader *mode
var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp)
var mu sync.Mutex = sync.Mutex{}
func Finetune(config config.Config, input, prediction string) string {
func Finetune(config config.BackendConfig, input, prediction string) string {
if config.Echo {
prediction = input + prediction
}

View File

@ -4,19 +4,17 @@ import (
"os"
"path/filepath"
"github.com/go-skynet/LocalAI/core/config"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
model "github.com/go-skynet/LocalAI/pkg/model"
config "github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/options"
)
func modelOpts(c config.Config, o *options.Option, opts []model.Option) []model.Option {
if o.SingleBackend {
func modelOpts(c config.BackendConfig, so *config.ApplicationConfig, opts []model.Option) []model.Option {
if so.SingleBackend {
opts = append(opts, model.WithSingleActiveBackend())
}
if o.ParallelBackendRequests {
if so.ParallelBackendRequests {
opts = append(opts, model.EnableParallelRequests)
}
@ -28,14 +26,14 @@ func modelOpts(c config.Config, o *options.Option, opts []model.Option) []model.
opts = append(opts, model.WithGRPCAttemptsDelay(c.GRPC.AttemptsSleepTime))
}
for k, v := range o.ExternalGRPCBackends {
for k, v := range so.ExternalGRPCBackends {
opts = append(opts, model.WithExternalBackend(k, v))
}
return opts
}
func gRPCModelOpts(c config.Config) *pb.ModelOptions {
func gRPCModelOpts(c config.BackendConfig) *pb.ModelOptions {
b := 512
if c.Batch != 0 {
b = c.Batch
@ -84,7 +82,7 @@ func gRPCModelOpts(c config.Config) *pb.ModelOptions {
}
}
func gRPCPredictOpts(c config.Config, modelPath string) *pb.PredictOptions {
func gRPCPredictOpts(c config.BackendConfig, modelPath string) *pb.PredictOptions {
promptCachePath := ""
if c.PromptCachePath != "" {
p := filepath.Join(modelPath, c.PromptCachePath)

View File

@ -4,25 +4,24 @@ import (
"context"
"fmt"
config "github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/core/options"
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
model "github.com/go-skynet/LocalAI/pkg/model"
)
func ModelTranscription(audio, language string, loader *model.ModelLoader, c config.Config, o *options.Option) (*schema.Result, error) {
func ModelTranscription(audio, language string, ml *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (*schema.Result, error) {
opts := modelOpts(c, o, []model.Option{
opts := modelOpts(backendConfig, appConfig, []model.Option{
model.WithBackendString(model.WhisperBackend),
model.WithModel(c.Model),
model.WithContext(o.Context),
model.WithThreads(uint32(c.Threads)),
model.WithAssetDir(o.AssetsDestination),
model.WithModel(backendConfig.Model),
model.WithContext(appConfig.Context),
model.WithThreads(uint32(backendConfig.Threads)),
model.WithAssetDir(appConfig.AssetsDestination),
})
whisperModel, err := o.Loader.BackendLoader(opts...)
whisperModel, err := ml.BackendLoader(opts...)
if err != nil {
return nil, err
}
@ -34,6 +33,6 @@ func ModelTranscription(audio, language string, loader *model.ModelLoader, c con
return whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{
Dst: audio,
Language: language,
Threads: uint32(c.Threads),
Threads: uint32(backendConfig.Threads),
})
}

View File

@ -6,8 +6,8 @@ import (
"os"
"path/filepath"
config "github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/options"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/utils"
@ -29,22 +29,22 @@ func generateUniqueFileName(dir, baseName, ext string) string {
}
}
func ModelTTS(backend, text, modelFile string, loader *model.ModelLoader, o *options.Option, c config.Config) (string, *proto.Result, error) {
func ModelTTS(backend, text, modelFile string, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (string, *proto.Result, error) {
bb := backend
if bb == "" {
bb = model.PiperBackend
}
grpcOpts := gRPCModelOpts(c)
grpcOpts := gRPCModelOpts(backendConfig)
opts := modelOpts(config.Config{}, o, []model.Option{
opts := modelOpts(config.BackendConfig{}, appConfig, []model.Option{
model.WithBackendString(bb),
model.WithModel(modelFile),
model.WithContext(o.Context),
model.WithAssetDir(o.AssetsDestination),
model.WithContext(appConfig.Context),
model.WithAssetDir(appConfig.AssetsDestination),
model.WithLoadGRPCLoadModelOpts(grpcOpts),
})
piperModel, err := o.Loader.BackendLoader(opts...)
piperModel, err := loader.BackendLoader(opts...)
if err != nil {
return "", nil, err
}
@ -53,19 +53,19 @@ func ModelTTS(backend, text, modelFile string, loader *model.ModelLoader, o *opt
return "", nil, fmt.Errorf("could not load piper model")
}
if err := os.MkdirAll(o.AudioDir, 0755); err != nil {
if err := os.MkdirAll(appConfig.AudioDir, 0755); err != nil {
return "", nil, fmt.Errorf("failed creating audio directory: %s", err)
}
fileName := generateUniqueFileName(o.AudioDir, "piper", ".wav")
filePath := filepath.Join(o.AudioDir, fileName)
fileName := generateUniqueFileName(appConfig.AudioDir, "piper", ".wav")
filePath := filepath.Join(appConfig.AudioDir, fileName)
// If the model file is not empty, we pass it joined with the model path
modelPath := ""
if modelFile != "" {
if bb != model.TransformersMusicGen {
modelPath = filepath.Join(o.Loader.ModelPath, modelFile)
if err := utils.VerifyPath(modelPath, o.Loader.ModelPath); err != nil {
modelPath = filepath.Join(loader.ModelPath, modelFile)
if err := utils.VerifyPath(modelPath, appConfig.ModelPath); err != nil {
return "", nil, err
}
} else {

View File

@ -1,4 +1,4 @@
package options
package config
import (
"context"
@ -6,16 +6,14 @@ import (
"encoding/json"
"time"
"github.com/go-skynet/LocalAI/metrics"
"github.com/go-skynet/LocalAI/pkg/gallery"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
)
type Option struct {
type ApplicationConfig struct {
Context context.Context
ConfigFile string
Loader *model.ModelLoader
ModelPath string
UploadLimitMB, Threads, ContextSize int
F16 bool
Debug, DisableMessage bool
@ -27,7 +25,6 @@ type Option struct {
PreloadModelsFromPath string
CORSAllowOrigins string
ApiKeys []string
Metrics *metrics.Metrics
ModelLibraryURL string
@ -52,10 +49,10 @@ type Option struct {
WatchDogBusyTimeout, WatchDogIdleTimeout time.Duration
}
type AppOption func(*Option)
type AppOption func(*ApplicationConfig)
func NewOptions(o ...AppOption) *Option {
opt := &Option{
func NewApplicationConfig(o ...AppOption) *ApplicationConfig {
opt := &ApplicationConfig{
Context: context.Background(),
UploadLimitMB: 15,
Threads: 1,
@ -70,63 +67,69 @@ func NewOptions(o ...AppOption) *Option {
}
func WithModelsURL(urls ...string) AppOption {
return func(o *Option) {
return func(o *ApplicationConfig) {
o.ModelsURL = urls
}
}
func WithModelPath(path string) AppOption {
return func(o *ApplicationConfig) {
o.ModelPath = path
}
}
func WithCors(b bool) AppOption {
return func(o *Option) {
return func(o *ApplicationConfig) {
o.CORS = b
}
}
func WithModelLibraryURL(url string) AppOption {
return func(o *Option) {
return func(o *ApplicationConfig) {
o.ModelLibraryURL = url
}
}
var EnableWatchDog = func(o *Option) {
var EnableWatchDog = func(o *ApplicationConfig) {
o.WatchDog = true
}
var EnableWatchDogIdleCheck = func(o *Option) {
var EnableWatchDogIdleCheck = func(o *ApplicationConfig) {
o.WatchDog = true
o.WatchDogIdle = true
}
var EnableWatchDogBusyCheck = func(o *Option) {
var EnableWatchDogBusyCheck = func(o *ApplicationConfig) {
o.WatchDog = true
o.WatchDogBusy = true
}
func SetWatchDogBusyTimeout(t time.Duration) AppOption {
return func(o *Option) {
return func(o *ApplicationConfig) {
o.WatchDogBusyTimeout = t
}
}
func SetWatchDogIdleTimeout(t time.Duration) AppOption {
return func(o *Option) {
return func(o *ApplicationConfig) {
o.WatchDogIdleTimeout = t
}
}
var EnableSingleBackend = func(o *Option) {
var EnableSingleBackend = func(o *ApplicationConfig) {
o.SingleBackend = true
}
var EnableParallelBackendRequests = func(o *Option) {
var EnableParallelBackendRequests = func(o *ApplicationConfig) {
o.ParallelBackendRequests = true
}
var EnableGalleriesAutoload = func(o *Option) {
var EnableGalleriesAutoload = func(o *ApplicationConfig) {
o.AutoloadGalleries = true
}
func WithExternalBackend(name string, uri string) AppOption {
return func(o *Option) {
return func(o *ApplicationConfig) {
if o.ExternalGRPCBackends == nil {
o.ExternalGRPCBackends = make(map[string]string)
}
@ -135,25 +138,25 @@ func WithExternalBackend(name string, uri string) AppOption {
}
func WithCorsAllowOrigins(b string) AppOption {
return func(o *Option) {
return func(o *ApplicationConfig) {
o.CORSAllowOrigins = b
}
}
func WithBackendAssetsOutput(out string) AppOption {
return func(o *Option) {
return func(o *ApplicationConfig) {
o.AssetsDestination = out
}
}
func WithBackendAssets(f embed.FS) AppOption {
return func(o *Option) {
return func(o *ApplicationConfig) {
o.BackendAssets = f
}
}
func WithStringGalleries(galls string) AppOption {
return func(o *Option) {
return func(o *ApplicationConfig) {
if galls == "" {
log.Debug().Msgf("no galleries to load")
o.Galleries = []gallery.Gallery{}
@ -168,102 +171,96 @@ func WithStringGalleries(galls string) AppOption {
}
func WithGalleries(galleries []gallery.Gallery) AppOption {
return func(o *Option) {
return func(o *ApplicationConfig) {
o.Galleries = append(o.Galleries, galleries...)
}
}
func WithContext(ctx context.Context) AppOption {
return func(o *Option) {
return func(o *ApplicationConfig) {
o.Context = ctx
}
}
func WithYAMLConfigPreload(configFile string) AppOption {
return func(o *Option) {
return func(o *ApplicationConfig) {
o.PreloadModelsFromPath = configFile
}
}
func WithJSONStringPreload(configFile string) AppOption {
return func(o *Option) {
return func(o *ApplicationConfig) {
o.PreloadJSONModels = configFile
}
}
func WithConfigFile(configFile string) AppOption {
return func(o *Option) {
return func(o *ApplicationConfig) {
o.ConfigFile = configFile
}
}
func WithModelLoader(loader *model.ModelLoader) AppOption {
return func(o *Option) {
o.Loader = loader
}
}
func WithUploadLimitMB(limit int) AppOption {
return func(o *Option) {
return func(o *ApplicationConfig) {
o.UploadLimitMB = limit
}
}
func WithThreads(threads int) AppOption {
return func(o *Option) {
return func(o *ApplicationConfig) {
o.Threads = threads
}
}
func WithContextSize(ctxSize int) AppOption {
return func(o *Option) {
return func(o *ApplicationConfig) {
o.ContextSize = ctxSize
}
}
func WithF16(f16 bool) AppOption {
return func(o *Option) {
return func(o *ApplicationConfig) {
o.F16 = f16
}
}
func WithDebug(debug bool) AppOption {
return func(o *Option) {
return func(o *ApplicationConfig) {
o.Debug = debug
}
}
func WithDisableMessage(disableMessage bool) AppOption {
return func(o *Option) {
return func(o *ApplicationConfig) {
o.DisableMessage = disableMessage
}
}
func WithAudioDir(audioDir string) AppOption {
return func(o *Option) {
return func(o *ApplicationConfig) {
o.AudioDir = audioDir
}
}
func WithImageDir(imageDir string) AppOption {
return func(o *Option) {
return func(o *ApplicationConfig) {
o.ImageDir = imageDir
}
}
func WithUploadDir(uploadDir string) AppOption {
return func(o *Option) {
return func(o *ApplicationConfig) {
o.UploadDir = uploadDir
}
}
func WithApiKeys(apiKeys []string) AppOption {
return func(o *Option) {
return func(o *ApplicationConfig) {
o.ApiKeys = apiKeys
}
}
func WithMetrics(meter *metrics.Metrics) AppOption {
return func(o *Option) {
o.Metrics = meter
}
}
// func WithMetrics(meter *metrics.Metrics) AppOption {
// return func(o *StartupOptions) {
// o.Metrics = meter
// }
// }

View File

@ -9,15 +9,16 @@ import (
"strings"
"sync"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/pkg/downloader"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/rs/zerolog/log"
"gopkg.in/yaml.v3"
)
type Config struct {
PredictionOptions `yaml:"parameters"`
Name string `yaml:"name"`
type BackendConfig struct {
schema.PredictionOptions `yaml:"parameters"`
Name string `yaml:"name"`
F16 bool `yaml:"f16"`
Threads int `yaml:"threads"`
@ -159,37 +160,55 @@ type TemplateConfig struct {
Functions string `yaml:"function"`
}
type ConfigLoader struct {
configs map[string]Config
sync.Mutex
}
func (c *Config) SetFunctionCallString(s string) {
func (c *BackendConfig) SetFunctionCallString(s string) {
c.functionCallString = s
}
func (c *Config) SetFunctionCallNameString(s string) {
func (c *BackendConfig) SetFunctionCallNameString(s string) {
c.functionCallNameString = s
}
func (c *Config) ShouldUseFunctions() bool {
func (c *BackendConfig) ShouldUseFunctions() bool {
return ((c.functionCallString != "none" || c.functionCallString == "") || c.ShouldCallSpecificFunction())
}
func (c *Config) ShouldCallSpecificFunction() bool {
func (c *BackendConfig) ShouldCallSpecificFunction() bool {
return len(c.functionCallNameString) > 0
}
func (c *Config) FunctionToCall() string {
func (c *BackendConfig) FunctionToCall() string {
return c.functionCallNameString
}
func defaultPredictOptions(modelFile string) schema.PredictionOptions {
return schema.PredictionOptions{
TopP: 0.7,
TopK: 80,
Maxtokens: 512,
Temperature: 0.9,
Model: modelFile,
}
}
func DefaultConfig(modelFile string) *BackendConfig {
return &BackendConfig{
PredictionOptions: defaultPredictOptions(modelFile),
}
}
////// Config Loader ////////
type BackendConfigLoader struct {
configs map[string]BackendConfig
sync.Mutex
}
// Load a config file for a model
func Load(modelName, modelPath string, cm *ConfigLoader, debug bool, threads, ctx int, f16 bool) (*Config, error) {
func LoadBackendConfigFileByName(modelName, modelPath string, cl *BackendConfigLoader, debug bool, threads, ctx int, f16 bool) (*BackendConfig, error) {
// Load a config file if present after the model name
modelConfig := filepath.Join(modelPath, modelName+".yaml")
var cfg *Config
var cfg *BackendConfig
defaults := func() {
cfg = DefaultConfig(modelName)
@ -199,13 +218,13 @@ func Load(modelName, modelPath string, cm *ConfigLoader, debug bool, threads, ct
cfg.Debug = debug
}
cfgExisting, exists := cm.GetConfig(modelName)
cfgExisting, exists := cl.GetBackendConfig(modelName)
if !exists {
if _, err := os.Stat(modelConfig); err == nil {
if err := cm.LoadConfig(modelConfig); err != nil {
if err := cl.LoadBackendConfig(modelConfig); err != nil {
return nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error())
}
cfgExisting, exists = cm.GetConfig(modelName)
cfgExisting, exists = cl.GetBackendConfig(modelName)
if exists {
cfg = &cfgExisting
} else {
@ -238,29 +257,13 @@ func Load(modelName, modelPath string, cm *ConfigLoader, debug bool, threads, ct
return cfg, nil
}
func defaultPredictOptions(modelFile string) PredictionOptions {
return PredictionOptions{
TopP: 0.7,
TopK: 80,
Maxtokens: 512,
Temperature: 0.9,
Model: modelFile,
func NewBackendConfigLoader() *BackendConfigLoader {
return &BackendConfigLoader{
configs: make(map[string]BackendConfig),
}
}
func DefaultConfig(modelFile string) *Config {
return &Config{
PredictionOptions: defaultPredictOptions(modelFile),
}
}
func NewConfigLoader() *ConfigLoader {
return &ConfigLoader{
configs: make(map[string]Config),
}
}
func ReadConfigFile(file string) ([]*Config, error) {
c := &[]*Config{}
func ReadBackendConfigFile(file string) ([]*BackendConfig, error) {
c := &[]*BackendConfig{}
f, err := os.ReadFile(file)
if err != nil {
return nil, fmt.Errorf("cannot read config file: %w", err)
@ -272,8 +275,8 @@ func ReadConfigFile(file string) ([]*Config, error) {
return *c, nil
}
func ReadConfig(file string) (*Config, error) {
c := &Config{}
func ReadBackendConfig(file string) (*BackendConfig, error) {
c := &BackendConfig{}
f, err := os.ReadFile(file)
if err != nil {
return nil, fmt.Errorf("cannot read config file: %w", err)
@ -285,10 +288,10 @@ func ReadConfig(file string) (*Config, error) {
return c, nil
}
func (cm *ConfigLoader) LoadConfigFile(file string) error {
func (cm *BackendConfigLoader) LoadBackendConfigFile(file string) error {
cm.Lock()
defer cm.Unlock()
c, err := ReadConfigFile(file)
c, err := ReadBackendConfigFile(file)
if err != nil {
return fmt.Errorf("cannot load config file: %w", err)
}
@ -299,49 +302,49 @@ func (cm *ConfigLoader) LoadConfigFile(file string) error {
return nil
}
func (cm *ConfigLoader) LoadConfig(file string) error {
cm.Lock()
defer cm.Unlock()
c, err := ReadConfig(file)
func (cl *BackendConfigLoader) LoadBackendConfig(file string) error {
cl.Lock()
defer cl.Unlock()
c, err := ReadBackendConfig(file)
if err != nil {
return fmt.Errorf("cannot read config file: %w", err)
}
cm.configs[c.Name]