From b2785ff06e3eb7c1d62a6c3921ae706d58c054dd Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Fri, 12 Apr 2024 00:49:23 +0200 Subject: [PATCH] feat(gallery): support ConfigURLs (#2012) Signed-off-by: Ettore Di Giacinto --- core/http/api_test.go | 24 +++++++++++++++++++++ core/http/endpoints/localai/gallery.go | 4 +++- core/services/gallery.go | 5 +++++ docs/content/docs/features/model-gallery.md | 10 ++++++--- pkg/gallery/op.go | 1 + 5 files changed, 40 insertions(+), 4 deletions(-) diff --git a/core/http/api_test.go b/core/http/api_test.go index 804c15fe..1553ed21 100644 --- a/core/http/api_test.go +++ b/core/http/api_test.go @@ -43,6 +43,7 @@ Can you help rephrasing sentences? type modelApplyRequest struct { ID string `json:"id"` URL string `json:"url"` + ConfigURL string `json:"config_url"` Name string `json:"name"` Overrides map[string]interface{} `json:"overrides"` } @@ -366,6 +367,29 @@ var _ = Describe("API test", func() { Expect(err).ToNot(HaveOccurred()) Expect(content["backend"]).To(Equal("llama")) }) + It("apply models from config", func() { + response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ + ConfigURL: "https://raw.githubusercontent.com/mudler/LocalAI/master/embedded/models/hermes-2-pro-mistral.yaml", + }) + + Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response)) + + uuid := response["uuid"].(string) + + Eventually(func() bool { + response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) + return response["processed"].(bool) + }, "360s", "10s").Should(Equal(true)) + + Eventually(func() []string { + models, _ := client.ListModels(context.TODO()) + modelList := []string{} + for _, m := range models.Models { + modelList = append(modelList, m.ID) + } + return modelList + }, "360s", "10s").Should(ContainElements("hermes-2-pro-mistral")) + }) It("apply models without overrides", func() { response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml", diff --git a/core/http/endpoints/localai/gallery.go b/core/http/endpoints/localai/gallery.go index 5c295a2a..b693e7c3 100644 --- a/core/http/endpoints/localai/gallery.go +++ b/core/http/endpoints/localai/gallery.go @@ -19,7 +19,8 @@ type ModelGalleryEndpointService struct { } type GalleryModel struct { - ID string `json:"id"` + ID string `json:"id"` + ConfigURL string `json:"config_url"` gallery.GalleryModel } @@ -64,6 +65,7 @@ func (mgs *ModelGalleryEndpointService) ApplyModelGalleryEndpoint() func(c *fibe Id: uuid.String(), GalleryName: input.ID, Galleries: mgs.galleries, + ConfigURL: input.ConfigURL, } return c.JSON(struct { ID string `json:"uuid"` diff --git a/core/services/gallery.go b/core/services/gallery.go index 826f4573..b068abbb 100644 --- a/core/services/gallery.go +++ b/core/services/gallery.go @@ -9,6 +9,7 @@ import ( "github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/pkg/gallery" + "github.com/go-skynet/LocalAI/pkg/startup" "github.com/go-skynet/LocalAI/pkg/utils" "gopkg.in/yaml.v2" ) @@ -90,6 +91,9 @@ func (g *GalleryService) Start(c context.Context, cl *config.BackendConfigLoader } else { err = gallery.InstallModelFromGalleryByName(op.Galleries, op.GalleryName, g.modelPath, op.Req, progressCallback) } + } else if op.ConfigURL != "" { + startup.PreloadModelsConfigurations(op.ConfigURL, g.modelPath, op.ConfigURL) + err = cl.Preload(g.modelPath) } else { err = prepareModel(g.modelPath, op.Req, cl, progressCallback) } @@ -129,6 +133,7 @@ func processRequests(modelPath, s string, cm *config.BackendConfigLoader, galler utils.ResetDownloadTimers() if r.ID == "" { err = prepareModel(modelPath, r.GalleryModel, cm, utils.DisplayDownloadFunction) + } else { if strings.Contains(r.ID, "@") { err = gallery.InstallModelFromGallery( diff --git a/docs/content/docs/features/model-gallery.md b/docs/content/docs/features/model-gallery.md index 0d978122..05d15ef4 100644 --- a/docs/content/docs/features/model-gallery.md +++ b/docs/content/docs/features/model-gallery.md @@ -146,12 +146,16 @@ In the body of the request you must specify the model configuration file URL (`u ```bash LOCALAI=http://localhost:8080 curl $LOCALAI/models/apply -H "Content-Type: application/json" -d '{ - "url": "" + "config_url": "" }' # or if from a repository curl $LOCALAI/models/apply -H "Content-Type: application/json" -d '{ "id": "@" }' +# or from a gallery config +curl $LOCALAI/models/apply -H "Content-Type: application/json" -d '{ + "url": "" + }' ``` An example that installs openllama can be: @@ -159,8 +163,8 @@ An example that installs openllama can be: ```bash LOCALAI=http://localhost:8080 curl $LOCALAI/models/apply -H "Content-Type: application/json" -d '{ - "url": "https://github.com/go-skynet/model-gallery/blob/main/openllama_3b.yaml" - }' + "config_url": "https://raw.githubusercontent.com/mudler/LocalAI/master/embedded/models/hermes-2-pro-mistral.yaml" + }' ``` The API will return a job `uuid` that you can use to track the job progress: diff --git a/pkg/gallery/op.go b/pkg/gallery/op.go index 873c356d..99796812 100644 --- a/pkg/gallery/op.go +++ b/pkg/gallery/op.go @@ -5,6 +5,7 @@ type GalleryOp struct { Id string Galleries []Gallery GalleryName string + ConfigURL string } type GalleryOpStatus struct {