2023-05-18 13:59:03 +00:00
|
|
|
package gallery
|
|
|
|
|
|
|
|
import (
|
2024-01-05 17:04:46 +00:00
|
|
|
"crypto/sha256"
|
2023-05-18 13:59:03 +00:00
|
|
|
"fmt"
|
2024-01-05 17:04:46 +00:00
|
|
|
"hash"
|
|
|
|
"io"
|
2023-05-18 13:59:03 +00:00
|
|
|
"os"
|
|
|
|
"path/filepath"
|
2024-01-05 17:04:46 +00:00
|
|
|
"strconv"
|
2023-05-18 13:59:03 +00:00
|
|
|
|
2023-06-22 15:53:10 +00:00
|
|
|
"github.com/go-skynet/LocalAI/pkg/utils"
|
2023-05-20 15:03:53 +00:00
|
|
|
"github.com/imdario/mergo"
|
2023-05-18 13:59:03 +00:00
|
|
|
"github.com/rs/zerolog/log"
|
|
|
|
"gopkg.in/yaml.v2"
|
|
|
|
)
|
|
|
|
|
|
|
|
/*
|
|
|
|
|
|
|
|
description: |
|
|
|
|
foo
|
|
|
|
license: ""
|
|
|
|
|
|
|
|
urls:
|
|
|
|
-
|
|
|
|
-
|
|
|
|
|
|
|
|
name: "bar"
|
|
|
|
|
|
|
|
config_file: |
|
|
|
|
# Note, name will be injected. or generated by the alias wanted by the user
|
|
|
|
threads: 14
|
|
|
|
|
|
|
|
files:
|
|
|
|
- filename: ""
|
|
|
|
sha: ""
|
|
|
|
uri: ""
|
|
|
|
|
|
|
|
prompt_templates:
|
|
|
|
- name: ""
|
|
|
|
content: ""
|
|
|
|
|
|
|
|
*/
|
2024-01-05 17:04:46 +00:00
|
|
|
// Config is the model configuration which contains all the model details
|
2023-06-25 20:51:02 +00:00
|
|
|
// This configuration is read from the gallery endpoint and is used to download and install the model
|
2024-01-05 17:04:46 +00:00
|
|
|
type Config struct {
|
2023-05-18 13:59:03 +00:00
|
|
|
Description string `yaml:"description"`
|
|
|
|
License string `yaml:"license"`
|
|
|
|
URLs []string `yaml:"urls"`
|
|
|
|
Name string `yaml:"name"`
|
|
|
|
ConfigFile string `yaml:"config_file"`
|
|
|
|
Files []File `yaml:"files"`
|
|
|
|
PromptTemplates []PromptTemplate `yaml:"prompt_templates"`
|
|
|
|
}
|
|
|
|
|
|
|
|
type File struct {
|
2023-05-19 06:31:11 +00:00
|
|
|
Filename string `yaml:"filename" json:"filename"`
|
|
|
|
SHA256 string `yaml:"sha256" json:"sha256"`
|
|
|
|
URI string `yaml:"uri" json:"uri"`
|
2023-05-18 13:59:03 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
type PromptTemplate struct {
|
|
|
|
Name string `yaml:"name"`
|
|
|
|
Content string `yaml:"content"`
|
|
|
|
}
|
|
|
|
|
2024-01-05 17:04:46 +00:00
|
|
|
func GetGalleryConfigFromURL(url string) (Config, error) {
|
|
|
|
var config Config
|
2023-06-26 10:25:38 +00:00
|
|
|
err := utils.GetURI(url, func(url string, d []byte) error {
|
|
|
|
return yaml.Unmarshal(d, &config)
|
|
|
|
})
|
|
|
|
if err != nil {
|
2023-10-11 16:18:12 +00:00
|
|
|
log.Error().Msgf("GetGalleryConfigFromURL error for url %s\n%s", url, err.Error())
|
2023-06-26 10:25:38 +00:00
|
|
|
return config, err
|
|
|
|
}
|
|
|
|
return config, nil
|
|
|
|
}
|
|
|
|
|
2024-01-05 17:04:46 +00:00
|
|
|
func ReadConfigFile(filePath string) (*Config, error) {
|
2023-05-18 13:59:03 +00:00
|
|
|
// Read the YAML file
|
|
|
|
yamlFile, err := os.ReadFile(filePath)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("failed to read YAML file: %v", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
// Unmarshal YAML data into a Config struct
|
2024-01-05 17:04:46 +00:00
|
|
|
var config Config
|
2023-05-18 13:59:03 +00:00
|
|
|
err = yaml.Unmarshal(yamlFile, &config)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("failed to unmarshal YAML: %v", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
return &config, nil
|
|
|
|
}
|
|
|
|
|
2024-01-05 17:04:46 +00:00
|
|
|
func InstallModel(basePath, nameOverride string, config *Config, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64)) error {
|
2023-05-18 13:59:03 +00:00
|
|
|
// Create base path if it doesn't exist
|
|
|
|
err := os.MkdirAll(basePath, 0755)
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("failed to create base path: %v", err)
|
|
|
|
}
|
|
|
|
|
2023-05-20 15:03:53 +00:00
|
|
|
if len(configOverrides) > 0 {
|
|
|
|
log.Debug().Msgf("Config overrides %+v", configOverrides)
|
|
|
|
}
|
|
|
|
|
2023-05-18 13:59:03 +00:00
|
|
|
// Download files and verify their SHA
|
|
|
|
for _, file := range config.Files {
|
|
|
|
log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename)
|
|
|
|
|
2023-06-22 15:53:10 +00:00
|
|
|
if err := utils.VerifyPath(file.Filename, basePath); err != nil {
|
2023-05-19 06:31:11 +00:00
|
|
|
return err
|
|
|
|
}
|
2023-05-18 13:59:03 +00:00
|
|
|
// Create file path
|
|
|
|
filePath := filepath.Join(basePath, file.Filename)
|
|
|
|
|
2023-12-18 17:58:44 +00:00
|
|
|
if err := utils.DownloadFile(file.URI, filePath, file.SHA256, downloadStatus); err != nil {
|
|
|
|
return err
|
2023-06-24 06:18:17 +00:00
|
|
|
}
|
2023-05-18 13:59:03 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
// Write prompt template contents to separate files
|
|
|
|
for _, template := range config.PromptTemplates {
|
2023-06-22 15:53:10 +00:00
|
|
|
if err := utils.VerifyPath(template.Name+".tmpl", basePath); err != nil {
|
2023-05-19 06:31:11 +00:00
|
|
|
return err
|
|
|
|
}
|
2023-05-18 13:59:03 +00:00
|
|
|
// Create file path
|
|
|
|
filePath := filepath.Join(basePath, template.Name+".tmpl")
|
|
|
|
|
|
|
|
// Create parent directory
|
|
|
|
err := os.MkdirAll(filepath.Dir(filePath), 0755)
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("failed to create parent directory for prompt template %q: %v", template.Name, err)
|
|
|
|
}
|
|
|
|
// Create and write file content
|
|
|
|
err = os.WriteFile(filePath, []byte(template.Content), 0644)
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("failed to write prompt template %q: %v", template.Name, err)
|
|
|
|
}
|
|
|
|
|
|
|
|
log.Debug().Msgf("Prompt template %q written", template.Name)
|
|
|
|
}
|
|
|
|
|
|
|
|
name := config.Name
|
|
|
|
if nameOverride != "" {
|
|
|
|
name = nameOverride
|
|
|
|
}
|
|
|
|
|
2023-06-22 15:53:10 +00:00
|
|
|
if err := utils.VerifyPath(name+".yaml", basePath); err != nil {
|
2023-05-19 06:31:11 +00:00
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2023-06-24 06:18:17 +00:00
|
|
|
// write config file
|
|
|
|
if len(configOverrides) != 0 || len(config.ConfigFile) != 0 {
|
|
|
|
configFilePath := filepath.Join(basePath, name+".yaml")
|
2023-05-18 13:59:03 +00:00
|
|
|
|
2023-06-24 06:18:17 +00:00
|
|
|
// Read and update config file as map[string]interface{}
|
|
|
|
configMap := make(map[string]interface{})
|
|
|
|
err = yaml.Unmarshal([]byte(config.ConfigFile), &configMap)
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("failed to unmarshal config YAML: %v", err)
|
|
|
|
}
|
2023-05-18 13:59:03 +00:00
|
|
|
|
2023-06-24 06:18:17 +00:00
|
|
|
configMap["name"] = name
|
2023-05-18 13:59:03 +00:00
|
|
|
|
2023-06-24 06:18:17 +00:00
|
|
|
if err := mergo.Merge(&configMap, configOverrides, mergo.WithOverride); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
2023-05-20 15:03:53 +00:00
|
|
|
|
2023-06-24 06:18:17 +00:00
|
|
|
// Write updated config file
|
|
|
|
updatedConfigYAML, err := yaml.Marshal(configMap)
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("failed to marshal updated config YAML: %v", err)
|
|
|
|
}
|
2023-05-18 13:59:03 +00:00
|
|
|
|
2023-06-24 06:18:17 +00:00
|
|
|
err = os.WriteFile(configFilePath, updatedConfigYAML, 0644)
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("failed to write updated config file: %v", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
log.Debug().Msgf("Written config file %s", configFilePath)
|
2023-05-18 13:59:03 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|
2024-01-05 17:04:46 +00:00
|
|
|
|
|
|
|
type progressWriter struct {
|
|
|
|
fileName string
|
|
|
|
total int64
|
|
|
|
written int64
|
|
|
|
downloadStatus func(string, string, string, float64)
|
|
|
|
hash hash.Hash
|
|
|
|
}
|
|
|
|
|
|
|
|
func (pw *progressWriter) Write(p []byte) (n int, err error) {
|
|
|
|
n, err = pw.hash.Write(p)
|
|
|
|
pw.written += int64(n)
|
|
|
|
|
|
|
|
if pw.total > 0 {
|
|
|
|
percentage := float64(pw.written) / float64(pw.total) * 100
|
|
|
|
//log.Debug().Msgf("Downloading %s: %s/%s (%.2f%%)", pw.fileName, formatBytes(pw.written), formatBytes(pw.total), percentage)
|
|
|
|
pw.downloadStatus(pw.fileName, formatBytes(pw.written), formatBytes(pw.total), percentage)
|
|
|
|
} else {
|
|
|
|
pw.downloadStatus(pw.fileName, formatBytes(pw.written), "", 0)
|
|
|
|
}
|
|
|
|
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
func formatBytes(bytes int64) string {
|
|
|
|
const unit = 1024
|
|
|
|
if bytes < unit {
|
|
|
|
return strconv.FormatInt(bytes, 10) + " B"
|
|
|
|
}
|
|
|
|
div, exp := int64(unit), 0
|
|
|
|
for n := bytes / unit; n >= unit; n /= unit {
|
|
|
|
div *= unit
|
|
|
|
exp++
|
|
|
|
}
|
|
|
|
return fmt.Sprintf("%.1f %ciB", float64(bytes)/float64(div), "KMGTPE"[exp])
|
|
|
|
}
|
|
|
|
|
|
|
|
func calculateSHA(filePath string) (string, error) {
|
|
|
|
file, err := os.Open(filePath)
|
|
|
|
if err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
defer file.Close()
|
|
|
|
|
|
|
|
hash := sha256.New()
|
|
|
|
if _, err := io.Copy(hash, file); err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
|
|
|
|
return fmt.Sprintf("%x", hash.Sum(nil)), nil
|
|
|
|
}
|