2024-02-21 01:21:19 +00:00
package config
2023-07-14 23:19:43 +00:00
import (
2024-04-17 21:33:49 +00:00
"os"
2024-05-21 12:33:47 +00:00
"regexp"
2024-10-01 18:55:46 +00:00
"slices"
2024-05-21 12:33:47 +00:00
"strings"
2024-04-17 21:33:49 +00:00
2024-06-23 08:24:36 +00:00
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/downloader"
"github.com/mudler/LocalAI/pkg/functions"
2024-10-01 18:55:46 +00:00
"gopkg.in/yaml.v3"
2023-07-14 23:19:43 +00:00
)
2024-04-03 20:25:47 +00:00
const (
RAND_SEED = - 1
)
2024-06-01 18:26:27 +00:00
type TTSConfig struct {
// Voice wav path or id
Voice string ` yaml:"voice" `
2025-01-17 08:35:10 +00:00
AudioPath string ` yaml:"audio_path" `
2024-06-01 18:26:27 +00:00
}
2024-03-01 15:19:53 +00:00
type BackendConfig struct {
schema . PredictionOptions ` yaml:"parameters" `
Name string ` yaml:"name" `
2023-08-09 06:38:51 +00:00
2024-10-01 18:55:46 +00:00
F16 * bool ` yaml:"f16" `
Threads * int ` yaml:"threads" `
Debug * bool ` yaml:"debug" `
Roles map [ string ] string ` yaml:"roles" `
Embeddings * bool ` yaml:"embeddings" `
Backend string ` yaml:"backend" `
TemplateConfig TemplateConfig ` yaml:"template" `
KnownUsecaseStrings [ ] string ` yaml:"known_usecases" `
KnownUsecases * BackendConfigUsecases ` yaml:"-" `
2023-08-09 06:38:51 +00:00
2024-05-29 12:40:54 +00:00
PromptStrings , InputStrings [ ] string ` yaml:"-" `
InputToken [ ] [ ] int ` yaml:"-" `
functionCallString , functionCallNameString string ` yaml:"-" `
ResponseFormat string ` yaml:"-" `
ResponseFormatMap map [ string ] interface { } ` yaml:"-" `
2023-07-14 23:19:43 +00:00
2024-04-18 20:43:12 +00:00
FunctionsConfig functions . FunctionsConfig ` yaml:"function" `
2023-07-22 15:31:39 +00:00
2023-08-19 14:15:22 +00:00
FeatureFlag FeatureFlag ` yaml:"feature_flags" ` // Feature Flag registry. We move fast, and features may break on a per model/backend basis. Registry for (usually temporary) flags that indicate aborting something early.
2023-08-09 06:38:51 +00:00
// LLM configs (GPT4ALL, Llama.cpp, ...)
LLMConfig ` yaml:",inline" `
// AutoGPTQ specifics
AutoGPTQ AutoGPTQ ` yaml:"autogptq" `
2023-08-02 22:51:08 +00:00
2023-08-09 06:38:51 +00:00
// Diffusers
Diffusers Diffusers ` yaml:"diffusers" `
2023-12-13 18:20:22 +00:00
Step int ` yaml:"step" `
2023-08-15 23:11:32 +00:00
// GRPC Options
GRPC GRPC ` yaml:"grpc" `
2023-09-04 17:25:23 +00:00
2024-06-01 18:26:27 +00:00
// TTS specifics
TTSConfig ` yaml:"tts" `
2023-12-08 14:45:04 +00:00
// CUDA
// Explicitly enable CUDA or not (some backends might need it)
CUDA bool ` yaml:"cuda" `
2024-01-01 13:39:13 +00:00
DownloadFiles [ ] File ` yaml:"download_files" `
2024-01-07 23:37:02 +00:00
Description string ` yaml:"description" `
Usage string ` yaml:"usage" `
2024-12-03 21:41:22 +00:00
Options [ ] string ` yaml:"options" `
2024-01-01 13:39:13 +00:00
}
type File struct {
2024-08-02 18:06:25 +00:00
Filename string ` yaml:"filename" json:"filename" `
SHA256 string ` yaml:"sha256" json:"sha256" `
URI downloader . URI ` yaml:"uri" json:"uri" `
2023-09-04 17:25:23 +00:00
}
2023-08-19 14:15:22 +00:00
type FeatureFlag map [ string ] * bool
func ( ff FeatureFlag ) Enabled ( s string ) bool {
v , exist := ff [ s ]
return exist && v != nil && * v
}
2023-08-15 23:11:32 +00:00
type GRPC struct {
Attempts int ` yaml:"attempts" `
AttemptsSleepTime int ` yaml:"attempts_sleep_time" `
2023-08-09 06:38:51 +00:00
}
type Diffusers struct {
2024-12-03 21:41:22 +00:00
CUDA bool ` yaml:"cuda" `
PipelineType string ` yaml:"pipeline_type" `
SchedulerType string ` yaml:"scheduler_type" `
EnableParameters string ` yaml:"enable_parameters" ` // A list of comma separated parameters to specify
IMG2IMG bool ` yaml:"img2img" ` // Image to Image Diffuser
ClipSkip int ` yaml:"clip_skip" ` // Skip every N frames
ClipModel string ` yaml:"clip_model" ` // Clip model to use
ClipSubFolder string ` yaml:"clip_subfolder" ` // Subfolder to use for clip model
ControlNet string ` yaml:"control_net" `
2023-08-09 06:38:51 +00:00
}
2024-05-14 23:17:02 +00:00
// LLMConfig is a struct that holds the configuration that are
// generic for most of the LLM backends.
2023-08-09 06:38:51 +00:00
type LLMConfig struct {
SystemPrompt string ` yaml:"system_prompt" `
TensorSplit string ` yaml:"tensor_split" `
MainGPU string ` yaml:"main_gpu" `
RMSNormEps float32 ` yaml:"rms_norm_eps" `
NGQA int32 ` yaml:"ngqa" `
PromptCachePath string ` yaml:"prompt_cache_path" `
PromptCacheAll bool ` yaml:"prompt_cache_all" `
PromptCacheRO bool ` yaml:"prompt_cache_ro" `
2024-03-13 09:05:30 +00:00
MirostatETA * float64 ` yaml:"mirostat_eta" `
MirostatTAU * float64 ` yaml:"mirostat_tau" `
Mirostat * int ` yaml:"mirostat" `
NGPULayers * int ` yaml:"gpu_layers" `
MMap * bool ` yaml:"mmap" `
MMlock * bool ` yaml:"mmlock" `
LowVRAM * bool ` yaml:"low_vram" `
2023-08-09 06:38:51 +00:00
Grammar string ` yaml:"grammar" `
StopWords [ ] string ` yaml:"stopwords" `
Cutstrings [ ] string ` yaml:"cutstrings" `
2024-09-13 11:27:36 +00:00
ExtractRegex [ ] string ` yaml:"extract_regex" `
2023-08-09 06:38:51 +00:00
TrimSpace [ ] string ` yaml:"trimspace" `
2024-01-01 13:39:42 +00:00
TrimSuffix [ ] string ` yaml:"trimsuffix" `
2024-11-05 14:14:33 +00:00
ContextSize * int ` yaml:"context_size" `
NUMA bool ` yaml:"numa" `
LoraAdapter string ` yaml:"lora_adapter" `
LoraBase string ` yaml:"lora_base" `
LoraAdapters [ ] string ` yaml:"lora_adapters" `
LoraScales [ ] float32 ` yaml:"lora_scales" `
LoraScale float32 ` yaml:"lora_scale" `
NoMulMatQ bool ` yaml:"no_mulmatq" `
DraftModel string ` yaml:"draft_model" `
NDraft int32 ` yaml:"n_draft" `
Quantization string ` yaml:"quantization" `
LoadFormat string ` yaml:"load_format" `
GPUMemoryUtilization float32 ` yaml:"gpu_memory_utilization" ` // vLLM
TrustRemoteCode bool ` yaml:"trust_remote_code" ` // vLLM
EnforceEager bool ` yaml:"enforce_eager" ` // vLLM
SwapSpace int ` yaml:"swap_space" ` // vLLM
MaxModelLen int ` yaml:"max_model_len" ` // vLLM
TensorParallelSize int ` yaml:"tensor_parallel_size" ` // vLLM
MMProj string ` yaml:"mmproj" `
2023-11-11 17:40:48 +00:00
2024-12-06 09:23:59 +00:00
FlashAttention bool ` yaml:"flash_attention" `
NoKVOffloading bool ` yaml:"no_kv_offloading" `
CacheTypeK string ` yaml:"cache_type_k" `
CacheTypeV string ` yaml:"cache_type_v" `
2024-05-13 17:07:51 +00:00
2024-01-25 23:13:21 +00:00
RopeScaling string ` yaml:"rope_scaling" `
ModelType string ` yaml:"type" `
2023-11-11 17:40:48 +00:00
YarnExtFactor float32 ` yaml:"yarn_ext_factor" `
YarnAttnFactor float32 ` yaml:"yarn_attn_factor" `
YarnBetaFast float32 ` yaml:"yarn_beta_fast" `
YarnBetaSlow float32 ` yaml:"yarn_beta_slow" `
2024-12-03 21:41:22 +00:00
CFGScale float32 ` yaml:"cfg_scale" ` // Classifier-Free Guidance Scale
2023-08-09 06:38:51 +00:00
}
2023-08-07 20:39:10 +00:00
2024-05-14 23:17:02 +00:00
// AutoGPTQ is a struct that holds the configuration specific to the AutoGPTQ backend
2023-08-09 06:38:51 +00:00
type AutoGPTQ struct {
2023-08-07 23:10:05 +00:00
ModelBaseName string ` yaml:"model_base_name" `
Device string ` yaml:"device" `
Triton bool ` yaml:"triton" `
UseFastTokenizer bool ` yaml:"use_fast_tokenizer" `
2023-07-14 23:19:43 +00:00
}
2024-05-14 23:17:02 +00:00
// TemplateConfig is a struct that holds the configuration of the templating system
2023-07-14 23:19:43 +00:00
type TemplateConfig struct {
2024-05-14 23:17:02 +00:00
// Chat is the template used in the chat completion endpoint
Chat string ` yaml:"chat" `
// ChatMessage is the template used for chat messages
ChatMessage string ` yaml:"chat_message" `
// Completion is the template used for completion requests
Completion string ` yaml:"completion" `
// Edit is the template used for edit completion requests
Edit string ` yaml:"edit" `
// Functions is the template used when tools are present in the client requests
Functions string ` yaml:"function" `
// UseTokenizerTemplate is a flag that indicates if the tokenizer template should be used.
// Note: this is mostly consumed for backends such as vllm and transformers
// that can use the tokenizers specified in the JSON config files of the models
UseTokenizerTemplate bool ` yaml:"use_tokenizer_template" `
// JoinChatMessagesByCharacter is a string that will be used to join chat messages together.
// It defaults to \n
JoinChatMessagesByCharacter * string ` yaml:"join_chat_messages_by_character" `
2024-10-04 16:32:29 +00:00
2024-10-22 07:34:05 +00:00
Multimodal string ` yaml:"multimodal" `
2024-12-08 12:50:33 +00:00
JinjaTemplate bool ` yaml:"jinja_template" `
2023-07-14 23:19:43 +00:00
}
2024-10-01 18:55:46 +00:00
func ( c * BackendConfig ) UnmarshalYAML ( value * yaml . Node ) error {
type BCAlias BackendConfig
var aux BCAlias
if err := value . Decode ( & aux ) ; err != nil {
return err
}
* c = BackendConfig ( aux )
c . KnownUsecases = GetUsecasesFromYAML ( c . KnownUsecaseStrings )
return nil
}
2024-03-01 15:19:53 +00:00
func ( c * BackendConfig ) SetFunctionCallString ( s string ) {
2023-07-14 23:19:43 +00:00
c . functionCallString = s
}
2024-03-01 15:19:53 +00:00
func ( c * BackendConfig ) SetFunctionCallNameString ( s string ) {
2023-07-14 23:19:43 +00:00
c . functionCallNameString = s
}
2024-03-01 15:19:53 +00:00
func ( c * BackendConfig ) ShouldUseFunctions ( ) bool {
2023-07-14 23:19:43 +00:00
return ( ( c . functionCallString != "none" || c . functionCallString == "" ) || c . ShouldCallSpecificFunction ( ) )
}
2024-03-01 15:19:53 +00:00
func ( c * BackendConfig ) ShouldCallSpecificFunction ( ) bool {
2023-07-14 23:19:43 +00:00
return len ( c . functionCallNameString ) > 0
}
2024-04-28 21:42:46 +00:00
// MMProjFileName returns the filename of the MMProj file
// If the MMProj is a URL, it will return the MD5 of the URL which is the filename
func ( c * BackendConfig ) MMProjFileName ( ) string {
2024-08-02 18:06:25 +00:00
uri := downloader . URI ( c . MMProj )
if uri . LooksLikeURL ( ) {
f , _ := uri . FilenameFromUrl ( )
return f
2024-04-28 21:42:46 +00:00
}
return c . MMProj
}
func ( c * BackendConfig ) IsMMProjURL ( ) bool {
2024-08-02 18:06:25 +00:00
uri := downloader . URI ( c . MMProj )
return uri . LooksLikeURL ( )
2024-04-28 21:42:46 +00:00
}
func ( c * BackendConfig ) IsModelURL ( ) bool {
2024-08-02 18:06:25 +00:00
uri := downloader . URI ( c . Model )
return uri . LooksLikeURL ( )
2024-04-28 21:42:46 +00:00
}
// ModelFileName returns the filename of the model
// If the model is a URL, it will return the MD5 of the URL which is the filename
func ( c * BackendConfig ) ModelFileName ( ) string {
2024-08-02 18:06:25 +00:00
uri := downloader . URI ( c . Model )
if uri . LooksLikeURL ( ) {
f , _ := uri . FilenameFromUrl ( )
return f
2024-04-28 21:42:46 +00:00
}
return c . Model
}
2024-03-01 15:19:53 +00:00
func ( c * BackendConfig ) FunctionToCall ( ) string {
2024-04-01 17:39:54 +00:00
if c . functionCallNameString != "" &&
c . functionCallNameString != "none" && c . functionCallNameString != "auto" {
return c . functionCallNameString
}
return c . functionCallString
2023-07-14 23:19:43 +00:00
}
2024-03-22 19:55:11 +00:00
func ( cfg * BackendConfig ) SetDefaults ( opts ... ConfigLoaderOption ) {
2024-04-17 21:33:49 +00:00
lo := & LoadOptions { }
2024-03-22 19:55:11 +00:00
lo . Apply ( opts ... )
ctx := lo . ctxSize
threads := lo . threads
f16 := lo . f16
debug := lo . debug
2024-04-06 20:56:45 +00:00
// https://github.com/ggerganov/llama.cpp/blob/75cd4c77292034ecec587ecb401366f57338f7c0/common/sampling.h#L22
defaultTopP := 0.95
defaultTopK := 40
2024-03-13 09:05:30 +00:00
defaultTemp := 0.9
defaultMirostat := 2
defaultMirostatTAU := 5.0
defaultMirostatETA := 0.1
2024-04-06 20:56:45 +00:00
defaultTypicalP := 1.0
defaultTFZ := 1.0
2024-04-21 14:34:00 +00:00
defaultZero := 0
2024-03-13 09:05:30 +00:00
// Try to offload all GPU layers (if GPU is found)
2024-04-20 18:20:10 +00:00
defaultHigh := 99999999
2024-03-13 09:05:30 +00:00
trueV := true
falseV := false
if cfg . Seed == nil {
// random number generator seed
2024-04-03 20:25:47 +00:00
defaultSeed := RAND_SEED
2024-03-13 09:05:30 +00:00
cfg . Seed = & defaultSeed
}
if cfg . TopK == nil {
cfg . TopK = & defaultTopK
}
2024-04-06 20:56:45 +00:00
if cfg . TypicalP == nil {
cfg . TypicalP = & defaultTypicalP
}
if cfg . TFZ == nil {
cfg . TFZ = & defaultTFZ
}
2024-03-13 09:05:30 +00:00
if cfg . MMap == nil {
// MMap is enabled by default
2024-04-27 07:08:33 +00:00
// Only exception is for Intel GPUs
if os . Getenv ( "XPU" ) != "" {
cfg . MMap = & falseV
} else {
cfg . MMap = & trueV
}
2024-03-01 15:19:53 +00:00
}
2024-03-13 09:05:30 +00:00
if cfg . MMlock == nil {
// MMlock is disabled by default
cfg . MMlock = & falseV
}
if cfg . TopP == nil {
cfg . TopP = & defaultTopP
}
if cfg . Temperature == nil {
cfg . Temperature = & defaultTemp
}
if cfg . Maxtokens == nil {
2024-04-21 14:34:00 +00:00
cfg . Maxtokens = & defaultZero
2024-03-13 09:05:30 +00:00
}
if cfg . Mirostat == nil {
cfg . Mirostat = & defaultMirostat
}
if cfg . MirostatETA == nil {
cfg . MirostatETA = & defaultMirostatETA
}
if cfg . MirostatTAU == nil {
cfg . MirostatTAU = & defaultMirostatTAU
}
if cfg . NGPULayers == nil {
2024-04-20 18:20:10 +00:00
cfg . NGPULayers = & defaultHigh
2024-03-13 09:05:30 +00:00
}
if cfg . LowVRAM == nil {
cfg . LowVRAM = & falseV
}
2024-07-15 20:54:16 +00:00
if cfg . Embeddings == nil {
cfg . Embeddings = & falseV
}
2024-03-13 09:05:30 +00:00
// Value passed by the top level are treated as default (no implicit defaults)
// defaults are set by the user
if ctx == 0 {
ctx = 1024
}
if cfg . ContextSize == nil {
cfg . ContextSize = & ctx
}
if threads == 0 {
// Threads can't be 0
threads = 4
}
if cfg . Threads == nil {
cfg . Threads = & threads
}
if cfg . F16 == nil {
cfg . F16 = & f16
}
2024-03-18 17:59:39 +00:00
if cfg . Debug == nil {
cfg . Debug = & falseV
}
2024-03-13 09:05:30 +00:00
if debug {
2024-03-18 17:59:39 +00:00
cfg . Debug = & trueV
2024-03-01 15:19:53 +00:00
}
2024-06-08 20:13:02 +00:00
guessDefaultsFromFile ( cfg , lo . modelPath )
2024-03-01 15:19:53 +00:00
}
2024-05-21 12:33:47 +00:00
func ( c * BackendConfig ) Validate ( ) bool {
2024-05-23 20:48:12 +00:00
downloadedFileNames := [ ] string { }
for _ , f := range c . DownloadFiles {
downloadedFileNames = append ( downloadedFileNames , f . Filename )
}
validationTargets := [ ] string { c . Backend , c . Model , c . MMProj }
validationTargets = append ( validationTargets , downloadedFileNames ... )
2024-05-21 12:33:47 +00:00
// Simple validation to make sure the model can be correctly loaded
2024-05-23 20:48:12 +00:00
for _ , n := range validationTargets {
2024-05-22 06:32:30 +00:00
if n == "" {
continue
}
2024-05-21 12:33:47 +00:00
if strings . HasPrefix ( n , string ( os . PathSeparator ) ) ||
strings . Contains ( n , ".." ) {
return false
}
}
if c . Backend != "" {
// a regex that checks that is a string name with no special characters, except '-' and '_'
re := regexp . MustCompile ( ` ^[a-zA-Z0-9-_]+$ ` )
return re . MatchString ( c . Backend )
}
return true
}
2024-06-08 20:13:02 +00:00
func ( c * BackendConfig ) HasTemplate ( ) bool {
return c . TemplateConfig . Completion != "" || c . TemplateConfig . Edit != "" || c . TemplateConfig . Chat != "" || c . TemplateConfig . ChatMessage != ""
}
2024-10-01 18:55:46 +00:00
type BackendConfigUsecases int
const (
FLAG_ANY BackendConfigUsecases = 0b000000000
FLAG_CHAT BackendConfigUsecases = 0b000000001
FLAG_COMPLETION BackendConfigUsecases = 0b000000010
FLAG_EDIT BackendConfigUsecases = 0b000000100
FLAG_EMBEDDINGS BackendConfigUsecases = 0b000001000
FLAG_RERANK BackendConfigUsecases = 0b000010000
FLAG_IMAGE BackendConfigUsecases = 0b000100000
FLAG_TRANSCRIPT BackendConfigUsecases = 0b001000000
FLAG_TTS BackendConfigUsecases = 0b010000000
FLAG_SOUND_GENERATION BackendConfigUsecases = 0b100000000
// Common Subsets
FLAG_LLM BackendConfigUsecases = FLAG_CHAT & FLAG_COMPLETION & FLAG_EDIT
)
func GetAllBackendConfigUsecases ( ) map [ string ] BackendConfigUsecases {
return map [ string ] BackendConfigUsecases {
"FLAG_ANY" : FLAG_ANY ,
"FLAG_CHAT" : FLAG_CHAT ,
"FLAG_COMPLETION" : FLAG_COMPLETION ,
"FLAG_EDIT" : FLAG_EDIT ,
"FLAG_EMBEDDINGS" : FLAG_EMBEDDINGS ,
"FLAG_RERANK" : FLAG_RERANK ,
"FLAG_IMAGE" : FLAG_IMAGE ,
"FLAG_TRANSCRIPT" : FLAG_TRANSCRIPT ,
"FLAG_TTS" : FLAG_TTS ,
"FLAG_SOUND_GENERATION" : FLAG_SOUND_GENERATION ,
"FLAG_LLM" : FLAG_LLM ,
}
}
func GetUsecasesFromYAML ( input [ ] string ) * BackendConfigUsecases {
if len ( input ) == 0 {
return nil
}
result := FLAG_ANY
flags := GetAllBackendConfigUsecases ( )
for _ , str := range input {
flag , exists := flags [ "FLAG_" + strings . ToUpper ( str ) ]
if exists {
result |= flag
}
}
return & result
}
// HasUsecases examines a BackendConfig and determines which endpoints have a chance of success.
func ( c * BackendConfig ) HasUsecases ( u BackendConfigUsecases ) bool {
if ( c . KnownUsecases != nil ) && ( ( u & * c . KnownUsecases ) == u ) {
return true
}
return c . GuessUsecases ( u )
}
// GuessUsecases is a **heuristic based** function, as the backend in question may not be loaded yet, and the config may not record what it's useful at.
// In its current state, this function should ideally check for properties of the config like templates, rather than the direct backend name checks for the lower half.
// This avoids the maintenance burden of updating this list for each new backend - but unfortunately, that's the best option for some services currently.
func ( c * BackendConfig ) GuessUsecases ( u BackendConfigUsecases ) bool {
if ( u & FLAG_CHAT ) == FLAG_CHAT {
if c . TemplateConfig . Chat == "" && c . TemplateConfig . ChatMessage == "" {
return false
}
}
if ( u & FLAG_COMPLETION ) == FLAG_COMPLETION {
if c . TemplateConfig . Completion == "" {
return false
}
}
if ( u & FLAG_EDIT ) == FLAG_EDIT {
if c . TemplateConfig . Edit == "" {
return false
}
}
if ( u & FLAG_EMBEDDINGS ) == FLAG_EMBEDDINGS {
if c . Embeddings == nil || ! * c . Embeddings {
return false
}
}
if ( u & FLAG_IMAGE ) == FLAG_IMAGE {
imageBackends := [ ] string { "diffusers" , "tinydream" , "stablediffusion" }
if ! slices . Contains ( imageBackends , c . Backend ) {
return false
}
if c . Backend == "diffusers" && c . Diffusers . PipelineType == "" {
return false
}
}
if ( u & FLAG_RERANK ) == FLAG_RERANK {
if c . Backend != "rerankers" {
return false
}
}
if ( u & FLAG_TRANSCRIPT ) == FLAG_TRANSCRIPT {
if c . Backend != "whisper" {
return false
}
}
if ( u & FLAG_TTS ) == FLAG_TTS {
ttsBackends := [ ] string { "piper" , "transformers-musicgen" , "parler-tts" }
if ! slices . Contains ( ttsBackends , c . Backend ) {
return false
}
}
if ( u & FLAG_SOUND_GENERATION ) == FLAG_SOUND_GENERATION {
if c . Backend != "transformers-musicgen" {
return false
}
}
return true
}