From f895d066055728e2744044ce6390a222bc24d095 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Wed, 13 Mar 2024 10:05:30 +0100 Subject: [PATCH] fix(config): set better defaults for inferencing (#1822) * fix(defaults): set better defaults for inferencing This changeset aim to have better defaults and to properly detect when no inference settings are provided with the model. If not specified, we defaults to mirostat sampling, and offload all the GPU layers (if a GPU is detected). Related to https://github.com/mudler/LocalAI/issues/1373 and https://github.com/mudler/LocalAI/issues/1723 * Adapt tests * Also pre-initialize default seed --- core/backend/embeddings.go | 2 +- core/backend/image.go | 6 +- core/backend/llm.go | 6 +- core/backend/options.go | 45 ++--- core/backend/transcript.go | 4 +- core/config/backend_config.go | 240 +++++++++++++++++++------- core/http/api_test.go | 2 +- core/http/endpoints/localai/tts.go | 9 +- core/http/endpoints/openai/image.go | 2 +- core/http/endpoints/openai/request.go | 33 ++-- core/schema/prediction.go | 17 +- main.go | 2 +- 12 files changed, 235 insertions(+), 133 deletions(-) diff --git a/core/backend/embeddings.go b/core/backend/embeddings.go index 0a74ea4c..94310854 100644 --- a/core/backend/embeddings.go +++ b/core/backend/embeddings.go @@ -23,7 +23,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendCo opts := modelOpts(backendConfig, appConfig, []model.Option{ model.WithLoadGRPCLoadModelOpts(grpcOpts), - model.WithThreads(uint32(backendConfig.Threads)), + model.WithThreads(uint32(*backendConfig.Threads)), model.WithAssetDir(appConfig.AssetsDestination), model.WithModel(modelFile), model.WithContext(appConfig.Context), diff --git a/core/backend/image.go b/core/backend/image.go index 79b8d4ba..b0cffb0b 100644 --- a/core/backend/image.go +++ b/core/backend/image.go @@ -9,14 +9,14 @@ import ( func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() error, error) { threads := backendConfig.Threads - if threads == 0 && appConfig.Threads != 0 { - threads = appConfig.Threads + if *threads == 0 && appConfig.Threads != 0 { + threads = &appConfig.Threads } gRPCOpts := gRPCModelOpts(backendConfig) opts := modelOpts(backendConfig, appConfig, []model.Option{ model.WithBackendString(backendConfig.Backend), model.WithAssetDir(appConfig.AssetsDestination), - model.WithThreads(uint32(threads)), + model.WithThreads(uint32(*threads)), model.WithContext(appConfig.Context), model.WithModel(backendConfig.Model), model.WithLoadGRPCLoadModelOpts(gRPCOpts), diff --git a/core/backend/llm.go b/core/backend/llm.go index 54e26188..d5e14df0 100644 --- a/core/backend/llm.go +++ b/core/backend/llm.go @@ -29,8 +29,8 @@ type TokenUsage struct { func ModelInference(ctx context.Context, s string, images []string, loader *model.ModelLoader, c config.BackendConfig, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) { modelFile := c.Model threads := c.Threads - if threads == 0 && o.Threads != 0 { - threads = o.Threads + if *threads == 0 && o.Threads != 0 { + threads = &o.Threads } grpcOpts := gRPCModelOpts(c) @@ -39,7 +39,7 @@ func ModelInference(ctx context.Context, s string, images []string, loader *mode opts := modelOpts(c, o, []model.Option{ model.WithLoadGRPCLoadModelOpts(grpcOpts), - model.WithThreads(uint32(threads)), // some models uses this to allocate threads during startup + model.WithThreads(uint32(*threads)), // some models uses this to allocate threads during startup model.WithAssetDir(o.AssetsDestination), model.WithModel(modelFile), model.WithContext(o.Context), diff --git a/core/backend/options.go b/core/backend/options.go index 3af6f679..bc7fa5a4 100644 --- a/core/backend/options.go +++ b/core/backend/options.go @@ -46,15 +46,15 @@ func gRPCModelOpts(c config.BackendConfig) *pb.ModelOptions { CFGScale: c.Diffusers.CFGScale, LoraAdapter: c.LoraAdapter, LoraScale: c.LoraScale, - F16Memory: c.F16, + F16Memory: *c.F16, LoraBase: c.LoraBase, IMG2IMG: c.Diffusers.IMG2IMG, CLIPModel: c.Diffusers.ClipModel, CLIPSubfolder: c.Diffusers.ClipSubFolder, CLIPSkip: int32(c.Diffusers.ClipSkip), ControlNet: c.Diffusers.ControlNet, - ContextSize: int32(c.ContextSize), - Seed: int32(c.Seed), + ContextSize: int32(*c.ContextSize), + Seed: int32(*c.Seed), NBatch: int32(b), NoMulMatQ: c.NoMulMatQ, DraftModel: c.DraftModel, @@ -72,18 +72,18 @@ func gRPCModelOpts(c config.BackendConfig) *pb.ModelOptions { YarnBetaSlow: c.YarnBetaSlow, NGQA: c.NGQA, RMSNormEps: c.RMSNormEps, - MLock: c.MMlock, + MLock: *c.MMlock, RopeFreqBase: c.RopeFreqBase, RopeScaling: c.RopeScaling, Type: c.ModelType, RopeFreqScale: c.RopeFreqScale, NUMA: c.NUMA, Embeddings: c.Embeddings, - LowVRAM: c.LowVRAM, - NGPULayers: int32(c.NGPULayers), - MMap: c.MMap, + LowVRAM: *c.LowVRAM, + NGPULayers: int32(*c.NGPULayers), + MMap: *c.MMap, MainGPU: c.MainGPU, - Threads: int32(c.Threads), + Threads: int32(*c.Threads), TensorSplit: c.TensorSplit, // AutoGPTQ ModelBaseName: c.AutoGPTQ.ModelBaseName, @@ -102,36 +102,37 @@ func gRPCPredictOpts(c config.BackendConfig, modelPath string) *pb.PredictOption os.MkdirAll(filepath.Dir(p), 0755) promptCachePath = p } + return &pb.PredictOptions{ - Temperature: float32(c.Temperature), - TopP: float32(c.TopP), + Temperature: float32(*c.Temperature), + TopP: float32(*c.TopP), NDraft: c.NDraft, - TopK: int32(c.TopK), - Tokens: int32(c.Maxtokens), - Threads: int32(c.Threads), + TopK: int32(*c.TopK), + Tokens: int32(*c.Maxtokens), + Threads: int32(*c.Threads), PromptCacheAll: c.PromptCacheAll, PromptCacheRO: c.PromptCacheRO, PromptCachePath: promptCachePath, - F16KV: c.F16, - DebugMode: c.Debug, + F16KV: *c.F16, + DebugMode: *c.Debug, Grammar: c.Grammar, NegativePromptScale: c.NegativePromptScale, RopeFreqBase: c.RopeFreqBase, RopeFreqScale: c.RopeFreqScale, NegativePrompt: c.NegativePrompt, - Mirostat: int32(c.LLMConfig.Mirostat), - MirostatETA: float32(c.LLMConfig.MirostatETA), - MirostatTAU: float32(c.LLMConfig.MirostatTAU), - Debug: c.Debug, + Mirostat: int32(*c.LLMConfig.Mirostat), + MirostatETA: float32(*c.LLMConfig.MirostatETA), + MirostatTAU: float32(*c.LLMConfig.MirostatTAU), + Debug: *c.Debug, StopPrompts: c.StopWords, Repeat: int32(c.RepeatPenalty), NKeep: int32(c.Keep), Batch: int32(c.Batch), IgnoreEOS: c.IgnoreEOS, - Seed: int32(c.Seed), + Seed: int32(*c.Seed), FrequencyPenalty: float32(c.FrequencyPenalty), - MLock: c.MMlock, - MMap: c.MMap, + MLock: *c.MMlock, + MMap: *c.MMap, MainGPU: c.MainGPU, TensorSplit: c.TensorSplit, TailFreeSamplingZ: float32(c.TFZ), diff --git a/core/backend/transcript.go b/core/backend/transcript.go index bbb4f4b4..4c3859df 100644 --- a/core/backend/transcript.go +++ b/core/backend/transcript.go @@ -17,7 +17,7 @@ func ModelTranscription(audio, language string, ml *model.ModelLoader, backendCo model.WithBackendString(model.WhisperBackend), model.WithModel(backendConfig.Model), model.WithContext(appConfig.Context), - model.WithThreads(uint32(backendConfig.Threads)), + model.WithThreads(uint32(*backendConfig.Threads)), model.WithAssetDir(appConfig.AssetsDestination), }) @@ -33,6 +33,6 @@ func ModelTranscription(audio, language string, ml *model.ModelLoader, backendCo return whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{ Dst: audio, Language: language, - Threads: uint32(backendConfig.Threads), + Threads: uint32(*backendConfig.Threads), }) } diff --git a/core/config/backend_config.go b/core/config/backend_config.go index 63e5855c..53326b3f 100644 --- a/core/config/backend_config.go +++ b/core/config/backend_config.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "io/fs" + "math/rand" "os" "path/filepath" "strings" @@ -20,9 +21,9 @@ type BackendConfig struct { schema.PredictionOptions `yaml:"parameters"` Name string `yaml:"name"` - F16 bool `yaml:"f16"` - Threads int `yaml:"threads"` - Debug bool `yaml:"debug"` + F16 *bool `yaml:"f16"` + Threads *int `yaml:"threads"` + Debug *bool `yaml:"debug"` Roles map[string]string `yaml:"roles"` Embeddings bool `yaml:"embeddings"` Backend string `yaml:"backend"` @@ -105,20 +106,20 @@ type LLMConfig struct { PromptCachePath string `yaml:"prompt_cache_path"` PromptCacheAll bool `yaml:"prompt_cache_all"` PromptCacheRO bool `yaml:"prompt_cache_ro"` - MirostatETA float64 `yaml:"mirostat_eta"` - MirostatTAU float64 `yaml:"mirostat_tau"` - Mirostat int `yaml:"mirostat"` - NGPULayers int `yaml:"gpu_layers"` - MMap bool `yaml:"mmap"` - MMlock bool `yaml:"mmlock"` - LowVRAM bool `yaml:"low_vram"` + MirostatETA *float64 `yaml:"mirostat_eta"` + MirostatTAU *float64 `yaml:"mirostat_tau"` + Mirostat *int `yaml:"mirostat"` + NGPULayers *int `yaml:"gpu_layers"` + MMap *bool `yaml:"mmap"` + MMlock *bool `yaml:"mmlock"` + LowVRAM *bool `yaml:"low_vram"` Grammar string `yaml:"grammar"` StopWords []string `yaml:"stopwords"` Cutstrings []string `yaml:"cutstrings"` TrimSpace []string `yaml:"trimspace"` TrimSuffix []string `yaml:"trimsuffix"` - ContextSize int `yaml:"context_size"` + ContextSize *int `yaml:"context_size"` NUMA bool `yaml:"numa"` LoraAdapter string `yaml:"lora_adapter"` LoraBase string `yaml:"lora_base"` @@ -185,19 +186,96 @@ func (c *BackendConfig) FunctionToCall() string { return c.functionCallNameString } -func defaultPredictOptions(modelFile string) schema.PredictionOptions { - return schema.PredictionOptions{ - TopP: 0.7, - TopK: 80, - Maxtokens: 512, - Temperature: 0.9, - Model: modelFile, - } -} +func (cfg *BackendConfig) SetDefaults(debug bool, threads, ctx int, f16 bool) { + defaultTopP := 0.7 + defaultTopK := 80 + defaultTemp := 0.9 + defaultMaxTokens := 2048 + defaultMirostat := 2 + defaultMirostatTAU := 5.0 + defaultMirostatETA := 0.1 -func DefaultConfig(modelFile string) *BackendConfig { - return &BackendConfig{ - PredictionOptions: defaultPredictOptions(modelFile), + // Try to offload all GPU layers (if GPU is found) + defaultNGPULayers := 99999999 + + trueV := true + falseV := false + + if cfg.Seed == nil { + // random number generator seed + defaultSeed := int(rand.Int31()) + cfg.Seed = &defaultSeed + } + + if cfg.TopK == nil { + cfg.TopK = &defaultTopK + } + + if cfg.MMap == nil { + // MMap is enabled by default + cfg.MMap = &trueV + } + + if cfg.MMlock == nil { + // MMlock is disabled by default + cfg.MMlock = &falseV + } + + if cfg.TopP == nil { + cfg.TopP = &defaultTopP + } + if cfg.Temperature == nil { + cfg.Temperature = &defaultTemp + } + + if cfg.Maxtokens == nil { + cfg.Maxtokens = &defaultMaxTokens + } + + if cfg.Mirostat == nil { + cfg.Mirostat = &defaultMirostat + } + + if cfg.MirostatETA == nil { + cfg.MirostatETA = &defaultMirostatETA + } + + if cfg.MirostatTAU == nil { + cfg.MirostatTAU = &defaultMirostatTAU + } + if cfg.NGPULayers == nil { + cfg.NGPULayers = &defaultNGPULayers + } + + if cfg.LowVRAM == nil { + cfg.LowVRAM = &falseV + } + + // Value passed by the top level are treated as default (no implicit defaults) + // defaults are set by the user + if ctx == 0 { + ctx = 1024 + } + + if cfg.ContextSize == nil { + cfg.ContextSize = &ctx + } + + if threads == 0 { + // Threads can't be 0 + threads = 4 + } + + if cfg.Threads == nil { + cfg.Threads = &threads + } + + if cfg.F16 == nil { + cfg.F16 = &f16 + } + + if debug { + cfg.Debug = &debug } } @@ -208,23 +286,63 @@ type BackendConfigLoader struct { sync.Mutex } +type LoadOptions struct { + debug bool + threads, ctxSize int + f16 bool +} + +func LoadOptionDebug(debug bool) ConfigLoaderOption { + return func(o *LoadOptions) { + o.debug = debug + } +} + +func LoadOptionThreads(threads int) ConfigLoaderOption { + return func(o *LoadOptions) { + o.threads = threads + } +} + +func LoadOptionContextSize(ctxSize int) ConfigLoaderOption { + return func(o *LoadOptions) { + o.ctxSize = ctxSize + } +} + +func LoadOptionF16(f16 bool) ConfigLoaderOption { + return func(o *LoadOptions) { + o.f16 = f16 + } +} + +type ConfigLoaderOption func(*LoadOptions) + +func (lo *LoadOptions) Apply(options ...ConfigLoaderOption) { + for _, l := range options { + l(lo) + } +} + // Load a config file for a model -func LoadBackendConfigFileByName(modelName, modelPath string, cl *BackendConfigLoader, debug bool, threads, ctx int, f16 bool) (*BackendConfig, error) { +func (cl *BackendConfigLoader) LoadBackendConfigFileByName(modelName, modelPath string, opts ...ConfigLoaderOption) (*BackendConfig, error) { + + lo := &LoadOptions{} + lo.Apply(opts...) + // Load a config file if present after the model name - modelConfig := filepath.Join(modelPath, modelName+".yaml") - - var cfg *BackendConfig - - defaults := func() { - cfg = DefaultConfig(modelName) - cfg.ContextSize = ctx - cfg.Threads = threads - cfg.F16 = f16 - cfg.Debug = debug + cfg := &BackendConfig{ + PredictionOptions: schema.PredictionOptions{ + Model: modelName, + }, } cfgExisting, exists := cl.GetBackendConfig(modelName) - if !exists { + if exists { + cfg = &cfgExisting + } else { + // Try loading a model config file + modelConfig := filepath.Join(modelPath, modelName+".yaml") if _, err := os.Stat(modelConfig); err == nil { if err := cl.LoadBackendConfig(modelConfig); err != nil { return nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error()) @@ -232,32 +350,11 @@ func LoadBackendConfigFileByName(modelName, modelPath string, cl *BackendConfigL cfgExisting, exists = cl.GetBackendConfig(modelName) if exists { cfg = &cfgExisting - } else { - defaults() } - } else { - defaults() - } - } else { - cfg = &cfgExisting - } - - // Set the parameters for the language model prediction - //updateConfig(cfg, input) - - // Don't allow 0 as setting - if cfg.Threads == 0 { - if threads != 0 { - cfg.Threads = threads - } else { - cfg.Threads = 4 } } - // Enforce debug flag if passed from CLI - if debug { - cfg.Debug = true - } + cfg.SetDefaults(lo.debug, lo.threads, lo.ctxSize, lo.f16) return cfg, nil } @@ -267,7 +364,10 @@ func NewBackendConfigLoader() *BackendConfigLoader { configs: make(map[string]BackendConfig), } } -func ReadBackendConfigFile(file string) ([]*BackendConfig, error) { +func ReadBackendConfigFile(file string, opts ...ConfigLoaderOption) ([]*BackendConfig, error) { + lo := &LoadOptions{} + lo.Apply(opts...) + c := &[]*BackendConfig{} f, err := os.ReadFile(file) if err != nil { @@ -277,10 +377,17 @@ func ReadBackendConfigFile(file string) ([]*BackendConfig, error) { return nil, fmt.Errorf("cannot unmarshal config file: %w", err) } + for _, cc := range *c { + cc.SetDefaults(lo.debug, lo.threads, lo.ctxSize, lo.f16) + } + return *c, nil } -func ReadBackendConfig(file string) (*BackendConfig, error) { +func ReadBackendConfig(file string, opts ...ConfigLoaderOption) (*BackendConfig, error) { + lo := &LoadOptions{} + lo.Apply(opts...) + c := &BackendConfig{} f, err := os.ReadFile(file) if err != nil { @@ -290,13 +397,14 @@ func ReadBackendConfig(file string) (*BackendConfig, error) { return nil, fmt.Errorf("cannot unmarshal config file: %w", err) } + c.SetDefaults(lo.debug, lo.threads, lo.ctxSize, lo.f16) return c, nil } -func (cm *BackendConfigLoader) LoadBackendConfigFile(file string) error { +func (cm *BackendConfigLoader) LoadBackendConfigFile(file string, opts ...ConfigLoaderOption) error { cm.Lock() defer cm.Unlock() - c, err := ReadBackendConfigFile(file) + c, err := ReadBackendConfigFile(file, opts...) if err != nil { return fmt.Errorf("cannot load config file: %w", err) } @@ -307,10 +415,10 @@ func (cm *BackendConfigLoader) LoadBackendConfigFile(file string) error { return nil } -func (cl *BackendConfigLoader) LoadBackendConfig(file string) error { +func (cl *BackendConfigLoader) LoadBackendConfig(file string, opts ...ConfigLoaderOption) error { cl.Lock() defer cl.Unlock() - c, err := ReadBackendConfig(file) + c, err := ReadBackendConfig(file, opts...) if err != nil { return fmt.Errorf("cannot read config file: %w", err) } @@ -407,7 +515,9 @@ func (cl *BackendConfigLoader) Preload(modelPath string) error { return nil } -func (cm *BackendConfigLoader) LoadBackendConfigsFromPath(path string) error { +// LoadBackendConfigsFromPath reads all the configurations of the models from a path +// (non-recursive) +func (cm *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ...ConfigLoaderOption) error { cm.Lock() defer cm.Unlock() entries, err := os.ReadDir(path) @@ -427,7 +537,7 @@ func (cm *BackendConfigLoader) LoadBackendConfigsFromPath(path string) error { if !strings.Contains(file.Name(), ".yaml") && !strings.Contains(file.Name(), ".yml") { continue } - c, err := ReadBackendConfig(filepath.Join(path, file.Name())) + c, err := ReadBackendConfig(filepath.Join(path, file.Name()), opts...) if err == nil { cm.configs[c.Name] = *c } diff --git a/core/http/api_test.go b/core/http/api_test.go index 8f3cfc91..b0579a19 100644 --- a/core/http/api_test.go +++ b/core/http/api_test.go @@ -386,7 +386,7 @@ var _ = Describe("API test", func() { var res map[string]string err = json.Unmarshal([]byte(resp2.Choices[0].Message.FunctionCall.Arguments), &res) Expect(err).ToNot(HaveOccurred()) - Expect(res["location"]).To(Equal("San Francisco, California, United States"), fmt.Sprint(res)) + Expect(res["location"]).To(Equal("San Francisco"), fmt.Sprint(res)) Expect(res["unit"]).To(Equal("celcius"), fmt.Sprint(res)) Expect(string(resp2.Choices[0].FinishReason)).To(Equal("function_call"), fmt.Sprint(resp2.Choices[0].FinishReason)) diff --git a/core/http/endpoints/localai/tts.go b/core/http/endpoints/localai/tts.go index 84fb7a55..9c3f890d 100644 --- a/core/http/endpoints/localai/tts.go +++ b/core/http/endpoints/localai/tts.go @@ -26,7 +26,14 @@ func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfi modelFile = input.Model log.Warn().Msgf("Model not found in context: %s", input.Model) } - cfg, err := config.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath, cl, false, 0, 0, false) + + cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath, + config.LoadOptionDebug(appConfig.Debug), + config.LoadOptionThreads(appConfig.Threads), + config.LoadOptionContextSize(appConfig.ContextSize), + config.LoadOptionF16(appConfig.F16), + ) + if err != nil { modelFile = input.Model log.Warn().Msgf("Model not found in context: %s", input.Model) diff --git a/core/http/endpoints/openai/image.go b/core/http/endpoints/openai/image.go index 8f535801..d59b1051 100644 --- a/core/http/endpoints/openai/image.go +++ b/core/http/endpoints/openai/image.go @@ -196,7 +196,7 @@ func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appCon baseURL := c.BaseURL() - fn, err := backend.ImageGeneration(height, width, mode, step, input.Seed, positive_prompt, negative_prompt, src, output, ml, *config, appConfig) + fn, err := backend.ImageGeneration(height, width, mode, step, *config.Seed, positive_prompt, negative_prompt, src, output, ml, *config, appConfig) if err != nil { return err } diff --git a/core/http/endpoints/openai/request.go b/core/http/endpoints/openai/request.go index 46ff2438..505244c4 100644 --- a/core/http/endpoints/openai/request.go +++ b/core/http/endpoints/openai/request.go @@ -74,10 +74,10 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque if input.Echo { config.Echo = input.Echo } - if input.TopK != 0 { + if input.TopK != nil { config.TopK = input.TopK } - if input.TopP != 0 { + if input.TopP != nil { config.TopP = input.TopP } @@ -117,11 +117,11 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque config.Grammar = input.Grammar } - if input.Temperature != 0 { + if input.Temperature != nil { config.Temperature = input.Temperature } - if input.Maxtokens != 0 { + if input.Maxtokens != nil { config.Maxtokens = input.Maxtokens } @@ -193,30 +193,14 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque config.Batch = input.Batch } - if input.F16 { - config.F16 = input.F16 - } - if input.IgnoreEOS { config.IgnoreEOS = input.IgnoreEOS } - if input.Seed != 0 { + if input.Seed != nil { config.Seed = input.Seed } - if input.Mirostat != 0 { - config.LLMConfig.Mirostat = input.Mirostat - } - - if input.MirostatETA != 0 { - config.LLMConfig.MirostatETA = input.MirostatETA - } - - if input.MirostatTAU != 0 { - config.LLMConfig.MirostatTAU = input.MirostatTAU - } - if input.TypicalP != 0 { config.TypicalP = input.TypicalP } @@ -272,7 +256,12 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque } func mergeRequestWithConfig(modelFile string, input *schema.OpenAIRequest, cm *config.BackendConfigLoader, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*config.BackendConfig, *schema.OpenAIRequest, error) { - cfg, err := config.LoadBackendConfigFileByName(modelFile, loader.ModelPath, cm, debug, threads, ctx, f16) + cfg, err := cm.LoadBackendConfigFileByName(modelFile, loader.ModelPath, + config.LoadOptionDebug(debug), + config.LoadOptionThreads(threads), + config.LoadOptionContextSize(ctx), + config.LoadOptionF16(f16), + ) // Set the parameters for the language model prediction updateRequestConfig(cfg, input) diff --git a/core/schema/prediction.go b/core/schema/prediction.go index efd085a4..d75e5eb8 100644 --- a/core/schema/prediction.go +++ b/core/schema/prediction.go @@ -12,28 +12,23 @@ type PredictionOptions struct { N int `json:"n"` // Common options between all the API calls, part of the OpenAI spec - TopP float64 `json:"top_p" yaml:"top_p"` - TopK int `json:"top_k" yaml:"top_k"` - Temperature float64 `json:"temperature" yaml:"temperature"` - Maxtokens int `json:"max_tokens" yaml:"max_tokens"` - Echo bool `json:"echo"` + TopP *float64 `json:"top_p" yaml:"top_p"` + TopK *int `json:"top_k" yaml:"top_k"` + Temperature *float64 `json:"temperature" yaml:"temperature"` + Maxtokens *int `json:"max_tokens" yaml:"max_tokens"` + Echo bool `json:"echo"` // Custom parameters - not present in the OpenAI API Batch int `json:"batch" yaml:"batch"` - F16 bool `json:"f16" yaml:"f16"` IgnoreEOS bool `json:"ignore_eos" yaml:"ignore_eos"` RepeatPenalty float64 `json:"repeat_penalty" yaml:"repeat_penalty"` Keep int `json:"n_keep" yaml:"n_keep"` - MirostatETA float64 `json:"mirostat_eta" yaml:"mirostat_eta"` - MirostatTAU float64 `json:"mirostat_tau" yaml:"mirostat_tau"` - Mirostat int `json:"mirostat" yaml:"mirostat"` - FrequencyPenalty float64 `json:"frequency_penalty" yaml:"frequency_penalty"` TFZ float64 `json:"tfz" yaml:"tfz"` TypicalP float64 `json:"typical_p" yaml:"typical_p"` - Seed int `json:"seed" yaml:"seed"` + Seed *int `json:"seed" yaml:"seed"` NegativePrompt string `json:"negative_prompt" yaml:"negative_prompt"` RopeFreqBase float32 `json:"rope_freq_base" yaml:"rope_freq_base"` diff --git a/main.go b/main.go index 237191cf..21560e5a 100644 --- a/main.go +++ b/main.go @@ -497,7 +497,7 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit return errors.New("model not found") } - c.Threads = threads + c.Threads = &threads defer ml.StopAllGRPC()