mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-02 08:43:10 +00:00
fix(model-loading): keep track of open GRPC Clients (#3377)
Due to a previous refactor we moved the client constructor tight to the model address, however that was just a string which we would use to build the client each time. With this change we make the loader to return a *Model which carries a constructor for the client and stores the client on the first connection. Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
771a052480
commit
7f06954425
@ -107,7 +107,7 @@ func (bms BackendMonitorService) CheckAndSample(modelName string) (*proto.Status
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
modelAddr := bms.modelLoader.CheckIsLoaded(backendId)
|
modelAddr := bms.modelLoader.CheckIsLoaded(backendId)
|
||||||
if modelAddr == "" {
|
if modelAddr == nil {
|
||||||
return nil, fmt.Errorf("backend %s is not currently loaded", backendId)
|
return nil, fmt.Errorf("backend %s is not currently loaded", backendId)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -18,10 +18,10 @@ func NewClient(address string, parallel bool, wd WatchDog, enableWatchDog bool)
|
|||||||
if bc, ok := embeds[address]; ok {
|
if bc, ok := embeds[address]; ok {
|
||||||
return bc
|
return bc
|
||||||
}
|
}
|
||||||
return NewGrpcClient(address, parallel, wd, enableWatchDog)
|
return buildClient(address, parallel, wd, enableWatchDog)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewGrpcClient(address string, parallel bool, wd WatchDog, enableWatchDog bool) Backend {
|
func buildClient(address string, parallel bool, wd WatchDog, enableWatchDog bool) Backend {
|
||||||
if !enableWatchDog {
|
if !enableWatchDog {
|
||||||
wd = nil
|
wd = nil
|
||||||
}
|
}
|
||||||
|
@ -39,6 +39,18 @@ func (c *Client) setBusy(v bool) {
|
|||||||
c.Unlock()
|
c.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) wdMark() {
|
||||||
|
if c.wd != nil {
|
||||||
|
c.wd.Mark(c.address)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) wdUnMark() {
|
||||||
|
if c.wd != nil {
|
||||||
|
c.wd.UnMark(c.address)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Client) HealthCheck(ctx context.Context) (bool, error) {
|
func (c *Client) HealthCheck(ctx context.Context) (bool, error) {
|
||||||
if !c.parallel {
|
if !c.parallel {
|
||||||
c.opMutex.Lock()
|
c.opMutex.Lock()
|
||||||
@ -76,10 +88,8 @@ func (c *Client) Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...
|
|||||||
}
|
}
|
||||||
c.setBusy(true)
|
c.setBusy(true)
|
||||||
defer c.setBusy(false)
|
defer c.setBusy(false)
|
||||||
if c.wd != nil {
|
c.wdMark()
|
||||||
c.wd.Mark(c.address)
|
defer c.wdUnMark()
|
||||||
defer c.wd.UnMark(c.address)
|
|
||||||
}
|
|
||||||
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -97,10 +107,8 @@ func (c *Client) Predict(ctx context.Context, in *pb.PredictOptions, opts ...grp
|
|||||||
}
|
}
|
||||||
c.setBusy(true)
|
c.setBusy(true)
|
||||||
defer c.setBusy(false)
|
defer c.setBusy(false)
|
||||||
if c.wd != nil {
|
c.wdMark()
|
||||||
c.wd.Mark(c.address)
|
defer c.wdUnMark()
|
||||||
defer c.wd.UnMark(c.address)
|
|
||||||
}
|
|
||||||
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -118,10 +126,8 @@ func (c *Client) LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grp
|
|||||||
}
|
}
|
||||||
c.setBusy(true)
|
c.setBusy(true)
|
||||||
defer c.setBusy(false)
|
defer c.setBusy(false)
|
||||||
if c.wd != nil {
|
c.wdMark()
|
||||||
c.wd.Mark(c.address)
|
defer c.wdUnMark()
|
||||||
defer c.wd.UnMark(c.address)
|
|
||||||
}
|
|
||||||
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -138,10 +144,8 @@ func (c *Client) PredictStream(ctx context.Context, in *pb.PredictOptions, f fun
|
|||||||
}
|
}
|
||||||
c.setBusy(true)
|
c.setBusy(true)
|
||||||
defer c.setBusy(false)
|
defer c.setBusy(false)
|
||||||
if c.wd != nil {
|
c.wdMark()
|
||||||
c.wd.Mark(c.address)
|
defer c.wdUnMark()
|
||||||
defer c.wd.UnMark(c.address)
|
|
||||||
}
|
|
||||||
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -177,10 +181,8 @@ func (c *Client) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest,
|
|||||||
}
|
}
|
||||||
c.setBusy(true)
|
c.setBusy(true)
|
||||||
defer c.setBusy(false)
|
defer c.setBusy(false)
|
||||||
if c.wd != nil {
|
c.wdMark()
|
||||||
c.wd.Mark(c.address)
|
defer c.wdUnMark()
|
||||||
defer c.wd.UnMark(c.address)
|
|
||||||
}
|
|
||||||
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -197,10 +199,8 @@ func (c *Client) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOp
|
|||||||
}
|
}
|
||||||
c.setBusy(true)
|
c.setBusy(true)
|
||||||
defer c.setBusy(false)
|
defer c.setBusy(false)
|
||||||
if c.wd != nil {
|
c.wdMark()
|
||||||
c.wd.Mark(c.address)
|
defer c.wdUnMark()
|
||||||
defer c.wd.UnMark(c.address)
|
|
||||||
}
|
|
||||||
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -217,10 +217,8 @@ func (c *Client) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequ
|
|||||||
}
|
}
|
||||||
c.setBusy(true)
|
c.setBusy(true)
|
||||||
defer c.setBusy(false)
|
defer c.setBusy(false)
|
||||||
if c.wd != nil {
|
c.wdMark()
|
||||||
c.wd.Mark(c.address)
|
defer c.wdUnMark()
|
||||||
defer c.wd.UnMark(c.address)
|
|
||||||
}
|
|
||||||
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -237,10 +235,8 @@ func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptReques
|
|||||||
}
|
}
|
||||||
c.setBusy(true)
|
c.setBusy(true)
|
||||||
defer c.setBusy(false)
|
defer c.setBusy(false)
|
||||||
if c.wd != nil {
|
c.wdMark()
|
||||||
c.wd.Mark(c.address)
|
defer c.wdUnMark()
|
||||||
defer c.wd.UnMark(c.address)
|
|
||||||
}
|
|
||||||
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -277,10 +273,8 @@ func (c *Client) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts
|
|||||||
}
|
}
|
||||||
c.setBusy(true)
|
c.setBusy(true)
|
||||||
defer c.setBusy(false)
|
defer c.setBusy(false)
|
||||||
if c.wd != nil {
|
c.wdMark()
|
||||||
c.wd.Mark(c.address)
|
defer c.wdUnMark()
|
||||||
defer c.wd.UnMark(c.address)
|
|
||||||
}
|
|
||||||
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -319,6 +313,8 @@ func (c *Client) StoresSet(ctx context.Context, in *pb.StoresSetOptions, opts ..
|
|||||||
}
|
}
|
||||||
c.setBusy(true)
|
c.setBusy(true)
|
||||||
defer c.setBusy(false)
|
defer c.setBusy(false)
|
||||||
|
c.wdMark()
|
||||||
|
defer c.wdUnMark()
|
||||||
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -333,6 +329,8 @@ func (c *Client) StoresDelete(ctx context.Context, in *pb.StoresDeleteOptions, o
|
|||||||
c.opMutex.Lock()
|
c.opMutex.Lock()
|
||||||
defer c.opMutex.Unlock()
|
defer c.opMutex.Unlock()
|
||||||
}
|
}
|
||||||
|
c.wdMark()
|
||||||
|
defer c.wdUnMark()
|
||||||
c.setBusy(true)
|
c.setBusy(true)
|
||||||
defer c.setBusy(false)
|
defer c.setBusy(false)
|
||||||
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||||
@ -351,6 +349,8 @@ func (c *Client) StoresGet(ctx context.Context, in *pb.StoresGetOptions, opts ..
|
|||||||
}
|
}
|
||||||
c.setBusy(true)
|
c.setBusy(true)
|
||||||
defer c.setBusy(false)
|
defer c.setBusy(false)
|
||||||
|
c.wdMark()
|
||||||
|
defer c.wdUnMark()
|
||||||
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -367,6 +367,8 @@ func (c *Client) StoresFind(ctx context.Context, in *pb.StoresFindOptions, opts
|
|||||||
}
|
}
|
||||||
c.setBusy(true)
|
c.setBusy(true)
|
||||||
defer c.setBusy(false)
|
defer c.setBusy(false)
|
||||||
|
c.wdMark()
|
||||||
|
defer c.wdUnMark()
|
||||||
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -383,6 +385,8 @@ func (c *Client) Rerank(ctx context.Context, in *pb.RerankRequest, opts ...grpc.
|
|||||||
}
|
}
|
||||||
c.setBusy(true)
|
c.setBusy(true)
|
||||||
defer c.setBusy(false)
|
defer c.setBusy(false)
|
||||||
|
c.wdMark()
|
||||||
|
defer c.wdUnMark()
|
||||||
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -80,6 +80,9 @@ ENTRY:
|
|||||||
if e.IsDir() {
|
if e.IsDir() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if strings.HasSuffix(e.Name(), ".log") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
// Skip the llama.cpp variants if we are autoDetecting
|
// Skip the llama.cpp variants if we are autoDetecting
|
||||||
// But we always load the fallback variant if it exists
|
// But we always load the fallback variant if it exists
|
||||||
@ -265,12 +268,12 @@ func selectGRPCProcess(backend, assetDir string, f16 bool) string {
|
|||||||
|
|
||||||
// starts the grpcModelProcess for the backend, and returns a grpc client
|
// starts the grpcModelProcess for the backend, and returns a grpc client
|
||||||
// It also loads the model
|
// It also loads the model
|
||||||
func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string) (ModelAddress, error) {
|
func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string) (*Model, error) {
|
||||||
return func(modelName, modelFile string) (ModelAddress, error) {
|
return func(modelName, modelFile string) (*Model, error) {
|
||||||
|
|
||||||
log.Debug().Msgf("Loading Model %s with gRPC (file: %s) (backend: %s): %+v", modelName, modelFile, backend, *o)
|
log.Debug().Msgf("Loading Model %s with gRPC (file: %s) (backend: %s): %+v", modelName, modelFile, backend, *o)
|
||||||
|
|
||||||
var client ModelAddress
|
var client *Model
|
||||||
|
|
||||||
getFreeAddress := func() (string, error) {
|
getFreeAddress := func() (string, error) {
|
||||||
port, err := freeport.GetFreePort()
|
port, err := freeport.GetFreePort()
|
||||||
@ -298,26 +301,26 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string
|
|||||||
log.Debug().Msgf("external backend is file: %+v", fi)
|
log.Debug().Msgf("external backend is file: %+v", fi)
|
||||||
serverAddress, err := getFreeAddress()
|
serverAddress, err := getFreeAddress()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed allocating free ports: %s", err.Error())
|
return nil, fmt.Errorf("failed allocating free ports: %s", err.Error())
|
||||||
}
|
}
|
||||||
// Make sure the process is executable
|
// Make sure the process is executable
|
||||||
if err := ml.startProcess(uri, o.model, serverAddress); err != nil {
|
if err := ml.startProcess(uri, o.model, serverAddress); err != nil {
|
||||||
log.Error().Err(err).Str("path", uri).Msg("failed to launch ")
|
log.Error().Err(err).Str("path", uri).Msg("failed to launch ")
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug().Msgf("GRPC Service Started")
|
log.Debug().Msgf("GRPC Service Started")
|
||||||
|
|
||||||
client = ModelAddress(serverAddress)
|
client = NewModel(serverAddress)
|
||||||
} else {
|
} else {
|
||||||
log.Debug().Msg("external backend is uri")
|
log.Debug().Msg("external backend is uri")
|
||||||
// address
|
// address
|
||||||
client = ModelAddress(uri)
|
client = NewModel(uri)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
grpcProcess := backendPath(o.assetDir, backend)
|
grpcProcess := backendPath(o.assetDir, backend)
|
||||||
if err := utils.VerifyPath(grpcProcess, o.assetDir); err != nil {
|
if err := utils.VerifyPath(grpcProcess, o.assetDir); err != nil {
|
||||||
return "", fmt.Errorf("grpc process not found in assetdir: %s", err.Error())
|
return nil, fmt.Errorf("grpc process not found in assetdir: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
if autoDetect {
|
if autoDetect {
|
||||||
@ -329,12 +332,12 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string
|
|||||||
|
|
||||||
// Check if the file exists
|
// Check if the file exists
|
||||||
if _, err := os.Stat(grpcProcess); os.IsNotExist(err) {
|
if _, err := os.Stat(grpcProcess); os.IsNotExist(err) {
|
||||||
return "", fmt.Errorf("grpc process not found: %s. some backends(stablediffusion, tts) require LocalAI compiled with GO_TAGS", grpcProcess)
|
return nil, fmt.Errorf("grpc process not found: %s. some backends(stablediffusion, tts) require LocalAI compiled with GO_TAGS", grpcProcess)
|
||||||
}
|
}
|
||||||
|
|
||||||
serverAddress, err := getFreeAddress()
|
serverAddress, err := getFreeAddress()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed allocating free ports: %s", err.Error())
|
return nil, fmt.Errorf("failed allocating free ports: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
args := []string{}
|
args := []string{}
|
||||||
@ -344,12 +347,12 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string
|
|||||||
|
|
||||||
// Make sure the process is executable in any circumstance
|
// Make sure the process is executable in any circumstance
|
||||||
if err := ml.startProcess(grpcProcess, o.model, serverAddress, args...); err != nil {
|
if err := ml.startProcess(grpcProcess, o.model, serverAddress, args...); err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug().Msgf("GRPC Service Started")
|
log.Debug().Msgf("GRPC Service Started")
|
||||||
|
|
||||||
client = ModelAddress(serverAddress)
|
client = NewModel(serverAddress)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait for the service to start up
|
// Wait for the service to start up
|
||||||
@ -369,7 +372,7 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string
|
|||||||
|
|
||||||
if !ready {
|
if !ready {
|
||||||
log.Debug().Msgf("GRPC Service NOT ready")
|
log.Debug().Msgf("GRPC Service NOT ready")
|
||||||
return "", fmt.Errorf("grpc service not ready")
|
return nil, fmt.Errorf("grpc service not ready")
|
||||||
}
|
}
|
||||||
|
|
||||||
options := *o.gRPCOptions
|
options := *o.gRPCOptions
|
||||||
@ -380,27 +383,16 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string
|
|||||||
|
|
||||||
res, err := client.GRPC(o.parallelRequests, ml.wd).LoadModel(o.context, &options)
|
res, err := client.GRPC(o.parallelRequests, ml.wd).LoadModel(o.context, &options)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("could not load model: %w", err)
|
return nil, fmt.Errorf("could not load model: %w", err)
|
||||||
}
|
}
|
||||||
if !res.Success {
|
if !res.Success {
|
||||||
return "", fmt.Errorf("could not load model (no success): %s", res.Message)
|
return nil, fmt.Errorf("could not load model (no success): %s", res.Message)
|
||||||
}
|
}
|
||||||
|
|
||||||
return client, nil
|
return client, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ml *ModelLoader) resolveAddress(addr ModelAddress, parallel bool) (grpc.Backend, error) {
|
|
||||||
if parallel {
|
|
||||||
return addr.GRPC(parallel, ml.wd), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, ok := ml.grpcClients[string(addr)]; !ok {
|
|
||||||
ml.grpcClients[string(addr)] = addr.GRPC(parallel, ml.wd)
|
|
||||||
}
|
|
||||||
return ml.grpcClients[string(addr)], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ml *ModelLoader) BackendLoader(opts ...Option) (client grpc.Backend, err error) {
|
func (ml *ModelLoader) BackendLoader(opts ...Option) (client grpc.Backend, err error) {
|
||||||
o := NewOptions(opts...)
|
o := NewOptions(opts...)
|
||||||
|
|
||||||
@ -425,7 +417,6 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (client grpc.Backend, err e
|
|||||||
log.Error().Err(err).Str("keptModel", o.model).Msg("error while shutting down all backends except for the keptModel")
|
log.Error().Err(err).Str("keptModel", o.model).Msg("error while shutting down all backends except for the keptModel")
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var backendToConsume string
|
var backendToConsume string
|
||||||
@ -438,26 +429,28 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (client grpc.Backend, err e
|
|||||||
backendToConsume = backend
|
backendToConsume = backend
|
||||||
}
|
}
|
||||||
|
|
||||||
addr, err := ml.LoadModel(o.model, ml.grpcModel(backendToConsume, o))
|
model, err := ml.LoadModel(o.model, ml.grpcModel(backendToConsume, o))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return ml.resolveAddress(addr, o.parallelRequests)
|
return model.GRPC(o.parallelRequests, ml.wd), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ml *ModelLoader) GreedyLoader(opts ...Option) (grpc.Backend, error) {
|
func (ml *ModelLoader) GreedyLoader(opts ...Option) (grpc.Backend, error) {
|
||||||
o := NewOptions(opts...)
|
o := NewOptions(opts...)
|
||||||
|
|
||||||
ml.mu.Lock()
|
ml.mu.Lock()
|
||||||
|
|
||||||
// Return earlier if we have a model already loaded
|
// Return earlier if we have a model already loaded
|
||||||
// (avoid looping through all the backends)
|
// (avoid looping through all the backends)
|
||||||
if m := ml.CheckIsLoaded(o.model); m != "" {
|
if m := ml.CheckIsLoaded(o.model); m != nil {
|
||||||
log.Debug().Msgf("Model '%s' already loaded", o.model)
|
log.Debug().Msgf("Model '%s' already loaded", o.model)
|
||||||
ml.mu.Unlock()
|
ml.mu.Unlock()
|
||||||
|
|
||||||
return ml.resolveAddress(m, o.parallelRequests)
|
return m.GRPC(o.parallelRequests, ml.wd), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we can have only one backend active, kill all the others (except external backends)
|
// If we can have only one backend active, kill all the others (except external backends)
|
||||||
if o.singleActiveBackend {
|
if o.singleActiveBackend {
|
||||||
log.Debug().Msgf("Stopping all backends except '%s'", o.model)
|
log.Debug().Msgf("Stopping all backends except '%s'", o.model)
|
||||||
|
@ -10,67 +10,28 @@ import (
|
|||||||
|
|
||||||
"github.com/mudler/LocalAI/pkg/templates"
|
"github.com/mudler/LocalAI/pkg/templates"
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/pkg/functions"
|
|
||||||
"github.com/mudler/LocalAI/pkg/grpc"
|
|
||||||
"github.com/mudler/LocalAI/pkg/utils"
|
"github.com/mudler/LocalAI/pkg/utils"
|
||||||
|
|
||||||
process "github.com/mudler/go-processmanager"
|
process "github.com/mudler/go-processmanager"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Rather than pass an interface{} to the prompt template:
|
|
||||||
// These are the definitions of all possible variables LocalAI will currently populate for use in a prompt template file
|
|
||||||
// Please note: Not all of these are populated on every endpoint - your template should either be tested for each endpoint you map it to, or tolerant of zero values.
|
|
||||||
type PromptTemplateData struct {
|
|
||||||
SystemPrompt string
|
|
||||||
SuppressSystemPrompt bool // used by chat specifically to indicate that SystemPrompt above should be _ignored_
|
|
||||||
Input string
|
|
||||||
Instruction string
|
|
||||||
Functions []functions.Function
|
|
||||||
MessageIndex int
|
|
||||||
}
|
|
||||||
|
|
||||||
type ChatMessageTemplateData struct {
|
|
||||||
SystemPrompt string
|
|
||||||
Role string
|
|
||||||
RoleName string
|
|
||||||
FunctionName string
|
|
||||||
Content string
|
|
||||||
MessageIndex int
|
|
||||||
Function bool
|
|
||||||
FunctionCall interface{}
|
|
||||||
LastMessage bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// new idea: what if we declare a struct of these here, and use a loop to check?
|
// new idea: what if we declare a struct of these here, and use a loop to check?
|
||||||
|
|
||||||
// TODO: Split ModelLoader and TemplateLoader? Just to keep things more organized. Left together to share a mutex until I look into that. Would split if we seperate directories for .bin/.yaml and .tmpl
|
// TODO: Split ModelLoader and TemplateLoader? Just to keep things more organized. Left together to share a mutex until I look into that. Would split if we seperate directories for .bin/.yaml and .tmpl
|
||||||
type ModelLoader struct {
|
type ModelLoader struct {
|
||||||
ModelPath string
|
ModelPath string
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
// TODO: this needs generics
|
models map[string]*Model
|
||||||
grpcClients map[string]grpc.Backend
|
|
||||||
models map[string]ModelAddress
|
|
||||||
grpcProcesses map[string]*process.Process
|
grpcProcesses map[string]*process.Process
|
||||||
templates *templates.TemplateCache
|
templates *templates.TemplateCache
|
||||||
wd *WatchDog
|
wd *WatchDog
|
||||||
}
|
}
|
||||||
|
|
||||||
type ModelAddress string
|
|
||||||
|
|
||||||
func (m ModelAddress) GRPC(parallel bool, wd *WatchDog) grpc.Backend {
|
|
||||||
enableWD := false
|
|
||||||
if wd != nil {
|
|
||||||
enableWD = true
|
|
||||||
}
|
|
||||||
return grpc.NewClient(string(m), parallel, wd, enableWD)
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewModelLoader(modelPath string) *ModelLoader {
|
func NewModelLoader(modelPath string) *ModelLoader {
|
||||||
nml := &ModelLoader{
|
nml := &ModelLoader{
|
||||||
ModelPath: modelPath,
|
ModelPath: modelPath,
|
||||||
grpcClients: make(map[string]grpc.Backend),
|
models: make(map[string]*Model),
|
||||||
models: make(map[string]ModelAddress),
|
|
||||||
templates: templates.NewTemplateCache(modelPath),
|
templates: templates.NewTemplateCache(modelPath),
|
||||||
grpcProcesses: make(map[string]*process.Process),
|
grpcProcesses: make(map[string]*process.Process),
|
||||||
}
|
}
|
||||||
@ -141,12 +102,12 @@ FILE:
|
|||||||
return models, nil
|
return models, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ml *ModelLoader) LoadModel(modelName string, loader func(string, string) (ModelAddress, error)) (ModelAddress, error) {
|
func (ml *ModelLoader) LoadModel(modelName string, loader func(string, string) (*Model, error)) (*Model, error) {
|
||||||
ml.mu.Lock()
|
ml.mu.Lock()
|
||||||
defer ml.mu.Unlock()
|
defer ml.mu.Unlock()
|
||||||
|
|
||||||
// Check if we already have a loaded model
|
// Check if we already have a loaded model
|
||||||
if model := ml.CheckIsLoaded(modelName); model != "" {
|
if model := ml.CheckIsLoaded(modelName); model != nil {
|
||||||
return model, nil
|
return model, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -156,17 +117,9 @@ func (ml *ModelLoader) LoadModel(modelName string, loader func(string, string) (
|
|||||||
|
|
||||||
model, err := loader(modelName, modelFile)
|
model, err := loader(modelName, modelFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Add a helper method to iterate all prompt templates associated with a config if and only if it's YAML?
|
|
||||||
// Minor perf loss here until this is fixed, but we initialize on first request
|
|
||||||
|
|
||||||
// // If there is a prompt template, load it
|
|
||||||
// if err := ml.loadTemplateIfExists(modelName); err != nil {
|
|
||||||
// return nil, err
|
|
||||||
// }
|
|
||||||
|
|
||||||
ml.models[modelName] = model
|
ml.models[modelName] = model
|
||||||
return model, nil
|
return model, nil
|
||||||
}
|
}
|
||||||
@ -184,55 +137,29 @@ func (ml *ModelLoader) stopModel(modelName string) error {
|
|||||||
return fmt.Errorf("model %s not found", modelName)
|
return fmt.Errorf("model %s not found", modelName)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
//return ml.deleteProcess(modelName)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ml *ModelLoader) CheckIsLoaded(s string) ModelAddress {
|
func (ml *ModelLoader) CheckIsLoaded(s string) *Model {
|
||||||
var client grpc.Backend
|
m, ok := ml.models[s]
|
||||||
if m, ok := ml.models[s]; ok {
|
if !ok {
|
||||||
log.Debug().Msgf("Model already loaded in memory: %s", s)
|
return nil
|
||||||
if c, ok := ml.grpcClients[s]; ok {
|
}
|
||||||
client = c
|
|
||||||
} else {
|
log.Debug().Msgf("Model already loaded in memory: %s", s)
|
||||||
client = m.GRPC(false, ml.wd)
|
alive, err := m.GRPC(false, ml.wd).HealthCheck(context.Background())
|
||||||
}
|
if !alive {
|
||||||
alive, err := client.HealthCheck(context.Background())
|
log.Warn().Msgf("GRPC Model not responding: %s", err.Error())
|
||||||
if !alive {
|
log.Warn().Msgf("Deleting the process in order to recreate it")
|
||||||
log.Warn().Msgf("GRPC Model not responding: %s", err.Error())
|
if !ml.grpcProcesses[s].IsAlive() {
|
||||||
log.Warn().Msgf("Deleting the process in order to recreate it")
|
log.Debug().Msgf("GRPC Process is not responding: %s", s)
|
||||||
if !ml.grpcProcesses[s].IsAlive() {
|
// stop and delete the process, this forces to re-load the model and re-create again the service
|
||||||
log.Debug().Msgf("GRPC Process is not responding: %s", s)
|
err := ml.deleteProcess(s)
|
||||||
// stop and delete the process, this forces to re-load the model and re-create again the service
|
if err != nil {
|
||||||
err := ml.deleteProcess(s)
|
log.Error().Err(err).Str("process", s).Msg("error stopping process")
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Str("process", s).Msg("error stopping process")
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return m
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return ""
|
return m
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
ChatPromptTemplate templates.TemplateType = iota
|
|
||||||
ChatMessageTemplate
|
|
||||||
CompletionPromptTemplate
|
|
||||||
EditPromptTemplate
|
|
||||||
FunctionsPromptTemplate
|
|
||||||
)
|
|
||||||
|
|
||||||
func (ml *ModelLoader) EvaluateTemplateForPrompt(templateType templates.TemplateType, templateName string, in PromptTemplateData) (string, error) {
|
|
||||||
// TODO: should this check be improved?
|
|
||||||
if templateType == ChatMessageTemplate {
|
|
||||||
return "", fmt.Errorf("invalid templateType: ChatMessage")
|
|
||||||
}
|
|
||||||
return ml.templates.EvaluateTemplate(templateType, templateName, in)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ml *ModelLoader) EvaluateTemplateForChatMessage(templateName string, messageData ChatMessageTemplateData) (string, error) {
|
|
||||||
return ml.templates.EvaluateTemplate(ChatMessageTemplate, templateName, messageData)
|
|
||||||
}
|
}
|
||||||
|
29
pkg/model/model.go
Normal file
29
pkg/model/model.go
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||||
|
|
||||||
|
type Model struct {
|
||||||
|
address string
|
||||||
|
client grpc.Backend
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewModel(address string) *Model {
|
||||||
|
return &Model{
|
||||||
|
address: address,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) GRPC(parallel bool, wd *WatchDog) grpc.Backend {
|
||||||
|
if m.client != nil {
|
||||||
|
return m.client
|
||||||
|
}
|
||||||
|
|
||||||
|
enableWD := false
|
||||||
|
if wd != nil {
|
||||||
|
enableWD = true
|
||||||
|
}
|
||||||
|
|
||||||
|
client := grpc.NewClient(m.address, parallel, wd, enableWD)
|
||||||
|
m.client = client
|
||||||
|
return client
|
||||||
|
}
|
@ -33,7 +33,7 @@ func (ml *ModelLoader) StopAllExcept(s string) error {
|
|||||||
func (ml *ModelLoader) deleteProcess(s string) error {
|
func (ml *ModelLoader) deleteProcess(s string) error {
|
||||||
if _, exists := ml.grpcProcesses[s]; exists {
|
if _, exists := ml.grpcProcesses[s]; exists {
|
||||||
if err := ml.grpcProcesses[s].Stop(); err != nil {
|
if err := ml.grpcProcesses[s].Stop(); err != nil {
|
||||||
return err
|
log.Error().Err(err).Msgf("(deleteProcess) error while deleting grpc process %s", s)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
delete(ml.grpcProcesses, s)
|
delete(ml.grpcProcesses, s)
|
||||||
|
52
pkg/model/template.go
Normal file
52
pkg/model/template.go
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/mudler/LocalAI/pkg/functions"
|
||||||
|
"github.com/mudler/LocalAI/pkg/templates"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Rather than pass an interface{} to the prompt template:
|
||||||
|
// These are the definitions of all possible variables LocalAI will currently populate for use in a prompt template file
|
||||||
|
// Please note: Not all of these are populated on every endpoint - your template should either be tested for each endpoint you map it to, or tolerant of zero values.
|
||||||
|
type PromptTemplateData struct {
|
||||||
|
SystemPrompt string
|
||||||
|
SuppressSystemPrompt bool // used by chat specifically to indicate that SystemPrompt above should be _ignored_
|
||||||
|
Input string
|
||||||
|
Instruction string
|
||||||
|
Functions []functions.Function
|
||||||
|
MessageIndex int
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatMessageTemplateData struct {
|
||||||
|
SystemPrompt string
|
||||||
|
Role string
|
||||||
|
RoleName string
|
||||||
|
FunctionName string
|
||||||
|
Content string
|
||||||
|
MessageIndex int
|
||||||
|
Function bool
|
||||||
|
FunctionCall interface{}
|
||||||
|
LastMessage bool
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
ChatPromptTemplate templates.TemplateType = iota
|
||||||
|
ChatMessageTemplate
|
||||||
|
CompletionPromptTemplate
|
||||||
|
EditPromptTemplate
|
||||||
|
FunctionsPromptTemplate
|
||||||
|
)
|
||||||
|
|
||||||
|
func (ml *ModelLoader) EvaluateTemplateForPrompt(templateType templates.TemplateType, templateName string, in PromptTemplateData) (string, error) {
|
||||||
|
// TODO: should this check be improved?
|
||||||
|
if templateType == ChatMessageTemplate {
|
||||||
|
return "", fmt.Errorf("invalid templateType: ChatMessage")
|
||||||
|
}
|
||||||
|
return ml.templates.EvaluateTemplate(templateType, templateName, in)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ml *ModelLoader) EvaluateTemplateForChatMessage(templateName string, messageData ChatMessageTemplateData) (string, error) {
|
||||||
|
return ml.templates.EvaluateTemplate(ChatMessageTemplate, templateName, messageData)
|
||||||
|
}
|
@ -8,6 +8,7 @@ import (
|
|||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// WatchDog tracks all the requests from GRPC clients.
|
||||||
// All GRPC Clients created by ModelLoader should have an associated injected
|
// All GRPC Clients created by ModelLoader should have an associated injected
|
||||||
// watchdog that will keep track of the state of each backend (busy or not)
|
// watchdog that will keep track of the state of each backend (busy or not)
|
||||||
// and for how much time it has been busy.
|
// and for how much time it has been busy.
|
||||||
@ -15,7 +16,6 @@ import (
|
|||||||
// force a reload of the model
|
// force a reload of the model
|
||||||
// The watchdog runs as a separate go routine,
|
// The watchdog runs as a separate go routine,
|
||||||
// and the GRPC client talks to it via a channel to send status updates
|
// and the GRPC client talks to it via a channel to send status updates
|
||||||
|
|
||||||
type WatchDog struct {
|
type WatchDog struct {
|
||||||
sync.Mutex
|
sync.Mutex
|
||||||
timetable map[string]time.Time
|
timetable map[string]time.Time
|
||||||
|
Loading…
x
Reference in New Issue
Block a user