feat: add experimental support for embeddings as arrays (#207)

This commit is contained in:
Ettore Di Giacinto 2023-05-08 19:31:18 +02:00 committed by GitHub
parent bc03c492a0
commit 89dfa0f5fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 31 additions and 7 deletions

View File

@ -3,7 +3,7 @@ GOTEST=$(GOCMD) test
GOVET=$(GOCMD) vet
BINARY_NAME=local-ai
GOLLAMA_VERSION?=cf9b522db63898dcc5eb86e37c979ab85cbd583e
GOLLAMA_VERSION?=b4e97a42d0c10ada6b529b0ec17b05c72435aeab
GOGPT4ALLJ_VERSION?=1f7bff57f66cb7062e40d0ac3abd2217815e5109
GOGPT2_VERSION?=245a5bfe6708ab80dc5c733dcdbfbe3cfd2acdaa
RWKV_REPO?=https://github.com/donomii/go-rwkv.cpp

View File

@ -33,6 +33,7 @@ type Config struct {
Mirostat int `yaml:"mirostat"`
PromptStrings, InputStrings []string
InputToken [][]int
}
type TemplateConfig struct {
@ -186,8 +187,15 @@ func updateConfig(config *Config, input *OpenAIRequest) {
}
case []interface{}:
for _, pp := range inputs {
if s, ok := pp.(string); ok {
config.InputStrings = append(config.InputStrings, s)
switch i := pp.(type) {
case string:
config.InputStrings = append(config.InputStrings, i)
case []interface{}:
tokens := []int{}
for _, ii := range i {
tokens = append(tokens, int(ii.(float64)))
}
config.InputToken = append(config.InputToken, tokens)
}
}
}

View File

@ -177,10 +177,23 @@ func embeddingsEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader,
log.Debug().Msgf("Parameter Config: %+v", config)
items := []Item{}
for i, s := range config.InputStrings {
for i, s := range config.InputToken {
// get the model function to call for the result
embedFn, err := ModelEmbedding(s, loader, *config)
embedFn, err := ModelEmbedding("", s, loader, *config)
if err != nil {
return err
}
embeddings, err := embedFn()
if err != nil {
return err
}
items = append(items, Item{Embedding: embeddings, Index: i, Object: "embedding"})
}
for i, s := range config.InputStrings {
// get the model function to call for the result
embedFn, err := ModelEmbedding(s, []int{}, loader, *config)
if err != nil {
return err
}

View File

@ -32,7 +32,7 @@ func defaultLLamaOpts(c Config) []llama.ModelOption {
return llamaOpts
}
func ModelEmbedding(s string, loader *model.ModelLoader, c Config) (func() ([]float32, error), error) {
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c Config) (func() ([]float32, error), error) {
if !c.Embeddings {
return nil, fmt.Errorf("endpoint disabled for this model by API configuration")
}
@ -57,6 +57,9 @@ func ModelEmbedding(s string, loader *model.ModelLoader, c Config) (func() ([]fl
case *llama.LLama:
fn = func() ([]float32, error) {
predictOptions := buildLLamaPredictOptions(c)
if len(tokens) > 0 {
return model.TokenEmbeddings(tokens, predictOptions...)
}
return model.Embeddings(s, predictOptions...)
}
default: