feat: allow to set a prompt cache path and enable saving state (#395)

Signed-off-by: mudler <mudler@mocaccino.org>
This commit is contained in:
Ettore Di Giacinto 2023-05-27 14:29:11 +02:00 committed by GitHub
parent 76c881043e
commit 217dbb448e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 39 additions and 22 deletions

View File

@ -3,7 +3,7 @@ GOTEST=$(GOCMD) test
GOVET=$(GOCMD) vet GOVET=$(GOCMD) vet
BINARY_NAME=local-ai BINARY_NAME=local-ai
GOLLAMA_VERSION?=8bd97d532e90cf34e755b3ea2d8aa17000443cf2 GOLLAMA_VERSION?=fbec625895ba0c458f783b62c8569135c5e80d79
GPT4ALL_REPO?=https://github.com/nomic-ai/gpt4all GPT4ALL_REPO?=https://github.com/nomic-ai/gpt4all
GPT4ALL_VERSION?=73db20ba85fbbdc66a56e2619394c0eea40dc72b GPT4ALL_VERSION?=73db20ba85fbbdc66a56e2619394c0eea40dc72b
GOGGMLTRANSFORMERS_VERSION?=c4c581f1853cf1b66276501c7c0dbea1e3e564b7 GOGGMLTRANSFORMERS_VERSION?=c4c581f1853cf1b66276501c7c0dbea1e3e564b7

View File

@ -16,24 +16,28 @@ import (
) )
type Config struct { type Config struct {
OpenAIRequest `yaml:"parameters"` OpenAIRequest `yaml:"parameters"`
Name string `yaml:"name"` Name string `yaml:"name"`
StopWords []string `yaml:"stopwords"` StopWords []string `yaml:"stopwords"`
Cutstrings []string `yaml:"cutstrings"` Cutstrings []string `yaml:"cutstrings"`
TrimSpace []string `yaml:"trimspace"` TrimSpace []string `yaml:"trimspace"`
ContextSize int `yaml:"context_size"` ContextSize int `yaml:"context_size"`
F16 bool `yaml:"f16"` F16 bool `yaml:"f16"`
Threads int `yaml:"threads"` Threads int `yaml:"threads"`
Debug bool `yaml:"debug"` Debug bool `yaml:"debug"`
Roles map[string]string `yaml:"roles"` Roles map[string]string `yaml:"roles"`
Embeddings bool `yaml:"embeddings"` Embeddings bool `yaml:"embeddings"`
Backend string `yaml:"backend"` Backend string `yaml:"backend"`
TemplateConfig TemplateConfig `yaml:"template"` TemplateConfig TemplateConfig `yaml:"template"`
MirostatETA float64 `yaml:"mirostat_eta"` MirostatETA float64 `yaml:"mirostat_eta"`
MirostatTAU float64 `yaml:"mirostat_tau"` MirostatTAU float64 `yaml:"mirostat_tau"`
Mirostat int `yaml:"mirostat"` Mirostat int `yaml:"mirostat"`
NGPULayers int `yaml:"gpu_layers"` NGPULayers int `yaml:"gpu_layers"`
ImageGenerationAssets string `yaml:"asset_dir"` ImageGenerationAssets string `yaml:"asset_dir"`
PromptCachePath string `yaml:"prompt_cache_path"`
PromptCacheAll bool `yaml:"prompt_cache_all"`
PromptStrings, InputStrings []string PromptStrings, InputStrings []string
InputToken [][]int InputToken [][]int
} }

View File

@ -2,6 +2,8 @@ package api
import ( import (
"fmt" "fmt"
"os"
"path/filepath"
"regexp" "regexp"
"strings" "strings"
"sync" "sync"
@ -102,7 +104,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c Config)
switch model := inferenceModel.(type) { switch model := inferenceModel.(type) {
case *llama.LLama: case *llama.LLama:
fn = func() ([]float32, error) { fn = func() ([]float32, error) {
predictOptions := buildLLamaPredictOptions(c) predictOptions := buildLLamaPredictOptions(c, loader.ModelPath)
if len(tokens) > 0 { if len(tokens) > 0 {
return model.TokenEmbeddings(tokens, predictOptions...) return model.TokenEmbeddings(tokens, predictOptions...)
} }
@ -151,7 +153,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c Config)
}, nil }, nil
} }
func buildLLamaPredictOptions(c Config) []llama.PredictOption { func buildLLamaPredictOptions(c Config, modelPath string) []llama.PredictOption {
// Generate the prediction using the language model // Generate the prediction using the language model
predictOptions := []llama.PredictOption{ predictOptions := []llama.PredictOption{
llama.SetTemperature(c.Temperature), llama.SetTemperature(c.Temperature),
@ -161,6 +163,17 @@ func buildLLamaPredictOptions(c Config) []llama.PredictOption {
llama.SetThreads(c.Threads), llama.SetThreads(c.Threads),
} }
if c.PromptCacheAll {
predictOptions = append(predictOptions, llama.EnablePromptCacheAll)
}
if c.PromptCachePath != "" {
// Create parent directory
p := filepath.Join(modelPath, c.PromptCachePath)
os.MkdirAll(filepath.Dir(p), 0755)
predictOptions = append(predictOptions, llama.SetPathPromptCache(p))
}
if c.Mirostat != 0 { if c.Mirostat != 0 {
predictOptions = append(predictOptions, llama.SetMirostat(c.Mirostat)) predictOptions = append(predictOptions, llama.SetMirostat(c.Mirostat))
} }
@ -469,7 +482,7 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback
model.SetTokenCallback(tokenCallback) model.SetTokenCallback(tokenCallback)
} }
predictOptions := buildLLamaPredictOptions(c) predictOptions := buildLLamaPredictOptions(c, loader.ModelPath)
str, er := model.Predict( str, er := model.Predict(
s, s,