feat: stream tokens usage (#4415)

* Use pb.Reply instead of []byte with Reply.GetMessage() in llama grpc to get the proper usage data in reply streaming mode at the last [DONE] frame

* Fix 'hang' on empty message from the start

Seems like that empty message marker trick was unnecessary

---------

Co-authored-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
This commit is contained in:
mintyleaf 2024-12-18 12:48:50 +04:00 committed by GitHub
parent fc920cc58a
commit 2bc4b56a79
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 17 additions and 9 deletions

View File

@ -117,8 +117,12 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
ss := "" ss := ""
var partialRune []byte var partialRune []byte
err := inferenceModel.PredictStream(ctx, opts, func(chars []byte) { err := inferenceModel.PredictStream(ctx, opts, func(reply *proto.Reply) {
partialRune = append(partialRune, chars...) msg := reply.Message
partialRune = append(partialRune, msg...)
tokenUsage.Prompt = int(reply.PromptTokens)
tokenUsage.Completion = int(reply.Tokens)
for len(partialRune) > 0 { for len(partialRune) > 0 {
r, size := utf8.DecodeRune(partialRune) r, size := utf8.DecodeRune(partialRune)
@ -132,6 +136,10 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
partialRune = partialRune[size:] partialRune = partialRune[size:]
} }
if len(msg) == 0 {
tokenCallback("", tokenUsage)
}
}) })
return LLMResponse{ return LLMResponse{
Response: ss, Response: ss,

View File

@ -37,7 +37,7 @@ type Backend interface {
Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.EmbeddingResult, error) Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.EmbeddingResult, error)
Predict(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.Reply, error) Predict(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.Reply, error)
LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.Result, error) LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.Result, error)
PredictStream(ctx context.Context, in *pb.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error PredictStream(ctx context.Context, in *pb.PredictOptions, f func(reply *pb.Reply), opts ...grpc.CallOption) error
GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error)
TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error)
SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...grpc.CallOption) (*pb.Result, error) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...grpc.CallOption) (*pb.Result, error)

View File

@ -136,7 +136,7 @@ func (c *Client) LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grp
return client.LoadModel(ctx, in, opts...) return client.LoadModel(ctx, in, opts...)
} }
func (c *Client) PredictStream(ctx context.Context, in *pb.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error { func (c *Client) PredictStream(ctx context.Context, in *pb.PredictOptions, f func(reply *pb.Reply), opts ...grpc.CallOption) error {
if !c.parallel { if !c.parallel {
c.opMutex.Lock() c.opMutex.Lock()
defer c.opMutex.Unlock() defer c.opMutex.Unlock()
@ -158,7 +158,7 @@ func (c *Client) PredictStream(ctx context.Context, in *pb.PredictOptions, f fun
} }
for { for {
feature, err := stream.Recv() reply, err := stream.Recv()
if err == io.EOF { if err == io.EOF {
break break
} }
@ -167,7 +167,7 @@ func (c *Client) PredictStream(ctx context.Context, in *pb.PredictOptions, f fun
return err return err
} }
f(feature.GetMessage()) f(reply)
} }
return nil return nil

View File

@ -35,7 +35,7 @@ func (e *embedBackend) LoadModel(ctx context.Context, in *pb.ModelOptions, opts
return e.s.LoadModel(ctx, in) return e.s.LoadModel(ctx, in)
} }
func (e *embedBackend) PredictStream(ctx context.Context, in *pb.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error { func (e *embedBackend) PredictStream(ctx context.Context, in *pb.PredictOptions, f func(reply *pb.Reply), opts ...grpc.CallOption) error {
bs := &embedBackendServerStream{ bs := &embedBackendServerStream{
ctx: ctx, ctx: ctx,
fn: f, fn: f,
@ -97,11 +97,11 @@ func (e *embedBackend) GetTokenMetrics(ctx context.Context, in *pb.MetricsReques
type embedBackendServerStream struct { type embedBackendServerStream struct {
ctx context.Context ctx context.Context
fn func(s []byte) fn func(reply *pb.Reply)
} }
func (e *embedBackendServerStream) Send(reply *pb.Reply) error { func (e *embedBackendServerStream) Send(reply *pb.Reply) error {
e.fn(reply.GetMessage()) e.fn(reply)
return nil return nil
} }