LocalAI/core/backend/rerank.go
Ettore Di Giacinto 0965c6cd68
feat: track internally started models by ID (#3693)
* chore(refactor): track internally started models by ID

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Just extend options, no need to copy

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Improve debugging for rerankers failures

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Simplify model loading with rerankers

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Be more consistent when generating model options

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Uncommitted code

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Make deleteProcess more idiomatic

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Adapt CLI for sound generation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Fixup threads definition

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Handle corner case where c.Seed is nil

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Consistently use ModelOptions

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Adapt new code to refactoring

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Co-authored-by: Dave <dave@gray101.com>
2024-10-02 08:55:58 +02:00

28 lines
720 B
Go

package backend
import (
"context"
"fmt"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/pkg/grpc/proto"
model "github.com/mudler/LocalAI/pkg/model"
)
func Rerank(modelFile string, request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (*proto.RerankResult, error) {
opts := ModelOptions(backendConfig, appConfig, []model.Option{model.WithModel(modelFile)})
rerankModel, err := loader.BackendLoader(opts...)
if err != nil {
return nil, err
}
if rerankModel == nil {
return nil, fmt.Errorf("could not load rerank model")
}
res, err := rerankModel.Rerank(context.Background(), request)
return res, err
}