mirror of
https://github.com/mudler/LocalAI.git
synced 2025-03-15 16:45:31 +00:00
By not closing the channel, if a server not implementing PredictStream receives a client call would hang indefinetly as would wait for resultChan to be consumed. If the prediction stream returns we close the channel now and we wait for the goroutine to finish. Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
266 lines
6.7 KiB
Go
266 lines
6.7 KiB
Go
package grpc
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log"
|
|
"net"
|
|
|
|
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
|
"google.golang.org/grpc"
|
|
)
|
|
|
|
// A GRPC Server that allows to run LLM inference.
|
|
// It is used by the LLMServices to expose the LLM functionalities that are called by the client.
|
|
// The GRPC Service is general, trying to encompass all the possible LLM options models.
|
|
// It depends on the real implementer then what can be done or not.
|
|
//
|
|
// The server is implemented as a GRPC service, with the following methods:
|
|
// - Predict: to run the inference with options
|
|
// - PredictStream: to run the inference with options and stream the results
|
|
|
|
// server is used to implement helloworld.GreeterServer.
|
|
type server struct {
|
|
pb.UnimplementedBackendServer
|
|
llm LLM
|
|
}
|
|
|
|
func (s *server) Health(ctx context.Context, in *pb.HealthMessage) (*pb.Reply, error) {
|
|
return newReply("OK"), nil
|
|
}
|
|
|
|
func (s *server) Embedding(ctx context.Context, in *pb.PredictOptions) (*pb.EmbeddingResult, error) {
|
|
if s.llm.Locking() {
|
|
s.llm.Lock()
|
|
defer s.llm.Unlock()
|
|
}
|
|
embeds, err := s.llm.Embeddings(in)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &pb.EmbeddingResult{Embeddings: embeds}, nil
|
|
}
|
|
|
|
func (s *server) LoadModel(ctx context.Context, in *pb.ModelOptions) (*pb.Result, error) {
|
|
if s.llm.Locking() {
|
|
s.llm.Lock()
|
|
defer s.llm.Unlock()
|
|
}
|
|
err := s.llm.Load(in)
|
|
if err != nil {
|
|
return &pb.Result{Message: fmt.Sprintf("Error loading model: %s", err.Error()), Success: false}, err
|
|
}
|
|
return &pb.Result{Message: "Loading succeeded", Success: true}, nil
|
|
}
|
|
|
|
func (s *server) Predict(ctx context.Context, in *pb.PredictOptions) (*pb.Reply, error) {
|
|
if s.llm.Locking() {
|
|
s.llm.Lock()
|
|
defer s.llm.Unlock()
|
|
}
|
|
result, err := s.llm.Predict(in)
|
|
return newReply(result), err
|
|
}
|
|
|
|
func (s *server) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest) (*pb.Result, error) {
|
|
if s.llm.Locking() {
|
|
s.llm.Lock()
|
|
defer s.llm.Unlock()
|
|
}
|
|
err := s.llm.GenerateImage(in)
|
|
if err != nil {
|
|
return &pb.Result{Message: fmt.Sprintf("Error generating image: %s", err.Error()), Success: false}, err
|
|
}
|
|
return &pb.Result{Message: "Image generated", Success: true}, nil
|
|
}
|
|
|
|
func (s *server) TTS(ctx context.Context, in *pb.TTSRequest) (*pb.Result, error) {
|
|
if s.llm.Locking() {
|
|
s.llm.Lock()
|
|
defer s.llm.Unlock()
|
|
}
|
|
err := s.llm.TTS(in)
|
|
if err != nil {
|
|
return &pb.Result{Message: fmt.Sprintf("Error generating audio: %s", err.Error()), Success: false}, err
|
|
}
|
|
return &pb.Result{Message: "TTS audio generated", Success: true}, nil
|
|
}
|
|
|
|
func (s *server) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest) (*pb.Result, error) {
|
|
if s.llm.Locking() {
|
|
s.llm.Lock()
|
|
defer s.llm.Unlock()
|
|
}
|
|
err := s.llm.SoundGeneration(in)
|
|
if err != nil {
|
|
return &pb.Result{Message: fmt.Sprintf("Error generating audio: %s", err.Error()), Success: false}, err
|
|
}
|
|
return &pb.Result{Message: "Sound Generation audio generated", Success: true}, nil
|
|
}
|
|
|
|
func (s *server) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest) (*pb.TranscriptResult, error) {
|
|
if s.llm.Locking() {
|
|
s.llm.Lock()
|
|
defer s.llm.Unlock()
|
|
}
|
|
result, err := s.llm.AudioTranscription(in)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
tresult := &pb.TranscriptResult{}
|
|
for _, s := range result.Segments {
|
|
tks := []int32{}
|
|
for _, t := range s.Tokens {
|
|
tks = append(tks, int32(t))
|
|
}
|
|
tresult.Segments = append(tresult.Segments,
|
|
&pb.TranscriptSegment{
|
|
Text: s.Text,
|
|
Id: int32(s.Id),
|
|
Start: int64(s.Start),
|
|
End: int64(s.End),
|
|
Tokens: tks,
|
|
})
|
|
}
|
|
|
|
tresult.Text = result.Text
|
|
return tresult, nil
|
|
}
|
|
|
|
func (s *server) PredictStream(in *pb.PredictOptions, stream pb.Backend_PredictStreamServer) error {
|
|
if s.llm.Locking() {
|
|
s.llm.Lock()
|
|
defer s.llm.Unlock()
|
|
}
|
|
resultChan := make(chan string)
|
|
|
|
done := make(chan bool)
|
|
go func() {
|
|
for result := range resultChan {
|
|
stream.Send(newReply(result))
|
|
}
|
|
done <- true
|
|
}()
|
|
|
|
err := s.llm.PredictStream(in, resultChan)
|
|
// close the channel, so if resultChan is not closed by the LLM (maybe because does not implement PredictStream), the client will not hang
|
|
close(resultChan)
|
|
<-done
|
|
|
|
return err
|
|
}
|
|
|
|
func (s *server) TokenizeString(ctx context.Context, in *pb.PredictOptions) (*pb.TokenizationResponse, error) {
|
|
if s.llm.Locking() {
|
|
s.llm.Lock()
|
|
defer s.llm.Unlock()
|
|
}
|
|
res, err := s.llm.TokenizeString(in)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
castTokens := make([]int32, len(res.Tokens))
|
|
for i, v := range res.Tokens {
|
|
castTokens[i] = int32(v)
|
|
}
|
|
|
|
return &pb.TokenizationResponse{
|
|
Length: int32(res.Length),
|
|
Tokens: castTokens,
|
|
}, err
|
|
}
|
|
|
|
func (s *server) Status(ctx context.Context, in *pb.HealthMessage) (*pb.StatusResponse, error) {
|
|
res, err := s.llm.Status()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &res, nil
|
|
}
|
|
|
|
func (s *server) StoresSet(ctx context.Context, in *pb.StoresSetOptions) (*pb.Result, error) {
|
|
if s.llm.Locking() {
|
|
s.llm.Lock()
|
|
defer s.llm.Unlock()
|
|
}
|
|
err := s.llm.StoresSet(in)
|
|
if err != nil {
|
|
return &pb.Result{Message: fmt.Sprintf("Error setting entry: %s", err.Error()), Success: false}, err
|
|
}
|
|
return &pb.Result{Message: "Set key", Success: true}, nil
|
|
}
|
|
|
|
func (s *server) StoresDelete(ctx context.Context, in *pb.StoresDeleteOptions) (*pb.Result, error) {
|
|
if s.llm.Locking() {
|
|
s.llm.Lock()
|
|
defer s.llm.Unlock()
|
|
}
|
|
err := s.llm.StoresDelete(in)
|
|
if err != nil {
|
|
return &pb.Result{Message: fmt.Sprintf("Error deleting entry: %s", err.Error()), Success: false}, err
|
|
}
|
|
return &pb.Result{Message: "Deleted key", Success: true}, nil
|
|
}
|
|
|
|
func (s *server) StoresGet(ctx context.Context, in *pb.StoresGetOptions) (*pb.StoresGetResult, error) {
|
|
if s.llm.Locking() {
|
|
s.llm.Lock()
|
|
defer s.llm.Unlock()
|
|
}
|
|
res, err := s.llm.StoresGet(in)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &res, nil
|
|
}
|
|
|
|
func (s *server) StoresFind(ctx context.Context, in *pb.StoresFindOptions) (*pb.StoresFindResult, error) {
|
|
if s.llm.Locking() {
|
|
s.llm.Lock()
|
|
defer s.llm.Unlock()
|
|
}
|
|
res, err := s.llm.StoresFind(in)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &res, nil
|
|
}
|
|
|
|
func StartServer(address string, model LLM) error {
|
|
lis, err := net.Listen("tcp", address)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
s := grpc.NewServer()
|
|
pb.RegisterBackendServer(s, &server{llm: model})
|
|
log.Printf("gRPC Server listening at %v", lis.Addr())
|
|
if err := s.Serve(lis); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func RunServer(address string, model LLM) (func() error, error) {
|
|
lis, err := net.Listen("tcp", address)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
s := grpc.NewServer()
|
|
pb.RegisterBackendServer(s, &server{llm: model})
|
|
log.Printf("gRPC Server listening at %v", lis.Addr())
|
|
if err = s.Serve(lis); err != nil {
|
|
return func() error {
|
|
return lis.Close()
|
|
}, err
|
|
}
|
|
|
|
return func() error {
|
|
s.GracefulStop()
|
|
return nil
|
|
}, nil
|
|
}
|