2024-04-17 21:33:49 +00:00
|
|
|
package startup
|
|
|
|
|
|
|
|
import (
|
|
|
|
"errors"
|
2024-06-13 14:12:46 +00:00
|
|
|
"fmt"
|
2024-07-11 13:04:05 +00:00
|
|
|
"net/url"
|
2024-04-17 21:33:49 +00:00
|
|
|
"os"
|
|
|
|
"path/filepath"
|
2024-06-22 06:17:41 +00:00
|
|
|
"strings"
|
2024-04-17 21:33:49 +00:00
|
|
|
|
2024-06-24 15:32:12 +00:00
|
|
|
"github.com/mudler/LocalAI/core/config"
|
|
|
|
"github.com/mudler/LocalAI/core/gallery"
|
2024-06-23 08:24:36 +00:00
|
|
|
"github.com/mudler/LocalAI/embedded"
|
|
|
|
"github.com/mudler/LocalAI/pkg/downloader"
|
|
|
|
"github.com/mudler/LocalAI/pkg/utils"
|
2024-04-17 21:33:49 +00:00
|
|
|
"github.com/rs/zerolog/log"
|
|
|
|
)
|
|
|
|
|
2024-06-13 14:12:46 +00:00
|
|
|
// InstallModels will preload models from the given list of URLs and galleries
|
2024-04-17 21:33:49 +00:00
|
|
|
// It will download the model if it is not already present in the model path
|
|
|
|
// It will also try to resolve if the model is an embedded model YAML configuration
|
2024-07-10 11:18:32 +00:00
|
|
|
func InstallModels(galleries []config.Gallery, modelLibraryURL string, modelPath string, enforceScan bool, downloadStatus func(string, string, string, float64), models ...string) error {
|
2024-06-13 14:12:46 +00:00
|
|
|
// create an error that groups all errors
|
|
|
|
var err error
|
|
|
|
|
2024-04-17 21:33:49 +00:00
|
|
|
for _, url := range models {
|
|
|
|
|
|
|
|
// As a best effort, try to resolve the model from the remote library
|
|
|
|
// if it's not resolved we try with the other method below
|
|
|
|
if modelLibraryURL != "" {
|
2024-06-04 14:32:47 +00:00
|
|
|
lib, err := embedded.GetRemoteLibraryShorteners(modelLibraryURL, modelPath)
|
2024-04-17 21:33:49 +00:00
|
|
|
if err == nil {
|
|
|
|
if lib[url] != "" {
|
|
|
|
log.Debug().Msgf("[startup] model configuration is defined remotely: %s (%s)", url, lib[url])
|
|
|
|
url = lib[url]
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
url = embedded.ModelShortURL(url)
|
|
|
|
switch {
|
|
|
|
case embedded.ExistsInModelsLibrary(url):
|
2024-06-13 14:12:46 +00:00
|
|
|
modelYAML, e := embedded.ResolveContent(url)
|
2024-04-17 21:33:49 +00:00
|
|
|
// If we resolve something, just save it to disk and continue
|
2024-06-13 14:12:46 +00:00
|
|
|
if e != nil {
|
|
|
|
log.Error().Err(e).Msg("error resolving model content")
|
|
|
|
err = errors.Join(err, e)
|
2024-04-17 21:33:49 +00:00
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
|
|
|
log.Debug().Msgf("[startup] resolved embedded model: %s", url)
|
|
|
|
md5Name := utils.MD5(url)
|
|
|
|
modelDefinitionFilePath := filepath.Join(modelPath, md5Name) + ".yaml"
|
2024-06-13 14:12:46 +00:00
|
|
|
if e := os.WriteFile(modelDefinitionFilePath, modelYAML, 0600); err != nil {
|
|
|
|
log.Error().Err(e).Str("filepath", modelDefinitionFilePath).Msg("error writing model definition")
|
|
|
|
err = errors.Join(err, e)
|
2024-04-17 21:33:49 +00:00
|
|
|
}
|
2024-06-22 06:17:41 +00:00
|
|
|
case downloader.LooksLikeOCI(url):
|
|
|
|
log.Debug().Msgf("[startup] resolved OCI model to download: %s", url)
|
|
|
|
|
|
|
|
// convert OCI image name to a file name.
|
|
|
|
ociName := strings.TrimPrefix(url, downloader.OCIPrefix)
|
|
|
|
ociName = strings.TrimPrefix(ociName, downloader.OllamaPrefix)
|
|
|
|
ociName = strings.ReplaceAll(ociName, "/", "__")
|
|
|
|
ociName = strings.ReplaceAll(ociName, ":", "__")
|
|
|
|
|
|
|
|
// check if file exists
|
|
|
|
if _, e := os.Stat(filepath.Join(modelPath, ociName)); errors.Is(e, os.ErrNotExist) {
|
|
|
|
modelDefinitionFilePath := filepath.Join(modelPath, ociName)
|
|
|
|
e := downloader.DownloadFile(url, modelDefinitionFilePath, "", 0, 0, func(fileName, current, total string, percent float64) {
|
|
|
|
utils.DisplayDownloadFunction(fileName, current, total, percent)
|
|
|
|
})
|
|
|
|
if e != nil {
|
|
|
|
log.Error().Err(e).Str("url", url).Str("filepath", modelDefinitionFilePath).Msg("error downloading model")
|
|
|
|
err = errors.Join(err, e)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
log.Info().Msgf("[startup] installed model from OCI repository: %s", ociName)
|
2024-04-17 21:33:49 +00:00
|
|
|
case downloader.LooksLikeURL(url):
|
2024-07-11 13:04:05 +00:00
|
|
|
log.Debug().Msgf("[startup] downloading %s", url)
|
|
|
|
|
|
|
|
// Extract filename from URL
|
|
|
|
fileName, e := filenameFromUrl(url)
|
|
|
|
if e != nil || fileName == "" {
|
|
|
|
fileName = utils.MD5(url)
|
|
|
|
if strings.HasSuffix(url, ".yaml") || strings.HasSuffix(url, ".yml") {
|
|
|
|
fileName = fileName + ".yaml"
|
|
|
|
}
|
|
|
|
log.Warn().Err(e).Str("url", url).Msg("error extracting filename from URL")
|
|
|
|
//err = errors.Join(err, e)
|
|
|
|
//continue
|
|
|
|
}
|
2024-04-17 21:33:49 +00:00
|
|
|
|
2024-07-11 13:04:05 +00:00
|
|
|
modelPath := filepath.Join(modelPath, fileName)
|
|
|
|
|
|
|
|
if e := utils.VerifyPath(fileName, modelPath); e != nil {
|
|
|
|
log.Error().Err(e).Str("filepath", modelPath).Msg("error verifying path")
|
|
|
|
err = errors.Join(err, e)
|
|
|
|
continue
|
|
|
|
}
|
2024-04-17 21:33:49 +00:00
|
|
|
|
|
|
|
// check if file exists
|
2024-07-11 13:04:05 +00:00
|
|
|
if _, e := os.Stat(modelPath); errors.Is(e, os.ErrNotExist) {
|
|
|
|
e := downloader.DownloadFile(url, modelPath, "", 0, 0, func(fileName, current, total string, percent float64) {
|
2024-04-17 21:33:49 +00:00
|
|
|
utils.DisplayDownloadFunction(fileName, current, total, percent)
|
|
|
|
})
|
2024-06-13 14:12:46 +00:00
|
|
|
if e != nil {
|
2024-07-11 13:04:05 +00:00
|
|
|
log.Error().Err(e).Str("url", url).Str("filepath", modelPath).Msg("error downloading model")
|
2024-06-13 14:12:46 +00:00
|
|
|
err = errors.Join(err, e)
|
2024-04-17 21:33:49 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
default:
|
2024-06-13 14:12:46 +00:00
|
|
|
if _, e := os.Stat(url); e == nil {
|
2024-04-17 21:33:49 +00:00
|
|
|
log.Debug().Msgf("[startup] resolved local model: %s", url)
|
|
|
|
// copy to modelPath
|
|
|
|
md5Name := utils.MD5(url)
|
|
|
|
|
2024-06-13 14:12:46 +00:00
|
|
|
modelYAML, e := os.ReadFile(url)
|
|
|
|
if e != nil {
|
|
|
|
log.Error().Err(e).Str("filepath", url).Msg("error reading model definition")
|
|
|
|
err = errors.Join(err, e)
|
2024-04-17 21:33:49 +00:00
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
|
|
|
modelDefinitionFilePath := filepath.Join(modelPath, md5Name) + ".yaml"
|
2024-06-13 14:12:46 +00:00
|
|
|
if e := os.WriteFile(modelDefinitionFilePath, modelYAML, 0600); e != nil {
|
2024-04-17 21:33:49 +00:00
|
|
|
log.Error().Err(err).Str("filepath", modelDefinitionFilePath).Msg("error loading model: %s")
|
2024-06-13 14:12:46 +00:00
|
|
|
err = errors.Join(err, e)
|
2024-04-17 21:33:49 +00:00
|
|
|
}
|
|
|
|
} else {
|
2024-06-13 14:12:46 +00:00
|
|
|
// Check if it's a model gallery, or print a warning
|
2024-07-10 11:18:32 +00:00
|
|
|
e, found := installModel(galleries, url, modelPath, downloadStatus, enforceScan)
|
2024-06-13 14:12:46 +00:00
|
|
|
if e != nil && found {
|
|
|
|
log.Error().Err(err).Msgf("[startup] failed installing model '%s'", url)
|
|
|
|
err = errors.Join(err, e)
|
|
|
|
} else if !found {
|
|
|
|
log.Warn().Msgf("[startup] failed resolving model '%s'", url)
|
|
|
|
err = errors.Join(err, fmt.Errorf("failed resolving model '%s'", url))
|
|
|
|
}
|
2024-04-17 21:33:49 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
2024-06-13 14:12:46 +00:00
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2024-07-10 11:18:32 +00:00
|
|
|
func installModel(galleries []config.Gallery, modelName, modelPath string, downloadStatus func(string, string, string, float64), enforceScan bool) (error, bool) {
|
2024-06-13 14:12:46 +00:00
|
|
|
models, err := gallery.AvailableGalleryModels(galleries, modelPath)
|
|
|
|
if err != nil {
|
|
|
|
return err, false
|
|
|
|
}
|
|
|
|
|
|
|
|
model := gallery.FindModel(models, modelName, modelPath)
|
|
|
|
if model == nil {
|
|
|
|
return err, false
|
|
|
|
}
|
|
|
|
|
|
|
|
if downloadStatus == nil {
|
|
|
|
downloadStatus = utils.DisplayDownloadFunction
|
|
|
|
}
|
|
|
|
|
|
|
|
log.Info().Str("model", modelName).Str("license", model.License).Msg("installing model")
|
2024-07-10 11:18:32 +00:00
|
|
|
err = gallery.InstallModelFromGallery(galleries, modelName, modelPath, gallery.GalleryModel{}, downloadStatus, enforceScan)
|
2024-06-13 14:12:46 +00:00
|
|
|
if err != nil {
|
|
|
|
return err, true
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil, true
|
2024-04-17 21:33:49 +00:00
|
|
|
}
|
2024-07-11 13:04:05 +00:00
|
|
|
|
|
|
|
func filenameFromUrl(urlstr string) (string, error) {
|
|
|
|
// strip anything after @
|
|
|
|
if strings.Contains(urlstr, "@") {
|
|
|
|
urlstr = strings.Split(urlstr, "@")[0]
|
|
|
|
}
|
|
|
|
|
|
|
|
u, err := url.Parse(urlstr)
|
|
|
|
if err != nil {
|
|
|
|
return "", fmt.Errorf("error due to parsing url: %w", err)
|
|
|
|
}
|
|
|
|
x, err := url.QueryUnescape(u.EscapedPath())
|
|
|
|
if err != nil {
|
|
|
|
return "", fmt.Errorf("error due to escaping: %w", err)
|
|
|
|
}
|
|
|
|
return filepath.Base(x), nil
|
|
|
|
}
|