LocalAI/core/config/backend_config_loader.go
Dave eed5706994
refactor: backend/service split, channel-based llm flow (#1963)
Refactor: channel based llm flow and services split

---------

Signed-off-by: Dave Lee <dave@gray101.com>
2024-04-13 09:45:34 +02:00

510 lines
12 KiB
Go

package config
import (
"encoding/json"
"errors"
"fmt"
"io/fs"
"os"
"path/filepath"
"sort"
"strings"
"sync"
"github.com/charmbracelet/glamour"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/pkg/downloader"
"github.com/go-skynet/LocalAI/pkg/grammar"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/rs/zerolog/log"
"gopkg.in/yaml.v2"
)
type BackendConfigLoader struct {
configs map[string]BackendConfig
sync.Mutex
}
type ConfigLoaderOptions struct {
debug bool
threads, ctxSize int
f16 bool
}
func LoadOptionDebug(debug bool) ConfigLoaderOption {
return func(o *ConfigLoaderOptions) {
o.debug = debug
}
}
func LoadOptionThreads(threads int) ConfigLoaderOption {
return func(o *ConfigLoaderOptions) {
o.threads = threads
}
}
func LoadOptionContextSize(ctxSize int) ConfigLoaderOption {
return func(o *ConfigLoaderOptions) {
o.ctxSize = ctxSize
}
}
func LoadOptionF16(f16 bool) ConfigLoaderOption {
return func(o *ConfigLoaderOptions) {
o.f16 = f16
}
}
type ConfigLoaderOption func(*ConfigLoaderOptions)
func (lo *ConfigLoaderOptions) Apply(options ...ConfigLoaderOption) {
for _, l := range options {
l(lo)
}
}
func NewBackendConfigLoader() *BackendConfigLoader {
return &BackendConfigLoader{
configs: make(map[string]BackendConfig),
}
}
func (bcl *BackendConfigLoader) LoadBackendConfig(file string, opts ...ConfigLoaderOption) error {
bcl.Lock()
defer bcl.Unlock()
c, err := readBackendConfig(file, opts...)
if err != nil {
return fmt.Errorf("cannot read config file: %w", err)
}
bcl.configs[c.Name] = *c
return nil
}
func (bcl *BackendConfigLoader) GetBackendConfig(m string) (BackendConfig, bool) {
bcl.Lock()
defer bcl.Unlock()
v, exists := bcl.configs[m]
return v, exists
}
func (bcl *BackendConfigLoader) GetAllBackendConfigs() []BackendConfig {
bcl.Lock()
defer bcl.Unlock()
var res []BackendConfig
for _, v := range bcl.configs {
res = append(res, v)
}
sort.SliceStable(res, func(i, j int) bool {
return res[i].Name < res[j].Name
})
return res
}
func (bcl *BackendConfigLoader) ListBackendConfigs() []string {
bcl.Lock()
defer bcl.Unlock()
var res []string
for k := range bcl.configs {
res = append(res, k)
}
return res
}
// Preload prepare models if they are not local but url or huggingface repositories
func (bcl *BackendConfigLoader) Preload(modelPath string) error {
bcl.Lock()
defer bcl.Unlock()
status := func(fileName, current, total string, percent float64) {
utils.DisplayDownloadFunction(fileName, current, total, percent)
}
log.Info().Msgf("Preloading models from %s", modelPath)
renderMode := "dark"
if os.Getenv("COLOR") != "" {
renderMode = os.Getenv("COLOR")
}
glamText := func(t string) {
out, err := glamour.Render(t, renderMode)
if err == nil && os.Getenv("NO_COLOR") == "" {
fmt.Println(out)
} else {
fmt.Println(t)
}
}
for i, config := range bcl.configs {
// Download files and verify their SHA
for _, file := range config.DownloadFiles {
log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename)
if err := utils.VerifyPath(file.Filename, modelPath); err != nil {
return err
}
// Create file path
filePath := filepath.Join(modelPath, file.Filename)
if err := downloader.DownloadFile(file.URI, filePath, file.SHA256, status); err != nil {
return err
}
}
modelURL := config.PredictionOptions.Model
modelURL = downloader.ConvertURL(modelURL)
if downloader.LooksLikeURL(modelURL) {
// md5 of model name
md5Name := utils.MD5(modelURL)
// check if file exists
if _, err := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(err, os.ErrNotExist) {
err := downloader.DownloadFile(modelURL, filepath.Join(modelPath, md5Name), "", status)
if err != nil {
return err
}
}
cc := bcl.configs[i]
c := &cc
c.PredictionOptions.Model = md5Name
bcl.configs[i] = *c
}
if bcl.configs[i].Name != "" {
glamText(fmt.Sprintf("**Model name**: _%s_", bcl.configs[i].Name))
}
if bcl.configs[i].Description != "" {
//glamText("**Description**")
glamText(bcl.configs[i].Description)
}
if bcl.configs[i].Usage != "" {
//glamText("**Usage**")
glamText(bcl.configs[i].Usage)
}
}
return nil
}
func (bcl *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ...ConfigLoaderOption) error {
bcl.Lock()
defer bcl.Unlock()
entries, err := os.ReadDir(path)
if err != nil {
return err
}
files := make([]fs.FileInfo, 0, len(entries))
for _, entry := range entries {
info, err := entry.Info()
if err != nil {
return err
}
files = append(files, info)
}
for _, file := range files {
// Skip templates, YAML and .keep files
if !strings.Contains(file.Name(), ".yaml") && !strings.Contains(file.Name(), ".yml") {
continue
}
c, err := readBackendConfig(filepath.Join(path, file.Name()), opts...)
if err == nil {
bcl.configs[c.Name] = *c
}
}
return nil
}
func (bcl *BackendConfigLoader) LoadBackendConfigFile(file string, opts ...ConfigLoaderOption) error {
bcl.Lock()
defer bcl.Unlock()
c, err := readBackendConfigFile(file, opts...)
if err != nil {
return fmt.Errorf("cannot load config file: %w", err)
}
for _, cc := range c {
bcl.configs[cc.Name] = *cc
}
return nil
}
//////////
// Load a config file for a model
func (bcl *BackendConfigLoader) LoadBackendConfigFileByName(modelName string, modelPath string, opts ...ConfigLoaderOption) (*BackendConfig, error) {
// Load a config file if present after the model name
cfg := &BackendConfig{
PredictionOptions: schema.PredictionOptions{
Model: modelName,
},
}
cfgExisting, exists := bcl.GetBackendConfig(modelName)
if exists {
cfg = &cfgExisting
} else {
// Load a config file if present after the model name
modelConfig := filepath.Join(modelPath, modelName+".yaml")
if _, err := os.Stat(modelConfig); err == nil {
if err := bcl.LoadBackendConfig(modelConfig); err != nil {
return nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error())
}
cfgExisting, exists = bcl.GetBackendConfig(modelName)
if exists {
cfg = &cfgExisting
}
}
}
cfg.SetDefaults(opts...)
return cfg, nil
}
func readBackendConfigFile(file string, opts ...ConfigLoaderOption) ([]*BackendConfig, error) {
c := &[]*BackendConfig{}
f, err := os.ReadFile(file)
if err != nil {
return nil, fmt.Errorf("cannot read config file: %w", err)
}
if err := yaml.Unmarshal(f, c); err != nil {
return nil, fmt.Errorf("cannot unmarshal config file: %w", err)
}
for _, cc := range *c {
cc.SetDefaults(opts...)
}
return *c, nil
}
func readBackendConfig(file string, opts ...ConfigLoaderOption) (*BackendConfig, error) {
c := &BackendConfig{}
f, err := os.ReadFile(file)
if err != nil {
return nil, fmt.Errorf("cannot read config file: %w", err)
}
if err := yaml.Unmarshal(f, c); err != nil {
return nil, fmt.Errorf("cannot unmarshal config file: %w", err)
}
c.SetDefaults(opts...)
return c, nil
}
func (bcl *BackendConfigLoader) LoadBackendConfigForModelAndOpenAIRequest(modelFile string, input *schema.OpenAIRequest, appConfig *ApplicationConfig) (*BackendConfig, *schema.OpenAIRequest, error) {
cfg, err := bcl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
LoadOptionContextSize(appConfig.ContextSize),
LoadOptionDebug(appConfig.Debug),
LoadOptionF16(appConfig.F16),
LoadOptionThreads(appConfig.Threads),
)
// Set the parameters for the language model prediction
updateBackendConfigFromOpenAIRequest(cfg, input)
return cfg, input, err
}
func updateBackendConfigFromOpenAIRequest(bc *BackendConfig, request *schema.OpenAIRequest) {
if request.Echo {
bc.Echo = request.Echo
}
if request.TopK != nil && *request.TopK != 0 {
bc.TopK = request.TopK
}
if request.TopP != nil && *request.TopP != 0 {
bc.TopP = request.TopP
}
if request.Backend != "" {
bc.Backend = request.Backend
}
if request.ClipSkip != 0 {
bc.Diffusers.ClipSkip = request.ClipSkip
}
if request.ModelBaseName != "" {
bc.AutoGPTQ.ModelBaseName = request.ModelBaseName
}
if request.NegativePromptScale != 0 {
bc.NegativePromptScale = request.NegativePromptScale
}
if request.UseFastTokenizer {
bc.UseFastTokenizer = request.UseFastTokenizer
}
if request.NegativePrompt != "" {
bc.NegativePrompt = request.NegativePrompt
}
if request.RopeFreqBase != 0 {
bc.RopeFreqBase = request.RopeFreqBase
}
if request.RopeFreqScale != 0 {
bc.RopeFreqScale = request.RopeFreqScale
}
if request.Grammar != "" {
bc.Grammar = request.Grammar
}
if request.Temperature != nil && *request.Temperature != 0 {
bc.Temperature = request.Temperature
}
if request.Maxtokens != nil && *request.Maxtokens != 0 {
bc.Maxtokens = request.Maxtokens
}
switch stop := request.Stop.(type) {
case string:
if stop != "" {
bc.StopWords = append(bc.StopWords, stop)
}
case []interface{}:
for _, pp := range stop {
if s, ok := pp.(string); ok {
bc.StopWords = append(bc.StopWords, s)
}
}
}
if len(request.Tools) > 0 {
for _, tool := range request.Tools {
request.Functions = append(request.Functions, tool.Function)
}
}
if request.ToolsChoice != nil {
var toolChoice grammar.Tool
switch content := request.ToolsChoice.(type) {
case string:
_ = json.Unmarshal([]byte(content), &toolChoice)
case map[string]interface{}:
dat, _ := json.Marshal(content)
_ = json.Unmarshal(dat, &toolChoice)
}
request.FunctionCall = map[string]interface{}{
"name": toolChoice.Function.Name,
}
}
// Decode each request's message content
index := 0
for i, m := range request.Messages {
switch content := m.Content.(type) {
case string:
request.Messages[i].StringContent = content
case []interface{}:
dat, _ := json.Marshal(content)
c := []schema.Content{}
json.Unmarshal(dat, &c)
for _, pp := range c {
if pp.Type == "text" {
request.Messages[i].StringContent = pp.Text
} else if pp.Type == "image_url" {
// Detect if pp.ImageURL is an URL, if it is download the image and encode it in base64:
base64, err := utils.GetImageURLAsBase64(pp.ImageURL.URL)
if err == nil {
request.Messages[i].StringImages = append(request.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff
// set a placeholder for each image
request.Messages[i].StringContent = fmt.Sprintf("[img-%d]", index) + request.Messages[i].StringContent
index++
} else {
fmt.Print("Failed encoding image", err)
}
}
}
}
}
if request.RepeatPenalty != 0 {
bc.RepeatPenalty = request.RepeatPenalty
}
if request.FrequencyPenalty != 0 {
bc.FrequencyPenalty = request.FrequencyPenalty
}
if request.PresencePenalty != 0 {
bc.PresencePenalty = request.PresencePenalty
}
if request.Keep != 0 {
bc.Keep = request.Keep
}
if request.Batch != 0 {
bc.Batch = request.Batch
}
if request.IgnoreEOS {
bc.IgnoreEOS = request.IgnoreEOS
}
if request.Seed != nil {
bc.Seed = request.Seed
}
if request.TypicalP != nil {
bc.TypicalP = request.TypicalP
}
switch inputs := request.Input.(type) {
case string:
if inputs != "" {
bc.InputStrings = append(bc.InputStrings, inputs)
}
case []interface{}:
for _, pp := range inputs {
switch i := pp.(type) {
case string:
bc.InputStrings = append(bc.InputStrings, i)
case []interface{}:
tokens := []int{}
for _, ii := range i {
tokens = append(tokens, int(ii.(float64)))
}
bc.InputToken = append(bc.InputToken, tokens)
}
}
}
// Can be either a string or an object
switch fnc := request.FunctionCall.(type) {
case string:
if fnc != "" {
bc.SetFunctionCallString(fnc)
}
case map[string]interface{}:
var name string
n, exists := fnc["name"]
if exists {
nn, e := n.(string)
if e {
name = nn
}
}
bc.SetFunctionCallNameString(name)
}
switch p := request.Prompt.(type) {
case string:
bc.PromptStrings = append(bc.PromptStrings, p)
case []interface{}:
for _, pp := range p {
if s, ok := pp.(string); ok {
bc.PromptStrings = append(bc.PromptStrings, s)
}
}
}
}