Usage Features (#863)

This commit is contained in:
Dave 2023-08-18 15:23:14 -04:00 committed by GitHub
parent 2bacd0180d
commit 8cb1061c11
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
40 changed files with 1222 additions and 317 deletions

2
.gitignore vendored
View File

@ -22,6 +22,8 @@ LocalAI
local-ai
# prevent above rules from omitting the helm chart
!charts/*
# prevent above rules from omitting the api/localai folder
!api/localai
# Ignore models
models/*

View File

@ -4,7 +4,7 @@ GOVET=$(GOCMD) vet
BINARY_NAME=local-ai
# llama.cpp versions
GOLLAMA_VERSION?=50cee7712066d9e38306eccadcfbb44ea87df4b7
GOLLAMA_VERSION?=f03869d188b72c8a617bea3a36cf8eb43f73445c
# gpt4all version
GPT4ALL_REPO?=https://github.com/nomic-ai/gpt4all

View File

@ -2,6 +2,7 @@ package api
import (
"errors"
"fmt"
"strings"
config "github.com/go-skynet/LocalAI/api/config"
@ -19,7 +20,7 @@ import (
"github.com/rs/zerolog/log"
)
func App(opts ...options.AppOption) (*fiber.App, error) {
func Startup(opts ...options.AppOption) (*options.Option, *config.ConfigLoader, error) {
options := options.NewOptions(opts...)
zerolog.SetGlobalLevel(zerolog.InfoLevel)
@ -27,6 +28,65 @@ func App(opts ...options.AppOption) (*fiber.App, error) {
zerolog.SetGlobalLevel(zerolog.DebugLevel)
}
log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.Loader.ModelPath)
log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion())
cl := config.NewConfigLoader()
if err := cl.LoadConfigs(options.Loader.ModelPath); err != nil {
log.Error().Msgf("error loading config files: %s", err.Error())
}
if options.ConfigFile != "" {
if err := cl.LoadConfigFile(options.ConfigFile); err != nil {
log.Error().Msgf("error loading config file: %s", err.Error())
}
}
if options.Debug {
for _, v := range cl.ListConfigs() {
cfg, _ := cl.GetConfig(v)
log.Debug().Msgf("Model: %s (config: %+v)", v, cfg)
}
}
if options.AssetsDestination != "" {
// Extract files from the embedded FS
err := assets.ExtractFiles(options.BackendAssets, options.AssetsDestination)
log.Debug().Msgf("Extracting backend assets files to %s", options.AssetsDestination)
if err != nil {
log.Warn().Msgf("Failed extracting backend assets files: %s (might be required for some backends to work properly, like gpt4all)", err)
}
}
if options.PreloadJSONModels != "" {
if err := localai.ApplyGalleryFromString(options.Loader.ModelPath, options.PreloadJSONModels, cl, options.Galleries); err != nil {
return nil, nil, err
}
}
if options.PreloadModelsFromPath != "" {
if err := localai.ApplyGalleryFromFile(options.Loader.ModelPath, options.PreloadModelsFromPath, cl, options.Galleries); err != nil {
return nil, nil, err
}
}
// turn off any process that was started by GRPC if the context is canceled
go func() {
<-options.Context.Done()
log.Debug().Msgf("Context canceled, shutting down")
options.Loader.StopAllGRPC()
}()
return options, cl, nil
}
func App(opts ...options.AppOption) (*fiber.App, error) {
options, cl, err := Startup(opts...)
if err != nil {
return nil, fmt.Errorf("failed basic startup tasks with error %s", err.Error())
}
// Return errors as JSON responses
app := fiber.New(fiber.Config{
BodyLimit: options.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
@ -57,36 +117,6 @@ func App(opts ...options.AppOption) (*fiber.App, error) {
}))
}
log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.Loader.ModelPath)
log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion())
cm := config.NewConfigLoader()
if err := cm.LoadConfigs(options.Loader.ModelPath); err != nil {
log.Error().Msgf("error loading config files: %s", err.Error())
}
if options.ConfigFile != "" {
if err := cm.LoadConfigFile(options.ConfigFile); err != nil {
log.Error().Msgf("error loading config file: %s", err.Error())
}
}
if options.Debug {
for _, v := range cm.ListConfigs() {
cfg, _ := cm.GetConfig(v)
log.Debug().Msgf("Model: %s (config: %+v)", v, cfg)
}
}
if options.AssetsDestination != "" {
// Extract files from the embedded FS
err := assets.ExtractFiles(options.BackendAssets, options.AssetsDestination)
log.Debug().Msgf("Extracting backend assets files to %s", options.AssetsDestination)
if err != nil {
log.Warn().Msgf("Failed extracting backend assets files: %s (might be required for some backends to work properly, like gpt4all)", err)
}
}
// Default middleware config
app.Use(recover.New())
@ -116,18 +146,6 @@ func App(opts ...options.AppOption) (*fiber.App, error) {
return c.Next()
}
if options.PreloadJSONModels != "" {
if err := localai.ApplyGalleryFromString(options.Loader.ModelPath, options.PreloadJSONModels, cm, options.Galleries); err != nil {
return nil, err
}
}
if options.PreloadModelsFromPath != "" {
if err := localai.ApplyGalleryFromFile(options.Loader.ModelPath, options.PreloadModelsFromPath, cm, options.Galleries); err != nil {
return nil, err
}
}
if options.CORS {
var c func(ctx *fiber.Ctx) error
if options.CORSAllowOrigins == "" {
@ -141,7 +159,7 @@ func App(opts ...options.AppOption) (*fiber.App, error) {
// LocalAI API endpoints
galleryService := localai.NewGalleryService(options.Loader.ModelPath)
galleryService.Start(options.Context, cm)
galleryService.Start(options.Context, cl)
app.Get("/version", auth, func(c *fiber.Ctx) error {
return c.JSON(struct {
@ -149,36 +167,36 @@ func App(opts ...options.AppOption) (*fiber.App, error) {
}{Version: internal.PrintableVersion()})
})
app.Post("/models/apply", auth, localai.ApplyModelGalleryEndpoint(options.Loader.ModelPath, cm, galleryService.C, options.Galleries))
app.Post("/models/apply", auth, localai.ApplyModelGalleryEndpoint(options.Loader.ModelPath, cl, galleryService.C, options.Galleries))
app.Get("/models/available", auth, localai.ListModelFromGalleryEndpoint(options.Galleries, options.Loader.ModelPath))
app.Get("/models/jobs/:uuid", auth, localai.GetOpStatusEndpoint(galleryService))
// openAI compatible API endpoint
// chat
app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(cm, options))
app.Post("/chat/completions", auth, openai.ChatEndpoint(cm, options))
app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(cl, options))
app.Post("/chat/completions", auth, openai.ChatEndpoint(cl, options))
// edit
app.Post("/v1/edits", auth, openai.EditEndpoint(cm, options))
app.Post("/edits", auth, openai.EditEndpoint(cm, options))
app.Post("/v1/edits", auth, openai.EditEndpoint(cl, options))
app.Post("/edits", auth, openai.EditEndpoint(cl, options))
// completion
app.Post("/v1/completions", auth, openai.CompletionEndpoint(cm, options))
app.Post("/completions", auth, openai.CompletionEndpoint(cm, options))
app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(cm, options))
app.Post("/v1/completions", auth, openai.CompletionEndpoint(cl, options))
app.Post("/completions", auth, openai.CompletionEndpoint(cl, options))
app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(cl, options))
// embeddings
app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(cm, options))
app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(cm, options))
app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(cm, options))
app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(cl, options))
app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(cl, options))
app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(cl, options))
// audio
app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(cm, options))
app.Post("/tts", auth, localai.TTSEndpoint(cm, options))
app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(cl, options))
app.Post("/tts", auth, localai.TTSEndpoint(cl, options))
// images
app.Post("/v1/images/generations", auth, openai.ImageEndpoint(cm, options))
app.Post("/v1/images/generations", auth, openai.ImageEndpoint(cl, options))
if options.ImageDir != "" {
app.Static("/generated-images", options.ImageDir)
@ -196,16 +214,13 @@ func App(opts ...options.AppOption) (*fiber.App, error) {
app.Get("/healthz", ok)
app.Get("/readyz", ok)
// models
app.Get("/v1/models", auth, openai.ListModelsEndpoint(options.Loader, cm))
app.Get("/models", auth, openai.ListModelsEndpoint(options.Loader, cm))
// Experimental Backend Statistics Module
backendMonitor := localai.NewBackendMonitor(cl, options) // Split out for now
app.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitor))
// turn off any process that was started by GRPC if the context is canceled
go func() {
<-options.Context.Done()
log.Debug().Msgf("Context canceled, shutting down")
options.Loader.StopGRPC()
}()
// models
app.Get("/v1/models", auth, openai.ListModelsEndpoint(options.Loader, cl))
app.Get("/models", auth, openai.ListModelsEndpoint(options.Loader, cl))
return app, nil
}

View File

@ -2,7 +2,6 @@ package backend
import (
"fmt"
"sync"
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/options"
@ -88,18 +87,6 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.
}
return func() ([]float32, error) {
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
mutexMap.Lock()
l, ok := mutexes[modelFile]
if !ok {
m := &sync.Mutex{}
mutexes[modelFile] = m
l = m
}
mutexMap.Unlock()
l.Lock()
defer l.Unlock()
embeds, err := fn()
if err != nil {
return embeds, err

View File

@ -1,8 +1,6 @@
package backend
import (
"sync"
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/options"
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
@ -67,19 +65,5 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat
return err
}
return func() error {
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
mutexMap.Lock()
l, ok := mutexes[c.Backend]
if !ok {
m := &sync.Mutex{}
mutexes[c.Backend] = m
l = m
}
mutexMap.Unlock()
l.Lock()
defer l.Unlock()
return fn()
}, nil
return fn, nil
}

View File

@ -15,7 +15,17 @@ import (
"github.com/go-skynet/LocalAI/pkg/utils"
)
func ModelInference(ctx context.Context, s string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string) bool) (func() (string, error), error) {
type LLMResponse struct {
Response string // should this be []byte?
Usage TokenUsage
}
type TokenUsage struct {
Prompt int
Completion int
}
func ModelInference(ctx context.Context, s string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
modelFile := c.Model
grpcOpts := gRPCModelOpts(c)
@ -70,40 +80,56 @@ func ModelInference(ctx context.Context, s string, loader *model.ModelLoader, c
}
// in GRPC, the backend is supposed to answer to 1 single token if stream is not supported
fn := func() (string, error) {
fn := func() (LLMResponse, error) {
opts := gRPCPredictOpts(c, loader.ModelPath)
opts.Prompt = s
tokenUsage := TokenUsage{}
// check the per-model feature flag for usage, since tokenCallback may have a cost, but default to on.
if !c.FeatureFlag["usage"] {
userTokenCallback := tokenCallback
if userTokenCallback == nil {
userTokenCallback = func(token string, usage TokenUsage) bool {
return true
}
}
promptInfo, pErr := inferenceModel.TokenizeString(ctx, opts)
if pErr == nil && promptInfo.Length > 0 {
tokenUsage.Prompt = int(promptInfo.Length)
}
tokenCallback = func(token string, usage TokenUsage) bool {
tokenUsage.Completion++
return userTokenCallback(token, tokenUsage)
}
}
if tokenCallback != nil {
ss := ""
err := inferenceModel.PredictStream(ctx, opts, func(s []byte) {
tokenCallback(string(s))
tokenCallback(string(s), tokenUsage)
ss += string(s)
})
return ss, err
return LLMResponse{
Response: ss,
Usage: tokenUsage,
}, err
} else {
// TODO: Is the chicken bit the only way to get here? is that acceptable?
reply, err := inferenceModel.Predict(ctx, opts)
if err != nil {
return "", err
return LLMResponse{}, err
}
return string(reply.Message), err
return LLMResponse{
Response: string(reply.Message),
Usage: tokenUsage,
}, err
}
}
return func() (string, error) {
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
mutexMap.Lock()
l, ok := mutexes[modelFile]
if !ok {
m := &sync.Mutex{}
mutexes[modelFile] = m
l = m
}
mutexMap.Unlock()
l.Lock()
defer l.Unlock()
return fn()
}, nil
return fn, nil
}
var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp)

View File

@ -1,22 +0,0 @@
package backend
import "sync"
// mutex still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
var mutexMap sync.Mutex
var mutexes map[string]*sync.Mutex = make(map[string]*sync.Mutex)
func Lock(s string) *sync.Mutex {
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
mutexMap.Lock()
l, ok := mutexes[s]
if !ok {
m := &sync.Mutex{}
mutexes[s] = m
l = m
}
mutexMap.Unlock()
l.Lock()
return l
}

View File

@ -29,6 +29,7 @@ type Config struct {
FunctionsConfig Functions `yaml:"function"`
FeatureFlag map[string]bool `yaml:"feature_flags"` // Feature Flag registry. We move fast, and features may break on a per model/backend basis. Registry for (usually temporary) flags that indicate aborting something early.
// LLM configs (GPT4ALL, Llama.cpp, ...)
LLMConfig `yaml:",inline"`

View File

@ -0,0 +1,142 @@
package localai
import (
"context"
"fmt"
"strings"
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/go-skynet/LocalAI/api/options"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
gopsutil "github.com/shirou/gopsutil/v3/process"
)
type BackendMonitorRequest struct {
Model string `json:"model" yaml:"model"`
}
type BackendMonitorResponse struct {
MemoryInfo *gopsutil.MemoryInfoStat
MemoryPercent float32
CPUPercent float64
}
type BackendMonitor struct {
configLoader *config.ConfigLoader
options *options.Option // Taking options in case we need to inspect ExternalGRPCBackends, though that's out of scope for now, hence the name.
}
func NewBackendMonitor(configLoader *config.ConfigLoader, options *options.Option) BackendMonitor {
return BackendMonitor{
configLoader: configLoader,
options: options,
}
}
func (bm *BackendMonitor) SampleLocalBackendProcess(model string) (*BackendMonitorResponse, error) {
config, exists := bm.configLoader.GetConfig(model)
var backend string
if exists {
backend = config.Model
} else {
// Last ditch effort: use it raw, see if a backend happens to match.
backend = model
}
if !strings.HasSuffix(backend, ".bin") {
backend = fmt.Sprintf("%s.bin", backend)
}
pid, err := bm.options.Loader.GetGRPCPID(backend)
if err != nil {
log.Error().Msgf("model %s : failed to find pid %+v", model, err)
return nil, err
}
// Name is slightly frightening but this does _not_ create a new process, rather it looks up an existing process by PID.
backendProcess, err := gopsutil.NewProcess(int32(pid))
if err != nil {
log.Error().Msgf("model %s [PID %d] : error getting process info %+v", model, pid, err)
return nil, err
}
memInfo, err := backendProcess.MemoryInfo()
if err != nil {
log.Error().Msgf("model %s [PID %d] : error getting memory info %+v", model, pid, err)
return nil, err
}
memPercent, err := backendProcess.MemoryPercent()
if err != nil {
log.Error().Msgf("model %s [PID %d] : error getting memory percent %+v", model, pid, err)
return nil, err
}
cpuPercent, err := backendProcess.CPUPercent()
if err != nil {
log.Error().Msgf("model %s [PID %d] : error getting cpu percent %+v", model, pid, err)
return nil, err
}
return &BackendMonitorResponse{
MemoryInfo: memInfo,
MemoryPercent: memPercent,
CPUPercent: cpuPercent,
}, nil
}
func BackendMonitorEndpoint(bm BackendMonitor) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(BackendMonitorRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
config, exists := bm.configLoader.GetConfig(input.Model)
var backendId string
if exists {
backendId = config.Model
} else {
// Last ditch effort: use it raw, see if a backend happens to match.
backendId = input.Model
}
if !strings.HasSuffix(backendId, ".bin") {
backendId = fmt.Sprintf("%s.bin", backendId)
}
client := bm.options.Loader.CheckIsLoaded(backendId)
if client == nil {
return fmt.Errorf("backend %s is not currently loaded", input.Model)
}
status, rpcErr := client.Status(context.TODO())
if rpcErr != nil {
log.Warn().Msgf("backend %s experienced an error retrieving status info: %s", input.Model, rpcErr.Error())
val, slbErr := bm.SampleLocalBackendProcess(backendId)
if slbErr != nil {
return fmt.Errorf("backend %s experienced an error retrieving status info via rpc: %s, then failed local node process sample: %s", input.Model, rpcErr.Error(), slbErr.Error())
}
return c.JSON(proto.StatusResponse{
State: proto.StatusResponse_ERROR,
Memory: &proto.MemoryUsageData{
Total: val.MemoryInfo.VMS,
Breakdown: map[string]uint64{
"gopsutil-RSS": val.MemoryInfo.RSS,
},
},
})
}
return c.JSON(status)
}
}

View File

@ -29,11 +29,16 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
}
responses <- initialMessage
ComputeChoices(req, s, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool {
ComputeChoices(req, s, config, o, loader, func(s string, c *[]Choice) {}, func(s string, usage backend.TokenUsage) bool {
resp := OpenAIResponse{
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []Choice{{Delta: &Message{Content: &s}, Index: 0}},
Object: "chat.completion.chunk",
Usage: OpenAIUsage{
PromptTokens: usage.Prompt,
CompletionTokens: usage.Completion,
TotalTokens: usage.Prompt + usage.Completion,
},
}
responses <- resp
@ -237,11 +242,13 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
usage := &OpenAIUsage{}
for ev := range responses {
usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it
var buf bytes.Buffer
enc := json.NewEncoder(&buf)
enc.Encode(ev)
log.Debug().Msgf("Sending chunk: %s", buf.String())
_, err := fmt.Fprintf(w, "data: %v\n", buf.String())
if err != nil {
@ -261,6 +268,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
Delta: &Message{Content: &emptyMessage},
}},
Object: "chat.completion.chunk",
Usage: *usage,
}
respData, _ := json.Marshal(resp)
@ -271,7 +279,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
return nil
}
result, err := ComputeChoices(input, predInput, config, o, o.Loader, func(s string, c *[]Choice) {
result, tokenUsage, err := ComputeChoices(input, predInput, config, o, o.Loader, func(s string, c *[]Choice) {
if processFunctions {
// As we have to change the result before processing, we can't stream the answer (yet?)
ss := map[string]interface{}{}
@ -327,8 +335,8 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
return
}
prediction = backend.Finetune(*config, predInput, prediction)
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &prediction}})
fineTunedResponse := backend.Finetune(*config, predInput, prediction.Response)
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &fineTunedResponse}})
} else {
// otherwise reply with the function call
*c = append(*c, Choice{
@ -349,6 +357,11 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: result,
Object: "chat.completion",
Usage: OpenAIUsage{
PromptTokens: tokenUsage.Prompt,
CompletionTokens: tokenUsage.Completion,
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
},
}
respData, _ := json.Marshal(resp)
log.Debug().Msgf("Response: %s", respData)

View File

@ -7,6 +7,7 @@ import (
"errors"
"fmt"
"github.com/go-skynet/LocalAI/api/backend"
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/options"
model "github.com/go-skynet/LocalAI/pkg/model"
@ -18,7 +19,7 @@ import (
// https://platform.openai.com/docs/api-reference/completions
func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
process := func(s string, req *OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan OpenAIResponse) {
ComputeChoices(req, s, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool {
ComputeChoices(req, s, config, o, loader, func(s string, c *[]Choice) {}, func(s string, usage backend.TokenUsage) bool {
resp := OpenAIResponse{
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []Choice{
@ -28,6 +29,11 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
},
},
Object: "text_completion",
Usage: OpenAIUsage{
PromptTokens: usage.Prompt,
CompletionTokens: usage.Completion,
TotalTokens: usage.Prompt + usage.Completion,
},
}
log.Debug().Msgf("Sending goroutine: %s", s)
@ -120,6 +126,9 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
}
var result []Choice
totalTokenUsage := backend.TokenUsage{}
for k, i := range config.PromptStrings {
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
@ -131,13 +140,16 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
log.Debug().Msgf("Template found, input modified to: %s", i)
}
r, err := ComputeChoices(input, i, config, o, o.Loader, func(s string, c *[]Choice) {
r, tokenUsage, err := ComputeChoices(input, i, config, o, o.Loader, func(s string, c *[]Choice) {
*c = append(*c, Choice{Text: s, FinishReason: "stop", Index: k})
}, nil)
if err != nil {
return err
}
totalTokenUsage.Prompt += tokenUsage.Prompt
totalTokenUsage.Completion += tokenUsage.Completion
result = append(result, r...)
}
@ -145,6 +157,11 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: result,
Object: "text_completion",
Usage: OpenAIUsage{
PromptTokens: totalTokenUsage.Prompt,
CompletionTokens: totalTokenUsage.Completion,
TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion,
},
}
jsonResult, _ := json.Marshal(resp)

View File

@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"github.com/go-skynet/LocalAI/api/backend"
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/options"
model "github.com/go-skynet/LocalAI/pkg/model"
@ -32,6 +33,8 @@ func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
}
var result []Choice
totalTokenUsage := backend.TokenUsage{}
for _, i := range config.InputStrings {
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.EditPromptTemplate, templateFile, model.PromptTemplateData{
@ -44,13 +47,16 @@ func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
log.Debug().Msgf("Template found, input modified to: %s", i)
}
r, err := ComputeChoices(input, i, config, o, o.Loader, func(s string, c *[]Choice) {
r, tokenUsage, err := ComputeChoices(input, i, config, o, o.Loader, func(s string, c *[]Choice) {
*c = append(*c, Choice{Text: s})
}, nil)
if err != nil {
return err
}
totalTokenUsage.Prompt += tokenUsage.Prompt
totalTokenUsage.Completion += tokenUsage.Completion
result = append(result, r...)
}
@ -58,6 +64,11 @@ func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: result,
Object: "edit",
Usage: OpenAIUsage{
PromptTokens: totalTokenUsage.Prompt,
CompletionTokens: totalTokenUsage.Completion,
TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion,
},
}
jsonResult, _ := json.Marshal(resp)

View File

@ -7,8 +7,8 @@ import (
model "github.com/go-skynet/LocalAI/pkg/model"
)
func ComputeChoices(req *OpenAIRequest, predInput string, config *config.Config, o *options.Option, loader *model.ModelLoader, cb func(string, *[]Choice), tokenCallback func(string) bool) ([]Choice, error) {
n := req.N
func ComputeChoices(req *OpenAIRequest, predInput string, config *config.Config, o *options.Option, loader *model.ModelLoader, cb func(string, *[]Choice), tokenCallback func(string, backend.TokenUsage) bool) ([]Choice, backend.TokenUsage, error) {
n := req.N // number of completions to return
result := []Choice{}
if n == 0 {
@ -18,20 +18,25 @@ func ComputeChoices(req *OpenAIRequest, predInput string, config *config.Config,
// get the model function to call for the result
predFunc, err := backend.ModelInference(req.Context, predInput, loader, *config, o, tokenCallback)
if err != nil {
return result, err
return result, backend.TokenUsage{}, err
}
tokenUsage := backend.TokenUsage{}
for i := 0; i < n; i++ {
prediction, err := predFunc()
if err != nil {
return result, err
return result, backend.TokenUsage{}, err
}
prediction = backend.Finetune(*config, predInput, prediction)
cb(prediction, &result)
tokenUsage.Prompt += prediction.Usage.Prompt
tokenUsage.Completion += prediction.Usage.Completion
finetunedResponse := backend.Finetune(*config, predInput, prediction.Response)
cb(finetunedResponse, &result)
//result = append(result, Choice{Text: prediction})
}
return result, err
return result, tokenUsage, err
}

12
go.mod
View File

@ -36,6 +36,17 @@ require (
gopkg.in/yaml.v3 v3.0.1
)
require (
github.com/go-ole/go-ole v1.2.6 // indirect
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
github.com/shirou/gopsutil/v3 v3.23.6
github.com/shoenig/go-m1cpu v0.1.6 // indirect
github.com/tklauser/go-sysconf v0.3.11 // indirect
github.com/tklauser/numcpus v0.6.0 // indirect
github.com/yusufpapurcu/wmi v1.2.3 // indirect
)
require (
github.com/dlclark/regexp2 v1.8.1 // indirect
github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5 // indirect
@ -50,7 +61,6 @@ require (
github.com/pkoukk/tiktoken-go v0.1.2 // indirect
github.com/ulikunitz/xz v0.5.9 // indirect
github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8 // indirect
google.golang.org/genproto v0.0.0-20230526161137-0005af68ea54 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20230525234030-28d5490b6b19 // indirect
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
gopkg.in/fsnotify.v1 v1.4.7 // indirect

91
go.sum
View File

@ -13,8 +13,6 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dlclark/regexp2 v1.8.1 h1:6Lcdwya6GjPUNsBct8Lg/yRPwMhABj269AAzdGSiR+0=
github.com/dlclark/regexp2 v1.8.1/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/donomii/go-rwkv.cpp v0.0.0-20230619005719-f5a8c4539674 h1:G70Yf/QOCEL1v24idWnGd6rJsbqiGkJAJnMaWaolzEg=
github.com/donomii/go-rwkv.cpp v0.0.0-20230619005719-f5a8c4539674/go.mod h1:gWy7FIWioqYmYxkaoFyBnaKApeZVrUkHhv9EV9pz4dM=
github.com/donomii/go-rwkv.cpp v0.0.0-20230715075832-c898cd0f62df h1:qVcBEZlvp5A1gGWNJj02xyDtbsUI2hohlQMSB1fgER4=
github.com/donomii/go-rwkv.cpp v0.0.0-20230715075832-c898cd0f62df/go.mod h1:gWy7FIWioqYmYxkaoFyBnaKApeZVrUkHhv9EV9pz4dM=
github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5 h1:iFaUwBSo5Svw6L7HYpRu/0lE3e0BaElwnNO1qkNQxBY=
@ -33,24 +31,14 @@ github.com/go-audio/wav v1.1.0 h1:jQgLtbqBzY7G+BM8fXF7AHUk1uHUviWS4X39d5rsL2g=
github.com/go-audio/wav v1.1.0/go.mod h1:mpe9qfwbScEbkd8uybLuIpTgHyrISw/OTuvjUW2iGtE=
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/go-skynet/bloomz.cpp v0.0.0-20230529155654-1834e77b83fa h1:gxr68r/6EWroay4iI81jxqGCDbKotY4+CiwdUkBz2NQ=
github.com/go-skynet/bloomz.cpp v0.0.0-20230529155654-1834e77b83fa/go.mod h1:wc0fJ9V04yiYTfgKvE5RUUSRQ5Kzi0Bo4I+U3nNOUuA=
github.com/go-skynet/go-bert.cpp v0.0.0-20230716133540-6abe312cded1 h1:yXvc7QfGtoZ51tUW/YVjoTwAfh8HG88XU7UOrbNlz5Y=
github.com/go-skynet/go-bert.cpp v0.0.0-20230716133540-6abe312cded1/go.mod h1:fYjkCDRzC+oRLHSjQoajmYK6AmeJnmEanV27CClAcDc=
github.com/go-skynet/go-ggml-transformers.cpp v0.0.0-20230714203132-ffb09d7dd71e h1:4reMY29i1eOZaRaSTMPNyXI7X8RMNxCTfDDBXYzrbr0=
github.com/go-skynet/go-ggml-transformers.cpp v0.0.0-20230714203132-ffb09d7dd71e/go.mod h1:31j1odgFXP8hDSUVfH0zErKI5aYVP18ddYnPkwCso2A=
github.com/go-skynet/go-llama.cpp v0.0.0-20230709163512-6c97625cca76 h1:NRdxo2MKi8qhWZXxu6CIZOkdH+LBERFz1kk22U1FD3k=
github.com/go-skynet/go-llama.cpp v0.0.0-20230709163512-6c97625cca76/go.mod h1:tzi97YvT1bVQ+iTG39LvpDkKG1WbizgtljC+orSoM40=
github.com/go-skynet/go-llama.cpp v0.0.0-20230724222459-562d2b5a7119 h1:FeUSk5yMHT7J7jeCQKAOs4x5LRNSYH0SR6djM/i1jcc=
github.com/go-skynet/go-llama.cpp v0.0.0-20230724222459-562d2b5a7119/go.mod h1:fiJBto+Le1XLtD/cID5SAKs8cKE7wFXJKfTT3wvPQRA=
github.com/go-skynet/go-llama.cpp v0.0.0-20230727163958-6ba16de8e965 h1:2MO/rABKpkXnnKQ3Ar90aqhnlMEejE9gnKG6bafv+ow=
github.com/go-skynet/go-llama.cpp v0.0.0-20230727163958-6ba16de8e965/go.mod h1:fiJBto+Le1XLtD/cID5SAKs8cKE7wFXJKfTT3wvPQRA=
github.com/go-skynet/go-llama.cpp v0.0.0-20230729200103-8c51308e42d7 h1:1uBwholTaJ8Lva8ySJjT4jNaCDAh+MJXtsbZBbQq9lA=
github.com/go-skynet/go-llama.cpp v0.0.0-20230729200103-8c51308e42d7/go.mod h1:fiJBto+Le1XLtD/cID5SAKs8cKE7wFXJKfTT3wvPQRA=
github.com/go-skynet/go-llama.cpp v0.0.0-20230802220037-50cee7712066 h1:v4Js+yEdgY9IV7n35M+5MELLxlOMp3qC5whZm5YTLjI=
github.com/go-skynet/go-llama.cpp v0.0.0-20230802220037-50cee7712066/go.mod h1:fiJBto+Le1XLtD/cID5SAKs8cKE7wFXJKfTT3wvPQRA=
github.com/go-skynet/go-llama.cpp v0.0.0-20230814195654-18f25c21abf9 h1:62wpzDHwjZGfIfimvve3bNrS6/gOLkSfwsCjcSD6g8U=
github.com/go-skynet/go-llama.cpp v0.0.0-20230814195654-18f25c21abf9/go.mod h1:fiJBto+Le1XLtD/cID5SAKs8cKE7wFXJKfTT3wvPQRA=
github.com/go-skynet/go-llama.cpp v0.0.0-20230815201253-f03869d188b7 h1:d/FXe1a55gCLf124uRYYtlYg6KvI7OI33xaFejQUAws=
github.com/go-skynet/go-llama.cpp v0.0.0-20230815201253-f03869d188b7/go.mod h1:fiJBto+Le1XLtD/cID5SAKs8cKE7wFXJKfTT3wvPQRA=
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE=
@ -76,6 +64,7 @@ github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
@ -107,6 +96,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
@ -130,22 +121,6 @@ github.com/mudler/go-processmanager v0.0.0-20220724164624-c45b5c61312d h1:/lAg9v
github.com/mudler/go-processmanager v0.0.0-20220724164624-c45b5c61312d/go.mod h1:HGGAOJhipApckwNV8ZTliRJqxctUv3xRY+zbQEwuytc=
github.com/mudler/go-stable-diffusion v0.0.0-20230605122230-d89260f598af h1:XFq6OUqsWQam0OrEr05okXsJK/TQur3zoZTHbiZD3Ks=
github.com/mudler/go-stable-diffusion v0.0.0-20230605122230-d89260f598af/go.mod h1:8ufRkpz/S/9ahkaxzZ5i4WMgO9w4InEhuRoT7vK5Rnw=
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230714185456-cfd70b69fcf5 h1:bmQnxyKiqCu8i2y/N/Sf0coWoG2/Ed12YGQeb7lTnjo=
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230714185456-cfd70b69fcf5/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI=
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230725212419-9100b2ef6fb9 h1:/oRwZhulKTU8LpPD2fXi2o2kdlTutQjYWDVMkrv14po=
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230725212419-9100b2ef6fb9/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI=
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230727161923-39acbc837816 h1:hRi7hpDUuaO0dB4NZ8eyaeD2fRar6CPyNAARsO5DhzA=
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230727161923-39acbc837816/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI=
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230731161838-cbdcde8b7586 h1:WVEMSZMyHFe68PN204c3Fdk5g2lZouPvbU9/2zkPpWc=
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230731161838-cbdcde8b7586/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI=
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230802145814-c449b71b56de h1:E5EGczxEAcbaO8yqj074MQxU609QbtB6in3qTOW1EFo=
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230802145814-c449b71b56de/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI=
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230807175413-0f2bb506a8ee h1:Y/j+GNytyncmDnAEuDZwzkYC9nzUPvXJPF+nntQG0VU=
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230807175413-0f2bb506a8ee/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI=
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230811181453-4d855afe973a h1:bX26Zfwh72ug2aZTEwFISTMEJ56Wa/4KqboidD+g92A=
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230811181453-4d855afe973a/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI=
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230814164545-4e55940edf11 h1:72DoTIAcKXEv5Q5MSaHFCpVAQHqwU84wUsxy/UcdKTc=
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230814164545-4e55940edf11/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI=
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230815171941-a63093554fb5 h1:b4EeYDaGxOLNlNm5LOVEmrUhaw1v6xq/V79ZwWVlY6I=
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230815171941-a63093554fb5/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI=
github.com/nwaples/rardecode v1.1.0 h1:vSxaY8vQhOcVr4mm5e8XllHWTiM4JF507A0Katqw7MQ=
@ -162,8 +137,6 @@ github.com/onsi/ginkgo/v2 v2.11.0/go.mod h1:ZhrRA5XmEE3x3rhlzamx/JJvujdZoJ2uvgI7
github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY=
github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo=
github.com/onsi/gomega v1.16.0/go.mod h1:HnhC7FXeEQY45zxNK3PPoIUhzk/80Xly9PcubAlGdZY=
github.com/onsi/gomega v1.27.8 h1:gegWiwZjBsf2DgiSbf5hpokZ98JVDMcWkUiigk6/KXc=
github.com/onsi/gomega v1.27.8/go.mod h1:2J8vzI/s+2shY9XHRApDkdgPo1TKT7P2u6fXeJKFnNQ=
github.com/onsi/gomega v1.27.10 h1:naR28SdDFlqrG6kScpT8VWpu1xWY5nJRCF3XaYyBjhI=
github.com/onsi/gomega v1.27.10/go.mod h1:RsS8tutOdbdgzbPtzzATp12yT7kM5I5aElG3evPbQ0M=
github.com/otiai10/mint v1.6.1 h1:kgbTJmOpp/0ce7hk3H8jiSuR0MXmpwWRfqUdKww17qg=
@ -178,39 +151,37 @@ github.com/pkoukk/tiktoken-go v0.1.2 h1:u7PCSBiWJ3nJYoTGShyM9iHXz4dNyYkurwwp+GHt
github.com/pkoukk/tiktoken-go v0.1.2/go.mod h1:boMWvk9pQCOTx11pgu0DrIdrAKgQzzJKUP6vLXaz7Rw=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/rs/zerolog v1.29.1 h1:cO+d60CHkknCbvzEWxP0S9K6KqyTjrCNUy1LdQLCGPc=
github.com/rs/zerolog v1.29.1/go.mod h1:Le6ESbR7hc+DP6Lt1THiV8CQSdkkNrd3R0XbEgp3ZBU=
github.com/rs/zerolog v1.30.0 h1:SymVODrcRsaRaSInD9yQtKbtWqwsfoPcRff/oRXLj4c=
github.com/rs/zerolog v1.30.0/go.mod h1:/tk+P47gFdPXq4QYjvCmT5/Gsug2nagsFWBWhAiSi1w=
github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/sashabaranov/go-openai v1.14.0 h1:D1yAB+DHElgbJFdYyjxfTWMFzhddn+PwZmkQ039L7mQ=
github.com/sashabaranov/go-openai v1.14.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/sashabaranov/go-openai v1.14.1 h1:jqfkdj8XHnBF84oi2aNtT8Ktp3EJ0MfuVjvcMkfI0LA=
github.com/sashabaranov/go-openai v1.14.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/sashabaranov/go-openai v1.14.2 h1:5DPTtR9JBjKPJS008/A409I5ntFhUPPGCmaAihcPRyo=
github.com/sashabaranov/go-openai v1.14.2/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/shirou/gopsutil/v3 v3.23.6 h1:5y46WPI9QBKBbK7EEccUPNXpJpNrvPuTD0O2zHEHT08=
github.com/shirou/gopsutil/v3 v3.23.6/go.mod h1:j7QX50DrXYggrpN30W0Mo+I4/8U2UUIQrnrhqUeWrAU=
github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM=
github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ=
github.com/shoenig/test v0.6.4 h1:kVTaSd7WLz5WZ2IaoM0RSzRsUD+m8wRR+5qvntpn4LU=
github.com/shoenig/test v0.6.4/go.mod h1:byHiCGXqrVaflBLAMq/srcZIHynQPQgeyvkvXnjqq0k=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
github.com/tmc/langchaingo v0.0.0-20230713201705-dcf7ecdc8ac8 h1:wdJigYmmIRCuXhCkADDr53Oa1fp/WlxCPoVXR2r7GrU=
github.com/tmc/langchaingo v0.0.0-20230713201705-dcf7ecdc8ac8/go.mod h1:mTzgQfAGwmBz2hhQELZfu2bwsbHwyKHA6IHOa+9LDFg=
github.com/tmc/langchaingo v0.0.0-20230726025230-7d5f9fd5e90a h1:I/2JSuYXkWaVVLSZmrPfrgbvvvPR0IaulZcB0Iu8oVI=
github.com/tmc/langchaingo v0.0.0-20230726025230-7d5f9fd5e90a/go.mod h1:8T+nNIGBr3nYQEYFmF/YaT8t8YTKLvFYZBuVZOAYn5E=
github.com/tmc/langchaingo v0.0.0-20230729232647-7df4fe5fb8fe h1:+XVrCjh3rPibfISkUFG2Ck5NLKODQ9cFdmraFye1bGA=
github.com/tmc/langchaingo v0.0.0-20230729232647-7df4fe5fb8fe/go.mod h1:8T+nNIGBr3nYQEYFmF/YaT8t8YTKLvFYZBuVZOAYn5E=
github.com/tmc/langchaingo v0.0.0-20230731024823-8f101609f600 h1:SABuIthjhIXEsxnokuA16CZOxxdW9XohIHQqd/go8Nc=
github.com/tmc/langchaingo v0.0.0-20230731024823-8f101609f600/go.mod h1:8T+nNIGBr3nYQEYFmF/YaT8t8YTKLvFYZBuVZOAYn5E=
github.com/tmc/langchaingo v0.0.0-20230802030916-271e9bd7e7c5 h1:js7vYDJGzUGVSt0YlIusUc5BXYVECu3LUI/asby5Ggo=
github.com/tmc/langchaingo v0.0.0-20230802030916-271e9bd7e7c5/go.mod h1:8T+nNIGBr3nYQEYFmF/YaT8t8YTKLvFYZBuVZOAYn5E=
github.com/tmc/langchaingo v0.0.0-20230811231558-fd8b7f099537 h1:vkeNjlW+0Xiw2XizMHoQuLG8pg6AN1hU8zJuMV9GQBc=
github.com/tmc/langchaingo v0.0.0-20230811231558-fd8b7f099537/go.mod h1:8T+nNIGBr3nYQEYFmF/YaT8t8YTKLvFYZBuVZOAYn5E=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/tklauser/go-sysconf v0.3.11 h1:89WgdJhk5SNwJfu+GKyYveZ4IaJ7xAkecBo+KdJV0CM=
github.com/tklauser/go-sysconf v0.3.11/go.mod h1:GqXfhXY3kiPa0nAXPDIQIWzJbMCB7AmcWpGR8lSZfqI=
github.com/tklauser/numcpus v0.6.0 h1:kebhY2Qt+3U6RNK7UqpYNA+tJ23IBEGKkB7JQBfDYms=
github.com/tklauser/numcpus v0.6.0/go.mod h1:FEZLMke0lhOUG6w2JadTzp0a+Nl8PF/GFkQ5UVIcaL4=
github.com/tmc/langchaingo v0.0.0-20230815194031-eb0cbd31327d h1:RBu2wOoyzxNxYTitUKVNDtU1H6T4Tu5skOwvZabnPFc=
github.com/tmc/langchaingo v0.0.0-20230815194031-eb0cbd31327d/go.mod h1:8T+nNIGBr3nYQEYFmF/YaT8t8YTKLvFYZBuVZOAYn5E=
github.com/ulikunitz/xz v0.5.8/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14=
@ -229,6 +200,8 @@ github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8/go.mod h1:HUYIGzjTL3rfEspMx
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 h1:bAn7/zixMGCfxrRTfdpNzjtPYqr8smhKouy9mxVdGPU=
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673/go.mod h1:N3UwUGtsrSj3ccvlPHLoLsHnpR27oXr4ZE984MbSER8=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yusufpapurcu/wmi v1.2.3 h1:E1ctvB7uKFMOJw3fdOW32DwGE9I7t++CRUEMKvFoFiw=
github.com/yusufpapurcu/wmi v1.2.3/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
@ -240,8 +213,6 @@ golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLL
golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk=
golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.12.0 h1:cfawfvKITfUsFCeJIHJrbSxpeu/E81khclypR0GVT50=
golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@ -251,26 +222,28 @@ golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5h
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210112080510-489259a85091/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.11.0 h1:LAntKIrcmeSKERyiOh0XMV39LXS8IE9UL2yP7+f5ij4=
golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
@ -282,14 +255,8 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1 h1:KpwkzHKEF7B9Zxg18WzOa7djJ+Ha5DzthMyZYQfEn2A=
google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1/go.mod h1:nKE/iIaLqn2bQwXBg8f1g2Ylh6r5MN5CmZvuzZCgsCU=
google.golang.org/genproto v0.0.0-20230526161137-0005af68ea54 h1:9NWlQfY2ePejTmfwUH1OWwmznFa+0kKcHGPDvcPza9M=
google.golang.org/genproto v0.0.0-20230526161137-0005af68ea54/go.mod h1:zqTuNwFlFRsw5zIts5VnzLQxSRqh+CGOTVMlYbY0Eyk=
google.golang.org/genproto/googleapis/rpc v0.0.0-20230525234030-28d5490b6b19 h1:0nDDozoAU19Qb2HwhXadU8OcsiO/09cnTqhUtq2MEOM=
google.golang.org/genproto/googleapis/rpc v0.0.0-20230525234030-28d5490b6b19/go.mod h1:66JfowdXAEgad5O9NnYcsNPLCPZJD++2L9X0PCMODrA=
google.golang.org/grpc v1.56.2 h1:fVRFRnXvU+x6C4IlHZewvJOVHoOv1TUuQyoRsYnB4bI=
google.golang.org/grpc v1.56.2/go.mod h1:I9bI3vqKfayGqPUAwGdOSu7kt6oIJLixfffKrpXqQ9s=
google.golang.org/grpc v1.57.0 h1:kfzNeI/klCGD2YPMUlaGNT3pxvYfga7smW3Vth8Zsiw=
google.golang.org/grpc v1.57.0/go.mod h1:Sd+9RMTACXwmub0zcNY2c4arhtrbBYD1AUHI/dt16Mo=
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=

11
main.go
View File

@ -135,6 +135,12 @@ func main() {
Usage: "List of API Keys to enable API authentication. When this is set, all the requests must be authenticated with one of these API keys.",
EnvVars: []string{"API_KEY"},
},
&cli.BoolFlag{
Name: "preload-backend-only",
Usage: "If set, the api is NOT launched, and only the preloaded models / backends are started. This is intended for multi-node setups.",
EnvVars: []string{"PRELOAD_BACKEND_ONLY"},
Value: false,
},
},
Description: `
LocalAI is a drop-in replacement OpenAI API which runs inference locally.
@ -187,6 +193,11 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit
opts = append(opts, options.EnableGalleriesAutoload)
}
if ctx.Bool("preload-backend-only") {
_, _, err := api.Startup(opts...)
return err
}
app, err := api.App(opts...)
if err != nil {
return err

View File

@ -4,17 +4,39 @@ package base
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
import (
"fmt"
"os"
"sync"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/go-skynet/LocalAI/pkg/grpc/whisper/api"
gopsutil "github.com/shirou/gopsutil/v3/process"
)
type Base struct {
backendBusy sync.Mutex
State pb.StatusResponse_State
}
func (llm *Base) Busy() bool {
r := llm.backendBusy.TryLock()
if r {
llm.backendBusy.Unlock()
}
return r
}
func (llm *Base) Lock() {
llm.backendBusy.Lock()
llm.State = pb.StatusResponse_BUSY
}
func (llm *Base) Unlock() {
llm.State = pb.StatusResponse_READY
llm.backendBusy.Unlock()
}
func (llm *Base) Load(opts *pb.ModelOptions) error {
return fmt.Errorf("unimplemented")
}
func (llm *Base) Predict(opts *pb.PredictOptions) (string, error) {
@ -40,3 +62,32 @@ func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (api.Result, error) {
func (llm *Base) TTS(*pb.TTSRequest) error {
return fmt.Errorf("unimplemented")
}
func (llm *Base) TokenizeString(opts *pb.PredictOptions) (pb.TokenizationResponse, error) {
return pb.TokenizationResponse{}, fmt.Errorf("unimplemented")
}
// backends may wish to call this to capture the gopsutil info, then enhance with additional memory usage details?
func (llm *Base) Status() (pb.StatusResponse, error) {
mud := pb.MemoryUsageData{
Breakdown: make(map[string]uint64),
}
pid := int32(os.Getpid())
backendProcess, err := gopsutil.NewProcess(pid)
if err == nil {
memInfo, err := backendProcess.MemoryInfo()
if err == nil {
mud.Total = memInfo.VMS // TEST, but rss seems reasonable first guess. Does include swap, but we might care about that.
mud.Breakdown["gopsutil-RSS"] = memInfo.RSS
}
}
return pb.StatusResponse{
State: llm.State,
Memory: &mud,
}, nil
}

View File

@ -158,3 +158,29 @@ func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptReques
tresult.Text = res.Text
return tresult, err
}
func (c *Client) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error) {
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
}
defer conn.Close()
client := pb.NewBackendClient(conn)
res, err := client.TokenizeString(ctx, in, opts...)
if err != nil {
return nil, err
}
return res, nil
}
func (c *Client) Status(ctx context.Context) (*pb.StatusResponse, error) {
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
}
defer conn.Close()
client := pb.NewBackendClient(conn)
return client.Status(ctx, &pb.HealthMessage{})
}

View File

@ -6,6 +6,7 @@ import (
)
type LLM interface {
Busy() bool
Predict(*pb.PredictOptions) (string, error)
PredictStream(*pb.PredictOptions, chan string) error
Load(*pb.ModelOptions) error
@ -13,6 +14,8 @@ type LLM interface {
GenerateImage(*pb.GenerateImageRequest) error
AudioTranscription(*pb.TranscriptRequest) (api.Result, error)
TTS(*pb.TTSRequest) error
TokenizeString(*pb.PredictOptions) (pb.TokenizationResponse, error)
Status() (pb.StatusResponse, error)
}
func newReply(s string) *pb.Reply {

View File

@ -4,6 +4,7 @@ package bert
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
import (
bert "github.com/go-skynet/go-bert.cpp"
"github.com/rs/zerolog/log"
"github.com/go-skynet/LocalAI/pkg/grpc/base"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
@ -15,12 +16,21 @@ type Embeddings struct {
}
func (llm *Embeddings) Load(opts *pb.ModelOptions) error {
if llm.Base.State != pb.StatusResponse_UNINITIALIZED {
log.Warn().Msgf("bert backend loading %s while already in state %s!", opts.Model, llm.Base.State.String())
}
llm.Base.Lock()
defer llm.Base.Unlock()
model, err := bert.New(opts.ModelFile)
llm.bert = model
return err
}
func (llm *Embeddings) Embeddings(opts *pb.PredictOptions) ([]float32, error) {
llm.Base.Lock()
defer llm.Base.Unlock()
if len(opts.EmbeddingTokens) > 0 {
tokens := []int{}
for _, t := range opts.EmbeddingTokens {

View File

@ -7,6 +7,7 @@ import (
"github.com/go-skynet/LocalAI/pkg/grpc/base"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/rs/zerolog/log"
"github.com/go-skynet/bloomz.cpp"
)
@ -18,6 +19,12 @@ type LLM struct {
}
func (llm *LLM) Load(opts *pb.ModelOptions) error {
if llm.Base.State != pb.StatusResponse_UNINITIALIZED {
log.Warn().Msgf("bloomz backend loading %s while already in state %s!", opts.Model, llm.Base.State.String())
}
llm.Base.Lock()
defer llm.Base.Unlock()
model, err := bloomz.New(opts.ModelFile)
llm.bloomz = model
return err
@ -40,11 +47,16 @@ func buildPredictOptions(opts *pb.PredictOptions) []bloomz.PredictOption {
}
func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) {
llm.Base.Lock()
defer llm.Base.Unlock()
return llm.bloomz.Predict(opts.Prompt, buildPredictOptions(opts)...)
}
// fallback to Predict
func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error {
llm.Base.Lock()
go func() {
res, err := llm.bloomz.Predict(opts.Prompt, buildPredictOptions(opts)...)
@ -53,6 +65,7 @@ func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) erro
}
results <- res
close(results)
llm.Base.Unlock()
}()
return nil

View File

@ -7,6 +7,7 @@ import (
"github.com/go-skynet/LocalAI/pkg/grpc/base"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/rs/zerolog/log"
ggllm "github.com/mudler/go-ggllm.cpp"
)
@ -18,6 +19,13 @@ type LLM struct {
}
func (llm *LLM) Load(opts *pb.ModelOptions) error {
if llm.Base.State != pb.StatusResponse_UNINITIALIZED {
log.Warn().Msgf("falcon backend loading %s while already in state %s!", opts.Model, llm.Base.State.String())
}
llm.Base.Lock()
defer llm.Base.Unlock()
ggllmOpts := []ggllm.ModelOption{}
if opts.ContextSize != 0 {
ggllmOpts = append(ggllmOpts, ggllm.SetContext(int(opts.ContextSize)))
@ -118,10 +126,14 @@ func buildPredictOptions(opts *pb.PredictOptions) []ggllm.PredictOption {
}
func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) {
llm.Base.Lock()
defer llm.Base.Unlock()
return llm.falcon.Predict(opts.Prompt, buildPredictOptions(opts)...)
}
func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error {
llm.Base.Lock()
predictOptions := buildPredictOptions(opts)
predictOptions = append(predictOptions, ggllm.SetTokenCallback(func(token string) bool {
@ -138,6 +150,7 @@ func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) erro
fmt.Println("err: ", err)
}
close(results)
llm.Base.Unlock()
}()
return nil

View File

@ -8,6 +8,7 @@ import (
"github.com/go-skynet/LocalAI/pkg/grpc/base"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
gpt4all "github.com/nomic-ai/gpt4all/gpt4all-bindings/golang"
"github.com/rs/zerolog/log"
)
type LLM struct {
@ -17,6 +18,13 @@ type LLM struct {
}
func (llm *LLM) Load(opts *pb.ModelOptions) error {
if llm.Base.State != pb.StatusResponse_UNINITIALIZED {
log.Warn().Msgf("gpt4all backend loading %s while already in state %s!", opts.Model, llm.Base.State.String())
}
llm.Base.Lock()
defer llm.Base.Unlock()
model, err := gpt4all.New(opts.ModelFile,
gpt4all.SetThreads(int(opts.Threads)),
gpt4all.SetLibrarySearchPath(opts.LibrarySearchPath))
@ -39,10 +47,15 @@ func buildPredictOptions(opts *pb.PredictOptions) []gpt4all.PredictOption {
}
func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) {
llm.Base.Lock()
defer llm.Base.Unlock()
return llm.gpt4all.Predict(opts.Prompt, buildPredictOptions(opts)...)
}
func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error {
llm.Base.Lock()
predictOptions := buildPredictOptions(opts)
go func() {
@ -56,6 +69,7 @@ func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) erro
}
llm.gpt4all.SetTokenCallback(nil)
close(results)
llm.Base.Unlock()
}()
return nil

View File

@ -8,6 +8,7 @@ import (
"github.com/go-skynet/LocalAI/pkg/grpc/base"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/go-skynet/LocalAI/pkg/langchain"
"github.com/rs/zerolog/log"
)
type LLM struct {
@ -18,12 +19,21 @@ type LLM struct {
}
func (llm *LLM) Load(opts *pb.ModelOptions) error {
if llm.Base.State != pb.StatusResponse_UNINITIALIZED {
log.Warn().Msgf("langchain backend loading %s while already in state %s!", opts.Model, llm.Base.State.String())
}
llm.Base.Lock()
defer llm.Base.Unlock()
llm.langchain, _ = langchain.NewHuggingFace(opts.Model)
llm.model = opts.Model
return nil
}
func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) {
llm.Base.Lock()
defer llm.Base.Unlock()
o := []langchain.PredictOption{
langchain.SetModel(llm.model),
langchain.SetMaxTokens(int(opts.Tokens)),
@ -38,6 +48,7 @@ func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) {
}
func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error {
llm.Base.Lock()
o := []langchain.PredictOption{
langchain.SetModel(llm.model),
langchain.SetMaxTokens(int(opts.Tokens)),
@ -52,6 +63,7 @@ func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) erro
}
results <- res.Completion
close(results)
llm.Base.Unlock()
}()
return nil

View File

@ -8,6 +8,7 @@ import (
"github.com/go-skynet/LocalAI/pkg/grpc/base"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/go-skynet/go-llama.cpp"
"github.com/rs/zerolog/log"
)
type LLM struct {
@ -18,6 +19,13 @@ type LLM struct {
func (llm *LLM) Load(opts *pb.ModelOptions) error {
if llm.Base.State != pb.StatusResponse_UNINITIALIZED {
log.Warn().Msgf("llama backend loading %s while already in state %s!", opts.Model, llm.Base.State.String())
}
llm.Base.Lock()
defer llm.Base.Unlock()
ropeFreqBase := float32(10000)
ropeFreqScale := float32(1)
@ -73,6 +81,7 @@ func (llm *LLM) Load(opts *pb.ModelOptions) error {
model, err := llama.New(opts.ModelFile, llamaOpts...)
llm.llama = model
return err
}
@ -167,10 +176,14 @@ func buildPredictOptions(opts *pb.PredictOptions) []llama.PredictOption {
}
func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) {
llm.Base.Lock()
defer llm.Base.Unlock()
return llm.llama.Predict(opts.Prompt, buildPredictOptions(opts)...)
}
func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error {
llm.Base.Lock()
predictOptions := buildPredictOptions(opts)
predictOptions = append(predictOptions, llama.SetTokenCallback(func(token string) bool {
@ -184,12 +197,16 @@ func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) erro
fmt.Println("err: ", err)
}
close(results)
llm.Base.Unlock()
}()
return nil
}
func (llm *LLM) Embeddings(opts *pb.PredictOptions) ([]float32, error) {
llm.Base.Lock()
defer llm.Base.Unlock()
predictOptions := buildPredictOptions(opts)
if len(opts.EmbeddingTokens) > 0 {
@ -202,3 +219,18 @@ func (llm *LLM) Embeddings(opts *pb.PredictOptions) ([]float32, error) {
return llm.llama.Embeddings(opts.Embeddings, predictOptions...)
}
func (llm *LLM) TokenizeString(opts *pb.PredictOptions) (pb.TokenizationResponse, error) {
llm.Base.Lock()
defer llm.Base.Unlock()
predictOptions := buildPredictOptions(opts)
l, tokens, err := llm.llama.TokenizeString(opts.Prompt, predictOptions...)
if err != nil {
return pb.TokenizationResponse{}, err
}
return pb.TokenizationResponse{
Length: l,
Tokens: tokens,
}, nil
}

View File

@ -9,6 +9,7 @@ import (
"github.com/donomii/go-rwkv.cpp"
"github.com/go-skynet/LocalAI/pkg/grpc/base"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/rs/zerolog/log"
)
const tokenizerSuffix = ".tokenizer.json"
@ -20,6 +21,12 @@ type LLM struct {
}
func (llm *LLM) Load(opts *pb.ModelOptions) error {
if llm.Base.State != pb.StatusResponse_UNINITIALIZED {
log.Warn().Msgf("rwkv backend loading %s while already in state %s!", opts.Model, llm.Base.State.String())
}
llm.Base.Lock()
defer llm.Base.Unlock()
modelPath := filepath.Dir(opts.ModelFile)
modelFile := filepath.Base(opts.ModelFile)
model := rwkv.LoadFiles(opts.ModelFile, filepath.Join(modelPath, modelFile+tokenizerSuffix), uint32(opts.GetThreads()))
@ -32,6 +39,8 @@ func (llm *LLM) Load(opts *pb.ModelOptions) error {
}
func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) {
llm.Base.Lock()
defer llm.Base.Unlock()
stopWord := "\n"
if len(opts.StopPrompts) > 0 {
@ -48,6 +57,7 @@ func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) {
}
func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error {
llm.Base.Lock()
go func() {
stopWord := "\n"
@ -65,6 +75,7 @@ func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) erro
return true
})
close(results)
llm.Base.Unlock()
}()
return nil

View File

@ -7,6 +7,7 @@ import (
"github.com/go-skynet/LocalAI/pkg/grpc/base"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/rs/zerolog/log"
transformers "github.com/go-skynet/go-ggml-transformers.cpp"
)
@ -18,17 +19,27 @@ type Dolly struct {
}
func (llm *Dolly) Load(opts *pb.ModelOptions) error {
if llm.Base.State != pb.StatusResponse_UNINITIALIZED {
log.Warn().Msgf("dolly backend loading %s while already in state %s!", opts.Model, llm.Base.State.String())
}
llm.Base.Lock()
defer llm.Base.Unlock()
model, err := transformers.NewDolly(opts.ModelFile)
llm.dolly = model
return err
}
func (llm *Dolly) Predict(opts *pb.PredictOptions) (string, error) {
llm.Base.Lock()
defer llm.Base.Unlock()
return llm.dolly.Predict(opts.Prompt, buildPredictOptions(opts)...)
}
// fallback to Predict
func (llm *Dolly) PredictStream(opts *pb.PredictOptions, results chan string) error {
llm.Base.Lock()
go func() {
res, err := llm.dolly.Predict(opts.Prompt, buildPredictOptions(opts)...)
@ -37,6 +48,7 @@ func (llm *Dolly) PredictStream(opts *pb.PredictOptions, results chan string) er
}
results <- res
close(results)
llm.Base.Unlock()
}()
return nil

View File

@ -7,6 +7,7 @@ import (
"github.com/go-skynet/LocalAI/pkg/grpc/base"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/rs/zerolog/log"
transformers "github.com/go-skynet/go-ggml-transformers.cpp"
)
@ -18,17 +19,26 @@ type Falcon struct {
}
func (llm *Falcon) Load(opts *pb.ModelOptions) error {
if llm.Base.State != pb.StatusResponse_UNINITIALIZED {
log.Warn().Msgf("transformers-falcon backend loading %s while already in state %s!", opts.Model, llm.Base.State.String())
}
llm.Base.Lock()
defer llm.Base.Unlock()
model, err := transformers.NewFalcon(opts.ModelFile)
llm.falcon = model
return err
}
func (llm *Falcon) Predict(opts *pb.PredictOptions) (string, error) {
llm.Base.Lock()
defer llm.Base.Unlock()
return llm.falcon.Predict(opts.Prompt, buildPredictOptions(opts)...)
}
// fallback to Predict
func (llm *Falcon) PredictStream(opts *pb.PredictOptions, results chan string) error {
llm.Base.Lock()
go func() {
res, err := llm.falcon.Predict(opts.Prompt, buildPredictOptions(opts)...)
@ -37,6 +47,7 @@ func (llm *Falcon) PredictStream(opts *pb.PredictOptions, results chan string) e
}
results <- res
close(results)
llm.Base.Unlock()
}()
return nil

View File

@ -7,6 +7,7 @@ import (
"github.com/go-skynet/LocalAI/pkg/grpc/base"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/rs/zerolog/log"
transformers "github.com/go-skynet/go-ggml-transformers.cpp"
)
@ -18,17 +19,26 @@ type GPT2 struct {
}
func (llm *GPT2) Load(opts *pb.ModelOptions) error {
if llm.Base.State != pb.StatusResponse_UNINITIALIZED {
log.Warn().Msgf("gpt2 backend loading %s while already in state %s!", opts.Model, llm.Base.State.String())
}
llm.Base.Lock()
defer llm.Base.Unlock()
model, err := transformers.New(opts.ModelFile)
llm.gpt2 = model
return err
}
func (llm *GPT2) Predict(opts *pb.PredictOptions) (string, error) {
llm.Base.Lock()
defer llm.Base.Unlock()
return llm.gpt2.Predict(opts.Prompt, buildPredictOptions(opts)...)
}
// fallback to Predict
func (llm *GPT2) PredictStream(opts *pb.PredictOptions, results chan string) error {
llm.Base.Lock()
go func() {
res, err := llm.gpt2.Predict(opts.Prompt, buildPredictOptions(opts)...)
@ -37,6 +47,7 @@ func (llm *GPT2) PredictStream(opts *pb.PredictOptions, results chan string) err
}
results <- res
close(results)
llm.Base.Unlock()
}()
return nil
}

View File

@ -7,6 +7,7 @@ import (
"github.com/go-skynet/LocalAI/pkg/grpc/base"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/rs/zerolog/log"
transformers "github.com/go-skynet/go-ggml-transformers.cpp"
)
@ -18,17 +19,26 @@ type GPTJ struct {
}
func (llm *GPTJ) Load(opts *pb.ModelOptions) error {
if llm.Base.State != pb.StatusResponse_UNINITIALIZED {
log.Warn().Msgf("gptj backend loading %s while already in state %s!", opts.Model, llm.Base.State.String())
}
llm.Base.Lock()
defer llm.Base.Unlock()
model, err := transformers.NewGPTJ(opts.ModelFile)
llm.gptj = model
return err
}
func (llm *GPTJ) Predict(opts *pb.PredictOptions) (string, error) {
llm.Base.Lock()
defer llm.Base.Unlock()
return llm.gptj.Predict(opts.Prompt, buildPredictOptions(opts)...)
}
// fallback to Predict
func (llm *GPTJ) PredictStream(opts *pb.PredictOptions, results chan string) error {
llm.Base.Lock()
go func() {
res, err := llm.gptj.Predict(opts.Prompt, buildPredictOptions(opts)...)
@ -37,6 +47,7 @@ func (llm *GPTJ) PredictStream(opts *pb.PredictOptions, results chan string) err
}
results <- res
close(results)
llm.Base.Unlock()
}()
return nil
}

View File

@ -7,6 +7,7 @@ import (
"github.com/go-skynet/LocalAI/pkg/grpc/base"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/rs/zerolog/log"
transformers "github.com/go-skynet/go-ggml-transformers.cpp"
)
@ -18,17 +19,26 @@ type GPTNeoX struct {
}
func (llm *GPTNeoX) Load(opts *pb.ModelOptions) error {
if llm.Base.State != pb.StatusResponse_UNINITIALIZED {
log.Warn().Msgf("gptneox backend loading %s while already in state %s!", opts.Model, llm.Base.State.String())
}
llm.Base.Lock()
defer llm.Base.Unlock()
model, err := transformers.NewGPTNeoX(opts.ModelFile)
llm.gptneox = model
return err
}
func (llm *GPTNeoX) Predict(opts *pb.PredictOptions) (string, error) {
llm.Base.Lock()
defer llm.Base.Unlock()
return llm.gptneox.Predict(opts.Prompt, buildPredictOptions(opts)...)
}
// fallback to Predict
func (llm *GPTNeoX) PredictStream(opts *pb.PredictOptions, results chan string) error {
llm.Base.Lock()
go func() {
res, err := llm.gptneox.Predict(opts.Prompt, buildPredictOptions(opts)...)
@ -37,6 +47,7 @@ func (llm *GPTNeoX) PredictStream(opts *pb.PredictOptions, results chan string)
}
results <- res
close(results)
llm.Base.Unlock()
}()
return nil
}

View File

@ -7,6 +7,7 @@ import (
"github.com/go-skynet/LocalAI/pkg/grpc/base"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/rs/zerolog/log"
transformers "github.com/go-skynet/go-ggml-transformers.cpp"
)
@ -18,17 +19,27 @@ type MPT struct {
}
func (llm *MPT) Load(opts *pb.ModelOptions) error {
if llm.Base.State != pb.StatusResponse_UNINITIALIZED {
log.Warn().Msgf("mpt backend loading %s while already in state %s!", opts.Model, llm.Base.State.String())
}
llm.Base.Lock()
defer llm.Base.Unlock()
model, err := transformers.NewMPT(opts.ModelFile)
llm.mpt = model
return err
}
func (llm *MPT) Predict(opts *pb.PredictOptions) (string, error) {
llm.Base.Lock()
defer llm.Base.Unlock()
return llm.mpt.Predict(opts.Prompt, buildPredictOptions(opts)...)
}
// fallback to Predict
func (llm *MPT) PredictStream(opts *pb.PredictOptions, results chan string) error {
llm.Base.Lock()
go func() {
res, err := llm.mpt.Predict(opts.Prompt, buildPredictOptions(opts)...)
@ -37,6 +48,7 @@ func (llm *MPT) PredictStream(opts *pb.PredictOptions, results chan string) erro
}
results <- res
close(results)
llm.Base.Unlock()
}()
return nil
}

View File

@ -7,6 +7,7 @@ import (
"github.com/go-skynet/LocalAI/pkg/grpc/base"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/rs/zerolog/log"
transformers "github.com/go-skynet/go-ggml-transformers.cpp"
)
@ -18,17 +19,26 @@ type Replit struct {
}
func (llm *Replit) Load(opts *pb.ModelOptions) error {
if llm.Base.State != pb.StatusResponse_UNINITIALIZED {
log.Warn().Msgf("replit backend loading %s while already in state %s!", opts.Model, llm.Base.State.String())
}
llm.Base.Lock()
defer llm.Base.Unlock()
model, err := transformers.NewReplit(opts.ModelFile)
llm.replit = model
return err
}
func (llm *Replit) Predict(opts *pb.PredictOptions) (string, error) {
llm.Base.Lock()
defer llm.Base.Unlock()
return llm.replit.Predict(opts.Prompt, buildPredictOptions(opts)...)
}
// fallback to Predict
func (llm *Replit) PredictStream(opts *pb.PredictOptions, results chan string) error {
llm.Base.Lock()
go func() {
res, err := llm.replit.Predict(opts.Prompt, buildPredictOptions(opts)...)
@ -37,6 +47,7 @@ func (llm *Replit) PredictStream(opts *pb.PredictOptions, results chan string) e
}
results <- res
close(results)
llm.Base.Unlock()
}()
return nil
}

View File

@ -7,6 +7,7 @@ import (
"github.com/go-skynet/LocalAI/pkg/grpc/base"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/rs/zerolog/log"
transformers "github.com/go-skynet/go-ggml-transformers.cpp"
)
@ -18,17 +19,26 @@ type Starcoder struct {
}
func (llm *Starcoder) Load(opts *pb.ModelOptions) error {
if llm.Base.State != pb.StatusResponse_UNINITIALIZED {
log.Warn().Msgf("starcoder backend loading %s while already in state %s!", opts.Model, llm.Base.State.String())
}
llm.Base.Lock()
defer llm.Base.Unlock()
model, err := transformers.NewStarcoder(opts.ModelFile)
llm.starcoder = model
return err
}
func (llm *Starcoder) Predict(opts *pb.PredictOptions) (string, error) {
llm.Base.Lock()
defer llm.Base.Unlock()
return llm.starcoder.Predict(opts.Prompt, buildPredictOptions(opts)...)
}
// fallback to Predict
func (llm *Starcoder) PredictStream(opts *pb.PredictOptions, results chan string) error {
llm.Base.Lock()
go func() {
res, err := llm.starcoder.Predict(opts.Prompt, buildPredictOptions(opts)...)
@ -37,6 +47,7 @@ func (llm *Starcoder) PredictStream(opts *pb.PredictOptions, results chan string
}
results <- res
close(results)
llm.Base.Unlock()
}()
return nil

View File

@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.26.0
// protoc v3.15.8
// protoc-gen-go v1.27.1
// protoc v3.12.4
// source: pkg/grpc/proto/backend.proto
package proto
@ -20,6 +20,58 @@ const (
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type StatusResponse_State int32
const (
StatusResponse_UNINITIALIZED StatusResponse_State = 0
StatusResponse_BUSY StatusResponse_State = 1
StatusResponse_READY StatusResponse_State = 2
StatusResponse_ERROR StatusResponse_State = -1
)
// Enum value maps for StatusResponse_State.
var (
StatusResponse_State_name = map[int32]string{
0: "UNINITIALIZED",
1: "BUSY",
2: "READY",
-1: "ERROR",
}
StatusResponse_State_value = map[string]int32{
"UNINITIALIZED": 0,
"BUSY": 1,
"READY": 2,
"ERROR": -1,
}
)
func (x StatusResponse_State) Enum() *StatusResponse_State {
p := new(StatusResponse_State)
*p = x
return p
}
func (x StatusResponse_State) String() string {
return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x))
}
func (StatusResponse_State) Descriptor() protoreflect.EnumDescriptor {
return file_pkg_grpc_proto_backend_proto_enumTypes[0].Descriptor()
}
func (StatusResponse_State) Type() protoreflect.EnumType {
return &file_pkg_grpc_proto_backend_proto_enumTypes[0]
}
func (x StatusResponse_State) Number() protoreflect.EnumNumber {
return protoreflect.EnumNumber(x)
}
// Deprecated: Use StatusResponse_State.Descriptor instead.
func (StatusResponse_State) EnumDescriptor() ([]byte, []int) {
return file_pkg_grpc_proto_backend_proto_rawDescGZIP(), []int{13, 0}
}
type HealthMessage struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
@ -1253,6 +1305,171 @@ func (x *TTSRequest) GetDst() string {
return ""
}
type TokenizationResponse struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Length int32 `protobuf:"varint,1,opt,name=length,proto3" json:"length,omitempty"`
Tokens []int32 `protobuf:"varint,2,rep,packed,name=tokens,proto3" json:"tokens,omitempty"`
}
func (x *TokenizationResponse) Reset() {
*x = TokenizationResponse{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_proto_backend_proto_msgTypes[11]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *TokenizationResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*TokenizationResponse) ProtoMessage() {}
func (x *TokenizationResponse) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_proto_backend_proto_msgTypes[11]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use TokenizationResponse.ProtoReflect.Descriptor instead.
func (*TokenizationResponse) Descriptor() ([]byte, []int) {
return file_pkg_grpc_proto_backend_proto_rawDescGZIP(), []int{11}
}
func (x *TokenizationResponse) GetLength() int32 {
if x != nil {
return x.Length
}
return 0
}
func (x *TokenizationResponse) GetTokens() []int32 {
if x != nil {
return x.Tokens
}
return nil
}
type MemoryUsageData struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Total uint64 `protobuf:"varint,1,opt,name=total,proto3" json:"total,omitempty"`
Breakdown map[string]uint64 `protobuf:"bytes,2,rep,name=breakdown,proto3" json:"breakdown,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"varint,2,opt,name=value,proto3"`
}
func (x *MemoryUsageData) Reset() {
*x = MemoryUsageData{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_proto_backend_proto_msgTypes[12]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *MemoryUsageData) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*MemoryUsageData) ProtoMessage() {}
func (x *MemoryUsageData) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_proto_backend_proto_msgTypes[12]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use MemoryUsageData.ProtoReflect.Descriptor instead.
func (*MemoryUsageData) Descriptor() ([]byte, []int) {
return file_pkg_grpc_proto_backend_proto_rawDescGZIP(), []int{12}
}
func (x *MemoryUsageData) GetTotal() uint64 {
if x != nil {
return x.Total
}
return 0
}
func (x *MemoryUsageData) GetBreakdown() map[string]uint64 {
if x != nil {
return x.Breakdown
}
return nil
}
type StatusResponse struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
State StatusResponse_State `protobuf:"varint,1,opt,name=state,proto3,enum=backend.StatusResponse_State" json:"state,omitempty"`
Memory *MemoryUsageData `protobuf:"bytes,2,opt,name=memory,proto3" json:"memory,omitempty"`
}
func (x *StatusResponse) Reset() {
*x = StatusResponse{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_proto_backend_proto_msgTypes[13]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *StatusResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*StatusResponse) ProtoMessage() {}
func (x *StatusResponse) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_proto_backend_proto_msgTypes[13]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use StatusResponse.ProtoReflect.Descriptor instead.
func (*StatusResponse) Descriptor() ([]byte, []int) {
return file_pkg_grpc_proto_backend_proto_rawDescGZIP(), []int{13}
}
func (x *StatusResponse) GetState() StatusResponse_State {
if x != nil {
return x.State
}
return StatusResponse_UNINITIALIZED
}
func (x *StatusResponse) GetMemory() *MemoryUsageData {
if x != nil {
return x.Memory
}
return nil
}
var File_pkg_grpc_proto_backend_proto protoreflect.FileDescriptor
var file_pkg_grpc_proto_backend_proto_rawDesc = []byte{
@ -1451,44 +1668,80 @@ var file_pkg_grpc_proto_backend_proto_rawDesc = []byte{
0x04, 0x74, 0x65, 0x78, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x74, 0x65, 0x78,
0x74, 0x12, 0x14, 0x0a, 0x05, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09,
0x52, 0x05, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x12, 0x10, 0x0a, 0x03, 0x64, 0x73, 0x74, 0x18, 0x03,
0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x64, 0x73, 0x74, 0x32, 0xeb, 0x03, 0x0a, 0x07, 0x42, 0x61,
0x63, 0x6b, 0x65, 0x6e, 0x64, 0x12, 0x32, 0x0a, 0x06, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x12,
0x16, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68,
0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x0e, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e,
0x64, 0x2e, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x12, 0x34, 0x0a, 0x07, 0x50, 0x72, 0x65,
0x64, 0x69, 0x63, 0x74, 0x12, 0x17, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x50,
0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x1a, 0x0e, 0x2e,
0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x64, 0x73, 0x74, 0x22, 0x46, 0x0a, 0x14, 0x54, 0x6f, 0x6b,
0x65, 0x6e, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73,
0x65, 0x12, 0x16, 0x0a, 0x06, 0x6c, 0x65, 0x6e, 0x67, 0x74, 0x68, 0x18, 0x01, 0x20, 0x01, 0x28,
0x05, 0x52, 0x06, 0x6c, 0x65, 0x6e, 0x67, 0x74, 0x68, 0x12, 0x16, 0x0a, 0x06, 0x74, 0x6f, 0x6b,
0x65, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x05, 0x52, 0x06, 0x74, 0x6f, 0x6b, 0x65, 0x6e,
0x73, 0x22, 0xac, 0x01, 0x0a, 0x0f, 0x4d, 0x65, 0x6d, 0x6f, 0x72, 0x79, 0x55, 0x73, 0x61, 0x67,
0x65, 0x44, 0x61, 0x74, 0x61, 0x12, 0x14, 0x0a, 0x05, 0x74, 0x6f, 0x74, 0x61, 0x6c, 0x18, 0x01,
0x20, 0x01, 0x28, 0x04, 0x52, 0x05, 0x74, 0x6f, 0x74, 0x61, 0x6c, 0x12, 0x45, 0x0a, 0x09, 0x62,
0x72, 0x65, 0x61, 0x6b, 0x64, 0x6f, 0x77, 0x6e, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x27,
0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x4d, 0x65, 0x6d, 0x6f, 0x72, 0x79, 0x55,
0x73, 0x61, 0x67, 0x65, 0x44, 0x61, 0x74, 0x61, 0x2e, 0x42, 0x72, 0x65, 0x61, 0x6b, 0x64, 0x6f,
0x77, 0x6e, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x09, 0x62, 0x72, 0x65, 0x61, 0x6b, 0x64, 0x6f,
0x77, 0x6e, 0x1a, 0x3c, 0x0a, 0x0e, 0x42, 0x72, 0x65, 0x61, 0x6b, 0x64, 0x6f, 0x77, 0x6e, 0x45,
0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28,
0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18,
0x02, 0x20, 0x01, 0x28, 0x04, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01,
0x22, 0xbc, 0x01, 0x0a, 0x0e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f,
0x6e, 0x73, 0x65, 0x12, 0x33, 0x0a, 0x05, 0x73, 0x74, 0x61, 0x74, 0x65, 0x18, 0x01, 0x20, 0x01,
0x28, 0x0e, 0x32, 0x1d, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x53, 0x74, 0x61,
0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x2e, 0x53, 0x74, 0x61, 0x74,
0x65, 0x52, 0x05, 0x73, 0x74, 0x61, 0x74, 0x65, 0x12, 0x30, 0x0a, 0x06, 0x6d, 0x65, 0x6d, 0x6f,
0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65,
0x6e, 0x64, 0x2e, 0x4d, 0x65, 0x6d, 0x6f, 0x72, 0x79, 0x55, 0x73, 0x61, 0x67, 0x65, 0x44, 0x61,
0x74, 0x61, 0x52, 0x06, 0x6d, 0x65, 0x6d, 0x6f, 0x72, 0x79, 0x22, 0x43, 0x0a, 0x05, 0x53, 0x74,
0x61, 0x74, 0x65, 0x12, 0x11, 0x0a, 0x0d, 0x55, 0x4e, 0x49, 0x4e, 0x49, 0x54, 0x49, 0x41, 0x4c,
0x49, 0x5a, 0x45, 0x44, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x42, 0x55, 0x53, 0x59, 0x10, 0x01,
0x12, 0x09, 0x0a, 0x05, 0x52, 0x45, 0x41, 0x44, 0x59, 0x10, 0x02, 0x12, 0x12, 0x0a, 0x05, 0x45,
0x52, 0x52, 0x4f, 0x52, 0x10, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x01, 0x32,
0xf4, 0x04, 0x0a, 0x07, 0x42, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x12, 0x32, 0x0a, 0x06, 0x48,
0x65, 0x61, 0x6c, 0x74, 0x68, 0x12, 0x16, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e,
0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x0e, 0x2e,
0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x12,
0x35, 0x0a, 0x09, 0x4c, 0x6f, 0x61, 0x64, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x12, 0x15, 0x2e, 0x62,
0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x4f, 0x70, 0x74, 0x69,
0x6f, 0x6e, 0x73, 0x1a, 0x0f, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x52, 0x65,
0x73, 0x75, 0x6c, 0x74, 0x22, 0x00, 0x12, 0x3c, 0x0a, 0x0d, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63,
0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, 0x17, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e,
0x34, 0x0a, 0x07, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x12, 0x17, 0x2e, 0x62, 0x61, 0x63,
0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x4f, 0x70, 0x74, 0x69,
0x6f, 0x6e, 0x73, 0x1a, 0x0e, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x52, 0x65,
0x70, 0x6c, 0x79, 0x22, 0x00, 0x12, 0x35, 0x0a, 0x09, 0x4c, 0x6f, 0x61, 0x64, 0x4d, 0x6f, 0x64,
0x65, 0x6c, 0x12, 0x15, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x4d, 0x6f, 0x64,
0x65, 0x6c, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x1a, 0x0f, 0x2e, 0x62, 0x61, 0x63, 0x6b,
0x65, 0x6e, 0x64, 0x2e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x22, 0x00, 0x12, 0x3c, 0x0a, 0x0d,
0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, 0x17, 0x2e,
0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x4f,
0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x1a, 0x0e, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64,
0x2e, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x30, 0x01, 0x12, 0x40, 0x0a, 0x09, 0x45, 0x6d,
0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x12, 0x17, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e,
0x64, 0x2e, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73,
0x1a, 0x0e, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x52, 0x65, 0x70, 0x6c, 0x79,
0x22, 0x00, 0x30, 0x01, 0x12, 0x40, 0x0a, 0x09, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e,
0x67, 0x12, 0x17, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x50, 0x72, 0x65, 0x64,
0x69, 0x63, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x1a, 0x18, 0x2e, 0x62, 0x61, 0x63,
0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x65,
0x73, 0x75, 0x6c, 0x74, 0x22, 0x00, 0x12, 0x41, 0x0a, 0x0d, 0x47, 0x65, 0x6e, 0x65, 0x72, 0x61,
0x74, 0x65, 0x49, 0x6d, 0x61, 0x67, 0x65, 0x12, 0x1d, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e,
0x64, 0x2e, 0x47, 0x65, 0x6e, 0x65, 0x72, 0x61, 0x74, 0x65, 0x49, 0x6d, 0x61, 0x67, 0x65, 0x52,
0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x0f, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64,
0x2e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x22, 0x00, 0x12, 0x4d, 0x0a, 0x12, 0x41, 0x75, 0x64,
0x69, 0x6f, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x12,
0x1a, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x63,
0x72, 0x69, 0x70, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x62, 0x61,
0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74,
0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x03, 0x54, 0x54, 0x53, 0x12,
0x13, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x54, 0x54, 0x53, 0x52, 0x65, 0x71,
0x75, 0x65, 0x73, 0x74, 0x1a, 0x0f, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x52,
0x65, 0x73, 0x75, 0x6c, 0x74, 0x22, 0x00, 0x42, 0x5a, 0x0a, 0x19, 0x69, 0x6f, 0x2e, 0x73, 0x6b,
0x79, 0x6e, 0x65, 0x74, 0x2e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x61, 0x69, 0x2e, 0x62, 0x61, 0x63,
0x6b, 0x65, 0x6e, 0x64, 0x42, 0x0e, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x41, 0x49, 0x42, 0x61, 0x63,
0x6b, 0x65, 0x6e, 0x64, 0x50, 0x01, 0x5a, 0x2b, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63,
0x6f, 0x6d, 0x2f, 0x67, 0x6f, 0x2d, 0x73, 0x6b, 0x79, 0x6e, 0x65, 0x74, 0x2f, 0x4c, 0x6f, 0x63,
0x61, 0x6c, 0x41, 0x49, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x70, 0x72,
0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
0x1a, 0x18, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x45, 0x6d, 0x62, 0x65, 0x64,
0x64, 0x69, 0x6e, 0x67, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x22, 0x00, 0x12, 0x41, 0x0a, 0x0d,
0x47, 0x65, 0x6e, 0x65, 0x72, 0x61, 0x74, 0x65, 0x49, 0x6d, 0x61, 0x67, 0x65, 0x12, 0x1d, 0x2e,
0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x47, 0x65, 0x6e, 0x65, 0x72, 0x61, 0x74, 0x65,
0x49, 0x6d, 0x61, 0x67, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x0f, 0x2e, 0x62,
0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x22, 0x00, 0x12,
0x4d, 0x0a, 0x12, 0x41, 0x75, 0x64, 0x69, 0x6f, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x63, 0x72, 0x69,
0x70, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x1a, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e,
0x54, 0x72, 0x61, 0x6e, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73,
0x74, 0x1a, 0x19, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x54, 0x72, 0x61, 0x6e,
0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x22, 0x00, 0x12, 0x2d,
0x0a, 0x03, 0x54, 0x54, 0x53, 0x12, 0x13, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e,
0x54, 0x54, 0x53, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x0f, 0x2e, 0x62, 0x61, 0x63,
0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x22, 0x00, 0x12, 0x4a, 0x0a,
0x0e, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x69, 0x7a, 0x65, 0x53, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x12,
0x17, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63,
0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x1a, 0x1d, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65,
0x6e, 0x64, 0x2e, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52,
0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x3b, 0x0a, 0x06, 0x53, 0x74, 0x61,
0x74, 0x75, 0x73, 0x12, 0x16, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x48, 0x65,
0x61, 0x6c, 0x74, 0x68, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x17, 0x2e, 0x62, 0x61,
0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70,
0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x5a, 0x0a, 0x19, 0x69, 0x6f, 0x2e, 0x73, 0x6b, 0x79,
0x6e, 0x65, 0x74, 0x2e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x61, 0x69, 0x2e, 0x62, 0x61, 0x63, 0x6b,
0x65, 0x6e, 0x64, 0x42, 0x0e, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x41, 0x49, 0x42, 0x61, 0x63, 0x6b,
0x65, 0x6e, 0x64, 0x50, 0x01, 0x5a, 0x2b, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f,
0x6d, 0x2f, 0x67, 0x6f, 0x2d, 0x73, 0x6b, 0x79, 0x6e, 0x65, 0x74, 0x2f, 0x4c, 0x6f, 0x63, 0x61,
0x6c, 0x41, 0x49, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x70, 0x72, 0x6f,
0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
@ -1503,43 +1756,56 @@ func file_pkg_grpc_proto_backend_proto_rawDescGZIP() []byte {
return file_pkg_grpc_proto_backend_proto_rawDescData
}
var file_pkg_grpc_proto_backend_proto_msgTypes = make([]protoimpl.MessageInfo, 11)
var file_pkg_grpc_proto_backend_proto_enumTypes = make([]protoimpl.EnumInfo, 1)
var file_pkg_grpc_proto_backend_proto_msgTypes = make([]protoimpl.MessageInfo, 15)
var file_pkg_grpc_proto_backend_proto_goTypes = []interface{}{
(*HealthMessage)(nil), // 0: backend.HealthMessage
(*PredictOptions)(nil), // 1: backend.PredictOptions
(*Reply)(nil), // 2: backend.Reply
(*ModelOptions)(nil), // 3: backend.ModelOptions
(*Result)(nil), // 4: backend.Result
(*EmbeddingResult)(nil), // 5: backend.EmbeddingResult
(*TranscriptRequest)(nil), // 6: backend.TranscriptRequest
(*TranscriptResult)(nil), // 7: backend.TranscriptResult
(*TranscriptSegment)(nil), // 8: backend.TranscriptSegment
(*GenerateImageRequest)(nil), // 9: backend.GenerateImageRequest
(*TTSRequest)(nil), // 10: backend.TTSRequest
(StatusResponse_State)(0), // 0: backend.StatusResponse.State
(*HealthMessage)(nil), // 1: backend.HealthMessage
(*PredictOptions)(nil), // 2: backend.PredictOptions
(*Reply)(nil), // 3: backend.Reply
(*ModelOptions)(nil), // 4: backend.ModelOptions
(*Result)(nil), // 5: backend.Result
(*EmbeddingResult)(nil), // 6: backend.EmbeddingResult
(*TranscriptRequest)(nil), // 7: backend.TranscriptRequest
(*TranscriptResult)(nil), // 8: backend.TranscriptResult
(*TranscriptSegment)(nil), // 9: backend.TranscriptSegment
(*GenerateImageRequest)(nil), // 10: backend.GenerateImageRequest
(*TTSRequest)(nil), // 11: backend.TTSRequest
(*TokenizationResponse)(nil), // 12: backend.TokenizationResponse
(*MemoryUsageData)(nil), // 13: backend.MemoryUsageData
(*StatusResponse)(nil), // 14: backend.StatusResponse
nil, // 15: backend.MemoryUsageData.BreakdownEntry
}
var file_pkg_grpc_proto_backend_proto_depIdxs = []int32{
8, // 0: backend.TranscriptResult.segments:type_name -> backend.TranscriptSegment
0, // 1: backend.Backend.Health:input_type -> backend.HealthMessage
1, // 2: backend.Backend.Predict:input_type -> backend.PredictOptions
3, // 3: backend.Backend.LoadModel:input_type -> backend.ModelOptions
1, // 4: backend.Backend.PredictStream:input_type -> backend.PredictOptions
1, // 5: backend.Backend.Embedding:input_type -> backend.PredictOptions
9, // 6: backend.Backend.GenerateImage:input_type -> backend.GenerateImageRequest
6, // 7: backend.Backend.AudioTranscription:input_type -> backend.TranscriptRequest
10, // 8: backend.Backend.TTS:input_type -> backend.TTSRequest
2, // 9: backend.Backend.Health:output_type -> backend.Reply
2, // 10: backend.Backend.Predict:output_type -> backend.Reply
4, // 11: backend.Backend.LoadModel:output_type -> backend.Result
2, // 12: backend.Backend.PredictStream:output_type -> backend.Reply
5, // 13: backend.Backend.Embedding:output_type -> backend.EmbeddingResult
4, // 14: backend.Backend.GenerateImage:output_type -> backend.Result
7, // 15: backend.Backend.AudioTranscription:output_type -> backend.TranscriptResult
4, // 16: backend.Backend.TTS:output_type -> backend.Result
9, // [9:17] is the sub-list for method output_type
1, // [1:9] is the sub-list for method input_type
1, // [1:1] is the sub-list for extension type_name
1, // [1:1] is the sub-list for extension extendee
0, // [0:1] is the sub-list for field type_name
9, // 0: backend.TranscriptResult.segments:type_name -> backend.TranscriptSegment
15, // 1: backend.MemoryUsageData.breakdown:type_name -> backend.MemoryUsageData.BreakdownEntry
0, // 2: backend.StatusResponse.state:type_name -> backend.StatusResponse.State
13, // 3: backend.StatusResponse.memory:type_name -> backend.MemoryUsageData
1, // 4: backend.Backend.Health:input_type -> backend.HealthMessage
2, // 5: backend.Backend.Predict:input_type -> backend.PredictOptions
4, // 6: backend.Backend.LoadModel:input_type -> backend.ModelOptions
2, // 7: backend.Backend.PredictStream:input_type -> backend.PredictOptions
2, // 8: backend.Backend.Embedding:input_type -> backend.PredictOptions
10, // 9: backend.Backend.GenerateImage:input_type -> backend.GenerateImageRequest
7, // 10: backend.Backend.AudioTranscription:input_type -> backend.TranscriptRequest
11, // 11: backend.Backend.TTS:input_type -> backend.TTSRequest
2, // 12: backend.Backend.TokenizeString:input_type -> backend.PredictOptions
1, // 13: backend.Backend.Status:input_type -> backend.HealthMessage
3, // 14: backend.Backend.Health:output_type -> backend.Reply
3, // 15: backend.Backend.Predict:output_type -> backend.Reply
5, // 16: backend.Backend.LoadModel:output_type -> backend.Result
3, // 17: backend.Backend.PredictStream:output_type -> backend.Reply
6, // 18: backend.Backend.Embedding:output_type -> backend.EmbeddingResult
5, // 19: backend.Backend.GenerateImage:output_type -> backend.Result
8, // 20: backend.Backend.AudioTranscription:output_type -> backend.TranscriptResult
5, // 21: backend.Backend.TTS:output_type -> backend.Result
12, // 22: backend.Backend.TokenizeString:output_type -> backend.TokenizationResponse
14, // 23: backend.Backend.Status:output_type -> backend.StatusResponse
14, // [14:24] is the sub-list for method output_type
4, // [4:14] is the sub-list for method input_type
4, // [4:4] is the sub-list for extension type_name
4, // [4:4] is the sub-list for extension extendee
0, // [0:4] is the sub-list for field type_name
}
func init() { file_pkg_grpc_proto_backend_proto_init() }
@ -1680,19 +1946,56 @@ func file_pkg_grpc_proto_backend_proto_init() {
return nil
}
}
file_pkg_grpc_proto_backend_proto_msgTypes[11].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*TokenizationResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_pkg_grpc_proto_backend_proto_msgTypes[12].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*MemoryUsageData); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_pkg_grpc_proto_backend_proto_msgTypes[13].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*StatusResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_pkg_grpc_proto_backend_proto_rawDesc,
NumEnums: 0,
NumMessages: 11,
NumEnums: 1,
NumMessages: 15,
NumExtensions: 0,
NumServices: 1,
},
GoTypes: file_pkg_grpc_proto_backend_proto_goTypes,
DependencyIndexes: file_pkg_grpc_proto_backend_proto_depIdxs,
EnumInfos: file_pkg_grpc_proto_backend_proto_enumTypes,
MessageInfos: file_pkg_grpc_proto_backend_proto_msgTypes,
}.Build()
File_pkg_grpc_proto_backend_proto = out.File

View File

@ -16,6 +16,8 @@ service Backend {
rpc GenerateImage(GenerateImageRequest) returns (Result) {}
rpc AudioTranscription(TranscriptRequest) returns (TranscriptResult) {}
rpc TTS(TTSRequest) returns (Result) {}
rpc TokenizeString(PredictOptions) returns (TokenizationResponse) {}
rpc Status(HealthMessage) returns (StatusResponse) {}
}
message HealthMessage {}
@ -157,3 +159,24 @@ message TTSRequest {
string model = 2;
string dst = 3;
}
message TokenizationResponse {
int32 length = 1;
repeated int32 tokens = 2;
}
message MemoryUsageData {
uint64 total = 1;
map<string, uint64> breakdown = 2;
}
message StatusResponse {
enum State {
UNINITIALIZED = 0;
BUSY = 1;
READY = 2;
ERROR = -1;
}
State state = 1;
MemoryUsageData memory = 2;
}

View File

@ -1,7 +1,7 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.2.0
// - protoc v3.15.8
// - protoc-gen-go-grpc v1.3.0
// - protoc v3.12.4
// source: pkg/grpc/proto/backend.proto
package proto
@ -18,6 +18,19 @@ import (
// Requires gRPC-Go v1.32.0 or later.
const _ = grpc.SupportPackageIsVersion7
const (
Backend_Health_FullMethodName = "/backend.Backend/Health"
Backend_Predict_FullMethodName = "/backend.Backend/Predict"
Backend_LoadModel_FullMethodName = "/backend.Backend/LoadModel"
Backend_PredictStream_FullMethodName = "/backend.Backend/PredictStream"
Backend_Embedding_FullMethodName = "/backend.Backend/Embedding"
Backend_GenerateImage_FullMethodName = "/backend.Backend/GenerateImage"
Backend_AudioTranscription_FullMethodName = "/backend.Backend/AudioTranscription"
Backend_TTS_FullMethodName = "/backend.Backend/TTS"
Backend_TokenizeString_FullMethodName = "/backend.Backend/TokenizeString"
Backend_Status_FullMethodName = "/backend.Backend/Status"
)
// BackendClient is the client API for Backend service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
@ -30,6 +43,8 @@ type BackendClient interface {
GenerateImage(ctx context.Context, in *GenerateImageRequest, opts ...grpc.CallOption) (*Result, error)
AudioTranscription(ctx context.Context, in *TranscriptRequest, opts ...grpc.CallOption) (*TranscriptResult, error)
TTS(ctx context.Context, in *TTSRequest, opts ...grpc.CallOption) (*Result, error)
TokenizeString(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*TokenizationResponse, error)
Status(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*StatusResponse, error)
}
type backendClient struct {
@ -42,7 +57,7 @@ func NewBackendClient(cc grpc.ClientConnInterface) BackendClient {
func (c *backendClient) Health(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*Reply, error) {
out := new(Reply)
err := c.cc.Invoke(ctx, "/backend.Backend/Health", in, out, opts...)
err := c.cc.Invoke(ctx, Backend_Health_FullMethodName, in, out, opts...)
if err != nil {
return nil, err
}
@ -51,7 +66,7 @@ func (c *backendClient) Health(ctx context.Context, in *HealthMessage, opts ...g
func (c *backendClient) Predict(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*Reply, error) {
out := new(Reply)
err := c.cc.Invoke(ctx, "/backend.Backend/Predict", in, out, opts...)
err := c.cc.Invoke(ctx, Backend_Predict_FullMethodName, in, out, opts...)
if err != nil {
return nil, err
}
@ -60,7 +75,7 @@ func (c *backendClient) Predict(ctx context.Context, in *PredictOptions, opts ..
func (c *backendClient) LoadModel(ctx context.Context, in *ModelOptions, opts ...grpc.CallOption) (*Result, error) {
out := new(Result)
err := c.cc.Invoke(ctx, "/backend.Backend/LoadModel", in, out, opts...)
err := c.cc.Invoke(ctx, Backend_LoadModel_FullMethodName, in, out, opts...)
if err != nil {
return nil, err
}
@ -68,7 +83,7 @@ func (c *backendClient) LoadModel(ctx context.Context, in *ModelOptions, opts ..
}
func (c *backendClient) PredictStream(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (Backend_PredictStreamClient, error) {
stream, err := c.cc.NewStream(ctx, &Backend_ServiceDesc.Streams[0], "/backend.Backend/PredictStream", opts...)
stream, err := c.cc.NewStream(ctx, &Backend_ServiceDesc.Streams[0], Backend_PredictStream_FullMethodName, opts...)
if err != nil {
return nil, err
}
@ -101,7 +116,7 @@ func (x *backendPredictStreamClient) Recv() (*Reply, error) {
func (c *backendClient) Embedding(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*EmbeddingResult, error) {
out := new(EmbeddingResult)
err := c.cc.Invoke(ctx, "/backend.Backend/Embedding", in, out, opts...)
err := c.cc.Invoke(ctx, Backend_Embedding_FullMethodName, in, out, opts...)
if err != nil {
return nil, err
}
@ -110,7 +125,7 @@ func (c *backendClient) Embedding(ctx context.Context, in *PredictOptions, opts
func (c *backendClient) GenerateImage(ctx context.Context, in *GenerateImageRequest, opts ...grpc.CallOption) (*Result, error) {
out := new(Result)
err := c.cc.Invoke(ctx, "/backend.Backend/GenerateImage", in, out, opts...)
err := c.cc.Invoke(ctx, Backend_GenerateImage_FullMethodName, in, out, opts...)
if err != nil {
return nil, err
}
@ -119,7 +134,7 @@ func (c *backendClient) GenerateImage(ctx context.Context, in *GenerateImageRequ
func (c *backendClient) AudioTranscription(ctx context.Context, in *TranscriptRequest, opts ...grpc.CallOption) (*TranscriptResult, error) {
out := new(TranscriptResult)
err := c.cc.Invoke(ctx, "/backend.Backend/AudioTranscription", in, out, opts...)
err := c.cc.Invoke(ctx, Backend_AudioTranscription_FullMethodName, in, out, opts...)
if err != nil {
return nil, err
}
@ -128,7 +143,25 @@ func (c *backendClient) AudioTranscription(ctx context.Context, in *TranscriptRe
func (c *backendClient) TTS(ctx context.Context, in *TTSRequest, opts ...grpc.CallOption) (*Result, error) {
out := new(Result)
err := c.cc.Invoke(ctx, "/backend.Backend/TTS", in, out, opts...)
err := c.cc.Invoke(ctx, Backend_TTS_FullMethodName, in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *backendClient) TokenizeString(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*TokenizationResponse, error) {
out := new(TokenizationResponse)
err := c.cc.Invoke(ctx, Backend_TokenizeString_FullMethodName, in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *backendClient) Status(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*StatusResponse, error) {
out := new(StatusResponse)
err := c.cc.Invoke(ctx, Backend_Status_FullMethodName, in, out, opts...)
if err != nil {
return nil, err
}
@ -147,6 +180,8 @@ type BackendServer interface {
GenerateImage(context.Context, *GenerateImageRequest) (*Result, error)
AudioTranscription(context.Context, *TranscriptRequest) (*TranscriptResult, error)
TTS(context.Context, *TTSRequest) (*Result, error)
TokenizeString(context.Context, *PredictOptions) (*TokenizationResponse, error)
Status(context.Context, *HealthMessage) (*StatusResponse, error)
mustEmbedUnimplementedBackendServer()
}
@ -178,6 +213,12 @@ func (UnimplementedBackendServer) AudioTranscription(context.Context, *Transcrip
func (UnimplementedBackendServer) TTS(context.Context, *TTSRequest) (*Result, error) {
return nil, status.Errorf(codes.Unimplemented, "method TTS not implemented")
}
func (UnimplementedBackendServer) TokenizeString(context.Context, *PredictOptions) (*TokenizationResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method TokenizeString not implemented")
}
func (UnimplementedBackendServer) Status(context.Context, *HealthMessage) (*StatusResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method Status not implemented")
}
func (UnimplementedBackendServer) mustEmbedUnimplementedBackendServer() {}
// UnsafeBackendServer may be embedded to opt out of forward compatibility for this service.
@ -201,7 +242,7 @@ func _Backend_Health_Handler(srv interface{}, ctx context.Context, dec func(inte
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/backend.Backend/Health",
FullMethod: Backend_Health_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).Health(ctx, req.(*HealthMessage))
@ -219,7 +260,7 @@ func _Backend_Predict_Handler(srv interface{}, ctx context.Context, dec func(int
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/backend.Backend/Predict",
FullMethod: Backend_Predict_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).Predict(ctx, req.(*PredictOptions))
@ -237,7 +278,7 @@ func _Backend_LoadModel_Handler(srv interface{}, ctx context.Context, dec func(i
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/backend.Backend/LoadModel",
FullMethod: Backend_LoadModel_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).LoadModel(ctx, req.(*ModelOptions))
@ -276,7 +317,7 @@ func _Backend_Embedding_Handler(srv interface{}, ctx context.Context, dec func(i
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/backend.Backend/Embedding",
FullMethod: Backend_Embedding_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).Embedding(ctx, req.(*PredictOptions))
@ -294,7 +335,7 @@ func _Backend_GenerateImage_Handler(srv interface{}, ctx context.Context, dec fu
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/backend.Backend/GenerateImage",
FullMethod: Backend_GenerateImage_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).GenerateImage(ctx, req.(*GenerateImageRequest))
@ -312,7 +353,7 @@ func _Backend_AudioTranscription_Handler(srv interface{}, ctx context.Context, d
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/backend.Backend/AudioTranscription",
FullMethod: Backend_AudioTranscription_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).AudioTranscription(ctx, req.(*TranscriptRequest))
@ -330,7 +371,7 @@ func _Backend_TTS_Handler(srv interface{}, ctx context.Context, dec func(interfa
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/backend.Backend/TTS",
FullMethod: Backend_TTS_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).TTS(ctx, req.(*TTSRequest))
@ -338,6 +379,42 @@ func _Backend_TTS_Handler(srv interface{}, ctx context.Context, dec func(interfa
return interceptor(ctx, in, info, handler)
}
func _Backend_TokenizeString_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(PredictOptions)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(BackendServer).TokenizeString(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: Backend_TokenizeString_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).TokenizeString(ctx, req.(*PredictOptions))
}
return interceptor(ctx, in, info, handler)
}
func _Backend_Status_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(HealthMessage)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(BackendServer).Status(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: Backend_Status_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).Status(ctx, req.(*HealthMessage))
}
return interceptor(ctx, in, info, handler)
}
// Backend_ServiceDesc is the grpc.ServiceDesc for Backend service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
@ -373,6 +450,14 @@ var Backend_ServiceDesc = grpc.ServiceDesc{
MethodName: "TTS",
Handler: _Backend_TTS_Handler,
},
{
MethodName: "TokenizeString",
Handler: _Backend_TokenizeString_Handler,
},
{
MethodName: "Status",
Handler: _Backend_Status_Handler,
},
},
Streams: []grpc.StreamDesc{
{

View File

@ -110,6 +110,32 @@ func (s *server) PredictStream(in *pb.PredictOptions, stream pb.Backend_PredictS
return nil
}
func (s *server) TokenizeString(ctx context.Context, in *pb.PredictOptions) (*pb.TokenizationResponse, error) {
res, err := s.llm.TokenizeString(in)
if err != nil {
return nil, err
}
castTokens := make([]int32, len(res.Tokens))
for i, v := range res.Tokens {
castTokens[i] = int32(v)
}
return &pb.TokenizationResponse{
Length: int32(res.Length),
Tokens: castTokens,
}, err
}
func (s *server) Status(ctx context.Context, in *pb.HealthMessage) (*pb.StatusResponse, error) {
res, err := s.llm.Status()
if err != nil {
return nil, err
}
return &res, nil
}
func StartServer(address string, model LLM) error {
lis, err := net.Listen("tcp", address)
if err != nil {

View File

@ -6,6 +6,7 @@ import (
"os"
"os/signal"
"path/filepath"
"strconv"
"strings"
"syscall"
"time"
@ -64,10 +65,33 @@ var AutoLoadBackends []string = []string{
PiperBackend,
}
func (ml *ModelLoader) StopGRPC() {
for _, p := range ml.grpcProcesses {
p.Stop()
func (ml *ModelLoader) GetGRPCPID(id string) (int, error) {
p, exists := ml.grpcProcesses[id]
if !exists {
return -1, fmt.Errorf("no grpc backend found for %s", id)
}
return strconv.Atoi(p.PID)
}
type GRPCProcessFilter = func(p *process.Process) bool
func includeAllProcesses(_ *process.Process) bool {
return true
}
func (ml *ModelLoader) StopGRPC(filter GRPCProcessFilter) {
for _, p := range ml.grpcProcesses {
if filter(p) {
p.Stop()
}
}
}
func (ml *ModelLoader) StopAllGRPC() {
ml.StopGRPC(includeAllProcesses)
// for _, p := range ml.grpcProcesses {
// p.Stop()
// }
}
func (ml *ModelLoader) startProcess(grpcProcess, id string, serverAddress string) error {
@ -252,7 +276,7 @@ func (ml *ModelLoader) GreedyLoader(opts ...Option) (*grpc.Client, error) {
// Is this really needed? BackendLoader already does this
ml.mu.Lock()
if m := ml.checkIsLoaded(o.model); m != nil {
if m := ml.CheckIsLoaded(o.model); m != nil {
log.Debug().Msgf("Model '%s' already loaded", o.model)
ml.mu.Unlock()
return m, nil

View File

@ -103,7 +103,7 @@ func (ml *ModelLoader) LoadModel(modelName string, loader func(string, string) (
defer ml.mu.Unlock()
// Check if we already have a loaded model
if model := ml.checkIsLoaded(modelName); model != nil {
if model := ml.CheckIsLoaded(modelName); model != nil {
return model, nil
}
@ -128,7 +128,7 @@ func (ml *ModelLoader) LoadModel(modelName string, loader func(string, string) (
return model, nil
}
func (ml *ModelLoader) checkIsLoaded(s string) *grpc.Client {
func (ml *ModelLoader) CheckIsLoaded(s string) *grpc.Client {
if m, ok := ml.models[s]; ok {
log.Debug().Msgf("Model already loaded in memory: %s", s)