refactor: gallery inconsistencies (#2647)

* refactor(gallery): move under core/

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix(unarchive): do not allow symlinks

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2024-06-24 17:32:12 +02:00 committed by GitHub
parent 69206fcd4b
commit a181dd0ebc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 93 additions and 54 deletions

View File

@ -12,7 +12,7 @@ import (
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/gallery" "github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/pkg/grpc" "github.com/mudler/LocalAI/pkg/grpc"
"github.com/mudler/LocalAI/pkg/grpc/proto" "github.com/mudler/LocalAI/pkg/grpc/proto"
model "github.com/mudler/LocalAI/pkg/model" model "github.com/mudler/LocalAI/pkg/model"

View File

@ -5,9 +5,10 @@ import (
"fmt" "fmt"
cliContext "github.com/mudler/LocalAI/core/cli/context" cliContext "github.com/mudler/LocalAI/core/cli/context"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/pkg/downloader" "github.com/mudler/LocalAI/pkg/downloader"
"github.com/mudler/LocalAI/pkg/gallery"
"github.com/mudler/LocalAI/pkg/startup" "github.com/mudler/LocalAI/pkg/startup"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/schollz/progressbar/v3" "github.com/schollz/progressbar/v3"
@ -34,7 +35,7 @@ type ModelsCMD struct {
} }
func (ml *ModelsList) Run(ctx *cliContext.Context) error { func (ml *ModelsList) Run(ctx *cliContext.Context) error {
var galleries []gallery.Gallery var galleries []config.Gallery
if err := json.Unmarshal([]byte(ml.Galleries), &galleries); err != nil { if err := json.Unmarshal([]byte(ml.Galleries), &galleries); err != nil {
log.Error().Err(err).Msg("unable to load galleries") log.Error().Err(err).Msg("unable to load galleries")
} }
@ -54,7 +55,7 @@ func (ml *ModelsList) Run(ctx *cliContext.Context) error {
} }
func (mi *ModelsInstall) Run(ctx *cliContext.Context) error { func (mi *ModelsInstall) Run(ctx *cliContext.Context) error {
var galleries []gallery.Gallery var galleries []config.Gallery
if err := json.Unmarshal([]byte(mi.Galleries), &galleries); err != nil { if err := json.Unmarshal([]byte(mi.Galleries), &galleries); err != nil {
log.Error().Err(err).Msg("unable to load galleries") log.Error().Err(err).Msg("unable to load galleries")
} }

View File

@ -6,7 +6,6 @@ import (
"encoding/json" "encoding/json"
"time" "time"
"github.com/mudler/LocalAI/pkg/gallery"
"github.com/mudler/LocalAI/pkg/xsysinfo" "github.com/mudler/LocalAI/pkg/xsysinfo"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
@ -36,7 +35,7 @@ type ApplicationConfig struct {
ModelLibraryURL string ModelLibraryURL string
Galleries []gallery.Gallery Galleries []Gallery
BackendAssets embed.FS BackendAssets embed.FS
AssetsDestination string AssetsDestination string
@ -180,10 +179,10 @@ func WithBackendAssets(f embed.FS) AppOption {
func WithStringGalleries(galls string) AppOption { func WithStringGalleries(galls string) AppOption {
return func(o *ApplicationConfig) { return func(o *ApplicationConfig) {
if galls == "" { if galls == "" {
o.Galleries = []gallery.Gallery{} o.Galleries = []Gallery{}
return return
} }
var galleries []gallery.Gallery var galleries []Gallery
if err := json.Unmarshal([]byte(galls), &galleries); err != nil { if err := json.Unmarshal([]byte(galls), &galleries); err != nil {
log.Error().Err(err).Msg("failed loading galleries") log.Error().Err(err).Msg("failed loading galleries")
} }
@ -191,7 +190,7 @@ func WithStringGalleries(galls string) AppOption {
} }
} }
func WithGalleries(galleries []gallery.Gallery) AppOption { func WithGalleries(galleries []Gallery) AppOption {
return func(o *ApplicationConfig) { return func(o *ApplicationConfig) {
o.Galleries = append(o.Galleries, galleries...) o.Galleries = append(o.Galleries, galleries...)
} }

View File

@ -390,10 +390,6 @@ func (c *BackendConfig) Validate() bool {
} }
} }
if c.Name == "" {
return false
}
if c.Backend != "" { if c.Backend != "" {
// a regex that checks that is a string name with no special characters, except '-' and '_' // a regex that checks that is a string name with no special characters, except '-' and '_'
re := regexp.MustCompile(`^[a-zA-Z0-9-_]+$`) re := regexp.MustCompile(`^[a-zA-Z0-9-_]+$`)

View File

@ -16,7 +16,8 @@ var _ = Describe("Test cases for config related functions", func() {
Expect(err).To(BeNil()) Expect(err).To(BeNil())
defer os.Remove(tmp.Name()) defer os.Remove(tmp.Name())
_, err = tmp.WriteString( _, err = tmp.WriteString(
`backend: "foo-bar" `backend: "../foo-bar"
name: "foo"
parameters: parameters:
model: "foo-bar"`) model: "foo-bar"`)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())

6
core/config/gallery.go Normal file
View File

@ -0,0 +1,6 @@
package config
type Gallery struct {
URL string `json:"url" yaml:"url"`
Name string `json:"name" yaml:"name"`
}

View File

@ -8,18 +8,14 @@ import (
"strings" "strings"
"github.com/imdario/mergo" "github.com/imdario/mergo"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/pkg/downloader" "github.com/mudler/LocalAI/pkg/downloader"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
) )
type Gallery struct {
URL string `json:"url" yaml:"url"`
Name string `json:"name" yaml:"name"`
}
// Installs a model from the gallery // Installs a model from the gallery
func InstallModelFromGallery(galleries []Gallery, name string, basePath string, req GalleryModel, downloadStatus func(string, string, string, float64)) error { func InstallModelFromGallery(galleries []config.Gallery, name string, basePath string, req GalleryModel, downloadStatus func(string, string, string, float64)) error {
applyModel := func(model *GalleryModel) error { applyModel := func(model *GalleryModel) error {
name = strings.ReplaceAll(name, string(os.PathSeparator), "__") name = strings.ReplaceAll(name, string(os.PathSeparator), "__")
@ -117,7 +113,7 @@ func FindModel(models []*GalleryModel, name string, basePath string) *GalleryMod
// List available models // List available models
// Models galleries are a list of yaml files that are hosted on a remote server (for example github). // Models galleries are a list of yaml files that are hosted on a remote server (for example github).
// Each yaml file contains a list of models that can be downloaded and optionally overrides to define a new model setting. // Each yaml file contains a list of models that can be downloaded and optionally overrides to define a new model setting.
func AvailableGalleryModels(galleries []Gallery, basePath string) ([]*GalleryModel, error) { func AvailableGalleryModels(galleries []config.Gallery, basePath string) ([]*GalleryModel, error) {
var models []*GalleryModel var models []*GalleryModel
// Get models from galleries // Get models from galleries
@ -146,7 +142,7 @@ func findGalleryURLFromReferenceURL(url string, basePath string) (string, error)
return refFile, err return refFile, err
} }
func getGalleryModels(gallery Gallery, basePath string) ([]*GalleryModel, error) { func getGalleryModels(gallery config.Gallery, basePath string) ([]*GalleryModel, error) {
var models []*GalleryModel = []*GalleryModel{} var models []*GalleryModel = []*GalleryModel{}
if strings.HasSuffix(gallery.URL, ".ref") { if strings.HasSuffix(gallery.URL, ".ref") {

View File

@ -6,8 +6,10 @@ import (
"path/filepath" "path/filepath"
"github.com/imdario/mergo" "github.com/imdario/mergo"
lconfig "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/pkg/downloader" "github.com/mudler/LocalAI/pkg/downloader"
"github.com/mudler/LocalAI/pkg/utils" "github.com/mudler/LocalAI/pkg/utils"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
) )
@ -172,6 +174,15 @@ func InstallModel(basePath, nameOverride string, config *Config, configOverrides
return fmt.Errorf("failed to marshal updated config YAML: %v", err) return fmt.Errorf("failed to marshal updated config YAML: %v", err)
} }
backendConfig := lconfig.BackendConfig{}
err = yaml.Unmarshal(updatedConfigYAML, &backendConfig)
if err != nil {
return fmt.Errorf("failed to unmarshal updated config YAML: %v", err)
}
if !backendConfig.Validate() {
return fmt.Errorf("failed to validate updated config YAML")
}
err = os.WriteFile(configFilePath, updatedConfigYAML, 0600) err = os.WriteFile(configFilePath, updatedConfigYAML, 0600)
if err != nil { if err != nil {
return fmt.Errorf("failed to write updated config file: %v", err) return fmt.Errorf("failed to write updated config file: %v", err)

View File

@ -5,7 +5,8 @@ import (
"os" "os"
"path/filepath" "path/filepath"
. "github.com/mudler/LocalAI/pkg/gallery" "github.com/mudler/LocalAI/core/config"
. "github.com/mudler/LocalAI/core/gallery"
. "github.com/onsi/ginkgo/v2" . "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
@ -54,7 +55,7 @@ var _ = Describe("Model test", func() {
err = os.WriteFile(galleryFilePath, out, 0600) err = os.WriteFile(galleryFilePath, out, 0600)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(filepath.IsAbs(galleryFilePath)).To(BeTrue(), galleryFilePath) Expect(filepath.IsAbs(galleryFilePath)).To(BeTrue(), galleryFilePath)
galleries := []Gallery{ galleries := []config.Gallery{
{ {
Name: "test", Name: "test",
URL: "file://" + galleryFilePath, URL: "file://" + galleryFilePath,

View File

@ -1,5 +1,7 @@
package gallery package gallery
import "github.com/mudler/LocalAI/core/config"
type GalleryOp struct { type GalleryOp struct {
Id string Id string
GalleryModelName string GalleryModelName string
@ -7,7 +9,7 @@ type GalleryOp struct {
Delete bool Delete bool
Req GalleryModel Req GalleryModel
Galleries []Gallery Galleries []config.Gallery
} }
type GalleryOpStatus struct { type GalleryOpStatus struct {

View File

@ -3,6 +3,8 @@ package gallery
import ( import (
"fmt" "fmt"
"strings" "strings"
"github.com/mudler/LocalAI/core/config"
) )
// GalleryModel is the struct used to represent a model in the gallery returned by the endpoint. // GalleryModel is the struct used to represent a model in the gallery returned by the endpoint.
@ -23,7 +25,7 @@ type GalleryModel struct {
// AdditionalFiles are used to add additional files to the model // AdditionalFiles are used to add additional files to the model
AdditionalFiles []File `json:"files,omitempty" yaml:"files,omitempty"` AdditionalFiles []File `json:"files,omitempty" yaml:"files,omitempty"`
// Gallery is a reference to the gallery which contains the model // Gallery is a reference to the gallery which contains the model
Gallery Gallery `json:"gallery,omitempty" yaml:"gallery,omitempty"` Gallery config.Gallery `json:"gallery,omitempty" yaml:"gallery,omitempty"`
// Installed is used to indicate if the model is installed or not // Installed is used to indicate if the model is installed or not
Installed bool `json:"installed,omitempty" yaml:"installed,omitempty"` Installed bool `json:"installed,omitempty" yaml:"installed,omitempty"`
} }

View File

@ -1,7 +1,7 @@
package gallery_test package gallery_test
import ( import (
. "github.com/mudler/LocalAI/pkg/gallery" . "github.com/mudler/LocalAI/core/gallery"
. "github.com/onsi/ginkgo/v2" . "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
) )

View File

@ -19,8 +19,8 @@ import (
"github.com/mudler/LocalAI/core/startup" "github.com/mudler/LocalAI/core/startup"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/pkg/downloader" "github.com/mudler/LocalAI/pkg/downloader"
"github.com/mudler/LocalAI/pkg/gallery"
"github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/model"
. "github.com/onsi/ginkgo/v2" . "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
@ -247,7 +247,7 @@ var _ = Describe("API test", func() {
err = os.WriteFile(filepath.Join(modelDir, "gallery_simple.yaml"), out, 0600) err = os.WriteFile(filepath.Join(modelDir, "gallery_simple.yaml"), out, 0600)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
galleries := []gallery.Gallery{ galleries := []config.Gallery{
{ {
Name: "test", Name: "test",
URL: "file://" + filepath.Join(modelDir, "gallery_simple.yaml"), URL: "file://" + filepath.Join(modelDir, "gallery_simple.yaml"),
@ -603,7 +603,7 @@ var _ = Describe("API test", func() {
c, cancel = context.WithCancel(context.Background()) c, cancel = context.WithCancel(context.Background())
galleries := []gallery.Gallery{ galleries := []config.Gallery{
{ {
Name: "model-gallery", Name: "model-gallery",
URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/index.yaml", URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/index.yaml",

View File

@ -6,8 +6,8 @@ import (
"github.com/chasefleming/elem-go" "github.com/chasefleming/elem-go"
"github.com/chasefleming/elem-go/attrs" "github.com/chasefleming/elem-go/attrs"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/pkg/gallery"
"github.com/mudler/LocalAI/pkg/xsync" "github.com/mudler/LocalAI/pkg/xsync"
) )

View File

@ -7,13 +7,14 @@ import (
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/pkg/gallery"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
type ModelGalleryEndpointService struct { type ModelGalleryEndpointService struct {
galleries []gallery.Gallery galleries []config.Gallery
modelPath string modelPath string
galleryApplier *services.GalleryService galleryApplier *services.GalleryService
} }
@ -24,7 +25,7 @@ type GalleryModel struct {
gallery.GalleryModel gallery.GalleryModel
} }
func CreateModelGalleryEndpointService(galleries []gallery.Gallery, modelPath string, galleryApplier *services.GalleryService) ModelGalleryEndpointService { func CreateModelGalleryEndpointService(galleries []config.Gallery, modelPath string, galleryApplier *services.GalleryService) ModelGalleryEndpointService {
return ModelGalleryEndpointService{ return ModelGalleryEndpointService{
galleries: galleries, galleries: galleries,
modelPath: modelPath, modelPath: modelPath,
@ -129,12 +130,12 @@ func (mgs *ModelGalleryEndpointService) ListModelGalleriesEndpoint() func(c *fib
func (mgs *ModelGalleryEndpointService) AddModelGalleryEndpoint() func(c *fiber.Ctx) error { func (mgs *ModelGalleryEndpointService) AddModelGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
input := new(gallery.Gallery) input := new(config.Gallery)
// Get input data from the request body // Get input data from the request body
if err := c.BodyParser(input); err != nil { if err := c.BodyParser(input); err != nil {
return err return err
} }
if slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool { if slices.ContainsFunc(mgs.galleries, func(gallery config.Gallery) bool {
return gallery.Name == input.Name return gallery.Name == input.Name
}) { }) {
return fmt.Errorf("%s already exists", input.Name) return fmt.Errorf("%s already exists", input.Name)
@ -151,17 +152,17 @@ func (mgs *ModelGalleryEndpointService) AddModelGalleryEndpoint() func(c *fiber.
func (mgs *ModelGalleryEndpointService) RemoveModelGalleryEndpoint() func(c *fiber.Ctx) error { func (mgs *ModelGalleryEndpointService) RemoveModelGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
input := new(gallery.Gallery) input := new(config.Gallery)
// Get input data from the request body // Get input data from the request body
if err := c.BodyParser(input); err != nil { if err := c.BodyParser(input); err != nil {
return err return err
} }
if !slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool { if !slices.ContainsFunc(mgs.galleries, func(gallery config.Gallery) bool {
return gallery.Name == input.Name return gallery.Name == input.Name
}) { }) {
return fmt.Errorf("%s is not currently registered", input.Name) return fmt.Errorf("%s is not currently registered", input.Name)
} }
mgs.galleries = slices.DeleteFunc(mgs.galleries, func(gallery gallery.Gallery) bool { mgs.galleries = slices.DeleteFunc(mgs.galleries, func(gallery config.Gallery) bool {
return gallery.Name == input.Name return gallery.Name == input.Name
}) })
return c.Send(nil) return c.Send(nil)

View File

@ -3,8 +3,8 @@ package localai
import ( import (
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/internal" "github.com/mudler/LocalAI/internal"
"github.com/mudler/LocalAI/pkg/gallery"
"github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/model"
) )

View File

@ -257,5 +257,9 @@ func mergeRequestWithConfig(modelFile string, input *schema.OpenAIRequest, cm *c
// Set the parameters for the language model prediction // Set the parameters for the language model prediction
updateRequestConfig(cfg, input) updateRequestConfig(cfg, input)
if !cfg.Validate() {
return nil, nil, fmt.Errorf("failed to validate config")
}
return cfg, input, err return cfg, input, err
} }

View File

@ -7,11 +7,11 @@ import (
"strings" "strings"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/http/elements" "github.com/mudler/LocalAI/core/http/elements"
"github.com/mudler/LocalAI/core/http/endpoints/localai" "github.com/mudler/LocalAI/core/http/endpoints/localai"
"github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/internal" "github.com/mudler/LocalAI/internal"
"github.com/mudler/LocalAI/pkg/gallery"
"github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/xsync" "github.com/mudler/LocalAI/pkg/xsync"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"

View File

@ -9,7 +9,7 @@ import (
"sync" "sync"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/pkg/gallery" "github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/pkg/startup" "github.com/mudler/LocalAI/pkg/startup"
"github.com/mudler/LocalAI/pkg/utils" "github.com/mudler/LocalAI/pkg/utils"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
@ -96,6 +96,7 @@ func (g *GalleryService) Start(c context.Context, cl *config.BackendConfigLoader
// delete a model // delete a model
if op.Delete { if op.Delete {
modelConfig := &config.BackendConfig{} modelConfig := &config.BackendConfig{}
// Galleryname is the name of the model in this case // Galleryname is the name of the model in this case
dat, err := os.ReadFile(filepath.Join(g.appConfig.ModelPath, op.GalleryModelName+".yaml")) dat, err := os.ReadFile(filepath.Join(g.appConfig.ModelPath, op.GalleryModelName+".yaml"))
if err != nil { if err != nil {
@ -174,7 +175,7 @@ type galleryModel struct {
ID string `json:"id"` ID string `json:"id"`
} }
func processRequests(modelPath string, galleries []gallery.Gallery, requests []galleryModel) error { func processRequests(modelPath string, galleries []config.Gallery, requests []galleryModel) error {
var err error var err error
for _, r := range requests { for _, r := range requests {
utils.ResetDownloadTimers() utils.ResetDownloadTimers()
@ -189,7 +190,7 @@ func processRequests(modelPath string, galleries []gallery.Gallery, requests []g
return err return err
} }
func ApplyGalleryFromFile(modelPath, s string, galleries []gallery.Gallery) error { func ApplyGalleryFromFile(modelPath, s string, galleries []config.Gallery) error {
dat, err := os.ReadFile(s) dat, err := os.ReadFile(s)
if err != nil { if err != nil {
return err return err
@ -203,7 +204,7 @@ func ApplyGalleryFromFile(modelPath, s string, galleries []gallery.Gallery) erro
return processRequests(modelPath, galleries, requests) return processRequests(modelPath, galleries, requests)
} }
func ApplyGalleryFromString(modelPath, s string, galleries []gallery.Gallery) error { func ApplyGalleryFromString(modelPath, s string, galleries []config.Gallery) error {
var requests []galleryModel var requests []galleryModel
err := json.Unmarshal([]byte(s), &requests) err := json.Unmarshal([]byte(s), &requests)
if err != nil { if err != nil {

View File

@ -13,6 +13,7 @@ import (
"github.com/klauspost/cpuid/v2" "github.com/klauspost/cpuid/v2"
grpc "github.com/mudler/LocalAI/pkg/grpc" grpc "github.com/mudler/LocalAI/pkg/grpc"
"github.com/mudler/LocalAI/pkg/library" "github.com/mudler/LocalAI/pkg/library"
"github.com/mudler/LocalAI/pkg/utils"
"github.com/mudler/LocalAI/pkg/xsysinfo" "github.com/mudler/LocalAI/pkg/xsysinfo"
"github.com/phayes/freeport" "github.com/phayes/freeport"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@ -309,6 +310,9 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string
} }
} else { } else {
grpcProcess := backendPath(o.assetDir, backend) grpcProcess := backendPath(o.assetDir, backend)
if err := utils.VerifyPath(grpcProcess, o.assetDir); err != nil {
return "", fmt.Errorf("grpc process not found in assetdir: %s", err.Error())
}
if autoDetect { if autoDetect {
// autoDetect GRPC process to start based on system capabilities // autoDetect GRPC process to start based on system capabilities

View File

@ -7,9 +7,10 @@ import (
"path/filepath" "path/filepath"
"strings" "strings"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/embedded" "github.com/mudler/LocalAI/embedded"
"github.com/mudler/LocalAI/pkg/downloader" "github.com/mudler/LocalAI/pkg/downloader"
"github.com/mudler/LocalAI/pkg/gallery"
"github.com/mudler/LocalAI/pkg/utils" "github.com/mudler/LocalAI/pkg/utils"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
@ -17,7 +18,7 @@ import (
// InstallModels will preload models from the given list of URLs and galleries // InstallModels will preload models from the given list of URLs and galleries
// It will download the model if it is not already present in the model path // It will download the model if it is not already present in the model path
// It will also try to resolve if the model is an embedded model YAML configuration // It will also try to resolve if the model is an embedded model YAML configuration
func InstallModels(galleries []gallery.Gallery, modelLibraryURL string, modelPath string, downloadStatus func(string, string, string, float64), models ...string) error { func InstallModels(galleries []config.Gallery, modelLibraryURL string, modelPath string, downloadStatus func(string, string, string, float64), models ...string) error {
// create an error that groups all errors // create an error that groups all errors
var err error var err error
@ -126,7 +127,7 @@ func InstallModels(galleries []gallery.Gallery, modelLibraryURL string, modelPat
return err return err
} }
func installModel(galleries []gallery.Gallery, modelName, modelPath string, downloadStatus func(string, string, string, float64)) (error, bool) { func installModel(galleries []config.Gallery, modelName, modelPath string, downloadStatus func(string, string, string, float64)) (error, bool) {
models, err := gallery.AvailableGalleryModels(galleries, modelPath) models, err := gallery.AvailableGalleryModels(galleries, modelPath)
if err != nil { if err != nil {
return err, false return err, false

View File

@ -5,7 +5,7 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"github.com/mudler/LocalAI/pkg/gallery" "github.com/mudler/LocalAI/core/config"
. "github.com/mudler/LocalAI/pkg/startup" . "github.com/mudler/LocalAI/pkg/startup"
"github.com/mudler/LocalAI/pkg/utils" "github.com/mudler/LocalAI/pkg/utils"
@ -22,7 +22,7 @@ var _ = Describe("Preload test", func() {
libraryURL := "https://raw.githubusercontent.com/mudler/LocalAI/master/embedded/model_library.yaml" libraryURL := "https://raw.githubusercontent.com/mudler/LocalAI/master/embedded/model_library.yaml"
fileName := fmt.Sprintf("%s.yaml", "1701d57f28d47552516c2b6ecc3cc719") fileName := fmt.Sprintf("%s.yaml", "1701d57f28d47552516c2b6ecc3cc719")
InstallModels([]gallery.Gallery{}, libraryURL, tmpdir, nil, "phi-2") InstallModels([]config.Gallery{}, libraryURL, tmpdir, nil, "phi-2")
resultFile := filepath.Join(tmpdir, fileName) resultFile := filepath.Join(tmpdir, fileName)
@ -38,7 +38,7 @@ var _ = Describe("Preload test", func() {
url := "https://raw.githubusercontent.com/mudler/LocalAI/master/examples/configurations/phi-2.yaml" url := "https://raw.githubusercontent.com/mudler/LocalAI/master/examples/configurations/phi-2.yaml"
fileName := fmt.Sprintf("%s.yaml", utils.MD5(url)) fileName := fmt.Sprintf("%s.yaml", utils.MD5(url))
InstallModels([]gallery.Gallery{}, "", tmpdir, nil, url) InstallModels([]config.Gallery{}, "", tmpdir, nil, url)
resultFile := filepath.Join(tmpdir, fileName) resultFile := filepath.Join(tmpdir, fileName)
@ -52,7 +52,7 @@ var _ = Describe("Preload test", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
url := "phi-2" url := "phi-2"
InstallModels([]gallery.Gallery{}, "", tmpdir, nil, url) InstallModels([]config.Gallery{}, "", tmpdir, nil, url)
entry, err := os.ReadDir(tmpdir) entry, err := os.ReadDir(tmpdir)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -70,7 +70,7 @@ var _ = Describe("Preload test", func() {
url := "mistral-openorca" url := "mistral-openorca"
fileName := fmt.Sprintf("%s.yaml", utils.MD5(url)) fileName := fmt.Sprintf("%s.yaml", utils.MD5(url))
InstallModels([]gallery.Gallery{}, "", tmpdir, nil, url) InstallModels([]config.Gallery{}, "", tmpdir, nil, url)
resultFile := filepath.Join(tmpdir, fileName) resultFile := filepath.Join(tmpdir, fileName)

View File

@ -2,6 +2,7 @@ package utils
import ( import (
"fmt" "fmt"
"os"
"github.com/mholt/archiver/v3" "github.com/mholt/archiver/v3"
) )
@ -52,5 +53,17 @@ func ExtractArchive(archive, dst string) error {
case *archiver.TarZstd: case *archiver.TarZstd:
v.Tar = mytar v.Tar = mytar
} }
err = archiver.Walk(archive, func(f archiver.File) error {
if f.FileInfo.Mode()&os.ModeSymlink != 0 {
return fmt.Errorf("archive contains a symlink")
}
return nil
})
if err != nil {
return err
}
return un.Unarchive(archive, dst) return un.Unarchive(archive, dst)
} }