mirror of
https://github.com/mudler/LocalAI.git
synced 2024-12-26 07:41:05 +00:00
b6b8ab6c21
* feat(models): pull models from urls When using `run` now we can point directly to hf models via URL, for instance: ```bash local-ai run huggingface://TheBloke/TinyLlama-1.1B-Chat-v0.3-GGUF/tinyllama-1.1b-chat-v0.3.Q2_K.gguf ``` Will pull the gguf model and place it in the models folder - of course this depends on the fact that the gguf file should be automatically detected by our guesser mechanism in order to this to make effective. Similarly now galleries can refer to single files in the API requests. This also changes the download code and `yaml` files now are treated in the same way, so now config files are saved with the appropriate name (and not hashed anymore). Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Adapt tests Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
187 lines
6.3 KiB
Go
187 lines
6.3 KiB
Go
package startup
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"net/url"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
|
|
"github.com/mudler/LocalAI/core/config"
|
|
"github.com/mudler/LocalAI/core/gallery"
|
|
"github.com/mudler/LocalAI/embedded"
|
|
"github.com/mudler/LocalAI/pkg/downloader"
|
|
"github.com/mudler/LocalAI/pkg/utils"
|
|
"github.com/rs/zerolog/log"
|
|
)
|
|
|
|
// InstallModels will preload models from the given list of URLs and galleries
|
|
// It will download the model if it is not already present in the model path
|
|
// It will also try to resolve if the model is an embedded model YAML configuration
|
|
func InstallModels(galleries []config.Gallery, modelLibraryURL string, modelPath string, enforceScan bool, downloadStatus func(string, string, string, float64), models ...string) error {
|
|
// create an error that groups all errors
|
|
var err error
|
|
|
|
for _, url := range models {
|
|
|
|
// As a best effort, try to resolve the model from the remote library
|
|
// if it's not resolved we try with the other method below
|
|
if modelLibraryURL != "" {
|
|
lib, err := embedded.GetRemoteLibraryShorteners(modelLibraryURL, modelPath)
|
|
if err == nil {
|
|
if lib[url] != "" {
|
|
log.Debug().Msgf("[startup] model configuration is defined remotely: %s (%s)", url, lib[url])
|
|
url = lib[url]
|
|
}
|
|
}
|
|
}
|
|
|
|
url = embedded.ModelShortURL(url)
|
|
switch {
|
|
case embedded.ExistsInModelsLibrary(url):
|
|
modelYAML, e := embedded.ResolveContent(url)
|
|
// If we resolve something, just save it to disk and continue
|
|
if e != nil {
|
|
log.Error().Err(e).Msg("error resolving model content")
|
|
err = errors.Join(err, e)
|
|
continue
|
|
}
|
|
|
|
log.Debug().Msgf("[startup] resolved embedded model: %s", url)
|
|
md5Name := utils.MD5(url)
|
|
modelDefinitionFilePath := filepath.Join(modelPath, md5Name) + ".yaml"
|
|
if e := os.WriteFile(modelDefinitionFilePath, modelYAML, 0600); err != nil {
|
|
log.Error().Err(e).Str("filepath", modelDefinitionFilePath).Msg("error writing model definition")
|
|
err = errors.Join(err, e)
|
|
}
|
|
case downloader.LooksLikeOCI(url):
|
|
log.Debug().Msgf("[startup] resolved OCI model to download: %s", url)
|
|
|
|
// convert OCI image name to a file name.
|
|
ociName := strings.TrimPrefix(url, downloader.OCIPrefix)
|
|
ociName = strings.TrimPrefix(ociName, downloader.OllamaPrefix)
|
|
ociName = strings.ReplaceAll(ociName, "/", "__")
|
|
ociName = strings.ReplaceAll(ociName, ":", "__")
|
|
|
|
// check if file exists
|
|
if _, e := os.Stat(filepath.Join(modelPath, ociName)); errors.Is(e, os.ErrNotExist) {
|
|
modelDefinitionFilePath := filepath.Join(modelPath, ociName)
|
|
e := downloader.DownloadFile(url, modelDefinitionFilePath, "", 0, 0, func(fileName, current, total string, percent float64) {
|
|
utils.DisplayDownloadFunction(fileName, current, total, percent)
|
|
})
|
|
if e != nil {
|
|
log.Error().Err(e).Str("url", url).Str("filepath", modelDefinitionFilePath).Msg("error downloading model")
|
|
err = errors.Join(err, e)
|
|
}
|
|
}
|
|
|
|
log.Info().Msgf("[startup] installed model from OCI repository: %s", ociName)
|
|
case downloader.LooksLikeURL(url):
|
|
log.Debug().Msgf("[startup] downloading %s", url)
|
|
|
|
// Extract filename from URL
|
|
fileName, e := filenameFromUrl(url)
|
|
if e != nil || fileName == "" {
|
|
fileName = utils.MD5(url)
|
|
if strings.HasSuffix(url, ".yaml") || strings.HasSuffix(url, ".yml") {
|
|
fileName = fileName + ".yaml"
|
|
}
|
|
log.Warn().Err(e).Str("url", url).Msg("error extracting filename from URL")
|
|
//err = errors.Join(err, e)
|
|
//continue
|
|
}
|
|
|
|
modelPath := filepath.Join(modelPath, fileName)
|
|
|
|
if e := utils.VerifyPath(fileName, modelPath); e != nil {
|
|
log.Error().Err(e).Str("filepath", modelPath).Msg("error verifying path")
|
|
err = errors.Join(err, e)
|
|
continue
|
|
}
|
|
|
|
// check if file exists
|
|
if _, e := os.Stat(modelPath); errors.Is(e, os.ErrNotExist) {
|
|
e := downloader.DownloadFile(url, modelPath, "", 0, 0, func(fileName, current, total string, percent float64) {
|
|
utils.DisplayDownloadFunction(fileName, current, total, percent)
|
|
})
|
|
if e != nil {
|
|
log.Error().Err(e).Str("url", url).Str("filepath", modelPath).Msg("error downloading model")
|
|
err = errors.Join(err, e)
|
|
}
|
|
}
|
|
default:
|
|
if _, e := os.Stat(url); e == nil {
|
|
log.Debug().Msgf("[startup] resolved local model: %s", url)
|
|
// copy to modelPath
|
|
md5Name := utils.MD5(url)
|
|
|
|
modelYAML, e := os.ReadFile(url)
|
|
if e != nil {
|
|
log.Error().Err(e).Str("filepath", url).Msg("error reading model definition")
|
|
err = errors.Join(err, e)
|
|
continue
|
|
}
|
|
|
|
modelDefinitionFilePath := filepath.Join(modelPath, md5Name) + ".yaml"
|
|
if e := os.WriteFile(modelDefinitionFilePath, modelYAML, 0600); e != nil {
|
|
log.Error().Err(err).Str("filepath", modelDefinitionFilePath).Msg("error loading model: %s")
|
|
err = errors.Join(err, e)
|
|
}
|
|
} else {
|
|
// Check if it's a model gallery, or print a warning
|
|
e, found := installModel(galleries, url, modelPath, downloadStatus, enforceScan)
|
|
if e != nil && found {
|
|
log.Error().Err(err).Msgf("[startup] failed installing model '%s'", url)
|
|
err = errors.Join(err, e)
|
|
} else if !found {
|
|
log.Warn().Msgf("[startup] failed resolving model '%s'", url)
|
|
err = errors.Join(err, fmt.Errorf("failed resolving model '%s'", url))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return err
|
|
}
|
|
|
|
func installModel(galleries []config.Gallery, modelName, modelPath string, downloadStatus func(string, string, string, float64), enforceScan bool) (error, bool) {
|
|
models, err := gallery.AvailableGalleryModels(galleries, modelPath)
|
|
if err != nil {
|
|
return err, false
|
|
}
|
|
|
|
model := gallery.FindModel(models, modelName, modelPath)
|
|
if model == nil {
|
|
return err, false
|
|
}
|
|
|
|
if downloadStatus == nil {
|
|
downloadStatus = utils.DisplayDownloadFunction
|
|
}
|
|
|
|
log.Info().Str("model", modelName).Str("license", model.License).Msg("installing model")
|
|
err = gallery.InstallModelFromGallery(galleries, modelName, modelPath, gallery.GalleryModel{}, downloadStatus, enforceScan)
|
|
if err != nil {
|
|
return err, true
|
|
}
|
|
|
|
return nil, true
|
|
}
|
|
|
|
func filenameFromUrl(urlstr string) (string, error) {
|
|
// strip anything after @
|
|
if strings.Contains(urlstr, "@") {
|
|
urlstr = strings.Split(urlstr, "@")[0]
|
|
}
|
|
|
|
u, err := url.Parse(urlstr)
|
|
if err != nil {
|
|
return "", fmt.Errorf("error due to parsing url: %w", err)
|
|
}
|
|
x, err := url.QueryUnescape(u.EscapedPath())
|
|
if err != nil {
|
|
return "", fmt.Errorf("error due to escaping: %w", err)
|
|
}
|
|
return filepath.Base(x), nil
|
|
}
|