mirror of
https://github.com/mudler/LocalAI.git
synced 2025-02-21 09:41:45 +00:00
feat(grpc): return consumed token count and update response accordingly (#2035)
Fixes: #1920
This commit is contained in:
parent
de3a1a0a8e
commit
e843d7df0e
@ -114,6 +114,8 @@ message PredictOptions {
|
||||
// The response message containing the result
|
||||
message Reply {
|
||||
bytes message = 1;
|
||||
int32 tokens = 2;
|
||||
int32 prompt_tokens = 3;
|
||||
}
|
||||
|
||||
message ModelOptions {
|
||||
|
@ -2332,6 +2332,10 @@ public:
|
||||
std::string completion_text = result.result_json.value("content", "");
|
||||
|
||||
reply.set_message(completion_text);
|
||||
int32_t tokens_predicted = result.result_json.value("tokens_predicted", 0);
|
||||
reply.set_tokens(tokens_predicted);
|
||||
int32_t tokens_evaluated = result.result_json.value("tokens_evaluated", 0);
|
||||
reply.set_prompt_tokens(tokens_evaluated);
|
||||
|
||||
// Send the reply
|
||||
writer->Write(reply);
|
||||
@ -2357,6 +2361,10 @@ public:
|
||||
task_result result = llama.queue_results.recv(task_id);
|
||||
if (!result.error && result.stop) {
|
||||
completion_text = result.result_json.value("content", "");
|
||||
int32_t tokens_predicted = result.result_json.value("tokens_predicted", 0);
|
||||
int32_t tokens_evaluated = result.result_json.value("tokens_evaluated", 0);
|
||||
reply->set_prompt_tokens(tokens_evaluated);
|
||||
reply->set_tokens(tokens_predicted);
|
||||
reply->set_message(completion_text);
|
||||
}
|
||||
else
|
||||
|
@ -189,6 +189,12 @@ func (llmbs *LLMBackendService) Inference(ctx context.Context, req *LLMRequest,
|
||||
} else {
|
||||
go func() {
|
||||
reply, err := inferenceModel.Predict(ctx, grpcPredOpts)
|
||||
if tokenUsage.Prompt == 0 {
|
||||
tokenUsage.Prompt = int(reply.PromptTokens)
|
||||
}
|
||||
if tokenUsage.Completion == 0 {
|
||||
tokenUsage.Completion = int(reply.Tokens)
|
||||
}
|
||||
if err != nil {
|
||||
rawResultChannel <- concurrency.ErrorOr[*LLMResponse]{Error: err}
|
||||
close(rawResultChannel)
|
||||
|
@ -160,7 +160,7 @@ func (oais *OpenAIService) GenerateTextFromRequest(request *schema.OpenAIRequest
|
||||
|
||||
bc, request, err := oais.getConfig(request)
|
||||
if err != nil {
|
||||
log.Error().Msgf("[oais::GenerateTextFromRequest] error getting configuration: %q", err)
|
||||
log.Error().Err(err).Msgf("[oais::GenerateTextFromRequest] error getting configuration")
|
||||
return
|
||||
}
|
||||
|
||||
@ -259,7 +259,7 @@ func (oais *OpenAIService) GenerateTextFromRequest(request *schema.OpenAIRequest
|
||||
// If any of the setup goroutines experienced an error, quit early here.
|
||||
if setupError != nil {
|
||||
go func() {
|
||||
log.Error().Msgf("[OAIS GenerateTextFromRequest] caught an error during setup: %q", setupError)
|
||||
log.Error().Err(setupError).Msgf("[OAIS GenerateTextFromRequest] caught an error during setup")
|
||||
rawFinalResultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: setupError}
|
||||
close(rawFinalResultChannel)
|
||||
}()
|
||||
@ -603,7 +603,7 @@ func (oais *OpenAIService) GenerateFromMultipleMessagesChatRequest(request *sche
|
||||
Usage: schema.OpenAIUsage{
|
||||
PromptTokens: rawResult.Value.Usage.Prompt,
|
||||
CompletionTokens: rawResult.Value.Usage.Completion,
|
||||
TotalTokens: rawResult.Value.Usage.Prompt + rawResult.Value.Usage.Prompt,
|
||||
TotalTokens: rawResult.Value.Usage.Prompt + rawResult.Value.Usage.Completion,
|
||||
},
|
||||
}
|
||||
|
||||
@ -644,7 +644,7 @@ func (oais *OpenAIService) GenerateFromMultipleMessagesChatRequest(request *sche
|
||||
Usage: schema.OpenAIUsage{
|
||||
PromptTokens: rawResult.Value.Usage.Prompt,
|
||||
CompletionTokens: rawResult.Value.Usage.Completion,
|
||||
TotalTokens: rawResult.Value.Usage.Prompt + rawResult.Value.Usage.Prompt,
|
||||
TotalTokens: rawResult.Value.Usage.Prompt + rawResult.Value.Usage.Completion,
|
||||
},
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user