mirror of
https://github.com/mudler/LocalAI.git
synced 2025-01-21 20:08:51 +00:00
feat: stream tokens usage (#4415)
* Use pb.Reply instead of []byte with Reply.GetMessage() in llama grpc to get the proper usage data in reply streaming mode at the last [DONE] frame * Fix 'hang' on empty message from the start Seems like that empty message marker trick was unnecessary --------- Co-authored-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
This commit is contained in:
parent
fc920cc58a
commit
2bc4b56a79
@ -117,8 +117,12 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
|
|||||||
ss := ""
|
ss := ""
|
||||||
|
|
||||||
var partialRune []byte
|
var partialRune []byte
|
||||||
err := inferenceModel.PredictStream(ctx, opts, func(chars []byte) {
|
err := inferenceModel.PredictStream(ctx, opts, func(reply *proto.Reply) {
|
||||||
partialRune = append(partialRune, chars...)
|
msg := reply.Message
|
||||||
|
partialRune = append(partialRune, msg...)
|
||||||
|
|
||||||
|
tokenUsage.Prompt = int(reply.PromptTokens)
|
||||||
|
tokenUsage.Completion = int(reply.Tokens)
|
||||||
|
|
||||||
for len(partialRune) > 0 {
|
for len(partialRune) > 0 {
|
||||||
r, size := utf8.DecodeRune(partialRune)
|
r, size := utf8.DecodeRune(partialRune)
|
||||||
@ -132,6 +136,10 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
|
|||||||
|
|
||||||
partialRune = partialRune[size:]
|
partialRune = partialRune[size:]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(msg) == 0 {
|
||||||
|
tokenCallback("", tokenUsage)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
return LLMResponse{
|
return LLMResponse{
|
||||||
Response: ss,
|
Response: ss,
|
||||||
|
@ -37,7 +37,7 @@ type Backend interface {
|
|||||||
Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.EmbeddingResult, error)
|
Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.EmbeddingResult, error)
|
||||||
Predict(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.Reply, error)
|
Predict(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.Reply, error)
|
||||||
LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.Result, error)
|
LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.Result, error)
|
||||||
PredictStream(ctx context.Context, in *pb.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error
|
PredictStream(ctx context.Context, in *pb.PredictOptions, f func(reply *pb.Reply), opts ...grpc.CallOption) error
|
||||||
GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error)
|
GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error)
|
||||||
TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error)
|
TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error)
|
||||||
SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...grpc.CallOption) (*pb.Result, error)
|
SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...grpc.CallOption) (*pb.Result, error)
|
||||||
|
@ -136,7 +136,7 @@ func (c *Client) LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grp
|
|||||||
return client.LoadModel(ctx, in, opts...)
|
return client.LoadModel(ctx, in, opts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) PredictStream(ctx context.Context, in *pb.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error {
|
func (c *Client) PredictStream(ctx context.Context, in *pb.PredictOptions, f func(reply *pb.Reply), opts ...grpc.CallOption) error {
|
||||||
if !c.parallel {
|
if !c.parallel {
|
||||||
c.opMutex.Lock()
|
c.opMutex.Lock()
|
||||||
defer c.opMutex.Unlock()
|
defer c.opMutex.Unlock()
|
||||||
@ -158,7 +158,7 @@ func (c *Client) PredictStream(ctx context.Context, in *pb.PredictOptions, f fun
|
|||||||
}
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
feature, err := stream.Recv()
|
reply, err := stream.Recv()
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@ -167,7 +167,7 @@ func (c *Client) PredictStream(ctx context.Context, in *pb.PredictOptions, f fun
|
|||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
f(feature.GetMessage())
|
f(reply)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -35,7 +35,7 @@ func (e *embedBackend) LoadModel(ctx context.Context, in *pb.ModelOptions, opts
|
|||||||
return e.s.LoadModel(ctx, in)
|
return e.s.LoadModel(ctx, in)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *embedBackend) PredictStream(ctx context.Context, in *pb.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error {
|
func (e *embedBackend) PredictStream(ctx context.Context, in *pb.PredictOptions, f func(reply *pb.Reply), opts ...grpc.CallOption) error {
|
||||||
bs := &embedBackendServerStream{
|
bs := &embedBackendServerStream{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
fn: f,
|
fn: f,
|
||||||
@ -97,11 +97,11 @@ func (e *embedBackend) GetTokenMetrics(ctx context.Context, in *pb.MetricsReques
|
|||||||
|
|
||||||
type embedBackendServerStream struct {
|
type embedBackendServerStream struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
fn func(s []byte)
|
fn func(reply *pb.Reply)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *embedBackendServerStream) Send(reply *pb.Reply) error {
|
func (e *embedBackendServerStream) Send(reply *pb.Reply) error {
|
||||||
e.fn(reply.GetMessage())
|
e.fn(reply)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user