mirror of
https://github.com/mudler/LocalAI.git
synced 2024-12-24 06:46:39 +00:00
refactor: consolidate usage of GetURI (#674)
Signed-off-by: mudler <mudler@localai.io>
This commit is contained in:
parent
d18f85df46
commit
78f3c3da48
@ -80,13 +80,13 @@ func App(opts ...AppOption) (*fiber.App, error) {
|
||||
app.Use(recover.New())
|
||||
|
||||
if options.preloadJSONModels != "" {
|
||||
if err := ApplyGalleryFromString(options.loader.ModelPath, options.preloadJSONModels, cm); err != nil {
|
||||
if err := ApplyGalleryFromString(options.loader.ModelPath, options.preloadJSONModels, cm, options.galleries); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if options.preloadModelsFromPath != "" {
|
||||
if err := ApplyGalleryFromFile(options.loader.ModelPath, options.preloadModelsFromPath, cm); err != nil {
|
||||
if err := ApplyGalleryFromFile(options.loader.ModelPath, options.preloadModelsFromPath, cm, options.galleries); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
@ -15,6 +15,7 @@ import (
|
||||
. "github.com/go-skynet/LocalAI/api"
|
||||
"github.com/go-skynet/LocalAI/pkg/gallery"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
@ -56,30 +57,10 @@ func getModelStatus(url string) (response map[string]interface{}) {
|
||||
}
|
||||
|
||||
func getModels(url string) (response []gallery.GalleryModel) {
|
||||
|
||||
//url := "http://localhost:AI/models/apply"
|
||||
|
||||
// Create the request payload
|
||||
|
||||
// Create the HTTP request
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
fmt.Println("Error reading response body:", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Unmarshal the response into a map[string]interface{}
|
||||
err = json.Unmarshal(body, &response)
|
||||
if err != nil {
|
||||
fmt.Println("Error unmarshaling JSON response:", err)
|
||||
return
|
||||
}
|
||||
utils.GetURI(url, func(url string, i []byte) error {
|
||||
// Unmarshal YAML data into a struct
|
||||
return json.Unmarshal(i, &response)
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -48,9 +48,8 @@ func newGalleryApplier(modelPath string) *galleryApplier {
|
||||
|
||||
// prepareModel applies a
|
||||
func prepareModel(modelPath string, req gallery.GalleryModel, cm *ConfigMerger, downloadStatus func(string, string, string, float64)) error {
|
||||
var config gallery.Config
|
||||
|
||||
err := req.Get(&config)
|
||||
config, err := gallery.GetGalleryConfigFromURL(req.URL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -144,40 +143,35 @@ func displayDownload(fileName string, current string, total string, percentage f
|
||||
}
|
||||
}
|
||||
|
||||
func ApplyGalleryFromFile(modelPath, s string, cm *ConfigMerger) error {
|
||||
type galleryModel struct {
|
||||
gallery.GalleryModel
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
func ApplyGalleryFromFile(modelPath, s string, cm *ConfigMerger, galleries []gallery.Gallery) error {
|
||||
dat, err := os.ReadFile(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var requests []gallery.GalleryModel
|
||||
err = json.Unmarshal(dat, &requests)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, r := range requests {
|
||||
if err := prepareModel(modelPath, r, cm, displayDownload); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return ApplyGalleryFromString(modelPath, string(dat), cm, galleries)
|
||||
}
|
||||
|
||||
func ApplyGalleryFromString(modelPath, s string, cm *ConfigMerger) error {
|
||||
var requests []gallery.GalleryModel
|
||||
func ApplyGalleryFromString(modelPath, s string, cm *ConfigMerger, galleries []gallery.Gallery) error {
|
||||
var requests []galleryModel
|
||||
err := json.Unmarshal([]byte(s), &requests)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, r := range requests {
|
||||
if err := prepareModel(modelPath, r, cm, displayDownload); err != nil {
|
||||
return err
|
||||
if r.ID == "" {
|
||||
err = prepareModel(modelPath, r.GalleryModel, cm, displayDownload)
|
||||
} else {
|
||||
err = gallery.InstallModelFromGallery(galleries, r.ID, modelPath, r.GalleryModel, displayDownload)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
func getOpStatus(g *galleryApplier) func(c *fiber.Ctx) error {
|
||||
|
@ -23,9 +23,7 @@ func InstallModelFromGallery(galleries []Gallery, name string, basePath string,
|
||||
}
|
||||
|
||||
applyModel := func(model *GalleryModel) error {
|
||||
var config Config
|
||||
|
||||
err := model.Get(&config)
|
||||
config, err := GetGalleryConfigFromURL(model.URL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -79,7 +77,7 @@ func AvailableGalleryModels(galleries []Gallery, basePath string) ([]*GalleryMod
|
||||
func getGalleryModels(gallery Gallery, basePath string) ([]*GalleryModel, error) {
|
||||
var models []*GalleryModel = []*GalleryModel{}
|
||||
|
||||
err := utils.GetURI(gallery.URL, func(d []byte) error {
|
||||
err := utils.GetURI(gallery.URL, func(url string, d []byte) error {
|
||||
return yaml.Unmarshal(d, &models)
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -65,6 +65,17 @@ type PromptTemplate struct {
|
||||
Content string `yaml:"content"`
|
||||
}
|
||||
|
||||
func GetGalleryConfigFromURL(url string) (Config, error) {
|
||||
var config Config
|
||||
err := utils.GetURI(url, func(url string, d []byte) error {
|
||||
return yaml.Unmarshal(d, &config)
|
||||
})
|
||||
if err != nil {
|
||||
return config, err
|
||||
}
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func ReadConfigFile(filePath string) (*Config, error) {
|
||||
// Read the YAML file
|
||||
yamlFile, err := os.ReadFile(filePath)
|
||||
|
@ -1,14 +1,5 @@
|
||||
package gallery
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
// GalleryModel is the struct used to represent a model in the gallery returned by the endpoint.
|
||||
// It is used to install the model by resolving the URL and downloading the files.
|
||||
// The other fields are used to override the configuration of the model.
|
||||
@ -34,52 +25,3 @@ type GalleryModel struct {
|
||||
const (
|
||||
githubURI = "github:"
|
||||
)
|
||||
|
||||
func (request GalleryModel) DecodeURL() (string, error) {
|
||||
input := request.URL
|
||||
var rawURL string
|
||||
|
||||
if strings.HasPrefix(input, githubURI) {
|
||||
parts := strings.Split(input, ":")
|
||||
repoParts := strings.Split(parts[1], "@")
|
||||
branch := "main"
|
||||
|
||||
if len(repoParts) > 1 {
|
||||
branch = repoParts[1]
|
||||
}
|
||||
|
||||
repoPath := strings.Split(repoParts[0], "/")
|
||||
org := repoPath[0]
|
||||
project := repoPath[1]
|
||||
projectPath := strings.Join(repoPath[2:], "/")
|
||||
|
||||
rawURL = fmt.Sprintf("https://raw.githubusercontent.com/%s/%s/%s/%s", org, project, branch, projectPath)
|
||||
} else if strings.HasPrefix(input, "http://") || strings.HasPrefix(input, "https://") {
|
||||
// Handle regular URLs
|
||||
u, err := url.Parse(input)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid URL: %w", err)
|
||||
}
|
||||
rawURL = u.String()
|
||||
// check if it's a file path
|
||||
} else if strings.HasPrefix(input, "file://") {
|
||||
return input, nil
|
||||
} else {
|
||||
|
||||
return "", fmt.Errorf("invalid URL format: %s", input)
|
||||
}
|
||||
|
||||
return rawURL, nil
|
||||
}
|
||||
|
||||
// Get fetches a model from a URL and unmarshals it into a struct
|
||||
func (request GalleryModel) Get(i interface{}) error {
|
||||
url, err := request.DecodeURL()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return utils.GetURI(url, func(d []byte) error {
|
||||
return yaml.Unmarshal(d, i)
|
||||
})
|
||||
}
|
||||
|
@ -6,37 +6,13 @@ import (
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
type example struct {
|
||||
Name string `yaml:"name"`
|
||||
}
|
||||
|
||||
var _ = Describe("Gallery API tests", func() {
|
||||
|
||||
Context("requests", func() {
|
||||
It("parses github with a branch", func() {
|
||||
req := GalleryModel{URL: "github:go-skynet/model-gallery/gpt4all-j.yaml@main"}
|
||||
var e example
|
||||
err := req.Get(&e)
|
||||
e, err := GetGalleryConfigFromURL(req.URL)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(e.Name).To(Equal("gpt4all-j"))
|
||||
})
|
||||
It("parses github without a branch", func() {
|
||||
req := GalleryModel{URL: "github:go-skynet/model-gallery/gpt4all-j.yaml@main"}
|
||||
str, err := req.DecodeURL()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))
|
||||
})
|
||||
It("parses github without a branch", func() {
|
||||
req := GalleryModel{URL: "github:go-skynet/model-gallery/gpt4all-j.yaml"}
|
||||
str, err := req.DecodeURL()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))
|
||||
})
|
||||
It("parses URLS", func() {
|
||||
req := GalleryModel{URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"}
|
||||
str, err := req.DecodeURL()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
@ -1,12 +1,34 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func GetURI(url string, f func(i []byte) error) error {
|
||||
const (
|
||||
githubURI = "github:"
|
||||
)
|
||||
|
||||
func GetURI(url string, f func(url string, i []byte) error) error {
|
||||
if strings.HasPrefix(url, githubURI) {
|
||||
parts := strings.Split(url, ":")
|
||||
repoParts := strings.Split(parts[1], "@")
|
||||
branch := "main"
|
||||
|
||||
if len(repoParts) > 1 {
|
||||
branch = repoParts[1]
|
||||
}
|
||||
|
||||
repoPath := strings.Split(repoParts[0], "/")
|
||||
org := repoPath[0]
|
||||
project := repoPath[1]
|
||||
projectPath := strings.Join(repoPath[2:], "/")
|
||||
|
||||
url = fmt.Sprintf("https://raw.githubusercontent.com/%s/%s/%s/%s", org, project, branch, projectPath)
|
||||
}
|
||||
|
||||
if strings.HasPrefix(url, "file://") {
|
||||
rawURL := strings.TrimPrefix(url, "file://")
|
||||
// Read the response body
|
||||
@ -16,7 +38,7 @@ func GetURI(url string, f func(i []byte) error) error {
|
||||
}
|
||||
|
||||
// Unmarshal YAML data into a struct
|
||||
return f(body)
|
||||
return f(url, body)
|
||||
}
|
||||
|
||||
// Send a GET request to the URL
|
||||
@ -33,5 +55,5 @@ func GetURI(url string, f func(i []byte) error) error {
|
||||
}
|
||||
|
||||
// Unmarshal YAML data into a struct
|
||||
return f(body)
|
||||
return f(url, body)
|
||||
}
|
||||
|
36
pkg/utils/uri_test.go
Normal file
36
pkg/utils/uri_test.go
Normal file
@ -0,0 +1,36 @@
|
||||
package utils_test
|
||||
|
||||
import (
|
||||
. "github.com/go-skynet/LocalAI/pkg/utils"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Gallery API tests", func() {
|
||||
Context("URI", func() {
|
||||
It("parses github with a branch", func() {
|
||||
Expect(
|
||||
GetURI("github:go-skynet/model-gallery/gpt4all-j.yaml", func(url string, i []byte) error {
|
||||
Expect(url).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))
|
||||
return nil
|
||||
}),
|
||||
).ToNot(HaveOccurred())
|
||||
})
|
||||
It("parses github without a branch", func() {
|
||||
Expect(
|
||||
GetURI("github:go-skynet/model-gallery/gpt4all-j.yaml@main", func(url string, i []byte) error {
|
||||
Expect(url).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))
|
||||
return nil
|
||||
}),
|
||||
).ToNot(HaveOccurred())
|
||||
})
|
||||
It("parses github with urls", func() {
|
||||
Expect(
|
||||
GetURI("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml", func(url string, i []byte) error {
|
||||
Expect(url).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))
|
||||
return nil
|
||||
}),
|
||||
).ToNot(HaveOccurred())
|
||||
})
|
||||
})
|
||||
})
|
13
pkg/utils/utils_suite_test.go
Normal file
13
pkg/utils/utils_suite_test.go
Normal file
@ -0,0 +1,13 @@
|
||||
package utils_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestUtils(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Utils test suite")
|
||||
}
|
Loading…
Reference in New Issue
Block a user