From 1c312685aa52333d6250f70cf03a2d4ee72c4509 Mon Sep 17 00:00:00 2001 From: Dave Date: Fri, 1 Mar 2024 10:19:53 -0500 Subject: [PATCH] refactor: move remaining api packages to core (#1731) * core 1 * api/openai/files fix * core 2 - core/config * move over core api.go and tests to the start of core/http * move over localai specific endpoints to core/http, begin the service/endpoint split there * refactor big chunk on the plane * refactor chunk 2 on plane, next step: port and modify changes to request.go * easy fixes for request.go, major changes not done yet * lintfix * json tag lintfix? * gitignore and .keep files * strange fix attempt: rename the config dir? --- .gitignore | 4 +- Makefile | 4 +- api/localai/backend_monitor.go | 162 --------- api/localai/gallery.go | 326 ------------------ configuration/.keep | 0 core/backend/embeddings.go | 30 +- core/backend/image.go | 48 +-- core/backend/llm.go | 8 +- core/backend/options.go | 16 +- core/backend/transcript.go | 19 +- core/backend/tts.go | 26 +- .../application_config.go} | 99 +++--- core/config/{config.go => backend_config.go} | 153 ++++---- core/config/config_test.go | 26 +- core/http/api.go | 229 ++++-------- core/http/api_test.go | 146 +++++--- {api => core/http}/ctx/fiber.go | 0 .../http/endpoints/localai/backend_monitor.go | 36 ++ core/http/endpoints/localai/gallery.go | 146 ++++++++ core/http/endpoints/localai/metrics.go | 43 +++ .../http/endpoints/localai/tts.go | 25 +- {api => core/http/endpoints}/openai/chat.go | 37 +- .../http/endpoints}/openai/completion.go | 24 +- {api => core/http/endpoints}/openai/edit.go | 16 +- .../http/endpoints}/openai/embeddings.go | 15 +- {api => core/http/endpoints}/openai/files.go | 28 +- .../http/endpoints}/openai/files_test.go | 24 +- {api => core/http/endpoints}/openai/image.go | 22 +- .../http/endpoints}/openai/inference.go | 8 +- {api => core/http/endpoints}/openai/list.go | 8 +- .../http/endpoints}/openai/request.go | 27 +- .../http/endpoints}/openai/transcription.go | 12 +- core/schema/localai.go | 21 ++ core/schema/openai.go | 8 +- core/{config => schema}/prediction.go | 2 +- core/services/backend_monitor.go | 140 ++++++++ core/services/gallery.go | 167 +++++++++ core/services/metrics.go | 54 +++ core/startup/config_file_watcher.go | 100 ++++++ core/startup/startup.go | 128 +++++++ .../backend monitor/backend monitor.bru | 8 +- .../langchainjs-localai-example/src/index.mts | 4 +- go.mod | 6 +- go.sum | 12 +- main.go | 121 ++++--- metrics/metrics.go | 83 ----- pkg/downloader/uri.go | 4 + pkg/gallery/models_test.go | 1 - pkg/gallery/op.go | 18 + tests/integration/reflect_test.go | 2 +- 50 files changed, 1440 insertions(+), 1206 deletions(-) delete mode 100644 api/localai/backend_monitor.go delete mode 100644 api/localai/gallery.go create mode 100644 configuration/.keep rename core/{options/options.go => config/application_config.go} (69%) rename core/config/{config.go => backend_config.go} (77%) rename {api => core/http}/ctx/fiber.go (100%) create mode 100644 core/http/endpoints/localai/backend_monitor.go create mode 100644 core/http/endpoints/localai/gallery.go create mode 100644 core/http/endpoints/localai/metrics.go rename api/localai/localai.go => core/http/endpoints/localai/tts.go (56%) rename {api => core/http/endpoints}/openai/chat.go (90%) rename {api => core/http/endpoints}/openai/completion.go (82%) rename {api => core/http/endpoints}/openai/edit.go (77%) rename {api => core/http/endpoints}/openai/embeddings.go (73%) rename {api => core/http/endpoints}/openai/files.go (83%) rename {api => core/http/endpoints}/openai/files_test.go (92%) rename {api => core/http/endpoints}/openai/image.go (87%) rename {api => core/http/endpoints}/openai/inference.go (90%) rename {api => core/http/endpoints}/openai/list.go (87%) rename {api => core/http/endpoints}/openai/request.go (89%) rename {api => core/http/endpoints}/openai/transcription.go (71%) create mode 100644 core/schema/localai.go rename core/{config => schema}/prediction.go (99%) create mode 100644 core/services/backend_monitor.go create mode 100644 core/services/gallery.go create mode 100644 core/services/metrics.go create mode 100644 core/startup/config_file_watcher.go create mode 100644 core/startup/startup.go delete mode 100644 metrics/metrics.go create mode 100644 pkg/gallery/op.go diff --git a/.gitignore b/.gitignore index df00829c..b48f7391 100644 --- a/.gitignore +++ b/.gitignore @@ -21,6 +21,7 @@ local-ai !charts/* # prevent above rules from omitting the api/localai folder !api/localai +!core/**/localai # Ignore models models/* @@ -34,6 +35,7 @@ release/ .idea # Generated during build -backend-assets/ +backend-assets/* +!backend-assets/.keep prepare /ggml-metal.metal diff --git a/Makefile b/Makefile index e9d3b2bc..a52774cd 100644 --- a/Makefile +++ b/Makefile @@ -44,6 +44,8 @@ BUILD_ID?=git TEST_DIR=/tmp/test +TEST_FLAKES?=5 + RANDOM := $(shell bash -c 'echo $$RANDOM') VERSION?=$(shell git describe --always --tags || echo "dev" ) @@ -337,7 +339,7 @@ test: prepare test-models/testmodel grpcs export GO_TAGS="tts stablediffusion" $(MAKE) prepare-test HUGGINGFACE_GRPC=$(abspath ./)/backend/python/sentencetransformers/run.sh TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \ - $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="!gpt4all && !llama && !llama-gguf" --flake-attempts 5 --fail-fast -v -r $(TEST_PATHS) + $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="!gpt4all && !llama && !llama-gguf" --flake-attempts $(TEST_FLAKES) --fail-fast -v -r $(TEST_PATHS) $(MAKE) test-gpt4all $(MAKE) test-llama $(MAKE) test-llama-gguf diff --git a/api/localai/backend_monitor.go b/api/localai/backend_monitor.go deleted file mode 100644 index e6f1b409..00000000 --- a/api/localai/backend_monitor.go +++ /dev/null @@ -1,162 +0,0 @@ -package localai - -import ( - "context" - "fmt" - "strings" - - config "github.com/go-skynet/LocalAI/core/config" - "github.com/go-skynet/LocalAI/pkg/grpc/proto" - - "github.com/go-skynet/LocalAI/core/options" - "github.com/gofiber/fiber/v2" - "github.com/rs/zerolog/log" - - gopsutil "github.com/shirou/gopsutil/v3/process" -) - -type BackendMonitorRequest struct { - Model string `json:"model" yaml:"model"` -} - -type BackendMonitorResponse struct { - MemoryInfo *gopsutil.MemoryInfoStat - MemoryPercent float32 - CPUPercent float64 -} - -type BackendMonitor struct { - configLoader *config.ConfigLoader - options *options.Option // Taking options in case we need to inspect ExternalGRPCBackends, though that's out of scope for now, hence the name. -} - -func NewBackendMonitor(configLoader *config.ConfigLoader, options *options.Option) BackendMonitor { - return BackendMonitor{ - configLoader: configLoader, - options: options, - } -} - -func (bm *BackendMonitor) SampleLocalBackendProcess(model string) (*BackendMonitorResponse, error) { - config, exists := bm.configLoader.GetConfig(model) - var backend string - if exists { - backend = config.Model - } else { - // Last ditch effort: use it raw, see if a backend happens to match. - backend = model - } - - if !strings.HasSuffix(backend, ".bin") { - backend = fmt.Sprintf("%s.bin", backend) - } - - pid, err := bm.options.Loader.GetGRPCPID(backend) - - if err != nil { - log.Error().Msgf("model %s : failed to find pid %+v", model, err) - return nil, err - } - - // Name is slightly frightening but this does _not_ create a new process, rather it looks up an existing process by PID. - backendProcess, err := gopsutil.NewProcess(int32(pid)) - - if err != nil { - log.Error().Msgf("model %s [PID %d] : error getting process info %+v", model, pid, err) - return nil, err - } - - memInfo, err := backendProcess.MemoryInfo() - - if err != nil { - log.Error().Msgf("model %s [PID %d] : error getting memory info %+v", model, pid, err) - return nil, err - } - - memPercent, err := backendProcess.MemoryPercent() - if err != nil { - log.Error().Msgf("model %s [PID %d] : error getting memory percent %+v", model, pid, err) - return nil, err - } - - cpuPercent, err := backendProcess.CPUPercent() - if err != nil { - log.Error().Msgf("model %s [PID %d] : error getting cpu percent %+v", model, pid, err) - return nil, err - } - - return &BackendMonitorResponse{ - MemoryInfo: memInfo, - MemoryPercent: memPercent, - CPUPercent: cpuPercent, - }, nil -} - -func (bm BackendMonitor) getModelLoaderIDFromCtx(c *fiber.Ctx) (string, error) { - input := new(BackendMonitorRequest) - // Get input data from the request body - if err := c.BodyParser(input); err != nil { - return "", err - } - - config, exists := bm.configLoader.GetConfig(input.Model) - var backendId string - if exists { - backendId = config.Model - } else { - // Last ditch effort: use it raw, see if a backend happens to match. - backendId = input.Model - } - - if !strings.HasSuffix(backendId, ".bin") { - backendId = fmt.Sprintf("%s.bin", backendId) - } - - return backendId, nil -} - -func BackendMonitorEndpoint(bm BackendMonitor) func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - - backendId, err := bm.getModelLoaderIDFromCtx(c) - if err != nil { - return err - } - - model := bm.options.Loader.CheckIsLoaded(backendId) - if model == "" { - return fmt.Errorf("backend %s is not currently loaded", backendId) - } - - status, rpcErr := model.GRPC(false, nil).Status(context.TODO()) - if rpcErr != nil { - log.Warn().Msgf("backend %s experienced an error retrieving status info: %s", backendId, rpcErr.Error()) - val, slbErr := bm.SampleLocalBackendProcess(backendId) - if slbErr != nil { - return fmt.Errorf("backend %s experienced an error retrieving status info via rpc: %s, then failed local node process sample: %s", backendId, rpcErr.Error(), slbErr.Error()) - } - return c.JSON(proto.StatusResponse{ - State: proto.StatusResponse_ERROR, - Memory: &proto.MemoryUsageData{ - Total: val.MemoryInfo.VMS, - Breakdown: map[string]uint64{ - "gopsutil-RSS": val.MemoryInfo.RSS, - }, - }, - }) - } - - return c.JSON(status) - } -} - -func BackendShutdownEndpoint(bm BackendMonitor) func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - backendId, err := bm.getModelLoaderIDFromCtx(c) - if err != nil { - return err - } - - return bm.options.Loader.ShutdownModel(backendId) - } -} diff --git a/api/localai/gallery.go b/api/localai/gallery.go deleted file mode 100644 index ee6f4d7d..00000000 --- a/api/localai/gallery.go +++ /dev/null @@ -1,326 +0,0 @@ -package localai - -import ( - "context" - "fmt" - "os" - "slices" - "strings" - "sync" - - json "github.com/json-iterator/go" - "gopkg.in/yaml.v3" - - config "github.com/go-skynet/LocalAI/core/config" - "github.com/go-skynet/LocalAI/pkg/gallery" - "github.com/go-skynet/LocalAI/pkg/utils" - - "github.com/gofiber/fiber/v2" - "github.com/google/uuid" - "github.com/rs/zerolog/log" -) - -type galleryOp struct { - req gallery.GalleryModel - id string - galleries []gallery.Gallery - galleryName string -} - -type galleryOpStatus struct { - FileName string `json:"file_name"` - Error error `json:"error"` - Processed bool `json:"processed"` - Message string `json:"message"` - Progress float64 `json:"progress"` - TotalFileSize string `json:"file_size"` - DownloadedFileSize string `json:"downloaded_size"` -} - -type galleryApplier struct { - modelPath string - sync.Mutex - C chan galleryOp - statuses map[string]*galleryOpStatus -} - -func NewGalleryService(modelPath string) *galleryApplier { - return &galleryApplier{ - modelPath: modelPath, - C: make(chan galleryOp), - statuses: make(map[string]*galleryOpStatus), - } -} - -func prepareModel(modelPath string, req gallery.GalleryModel, cm *config.ConfigLoader, downloadStatus func(string, string, string, float64)) error { - - config, err := gallery.GetGalleryConfigFromURL(req.URL) - if err != nil { - return err - } - - config.Files = append(config.Files, req.AdditionalFiles...) - - return gallery.InstallModel(modelPath, req.Name, &config, req.Overrides, downloadStatus) -} - -func (g *galleryApplier) updateStatus(s string, op *galleryOpStatus) { - g.Lock() - defer g.Unlock() - g.statuses[s] = op -} - -func (g *galleryApplier) getStatus(s string) *galleryOpStatus { - g.Lock() - defer g.Unlock() - - return g.statuses[s] -} - -func (g *galleryApplier) getAllStatus() map[string]*galleryOpStatus { - g.Lock() - defer g.Unlock() - - return g.statuses -} - -func (g *galleryApplier) Start(c context.Context, cm *config.ConfigLoader) { - go func() { - for { - select { - case <-c.Done(): - return - case op := <-g.C: - utils.ResetDownloadTimers() - - g.updateStatus(op.id, &galleryOpStatus{Message: "processing", Progress: 0}) - - // updates the status with an error - updateError := func(e error) { - g.updateStatus(op.id, &galleryOpStatus{Error: e, Processed: true, Message: "error: " + e.Error()}) - } - - // displayDownload displays the download progress - progressCallback := func(fileName string, current string, total string, percentage float64) { - g.updateStatus(op.id, &galleryOpStatus{Message: "processing", FileName: fileName, Progress: percentage, TotalFileSize: total, DownloadedFileSize: current}) - utils.DisplayDownloadFunction(fileName, current, total, percentage) - } - - var err error - // if the request contains a gallery name, we apply the gallery from the gallery list - if op.galleryName != "" { - if strings.Contains(op.galleryName, "@") { - err = gallery.InstallModelFromGallery(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback) - } else { - err = gallery.InstallModelFromGalleryByName(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback) - } - } else { - err = prepareModel(g.modelPath, op.req, cm, progressCallback) - } - - if err != nil { - updateError(err) - continue - } - - // Reload models - err = cm.LoadConfigs(g.modelPath) - if err != nil { - updateError(err) - continue - } - - err = cm.Preload(g.modelPath) - if err != nil { - updateError(err) - continue - } - - g.updateStatus(op.id, &galleryOpStatus{Processed: true, Message: "completed", Progress: 100}) - } - } - }() -} - -type galleryModel struct { - gallery.GalleryModel `yaml:",inline"` // https://github.com/go-yaml/yaml/issues/63 - ID string `json:"id"` -} - -func processRequests(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery, requests []galleryModel) error { - var err error - for _, r := range requests { - utils.ResetDownloadTimers() - if r.ID == "" { - err = prepareModel(modelPath, r.GalleryModel, cm, utils.DisplayDownloadFunction) - } else { - if strings.Contains(r.ID, "@") { - err = gallery.InstallModelFromGallery( - galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction) - } else { - err = gallery.InstallModelFromGalleryByName( - galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction) - } - } - } - return err -} - -func ApplyGalleryFromFile(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery) error { - dat, err := os.ReadFile(s) - if err != nil { - return err - } - var requests []galleryModel - - if err := yaml.Unmarshal(dat, &requests); err != nil { - return err - } - - return processRequests(modelPath, s, cm, galleries, requests) -} - -func ApplyGalleryFromString(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery) error { - var requests []galleryModel - err := json.Unmarshal([]byte(s), &requests) - if err != nil { - return err - } - - return processRequests(modelPath, s, cm, galleries, requests) -} - -/// Endpoint Service - -type ModelGalleryService struct { - galleries []gallery.Gallery - modelPath string - galleryApplier *galleryApplier -} - -type GalleryModel struct { - ID string `json:"id"` - gallery.GalleryModel -} - -func CreateModelGalleryService(galleries []gallery.Gallery, modelPath string, galleryApplier *galleryApplier) ModelGalleryService { - return ModelGalleryService{ - galleries: galleries, - modelPath: modelPath, - galleryApplier: galleryApplier, - } -} - -func (mgs *ModelGalleryService) GetOpStatusEndpoint() func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - status := mgs.galleryApplier.getStatus(c.Params("uuid")) - if status == nil { - return fmt.Errorf("could not find any status for ID") - } - return c.JSON(status) - } -} - -func (mgs *ModelGalleryService) GetAllStatusEndpoint() func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - return c.JSON(mgs.galleryApplier.getAllStatus()) - } -} - -func (mgs *ModelGalleryService) ApplyModelGalleryEndpoint() func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - input := new(GalleryModel) - // Get input data from the request body - if err := c.BodyParser(input); err != nil { - return err - } - - uuid, err := uuid.NewUUID() - if err != nil { - return err - } - mgs.galleryApplier.C <- galleryOp{ - req: input.GalleryModel, - id: uuid.String(), - galleryName: input.ID, - galleries: mgs.galleries, - } - return c.JSON(struct { - ID string `json:"uuid"` - StatusURL string `json:"status"` - }{ID: uuid.String(), StatusURL: c.BaseURL() + "/models/jobs/" + uuid.String()}) - } -} - -func (mgs *ModelGalleryService) ListModelFromGalleryEndpoint() func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - log.Debug().Msgf("Listing models from galleries: %+v", mgs.galleries) - - models, err := gallery.AvailableGalleryModels(mgs.galleries, mgs.modelPath) - if err != nil { - return err - } - log.Debug().Msgf("Models found from galleries: %+v", models) - for _, m := range models { - log.Debug().Msgf("Model found from galleries: %+v", m) - } - dat, err := json.Marshal(models) - if err != nil { - return err - } - return c.Send(dat) - } -} - -// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents! -func (mgs *ModelGalleryService) ListModelGalleriesEndpoint() func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - log.Debug().Msgf("Listing model galleries %+v", mgs.galleries) - dat, err := json.Marshal(mgs.galleries) - if err != nil { - return err - } - return c.Send(dat) - } -} - -func (mgs *ModelGalleryService) AddModelGalleryEndpoint() func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - input := new(gallery.Gallery) - // Get input data from the request body - if err := c.BodyParser(input); err != nil { - return err - } - if slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool { - return gallery.Name == input.Name - }) { - return fmt.Errorf("%s already exists", input.Name) - } - dat, err := json.Marshal(mgs.galleries) - if err != nil { - return err - } - log.Debug().Msgf("Adding %+v to gallery list", *input) - mgs.galleries = append(mgs.galleries, *input) - return c.Send(dat) - } -} - -func (mgs *ModelGalleryService) RemoveModelGalleryEndpoint() func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - input := new(gallery.Gallery) - // Get input data from the request body - if err := c.BodyParser(input); err != nil { - return err - } - if !slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool { - return gallery.Name == input.Name - }) { - return fmt.Errorf("%s is not currently registered", input.Name) - } - mgs.galleries = slices.DeleteFunc(mgs.galleries, func(gallery gallery.Gallery) bool { - return gallery.Name == input.Name - }) - return c.Send(nil) - } -} diff --git a/configuration/.keep b/configuration/.keep new file mode 100644 index 00000000..e69de29b diff --git a/core/backend/embeddings.go b/core/backend/embeddings.go index d8b89e12..0a74ea4c 100644 --- a/core/backend/embeddings.go +++ b/core/backend/embeddings.go @@ -3,36 +3,36 @@ package backend import ( "fmt" - config "github.com/go-skynet/LocalAI/core/config" - "github.com/go-skynet/LocalAI/core/options" + "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/pkg/grpc" model "github.com/go-skynet/LocalAI/pkg/model" ) -func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.Config, o *options.Option) (func() ([]float32, error), error) { - if !c.Embeddings { +func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) { + if !backendConfig.Embeddings { return nil, fmt.Errorf("endpoint disabled for this model by API configuration") } - modelFile := c.Model + modelFile := backendConfig.Model - grpcOpts := gRPCModelOpts(c) + grpcOpts := gRPCModelOpts(backendConfig) var inferenceModel interface{} var err error - opts := modelOpts(c, o, []model.Option{ + opts := modelOpts(backendConfig, appConfig, []model.Option{ model.WithLoadGRPCLoadModelOpts(grpcOpts), - model.WithThreads(uint32(c.Threads)), - model.WithAssetDir(o.AssetsDestination), + model.WithThreads(uint32(backendConfig.Threads)), + model.WithAssetDir(appConfig.AssetsDestination), model.WithModel(modelFile), - model.WithContext(o.Context), + model.WithContext(appConfig.Context), }) - if c.Backend == "" { + if backendConfig.Backend == "" { inferenceModel, err = loader.GreedyLoader(opts...) } else { - opts = append(opts, model.WithBackendString(c.Backend)) + opts = append(opts, model.WithBackendString(backendConfig.Backend)) inferenceModel, err = loader.BackendLoader(opts...) } if err != nil { @@ -43,7 +43,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config. switch model := inferenceModel.(type) { case grpc.Backend: fn = func() ([]float32, error) { - predictOptions := gRPCPredictOpts(c, loader.ModelPath) + predictOptions := gRPCPredictOpts(backendConfig, loader.ModelPath) if len(tokens) > 0 { embeds := []int32{} @@ -52,7 +52,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config. } predictOptions.EmbeddingTokens = embeds - res, err := model.Embeddings(o.Context, predictOptions) + res, err := model.Embeddings(appConfig.Context, predictOptions) if err != nil { return nil, err } @@ -61,7 +61,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config. } predictOptions.Embeddings = s - res, err := model.Embeddings(o.Context, predictOptions) + res, err := model.Embeddings(appConfig.Context, predictOptions) if err != nil { return nil, err } diff --git a/core/backend/image.go b/core/backend/image.go index 12ea57ce..60db48f9 100644 --- a/core/backend/image.go +++ b/core/backend/image.go @@ -1,33 +1,33 @@ package backend import ( - config "github.com/go-skynet/LocalAI/core/config" - "github.com/go-skynet/LocalAI/core/options" + "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/pkg/grpc/proto" model "github.com/go-skynet/LocalAI/pkg/model" ) -func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, c config.Config, o *options.Option) (func() error, error) { +func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() error, error) { - opts := modelOpts(c, o, []model.Option{ - model.WithBackendString(c.Backend), - model.WithAssetDir(o.AssetsDestination), - model.WithThreads(uint32(c.Threads)), - model.WithContext(o.Context), - model.WithModel(c.Model), + opts := modelOpts(backendConfig, appConfig, []model.Option{ + model.WithBackendString(backendConfig.Backend), + model.WithAssetDir(appConfig.AssetsDestination), + model.WithThreads(uint32(backendConfig.Threads)), + model.WithContext(appConfig.Context), + model.WithModel(backendConfig.Model), model.WithLoadGRPCLoadModelOpts(&proto.ModelOptions{ - CUDA: c.CUDA || c.Diffusers.CUDA, - SchedulerType: c.Diffusers.SchedulerType, - PipelineType: c.Diffusers.PipelineType, - CFGScale: c.Diffusers.CFGScale, - LoraAdapter: c.LoraAdapter, - LoraScale: c.LoraScale, - LoraBase: c.LoraBase, - IMG2IMG: c.Diffusers.IMG2IMG, - CLIPModel: c.Diffusers.ClipModel, - CLIPSubfolder: c.Diffusers.ClipSubFolder, - CLIPSkip: int32(c.Diffusers.ClipSkip), - ControlNet: c.Diffusers.ControlNet, + CUDA: backendConfig.CUDA || backendConfig.Diffusers.CUDA, + SchedulerType: backendConfig.Diffusers.SchedulerType, + PipelineType: backendConfig.Diffusers.PipelineType, + CFGScale: backendConfig.Diffusers.CFGScale, + LoraAdapter: backendConfig.LoraAdapter, + LoraScale: backendConfig.LoraScale, + LoraBase: backendConfig.LoraBase, + IMG2IMG: backendConfig.Diffusers.IMG2IMG, + CLIPModel: backendConfig.Diffusers.ClipModel, + CLIPSubfolder: backendConfig.Diffusers.ClipSubFolder, + CLIPSkip: int32(backendConfig.Diffusers.ClipSkip), + ControlNet: backendConfig.Diffusers.ControlNet, }), }) @@ -40,19 +40,19 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat fn := func() error { _, err := inferenceModel.GenerateImage( - o.Context, + appConfig.Context, &proto.GenerateImageRequest{ Height: int32(height), Width: int32(width), Mode: int32(mode), Step: int32(step), Seed: int32(seed), - CLIPSkip: int32(c.Diffusers.ClipSkip), + CLIPSkip: int32(backendConfig.Diffusers.ClipSkip), PositivePrompt: positive_prompt, NegativePrompt: negative_prompt, Dst: dst, Src: src, - EnableParameters: c.Diffusers.EnableParameters, + EnableParameters: backendConfig.Diffusers.EnableParameters, }) return err } diff --git a/core/backend/llm.go b/core/backend/llm.go index d1081ad6..f16878c0 100644 --- a/core/backend/llm.go +++ b/core/backend/llm.go @@ -8,8 +8,8 @@ import ( "sync" "unicode/utf8" - config "github.com/go-skynet/LocalAI/core/config" - "github.com/go-skynet/LocalAI/core/options" + "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/pkg/gallery" "github.com/go-skynet/LocalAI/pkg/grpc" model "github.com/go-skynet/LocalAI/pkg/model" @@ -26,7 +26,7 @@ type TokenUsage struct { Completion int } -func ModelInference(ctx context.Context, s string, images []string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) { +func ModelInference(ctx context.Context, s string, images []string, loader *model.ModelLoader, c config.BackendConfig, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) { modelFile := c.Model grpcOpts := gRPCModelOpts(c) @@ -140,7 +140,7 @@ func ModelInference(ctx context.Context, s string, images []string, loader *mode var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp) var mu sync.Mutex = sync.Mutex{} -func Finetune(config config.Config, input, prediction string) string { +func Finetune(config config.BackendConfig, input, prediction string) string { if config.Echo { prediction = input + prediction } diff --git a/core/backend/options.go b/core/backend/options.go index 9710ac17..60160572 100644 --- a/core/backend/options.go +++ b/core/backend/options.go @@ -4,19 +4,17 @@ import ( "os" "path/filepath" + "github.com/go-skynet/LocalAI/core/config" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" model "github.com/go-skynet/LocalAI/pkg/model" - - config "github.com/go-skynet/LocalAI/core/config" - "github.com/go-skynet/LocalAI/core/options" ) -func modelOpts(c config.Config, o *options.Option, opts []model.Option) []model.Option { - if o.SingleBackend { +func modelOpts(c config.BackendConfig, so *config.ApplicationConfig, opts []model.Option) []model.Option { + if so.SingleBackend { opts = append(opts, model.WithSingleActiveBackend()) } - if o.ParallelBackendRequests { + if so.ParallelBackendRequests { opts = append(opts, model.EnableParallelRequests) } @@ -28,14 +26,14 @@ func modelOpts(c config.Config, o *options.Option, opts []model.Option) []model. opts = append(opts, model.WithGRPCAttemptsDelay(c.GRPC.AttemptsSleepTime)) } - for k, v := range o.ExternalGRPCBackends { + for k, v := range so.ExternalGRPCBackends { opts = append(opts, model.WithExternalBackend(k, v)) } return opts } -func gRPCModelOpts(c config.Config) *pb.ModelOptions { +func gRPCModelOpts(c config.BackendConfig) *pb.ModelOptions { b := 512 if c.Batch != 0 { b = c.Batch @@ -84,7 +82,7 @@ func gRPCModelOpts(c config.Config) *pb.ModelOptions { } } -func gRPCPredictOpts(c config.Config, modelPath string) *pb.PredictOptions { +func gRPCPredictOpts(c config.BackendConfig, modelPath string) *pb.PredictOptions { promptCachePath := "" if c.PromptCachePath != "" { p := filepath.Join(modelPath, c.PromptCachePath) diff --git a/core/backend/transcript.go b/core/backend/transcript.go index 1cbaf820..bbb4f4b4 100644 --- a/core/backend/transcript.go +++ b/core/backend/transcript.go @@ -4,25 +4,24 @@ import ( "context" "fmt" - config "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/core/schema" - "github.com/go-skynet/LocalAI/core/options" "github.com/go-skynet/LocalAI/pkg/grpc/proto" model "github.com/go-skynet/LocalAI/pkg/model" ) -func ModelTranscription(audio, language string, loader *model.ModelLoader, c config.Config, o *options.Option) (*schema.Result, error) { +func ModelTranscription(audio, language string, ml *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (*schema.Result, error) { - opts := modelOpts(c, o, []model.Option{ + opts := modelOpts(backendConfig, appConfig, []model.Option{ model.WithBackendString(model.WhisperBackend), - model.WithModel(c.Model), - model.WithContext(o.Context), - model.WithThreads(uint32(c.Threads)), - model.WithAssetDir(o.AssetsDestination), + model.WithModel(backendConfig.Model), + model.WithContext(appConfig.Context), + model.WithThreads(uint32(backendConfig.Threads)), + model.WithAssetDir(appConfig.AssetsDestination), }) - whisperModel, err := o.Loader.BackendLoader(opts...) + whisperModel, err := ml.BackendLoader(opts...) if err != nil { return nil, err } @@ -34,6 +33,6 @@ func ModelTranscription(audio, language string, loader *model.ModelLoader, c con return whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{ Dst: audio, Language: language, - Threads: uint32(c.Threads), + Threads: uint32(backendConfig.Threads), }) } diff --git a/core/backend/tts.go b/core/backend/tts.go index a9d7153f..85aa3457 100644 --- a/core/backend/tts.go +++ b/core/backend/tts.go @@ -6,8 +6,8 @@ import ( "os" "path/filepath" - config "github.com/go-skynet/LocalAI/core/config" - "github.com/go-skynet/LocalAI/core/options" + "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/pkg/grpc/proto" model "github.com/go-skynet/LocalAI/pkg/model" "github.com/go-skynet/LocalAI/pkg/utils" @@ -29,22 +29,22 @@ func generateUniqueFileName(dir, baseName, ext string) string { } } -func ModelTTS(backend, text, modelFile string, loader *model.ModelLoader, o *options.Option, c config.Config) (string, *proto.Result, error) { +func ModelTTS(backend, text, modelFile string, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (string, *proto.Result, error) { bb := backend if bb == "" { bb = model.PiperBackend } - grpcOpts := gRPCModelOpts(c) + grpcOpts := gRPCModelOpts(backendConfig) - opts := modelOpts(config.Config{}, o, []model.Option{ + opts := modelOpts(config.BackendConfig{}, appConfig, []model.Option{ model.WithBackendString(bb), model.WithModel(modelFile), - model.WithContext(o.Context), - model.WithAssetDir(o.AssetsDestination), + model.WithContext(appConfig.Context), + model.WithAssetDir(appConfig.AssetsDestination), model.WithLoadGRPCLoadModelOpts(grpcOpts), }) - piperModel, err := o.Loader.BackendLoader(opts...) + piperModel, err := loader.BackendLoader(opts...) if err != nil { return "", nil, err } @@ -53,19 +53,19 @@ func ModelTTS(backend, text, modelFile string, loader *model.ModelLoader, o *opt return "", nil, fmt.Errorf("could not load piper model") } - if err := os.MkdirAll(o.AudioDir, 0755); err != nil { + if err := os.MkdirAll(appConfig.AudioDir, 0755); err != nil { return "", nil, fmt.Errorf("failed creating audio directory: %s", err) } - fileName := generateUniqueFileName(o.AudioDir, "piper", ".wav") - filePath := filepath.Join(o.AudioDir, fileName) + fileName := generateUniqueFileName(appConfig.AudioDir, "piper", ".wav") + filePath := filepath.Join(appConfig.AudioDir, fileName) // If the model file is not empty, we pass it joined with the model path modelPath := "" if modelFile != "" { if bb != model.TransformersMusicGen { - modelPath = filepath.Join(o.Loader.ModelPath, modelFile) - if err := utils.VerifyPath(modelPath, o.Loader.ModelPath); err != nil { + modelPath = filepath.Join(loader.ModelPath, modelFile) + if err := utils.VerifyPath(modelPath, appConfig.ModelPath); err != nil { return "", nil, err } } else { diff --git a/core/options/options.go b/core/config/application_config.go similarity index 69% rename from core/options/options.go rename to core/config/application_config.go index 72aea1a3..d90ae906 100644 --- a/core/options/options.go +++ b/core/config/application_config.go @@ -1,4 +1,4 @@ -package options +package config import ( "context" @@ -6,16 +6,14 @@ import ( "encoding/json" "time" - "github.com/go-skynet/LocalAI/metrics" "github.com/go-skynet/LocalAI/pkg/gallery" - model "github.com/go-skynet/LocalAI/pkg/model" "github.com/rs/zerolog/log" ) -type Option struct { +type ApplicationConfig struct { Context context.Context ConfigFile string - Loader *model.ModelLoader + ModelPath string UploadLimitMB, Threads, ContextSize int F16 bool Debug, DisableMessage bool @@ -27,7 +25,6 @@ type Option struct { PreloadModelsFromPath string CORSAllowOrigins string ApiKeys []string - Metrics *metrics.Metrics ModelLibraryURL string @@ -52,10 +49,10 @@ type Option struct { WatchDogBusyTimeout, WatchDogIdleTimeout time.Duration } -type AppOption func(*Option) +type AppOption func(*ApplicationConfig) -func NewOptions(o ...AppOption) *Option { - opt := &Option{ +func NewApplicationConfig(o ...AppOption) *ApplicationConfig { + opt := &ApplicationConfig{ Context: context.Background(), UploadLimitMB: 15, Threads: 1, @@ -70,63 +67,69 @@ func NewOptions(o ...AppOption) *Option { } func WithModelsURL(urls ...string) AppOption { - return func(o *Option) { + return func(o *ApplicationConfig) { o.ModelsURL = urls } } +func WithModelPath(path string) AppOption { + return func(o *ApplicationConfig) { + o.ModelPath = path + } +} + func WithCors(b bool) AppOption { - return func(o *Option) { + return func(o *ApplicationConfig) { o.CORS = b } } func WithModelLibraryURL(url string) AppOption { - return func(o *Option) { + return func(o *ApplicationConfig) { o.ModelLibraryURL = url } } -var EnableWatchDog = func(o *Option) { +var EnableWatchDog = func(o *ApplicationConfig) { o.WatchDog = true } -var EnableWatchDogIdleCheck = func(o *Option) { +var EnableWatchDogIdleCheck = func(o *ApplicationConfig) { o.WatchDog = true o.WatchDogIdle = true } -var EnableWatchDogBusyCheck = func(o *Option) { +var EnableWatchDogBusyCheck = func(o *ApplicationConfig) { o.WatchDog = true o.WatchDogBusy = true } func SetWatchDogBusyTimeout(t time.Duration) AppOption { - return func(o *Option) { + return func(o *ApplicationConfig) { o.WatchDogBusyTimeout = t } } func SetWatchDogIdleTimeout(t time.Duration) AppOption { - return func(o *Option) { + return func(o *ApplicationConfig) { o.WatchDogIdleTimeout = t } } -var EnableSingleBackend = func(o *Option) { +var EnableSingleBackend = func(o *ApplicationConfig) { o.SingleBackend = true } -var EnableParallelBackendRequests = func(o *Option) { +var EnableParallelBackendRequests = func(o *ApplicationConfig) { o.ParallelBackendRequests = true } -var EnableGalleriesAutoload = func(o *Option) { +var EnableGalleriesAutoload = func(o *ApplicationConfig) { o.AutoloadGalleries = true } func WithExternalBackend(name string, uri string) AppOption { - return func(o *Option) { + return func(o *ApplicationConfig) { if o.ExternalGRPCBackends == nil { o.ExternalGRPCBackends = make(map[string]string) } @@ -135,25 +138,25 @@ func WithExternalBackend(name string, uri string) AppOption { } func WithCorsAllowOrigins(b string) AppOption { - return func(o *Option) { + return func(o *ApplicationConfig) { o.CORSAllowOrigins = b } } func WithBackendAssetsOutput(out string) AppOption { - return func(o *Option) { + return func(o *ApplicationConfig) { o.AssetsDestination = out } } func WithBackendAssets(f embed.FS) AppOption { - return func(o *Option) { + return func(o *ApplicationConfig) { o.BackendAssets = f } } func WithStringGalleries(galls string) AppOption { - return func(o *Option) { + return func(o *ApplicationConfig) { if galls == "" { log.Debug().Msgf("no galleries to load") o.Galleries = []gallery.Gallery{} @@ -168,102 +171,96 @@ func WithStringGalleries(galls string) AppOption { } func WithGalleries(galleries []gallery.Gallery) AppOption { - return func(o *Option) { + return func(o *ApplicationConfig) { o.Galleries = append(o.Galleries, galleries...) } } func WithContext(ctx context.Context) AppOption { - return func(o *Option) { + return func(o *ApplicationConfig) { o.Context = ctx } } func WithYAMLConfigPreload(configFile string) AppOption { - return func(o *Option) { + return func(o *ApplicationConfig) { o.PreloadModelsFromPath = configFile } } func WithJSONStringPreload(configFile string) AppOption { - return func(o *Option) { + return func(o *ApplicationConfig) { o.PreloadJSONModels = configFile } } func WithConfigFile(configFile string) AppOption { - return func(o *Option) { + return func(o *ApplicationConfig) { o.ConfigFile = configFile } } -func WithModelLoader(loader *model.ModelLoader) AppOption { - return func(o *Option) { - o.Loader = loader - } -} - func WithUploadLimitMB(limit int) AppOption { - return func(o *Option) { + return func(o *ApplicationConfig) { o.UploadLimitMB = limit } } func WithThreads(threads int) AppOption { - return func(o *Option) { + return func(o *ApplicationConfig) { o.Threads = threads } } func WithContextSize(ctxSize int) AppOption { - return func(o *Option) { + return func(o *ApplicationConfig) { o.ContextSize = ctxSize } } func WithF16(f16 bool) AppOption { - return func(o *Option) { + return func(o *ApplicationConfig) { o.F16 = f16 } } func WithDebug(debug bool) AppOption { - return func(o *Option) { + return func(o *ApplicationConfig) { o.Debug = debug } } func WithDisableMessage(disableMessage bool) AppOption { - return func(o *Option) { + return func(o *ApplicationConfig) { o.DisableMessage = disableMessage } } func WithAudioDir(audioDir string) AppOption { - return func(o *Option) { + return func(o *ApplicationConfig) { o.AudioDir = audioDir } } func WithImageDir(imageDir string) AppOption { - return func(o *Option) { + return func(o *ApplicationConfig) { o.ImageDir = imageDir } } func WithUploadDir(uploadDir string) AppOption { - return func(o *Option) { + return func(o *ApplicationConfig) { o.UploadDir = uploadDir } } func WithApiKeys(apiKeys []string) AppOption { - return func(o *Option) { + return func(o *ApplicationConfig) { o.ApiKeys = apiKeys } } -func WithMetrics(meter *metrics.Metrics) AppOption { - return func(o *Option) { - o.Metrics = meter - } -} +// func WithMetrics(meter *metrics.Metrics) AppOption { +// return func(o *StartupOptions) { +// o.Metrics = meter +// } +// } diff --git a/core/config/config.go b/core/config/backend_config.go similarity index 77% rename from core/config/config.go rename to core/config/backend_config.go index af203ecc..3098da86 100644 --- a/core/config/config.go +++ b/core/config/backend_config.go @@ -9,15 +9,16 @@ import ( "strings" "sync" + "github.com/go-skynet/LocalAI/core/schema" "github.com/go-skynet/LocalAI/pkg/downloader" "github.com/go-skynet/LocalAI/pkg/utils" "github.com/rs/zerolog/log" "gopkg.in/yaml.v3" ) -type Config struct { - PredictionOptions `yaml:"parameters"` - Name string `yaml:"name"` +type BackendConfig struct { + schema.PredictionOptions `yaml:"parameters"` + Name string `yaml:"name"` F16 bool `yaml:"f16"` Threads int `yaml:"threads"` @@ -159,37 +160,55 @@ type TemplateConfig struct { Functions string `yaml:"function"` } -type ConfigLoader struct { - configs map[string]Config - sync.Mutex -} - -func (c *Config) SetFunctionCallString(s string) { +func (c *BackendConfig) SetFunctionCallString(s string) { c.functionCallString = s } -func (c *Config) SetFunctionCallNameString(s string) { +func (c *BackendConfig) SetFunctionCallNameString(s string) { c.functionCallNameString = s } -func (c *Config) ShouldUseFunctions() bool { +func (c *BackendConfig) ShouldUseFunctions() bool { return ((c.functionCallString != "none" || c.functionCallString == "") || c.ShouldCallSpecificFunction()) } -func (c *Config) ShouldCallSpecificFunction() bool { +func (c *BackendConfig) ShouldCallSpecificFunction() bool { return len(c.functionCallNameString) > 0 } -func (c *Config) FunctionToCall() string { +func (c *BackendConfig) FunctionToCall() string { return c.functionCallNameString } +func defaultPredictOptions(modelFile string) schema.PredictionOptions { + return schema.PredictionOptions{ + TopP: 0.7, + TopK: 80, + Maxtokens: 512, + Temperature: 0.9, + Model: modelFile, + } +} + +func DefaultConfig(modelFile string) *BackendConfig { + return &BackendConfig{ + PredictionOptions: defaultPredictOptions(modelFile), + } +} + +////// Config Loader //////// + +type BackendConfigLoader struct { + configs map[string]BackendConfig + sync.Mutex +} + // Load a config file for a model -func Load(modelName, modelPath string, cm *ConfigLoader, debug bool, threads, ctx int, f16 bool) (*Config, error) { +func LoadBackendConfigFileByName(modelName, modelPath string, cl *BackendConfigLoader, debug bool, threads, ctx int, f16 bool) (*BackendConfig, error) { // Load a config file if present after the model name modelConfig := filepath.Join(modelPath, modelName+".yaml") - var cfg *Config + var cfg *BackendConfig defaults := func() { cfg = DefaultConfig(modelName) @@ -199,13 +218,13 @@ func Load(modelName, modelPath string, cm *ConfigLoader, debug bool, threads, ct cfg.Debug = debug } - cfgExisting, exists := cm.GetConfig(modelName) + cfgExisting, exists := cl.GetBackendConfig(modelName) if !exists { if _, err := os.Stat(modelConfig); err == nil { - if err := cm.LoadConfig(modelConfig); err != nil { + if err := cl.LoadBackendConfig(modelConfig); err != nil { return nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error()) } - cfgExisting, exists = cm.GetConfig(modelName) + cfgExisting, exists = cl.GetBackendConfig(modelName) if exists { cfg = &cfgExisting } else { @@ -238,29 +257,13 @@ func Load(modelName, modelPath string, cm *ConfigLoader, debug bool, threads, ct return cfg, nil } -func defaultPredictOptions(modelFile string) PredictionOptions { - return PredictionOptions{ - TopP: 0.7, - TopK: 80, - Maxtokens: 512, - Temperature: 0.9, - Model: modelFile, +func NewBackendConfigLoader() *BackendConfigLoader { + return &BackendConfigLoader{ + configs: make(map[string]BackendConfig), } } - -func DefaultConfig(modelFile string) *Config { - return &Config{ - PredictionOptions: defaultPredictOptions(modelFile), - } -} - -func NewConfigLoader() *ConfigLoader { - return &ConfigLoader{ - configs: make(map[string]Config), - } -} -func ReadConfigFile(file string) ([]*Config, error) { - c := &[]*Config{} +func ReadBackendConfigFile(file string) ([]*BackendConfig, error) { + c := &[]*BackendConfig{} f, err := os.ReadFile(file) if err != nil { return nil, fmt.Errorf("cannot read config file: %w", err) @@ -272,8 +275,8 @@ func ReadConfigFile(file string) ([]*Config, error) { return *c, nil } -func ReadConfig(file string) (*Config, error) { - c := &Config{} +func ReadBackendConfig(file string) (*BackendConfig, error) { + c := &BackendConfig{} f, err := os.ReadFile(file) if err != nil { return nil, fmt.Errorf("cannot read config file: %w", err) @@ -285,10 +288,10 @@ func ReadConfig(file string) (*Config, error) { return c, nil } -func (cm *ConfigLoader) LoadConfigFile(file string) error { +func (cm *BackendConfigLoader) LoadBackendConfigFile(file string) error { cm.Lock() defer cm.Unlock() - c, err := ReadConfigFile(file) + c, err := ReadBackendConfigFile(file) if err != nil { return fmt.Errorf("cannot load config file: %w", err) } @@ -299,49 +302,49 @@ func (cm *ConfigLoader) LoadConfigFile(file string) error { return nil } -func (cm *ConfigLoader) LoadConfig(file string) error { - cm.Lock() - defer cm.Unlock() - c, err := ReadConfig(file) +func (cl *BackendConfigLoader) LoadBackendConfig(file string) error { + cl.Lock() + defer cl.Unlock() + c, err := ReadBackendConfig(file) if err != nil { return fmt.Errorf("cannot read config file: %w", err) } - cm.configs[c.Name] = *c + cl.configs[c.Name] = *c return nil } -func (cm *ConfigLoader) GetConfig(m string) (Config, bool) { - cm.Lock() - defer cm.Unlock() - v, exists := cm.configs[m] +func (cl *BackendConfigLoader) GetBackendConfig(m string) (BackendConfig, bool) { + cl.Lock() + defer cl.Unlock() + v, exists := cl.configs[m] return v, exists } -func (cm *ConfigLoader) GetAllConfigs() []Config { - cm.Lock() - defer cm.Unlock() - var res []Config - for _, v := range cm.configs { +func (cl *BackendConfigLoader) GetAllBackendConfigs() []BackendConfig { + cl.Lock() + defer cl.Unlock() + var res []BackendConfig + for _, v := range cl.configs { res = append(res, v) } return res } -func (cm *ConfigLoader) ListConfigs() []string { - cm.Lock() - defer cm.Unlock() +func (cl *BackendConfigLoader) ListBackendConfigs() []string { + cl.Lock() + defer cl.Unlock() var res []string - for k := range cm.configs { + for k := range cl.configs { res = append(res, k) } return res } // Preload prepare models if they are not local but url or huggingface repositories -func (cm *ConfigLoader) Preload(modelPath string) error { - cm.Lock() - defer cm.Unlock() +func (cl *BackendConfigLoader) Preload(modelPath string) error { + cl.Lock() + defer cl.Unlock() status := func(fileName, current, total string, percent float64) { utils.DisplayDownloadFunction(fileName, current, total, percent) @@ -349,7 +352,7 @@ func (cm *ConfigLoader) Preload(modelPath string) error { log.Info().Msgf("Preloading models from %s", modelPath) - for i, config := range cm.configs { + for i, config := range cl.configs { // Download files and verify their SHA for _, file := range config.DownloadFiles { @@ -381,25 +384,25 @@ func (cm *ConfigLoader) Preload(modelPath string) error { } } - cc := cm.configs[i] + cc := cl.configs[i] c := &cc c.PredictionOptions.Model = md5Name - cm.configs[i] = *c + cl.configs[i] = *c } - if cm.configs[i].Name != "" { - log.Info().Msgf("Model name: %s", cm.configs[i].Name) + if cl.configs[i].Name != "" { + log.Info().Msgf("Model name: %s", cl.configs[i].Name) } - if cm.configs[i].Description != "" { - log.Info().Msgf("Model description: %s", cm.configs[i].Description) + if cl.configs[i].Description != "" { + log.Info().Msgf("Model description: %s", cl.configs[i].Description) } - if cm.configs[i].Usage != "" { - log.Info().Msgf("Model usage: \n%s", cm.configs[i].Usage) + if cl.configs[i].Usage != "" { + log.Info().Msgf("Model usage: \n%s", cl.configs[i].Usage) } } return nil } -func (cm *ConfigLoader) LoadConfigs(path string) error { +func (cm *BackendConfigLoader) LoadBackendConfigsFromPath(path string) error { cm.Lock() defer cm.Unlock() entries, err := os.ReadDir(path) @@ -419,7 +422,7 @@ func (cm *ConfigLoader) LoadConfigs(path string) error { if !strings.Contains(file.Name(), ".yaml") && !strings.Contains(file.Name(), ".yml") { continue } - c, err := ReadConfig(filepath.Join(path, file.Name())) + c, err := ReadBackendConfig(filepath.Join(path, file.Name())) if err == nil { cm.configs[c.Name] = *c } diff --git a/core/config/config_test.go b/core/config/config_test.go index d1e92d5c..b18e083f 100644 --- a/core/config/config_test.go +++ b/core/config/config_test.go @@ -4,8 +4,7 @@ import ( "os" . "github.com/go-skynet/LocalAI/core/config" - "github.com/go-skynet/LocalAI/core/options" - "github.com/go-skynet/LocalAI/pkg/model" + . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) @@ -19,7 +18,7 @@ var _ = Describe("Test cases for config related functions", func() { Context("Test Read configuration functions", func() { configFile = os.Getenv("CONFIG_FILE") It("Test ReadConfigFile", func() { - config, err := ReadConfigFile(configFile) + config, err := ReadBackendConfigFile(configFile) Expect(err).To(BeNil()) Expect(config).ToNot(BeNil()) // two configs in config.yaml @@ -28,29 +27,26 @@ var _ = Describe("Test cases for config related functions", func() { }) It("Test LoadConfigs", func() { - cm := NewConfigLoader() - opts := options.NewOptions() - modelLoader := model.NewModelLoader(os.Getenv("MODELS_PATH")) - options.WithModelLoader(modelLoader)(opts) - - err := cm.LoadConfigs(opts.Loader.ModelPath) + cm := NewBackendConfigLoader() + opts := NewApplicationConfig() + err := cm.LoadBackendConfigsFromPath(opts.ModelPath) Expect(err).To(BeNil()) - Expect(cm.ListConfigs()).ToNot(BeNil()) + Expect(cm.ListBackendConfigs()).ToNot(BeNil()) // config should includes gpt4all models's api.config - Expect(cm.ListConfigs()).To(ContainElements("gpt4all")) + Expect(cm.ListBackendConfigs()).To(ContainElements("gpt4all")) // config should includes gpt2 models's api.config - Expect(cm.ListConfigs()).To(ContainElements("gpt4all-2")) + Expect(cm.ListBackendConfigs()).To(ContainElements("gpt4all-2")) // config should includes text-embedding-ada-002 models's api.config - Expect(cm.ListConfigs()).To(ContainElements("text-embedding-ada-002")) + Expect(cm.ListBackendConfigs()).To(ContainElements("text-embedding-ada-002")) // config should includes rwkv_test models's api.config - Expect(cm.ListConfigs()).To(ContainElements("rwkv_test")) + Expect(cm.ListBackendConfigs()).To(ContainElements("rwkv_test")) // config should includes whisper-1 models's api.config - Expect(cm.ListConfigs()).To(ContainElements("whisper-1")) + Expect(cm.ListBackendConfigs()).To(ContainElements("whisper-1")) }) }) }) diff --git a/core/http/api.go b/core/http/api.go index 7d228152..e2646a14 100644 --- a/core/http/api.go +++ b/core/http/api.go @@ -3,122 +3,29 @@ package http import ( "encoding/json" "errors" - "fmt" "os" "strings" - "github.com/go-skynet/LocalAI/api/localai" - "github.com/go-skynet/LocalAI/api/openai" - config "github.com/go-skynet/LocalAI/core/config" - "github.com/go-skynet/LocalAI/core/options" + "github.com/go-skynet/LocalAI/core/http/endpoints/localai" + "github.com/go-skynet/LocalAI/core/http/endpoints/openai" + + "github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/core/schema" + "github.com/go-skynet/LocalAI/core/services" "github.com/go-skynet/LocalAI/internal" - "github.com/go-skynet/LocalAI/metrics" - "github.com/go-skynet/LocalAI/pkg/assets" "github.com/go-skynet/LocalAI/pkg/model" - "github.com/go-skynet/LocalAI/pkg/startup" "github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2/middleware/cors" "github.com/gofiber/fiber/v2/middleware/logger" "github.com/gofiber/fiber/v2/middleware/recover" - "github.com/rs/zerolog" - "github.com/rs/zerolog/log" ) -func Startup(opts ...options.AppOption) (*options.Option, *config.ConfigLoader, error) { - options := options.NewOptions(opts...) - - zerolog.SetGlobalLevel(zerolog.InfoLevel) - if options.Debug { - zerolog.SetGlobalLevel(zerolog.DebugLevel) - } - - log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.Loader.ModelPath) - log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion()) - - startup.PreloadModelsConfigurations(options.ModelLibraryURL, options.Loader.ModelPath, options.ModelsURL...) - - cl := config.NewConfigLoader() - if err := cl.LoadConfigs(options.Loader.ModelPath); err != nil { - log.Error().Msgf("error loading config files: %s", err.Error()) - } - - if options.ConfigFile != "" { - if err := cl.LoadConfigFile(options.ConfigFile); err != nil { - log.Error().Msgf("error loading config file: %s", err.Error()) - } - } - - if err := cl.Preload(options.Loader.ModelPath); err != nil { - log.Error().Msgf("error downloading models: %s", err.Error()) - } - - if options.PreloadJSONModels != "" { - if err := localai.ApplyGalleryFromString(options.Loader.ModelPath, options.PreloadJSONModels, cl, options.Galleries); err != nil { - return nil, nil, err - } - } - - if options.PreloadModelsFromPath != "" { - if err := localai.ApplyGalleryFromFile(options.Loader.ModelPath, options.PreloadModelsFromPath, cl, options.Galleries); err != nil { - return nil, nil, err - } - } - - if options.Debug { - for _, v := range cl.ListConfigs() { - cfg, _ := cl.GetConfig(v) - log.Debug().Msgf("Model: %s (config: %+v)", v, cfg) - } - } - - if options.AssetsDestination != "" { - // Extract files from the embedded FS - err := assets.ExtractFiles(options.BackendAssets, options.AssetsDestination) - log.Debug().Msgf("Extracting backend assets files to %s", options.AssetsDestination) - if err != nil { - log.Warn().Msgf("Failed extracting backend assets files: %s (might be required for some backends to work properly, like gpt4all)", err) - } - } - - // turn off any process that was started by GRPC if the context is canceled - go func() { - <-options.Context.Done() - log.Debug().Msgf("Context canceled, shutting down") - options.Loader.StopAllGRPC() - }() - - if options.WatchDog { - wd := model.NewWatchDog( - options.Loader, - options.WatchDogBusyTimeout, - options.WatchDogIdleTimeout, - options.WatchDogBusy, - options.WatchDogIdle) - options.Loader.SetWatchDog(wd) - go wd.Run() - go func() { - <-options.Context.Done() - log.Debug().Msgf("Context canceled, shutting down") - wd.Shutdown() - }() - } - - return options, cl, nil -} - -func App(opts ...options.AppOption) (*fiber.App, error) { - - options, cl, err := Startup(opts...) - if err != nil { - return nil, fmt.Errorf("failed basic startup tasks with error %s", err.Error()) - } - +func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (*fiber.App, error) { // Return errors as JSON responses app := fiber.New(fiber.Config{ - BodyLimit: options.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB - DisableStartupMessage: options.DisableMessage, + BodyLimit: appConfig.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB + DisableStartupMessage: appConfig.DisableMessage, // Override default error handler ErrorHandler: func(ctx *fiber.Ctx, err error) error { // Status code defaults to 500 @@ -139,7 +46,7 @@ func App(opts ...options.AppOption) (*fiber.App, error) { }, }) - if options.Debug { + if appConfig.Debug { app.Use(logger.New(logger.Config{ Format: "[${ip}]:${port} ${status} - ${method} ${path}\n", })) @@ -147,17 +54,25 @@ func App(opts ...options.AppOption) (*fiber.App, error) { // Default middleware config - if !options.Debug { + if !appConfig.Debug { app.Use(recover.New()) } - if options.Metrics != nil { - app.Use(metrics.APIMiddleware(options.Metrics)) + metricsService, err := services.NewLocalAIMetricsService() + if err != nil { + return nil, err + } + + if metricsService != nil { + app.Use(localai.LocalAIMetricsAPIMiddleware(metricsService)) + app.Hooks().OnShutdown(func() error { + return metricsService.Shutdown() + }) } // Auth middleware checking if API key is valid. If no API key is set, no auth is required. auth := func(c *fiber.Ctx) error { - if len(options.ApiKeys) == 0 { + if len(appConfig.ApiKeys) == 0 { return c.Next() } @@ -172,10 +87,10 @@ func App(opts ...options.AppOption) (*fiber.App, error) { } // Add file keys to options.ApiKeys - options.ApiKeys = append(options.ApiKeys, fileKeys...) + appConfig.ApiKeys = append(appConfig.ApiKeys, fileKeys...) } - if len(options.ApiKeys) == 0 { + if len(appConfig.ApiKeys) == 0 { return c.Next() } @@ -189,7 +104,7 @@ func App(opts ...options.AppOption) (*fiber.App, error) { } apiKey := authHeaderParts[1] - for _, key := range options.ApiKeys { + for _, key := range appConfig.ApiKeys { if apiKey == key { return c.Next() } @@ -199,20 +114,20 @@ func App(opts ...options.AppOption) (*fiber.App, error) { } - if options.CORS { + if appConfig.CORS { var c func(ctx *fiber.Ctx) error - if options.CORSAllowOrigins == "" { + if appConfig.CORSAllowOrigins == "" { c = cors.New() } else { - c = cors.New(cors.Config{AllowOrigins: options.CORSAllowOrigins}) + c = cors.New(cors.Config{AllowOrigins: appConfig.CORSAllowOrigins}) } app.Use(c) } // LocalAI API endpoints - galleryService := localai.NewGalleryService(options.Loader.ModelPath) - galleryService.Start(options.Context, cl) + galleryService := services.NewGalleryService(appConfig.ModelPath) + galleryService.Start(appConfig.Context, cl) app.Get("/version", auth, func(c *fiber.Ctx) error { return c.JSON(struct { @@ -220,69 +135,63 @@ func App(opts ...options.AppOption) (*fiber.App, error) { }{Version: internal.PrintableVersion()}) }) - // Make sure directories exists - os.MkdirAll(options.ImageDir, 0755) - os.MkdirAll(options.AudioDir, 0755) - os.MkdirAll(options.UploadDir, 0755) - os.MkdirAll(options.Loader.ModelPath, 0755) - // Load upload json - openai.LoadUploadConfig(options.UploadDir) + openai.LoadUploadConfig(appConfig.UploadDir) - modelGalleryService := localai.CreateModelGalleryService(options.Galleries, options.Loader.ModelPath, galleryService) - app.Post("/models/apply", auth, modelGalleryService.ApplyModelGalleryEndpoint()) - app.Get("/models/available", auth, modelGalleryService.ListModelFromGalleryEndpoint()) - app.Get("/models/galleries", auth, modelGalleryService.ListModelGalleriesEndpoint()) - app.Post("/models/galleries", auth, modelGalleryService.AddModelGalleryEndpoint()) - app.Delete("/models/galleries", auth, modelGalleryService.RemoveModelGalleryEndpoint()) - app.Get("/models/jobs/:uuid", auth, modelGalleryService.GetOpStatusEndpoint()) - app.Get("/models/jobs", auth, modelGalleryService.GetAllStatusEndpoint()) + modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.ModelPath, galleryService) + app.Post("/models/apply", auth, modelGalleryEndpointService.ApplyModelGalleryEndpoint()) + app.Get("/models/available", auth, modelGalleryEndpointService.ListModelFromGalleryEndpoint()) + app.Get("/models/galleries", auth, modelGalleryEndpointService.ListModelGalleriesEndpoint()) + app.Post("/models/galleries", auth, modelGalleryEndpointService.AddModelGalleryEndpoint()) + app.Delete("/models/galleries", auth, modelGalleryEndpointService.RemoveModelGalleryEndpoint()) + app.Get("/models/jobs/:uuid", auth, modelGalleryEndpointService.GetOpStatusEndpoint()) + app.Get("/models/jobs", auth, modelGalleryEndpointService.GetAllStatusEndpoint()) // openAI compatible API endpoint // chat - app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(cl, options)) - app.Post("/chat/completions", auth, openai.ChatEndpoint(cl, options)) + app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(cl, ml, appConfig)) + app.Post("/chat/completions", auth, openai.ChatEndpoint(cl, ml, appConfig)) // edit - app.Post("/v1/edits", auth, openai.EditEndpoint(cl, options)) - app.Post("/edits", auth, openai.EditEndpoint(cl, options)) + app.Post("/v1/edits", auth, openai.EditEndpoint(cl, ml, appConfig)) + app.Post("/edits", auth, openai.EditEndpoint(cl, ml, appConfig)) // files - app.Post("/v1/files", auth, openai.UploadFilesEndpoint(cl, options)) - app.Post("/files", auth, openai.UploadFilesEndpoint(cl, options)) - app.Get("/v1/files", auth, openai.ListFilesEndpoint(cl, options)) - app.Get("/files", auth, openai.ListFilesEndpoint(cl, options)) - app.Get("/v1/files/:file_id", auth, openai.GetFilesEndpoint(cl, options)) - app.Get("/files/:file_id", auth, openai.GetFilesEndpoint(cl, options)) - app.Delete("/v1/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, options)) - app.Delete("/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, options)) - app.Get("/v1/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, options)) - app.Get("/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, options)) + app.Post("/v1/files", auth, openai.UploadFilesEndpoint(cl, appConfig)) + app.Post("/files", auth, openai.UploadFilesEndpoint(cl, appConfig)) + app.Get("/v1/files", auth, openai.ListFilesEndpoint(cl, appConfig)) + app.Get("/files", auth, openai.ListFilesEndpoint(cl, appConfig)) + app.Get("/v1/files/:file_id", auth, openai.GetFilesEndpoint(cl, appConfig)) + app.Get("/files/:file_id", auth, openai.GetFilesEndpoint(cl, appConfig)) + app.Delete("/v1/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, appConfig)) + app.Delete("/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, appConfig)) + app.Get("/v1/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, appConfig)) + app.Get("/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, appConfig)) // completion - app.Post("/v1/completions", auth, openai.CompletionEndpoint(cl, options)) - app.Post("/completions", auth, openai.CompletionEndpoint(cl, options)) - app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(cl, options)) + app.Post("/v1/completions", auth, openai.CompletionEndpoint(cl, ml, appConfig)) + app.Post("/completions", auth, openai.CompletionEndpoint(cl, ml, appConfig)) + app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(cl, ml, appConfig)) // embeddings - app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(cl, options)) - app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(cl, options)) - app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(cl, options)) + app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, appConfig)) + app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, appConfig)) + app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, appConfig)) // audio - app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(cl, options)) - app.Post("/tts", auth, localai.TTSEndpoint(cl, options)) + app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(cl, ml, appConfig)) + app.Post("/tts", auth, localai.TTSEndpoint(cl, ml, appConfig)) // images - app.Post("/v1/images/generations", auth, openai.ImageEndpoint(cl, options)) + app.Post("/v1/images/generations", auth, openai.ImageEndpoint(cl, ml, appConfig)) - if options.ImageDir != "" { - app.Static("/generated-images", options.ImageDir) + if appConfig.ImageDir != "" { + app.Static("/generated-images", appConfig.ImageDir) } - if options.AudioDir != "" { - app.Static("/generated-audio", options.AudioDir) + if appConfig.AudioDir != "" { + app.Static("/generated-audio", appConfig.AudioDir) } ok := func(c *fiber.Ctx) error { @@ -294,15 +203,15 @@ func App(opts ...options.AppOption) (*fiber.App, error) { app.Get("/readyz", ok) // Experimental Backend Statistics Module - backendMonitor := localai.NewBackendMonitor(cl, options) // Split out for now + backendMonitor := services.NewBackendMonitor(cl, ml, appConfig) // Split out for now app.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitor)) app.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitor)) // models - app.Get("/v1/models", auth, openai.ListModelsEndpoint(options.Loader, cl)) - app.Get("/models", auth, openai.ListModelsEndpoint(options.Loader, cl)) + app.Get("/v1/models", auth, openai.ListModelsEndpoint(cl, ml)) + app.Get("/models", auth, openai.ListModelsEndpoint(cl, ml)) - app.Get("/metrics", metrics.MetricsHandler()) + app.Get("/metrics", localai.LocalAIMetricsEndpoint()) return app, nil } diff --git a/core/http/api_test.go b/core/http/api_test.go index 9068b393..8f3cfc91 100644 --- a/core/http/api_test.go +++ b/core/http/api_test.go @@ -13,9 +13,10 @@ import ( "path/filepath" "runtime" + "github.com/go-skynet/LocalAI/core/config" . "github.com/go-skynet/LocalAI/core/http" - "github.com/go-skynet/LocalAI/core/options" - "github.com/go-skynet/LocalAI/metrics" + "github.com/go-skynet/LocalAI/core/startup" + "github.com/go-skynet/LocalAI/pkg/downloader" "github.com/go-skynet/LocalAI/pkg/gallery" "github.com/go-skynet/LocalAI/pkg/model" @@ -127,25 +128,33 @@ var backendAssets embed.FS var _ = Describe("API test", func() { var app *fiber.App - var modelLoader *model.ModelLoader var client *openai.Client var client2 *openaigo.Client var c context.Context var cancel context.CancelFunc var tmpdir string + var modelDir string + var bcl *config.BackendConfigLoader + var ml *model.ModelLoader + var applicationConfig *config.ApplicationConfig - commonOpts := []options.AppOption{ - options.WithDebug(true), - options.WithDisableMessage(true), + commonOpts := []config.AppOption{ + config.WithDebug(true), + config.WithDisableMessage(true), } Context("API with ephemeral models", func() { - BeforeEach(func() { + + BeforeEach(func(sc SpecContext) { var err error tmpdir, err = os.MkdirTemp("", "") Expect(err).ToNot(HaveOccurred()) - modelLoader = model.NewModelLoader(tmpdir) + modelDir = filepath.Join(tmpdir, "models") + backendAssetsDir := filepath.Join(tmpdir, "backend-assets") + err = os.Mkdir(backendAssetsDir, 0755) + Expect(err).ToNot(HaveOccurred()) + c, cancel = context.WithCancel(context.Background()) g := []gallery.GalleryModel{ @@ -172,16 +181,18 @@ var _ = Describe("API test", func() { }, } - metricsService, err := metrics.SetupMetrics() + bcl, ml, applicationConfig, err = startup.Startup( + append(commonOpts, + config.WithContext(c), + config.WithGalleries(galleries), + config.WithModelPath(modelDir), + config.WithBackendAssets(backendAssets), + config.WithBackendAssetsOutput(backendAssetsDir))...) Expect(err).ToNot(HaveOccurred()) - app, err = App( - append(commonOpts, - options.WithMetrics(metricsService), - options.WithContext(c), - options.WithGalleries(galleries), - options.WithModelLoader(modelLoader), options.WithBackendAssets(backendAssets), options.WithBackendAssetsOutput(tmpdir))...) + app, err = App(bcl, ml, applicationConfig) Expect(err).ToNot(HaveOccurred()) + go app.Listen("127.0.0.1:9090") defaultConfig := openai.DefaultConfig("") @@ -198,15 +209,21 @@ var _ = Describe("API test", func() { }, "2m").ShouldNot(HaveOccurred()) }) - AfterEach(func() { + AfterEach(func(sc SpecContext) { cancel() - app.Shutdown() - os.RemoveAll(tmpdir) + if app != nil { + err := app.Shutdown() + Expect(err).ToNot(HaveOccurred()) + } + err := os.RemoveAll(tmpdir) + Expect(err).ToNot(HaveOccurred()) + _, err = os.ReadDir(tmpdir) + Expect(err).To(HaveOccurred()) }) Context("Applying models", func() { - It("applies models from a gallery", func() { + It("applies models from a gallery", func() { models := getModels("http://127.0.0.1:9090/models/available") Expect(len(models)).To(Equal(2), fmt.Sprint(models)) Expect(models[0].Installed).To(BeFalse(), fmt.Sprint(models)) @@ -228,10 +245,10 @@ var _ = Describe("API test", func() { }, "360s", "10s").Should(Equal(true)) Expect(resp["message"]).ToNot(ContainSubstring("error")) - dat, err := os.ReadFile(filepath.Join(tmpdir, "bert2.yaml")) + dat, err := os.ReadFile(filepath.Join(modelDir, "bert2.yaml")) Expect(err).ToNot(HaveOccurred()) - _, err = os.ReadFile(filepath.Join(tmpdir, "foo.yaml")) + _, err = os.ReadFile(filepath.Join(modelDir, "foo.yaml")) Expect(err).ToNot(HaveOccurred()) content := map[string]interface{}{} @@ -253,6 +270,7 @@ var _ = Describe("API test", func() { } }) It("overrides models", func() { + response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml", Name: "bert", @@ -270,7 +288,7 @@ var _ = Describe("API test", func() { return response["processed"].(bool) }, "360s", "10s").Should(Equal(true)) - dat, err := os.ReadFile(filepath.Join(tmpdir, "bert.yaml")) + dat, err := os.ReadFile(filepath.Join(modelDir, "bert.yaml")) Expect(err).ToNot(HaveOccurred()) content := map[string]interface{}{} @@ -294,7 +312,7 @@ var _ = Describe("API test", func() { return response["processed"].(bool) }, "360s", "10s").Should(Equal(true)) - dat, err := os.ReadFile(filepath.Join(tmpdir, "bert.yaml")) + dat, err := os.ReadFile(filepath.Join(modelDir, "bert.yaml")) Expect(err).ToNot(HaveOccurred()) content := map[string]interface{}{} @@ -483,8 +501,11 @@ var _ = Describe("API test", func() { var err error tmpdir, err = os.MkdirTemp("", "") Expect(err).ToNot(HaveOccurred()) + modelDir = filepath.Join(tmpdir, "models") + backendAssetsDir := filepath.Join(tmpdir, "backend-assets") + err = os.Mkdir(backendAssetsDir, 0755) + Expect(err).ToNot(HaveOccurred()) - modelLoader = model.NewModelLoader(tmpdir) c, cancel = context.WithCancel(context.Background()) galleries := []gallery.Gallery{ @@ -494,21 +515,20 @@ var _ = Describe("API test", func() { }, } - metricsService, err := metrics.SetupMetrics() - Expect(err).ToNot(HaveOccurred()) - - app, err = App( + bcl, ml, applicationConfig, err = startup.Startup( append(commonOpts, - options.WithContext(c), - options.WithMetrics(metricsService), - options.WithAudioDir(tmpdir), - options.WithImageDir(tmpdir), - options.WithGalleries(galleries), - options.WithModelLoader(modelLoader), - options.WithBackendAssets(backendAssets), - options.WithBackendAssetsOutput(tmpdir))..., + config.WithContext(c), + config.WithAudioDir(tmpdir), + config.WithImageDir(tmpdir), + config.WithGalleries(galleries), + config.WithModelPath(modelDir), + config.WithBackendAssets(backendAssets), + config.WithBackendAssetsOutput(tmpdir))..., ) Expect(err).ToNot(HaveOccurred()) + app, err = App(bcl, ml, applicationConfig) + Expect(err).ToNot(HaveOccurred()) + go app.Listen("127.0.0.1:9090") defaultConfig := openai.DefaultConfig("") @@ -527,8 +547,14 @@ var _ = Describe("API test", func() { AfterEach(func() { cancel() - app.Shutdown() - os.RemoveAll(tmpdir) + if app != nil { + err := app.Shutdown() + Expect(err).ToNot(HaveOccurred()) + } + err := os.RemoveAll(tmpdir) + Expect(err).ToNot(HaveOccurred()) + _, err = os.ReadDir(tmpdir) + Expect(err).To(HaveOccurred()) }) It("installs and is capable to run tts", Label("tts"), func() { if runtime.GOOS != "linux" { @@ -599,20 +625,20 @@ var _ = Describe("API test", func() { Context("API query", func() { BeforeEach(func() { - modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH")) + modelPath := os.Getenv("MODELS_PATH") c, cancel = context.WithCancel(context.Background()) - metricsService, err := metrics.SetupMetrics() - Expect(err).ToNot(HaveOccurred()) + var err error - app, err = App( + bcl, ml, applicationConfig, err = startup.Startup( append(commonOpts, - options.WithExternalBackend("huggingface", os.Getenv("HUGGINGFACE_GRPC")), - options.WithContext(c), - options.WithModelLoader(modelLoader), - options.WithMetrics(metricsService), + config.WithExternalBackend("huggingface", os.Getenv("HUGGINGFACE_GRPC")), + config.WithContext(c), + config.WithModelPath(modelPath), )...) Expect(err).ToNot(HaveOccurred()) + app, err = App(bcl, ml, applicationConfig) + Expect(err).ToNot(HaveOccurred()) go app.Listen("127.0.0.1:9090") defaultConfig := openai.DefaultConfig("") @@ -630,7 +656,10 @@ var _ = Describe("API test", func() { }) AfterEach(func() { cancel() - app.Shutdown() + if app != nil { + err := app.Shutdown() + Expect(err).ToNot(HaveOccurred()) + } }) It("returns the models list", func() { models, err := client.ListModels(context.TODO()) @@ -811,20 +840,20 @@ var _ = Describe("API test", func() { Context("Config file", func() { BeforeEach(func() { - modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH")) + modelPath := os.Getenv("MODELS_PATH") c, cancel = context.WithCancel(context.Background()) - metricsService, err := metrics.SetupMetrics() - Expect(err).ToNot(HaveOccurred()) - - app, err = App( + var err error + bcl, ml, applicationConfig, err = startup.Startup( append(commonOpts, - options.WithContext(c), - options.WithMetrics(metricsService), - options.WithModelLoader(modelLoader), - options.WithConfigFile(os.Getenv("CONFIG_FILE")))..., + config.WithContext(c), + config.WithModelPath(modelPath), + config.WithConfigFile(os.Getenv("CONFIG_FILE")))..., ) Expect(err).ToNot(HaveOccurred()) + app, err = App(bcl, ml, applicationConfig) + Expect(err).ToNot(HaveOccurred()) + go app.Listen("127.0.0.1:9090") defaultConfig := openai.DefaultConfig("") @@ -840,7 +869,10 @@ var _ = Describe("API test", func() { }) AfterEach(func() { cancel() - app.Shutdown() + if app != nil { + err := app.Shutdown() + Expect(err).ToNot(HaveOccurred()) + } }) It("can generate chat completions from config file (list1)", func() { resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list1", Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}}}) diff --git a/api/ctx/fiber.go b/core/http/ctx/fiber.go similarity index 100% rename from api/ctx/fiber.go rename to core/http/ctx/fiber.go diff --git a/core/http/endpoints/localai/backend_monitor.go b/core/http/endpoints/localai/backend_monitor.go new file mode 100644 index 00000000..8c7a664a --- /dev/null +++ b/core/http/endpoints/localai/backend_monitor.go @@ -0,0 +1,36 @@ +package localai + +import ( + "github.com/go-skynet/LocalAI/core/schema" + "github.com/go-skynet/LocalAI/core/services" + "github.com/gofiber/fiber/v2" +) + +func BackendMonitorEndpoint(bm services.BackendMonitor) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + + input := new(schema.BackendMonitorRequest) + // Get input data from the request body + if err := c.BodyParser(input); err != nil { + return err + } + + resp, err := bm.CheckAndSample(input.Model) + if err != nil { + return err + } + return c.JSON(resp) + } +} + +func BackendShutdownEndpoint(bm services.BackendMonitor) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + input := new(schema.BackendMonitorRequest) + // Get input data from the request body + if err := c.BodyParser(input); err != nil { + return err + } + + return bm.ShutdownModel(input.Model) + } +} diff --git a/core/http/endpoints/localai/gallery.go b/core/http/endpoints/localai/gallery.go new file mode 100644 index 00000000..5c295a2a --- /dev/null +++ b/core/http/endpoints/localai/gallery.go @@ -0,0 +1,146 @@ +package localai + +import ( + "encoding/json" + "fmt" + "slices" + + "github.com/go-skynet/LocalAI/core/services" + "github.com/go-skynet/LocalAI/pkg/gallery" + "github.com/gofiber/fiber/v2" + "github.com/google/uuid" + "github.com/rs/zerolog/log" +) + +type ModelGalleryEndpointService struct { + galleries []gallery.Gallery + modelPath string + galleryApplier *services.GalleryService +} + +type GalleryModel struct { + ID string `json:"id"` + gallery.GalleryModel +} + +func CreateModelGalleryEndpointService(galleries []gallery.Gallery, modelPath string, galleryApplier *services.GalleryService) ModelGalleryEndpointService { + return ModelGalleryEndpointService{ + galleries: galleries, + modelPath: modelPath, + galleryApplier: galleryApplier, + } +} + +func (mgs *ModelGalleryEndpointService) GetOpStatusEndpoint() func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + status := mgs.galleryApplier.GetStatus(c.Params("uuid")) + if status == nil { + return fmt.Errorf("could not find any status for ID") + } + return c.JSON(status) + } +} + +func (mgs *ModelGalleryEndpointService) GetAllStatusEndpoint() func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + return c.JSON(mgs.galleryApplier.GetAllStatus()) + } +} + +func (mgs *ModelGalleryEndpointService) ApplyModelGalleryEndpoint() func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + input := new(GalleryModel) + // Get input data from the request body + if err := c.BodyParser(input); err != nil { + return err + } + + uuid, err := uuid.NewUUID() + if err != nil { + return err + } + mgs.galleryApplier.C <- gallery.GalleryOp{ + Req: input.GalleryModel, + Id: uuid.String(), + GalleryName: input.ID, + Galleries: mgs.galleries, + } + return c.JSON(struct { + ID string `json:"uuid"` + StatusURL string `json:"status"` + }{ID: uuid.String(), StatusURL: c.BaseURL() + "/models/jobs/" + uuid.String()}) + } +} + +func (mgs *ModelGalleryEndpointService) ListModelFromGalleryEndpoint() func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + log.Debug().Msgf("Listing models from galleries: %+v", mgs.galleries) + + models, err := gallery.AvailableGalleryModels(mgs.galleries, mgs.modelPath) + if err != nil { + return err + } + log.Debug().Msgf("Models found from galleries: %+v", models) + for _, m := range models { + log.Debug().Msgf("Model found from galleries: %+v", m) + } + dat, err := json.Marshal(models) + if err != nil { + return err + } + return c.Send(dat) + } +} + +// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents! +func (mgs *ModelGalleryEndpointService) ListModelGalleriesEndpoint() func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + log.Debug().Msgf("Listing model galleries %+v", mgs.galleries) + dat, err := json.Marshal(mgs.galleries) + if err != nil { + return err + } + return c.Send(dat) + } +} + +func (mgs *ModelGalleryEndpointService) AddModelGalleryEndpoint() func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + input := new(gallery.Gallery) + // Get input data from the request body + if err := c.BodyParser(input); err != nil { + return err + } + if slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool { + return gallery.Name == input.Name + }) { + return fmt.Errorf("%s already exists", input.Name) + } + dat, err := json.Marshal(mgs.galleries) + if err != nil { + return err + } + log.Debug().Msgf("Adding %+v to gallery list", *input) + mgs.galleries = append(mgs.galleries, *input) + return c.Send(dat) + } +} + +func (mgs *ModelGalleryEndpointService) RemoveModelGalleryEndpoint() func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + input := new(gallery.Gallery) + // Get input data from the request body + if err := c.BodyParser(input); err != nil { + return err + } + if !slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool { + return gallery.Name == input.Name + }) { + return fmt.Errorf("%s is not currently registered", input.Name) + } + mgs.galleries = slices.DeleteFunc(mgs.galleries, func(gallery gallery.Gallery) bool { + return gallery.Name == input.Name + }) + return c.Send(nil) + } +} diff --git a/core/http/endpoints/localai/metrics.go b/core/http/endpoints/localai/metrics.go new file mode 100644 index 00000000..23c2af7a --- /dev/null +++ b/core/http/endpoints/localai/metrics.go @@ -0,0 +1,43 @@ +package localai + +import ( + "time" + + "github.com/go-skynet/LocalAI/core/services" + "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/middleware/adaptor" + "github.com/prometheus/client_golang/prometheus/promhttp" +) + +func LocalAIMetricsEndpoint() fiber.Handler { + + return adaptor.HTTPHandler(promhttp.Handler()) +} + +type apiMiddlewareConfig struct { + Filter func(c *fiber.Ctx) bool + metricsService *services.LocalAIMetricsService +} + +func LocalAIMetricsAPIMiddleware(metrics *services.LocalAIMetricsService) fiber.Handler { + cfg := apiMiddlewareConfig{ + metricsService: metrics, + Filter: func(c *fiber.Ctx) bool { + return c.Path() == "/metrics" + }, + } + + return func(c *fiber.Ctx) error { + if cfg.Filter != nil && cfg.Filter(c) { + return c.Next() + } + path := c.Path() + method := c.Method() + + start := time.Now() + err := c.Next() + elapsed := float64(time.Since(start)) / float64(time.Second) + cfg.metricsService.ObserveAPICall(method, path, elapsed) + return err + } +} diff --git a/api/localai/localai.go b/core/http/endpoints/localai/tts.go similarity index 56% rename from api/localai/localai.go rename to core/http/endpoints/localai/tts.go index 9d5bbf6c..84fb7a55 100644 --- a/api/localai/localai.go +++ b/core/http/endpoints/localai/tts.go @@ -1,37 +1,32 @@ package localai import ( - fiberContext "github.com/go-skynet/LocalAI/api/ctx" "github.com/go-skynet/LocalAI/core/backend" - config "github.com/go-skynet/LocalAI/core/config" - "github.com/rs/zerolog/log" + "github.com/go-skynet/LocalAI/core/config" + fiberContext "github.com/go-skynet/LocalAI/core/http/ctx" + "github.com/go-skynet/LocalAI/pkg/model" - "github.com/go-skynet/LocalAI/core/options" + "github.com/go-skynet/LocalAI/core/schema" "github.com/gofiber/fiber/v2" + "github.com/rs/zerolog/log" ) -type TTSRequest struct { - Model string `json:"model" yaml:"model"` - Input string `json:"input" yaml:"input"` - Backend string `json:"backend" yaml:"backend"` -} - -func TTSEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { +func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - input := new(TTSRequest) + input := new(schema.TTSRequest) // Get input data from the request body if err := c.BodyParser(input); err != nil { return err } - modelFile, err := fiberContext.ModelFromContext(c, o.Loader, input.Model, false) + modelFile, err := fiberContext.ModelFromContext(c, ml, input.Model, false) if err != nil { modelFile = input.Model log.Warn().Msgf("Model not found in context: %s", input.Model) } - cfg, err := config.Load(modelFile, o.Loader.ModelPath, cm, false, 0, 0, false) + cfg, err := config.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath, cl, false, 0, 0, false) if err != nil { modelFile = input.Model log.Warn().Msgf("Model not found in context: %s", input.Model) @@ -44,7 +39,7 @@ func TTSEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) cfg.Backend = input.Backend } - filePath, _, err := backend.ModelTTS(cfg.Backend, input.Input, modelFile, o.Loader, o, *cfg) + filePath, _, err := backend.ModelTTS(cfg.Backend, input.Input, modelFile, ml, appConfig, *cfg) if err != nil { return err } diff --git a/api/openai/chat.go b/core/http/endpoints/openai/chat.go similarity index 90% rename from api/openai/chat.go rename to core/http/endpoints/openai/chat.go index cd535f0a..3add0972 100644 --- a/api/openai/chat.go +++ b/core/http/endpoints/openai/chat.go @@ -9,8 +9,7 @@ import ( "time" "github.com/go-skynet/LocalAI/core/backend" - config "github.com/go-skynet/LocalAI/core/config" - "github.com/go-skynet/LocalAI/core/options" + "github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/core/schema" "github.com/go-skynet/LocalAI/pkg/grammar" model "github.com/go-skynet/LocalAI/pkg/model" @@ -21,12 +20,12 @@ import ( "github.com/valyala/fasthttp" ) -func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { +func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startupOptions *config.ApplicationConfig) func(c *fiber.Ctx) error { emptyMessage := "" id := uuid.New().String() created := int(time.Now().Unix()) - process := func(s string, req *schema.OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan schema.OpenAIResponse) { + process := func(s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse) { initialMessage := schema.OpenAIResponse{ ID: id, Created: created, @@ -36,7 +35,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) } responses <- initialMessage - ComputeChoices(req, s, config, o, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool { + ComputeChoices(req, s, config, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool { resp := schema.OpenAIResponse{ ID: id, Created: created, @@ -55,9 +54,9 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) }) close(responses) } - processTools := func(noAction string, prompt string, req *schema.OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan schema.OpenAIResponse) { + processTools := func(noAction string, prompt string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse) { result := "" - _, tokenUsage, _ := ComputeChoices(req, prompt, config, o, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool { + _, tokenUsage, _ := ComputeChoices(req, prompt, config, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool { result += s // TODO: Change generated BNF grammar to be compliant with the schema so we can // stream the result token by token here. @@ -78,7 +77,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) } responses <- initialMessage - result, err := handleQuestion(config, req, o, results[0].arguments, prompt) + result, err := handleQuestion(config, req, ml, startupOptions, results[0].arguments, prompt) if err != nil { log.Error().Msgf("error handling question: %s", err.Error()) return @@ -154,12 +153,12 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) return func(c *fiber.Ctx) error { processFunctions := false funcs := grammar.Functions{} - modelFile, input, err := readRequest(c, o, true) + modelFile, input, err := readRequest(c, ml, startupOptions, true) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } - config, input, err := mergeRequestWithConfig(modelFile, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) + config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, startupOptions.Debug, startupOptions.Threads, startupOptions.ContextSize, startupOptions.F16) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } @@ -252,7 +251,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) FunctionName: i.Name, MessageIndex: messageIndex, } - templatedChatMessage, err := o.Loader.EvaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData) + templatedChatMessage, err := ml.EvaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData) if err != nil { log.Error().Msgf("error processing message %+v using template \"%s\": %v. Skipping!", chatMessageData, config.TemplateConfig.ChatMessage, err) } else { @@ -320,7 +319,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) templateFile := "" // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - if o.Loader.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) { + if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) { templateFile = config.Model } @@ -333,7 +332,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) } if templateFile != "" { - templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.ChatPromptTemplate, templateFile, model.PromptTemplateData{ + templatedInput, err := ml.EvaluateTemplateForPrompt(model.ChatPromptTemplate, templateFile, model.PromptTemplateData{ SystemPrompt: config.SystemPrompt, SuppressSystemPrompt: suppressConfigSystemPrompt, Input: predInput, @@ -357,9 +356,9 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) responses := make(chan schema.OpenAIResponse) if !processFunctions { - go process(predInput, input, config, o.Loader, responses) + go process(predInput, input, config, ml, responses) } else { - go processTools(noActionName, predInput, input, config, o.Loader, responses) + go processTools(noActionName, predInput, input, config, ml, responses) } c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { @@ -413,7 +412,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) // no streaming mode default: - result, tokenUsage, err := ComputeChoices(input, predInput, config, o, o.Loader, func(s string, c *[]schema.Choice) { + result, tokenUsage, err := ComputeChoices(input, predInput, config, startupOptions, ml, func(s string, c *[]schema.Choice) { if !processFunctions { // no function is called, just reply and use stop as finish reason *c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}}) @@ -425,7 +424,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) switch { case noActionsToRun: - result, err := handleQuestion(config, input, o, results[0].arguments, predInput) + result, err := handleQuestion(config, input, ml, startupOptions, results[0].arguments, predInput) if err != nil { log.Error().Msgf("error handling question: %s", err.Error()) return @@ -506,7 +505,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) } } -func handleQuestion(config *config.Config, input *schema.OpenAIRequest, o *options.Option, args, prompt string) (string, error) { +func handleQuestion(config *config.BackendConfig, input *schema.OpenAIRequest, ml *model.ModelLoader, o *config.ApplicationConfig, args, prompt string) (string, error) { log.Debug().Msgf("nothing to do, computing a reply") // If there is a message that the LLM already sends as part of the JSON reply, use it @@ -535,7 +534,7 @@ func handleQuestion(config *config.Config, input *schema.OpenAIRequest, o *optio images = append(images, m.StringImages...) } - predFunc, err := backend.ModelInference(input.Context, prompt, images, o.Loader, *config, o, nil) + predFunc, err := backend.ModelInference(input.Context, prompt, images, ml, *config, o, nil) if err != nil { log.Error().Msgf("inference error: %s", err.Error()) return "", err diff --git a/api/openai/completion.go b/core/http/endpoints/openai/completion.go similarity index 82% rename from api/openai/completion.go rename to core/http/endpoints/openai/completion.go index af56625e..9344f9fe 100644 --- a/api/openai/completion.go +++ b/core/http/endpoints/openai/completion.go @@ -9,8 +9,8 @@ import ( "time" "github.com/go-skynet/LocalAI/core/backend" - config "github.com/go-skynet/LocalAI/core/config" - "github.com/go-skynet/LocalAI/core/options" + "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/core/schema" "github.com/go-skynet/LocalAI/pkg/grammar" model "github.com/go-skynet/LocalAI/pkg/model" @@ -21,12 +21,12 @@ import ( ) // https://platform.openai.com/docs/api-reference/completions -func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { +func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { id := uuid.New().String() created := int(time.Now().Unix()) - process := func(s string, req *schema.OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan schema.OpenAIResponse) { - ComputeChoices(req, s, config, o, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool { + process := func(s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse) { + ComputeChoices(req, s, config, appConfig, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool { resp := schema.OpenAIResponse{ ID: id, Created: created, @@ -53,14 +53,14 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe } return func(c *fiber.Ctx) error { - modelFile, input, err := readRequest(c, o, true) + modelFile, input, err := readRequest(c, ml, appConfig, true) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } log.Debug().Msgf("`input`: %+v", input) - config, input, err := mergeRequestWithConfig(modelFile, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) + config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } @@ -84,7 +84,7 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe templateFile := "" // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - if o.Loader.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) { + if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) { templateFile = config.Model } @@ -100,7 +100,7 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe predInput := config.PromptStrings[0] if templateFile != "" { - templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{ + templatedInput, err := ml.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{ Input: predInput, }) if err == nil { @@ -111,7 +111,7 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe responses := make(chan schema.OpenAIResponse) - go process(predInput, input, config, o.Loader, responses) + go process(predInput, input, config, ml, responses) c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { @@ -153,7 +153,7 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe for k, i := range config.PromptStrings { if templateFile != "" { // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{ + templatedInput, err := ml.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{ SystemPrompt: config.SystemPrompt, Input: i, }) @@ -164,7 +164,7 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe } r, tokenUsage, err := ComputeChoices( - input, i, config, o, o.Loader, func(s string, c *[]schema.Choice) { + input, i, config, appConfig, ml, func(s string, c *[]schema.Choice) { *c = append(*c, schema.Choice{Text: s, FinishReason: "stop", Index: k}) }, nil) if err != nil { diff --git a/api/openai/edit.go b/core/http/endpoints/openai/edit.go similarity index 77% rename from api/openai/edit.go rename to core/http/endpoints/openai/edit.go index 56b17920..25497095 100644 --- a/api/openai/edit.go +++ b/core/http/endpoints/openai/edit.go @@ -6,8 +6,8 @@ import ( "time" "github.com/go-skynet/LocalAI/core/backend" - config "github.com/go-skynet/LocalAI/core/config" - "github.com/go-skynet/LocalAI/core/options" + "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/core/schema" model "github.com/go-skynet/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" @@ -16,14 +16,14 @@ import ( "github.com/rs/zerolog/log" ) -func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { +func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - modelFile, input, err := readRequest(c, o, true) + modelFile, input, err := readRequest(c, ml, appConfig, true) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } - config, input, err := mergeRequestWithConfig(modelFile, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) + config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } @@ -33,7 +33,7 @@ func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) templateFile := "" // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - if o.Loader.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) { + if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) { templateFile = config.Model } @@ -46,7 +46,7 @@ func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) for _, i := range config.InputStrings { if templateFile != "" { - templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.EditPromptTemplate, templateFile, model.PromptTemplateData{ + templatedInput, err := ml.EvaluateTemplateForPrompt(model.EditPromptTemplate, templateFile, model.PromptTemplateData{ Input: i, Instruction: input.Instruction, SystemPrompt: config.SystemPrompt, @@ -57,7 +57,7 @@ func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) } } - r, tokenUsage, err := ComputeChoices(input, i, config, o, o.Loader, func(s string, c *[]schema.Choice) { + r, tokenUsage, err := ComputeChoices(input, i, config, appConfig, ml, func(s string, c *[]schema.Choice) { *c = append(*c, schema.Choice{Text: s}) }, nil) if err != nil { diff --git a/api/openai/embeddings.go b/core/http/endpoints/openai/embeddings.go similarity index 73% rename from api/openai/embeddings.go rename to core/http/endpoints/openai/embeddings.go index 198493e1..774b0a5e 100644 --- a/api/openai/embeddings.go +++ b/core/http/endpoints/openai/embeddings.go @@ -6,24 +6,25 @@ import ( "time" "github.com/go-skynet/LocalAI/core/backend" - config "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/pkg/model" + "github.com/go-skynet/LocalAI/core/schema" "github.com/google/uuid" - "github.com/go-skynet/LocalAI/core/options" "github.com/gofiber/fiber/v2" "github.com/rs/zerolog/log" ) // https://platform.openai.com/docs/api-reference/embeddings -func EmbeddingsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { +func EmbeddingsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - model, input, err := readRequest(c, o, true) + model, input, err := readRequest(c, ml, appConfig, true) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } - config, input, err := mergeRequestWithConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) + config, input, err := mergeRequestWithConfig(model, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } @@ -33,7 +34,7 @@ func EmbeddingsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe for i, s := range config.InputToken { // get the model function to call for the result - embedFn, err := backend.ModelEmbedding("", s, o.Loader, *config, o) + embedFn, err := backend.ModelEmbedding("", s, ml, *config, appConfig) if err != nil { return err } @@ -47,7 +48,7 @@ func EmbeddingsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe for i, s := range config.InputStrings { // get the model function to call for the result - embedFn, err := backend.ModelEmbedding(s, []int{}, o.Loader, *config, o) + embedFn, err := backend.ModelEmbedding(s, []int{}, ml, *config, appConfig) if err != nil { return err } diff --git a/api/openai/files.go b/core/http/endpoints/openai/files.go similarity index 83% rename from api/openai/files.go rename to core/http/endpoints/openai/files.go index 140b4151..5cb8d7a9 100644 --- a/api/openai/files.go +++ b/core/http/endpoints/openai/files.go @@ -8,8 +8,8 @@ import ( "path/filepath" "time" - config "github.com/go-skynet/LocalAI/core/config" - "github.com/go-skynet/LocalAI/core/options" + "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/pkg/utils" "github.com/gofiber/fiber/v2" "github.com/rs/zerolog/log" @@ -62,7 +62,7 @@ func LoadUploadConfig(uploadPath string) { } // UploadFilesEndpoint https://platform.openai.com/docs/api-reference/files/create -func UploadFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { +func UploadFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { file, err := c.FormFile("file") if err != nil { @@ -70,8 +70,8 @@ func UploadFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fib } // Check the file size - if file.Size > int64(o.UploadLimitMB*1024*1024) { - return c.Status(fiber.StatusBadRequest).SendString(fmt.Sprintf("File size %d exceeds upload limit %d", file.Size, o.UploadLimitMB)) + if file.Size > int64(appConfig.UploadLimitMB*1024*1024) { + return c.Status(fiber.StatusBadRequest).SendString(fmt.Sprintf("File size %d exceeds upload limit %d", file.Size, appConfig.UploadLimitMB)) } purpose := c.FormValue("purpose", "") //TODO put in purpose dirs @@ -82,7 +82,7 @@ func UploadFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fib // Sanitize the filename to prevent directory traversal filename := utils.SanitizeFileName(file.Filename) - savePath := filepath.Join(o.UploadDir, filename) + savePath := filepath.Join(appConfig.UploadDir, filename) // Check if file already exists if _, err := os.Stat(savePath); !os.IsNotExist(err) { @@ -104,13 +104,13 @@ func UploadFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fib } uploadedFiles = append(uploadedFiles, f) - saveUploadConfig(o.UploadDir) + saveUploadConfig(appConfig.UploadDir) return c.Status(fiber.StatusOK).JSON(f) } } // ListFilesEndpoint https://platform.openai.com/docs/api-reference/files/list -func ListFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { +func ListFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { type ListFiles struct { Data []File Object string @@ -150,7 +150,7 @@ func getFileFromRequest(c *fiber.Ctx) (*File, error) { } // GetFilesEndpoint https://platform.openai.com/docs/api-reference/files/retrieve -func GetFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { +func GetFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { file, err := getFileFromRequest(c) if err != nil { @@ -162,7 +162,7 @@ func GetFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber. } // DeleteFilesEndpoint https://platform.openai.com/docs/api-reference/files/delete -func DeleteFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { +func DeleteFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { type DeleteStatus struct { Id string Object string @@ -175,7 +175,7 @@ func DeleteFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fib return c.Status(fiber.StatusInternalServerError).SendString(err.Error()) } - err = os.Remove(filepath.Join(o.UploadDir, file.Filename)) + err = os.Remove(filepath.Join(appConfig.UploadDir, file.Filename)) if err != nil { // If the file doesn't exist then we should just continue to remove it if !errors.Is(err, os.ErrNotExist) { @@ -191,7 +191,7 @@ func DeleteFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fib } } - saveUploadConfig(o.UploadDir) + saveUploadConfig(appConfig.UploadDir) return c.JSON(DeleteStatus{ Id: file.ID, Object: "file", @@ -201,14 +201,14 @@ func DeleteFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fib } // GetFilesContentsEndpoint https://platform.openai.com/docs/api-reference/files/retrieve-contents -func GetFilesContentsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { +func GetFilesContentsEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { file, err := getFileFromRequest(c) if err != nil { return c.Status(fiber.StatusInternalServerError).SendString(err.Error()) } - fileContents, err := os.ReadFile(filepath.Join(o.UploadDir, file.Filename)) + fileContents, err := os.ReadFile(filepath.Join(appConfig.UploadDir, file.Filename)) if err != nil { return c.Status(fiber.StatusInternalServerError).SendString(err.Error()) } diff --git a/api/openai/files_test.go b/core/http/endpoints/openai/files_test.go similarity index 92% rename from api/openai/files_test.go rename to core/http/endpoints/openai/files_test.go index 535cde8b..a036bd0d 100644 --- a/api/openai/files_test.go +++ b/core/http/endpoints/openai/files_test.go @@ -11,8 +11,8 @@ import ( "path/filepath" "strings" - config "github.com/go-skynet/LocalAI/core/config" - "github.com/go-skynet/LocalAI/core/options" + "github.com/go-skynet/LocalAI/core/config" + utils2 "github.com/go-skynet/LocalAI/pkg/utils" "github.com/gofiber/fiber/v2" "github.com/stretchr/testify/assert" @@ -25,11 +25,11 @@ type ListFiles struct { Object string } -func startUpApp() (app *fiber.App, option *options.Option, loader *config.ConfigLoader) { +func startUpApp() (app *fiber.App, option *config.ApplicationConfig, loader *config.BackendConfigLoader) { // Preparing the mocked objects - loader = &config.ConfigLoader{} + loader = &config.BackendConfigLoader{} - option = &options.Option{ + option = &config.ApplicationConfig{ UploadLimitMB: 10, UploadDir: "test_dir", } @@ -52,9 +52,9 @@ func startUpApp() (app *fiber.App, option *options.Option, loader *config.Config func TestUploadFileExceedSizeLimit(t *testing.T) { // Preparing the mocked objects - loader := &config.ConfigLoader{} + loader := &config.BackendConfigLoader{} - option := &options.Option{ + option := &config.ApplicationConfig{ UploadLimitMB: 10, UploadDir: "test_dir", } @@ -174,9 +174,9 @@ func CallFilesContentEndpoint(t *testing.T, app *fiber.App, fileId string) (*htt return app.Test(request) } -func CallFilesUploadEndpoint(t *testing.T, app *fiber.App, fileName, tag, purpose string, fileSize int, o *options.Option) (*http.Response, error) { +func CallFilesUploadEndpoint(t *testing.T, app *fiber.App, fileName, tag, purpose string, fileSize int, appConfig *config.ApplicationConfig) (*http.Response, error) { // Create a file that exceeds the limit - file := createTestFile(t, fileName, fileSize, o) + file := createTestFile(t, fileName, fileSize, appConfig) // Creating a new HTTP Request body, writer := newMultipartFile(file.Name(), tag, purpose) @@ -186,9 +186,9 @@ func CallFilesUploadEndpoint(t *testing.T, app *fiber.App, fileName, tag, purpos return app.Test(req) } -func CallFilesUploadEndpointWithCleanup(t *testing.T, app *fiber.App, fileName, tag, purpose string, fileSize int, o *options.Option) File { +func CallFilesUploadEndpointWithCleanup(t *testing.T, app *fiber.App, fileName, tag, purpose string, fileSize int, appConfig *config.ApplicationConfig) File { // Create a file that exceeds the limit - file := createTestFile(t, fileName, fileSize, o) + file := createTestFile(t, fileName, fileSize, appConfig) // Creating a new HTTP Request body, writer := newMultipartFile(file.Name(), tag, purpose) @@ -233,7 +233,7 @@ func newMultipartFile(filePath, tag, purpose string) (*strings.Reader, *multipar } // Helper to create test files -func createTestFile(t *testing.T, name string, sizeMB int, option *options.Option) *os.File { +func createTestFile(t *testing.T, name string, sizeMB int, option *config.ApplicationConfig) *os.File { err := os.MkdirAll(option.UploadDir, 0755) if err != nil { diff --git a/api/openai/image.go b/core/http/endpoints/openai/image.go similarity index 87% rename from api/openai/image.go rename to core/http/endpoints/openai/image.go index 2da6883e..8f535801 100644 --- a/api/openai/image.go +++ b/core/http/endpoints/openai/image.go @@ -13,12 +13,12 @@ import ( "strings" "time" + "github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/core/schema" "github.com/google/uuid" "github.com/go-skynet/LocalAI/core/backend" - config "github.com/go-skynet/LocalAI/core/config" - "github.com/go-skynet/LocalAI/core/options" + model "github.com/go-skynet/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" "github.com/rs/zerolog/log" @@ -59,9 +59,9 @@ func downloadFile(url string) (string, error) { * */ -func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { +func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - m, input, err := readRequest(c, o, false) + m, input, err := readRequest(c, ml, appConfig, false) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } @@ -71,7 +71,7 @@ func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx } log.Debug().Msgf("Loading model: %+v", m) - config, input, err := mergeRequestWithConfig(m, input, cm, o.Loader, o.Debug, 0, 0, false) + config, input, err := mergeRequestWithConfig(m, input, cl, ml, appConfig.Debug, 0, 0, false) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } @@ -104,7 +104,7 @@ func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx } // Create a temporary file - outputFile, err := os.CreateTemp(o.ImageDir, "b64") + outputFile, err := os.CreateTemp(appConfig.ImageDir, "b64") if err != nil { return err } @@ -133,15 +133,15 @@ func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx sizeParts := strings.Split(input.Size, "x") if len(sizeParts) != 2 { - return fmt.Errorf("Invalid value for 'size'") + return fmt.Errorf("invalid value for 'size'") } width, err := strconv.Atoi(sizeParts[0]) if err != nil { - return fmt.Errorf("Invalid value for 'size'") + return fmt.Errorf("invalid value for 'size'") } height, err := strconv.Atoi(sizeParts[1]) if err != nil { - return fmt.Errorf("Invalid value for 'size'") + return fmt.Errorf("invalid value for 'size'") } b64JSON := false @@ -179,7 +179,7 @@ func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx tempDir := "" if !b64JSON { - tempDir = o.ImageDir + tempDir = appConfig.ImageDir } // Create a temporary file outputFile, err := os.CreateTemp(tempDir, "b64") @@ -196,7 +196,7 @@ func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx baseURL := c.BaseURL() - fn, err := backend.ImageGeneration(height, width, mode, step, input.Seed, positive_prompt, negative_prompt, src, output, o.Loader, *config, o) + fn, err := backend.ImageGeneration(height, width, mode, step, input.Seed, positive_prompt, negative_prompt, src, output, ml, *config, appConfig) if err != nil { return err } diff --git a/api/openai/inference.go b/core/http/endpoints/openai/inference.go similarity index 90% rename from api/openai/inference.go rename to core/http/endpoints/openai/inference.go index 184688b2..5d97d21d 100644 --- a/api/openai/inference.go +++ b/core/http/endpoints/openai/inference.go @@ -2,8 +2,8 @@ package openai import ( "github.com/go-skynet/LocalAI/core/backend" - config "github.com/go-skynet/LocalAI/core/config" - "github.com/go-skynet/LocalAI/core/options" + "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/core/schema" model "github.com/go-skynet/LocalAI/pkg/model" ) @@ -11,8 +11,8 @@ import ( func ComputeChoices( req *schema.OpenAIRequest, predInput string, - config *config.Config, - o *options.Option, + config *config.BackendConfig, + o *config.ApplicationConfig, loader *model.ModelLoader, cb func(string, *[]schema.Choice), tokenCallback func(string, backend.TokenUsage) bool) ([]schema.Choice, backend.TokenUsage, error) { diff --git a/api/openai/list.go b/core/http/endpoints/openai/list.go similarity index 87% rename from api/openai/list.go rename to core/http/endpoints/openai/list.go index 614d5c80..04e611a2 100644 --- a/api/openai/list.go +++ b/core/http/endpoints/openai/list.go @@ -3,15 +3,15 @@ package openai import ( "regexp" - config "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/core/schema" model "github.com/go-skynet/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" ) -func ListModelsEndpoint(loader *model.ModelLoader, cm *config.ConfigLoader) func(ctx *fiber.Ctx) error { +func ListModelsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader) func(ctx *fiber.Ctx) error { return func(c *fiber.Ctx) error { - models, err := loader.ListModels() + models, err := ml.ListModels() if err != nil { return err } @@ -40,7 +40,7 @@ func ListModelsEndpoint(loader *model.ModelLoader, cm *config.ConfigLoader) func excludeConfigured := c.QueryBool("excludeConfigured", true) // Start with the known configurations - for _, c := range cm.GetAllConfigs() { + for _, c := range cl.GetAllBackendConfigs() { if excludeConfigured { mm[c.Model] = nil } diff --git a/api/openai/request.go b/core/http/endpoints/openai/request.go similarity index 89% rename from api/openai/request.go rename to core/http/endpoints/openai/request.go index 83c41d97..46ff2438 100644 --- a/api/openai/request.go +++ b/core/http/endpoints/openai/request.go @@ -5,13 +5,12 @@ import ( "encoding/base64" "encoding/json" "fmt" - "io/ioutil" + "io" "net/http" "strings" - fiberContext "github.com/go-skynet/LocalAI/api/ctx" - config "github.com/go-skynet/LocalAI/core/config" - options "github.com/go-skynet/LocalAI/core/options" + "github.com/go-skynet/LocalAI/core/config" + fiberContext "github.com/go-skynet/LocalAI/core/http/ctx" "github.com/go-skynet/LocalAI/core/schema" "github.com/go-skynet/LocalAI/pkg/grammar" model "github.com/go-skynet/LocalAI/pkg/model" @@ -19,11 +18,9 @@ import ( "github.com/rs/zerolog/log" ) -func readRequest(c *fiber.Ctx, o *options.Option, firstModel bool) (string, *schema.OpenAIRequest, error) { +func readRequest(c *fiber.Ctx, ml *model.ModelLoader, o *config.ApplicationConfig, firstModel bool) (string, *schema.OpenAIRequest, error) { input := new(schema.OpenAIRequest) - ctx, cancel := context.WithCancel(o.Context) - input.Context = ctx - input.Cancel = cancel + // Get input data from the request body if err := c.BodyParser(input); err != nil { return "", nil, fmt.Errorf("failed parsing request body: %w", err) @@ -31,9 +28,13 @@ func readRequest(c *fiber.Ctx, o *options.Option, firstModel bool) (string, *sch received, _ := json.Marshal(input) + ctx, cancel := context.WithCancel(o.Context) + input.Context = ctx + input.Cancel = cancel + log.Debug().Msgf("Request received: %s", string(received)) - modelFile, err := fiberContext.ModelFromContext(c, o.Loader, input.Model, firstModel) + modelFile, err := fiberContext.ModelFromContext(c, ml, input.Model, firstModel) return modelFile, input, err } @@ -50,7 +51,7 @@ func getBase64Image(s string) (string, error) { defer resp.Body.Close() // read the image data into memory - data, err := ioutil.ReadAll(resp.Body) + data, err := io.ReadAll(resp.Body) if err != nil { return "", err } @@ -69,7 +70,7 @@ func getBase64Image(s string) (string, error) { return "", fmt.Errorf("not valid string") } -func updateRequestConfig(config *config.Config, input *schema.OpenAIRequest) { +func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIRequest) { if input.Echo { config.Echo = input.Echo } @@ -270,8 +271,8 @@ func updateRequestConfig(config *config.Config, input *schema.OpenAIRequest) { } } -func mergeRequestWithConfig(modelFile string, input *schema.OpenAIRequest, cm *config.ConfigLoader, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*config.Config, *schema.OpenAIRequest, error) { - cfg, err := config.Load(modelFile, loader.ModelPath, cm, debug, threads, ctx, f16) +func mergeRequestWithConfig(modelFile string, input *schema.OpenAIRequest, cm *config.BackendConfigLoader, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*config.BackendConfig, *schema.OpenAIRequest, error) { + cfg, err := config.LoadBackendConfigFileByName(modelFile, loader.ModelPath, cm, debug, threads, ctx, f16) // Set the parameters for the language model prediction updateRequestConfig(cfg, input) diff --git a/api/openai/transcription.go b/core/http/endpoints/openai/transcription.go similarity index 71% rename from api/openai/transcription.go rename to core/http/endpoints/openai/transcription.go index c3fd7d5c..403f8b02 100644 --- a/api/openai/transcription.go +++ b/core/http/endpoints/openai/transcription.go @@ -9,22 +9,22 @@ import ( "path/filepath" "github.com/go-skynet/LocalAI/core/backend" - config "github.com/go-skynet/LocalAI/core/config" - "github.com/go-skynet/LocalAI/core/options" + "github.com/go-skynet/LocalAI/core/config" + model "github.com/go-skynet/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" "github.com/rs/zerolog/log" ) // https://platform.openai.com/docs/api-reference/audio/create -func TranscriptEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { +func TranscriptEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - m, input, err := readRequest(c, o, false) + m, input, err := readRequest(c, ml, appConfig, false) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } - config, input, err := mergeRequestWithConfig(m, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) + config, input, err := mergeRequestWithConfig(m, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } @@ -59,7 +59,7 @@ func TranscriptEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe log.Debug().Msgf("Audio file copied to: %+v", dst) - tr, err := backend.ModelTranscription(dst, input.Language, o.Loader, *config, o) + tr, err := backend.ModelTranscription(dst, input.Language, ml, *config, appConfig) if err != nil { return err } diff --git a/core/schema/localai.go b/core/schema/localai.go new file mode 100644 index 00000000..115183a3 --- /dev/null +++ b/core/schema/localai.go @@ -0,0 +1,21 @@ +package schema + +import ( + gopsutil "github.com/shirou/gopsutil/v3/process" +) + +type BackendMonitorRequest struct { + Model string `json:"model" yaml:"model"` +} + +type BackendMonitorResponse struct { + MemoryInfo *gopsutil.MemoryInfoStat + MemoryPercent float32 + CPUPercent float64 +} + +type TTSRequest struct { + Model string `json:"model" yaml:"model"` + Input string `json:"input" yaml:"input"` + Backend string `json:"backend" yaml:"backend"` +} diff --git a/core/schema/openai.go b/core/schema/openai.go index 53dd5324..1c13847c 100644 --- a/core/schema/openai.go +++ b/core/schema/openai.go @@ -3,8 +3,6 @@ package schema import ( "context" - config "github.com/go-skynet/LocalAI/core/config" - "github.com/go-skynet/LocalAI/pkg/grammar" ) @@ -108,10 +106,10 @@ type ChatCompletionResponseFormat struct { } type OpenAIRequest struct { - config.PredictionOptions + PredictionOptions - Context context.Context - Cancel context.CancelFunc + Context context.Context `json:"-"` + Cancel context.CancelFunc `json:"-"` // whisper File string `json:"file" validate:"required"` diff --git a/core/config/prediction.go b/core/schema/prediction.go similarity index 99% rename from core/config/prediction.go rename to core/schema/prediction.go index dccb4dfb..efd085a4 100644 --- a/core/config/prediction.go +++ b/core/schema/prediction.go @@ -1,4 +1,4 @@ -package config +package schema type PredictionOptions struct { diff --git a/core/services/backend_monitor.go b/core/services/backend_monitor.go new file mode 100644 index 00000000..88176753 --- /dev/null +++ b/core/services/backend_monitor.go @@ -0,0 +1,140 @@ +package services + +import ( + "context" + "fmt" + "strings" + + "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/core/schema" + "github.com/go-skynet/LocalAI/pkg/grpc/proto" + "github.com/go-skynet/LocalAI/pkg/model" + + "github.com/rs/zerolog/log" + + gopsutil "github.com/shirou/gopsutil/v3/process" +) + +type BackendMonitor struct { + configLoader *config.BackendConfigLoader + modelLoader *model.ModelLoader + options *config.ApplicationConfig // Taking options in case we need to inspect ExternalGRPCBackends, though that's out of scope for now, hence the name. +} + +func NewBackendMonitor(configLoader *config.BackendConfigLoader, modelLoader *model.ModelLoader, appConfig *config.ApplicationConfig) BackendMonitor { + return BackendMonitor{ + configLoader: configLoader, + modelLoader: modelLoader, + options: appConfig, + } +} + +func (bm BackendMonitor) getModelLoaderIDFromModelName(modelName string) (string, error) { + config, exists := bm.configLoader.GetBackendConfig(modelName) + var backendId string + if exists { + backendId = config.Model + } else { + // Last ditch effort: use it raw, see if a backend happens to match. + backendId = modelName + } + + if !strings.HasSuffix(backendId, ".bin") { + backendId = fmt.Sprintf("%s.bin", backendId) + } + + return backendId, nil +} + +func (bm *BackendMonitor) SampleLocalBackendProcess(model string) (*schema.BackendMonitorResponse, error) { + config, exists := bm.configLoader.GetBackendConfig(model) + var backend string + if exists { + backend = config.Model + } else { + // Last ditch effort: use it raw, see if a backend happens to match. + backend = model + } + + if !strings.HasSuffix(backend, ".bin") { + backend = fmt.Sprintf("%s.bin", backend) + } + + pid, err := bm.modelLoader.GetGRPCPID(backend) + + if err != nil { + log.Error().Msgf("model %s : failed to find pid %+v", model, err) + return nil, err + } + + // Name is slightly frightening but this does _not_ create a new process, rather it looks up an existing process by PID. + backendProcess, err := gopsutil.NewProcess(int32(pid)) + + if err != nil { + log.Error().Msgf("model %s [PID %d] : error getting process info %+v", model, pid, err) + return nil, err + } + + memInfo, err := backendProcess.MemoryInfo() + + if err != nil { + log.Error().Msgf("model %s [PID %d] : error getting memory info %+v", model, pid, err) + return nil, err + } + + memPercent, err := backendProcess.MemoryPercent() + if err != nil { + log.Error().Msgf("model %s [PID %d] : error getting memory percent %+v", model, pid, err) + return nil, err + } + + cpuPercent, err := backendProcess.CPUPercent() + if err != nil { + log.Error().Msgf("model %s [PID %d] : error getting cpu percent %+v", model, pid, err) + return nil, err + } + + return &schema.BackendMonitorResponse{ + MemoryInfo: memInfo, + MemoryPercent: memPercent, + CPUPercent: cpuPercent, + }, nil +} + +func (bm BackendMonitor) CheckAndSample(modelName string) (*proto.StatusResponse, error) { + backendId, err := bm.getModelLoaderIDFromModelName(modelName) + if err != nil { + return nil, err + } + modelAddr := bm.modelLoader.CheckIsLoaded(backendId) + if modelAddr == "" { + return nil, fmt.Errorf("backend %s is not currently loaded", backendId) + } + + status, rpcErr := modelAddr.GRPC(false, nil).Status(context.TODO()) + if rpcErr != nil { + log.Warn().Msgf("backend %s experienced an error retrieving status info: %s", backendId, rpcErr.Error()) + val, slbErr := bm.SampleLocalBackendProcess(backendId) + if slbErr != nil { + return nil, fmt.Errorf("backend %s experienced an error retrieving status info via rpc: %s, then failed local node process sample: %s", backendId, rpcErr.Error(), slbErr.Error()) + } + return &proto.StatusResponse{ + State: proto.StatusResponse_ERROR, + Memory: &proto.MemoryUsageData{ + Total: val.MemoryInfo.VMS, + Breakdown: map[string]uint64{ + "gopsutil-RSS": val.MemoryInfo.RSS, + }, + }, + }, nil + } + return status, nil +} + +func (bm BackendMonitor) ShutdownModel(modelName string) error { + backendId, err := bm.getModelLoaderIDFromModelName(modelName) + if err != nil { + return err + } + return bm.modelLoader.ShutdownModel(backendId) +} diff --git a/core/services/gallery.go b/core/services/gallery.go new file mode 100644 index 00000000..826f4573 --- /dev/null +++ b/core/services/gallery.go @@ -0,0 +1,167 @@ +package services + +import ( + "context" + "encoding/json" + "os" + "strings" + "sync" + + "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/pkg/gallery" + "github.com/go-skynet/LocalAI/pkg/utils" + "gopkg.in/yaml.v2" +) + +type GalleryService struct { + modelPath string + sync.Mutex + C chan gallery.GalleryOp + statuses map[string]*gallery.GalleryOpStatus +} + +func NewGalleryService(modelPath string) *GalleryService { + return &GalleryService{ + modelPath: modelPath, + C: make(chan gallery.GalleryOp), + statuses: make(map[string]*gallery.GalleryOpStatus), + } +} + +func prepareModel(modelPath string, req gallery.GalleryModel, cl *config.BackendConfigLoader, downloadStatus func(string, string, string, float64)) error { + + config, err := gallery.GetGalleryConfigFromURL(req.URL) + if err != nil { + return err + } + + config.Files = append(config.Files, req.AdditionalFiles...) + + return gallery.InstallModel(modelPath, req.Name, &config, req.Overrides, downloadStatus) +} + +func (g *GalleryService) UpdateStatus(s string, op *gallery.GalleryOpStatus) { + g.Lock() + defer g.Unlock() + g.statuses[s] = op +} + +func (g *GalleryService) GetStatus(s string) *gallery.GalleryOpStatus { + g.Lock() + defer g.Unlock() + + return g.statuses[s] +} + +func (g *GalleryService) GetAllStatus() map[string]*gallery.GalleryOpStatus { + g.Lock() + defer g.Unlock() + + return g.statuses +} + +func (g *GalleryService) Start(c context.Context, cl *config.BackendConfigLoader) { + go func() { + for { + select { + case <-c.Done(): + return + case op := <-g.C: + utils.ResetDownloadTimers() + + g.UpdateStatus(op.Id, &gallery.GalleryOpStatus{Message: "processing", Progress: 0}) + + // updates the status with an error + updateError := func(e error) { + g.UpdateStatus(op.Id, &gallery.GalleryOpStatus{Error: e, Processed: true, Message: "error: " + e.Error()}) + } + + // displayDownload displays the download progress + progressCallback := func(fileName string, current string, total string, percentage float64) { + g.UpdateStatus(op.Id, &gallery.GalleryOpStatus{Message: "processing", FileName: fileName, Progress: percentage, TotalFileSize: total, DownloadedFileSize: current}) + utils.DisplayDownloadFunction(fileName, current, total, percentage) + } + + var err error + // if the request contains a gallery name, we apply the gallery from the gallery list + if op.GalleryName != "" { + if strings.Contains(op.GalleryName, "@") { + err = gallery.InstallModelFromGallery(op.Galleries, op.GalleryName, g.modelPath, op.Req, progressCallback) + } else { + err = gallery.InstallModelFromGalleryByName(op.Galleries, op.GalleryName, g.modelPath, op.Req, progressCallback) + } + } else { + err = prepareModel(g.modelPath, op.Req, cl, progressCallback) + } + + if err != nil { + updateError(err) + continue + } + + // Reload models + err = cl.LoadBackendConfigsFromPath(g.modelPath) + if err != nil { + updateError(err) + continue + } + + err = cl.Preload(g.modelPath) + if err != nil { + updateError(err) + continue + } + + g.UpdateStatus(op.Id, &gallery.GalleryOpStatus{Processed: true, Message: "completed", Progress: 100}) + } + } + }() +} + +type galleryModel struct { + gallery.GalleryModel `yaml:",inline"` // https://github.com/go-yaml/yaml/issues/63 + ID string `json:"id"` +} + +func processRequests(modelPath, s string, cm *config.BackendConfigLoader, galleries []gallery.Gallery, requests []galleryModel) error { + var err error + for _, r := range requests { + utils.ResetDownloadTimers() + if r.ID == "" { + err = prepareModel(modelPath, r.GalleryModel, cm, utils.DisplayDownloadFunction) + } else { + if strings.Contains(r.ID, "@") { + err = gallery.InstallModelFromGallery( + galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction) + } else { + err = gallery.InstallModelFromGalleryByName( + galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction) + } + } + } + return err +} + +func ApplyGalleryFromFile(modelPath, s string, cl *config.BackendConfigLoader, galleries []gallery.Gallery) error { + dat, err := os.ReadFile(s) + if err != nil { + return err + } + var requests []galleryModel + + if err := yaml.Unmarshal(dat, &requests); err != nil { + return err + } + + return processRequests(modelPath, s, cl, galleries, requests) +} + +func ApplyGalleryFromString(modelPath, s string, cl *config.BackendConfigLoader, galleries []gallery.Gallery) error { + var requests []galleryModel + err := json.Unmarshal([]byte(s), &requests) + if err != nil { + return err + } + + return processRequests(modelPath, s, cl, galleries, requests) +} diff --git a/core/services/metrics.go b/core/services/metrics.go new file mode 100644 index 00000000..b3107398 --- /dev/null +++ b/core/services/metrics.go @@ -0,0 +1,54 @@ +package services + +import ( + "context" + + "github.com/rs/zerolog/log" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/exporters/prometheus" + "go.opentelemetry.io/otel/metric" + metricApi "go.opentelemetry.io/otel/sdk/metric" +) + +type LocalAIMetricsService struct { + Meter metric.Meter + ApiTimeMetric metric.Float64Histogram +} + +func (m *LocalAIMetricsService) ObserveAPICall(method string, path string, duration float64) { + opts := metric.WithAttributes( + attribute.String("method", method), + attribute.String("path", path), + ) + m.ApiTimeMetric.Record(context.Background(), duration, opts) +} + +// setupOTelSDK bootstraps the OpenTelemetry pipeline. +// If it does not return an error, make sure to call shutdown for proper cleanup. +func NewLocalAIMetricsService() (*LocalAIMetricsService, error) { + exporter, err := prometheus.New() + if err != nil { + return nil, err + } + provider := metricApi.NewMeterProvider(metricApi.WithReader(exporter)) + meter := provider.Meter("github.com/go-skynet/LocalAI") + + apiTimeMetric, err := meter.Float64Histogram("api_call", metric.WithDescription("api calls")) + if err != nil { + return nil, err + } + + return &LocalAIMetricsService{ + Meter: meter, + ApiTimeMetric: apiTimeMetric, + }, nil +} + +func (lams LocalAIMetricsService) Shutdown() error { + // TODO: Not sure how to actually do this: + //// setupOTelSDK bootstraps the OpenTelemetry pipeline. + //// If it does not return an error, make sure to call shutdown for proper cleanup. + + log.Warn().Msgf("LocalAIMetricsService Shutdown called, but OTelSDK proper shutdown not yet implemented?") + return nil +} diff --git a/core/startup/config_file_watcher.go b/core/startup/config_file_watcher.go new file mode 100644 index 00000000..0c7eff2d --- /dev/null +++ b/core/startup/config_file_watcher.go @@ -0,0 +1,100 @@ +package startup + +import ( + "encoding/json" + "fmt" + "os" + "path" + + "github.com/fsnotify/fsnotify" + "github.com/go-skynet/LocalAI/core/config" + "github.com/imdario/mergo" + "github.com/rs/zerolog/log" +) + +type WatchConfigDirectoryCloser func() error + +func ReadApiKeysJson(configDir string, appConfig *config.ApplicationConfig) error { + fileContent, err := os.ReadFile(path.Join(configDir, "api_keys.json")) + if err == nil { + // Parse JSON content from the file + var fileKeys []string + err := json.Unmarshal(fileContent, &fileKeys) + if err == nil { + appConfig.ApiKeys = append(appConfig.ApiKeys, fileKeys...) + return nil + } + return err + } + return err +} + +func ReadExternalBackendsJson(configDir string, appConfig *config.ApplicationConfig) error { + fileContent, err := os.ReadFile(path.Join(configDir, "external_backends.json")) + if err != nil { + return err + } + // Parse JSON content from the file + var fileBackends map[string]string + err = json.Unmarshal(fileContent, &fileBackends) + if err != nil { + return err + } + err = mergo.Merge(&appConfig.ExternalGRPCBackends, fileBackends) + if err != nil { + return err + } + return nil +} + +var CONFIG_FILE_UPDATES = map[string]func(configDir string, appConfig *config.ApplicationConfig) error{ + "api_keys.json": ReadApiKeysJson, + "external_backends.json": ReadExternalBackendsJson, +} + +func WatchConfigDirectory(configDir string, appConfig *config.ApplicationConfig) (WatchConfigDirectoryCloser, error) { + if len(configDir) == 0 { + return nil, fmt.Errorf("configDir blank") + } + configWatcher, err := fsnotify.NewWatcher() + if err != nil { + log.Fatal().Msgf("Unable to create a watcher for the LocalAI Configuration Directory: %+v", err) + } + ret := func() error { + configWatcher.Close() + return nil + } + + // Start listening for events. + go func() { + for { + select { + case event, ok := <-configWatcher.Events: + if !ok { + return + } + if event.Has(fsnotify.Write) { + for targetName, watchFn := range CONFIG_FILE_UPDATES { + if event.Name == targetName { + err := watchFn(configDir, appConfig) + log.Warn().Msgf("WatchConfigDirectory goroutine for %s: failed to update options: %+v", targetName, err) + } + } + } + case _, ok := <-configWatcher.Errors: + if !ok { + return + } + log.Error().Msgf("WatchConfigDirectory goroutine error: %+v", err) + } + } + }() + + // Add a path. + err = configWatcher.Add(configDir) + if err != nil { + return ret, fmt.Errorf("unable to establish watch on the LocalAI Configuration Directory: %+v", err) + } + + return ret, nil +} diff --git a/core/startup/startup.go b/core/startup/startup.go new file mode 100644 index 00000000..43e6646d --- /dev/null +++ b/core/startup/startup.go @@ -0,0 +1,128 @@ +package startup + +import ( + "fmt" + "os" + + "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/core/services" + "github.com/go-skynet/LocalAI/internal" + "github.com/go-skynet/LocalAI/pkg/assets" + "github.com/go-skynet/LocalAI/pkg/model" + pkgStartup "github.com/go-skynet/LocalAI/pkg/startup" + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" +) + +func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.ModelLoader, *config.ApplicationConfig, error) { + options := config.NewApplicationConfig(opts...) + + zerolog.SetGlobalLevel(zerolog.InfoLevel) + if options.Debug { + zerolog.SetGlobalLevel(zerolog.DebugLevel) + } + + log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.ModelPath) + log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion()) + + // Make sure directories exists + if options.ModelPath == "" { + return nil, nil, nil, fmt.Errorf("options.ModelPath cannot be empty") + } + err := os.MkdirAll(options.ModelPath, 0755) + if err != nil { + return nil, nil, nil, fmt.Errorf("unable to create ModelPath: %q", err) + } + if options.ImageDir != "" { + err := os.MkdirAll(options.ImageDir, 0755) + if err != nil { + return nil, nil, nil, fmt.Errorf("unable to create ImageDir: %q", err) + } + } + if options.AudioDir != "" { + err := os.MkdirAll(options.AudioDir, 0755) + if err != nil { + return nil, nil, nil, fmt.Errorf("unable to create AudioDir: %q", err) + } + } + if options.UploadDir != "" { + err := os.MkdirAll(options.UploadDir, 0755) + if err != nil { + return nil, nil, nil, fmt.Errorf("unable to create UploadDir: %q", err) + } + } + + // + pkgStartup.PreloadModelsConfigurations(options.ModelLibraryURL, options.ModelPath, options.ModelsURL...) + + cl := config.NewBackendConfigLoader() + ml := model.NewModelLoader(options.ModelPath) + + if err := cl.LoadBackendConfigsFromPath(options.ModelPath); err != nil { + log.Error().Msgf("error loading config files: %s", err.Error()) + } + + if options.ConfigFile != "" { + if err := cl.LoadBackendConfigFile(options.ConfigFile); err != nil { + log.Error().Msgf("error loading config file: %s", err.Error()) + } + } + + if err := cl.Preload(options.ModelPath); err != nil { + log.Error().Msgf("error downloading models: %s", err.Error()) + } + + if options.PreloadJSONModels != "" { + if err := services.ApplyGalleryFromString(options.ModelPath, options.PreloadJSONModels, cl, options.Galleries); err != nil { + return nil, nil, nil, err + } + } + + if options.PreloadModelsFromPath != "" { + if err := services.ApplyGalleryFromFile(options.ModelPath, options.PreloadModelsFromPath, cl, options.Galleries); err != nil { + return nil, nil, nil, err + } + } + + if options.Debug { + for _, v := range cl.ListBackendConfigs() { + cfg, _ := cl.GetBackendConfig(v) + log.Debug().Msgf("Model: %s (config: %+v)", v, cfg) + } + } + + if options.AssetsDestination != "" { + // Extract files from the embedded FS + err := assets.ExtractFiles(options.BackendAssets, options.AssetsDestination) + log.Debug().Msgf("Extracting backend assets files to %s", options.AssetsDestination) + if err != nil { + log.Warn().Msgf("Failed extracting backend assets files: %s (might be required for some backends to work properly, like gpt4all)", err) + } + } + + // turn off any process that was started by GRPC if the context is canceled + go func() { + <-options.Context.Done() + log.Debug().Msgf("Context canceled, shutting down") + ml.StopAllGRPC() + }() + + if options.WatchDog { + wd := model.NewWatchDog( + ml, + options.WatchDogBusyTimeout, + options.WatchDogIdleTimeout, + options.WatchDogBusy, + options.WatchDogIdle) + ml.SetWatchDog(wd) + go wd.Run() + go func() { + <-options.Context.Done() + log.Debug().Msgf("Context canceled, shutting down") + wd.Shutdown() + }() + } + + log.Info().Msg("core/startup process completed!") + return cl, ml, options, nil +} diff --git a/examples/bruno/LocalAI Test Requests/backend monitor/backend monitor.bru b/examples/bruno/LocalAI Test Requests/backend monitor/backend monitor.bru index e3f72134..51e3771a 100644 --- a/examples/bruno/LocalAI Test Requests/backend monitor/backend monitor.bru +++ b/examples/bruno/LocalAI Test Requests/backend monitor/backend monitor.bru @@ -6,6 +6,12 @@ meta { get { url: {{PROTOCOL}}{{HOST}}:{{PORT}}/backend/monitor - body: none + body: json auth: none } + +body:json { + { + "model": "{{DEFAULT_MODEL}}" + } +} diff --git a/examples/langchain/langchainjs-localai-example/src/index.mts b/examples/langchain/langchainjs-localai-example/src/index.mts index e6dcfb86..11faa384 100644 --- a/examples/langchain/langchainjs-localai-example/src/index.mts +++ b/examples/langchain/langchainjs-localai-example/src/index.mts @@ -4,7 +4,7 @@ import { Document } from "langchain/document"; import { initializeAgentExecutorWithOptions } from "langchain/agents"; import {Calculator} from "langchain/tools/calculator"; -const pathToLocalAi = process.env['OPENAI_API_BASE'] || 'http://api:8080/v1'; +const pathToLocalAI = process.env['OPENAI_API_BASE'] || 'http://api:8080/v1'; const fakeApiKey = process.env['OPENAI_API_KEY'] || '-'; const modelName = process.env['MODEL_NAME'] || 'gpt-3.5-turbo'; @@ -21,7 +21,7 @@ function getModel(): OpenAIChat { openAIApiKey: fakeApiKey, maxRetries: 2 }, { - basePath: pathToLocalAi, + basePath: pathToLocalAI, apiKey: fakeApiKey, }); } diff --git a/go.mod b/go.mod index bbd787b5..bbb90838 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.21 require ( github.com/M0Rf30/go-tiny-dream v0.0.0-20231128165230-772a9c0d9aaf github.com/donomii/go-rwkv.cpp v0.0.0-20230715075832-c898cd0f62df + github.com/fsnotify/fsnotify v1.7.0 github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230628193450-85ed71aaec8e github.com/go-audio/wav v1.1.0 github.com/go-skynet/go-bert.cpp v0.0.0-20230716133540-6abe312cded1 @@ -14,7 +15,6 @@ require ( github.com/hashicorp/go-multierror v1.1.1 github.com/hpcloud/tail v1.0.0 github.com/imdario/mergo v0.3.16 - github.com/json-iterator/go v1.1.12 github.com/mholt/archiver/v3 v3.5.1 github.com/mudler/go-processmanager v0.0.0-20230818213616-f204007f963c github.com/mudler/go-stable-diffusion v0.0.0-20230605122230-d89260f598af @@ -64,8 +64,6 @@ require ( github.com/klauspost/pgzip v1.2.5 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect - github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect - github.com/modern-go/reflect2 v1.0.2 // indirect github.com/nwaples/rardecode v1.1.0 // indirect github.com/pierrec/lz4/v4 v4.1.2 // indirect github.com/pkoukk/tiktoken-go v0.1.2 // indirect @@ -104,7 +102,7 @@ require ( github.com/valyala/tcplisten v1.0.0 // indirect github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect golang.org/x/net v0.17.0 // indirect - golang.org/x/sys v0.13.0 // indirect + golang.org/x/sys v0.17.0 // indirect golang.org/x/text v0.13.0 // indirect golang.org/x/tools v0.12.0 // indirect ) diff --git a/go.sum b/go.sum index 20dfbfb4..84aba3a0 100644 --- a/go.sum +++ b/go.sum @@ -26,6 +26,8 @@ github.com/dsnet/golib v0.0.0-20171103203638-1ea166775780/go.mod h1:Lj+Z9rebOhdf github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= +github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= +github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230628193450-85ed71aaec8e h1:KtbU2JR3lJuXFASHG2+sVLucfMPBjWKUUKByX6C81mQ= github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230628193450-85ed71aaec8e/go.mod h1:QIjZ9OktHFG7p+/m3sMvrAJKKdWrr1fZIK0rM6HZlyo= github.com/go-audio/audio v1.0.0 h1:zS9vebldgbQqktK4H0lUqWrG8P0NxCJVqcj7ZpNnwd4= @@ -72,7 +74,6 @@ github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE= github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4= @@ -86,8 +87,6 @@ github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpO github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/imdario/mergo v0.3.16 h1:wwQJbIsHYGMUyLSPrEq1CT16AhnhNJQ51+4fdHUnCl4= github.com/imdario/mergo v0.3.16/go.mod h1:WBLT9ZmE3lPoWsEzCh9LPo3TiwVN+ZKEjmz+hD27ysY= -github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= -github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/k0kubun/go-ansi v0.0.0-20180517002512-3bf9e2903213/go.mod h1:vNUNkEQ1e29fT/6vq2aBdFsgNPmy8qMdSay1npru+Sw= github.com/klauspost/compress v1.4.1/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A= github.com/klauspost/compress v1.11.4/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= @@ -117,11 +116,6 @@ github.com/mholt/archiver/v3 v3.5.1 h1:rDjOBX9JSF5BvoJGvjqK479aL70qh9DIpZCl+k7Cl github.com/mholt/archiver/v3 v3.5.1/go.mod h1:e3dqJ7H78uzsRSEACH1joayhuSyhnonssnDhppzS1L4= github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db h1:62I3jR2EmQ4l5rM/4FEfDWcRD+abF5XlKShorW5LRoQ= github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db/go.mod h1:l0dey0ia/Uv7NcFFVbCLtqEBQbrT4OCwCSKTEv6enCw= -github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= -github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= -github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= -github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= -github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/mudler/go-piper v0.0.0-20230621222733-56b8a81b4760 h1:OFVkSxR7CRSRSNm5dvpMRZwmSwWa8EMMnHbc84fW5tU= github.com/mudler/go-piper v0.0.0-20230621222733-56b8a81b4760/go.mod h1:O7SwdSWMilAWhBZMK9N9Y/oBDyMMzshE3ju8Xkexwig= github.com/mudler/go-processmanager v0.0.0-20230818213616-f204007f963c h1:CI5uGwqBpN8N7BrSKC+nmdfw+9nPQIDyjHHlaIiitZI= @@ -278,6 +272,8 @@ golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= +golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U= golang.org/x/term v0.13.0 h1:bb+I9cTfFazGW51MZqBVmZy7+JEJMouUHTUSKVQLBek= diff --git a/main.go b/main.go index 7e4262ee..237191cf 100644 --- a/main.go +++ b/main.go @@ -13,11 +13,12 @@ import ( "time" "github.com/go-skynet/LocalAI/core/backend" - config "github.com/go-skynet/LocalAI/core/config" - api "github.com/go-skynet/LocalAI/core/http" - "github.com/go-skynet/LocalAI/core/options" + "github.com/go-skynet/LocalAI/core/config" + + "github.com/go-skynet/LocalAI/core/http" + "github.com/go-skynet/LocalAI/core/startup" + "github.com/go-skynet/LocalAI/internal" - "github.com/go-skynet/LocalAI/metrics" "github.com/go-skynet/LocalAI/pkg/gallery" model "github.com/go-skynet/LocalAI/pkg/model" "github.com/rs/zerolog" @@ -206,6 +207,12 @@ func main() { EnvVars: []string{"PRELOAD_BACKEND_ONLY"}, Value: false, }, + &cli.StringFlag{ + Name: "localai-config-dir", + Usage: "Directory to use for the configuration files of LocalAI itself. This is NOT where model files should be placed.", + EnvVars: []string{"LOCALAI_CONFIG_DIR"}, + Value: "./configuration", + }, }, Description: ` LocalAI is a drop-in replacement OpenAI API which runs inference locally. @@ -224,56 +231,56 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit UsageText: `local-ai [options]`, Copyright: "Ettore Di Giacinto", Action: func(ctx *cli.Context) error { - opts := []options.AppOption{ - options.WithConfigFile(ctx.String("config-file")), - options.WithJSONStringPreload(ctx.String("preload-models")), - options.WithYAMLConfigPreload(ctx.String("preload-models-config")), - options.WithModelLoader(model.NewModelLoader(ctx.String("models-path"))), - options.WithContextSize(ctx.Int("context-size")), - options.WithDebug(ctx.Bool("debug")), - options.WithImageDir(ctx.String("image-path")), - options.WithAudioDir(ctx.String("audio-path")), - options.WithUploadDir(ctx.String("upload-path")), - options.WithF16(ctx.Bool("f16")), - options.WithStringGalleries(ctx.String("galleries")), - options.WithModelLibraryURL(ctx.String("remote-library")), - options.WithDisableMessage(false), - options.WithCors(ctx.Bool("cors")), - options.WithCorsAllowOrigins(ctx.String("cors-allow-origins")), - options.WithThreads(ctx.Int("threads")), - options.WithBackendAssets(backendAssets), - options.WithBackendAssetsOutput(ctx.String("backend-assets-path")), - options.WithUploadLimitMB(ctx.Int("upload-limit")), - options.WithApiKeys(ctx.StringSlice("api-keys")), - options.WithModelsURL(append(ctx.StringSlice("models"), ctx.Args().Slice()...)...), + opts := []config.AppOption{ + config.WithConfigFile(ctx.String("config-file")), + config.WithJSONStringPreload(ctx.String("preload-models")), + config.WithYAMLConfigPreload(ctx.String("preload-models-config")), + config.WithModelPath(ctx.String("models-path")), + config.WithContextSize(ctx.Int("context-size")), + config.WithDebug(ctx.Bool("debug")), + config.WithImageDir(ctx.String("image-path")), + config.WithAudioDir(ctx.String("audio-path")), + config.WithUploadDir(ctx.String("upload-path")), + config.WithF16(ctx.Bool("f16")), + config.WithStringGalleries(ctx.String("galleries")), + config.WithModelLibraryURL(ctx.String("remote-library")), + config.WithDisableMessage(false), + config.WithCors(ctx.Bool("cors")), + config.WithCorsAllowOrigins(ctx.String("cors-allow-origins")), + config.WithThreads(ctx.Int("threads")), + config.WithBackendAssets(backendAssets), + config.WithBackendAssetsOutput(ctx.String("backend-assets-path")), + config.WithUploadLimitMB(ctx.Int("upload-limit")), + config.WithApiKeys(ctx.StringSlice("api-keys")), + config.WithModelsURL(append(ctx.StringSlice("models"), ctx.Args().Slice()...)...), } idleWatchDog := ctx.Bool("enable-watchdog-idle") busyWatchDog := ctx.Bool("enable-watchdog-busy") if idleWatchDog || busyWatchDog { - opts = append(opts, options.EnableWatchDog) + opts = append(opts, config.EnableWatchDog) if idleWatchDog { - opts = append(opts, options.EnableWatchDogIdleCheck) + opts = append(opts, config.EnableWatchDogIdleCheck) dur, err := time.ParseDuration(ctx.String("watchdog-idle-timeout")) if err != nil { return err } - opts = append(opts, options.SetWatchDogIdleTimeout(dur)) + opts = append(opts, config.SetWatchDogIdleTimeout(dur)) } if busyWatchDog { - opts = append(opts, options.EnableWatchDogBusyCheck) + opts = append(opts, config.EnableWatchDogBusyCheck) dur, err := time.ParseDuration(ctx.String("watchdog-busy-timeout")) if err != nil { return err } - opts = append(opts, options.SetWatchDogBusyTimeout(dur)) + opts = append(opts, config.SetWatchDogBusyTimeout(dur)) } } if ctx.Bool("parallel-requests") { - opts = append(opts, options.EnableParallelBackendRequests) + opts = append(opts, config.EnableParallelBackendRequests) } if ctx.Bool("single-active-backend") { - opts = append(opts, options.EnableSingleBackend) + opts = append(opts, config.EnableSingleBackend) } externalgRPC := ctx.StringSlice("external-grpc-backends") @@ -281,30 +288,38 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit for _, v := range externalgRPC { backend := v[:strings.IndexByte(v, ':')] uri := v[strings.IndexByte(v, ':')+1:] - opts = append(opts, options.WithExternalBackend(backend, uri)) + opts = append(opts, config.WithExternalBackend(backend, uri)) } if ctx.Bool("autoload-galleries") { - opts = append(opts, options.EnableGalleriesAutoload) + opts = append(opts, config.EnableGalleriesAutoload) } if ctx.Bool("preload-backend-only") { - _, _, err := api.Startup(opts...) + _, _, _, err := startup.Startup(opts...) return err } - metrics, err := metrics.SetupMetrics() + cl, ml, options, err := startup.Startup(opts...) + if err != nil { - return err + return fmt.Errorf("failed basic startup tasks with error %s", err.Error()) } - opts = append(opts, options.WithMetrics(metrics)) - app, err := api.App(opts...) + closeConfigWatcherFn, err := startup.WatchConfigDirectory(ctx.String("localai-config-dir"), options) + defer closeConfigWatcherFn() + if err != nil { + return fmt.Errorf("failed while watching configuration directory %s", ctx.String("localai-config-dir")) + } + + appHTTP, err := http.App(cl, ml, options) + if err != nil { + log.Error().Msg("Error during HTTP App constructor") return err } - return app.Listen(ctx.String("address")) + return appHTTP.Listen(ctx.String("address")) }, Commands: []*cli.Command{ { @@ -402,16 +417,17 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit text := strings.Join(ctx.Args().Slice(), " ") - opts := &options.Option{ - Loader: model.NewModelLoader(ctx.String("models-path")), + opts := &config.ApplicationConfig{ + ModelPath: ctx.String("models-path"), Context: context.Background(), AudioDir: outputDir, AssetsDestination: ctx.String("backend-assets-path"), } + ml := model.NewModelLoader(opts.ModelPath) - defer opts.Loader.StopAllGRPC() + defer ml.StopAllGRPC() - filePath, _, err := backend.ModelTTS(backendOption, text, modelOption, opts.Loader, opts, config.Config{}) + filePath, _, err := backend.ModelTTS(backendOption, text, modelOption, ml, opts, config.BackendConfig{}) if err != nil { return err } @@ -464,27 +480,28 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit language := ctx.String("language") threads := ctx.Int("threads") - opts := &options.Option{ - Loader: model.NewModelLoader(ctx.String("models-path")), + opts := &config.ApplicationConfig{ + ModelPath: ctx.String("models-path"), Context: context.Background(), AssetsDestination: ctx.String("backend-assets-path"), } - cl := config.NewConfigLoader() - if err := cl.LoadConfigs(ctx.String("models-path")); err != nil { + cl := config.NewBackendConfigLoader() + ml := model.NewModelLoader(opts.ModelPath) + if err := cl.LoadBackendConfigsFromPath(ctx.String("models-path")); err != nil { return err } - c, exists := cl.GetConfig(modelOption) + c, exists := cl.GetBackendConfig(modelOption) if !exists { return errors.New("model not found") } c.Threads = threads - defer opts.Loader.StopAllGRPC() + defer ml.StopAllGRPC() - tr, err := backend.ModelTranscription(filename, language, opts.Loader, c, opts) + tr, err := backend.ModelTranscription(filename, language, ml, c, opts) if err != nil { return err } diff --git a/metrics/metrics.go b/metrics/metrics.go deleted file mode 100644 index 84b83161..00000000 --- a/metrics/metrics.go +++ /dev/null @@ -1,83 +0,0 @@ -package metrics - -import ( - "context" - "time" - - "github.com/gofiber/fiber/v2" - "github.com/gofiber/fiber/v2/middleware/adaptor" - "github.com/prometheus/client_golang/prometheus/promhttp" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/exporters/prometheus" - api "go.opentelemetry.io/otel/metric" - "go.opentelemetry.io/otel/sdk/metric" -) - -type Metrics struct { - meter api.Meter - apiTimeMetric api.Float64Histogram -} - -// setupOTelSDK bootstraps the OpenTelemetry pipeline. -// If it does not return an error, make sure to call shutdown for proper cleanup. -func SetupMetrics() (*Metrics, error) { - exporter, err := prometheus.New() - if err != nil { - return nil, err - } - provider := metric.NewMeterProvider(metric.WithReader(exporter)) - meter := provider.Meter("github.com/go-skynet/LocalAI") - - apiTimeMetric, err := meter.Float64Histogram("api_call", api.WithDescription("api calls")) - if err != nil { - return nil, err - } - - return &Metrics{ - meter: meter, - apiTimeMetric: apiTimeMetric, - }, nil -} - -func MetricsHandler() fiber.Handler { - return adaptor.HTTPHandler(promhttp.Handler()) -} - -type apiMiddlewareConfig struct { - Filter func(c *fiber.Ctx) bool - metrics *Metrics -} - -func APIMiddleware(metrics *Metrics) fiber.Handler { - cfg := apiMiddlewareConfig{ - metrics: metrics, - Filter: func(c *fiber.Ctx) bool { - if c.Path() == "/metrics" { - return true - } - return false - }, - } - - return func(c *fiber.Ctx) error { - if cfg.Filter != nil && cfg.Filter(c) { - return c.Next() - } - path := c.Path() - method := c.Method() - - start := time.Now() - err := c.Next() - elapsed := float64(time.Since(start)) / float64(time.Second) - cfg.metrics.ObserveAPICall(method, path, elapsed) - return err - } -} - -func (m *Metrics) ObserveAPICall(method string, path string, duration float64) { - opts := api.WithAttributes( - attribute.String("method", method), - attribute.String("path", path), - ) - m.apiTimeMetric.Record(context.Background(), duration, opts) -} diff --git a/pkg/downloader/uri.go b/pkg/downloader/uri.go index 80214f5b..b678ae0d 100644 --- a/pkg/downloader/uri.go +++ b/pkg/downloader/uri.go @@ -179,6 +179,10 @@ func DownloadFile(url string, filePath, sha string, downloadStatus func(string, } defer resp.Body.Close() + if resp.StatusCode >= 400 { + return fmt.Errorf("failed to download url %q, invalid status code %d", url, resp.StatusCode) + } + // Create parent directory err = os.MkdirAll(filepath.Dir(filePath), 0755) if err != nil { diff --git a/pkg/gallery/models_test.go b/pkg/gallery/models_test.go index f454c611..21d3a03d 100644 --- a/pkg/gallery/models_test.go +++ b/pkg/gallery/models_test.go @@ -18,7 +18,6 @@ var _ = Describe("Model test", func() { defer os.RemoveAll(tempdir) c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) Expect(err).ToNot(HaveOccurred()) - err = InstallModel(tempdir, "", c, map[string]interface{}{}, func(string, string, string, float64) {}) Expect(err).ToNot(HaveOccurred()) diff --git a/pkg/gallery/op.go b/pkg/gallery/op.go new file mode 100644 index 00000000..873c356d --- /dev/null +++ b/pkg/gallery/op.go @@ -0,0 +1,18 @@ +package gallery + +type GalleryOp struct { + Req GalleryModel + Id string + Galleries []Gallery + GalleryName string +} + +type GalleryOpStatus struct { + FileName string `json:"file_name"` + Error error `json:"error"` + Processed bool `json:"processed"` + Message string `json:"message"` + Progress float64 `json:"progress"` + TotalFileSize string `json:"file_size"` + DownloadedFileSize string `json:"downloaded_size"` +} diff --git a/tests/integration/reflect_test.go b/tests/integration/reflect_test.go index bf3f8a5b..5fd60114 100644 --- a/tests/integration/reflect_test.go +++ b/tests/integration/reflect_test.go @@ -3,7 +3,7 @@ package integration_test import ( "reflect" - config "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/core/config" model "github.com/go-skynet/LocalAI/pkg/model" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega"