LocalAI/core/backend/embeddings.go

175 lines
4.7 KiB
Go
Raw Normal View History

package backend
import (
"fmt"
"time"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/google/uuid"
"github.com/go-skynet/LocalAI/pkg/concurrency"
"github.com/go-skynet/LocalAI/pkg/grpc"
"github.com/go-skynet/LocalAI/pkg/model"
)
type EmbeddingsBackendService struct {
ml *model.ModelLoader
bcl *config.BackendConfigLoader
appConfig *config.ApplicationConfig
}
func NewEmbeddingsBackendService(ml *model.ModelLoader, bcl *config.BackendConfigLoader, appConfig *config.ApplicationConfig) *EmbeddingsBackendService {
return &EmbeddingsBackendService{
ml: ml,
bcl: bcl,
appConfig: appConfig,
}
}
func (ebs *EmbeddingsBackendService) Embeddings(request *schema.OpenAIRequest) <-chan concurrency.ErrorOr[*schema.OpenAIResponse] {
resultChannel := make(chan concurrency.ErrorOr[*schema.OpenAIResponse])
go func(request *schema.OpenAIRequest) {
if request.Model == "" {
request.Model = model.StableDiffusionBackend
}
bc, request, err := ebs.bcl.LoadBackendConfigForModelAndOpenAIRequest(request.Model, request, ebs.appConfig)
if err != nil {
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err}
close(resultChannel)
return
}
items := []schema.Item{}
for i, s := range bc.InputToken {
// get the model function to call for the result
embedFn, err := modelEmbedding("", s, ebs.ml, bc, ebs.appConfig)
if err != nil {
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err}
close(resultChannel)
return
}
embeddings, err := embedFn()
if err != nil {
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err}
close(resultChannel)
return
}
items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"})
}
for i, s := range bc.InputStrings {
// get the model function to call for the result
embedFn, err := modelEmbedding(s, []int{}, ebs.ml, bc, ebs.appConfig)
if err != nil {
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err}
close(resultChannel)
return
}
embeddings, err := embedFn()
if err != nil {
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err}
close(resultChannel)
return
}
items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"})
}
id := uuid.New().String()
created := int(time.Now().Unix())
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: request.Model, // we have to return what the user sent here, due to OpenAI spec.
Data: items,
Object: "list",
}
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Value: resp}
close(resultChannel)
}(request)
return resultChannel
}
func modelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendConfig *config.BackendConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) {
modelFile := backendConfig.Model
grpcOpts := gRPCModelOpts(backendConfig)
var inferenceModel interface{}
var err error
opts := modelOpts(backendConfig, appConfig, []model.Option{
model.WithLoadGRPCLoadModelOpts(grpcOpts),
model.WithThreads(uint32(*backendConfig.Threads)),
model.WithAssetDir(appConfig.AssetsDestination),
model.WithModel(modelFile),
model.WithContext(appConfig.Context),
})
if backendConfig.Backend == "" {
inferenceModel, err = loader.GreedyLoader(opts...)
} else {
opts = append(opts, model.WithBackendString(backendConfig.Backend))
inferenceModel, err = loader.BackendLoader(opts...)
}
if err != nil {
return nil, err
}
var fn func() ([]float32, error)
switch model := inferenceModel.(type) {
case grpc.Backend:
fn = func() ([]float32, error) {
predictOptions := gRPCPredictOpts(backendConfig, loader.ModelPath)
if len(tokens) > 0 {
embeds := []int32{}
for _, t := range tokens {
embeds = append(embeds, int32(t))
}
predictOptions.EmbeddingTokens = embeds
res, err := model.Embeddings(appConfig.Context, predictOptions)
if err != nil {
return nil, err
}
return res.Embeddings, nil
}
predictOptions.Embeddings = s
res, err := model.Embeddings(appConfig.Context, predictOptions)
if err != nil {
return nil, err
}
return res.Embeddings, nil
}
default:
fn = func() ([]float32, error) {
return nil, fmt.Errorf("embeddings not supported by the backend")
}
}
return func() ([]float32, error) {
embeds, err := fn()
if err != nil {
return embeds, err
}
// Remove trailing 0s
for i := len(embeds) - 1; i >= 0; i-- {
if embeds[i] == 0.0 {
embeds = embeds[:i]
} else {
break
}
}
return embeds, nil
}, nil
}