mirror of
https://github.com/mudler/LocalAI.git
synced 2024-12-24 06:46:39 +00:00
refactor: consolidate usage of GetURI (#674)
Signed-off-by: mudler <mudler@localai.io>
This commit is contained in:
parent
d18f85df46
commit
78f3c3da48
@ -80,13 +80,13 @@ func App(opts ...AppOption) (*fiber.App, error) {
|
|||||||
app.Use(recover.New())
|
app.Use(recover.New())
|
||||||
|
|
||||||
if options.preloadJSONModels != "" {
|
if options.preloadJSONModels != "" {
|
||||||
if err := ApplyGalleryFromString(options.loader.ModelPath, options.preloadJSONModels, cm); err != nil {
|
if err := ApplyGalleryFromString(options.loader.ModelPath, options.preloadJSONModels, cm, options.galleries); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if options.preloadModelsFromPath != "" {
|
if options.preloadModelsFromPath != "" {
|
||||||
if err := ApplyGalleryFromFile(options.loader.ModelPath, options.preloadModelsFromPath, cm); err != nil {
|
if err := ApplyGalleryFromFile(options.loader.ModelPath, options.preloadModelsFromPath, cm, options.galleries); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -15,6 +15,7 @@ import (
|
|||||||
. "github.com/go-skynet/LocalAI/api"
|
. "github.com/go-skynet/LocalAI/api"
|
||||||
"github.com/go-skynet/LocalAI/pkg/gallery"
|
"github.com/go-skynet/LocalAI/pkg/gallery"
|
||||||
"github.com/go-skynet/LocalAI/pkg/model"
|
"github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
. "github.com/onsi/ginkgo/v2"
|
. "github.com/onsi/ginkgo/v2"
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
@ -56,30 +57,10 @@ func getModelStatus(url string) (response map[string]interface{}) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func getModels(url string) (response []gallery.GalleryModel) {
|
func getModels(url string) (response []gallery.GalleryModel) {
|
||||||
|
utils.GetURI(url, func(url string, i []byte) error {
|
||||||
//url := "http://localhost:AI/models/apply"
|
// Unmarshal YAML data into a struct
|
||||||
|
return json.Unmarshal(i, &response)
|
||||||
// Create the request payload
|
})
|
||||||
|
|
||||||
// Create the HTTP request
|
|
||||||
resp, err := http.Get(url)
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
body, err := ioutil.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Println("Error reading response body:", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unmarshal the response into a map[string]interface{}
|
|
||||||
err = json.Unmarshal(body, &response)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Println("Error unmarshaling JSON response:", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -48,9 +48,8 @@ func newGalleryApplier(modelPath string) *galleryApplier {
|
|||||||
|
|
||||||
// prepareModel applies a
|
// prepareModel applies a
|
||||||
func prepareModel(modelPath string, req gallery.GalleryModel, cm *ConfigMerger, downloadStatus func(string, string, string, float64)) error {
|
func prepareModel(modelPath string, req gallery.GalleryModel, cm *ConfigMerger, downloadStatus func(string, string, string, float64)) error {
|
||||||
var config gallery.Config
|
|
||||||
|
|
||||||
err := req.Get(&config)
|
config, err := gallery.GetGalleryConfigFromURL(req.URL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -144,40 +143,35 @@ func displayDownload(fileName string, current string, total string, percentage f
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func ApplyGalleryFromFile(modelPath, s string, cm *ConfigMerger) error {
|
type galleryModel struct {
|
||||||
|
gallery.GalleryModel
|
||||||
|
ID string `json:"id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func ApplyGalleryFromFile(modelPath, s string, cm *ConfigMerger, galleries []gallery.Gallery) error {
|
||||||
dat, err := os.ReadFile(s)
|
dat, err := os.ReadFile(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
var requests []gallery.GalleryModel
|
return ApplyGalleryFromString(modelPath, string(dat), cm, galleries)
|
||||||
err = json.Unmarshal(dat, &requests)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, r := range requests {
|
|
||||||
if err := prepareModel(modelPath, r, cm, displayDownload); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func ApplyGalleryFromString(modelPath, s string, cm *ConfigMerger) error {
|
func ApplyGalleryFromString(modelPath, s string, cm *ConfigMerger, galleries []gallery.Gallery) error {
|
||||||
var requests []gallery.GalleryModel
|
var requests []galleryModel
|
||||||
err := json.Unmarshal([]byte(s), &requests)
|
err := json.Unmarshal([]byte(s), &requests)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, r := range requests {
|
for _, r := range requests {
|
||||||
if err := prepareModel(modelPath, r, cm, displayDownload); err != nil {
|
if r.ID == "" {
|
||||||
return err
|
err = prepareModel(modelPath, r.GalleryModel, cm, displayDownload)
|
||||||
|
} else {
|
||||||
|
err = gallery.InstallModelFromGallery(galleries, r.ID, modelPath, r.GalleryModel, displayDownload)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func getOpStatus(g *galleryApplier) func(c *fiber.Ctx) error {
|
func getOpStatus(g *galleryApplier) func(c *fiber.Ctx) error {
|
||||||
|
@ -23,9 +23,7 @@ func InstallModelFromGallery(galleries []Gallery, name string, basePath string,
|
|||||||
}
|
}
|
||||||
|
|
||||||
applyModel := func(model *GalleryModel) error {
|
applyModel := func(model *GalleryModel) error {
|
||||||
var config Config
|
config, err := GetGalleryConfigFromURL(model.URL)
|
||||||
|
|
||||||
err := model.Get(&config)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -79,7 +77,7 @@ func AvailableGalleryModels(galleries []Gallery, basePath string) ([]*GalleryMod
|
|||||||
func getGalleryModels(gallery Gallery, basePath string) ([]*GalleryModel, error) {
|
func getGalleryModels(gallery Gallery, basePath string) ([]*GalleryModel, error) {
|
||||||
var models []*GalleryModel = []*GalleryModel{}
|
var models []*GalleryModel = []*GalleryModel{}
|
||||||
|
|
||||||
err := utils.GetURI(gallery.URL, func(d []byte) error {
|
err := utils.GetURI(gallery.URL, func(url string, d []byte) error {
|
||||||
return yaml.Unmarshal(d, &models)
|
return yaml.Unmarshal(d, &models)
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -65,6 +65,17 @@ type PromptTemplate struct {
|
|||||||
Content string `yaml:"content"`
|
Content string `yaml:"content"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetGalleryConfigFromURL(url string) (Config, error) {
|
||||||
|
var config Config
|
||||||
|
err := utils.GetURI(url, func(url string, d []byte) error {
|
||||||
|
return yaml.Unmarshal(d, &config)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return config, err
|
||||||
|
}
|
||||||
|
return config, nil
|
||||||
|
}
|
||||||
|
|
||||||
func ReadConfigFile(filePath string) (*Config, error) {
|
func ReadConfigFile(filePath string) (*Config, error) {
|
||||||
// Read the YAML file
|
// Read the YAML file
|
||||||
yamlFile, err := os.ReadFile(filePath)
|
yamlFile, err := os.ReadFile(filePath)
|
||||||
|
@ -1,14 +1,5 @@
|
|||||||
package gallery
|
package gallery
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net/url"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
|
||||||
"gopkg.in/yaml.v2"
|
|
||||||
)
|
|
||||||
|
|
||||||
// GalleryModel is the struct used to represent a model in the gallery returned by the endpoint.
|
// GalleryModel is the struct used to represent a model in the gallery returned by the endpoint.
|
||||||
// It is used to install the model by resolving the URL and downloading the files.
|
// It is used to install the model by resolving the URL and downloading the files.
|
||||||
// The other fields are used to override the configuration of the model.
|
// The other fields are used to override the configuration of the model.
|
||||||
@ -34,52 +25,3 @@ type GalleryModel struct {
|
|||||||
const (
|
const (
|
||||||
githubURI = "github:"
|
githubURI = "github:"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (request GalleryModel) DecodeURL() (string, error) {
|
|
||||||
input := request.URL
|
|
||||||
var rawURL string
|
|
||||||
|
|
||||||
if strings.HasPrefix(input, githubURI) {
|
|
||||||
parts := strings.Split(input, ":")
|
|
||||||
repoParts := strings.Split(parts[1], "@")
|
|
||||||
branch := "main"
|
|
||||||
|
|
||||||
if len(repoParts) > 1 {
|
|
||||||
branch = repoParts[1]
|
|
||||||
}
|
|
||||||
|
|
||||||
repoPath := strings.Split(repoParts[0], "/")
|
|
||||||
org := repoPath[0]
|
|
||||||
project := repoPath[1]
|
|
||||||
projectPath := strings.Join(repoPath[2:], "/")
|
|
||||||
|
|
||||||
rawURL = fmt.Sprintf("https://raw.githubusercontent.com/%s/%s/%s/%s", org, project, branch, projectPath)
|
|
||||||
} else if strings.HasPrefix(input, "http://") || strings.HasPrefix(input, "https://") {
|
|
||||||
// Handle regular URLs
|
|
||||||
u, err := url.Parse(input)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("invalid URL: %w", err)
|
|
||||||
}
|
|
||||||
rawURL = u.String()
|
|
||||||
// check if it's a file path
|
|
||||||
} else if strings.HasPrefix(input, "file://") {
|
|
||||||
return input, nil
|
|
||||||
} else {
|
|
||||||
|
|
||||||
return "", fmt.Errorf("invalid URL format: %s", input)
|
|
||||||
}
|
|
||||||
|
|
||||||
return rawURL, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get fetches a model from a URL and unmarshals it into a struct
|
|
||||||
func (request GalleryModel) Get(i interface{}) error {
|
|
||||||
url, err := request.DecodeURL()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return utils.GetURI(url, func(d []byte) error {
|
|
||||||
return yaml.Unmarshal(d, i)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
@ -6,37 +6,13 @@ import (
|
|||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
)
|
)
|
||||||
|
|
||||||
type example struct {
|
|
||||||
Name string `yaml:"name"`
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ = Describe("Gallery API tests", func() {
|
var _ = Describe("Gallery API tests", func() {
|
||||||
|
|
||||||
Context("requests", func() {
|
Context("requests", func() {
|
||||||
It("parses github with a branch", func() {
|
It("parses github with a branch", func() {
|
||||||
req := GalleryModel{URL: "github:go-skynet/model-gallery/gpt4all-j.yaml@main"}
|
req := GalleryModel{URL: "github:go-skynet/model-gallery/gpt4all-j.yaml@main"}
|
||||||
var e example
|
e, err := GetGalleryConfigFromURL(req.URL)
|
||||||
err := req.Get(&e)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(e.Name).To(Equal("gpt4all-j"))
|
Expect(e.Name).To(Equal("gpt4all-j"))
|
||||||
})
|
})
|
||||||
It("parses github without a branch", func() {
|
|
||||||
req := GalleryModel{URL: "github:go-skynet/model-gallery/gpt4all-j.yaml@main"}
|
|
||||||
str, err := req.DecodeURL()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(str).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))
|
|
||||||
})
|
|
||||||
It("parses github without a branch", func() {
|
|
||||||
req := GalleryModel{URL: "github:go-skynet/model-gallery/gpt4all-j.yaml"}
|
|
||||||
str, err := req.DecodeURL()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(str).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))
|
|
||||||
})
|
|
||||||
It("parses URLS", func() {
|
|
||||||
req := GalleryModel{URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"}
|
|
||||||
str, err := req.DecodeURL()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(str).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))
|
|
||||||
})
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
@ -1,12 +1,34 @@
|
|||||||
package utils
|
package utils
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetURI(url string, f func(i []byte) error) error {
|
const (
|
||||||
|
githubURI = "github:"
|
||||||
|
)
|
||||||
|
|
||||||
|
func GetURI(url string, f func(url string, i []byte) error) error {
|
||||||
|
if strings.HasPrefix(url, githubURI) {
|
||||||
|
parts := strings.Split(url, ":")
|
||||||
|
repoParts := strings.Split(parts[1], "@")
|
||||||
|
branch := "main"
|
||||||
|
|
||||||
|
if len(repoParts) > 1 {
|
||||||
|
branch = repoParts[1]
|
||||||
|
}
|
||||||
|
|
||||||
|
repoPath := strings.Split(repoParts[0], "/")
|
||||||
|
org := repoPath[0]
|
||||||
|
project := repoPath[1]
|
||||||
|
projectPath := strings.Join(repoPath[2:], "/")
|
||||||
|
|
||||||
|
url = fmt.Sprintf("https://raw.githubusercontent.com/%s/%s/%s/%s", org, project, branch, projectPath)
|
||||||
|
}
|
||||||
|
|
||||||
if strings.HasPrefix(url, "file://") {
|
if strings.HasPrefix(url, "file://") {
|
||||||
rawURL := strings.TrimPrefix(url, "file://")
|
rawURL := strings.TrimPrefix(url, "file://")
|
||||||
// Read the response body
|
// Read the response body
|
||||||
@ -16,7 +38,7 @@ func GetURI(url string, f func(i []byte) error) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Unmarshal YAML data into a struct
|
// Unmarshal YAML data into a struct
|
||||||
return f(body)
|
return f(url, body)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send a GET request to the URL
|
// Send a GET request to the URL
|
||||||
@ -33,5 +55,5 @@ func GetURI(url string, f func(i []byte) error) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Unmarshal YAML data into a struct
|
// Unmarshal YAML data into a struct
|
||||||
return f(body)
|
return f(url, body)
|
||||||
}
|
}
|
||||||
|
36
pkg/utils/uri_test.go
Normal file
36
pkg/utils/uri_test.go
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
package utils_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
. "github.com/go-skynet/LocalAI/pkg/utils"
|
||||||
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = Describe("Gallery API tests", func() {
|
||||||
|
Context("URI", func() {
|
||||||
|
It("parses github with a branch", func() {
|
||||||
|
Expect(
|
||||||
|
GetURI("github:go-skynet/model-gallery/gpt4all-j.yaml", func(url string, i []byte) error {
|
||||||
|
Expect(url).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))
|
||||||
|
return nil
|
||||||
|
}),
|
||||||
|
).ToNot(HaveOccurred())
|
||||||
|
})
|
||||||
|
It("parses github without a branch", func() {
|
||||||
|
Expect(
|
||||||
|
GetURI("github:go-skynet/model-gallery/gpt4all-j.yaml@main", func(url string, i []byte) error {
|
||||||
|
Expect(url).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))
|
||||||
|
return nil
|
||||||
|
}),
|
||||||
|
).ToNot(HaveOccurred())
|
||||||
|
})
|
||||||
|
It("parses github with urls", func() {
|
||||||
|
Expect(
|
||||||
|
GetURI("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml", func(url string, i []byte) error {
|
||||||
|
Expect(url).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))
|
||||||
|
return nil
|
||||||
|
}),
|
||||||
|
).ToNot(HaveOccurred())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
13
pkg/utils/utils_suite_test.go
Normal file
13
pkg/utils/utils_suite_test.go
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
package utils_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUtils(t *testing.T) {
|
||||||
|
RegisterFailHandler(Fail)
|
||||||
|
RunSpecs(t, "Utils test suite")
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user