diff --git a/backend/backend.proto b/backend/backend.proto index b2d4518e..568655b6 100644 --- a/backend/backend.proto +++ b/backend/backend.proto @@ -26,6 +26,19 @@ service Backend { rpc StoresFind(StoresFindOptions) returns (StoresFindResult) {} rpc Rerank(RerankRequest) returns (RerankResult) {} + + rpc GetMetrics(MetricsRequest) returns (MetricsResponse); +} + +// Define the empty request +message MetricsRequest {} + +message MetricsResponse { + int32 slot_id = 1; + string prompt_json_for_slot = 2; // Stores the prompt as a JSON string. + float tokens_per_second = 3; + int32 tokens_generated = 4; + int32 prompt_tokens_processed = 5; } message RerankRequest { diff --git a/backend/cpp/llama/grpc-server.cpp b/backend/cpp/llama/grpc-server.cpp index 791612db..be99bf76 100644 --- a/backend/cpp/llama/grpc-server.cpp +++ b/backend/cpp/llama/grpc-server.cpp @@ -495,6 +495,16 @@ struct llama_server_context } } + llama_client_slot* get_active_slot() { + for (llama_client_slot& slot : slots) { + // Check if the slot is currently processing + if (slot.is_processing()) { + return &slot; // Return the active slot + } + } + return nullptr; // No active slot found + } + void initialize() { // create slots all_slots_are_idle = true; @@ -2420,6 +2430,31 @@ public: return grpc::Status::OK; } + + grpc::Status GetMetrics(ServerContext* context, const backend::MetricsRequest* request, backend::MetricsResponse* response) { + llama_client_slot* active_slot = llama.get_active_slot(); + + if (active_slot != nullptr) { + // Calculate the tokens per second using existing logic + double tokens_per_second = 1e3 / active_slot->t_token_generation * active_slot->n_decoded; + + // Populate the response with metrics + response->set_slot_id(active_slot->id); + response->set_prompt_json_for_slot(active_slot->prompt.dump()); + response->set_tokens_per_second(tokens_per_second); + response->set_tokens_generated(active_slot->n_decoded); + response->set_prompt_tokens_processed(active_slot->num_prompt_tokens_processed); + } else { + // Handle case when no active slot exists + response->set_slot_id(0); + response->set_prompt_json_for_slot(""); + response->set_tokens_per_second(0); + response->set_tokens_generated(0); + response->set_prompt_tokens_processed(0); + } + + return grpc::Status::OK; + } }; void RunServer(const std::string& server_address) { diff --git a/core/backend/token_metrics.go b/core/backend/token_metrics.go new file mode 100644 index 00000000..cd715108 --- /dev/null +++ b/core/backend/token_metrics.go @@ -0,0 +1,44 @@ +package backend + +import ( + "context" + "fmt" + + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/pkg/grpc/proto" + model "github.com/mudler/LocalAI/pkg/model" +) + +func TokenMetrics( + backend, + modelFile string, + loader *model.ModelLoader, + appConfig *config.ApplicationConfig, + backendConfig config.BackendConfig) (*proto.MetricsResponse, error) { + bb := backend + if bb == "" { + return nil, fmt.Errorf("backend is required") + } + + grpcOpts := GRPCModelOpts(backendConfig) + + opts := modelOpts(config.BackendConfig{}, appConfig, []model.Option{ + model.WithBackendString(bb), + model.WithModel(modelFile), + model.WithContext(appConfig.Context), + model.WithAssetDir(appConfig.AssetsDestination), + model.WithLoadGRPCLoadModelOpts(grpcOpts), + }) + model, err := loader.BackendLoader(opts...) + if err != nil { + return nil, err + } + + if model == nil { + return nil, fmt.Errorf("could not loadmodel model") + } + + res, err := model.GetTokenMetrics(context.Background(), &proto.MetricsRequest{}) + + return res, err +} diff --git a/core/http/endpoints/localai/get_token_metrics.go b/core/http/endpoints/localai/get_token_metrics.go new file mode 100644 index 00000000..95e79bac --- /dev/null +++ b/core/http/endpoints/localai/get_token_metrics.go @@ -0,0 +1,60 @@ +package localai + +import ( + "github.com/gofiber/fiber/v2" + "github.com/mudler/LocalAI/core/backend" + "github.com/mudler/LocalAI/core/config" + fiberContext "github.com/mudler/LocalAI/core/http/ctx" + "github.com/mudler/LocalAI/core/schema" + "github.com/rs/zerolog/log" + + "github.com/mudler/LocalAI/pkg/model" +) + +// TokenMetricsEndpoint is an endpoint to get TokensProcessed Per Second for Active SlotID +// +// @Summary Get TokenMetrics for Active Slot. +// @Accept json +// @Produce audio/x-wav +// @Success 200 {string} binary "generated audio/wav file" +// @Router /v1/tokenMetrics [get] +// @Router /tokenMetrics [get] +func TokenMetricsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + + input := new(schema.TokenMetricsRequest) + + // Get input data from the request body + if err := c.BodyParser(input); err != nil { + return err + } + + modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, false) + if err != nil { + modelFile = input.Model + log.Warn().Msgf("Model not found in context: %s", input.Model) + } + + cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath, + config.LoadOptionDebug(appConfig.Debug), + config.LoadOptionThreads(appConfig.Threads), + config.LoadOptionContextSize(appConfig.ContextSize), + config.LoadOptionF16(appConfig.F16), + ) + + if err != nil { + log.Err(err) + modelFile = input.Model + log.Warn().Msgf("Model not found in context: %s", input.Model) + } else { + modelFile = cfg.Model + } + log.Debug().Msgf("Token Metrics for model: %s", modelFile) + + response, err := backend.TokenMetrics(cfg.Backend, modelFile, ml, appConfig, *cfg) + if err != nil { + return err + } + return c.JSON(response) + } +} diff --git a/core/schema/localai.go b/core/schema/localai.go index 75fa40c7..cdc3e5b0 100644 --- a/core/schema/localai.go +++ b/core/schema/localai.go @@ -10,6 +10,10 @@ type BackendMonitorRequest struct { Model string `json:"model" yaml:"model"` } +type TokenMetricsRequest struct { + Model string `json:"model" yaml:"model"` +} + type BackendMonitorResponse struct { MemoryInfo *gopsutil.MemoryInfoStat MemoryPercent float32 diff --git a/pkg/grpc/backend.go b/pkg/grpc/backend.go index 85c9e5bc..637a6db1 100644 --- a/pkg/grpc/backend.go +++ b/pkg/grpc/backend.go @@ -51,4 +51,6 @@ type Backend interface { StoresFind(ctx context.Context, in *pb.StoresFindOptions, opts ...grpc.CallOption) (*pb.StoresFindResult, error) Rerank(ctx context.Context, in *pb.RerankRequest, opts ...grpc.CallOption) (*pb.RerankResult, error) + + GetTokenMetrics(ctx context.Context, in *pb.MetricsRequest, opts ...grpc.CallOption) (*pb.MetricsResponse, error) } diff --git a/pkg/grpc/client.go b/pkg/grpc/client.go index 032c9c00..14481620 100644 --- a/pkg/grpc/client.go +++ b/pkg/grpc/client.go @@ -374,3 +374,21 @@ func (c *Client) Rerank(ctx context.Context, in *pb.RerankRequest, opts ...grpc. client := pb.NewBackendClient(conn) return client.Rerank(ctx, in, opts...) } + +func (c *Client) GetTokenMetrics(ctx context.Context, in *pb.MetricsRequest, opts ...grpc.CallOption) (*pb.MetricsResponse, error) { + if !c.parallel { + c.opMutex.Lock() + defer c.opMutex.Unlock() + } + c.setBusy(true) + defer c.setBusy(false) + c.wdMark() + defer c.wdUnMark() + conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + return nil, err + } + defer conn.Close() + client := pb.NewBackendClient(conn) + return client.GetMetrics(ctx, in, opts...) +} diff --git a/pkg/grpc/embed.go b/pkg/grpc/embed.go index 3155ff59..cf624344 100644 --- a/pkg/grpc/embed.go +++ b/pkg/grpc/embed.go @@ -87,6 +87,10 @@ func (e *embedBackend) Rerank(ctx context.Context, in *pb.RerankRequest, opts .. return e.s.Rerank(ctx, in) } +func (e *embedBackend) GetTokenMetrics(ctx context.Context, in *pb.MetricsRequest, opts ...grpc.CallOption) (*pb.MetricsResponse, error) { + return e.s.GetMetrics(ctx, in) +} + type embedBackendServerStream struct { ctx context.Context fn func(s []byte)