From 5dcfdbe51da5b8c9159a358ab1694c0e4f68f437 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sat, 15 Jul 2023 01:19:43 +0200 Subject: [PATCH] feat: various refactorings Signed-off-by: Ettore Di Giacinto --- api/api.go | 108 ++-- api/api_test.go | 12 +- api/backend/embeddings.go | 107 ++++ api/backend/image.go | 56 ++ api/backend/llm.go | 160 ++++++ api/backend/lock.go | 22 + api/backend/options.go | 98 ++++ api/config.go | 401 ------------- api/config/config.go | 209 +++++++ api/{ => config}/config_test.go | 24 +- api/config/prediction.go | 37 ++ api/{ => localai}/gallery.go | 21 +- api/{ => localai}/localai.go | 21 +- api/openai.go | 973 -------------------------------- api/openai/api.go | 105 ++++ api/openai/chat.go | 320 +++++++++++ api/openai/completion.go | 159 ++++++ api/openai/edit.go | 67 +++ api/openai/embeddings.go | 70 +++ api/openai/image.go | 158 ++++++ api/openai/inference.go | 36 ++ api/openai/list.go | 37 ++ api/openai/request.go | 234 ++++++++ api/openai/transcription.go | 91 +++ api/{ => options}/options.go | 84 +-- api/prediction.go | 415 -------------- main.go | 35 +- pkg/grpc/llm/falcon/falcon.go | 3 + 28 files changed, 2130 insertions(+), 1933 deletions(-) create mode 100644 api/backend/embeddings.go create mode 100644 api/backend/image.go create mode 100644 api/backend/llm.go create mode 100644 api/backend/lock.go create mode 100644 api/backend/options.go delete mode 100644 api/config.go create mode 100644 api/config/config.go rename api/{ => config}/config_test.go (62%) create mode 100644 api/config/prediction.go rename api/{ => localai}/gallery.go (86%) rename api/{ => localai}/localai.go (68%) delete mode 100644 api/openai.go create mode 100644 api/openai/api.go create mode 100644 api/openai/chat.go create mode 100644 api/openai/completion.go create mode 100644 api/openai/edit.go create mode 100644 api/openai/embeddings.go create mode 100644 api/openai/image.go create mode 100644 api/openai/inference.go create mode 100644 api/openai/list.go create mode 100644 api/openai/request.go create mode 100644 api/openai/transcription.go rename api/{ => options}/options.go (60%) delete mode 100644 api/prediction.go diff --git a/api/api.go b/api/api.go index 1438f1f0..5d4f4c97 100644 --- a/api/api.go +++ b/api/api.go @@ -3,8 +3,13 @@ package api import ( "errors" + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/localai" + "github.com/go-skynet/LocalAI/api/openai" + "github.com/go-skynet/LocalAI/api/options" "github.com/go-skynet/LocalAI/internal" "github.com/go-skynet/LocalAI/pkg/assets" + "github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2/middleware/cors" "github.com/gofiber/fiber/v2/middleware/logger" @@ -13,18 +18,18 @@ import ( "github.com/rs/zerolog/log" ) -func App(opts ...AppOption) (*fiber.App, error) { - options := newOptions(opts...) +func App(opts ...options.AppOption) (*fiber.App, error) { + options := options.NewOptions(opts...) zerolog.SetGlobalLevel(zerolog.InfoLevel) - if options.debug { + if options.Debug { zerolog.SetGlobalLevel(zerolog.DebugLevel) } // Return errors as JSON responses app := fiber.New(fiber.Config{ - BodyLimit: options.uploadLimitMB * 1024 * 1024, // this is the default limit of 4MB - DisableStartupMessage: options.disableMessage, + BodyLimit: options.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB + DisableStartupMessage: options.DisableMessage, // Override default error handler ErrorHandler: func(ctx *fiber.Ctx, err error) error { // Status code defaults to 500 @@ -38,44 +43,44 @@ func App(opts ...AppOption) (*fiber.App, error) { // Send custom error page return ctx.Status(code).JSON( - ErrorResponse{ - Error: &APIError{Message: err.Error(), Code: code}, + openai.ErrorResponse{ + Error: &openai.APIError{Message: err.Error(), Code: code}, }, ) }, }) - if options.debug { + if options.Debug { app.Use(logger.New(logger.Config{ Format: "[${ip}]:${port} ${status} - ${method} ${path}\n", })) } - log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.threads, options.loader.ModelPath) + log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.Loader.ModelPath) log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion()) - cm := NewConfigMerger() - if err := cm.LoadConfigs(options.loader.ModelPath); err != nil { + cm := config.NewConfigLoader() + if err := cm.LoadConfigs(options.Loader.ModelPath); err != nil { log.Error().Msgf("error loading config files: %s", err.Error()) } - if options.configFile != "" { - if err := cm.LoadConfigFile(options.configFile); err != nil { + if options.ConfigFile != "" { + if err := cm.LoadConfigFile(options.ConfigFile); err != nil { log.Error().Msgf("error loading config file: %s", err.Error()) } } - if options.debug { + if options.Debug { for _, v := range cm.ListConfigs() { cfg, _ := cm.GetConfig(v) log.Debug().Msgf("Model: %s (config: %+v)", v, cfg) } } - if options.assetsDestination != "" { + if options.AssetsDestination != "" { // Extract files from the embedded FS - err := assets.ExtractFiles(options.backendAssets, options.assetsDestination) - log.Debug().Msgf("Extracting backend assets files to %s", options.assetsDestination) + err := assets.ExtractFiles(options.BackendAssets, options.AssetsDestination) + log.Debug().Msgf("Extracting backend assets files to %s", options.AssetsDestination) if err != nil { log.Warn().Msgf("Failed extracting backend assets files: %s (might be required for some backends to work properly, like gpt4all)", err) } @@ -84,31 +89,32 @@ func App(opts ...AppOption) (*fiber.App, error) { // Default middleware config app.Use(recover.New()) - if options.preloadJSONModels != "" { - if err := ApplyGalleryFromString(options.loader.ModelPath, options.preloadJSONModels, cm, options.galleries); err != nil { + if options.PreloadJSONModels != "" { + if err := localai.ApplyGalleryFromString(options.Loader.ModelPath, options.PreloadJSONModels, cm, options.Galleries); err != nil { return nil, err } } - if options.preloadModelsFromPath != "" { - if err := ApplyGalleryFromFile(options.loader.ModelPath, options.preloadModelsFromPath, cm, options.galleries); err != nil { + if options.PreloadModelsFromPath != "" { + if err := localai.ApplyGalleryFromFile(options.Loader.ModelPath, options.PreloadModelsFromPath, cm, options.Galleries); err != nil { return nil, err } } - if options.cors { - if options.corsAllowOrigins == "" { - app.Use(cors.New()) + if options.CORS { + var c func(ctx *fiber.Ctx) error + if options.CORSAllowOrigins == "" { + c = cors.New() } else { - app.Use(cors.New(cors.Config{ - AllowOrigins: options.corsAllowOrigins, - })) + c = cors.New(cors.Config{AllowOrigins: options.CORSAllowOrigins}) } + + app.Use(c) } // LocalAI API endpoints - applier := newGalleryApplier(options.loader.ModelPath) - applier.start(options.context, cm) + galleryService := localai.NewGalleryService(options.Loader.ModelPath) + galleryService.Start(options.Context, cm) app.Get("/version", func(c *fiber.Ctx) error { return c.JSON(struct { @@ -116,43 +122,43 @@ func App(opts ...AppOption) (*fiber.App, error) { }{Version: internal.PrintableVersion()}) }) - app.Post("/models/apply", applyModelGallery(options.loader.ModelPath, cm, applier.C, options.galleries)) - app.Get("/models/available", listModelFromGallery(options.galleries, options.loader.ModelPath)) - app.Get("/models/jobs/:uuid", getOpStatus(applier)) + app.Post("/models/apply", localai.ApplyModelGalleryEndpoint(options.Loader.ModelPath, cm, galleryService.C, options.Galleries)) + app.Get("/models/available", localai.ListModelFromGalleryEndpoint(options.Galleries, options.Loader.ModelPath)) + app.Get("/models/jobs/:uuid", localai.GetOpStatusEndpoint(galleryService)) // openAI compatible API endpoint // chat - app.Post("/v1/chat/completions", chatEndpoint(cm, options)) - app.Post("/chat/completions", chatEndpoint(cm, options)) + app.Post("/v1/chat/completions", openai.ChatEndpoint(cm, options)) + app.Post("/chat/completions", openai.ChatEndpoint(cm, options)) // edit - app.Post("/v1/edits", editEndpoint(cm, options)) - app.Post("/edits", editEndpoint(cm, options)) + app.Post("/v1/edits", openai.EditEndpoint(cm, options)) + app.Post("/edits", openai.EditEndpoint(cm, options)) // completion - app.Post("/v1/completions", completionEndpoint(cm, options)) - app.Post("/completions", completionEndpoint(cm, options)) - app.Post("/v1/engines/:model/completions", completionEndpoint(cm, options)) + app.Post("/v1/completions", openai.CompletionEndpoint(cm, options)) + app.Post("/completions", openai.CompletionEndpoint(cm, options)) + app.Post("/v1/engines/:model/completions", openai.CompletionEndpoint(cm, options)) // embeddings - app.Post("/v1/embeddings", embeddingsEndpoint(cm, options)) - app.Post("/embeddings", embeddingsEndpoint(cm, options)) - app.Post("/v1/engines/:model/embeddings", embeddingsEndpoint(cm, options)) + app.Post("/v1/embeddings", openai.EmbeddingsEndpoint(cm, options)) + app.Post("/embeddings", openai.EmbeddingsEndpoint(cm, options)) + app.Post("/v1/engines/:model/embeddings", openai.EmbeddingsEndpoint(cm, options)) // audio - app.Post("/v1/audio/transcriptions", transcriptEndpoint(cm, options)) - app.Post("/tts", ttsEndpoint(cm, options)) + app.Post("/v1/audio/transcriptions", openai.TranscriptEndpoint(cm, options)) + app.Post("/tts", localai.TTSEndpoint(cm, options)) // images - app.Post("/v1/images/generations", imageEndpoint(cm, options)) + app.Post("/v1/images/generations", openai.ImageEndpoint(cm, options)) - if options.imageDir != "" { - app.Static("/generated-images", options.imageDir) + if options.ImageDir != "" { + app.Static("/generated-images", options.ImageDir) } - if options.audioDir != "" { - app.Static("/generated-audio", options.audioDir) + if options.AudioDir != "" { + app.Static("/generated-audio", options.AudioDir) } ok := func(c *fiber.Ctx) error { @@ -164,8 +170,8 @@ func App(opts ...AppOption) (*fiber.App, error) { app.Get("/readyz", ok) // models - app.Get("/v1/models", listModels(options.loader, cm)) - app.Get("/models", listModels(options.loader, cm)) + app.Get("/v1/models", openai.ListModelsEndpoint(options.Loader, cm)) + app.Get("/models", openai.ListModelsEndpoint(options.Loader, cm)) return app, nil } diff --git a/api/api_test.go b/api/api_test.go index 43aa30bb..a69e60d2 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -13,6 +13,7 @@ import ( "runtime" . "github.com/go-skynet/LocalAI/api" + "github.com/go-skynet/LocalAI/api/options" "github.com/go-skynet/LocalAI/pkg/gallery" "github.com/go-skynet/LocalAI/pkg/model" "github.com/go-skynet/LocalAI/pkg/utils" @@ -154,9 +155,10 @@ var _ = Describe("API test", func() { }, } - app, err = App(WithContext(c), - WithGalleries(galleries), - WithModelLoader(modelLoader), WithBackendAssets(backendAssets), WithBackendAssetsOutput(tmpdir)) + app, err = App( + options.WithContext(c), + options.WithGalleries(galleries), + options.WithModelLoader(modelLoader), options.WithBackendAssets(backendAssets), options.WithBackendAssetsOutput(tmpdir)) Expect(err).ToNot(HaveOccurred()) go app.Listen("127.0.0.1:9090") @@ -342,7 +344,7 @@ var _ = Describe("API test", func() { c, cancel = context.WithCancel(context.Background()) var err error - app, err = App(WithContext(c), WithModelLoader(modelLoader)) + app, err = App(options.WithContext(c), options.WithModelLoader(modelLoader)) Expect(err).ToNot(HaveOccurred()) go app.Listen("127.0.0.1:9090") @@ -462,7 +464,7 @@ var _ = Describe("API test", func() { c, cancel = context.WithCancel(context.Background()) var err error - app, err = App(WithContext(c), WithModelLoader(modelLoader), WithConfigFile(os.Getenv("CONFIG_FILE"))) + app, err = App(options.WithContext(c), options.WithModelLoader(modelLoader), options.WithConfigFile(os.Getenv("CONFIG_FILE"))) Expect(err).ToNot(HaveOccurred()) go app.Listen("127.0.0.1:9090") diff --git a/api/backend/embeddings.go b/api/backend/embeddings.go new file mode 100644 index 00000000..cb77b6f5 --- /dev/null +++ b/api/backend/embeddings.go @@ -0,0 +1,107 @@ +package backend + +import ( + "context" + "fmt" + "sync" + + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/options" + "github.com/go-skynet/LocalAI/pkg/grpc" + model "github.com/go-skynet/LocalAI/pkg/model" + bert "github.com/go-skynet/go-bert.cpp" +) + +func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.Config, o *options.Option) (func() ([]float32, error), error) { + if !c.Embeddings { + return nil, fmt.Errorf("endpoint disabled for this model by API configuration") + } + + modelFile := c.Model + + grpcOpts := gRPCModelOpts(c) + + var inferenceModel interface{} + var err error + + opts := []model.Option{ + model.WithLoadGRPCOpts(grpcOpts), + model.WithThreads(uint32(c.Threads)), + model.WithAssetDir(o.AssetsDestination), + model.WithModelFile(modelFile), + } + + if c.Backend == "" { + inferenceModel, err = loader.GreedyLoader(opts...) + } else { + opts = append(opts, model.WithBackendString(c.Backend)) + inferenceModel, err = loader.BackendLoader(opts...) + } + if err != nil { + return nil, err + } + + var fn func() ([]float32, error) + switch model := inferenceModel.(type) { + case *grpc.Client: + fn = func() ([]float32, error) { + predictOptions := gRPCPredictOpts(c, loader.ModelPath) + if len(tokens) > 0 { + embeds := []int32{} + + for _, t := range tokens { + embeds = append(embeds, int32(t)) + } + predictOptions.EmbeddingTokens = embeds + + res, err := model.Embeddings(context.TODO(), predictOptions) + if err != nil { + return nil, err + } + + return res.Embeddings, nil + } + predictOptions.Embeddings = s + + res, err := model.Embeddings(context.TODO(), predictOptions) + if err != nil { + return nil, err + } + + return res.Embeddings, nil + } + + // bert embeddings + case *bert.Bert: + fn = func() ([]float32, error) { + if len(tokens) > 0 { + return model.TokenEmbeddings(tokens, bert.SetThreads(c.Threads)) + } + return model.Embeddings(s, bert.SetThreads(c.Threads)) + } + default: + fn = func() ([]float32, error) { + return nil, fmt.Errorf("embeddings not supported by the backend") + } + } + + return func() ([]float32, error) { + // This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 + l := Lock(modelFile) + defer l.Unlock() + + embeds, err := fn() + if err != nil { + return embeds, err + } + // Remove trailing 0s + for i := len(embeds) - 1; i >= 0; i-- { + if embeds[i] == 0.0 { + embeds = embeds[:i] + } else { + break + } + } + return embeds, nil + }, nil +} diff --git a/api/backend/image.go b/api/backend/image.go new file mode 100644 index 00000000..47ae8428 --- /dev/null +++ b/api/backend/image.go @@ -0,0 +1,56 @@ +package backend + +import ( + "fmt" + "sync" + + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/options" + model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/go-skynet/LocalAI/pkg/stablediffusion" +) + +func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, dst string, loader *model.ModelLoader, c config.Config, o *options.Option) (func() error, error) { + if c.Backend != model.StableDiffusionBackend { + return nil, fmt.Errorf("endpoint only working with stablediffusion models") + } + + inferenceModel, err := loader.BackendLoader( + model.WithBackendString(c.Backend), + model.WithAssetDir(o.AssetsDestination), + model.WithThreads(uint32(c.Threads)), + model.WithModelFile(c.ImageGenerationAssets), + ) + if err != nil { + return nil, err + } + + var fn func() error + switch model := inferenceModel.(type) { + case *stablediffusion.StableDiffusion: + fn = func() error { + return model.GenerateImage(height, width, mode, step, seed, positive_prompt, negative_prompt, dst) + } + + default: + fn = func() error { + return fmt.Errorf("creation of images not supported by the backend") + } + } + + return func() error { + // This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 + mutexMap.Lock() + l, ok := mutexes[c.Backend] + if !ok { + m := &sync.Mutex{} + mutexes[c.Backend] = m + l = m + } + mutexMap.Unlock() + l.Lock() + defer l.Unlock() + + return fn() + }, nil +} diff --git a/api/backend/llm.go b/api/backend/llm.go new file mode 100644 index 00000000..d2f8ef65 --- /dev/null +++ b/api/backend/llm.go @@ -0,0 +1,160 @@ +package backend + +import ( + "context" + "regexp" + "strings" + "sync" + + "github.com/donomii/go-rwkv.cpp" + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/options" + "github.com/go-skynet/LocalAI/pkg/grpc" + "github.com/go-skynet/LocalAI/pkg/langchain" + model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/go-skynet/bloomz.cpp" +) + +func ModelInference(s string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string) bool) (func() (string, error), error) { + supportStreams := false + modelFile := c.Model + + grpcOpts := gRPCModelOpts(c) + + var inferenceModel interface{} + var err error + + opts := []model.Option{ + model.WithLoadGRPCOpts(grpcOpts), + model.WithThreads(uint32(c.Threads)), // GPT4all uses this + model.WithAssetDir(o.AssetsDestination), + model.WithModelFile(modelFile), + } + + if c.Backend == "" { + inferenceModel, err = loader.GreedyLoader(opts...) + } else { + opts = append(opts, model.WithBackendString(c.Backend)) + inferenceModel, err = loader.BackendLoader(opts...) + } + if err != nil { + return nil, err + } + + var fn func() (string, error) + + switch model := inferenceModel.(type) { + case *rwkv.RwkvState: + supportStreams = true + + fn = func() (string, error) { + stopWord := "\n" + if len(c.StopWords) > 0 { + stopWord = c.StopWords[0] + } + + if err := model.ProcessInput(s); err != nil { + return "", err + } + + response := model.GenerateResponse(c.Maxtokens, stopWord, float32(c.Temperature), float32(c.TopP), tokenCallback) + + return response, nil + } + case *bloomz.Bloomz: + fn = func() (string, error) { + // Generate the prediction using the language model + predictOptions := []bloomz.PredictOption{ + bloomz.SetTemperature(c.Temperature), + bloomz.SetTopP(c.TopP), + bloomz.SetTopK(c.TopK), + bloomz.SetTokens(c.Maxtokens), + bloomz.SetThreads(c.Threads), + } + + if c.Seed != 0 { + predictOptions = append(predictOptions, bloomz.SetSeed(c.Seed)) + } + + return model.Predict( + s, + predictOptions..., + ) + } + + case *grpc.Client: + // in GRPC, the backend is supposed to answer to 1 single token if stream is not supported + supportStreams = true + fn = func() (string, error) { + + opts := gRPCPredictOpts(c, loader.ModelPath) + opts.Prompt = s + if tokenCallback != nil { + ss := "" + err := model.PredictStream(context.TODO(), opts, func(s string) { + tokenCallback(s) + ss += s + }) + return ss, err + } else { + reply, err := model.Predict(context.TODO(), opts) + return reply.Message, err + } + } + case *langchain.HuggingFace: + fn = func() (string, error) { + + // Generate the prediction using the language model + predictOptions := []langchain.PredictOption{ + langchain.SetModel(c.Model), + langchain.SetMaxTokens(c.Maxtokens), + langchain.SetTemperature(c.Temperature), + langchain.SetStopWords(c.StopWords), + } + + pred, er := model.PredictHuggingFace(s, predictOptions...) + if er != nil { + return "", er + } + return pred.Completion, nil + } + } + + return func() (string, error) { + // This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 + l := Lock(modelFile) + defer l.Unlock() + + res, err := fn() + if tokenCallback != nil && !supportStreams { + tokenCallback(res) + } + return res, err + }, nil +} + +var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp) +var mu sync.Mutex = sync.Mutex{} + +func Finetune(config config.Config, input, prediction string) string { + if config.Echo { + prediction = input + prediction + } + + for _, c := range config.Cutstrings { + mu.Lock() + reg, ok := cutstrings[c] + if !ok { + cutstrings[c] = regexp.MustCompile(c) + reg = cutstrings[c] + } + mu.Unlock() + prediction = reg.ReplaceAllString(prediction, "") + } + + for _, c := range config.TrimSpace { + prediction = strings.TrimSpace(strings.TrimPrefix(prediction, c)) + } + return prediction + +} diff --git a/api/backend/lock.go b/api/backend/lock.go new file mode 100644 index 00000000..6b4f577c --- /dev/null +++ b/api/backend/lock.go @@ -0,0 +1,22 @@ +package backend + +import "sync" + +// mutex still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 +var mutexMap sync.Mutex +var mutexes map[string]*sync.Mutex = make(map[string]*sync.Mutex) + +func Lock(s string) *sync.Mutex { + // This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 + mutexMap.Lock() + l, ok := mutexes[s] + if !ok { + m := &sync.Mutex{} + mutexes[s] = m + l = m + } + mutexMap.Unlock() + l.Lock() + + return l +} diff --git a/api/backend/options.go b/api/backend/options.go new file mode 100644 index 00000000..f19dbaeb --- /dev/null +++ b/api/backend/options.go @@ -0,0 +1,98 @@ +package backend + +import ( + "os" + "path/filepath" + + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/pkg/langchain" + "github.com/go-skynet/bloomz.cpp" +) + +func langchainOptions(c config.Config) []langchain.PredictOption { + return []langchain.PredictOption{ + langchain.SetModel(c.Model), + langchain.SetMaxTokens(c.Maxtokens), + langchain.SetTemperature(c.Temperature), + langchain.SetStopWords(c.StopWords), + } +} + +func bloomzOptions(c config.Config) []bloomz.PredictOption { + // Generate the prediction using the language model + predictOptions := []bloomz.PredictOption{ + bloomz.SetTemperature(c.Temperature), + bloomz.SetTopP(c.TopP), + bloomz.SetTopK(c.TopK), + bloomz.SetTokens(c.Maxtokens), + bloomz.SetThreads(c.Threads), + } + + if c.Seed != 0 { + predictOptions = append(predictOptions, bloomz.SetSeed(c.Seed)) + } + return predictOptions +} +func gRPCModelOpts(c config.Config) *pb.ModelOptions { + b := 512 + if c.Batch != 0 { + b = c.Batch + } + return &pb.ModelOptions{ + ContextSize: int32(c.ContextSize), + Seed: int32(c.Seed), + NBatch: int32(b), + F16Memory: c.F16, + MLock: c.MMlock, + NUMA: c.NUMA, + Embeddings: c.Embeddings, + LowVRAM: c.LowVRAM, + NGPULayers: int32(c.NGPULayers), + MMap: c.MMap, + MainGPU: c.MainGPU, + Threads: int32(c.Threads), + TensorSplit: c.TensorSplit, + } +} + +func gRPCPredictOpts(c config.Config, modelPath string) *pb.PredictOptions { + promptCachePath := "" + if c.PromptCachePath != "" { + p := filepath.Join(modelPath, c.PromptCachePath) + os.MkdirAll(filepath.Dir(p), 0755) + promptCachePath = p + } + return &pb.PredictOptions{ + Temperature: float32(c.Temperature), + TopP: float32(c.TopP), + TopK: int32(c.TopK), + Tokens: int32(c.Maxtokens), + Threads: int32(c.Threads), + PromptCacheAll: c.PromptCacheAll, + PromptCacheRO: c.PromptCacheRO, + PromptCachePath: promptCachePath, + F16KV: c.F16, + DebugMode: c.Debug, + Grammar: c.Grammar, + + Mirostat: int32(c.Mirostat), + MirostatETA: float32(c.MirostatETA), + MirostatTAU: float32(c.MirostatTAU), + Debug: c.Debug, + StopPrompts: c.StopWords, + Repeat: int32(c.RepeatPenalty), + NKeep: int32(c.Keep), + Batch: int32(c.Batch), + IgnoreEOS: c.IgnoreEOS, + Seed: int32(c.Seed), + FrequencyPenalty: float32(c.FrequencyPenalty), + MLock: c.MMlock, + MMap: c.MMap, + MainGPU: c.MainGPU, + TensorSplit: c.TensorSplit, + TailFreeSamplingZ: float32(c.TFZ), + TypicalP: float32(c.TypicalP), + } +} diff --git a/api/config.go b/api/config.go deleted file mode 100644 index 57fe0d10..00000000 --- a/api/config.go +++ /dev/null @@ -1,401 +0,0 @@ -package api - -import ( - "encoding/json" - "fmt" - "io/fs" - "os" - "path/filepath" - "strings" - "sync" - - model "github.com/go-skynet/LocalAI/pkg/model" - "github.com/gofiber/fiber/v2" - "github.com/rs/zerolog/log" - "gopkg.in/yaml.v3" -) - -type Config struct { - OpenAIRequest `yaml:"parameters"` - Name string `yaml:"name"` - StopWords []string `yaml:"stopwords"` - Cutstrings []string `yaml:"cutstrings"` - TrimSpace []string `yaml:"trimspace"` - ContextSize int `yaml:"context_size"` - F16 bool `yaml:"f16"` - NUMA bool `yaml:"numa"` - Threads int `yaml:"threads"` - Debug bool `yaml:"debug"` - Roles map[string]string `yaml:"roles"` - Embeddings bool `yaml:"embeddings"` - Backend string `yaml:"backend"` - TemplateConfig TemplateConfig `yaml:"template"` - MirostatETA float64 `yaml:"mirostat_eta"` - MirostatTAU float64 `yaml:"mirostat_tau"` - Mirostat int `yaml:"mirostat"` - NGPULayers int `yaml:"gpu_layers"` - MMap bool `yaml:"mmap"` - MMlock bool `yaml:"mmlock"` - LowVRAM bool `yaml:"low_vram"` - - TensorSplit string `yaml:"tensor_split"` - MainGPU string `yaml:"main_gpu"` - ImageGenerationAssets string `yaml:"asset_dir"` - - PromptCachePath string `yaml:"prompt_cache_path"` - PromptCacheAll bool `yaml:"prompt_cache_all"` - PromptCacheRO bool `yaml:"prompt_cache_ro"` - - Grammar string `yaml:"grammar"` - - FunctionsConfig Functions `yaml:"function"` - - PromptStrings, InputStrings []string - InputToken [][]int - functionCallString, functionCallNameString string -} - -type Functions struct { - DisableNoAction bool `yaml:"disable_no_action"` - NoActionFunctionName string `yaml:"no_action_function_name"` - NoActionDescriptionName string `yaml:"no_action_description_name"` -} - -type TemplateConfig struct { - Completion string `yaml:"completion"` - Functions string `yaml:"function"` - Chat string `yaml:"chat"` - Edit string `yaml:"edit"` -} - -type ConfigMerger struct { - configs map[string]Config - sync.Mutex -} - -func defaultConfig(modelFile string) *Config { - return &Config{ - OpenAIRequest: defaultRequest(modelFile), - } -} - -func NewConfigMerger() *ConfigMerger { - return &ConfigMerger{ - configs: make(map[string]Config), - } -} -func ReadConfigFile(file string) ([]*Config, error) { - c := &[]*Config{} - f, err := os.ReadFile(file) - if err != nil { - return nil, fmt.Errorf("cannot read config file: %w", err) - } - if err := yaml.Unmarshal(f, c); err != nil { - return nil, fmt.Errorf("cannot unmarshal config file: %w", err) - } - - return *c, nil -} - -func ReadConfig(file string) (*Config, error) { - c := &Config{} - f, err := os.ReadFile(file) - if err != nil { - return nil, fmt.Errorf("cannot read config file: %w", err) - } - if err := yaml.Unmarshal(f, c); err != nil { - return nil, fmt.Errorf("cannot unmarshal config file: %w", err) - } - - return c, nil -} - -func (cm *ConfigMerger) LoadConfigFile(file string) error { - cm.Lock() - defer cm.Unlock() - c, err := ReadConfigFile(file) - if err != nil { - return fmt.Errorf("cannot load config file: %w", err) - } - - for _, cc := range c { - cm.configs[cc.Name] = *cc - } - return nil -} - -func (cm *ConfigMerger) LoadConfig(file string) error { - cm.Lock() - defer cm.Unlock() - c, err := ReadConfig(file) - if err != nil { - return fmt.Errorf("cannot read config file: %w", err) - } - - cm.configs[c.Name] = *c - return nil -} - -func (cm *ConfigMerger) GetConfig(m string) (Config, bool) { - cm.Lock() - defer cm.Unlock() - v, exists := cm.configs[m] - return v, exists -} - -func (cm *ConfigMerger) ListConfigs() []string { - cm.Lock() - defer cm.Unlock() - var res []string - for k := range cm.configs { - res = append(res, k) - } - return res -} - -func (cm *ConfigMerger) LoadConfigs(path string) error { - cm.Lock() - defer cm.Unlock() - entries, err := os.ReadDir(path) - if err != nil { - return err - } - files := make([]fs.FileInfo, 0, len(entries)) - for _, entry := range entries { - info, err := entry.Info() - if err != nil { - return err - } - files = append(files, info) - } - for _, file := range files { - // Skip templates, YAML and .keep files - if !strings.Contains(file.Name(), ".yaml") { - continue - } - c, err := ReadConfig(filepath.Join(path, file.Name())) - if err == nil { - cm.configs[c.Name] = *c - } - } - - return nil -} - -func updateConfig(config *Config, input *OpenAIRequest) { - if input.Echo { - config.Echo = input.Echo - } - if input.TopK != 0 { - config.TopK = input.TopK - } - if input.TopP != 0 { - config.TopP = input.TopP - } - - if input.Grammar != "" { - config.Grammar = input.Grammar - } - - if input.Temperature != 0 { - config.Temperature = input.Temperature - } - - if input.Maxtokens != 0 { - config.Maxtokens = input.Maxtokens - } - - switch stop := input.Stop.(type) { - case string: - if stop != "" { - config.StopWords = append(config.StopWords, stop) - } - case []interface{}: - for _, pp := range stop { - if s, ok := pp.(string); ok { - config.StopWords = append(config.StopWords, s) - } - } - } - - if input.RepeatPenalty != 0 { - config.RepeatPenalty = input.RepeatPenalty - } - - if input.Keep != 0 { - config.Keep = input.Keep - } - - if input.Batch != 0 { - config.Batch = input.Batch - } - - if input.F16 { - config.F16 = input.F16 - } - - if input.IgnoreEOS { - config.IgnoreEOS = input.IgnoreEOS - } - - if input.Seed != 0 { - config.Seed = input.Seed - } - - if input.Mirostat != 0 { - config.Mirostat = input.Mirostat - } - - if input.MirostatETA != 0 { - config.MirostatETA = input.MirostatETA - } - - if input.MirostatTAU != 0 { - config.MirostatTAU = input.MirostatTAU - } - - if input.TypicalP != 0 { - config.TypicalP = input.TypicalP - } - - switch inputs := input.Input.(type) { - case string: - if inputs != "" { - config.InputStrings = append(config.InputStrings, inputs) - } - case []interface{}: - for _, pp := range inputs { - switch i := pp.(type) { - case string: - config.InputStrings = append(config.InputStrings, i) - case []interface{}: - tokens := []int{} - for _, ii := range i { - tokens = append(tokens, int(ii.(float64))) - } - config.InputToken = append(config.InputToken, tokens) - } - } - } - // Can be either a string or an object - switch fnc := input.FunctionCall.(type) { - case string: - if fnc != "" { - config.functionCallString = fnc - } - case map[string]interface{}: - var name string - n, exists := fnc["name"] - if exists { - nn, e := n.(string) - if e { - name = nn - } - } - config.functionCallNameString = name - } - - switch p := input.Prompt.(type) { - case string: - config.PromptStrings = append(config.PromptStrings, p) - case []interface{}: - for _, pp := range p { - if s, ok := pp.(string); ok { - config.PromptStrings = append(config.PromptStrings, s) - } - } - } -} -func readInput(c *fiber.Ctx, loader *model.ModelLoader, randomModel bool) (string, *OpenAIRequest, error) { - input := new(OpenAIRequest) - // Get input data from the request body - if err := c.BodyParser(input); err != nil { - return "", nil, err - } - - modelFile := input.Model - - if c.Params("model") != "" { - modelFile = c.Params("model") - } - - received, _ := json.Marshal(input) - - log.Debug().Msgf("Request received: %s", string(received)) - - // Set model from bearer token, if available - bearer := strings.TrimLeft(c.Get("authorization"), "Bearer ") - bearerExists := bearer != "" && loader.ExistsInModelPath(bearer) - - // If no model was specified, take the first available - if modelFile == "" && !bearerExists && randomModel { - models, _ := loader.ListModels() - if len(models) > 0 { - modelFile = models[0] - log.Debug().Msgf("No model specified, using: %s", modelFile) - } else { - log.Debug().Msgf("No model specified, returning error") - return "", nil, fmt.Errorf("no model specified") - } - } - - // If a model is found in bearer token takes precedence - if bearerExists { - log.Debug().Msgf("Using model from bearer token: %s", bearer) - modelFile = bearer - } - return modelFile, input, nil -} - -func readConfig(modelFile string, input *OpenAIRequest, cm *ConfigMerger, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*Config, *OpenAIRequest, error) { - // Load a config file if present after the model name - modelConfig := filepath.Join(loader.ModelPath, modelFile+".yaml") - - var config *Config - - defaults := func() { - config = defaultConfig(modelFile) - config.ContextSize = ctx - config.Threads = threads - config.F16 = f16 - config.Debug = debug - } - - cfg, exists := cm.GetConfig(modelFile) - if !exists { - if _, err := os.Stat(modelConfig); err == nil { - if err := cm.LoadConfig(modelConfig); err != nil { - return nil, nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error()) - } - cfg, exists = cm.GetConfig(modelFile) - if exists { - config = &cfg - } else { - defaults() - } - } else { - defaults() - } - } else { - config = &cfg - } - - // Set the parameters for the language model prediction - updateConfig(config, input) - - // Don't allow 0 as setting - if config.Threads == 0 { - if threads != 0 { - config.Threads = threads - } else { - config.Threads = 4 - } - } - - // Enforce debug flag if passed from CLI - if debug { - config.Debug = true - } - - return config, input, nil -} diff --git a/api/config/config.go b/api/config/config.go new file mode 100644 index 00000000..9df8d3e0 --- /dev/null +++ b/api/config/config.go @@ -0,0 +1,209 @@ +package api_config + +import ( + "fmt" + "io/fs" + "os" + "path/filepath" + "strings" + "sync" + + "gopkg.in/yaml.v3" +) + +type Config struct { + PredictionOptions `yaml:"parameters"` + Name string `yaml:"name"` + StopWords []string `yaml:"stopwords"` + Cutstrings []string `yaml:"cutstrings"` + TrimSpace []string `yaml:"trimspace"` + ContextSize int `yaml:"context_size"` + F16 bool `yaml:"f16"` + NUMA bool `yaml:"numa"` + Threads int `yaml:"threads"` + Debug bool `yaml:"debug"` + Roles map[string]string `yaml:"roles"` + Embeddings bool `yaml:"embeddings"` + Backend string `yaml:"backend"` + TemplateConfig TemplateConfig `yaml:"template"` + MirostatETA float64 `yaml:"mirostat_eta"` + MirostatTAU float64 `yaml:"mirostat_tau"` + Mirostat int `yaml:"mirostat"` + NGPULayers int `yaml:"gpu_layers"` + MMap bool `yaml:"mmap"` + MMlock bool `yaml:"mmlock"` + LowVRAM bool `yaml:"low_vram"` + + TensorSplit string `yaml:"tensor_split"` + MainGPU string `yaml:"main_gpu"` + ImageGenerationAssets string `yaml:"asset_dir"` + + PromptCachePath string `yaml:"prompt_cache_path"` + PromptCacheAll bool `yaml:"prompt_cache_all"` + PromptCacheRO bool `yaml:"prompt_cache_ro"` + + Grammar string `yaml:"grammar"` + + PromptStrings, InputStrings []string + InputToken [][]int + functionCallString, functionCallNameString string + + FunctionsConfig Functions `yaml:"function"` +} + +type Functions struct { + DisableNoAction bool `yaml:"disable_no_action"` + NoActionFunctionName string `yaml:"no_action_function_name"` + NoActionDescriptionName string `yaml:"no_action_description_name"` +} + +type TemplateConfig struct { + Completion string `yaml:"completion"` + Functions string `yaml:"function"` + Chat string `yaml:"chat"` + Edit string `yaml:"edit"` +} + +type ConfigLoader struct { + configs map[string]Config + sync.Mutex +} + +func (c *Config) SetFunctionCallString(s string) { + c.functionCallString = s +} + +func (c *Config) SetFunctionCallNameString(s string) { + c.functionCallNameString = s +} + +func (c *Config) ShouldUseFunctions() bool { + return ((c.functionCallString != "none" || c.functionCallString == "") || c.ShouldCallSpecificFunction()) +} + +func (c *Config) ShouldCallSpecificFunction() bool { + return len(c.functionCallNameString) > 0 +} + +func (c *Config) FunctionToCall() string { + return c.functionCallNameString +} + +func defaultPredictOptions(modelFile string) PredictionOptions { + return PredictionOptions{ + TopP: 0.7, + TopK: 80, + Maxtokens: 512, + Temperature: 0.9, + Model: modelFile, + } +} + +func DefaultConfig(modelFile string) *Config { + return &Config{ + PredictionOptions: defaultPredictOptions(modelFile), + } +} + +func NewConfigLoader() *ConfigLoader { + return &ConfigLoader{ + configs: make(map[string]Config), + } +} +func ReadConfigFile(file string) ([]*Config, error) { + c := &[]*Config{} + f, err := os.ReadFile(file) + if err != nil { + return nil, fmt.Errorf("cannot read config file: %w", err) + } + if err := yaml.Unmarshal(f, c); err != nil { + return nil, fmt.Errorf("cannot unmarshal config file: %w", err) + } + + return *c, nil +} + +func ReadConfig(file string) (*Config, error) { + c := &Config{} + f, err := os.ReadFile(file) + if err != nil { + return nil, fmt.Errorf("cannot read config file: %w", err) + } + if err := yaml.Unmarshal(f, c); err != nil { + return nil, fmt.Errorf("cannot unmarshal config file: %w", err) + } + + return c, nil +} + +func (cm *ConfigLoader) LoadConfigFile(file string) error { + cm.Lock() + defer cm.Unlock() + c, err := ReadConfigFile(file) + if err != nil { + return fmt.Errorf("cannot load config file: %w", err) + } + + for _, cc := range c { + cm.configs[cc.Name] = *cc + } + return nil +} + +func (cm *ConfigLoader) LoadConfig(file string) error { + cm.Lock() + defer cm.Unlock() + c, err := ReadConfig(file) + if err != nil { + return fmt.Errorf("cannot read config file: %w", err) + } + + cm.configs[c.Name] = *c + return nil +} + +func (cm *ConfigLoader) GetConfig(m string) (Config, bool) { + cm.Lock() + defer cm.Unlock() + v, exists := cm.configs[m] + return v, exists +} + +func (cm *ConfigLoader) ListConfigs() []string { + cm.Lock() + defer cm.Unlock() + var res []string + for k := range cm.configs { + res = append(res, k) + } + return res +} + +func (cm *ConfigLoader) LoadConfigs(path string) error { + cm.Lock() + defer cm.Unlock() + entries, err := os.ReadDir(path) + if err != nil { + return err + } + files := make([]fs.FileInfo, 0, len(entries)) + for _, entry := range entries { + info, err := entry.Info() + if err != nil { + return err + } + files = append(files, info) + } + for _, file := range files { + // Skip templates, YAML and .keep files + if !strings.Contains(file.Name(), ".yaml") { + continue + } + c, err := ReadConfig(filepath.Join(path, file.Name())) + if err == nil { + cm.configs[c.Name] = *c + } + } + + return nil +} diff --git a/api/config_test.go b/api/config/config_test.go similarity index 62% rename from api/config_test.go rename to api/config/config_test.go index 626b90be..4b00d587 100644 --- a/api/config_test.go +++ b/api/config/config_test.go @@ -1,8 +1,10 @@ -package api +package api_config_test import ( "os" + . "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/options" "github.com/go-skynet/LocalAI/pkg/model" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -26,29 +28,29 @@ var _ = Describe("Test cases for config related functions", func() { }) It("Test LoadConfigs", func() { - cm := NewConfigMerger() - options := newOptions() + cm := NewConfigLoader() + opts := options.NewOptions() modelLoader := model.NewModelLoader(os.Getenv("MODELS_PATH")) - WithModelLoader(modelLoader)(options) + options.WithModelLoader(modelLoader)(opts) - err := cm.LoadConfigs(options.loader.ModelPath) + err := cm.LoadConfigs(opts.Loader.ModelPath) Expect(err).To(BeNil()) - Expect(cm.configs).ToNot(BeNil()) + Expect(cm.ListConfigs()).ToNot(BeNil()) // config should includes gpt4all models's api.config - Expect(cm.configs).To(HaveKey("gpt4all")) + Expect(cm.ListConfigs()).To(ContainElements("gpt4all")) // config should includes gpt2 models's api.config - Expect(cm.configs).To(HaveKey("gpt4all-2")) + Expect(cm.ListConfigs()).To(ContainElements("gpt4all-2")) // config should includes text-embedding-ada-002 models's api.config - Expect(cm.configs).To(HaveKey("text-embedding-ada-002")) + Expect(cm.ListConfigs()).To(ContainElements("text-embedding-ada-002")) // config should includes rwkv_test models's api.config - Expect(cm.configs).To(HaveKey("rwkv_test")) + Expect(cm.ListConfigs()).To(ContainElements("rwkv_test")) // config should includes whisper-1 models's api.config - Expect(cm.configs).To(HaveKey("whisper-1")) + Expect(cm.ListConfigs()).To(ContainElements("whisper-1")) }) }) }) diff --git a/api/config/prediction.go b/api/config/prediction.go new file mode 100644 index 00000000..59f4fcb1 --- /dev/null +++ b/api/config/prediction.go @@ -0,0 +1,37 @@ +package api_config + +type PredictionOptions struct { + + // Also part of the OpenAI official spec + Model string `json:"model" yaml:"model"` + + // Also part of the OpenAI official spec + Language string `json:"language"` + + // Also part of the OpenAI official spec. use it for returning multiple results + N int `json:"n"` + + // Common options between all the API calls, part of the OpenAI spec + TopP float64 `json:"top_p" yaml:"top_p"` + TopK int `json:"top_k" yaml:"top_k"` + Temperature float64 `json:"temperature" yaml:"temperature"` + Maxtokens int `json:"max_tokens" yaml:"max_tokens"` + Echo bool `json:"echo"` + + // Custom parameters - not present in the OpenAI API + Batch int `json:"batch" yaml:"batch"` + F16 bool `json:"f16" yaml:"f16"` + IgnoreEOS bool `json:"ignore_eos" yaml:"ignore_eos"` + RepeatPenalty float64 `json:"repeat_penalty" yaml:"repeat_penalty"` + Keep int `json:"n_keep" yaml:"n_keep"` + + MirostatETA float64 `json:"mirostat_eta" yaml:"mirostat_eta"` + MirostatTAU float64 `json:"mirostat_tau" yaml:"mirostat_tau"` + Mirostat int `json:"mirostat" yaml:"mirostat"` + + FrequencyPenalty float64 `json:"frequency_penalty" yaml:"frequency_penalty"` + TFZ float64 `json:"tfz" yaml:"tfz"` + + TypicalP float64 `json:"typical_p" yaml:"typical_p"` + Seed int `json:"seed" yaml:"seed"` +} diff --git a/api/gallery.go b/api/localai/gallery.go similarity index 86% rename from api/gallery.go rename to api/localai/gallery.go index 1c0cec91..feae2942 100644 --- a/api/gallery.go +++ b/api/localai/gallery.go @@ -1,4 +1,4 @@ -package api +package localai import ( "context" @@ -9,6 +9,7 @@ import ( json "github.com/json-iterator/go" + config "github.com/go-skynet/LocalAI/api/config" "github.com/go-skynet/LocalAI/pkg/gallery" "github.com/gofiber/fiber/v2" "github.com/google/uuid" @@ -38,7 +39,7 @@ type galleryApplier struct { statuses map[string]*galleryOpStatus } -func newGalleryApplier(modelPath string) *galleryApplier { +func NewGalleryService(modelPath string) *galleryApplier { return &galleryApplier{ modelPath: modelPath, C: make(chan galleryOp), @@ -47,7 +48,7 @@ func newGalleryApplier(modelPath string) *galleryApplier { } // prepareModel applies a -func prepareModel(modelPath string, req gallery.GalleryModel, cm *ConfigMerger, downloadStatus func(string, string, string, float64)) error { +func prepareModel(modelPath string, req gallery.GalleryModel, cm *config.ConfigLoader, downloadStatus func(string, string, string, float64)) error { config, err := gallery.GetGalleryConfigFromURL(req.URL) if err != nil { @@ -72,7 +73,7 @@ func (g *galleryApplier) getStatus(s string) *galleryOpStatus { return g.statuses[s] } -func (g *galleryApplier) start(c context.Context, cm *ConfigMerger) { +func (g *galleryApplier) Start(c context.Context, cm *config.ConfigLoader) { go func() { for { select { @@ -148,7 +149,7 @@ type galleryModel struct { ID string `json:"id"` } -func ApplyGalleryFromFile(modelPath, s string, cm *ConfigMerger, galleries []gallery.Gallery) error { +func ApplyGalleryFromFile(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery) error { dat, err := os.ReadFile(s) if err != nil { return err @@ -156,7 +157,7 @@ func ApplyGalleryFromFile(modelPath, s string, cm *ConfigMerger, galleries []gal return ApplyGalleryFromString(modelPath, string(dat), cm, galleries) } -func ApplyGalleryFromString(modelPath, s string, cm *ConfigMerger, galleries []gallery.Gallery) error { +func ApplyGalleryFromString(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery) error { var requests []galleryModel err := json.Unmarshal([]byte(s), &requests) if err != nil { @@ -174,7 +175,9 @@ func ApplyGalleryFromString(modelPath, s string, cm *ConfigMerger, galleries []g return err } -func getOpStatus(g *galleryApplier) func(c *fiber.Ctx) error { +/// Endpoints + +func GetOpStatusEndpoint(g *galleryApplier) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { status := g.getStatus(c.Params("uuid")) @@ -191,7 +194,7 @@ type GalleryModel struct { gallery.GalleryModel } -func applyModelGallery(modelPath string, cm *ConfigMerger, g chan galleryOp, galleries []gallery.Gallery) func(c *fiber.Ctx) error { +func ApplyModelGalleryEndpoint(modelPath string, cm *config.ConfigLoader, g chan galleryOp, galleries []gallery.Gallery) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { input := new(GalleryModel) // Get input data from the request body @@ -216,7 +219,7 @@ func applyModelGallery(modelPath string, cm *ConfigMerger, g chan galleryOp, gal } } -func listModelFromGallery(galleries []gallery.Gallery, basePath string) func(c *fiber.Ctx) error { +func ListModelFromGalleryEndpoint(galleries []gallery.Gallery, basePath string) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { log.Debug().Msgf("Listing models from galleries: %+v", galleries) diff --git a/api/localai.go b/api/localai/localai.go similarity index 68% rename from api/localai.go rename to api/localai/localai.go index 66eda5a6..f79e8896 100644 --- a/api/localai.go +++ b/api/localai/localai.go @@ -1,10 +1,13 @@ -package api +package localai import ( "fmt" "os" "path/filepath" + config "github.com/go-skynet/LocalAI/api/config" + + "github.com/go-skynet/LocalAI/api/options" model "github.com/go-skynet/LocalAI/pkg/model" "github.com/go-skynet/LocalAI/pkg/tts" "github.com/go-skynet/LocalAI/pkg/utils" @@ -32,7 +35,7 @@ func generateUniqueFileName(dir, baseName, ext string) string { } } -func ttsEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { +func TTSEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { input := new(TTSRequest) @@ -41,10 +44,10 @@ func ttsEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { return err } - piperModel, err := o.loader.BackendLoader( + piperModel, err := o.Loader.BackendLoader( model.WithBackendString(model.PiperBackend), model.WithModelFile(input.Model), - model.WithAssetDir(o.assetsDestination)) + model.WithAssetDir(o.AssetsDestination)) if err != nil { return err } @@ -58,16 +61,16 @@ func ttsEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { return fmt.Errorf("loader returned non-piper object %+v", w) } - if err := os.MkdirAll(o.audioDir, 0755); err != nil { + if err := os.MkdirAll(o.AudioDir, 0755); err != nil { return err } - fileName := generateUniqueFileName(o.audioDir, "piper", ".wav") - filePath := filepath.Join(o.audioDir, fileName) + fileName := generateUniqueFileName(o.AudioDir, "piper", ".wav") + filePath := filepath.Join(o.AudioDir, fileName) - modelPath := filepath.Join(o.loader.ModelPath, input.Model) + modelPath := filepath.Join(o.Loader.ModelPath, input.Model) - if err := utils.VerifyPath(modelPath, o.loader.ModelPath); err != nil { + if err := utils.VerifyPath(modelPath, o.Loader.ModelPath); err != nil { return err } diff --git a/api/openai.go b/api/openai.go deleted file mode 100644 index c39b1cc6..00000000 --- a/api/openai.go +++ /dev/null @@ -1,973 +0,0 @@ -package api - -import ( - "bufio" - "bytes" - "encoding/base64" - "encoding/json" - "errors" - "fmt" - "io" - "io/ioutil" - "net/http" - "os" - "path" - "path/filepath" - "strconv" - "strings" - - "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" - "github.com/go-skynet/LocalAI/pkg/grammar" - model "github.com/go-skynet/LocalAI/pkg/model" - whisperutil "github.com/go-skynet/LocalAI/pkg/whisper" - "github.com/gofiber/fiber/v2" - "github.com/rs/zerolog/log" - "github.com/valyala/fasthttp" -) - -// APIError provides error information returned by the OpenAI API. -type APIError struct { - Code any `json:"code,omitempty"` - Message string `json:"message"` - Param *string `json:"param,omitempty"` - Type string `json:"type"` -} - -type ErrorResponse struct { - Error *APIError `json:"error,omitempty"` -} - -type OpenAIUsage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` -} - -type Item struct { - Embedding []float32 `json:"embedding"` - Index int `json:"index"` - Object string `json:"object,omitempty"` - - // Images - URL string `json:"url,omitempty"` - B64JSON string `json:"b64_json,omitempty"` -} - -type OpenAIResponse struct { - Created int `json:"created,omitempty"` - Object string `json:"object,omitempty"` - ID string `json:"id,omitempty"` - Model string `json:"model,omitempty"` - Choices []Choice `json:"choices,omitempty"` - Data []Item `json:"data,omitempty"` - - Usage OpenAIUsage `json:"usage"` -} - -type Choice struct { - Index int `json:"index,omitempty"` - FinishReason string `json:"finish_reason,omitempty"` - Message *Message `json:"message,omitempty"` - Delta *Message `json:"delta,omitempty"` - Text string `json:"text,omitempty"` -} - -type Message struct { - // The message role - Role string `json:"role,omitempty" yaml:"role"` - // The message content - Content *string `json:"content" yaml:"content"` - // A result of a function call - FunctionCall interface{} `json:"function_call,omitempty" yaml:"function_call,omitempty"` -} - -type OpenAIModel struct { - ID string `json:"id"` - Object string `json:"object"` -} - -type OpenAIRequest struct { - Model string `json:"model" yaml:"model"` - - // whisper - File string `json:"file" validate:"required"` - Language string `json:"language"` - //whisper/image - ResponseFormat string `json:"response_format"` - // image - Size string `json:"size"` - // Prompt is read only by completion/image API calls - Prompt interface{} `json:"prompt" yaml:"prompt"` - - // Edit endpoint - Instruction string `json:"instruction" yaml:"instruction"` - Input interface{} `json:"input" yaml:"input"` - - Stop interface{} `json:"stop" yaml:"stop"` - - // Messages is read only by chat/completion API calls - Messages []Message `json:"messages" yaml:"messages"` - - // A list of available functions to call - Functions []grammar.Function `json:"functions" yaml:"functions"` - FunctionCall interface{} `json:"function_call" yaml:"function_call"` // might be a string or an object - - Stream bool `json:"stream"` - Echo bool `json:"echo"` - // Common options between all the API calls - TopP float64 `json:"top_p" yaml:"top_p"` - TopK int `json:"top_k" yaml:"top_k"` - Temperature float64 `json:"temperature" yaml:"temperature"` - Maxtokens int `json:"max_tokens" yaml:"max_tokens"` - - N int `json:"n"` - - // Custom parameters - not present in the OpenAI API - Batch int `json:"batch" yaml:"batch"` - F16 bool `json:"f16" yaml:"f16"` - IgnoreEOS bool `json:"ignore_eos" yaml:"ignore_eos"` - RepeatPenalty float64 `json:"repeat_penalty" yaml:"repeat_penalty"` - Keep int `json:"n_keep" yaml:"n_keep"` - - MirostatETA float64 `json:"mirostat_eta" yaml:"mirostat_eta"` - MirostatTAU float64 `json:"mirostat_tau" yaml:"mirostat_tau"` - Mirostat int `json:"mirostat" yaml:"mirostat"` - - FrequencyPenalty float64 `json:"frequency_penalty" yaml:"frequency_penalty"` - TFZ float64 `json:"tfz" yaml:"tfz"` - - Seed int `json:"seed" yaml:"seed"` - - // Image (not supported by OpenAI) - Mode int `json:"mode"` - Step int `json:"step"` - - // A grammar to constrain the LLM output - Grammar string `json:"grammar" yaml:"grammar"` - // A grammar object - JSONFunctionGrammarObject *grammar.JSONFunctionStructure `json:"grammar_json_functions" yaml:"grammar_json_functions"` - - TypicalP float64 `json:"typical_p" yaml:"typical_p"` -} - -func defaultRequest(modelFile string) OpenAIRequest { - return OpenAIRequest{ - TopP: 0.7, - TopK: 80, - Maxtokens: 512, - Temperature: 0.9, - Model: modelFile, - } -} - -// https://platform.openai.com/docs/api-reference/completions -func completionEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { - process := func(s string, req *OpenAIRequest, config *Config, loader *model.ModelLoader, responses chan OpenAIResponse) { - ComputeChoices(s, req, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool { - resp := OpenAIResponse{ - Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []Choice{ - { - Index: 0, - Text: s, - }, - }, - Object: "text_completion", - } - log.Debug().Msgf("Sending goroutine: %s", s) - - responses <- resp - return true - }) - close(responses) - } - - return func(c *fiber.Ctx) error { - model, input, err := readInput(c, o.loader, true) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - - log.Debug().Msgf("`input`: %+v", input) - - config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - - log.Debug().Msgf("Parameter Config: %+v", config) - - if input.Stream { - log.Debug().Msgf("Stream request received") - c.Context().SetContentType("text/event-stream") - //c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8) - //c.Set("Content-Type", "text/event-stream") - c.Set("Cache-Control", "no-cache") - c.Set("Connection", "keep-alive") - c.Set("Transfer-Encoding", "chunked") - } - - templateFile := config.Model - - if config.TemplateConfig.Completion != "" { - templateFile = config.TemplateConfig.Completion - } - - if input.Stream { - if len(config.PromptStrings) > 1 { - return errors.New("cannot handle more than 1 `PromptStrings` when Streaming") - } - - predInput := config.PromptStrings[0] - - // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - templatedInput, err := o.loader.TemplatePrefix(templateFile, struct { - Input string - }{ - Input: predInput, - }) - if err == nil { - predInput = templatedInput - log.Debug().Msgf("Template found, input modified to: %s", predInput) - } - - responses := make(chan OpenAIResponse) - - go process(predInput, input, config, o.loader, responses) - - c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { - - for ev := range responses { - var buf bytes.Buffer - enc := json.NewEncoder(&buf) - enc.Encode(ev) - - log.Debug().Msgf("Sending chunk: %s", buf.String()) - fmt.Fprintf(w, "data: %v\n", buf.String()) - w.Flush() - } - - resp := &OpenAIResponse{ - Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []Choice{ - { - Index: 0, - FinishReason: "stop", - }, - }, - Object: "text_completion", - } - respData, _ := json.Marshal(resp) - - w.WriteString(fmt.Sprintf("data: %s\n\n", respData)) - w.WriteString("data: [DONE]\n\n") - w.Flush() - })) - return nil - } - - var result []Choice - for _, i := range config.PromptStrings { - // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - templatedInput, err := o.loader.TemplatePrefix(templateFile, struct { - Input string - }{ - Input: i, - }) - if err == nil { - i = templatedInput - log.Debug().Msgf("Template found, input modified to: %s", i) - } - - r, err := ComputeChoices(i, input, config, o, o.loader, func(s string, c *[]Choice) { - *c = append(*c, Choice{Text: s}) - }, nil) - if err != nil { - return err - } - - result = append(result, r...) - } - - resp := &OpenAIResponse{ - Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: result, - Object: "text_completion", - } - - jsonResult, _ := json.Marshal(resp) - log.Debug().Msgf("Response: %s", jsonResult) - - // Return the prediction in the response body - return c.JSON(resp) - } -} - -// https://platform.openai.com/docs/api-reference/embeddings -func embeddingsEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - model, input, err := readInput(c, o.loader, true) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - - config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - - log.Debug().Msgf("Parameter Config: %+v", config) - items := []Item{} - - for i, s := range config.InputToken { - // get the model function to call for the result - embedFn, err := ModelEmbedding("", s, o.loader, *config, o) - if err != nil { - return err - } - - embeddings, err := embedFn() - if err != nil { - return err - } - items = append(items, Item{Embedding: embeddings, Index: i, Object: "embedding"}) - } - - for i, s := range config.InputStrings { - // get the model function to call for the result - embedFn, err := ModelEmbedding(s, []int{}, o.loader, *config, o) - if err != nil { - return err - } - - embeddings, err := embedFn() - if err != nil { - return err - } - items = append(items, Item{Embedding: embeddings, Index: i, Object: "embedding"}) - } - - resp := &OpenAIResponse{ - Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. - Data: items, - Object: "list", - } - - jsonResult, _ := json.Marshal(resp) - log.Debug().Msgf("Response: %s", jsonResult) - - // Return the prediction in the response body - return c.JSON(resp) - } -} - -func isEOS(s string) bool { - if s == "<|endoftext|>" { - return true - } - - return false -} -func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { - - process := func(s string, req *OpenAIRequest, config *Config, loader *model.ModelLoader, responses chan OpenAIResponse) { - initialMessage := OpenAIResponse{ - Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []Choice{{Delta: &Message{Role: "assistant"}}}, - Object: "chat.completion.chunk", - } - responses <- initialMessage - - ComputeChoices(s, req, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool { - resp := OpenAIResponse{ - Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []Choice{{Delta: &Message{Content: &s}, Index: 0}}, - Object: "chat.completion.chunk", - } - log.Debug().Msgf("Sending goroutine: %s", s) - - if s != "" && !isEOS(s) { - responses <- resp - } - return true - }) - close(responses) - } - return func(c *fiber.Ctx) error { - processFunctions := false - funcs := grammar.Functions{} - model, input, err := readInput(c, o.loader, true) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - - config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - log.Debug().Msgf("Configuration read: %+v", config) - - // Allow the user to set custom actions via config file - // to be "embedded" in each model - noActionName := "answer" - noActionDescription := "use this action to answer without performing any action" - - if config.FunctionsConfig.NoActionFunctionName != "" { - noActionName = config.FunctionsConfig.NoActionFunctionName - } - if config.FunctionsConfig.NoActionDescriptionName != "" { - noActionDescription = config.FunctionsConfig.NoActionDescriptionName - } - - // process functions if we have any defined or if we have a function call string - if len(input.Functions) > 0 && - ((config.functionCallString != "none" || config.functionCallString == "") || len(config.functionCallNameString) > 0) { - log.Debug().Msgf("Response needs to process functions") - - processFunctions = true - - noActionGrammar := grammar.Function{ - Name: noActionName, - Description: noActionDescription, - Parameters: map[string]interface{}{ - "properties": map[string]interface{}{ - "message": map[string]interface{}{ - "type": "string", - "description": "The message to reply the user with", - }}, - }, - } - - // Append the no action function - funcs = append(funcs, input.Functions...) - if !config.FunctionsConfig.DisableNoAction { - funcs = append(funcs, noActionGrammar) - } - - // Force picking one of the functions by the request - if config.functionCallNameString != "" { - funcs = funcs.Select(config.functionCallNameString) - } - - // Update input grammar - jsStruct := funcs.ToJSONStructure() - config.Grammar = jsStruct.Grammar("") - } else if input.JSONFunctionGrammarObject != nil { - config.Grammar = input.JSONFunctionGrammarObject.Grammar("") - } - - // functions are not supported in stream mode (yet?) - toStream := input.Stream && !processFunctions - - log.Debug().Msgf("Parameters: %+v", config) - - var predInput string - - mess := []string{} - for _, i := range input.Messages { - var content string - role := i.Role - // if function call, we might want to customize the role so we can display better that the "assistant called a json action" - // if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request - if i.FunctionCall != nil && i.Role == "assistant" { - roleFn := "assistant_function_call" - r := config.Roles[roleFn] - if r != "" { - role = roleFn - } - } - r := config.Roles[role] - contentExists := i.Content != nil && *i.Content != "" - if r != "" { - if contentExists { - content = fmt.Sprint(r, " ", *i.Content) - } - if i.FunctionCall != nil { - j, err := json.Marshal(i.FunctionCall) - if err == nil { - if contentExists { - content += "\n" + fmt.Sprint(r, " ", string(j)) - } else { - content = fmt.Sprint(r, " ", string(j)) - } - } - } - } else { - if contentExists { - content = fmt.Sprint(*i.Content) - } - if i.FunctionCall != nil { - j, err := json.Marshal(i.FunctionCall) - if err == nil { - if contentExists { - content += "\n" + string(j) - } else { - content = string(j) - } - } - } - } - - mess = append(mess, content) - } - - predInput = strings.Join(mess, "\n") - log.Debug().Msgf("Prompt (before templating): %s", predInput) - - if toStream { - log.Debug().Msgf("Stream request received") - c.Context().SetContentType("text/event-stream") - //c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8) - // c.Set("Content-Type", "text/event-stream") - c.Set("Cache-Control", "no-cache") - c.Set("Connection", "keep-alive") - c.Set("Transfer-Encoding", "chunked") - } - - templateFile := config.Model - - if config.TemplateConfig.Chat != "" && !processFunctions { - templateFile = config.TemplateConfig.Chat - } - - if config.TemplateConfig.Functions != "" && processFunctions { - templateFile = config.TemplateConfig.Functions - } - - // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - templatedInput, err := o.loader.TemplatePrefix(templateFile, struct { - Input string - Functions []grammar.Function - }{ - Input: predInput, - Functions: funcs, - }) - if err == nil { - predInput = templatedInput - log.Debug().Msgf("Template found, input modified to: %s", predInput) - } else { - log.Debug().Msgf("Template failed loading: %s", err.Error()) - } - - log.Debug().Msgf("Prompt (after templating): %s", predInput) - if processFunctions { - log.Debug().Msgf("Grammar: %+v", config.Grammar) - } - - if toStream { - responses := make(chan OpenAIResponse) - - go process(predInput, input, config, o.loader, responses) - - c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { - - for ev := range responses { - var buf bytes.Buffer - enc := json.NewEncoder(&buf) - enc.Encode(ev) - - log.Debug().Msgf("Sending chunk: %s", buf.String()) - fmt.Fprintf(w, "data: %v\n", buf.String()) - w.Flush() - } - - resp := &OpenAIResponse{ - Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []Choice{ - { - FinishReason: "stop", - Index: 0, - Delta: &Message{}, - }}, - Object: "chat.completion.chunk", - } - respData, _ := json.Marshal(resp) - - w.WriteString(fmt.Sprintf("data: %s\n\n", respData)) - w.WriteString("data: [DONE]\n\n") - w.Flush() - })) - return nil - } - - result, err := ComputeChoices(predInput, input, config, o, o.loader, func(s string, c *[]Choice) { - if processFunctions { - // As we have to change the result before processing, we can't stream the answer (yet?) - ss := map[string]interface{}{} - json.Unmarshal([]byte(s), &ss) - log.Debug().Msgf("Function return: %s %+v", s, ss) - - // The grammar defines the function name as "function", while OpenAI returns "name" - func_name := ss["function"] - // Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object - args := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix) - d, _ := json.Marshal(args) - - ss["arguments"] = string(d) - ss["name"] = func_name - - // if do nothing, reply with a message - if func_name == noActionName { - log.Debug().Msgf("nothing to do, computing a reply") - - // If there is a message that the LLM already sends as part of the JSON reply, use it - arguments := map[string]interface{}{} - json.Unmarshal([]byte(d), &arguments) - m, exists := arguments["message"] - if exists { - switch message := m.(type) { - case string: - if message != "" { - log.Debug().Msgf("Reply received from LLM: %s", message) - message = Finetune(*config, predInput, message) - log.Debug().Msgf("Reply received from LLM(finetuned): %s", message) - - *c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &message}}) - return - } - } - } - - log.Debug().Msgf("No action received from LLM, without a message, computing a reply") - // Otherwise ask the LLM to understand the JSON output and the context, and return a message - // Note: This costs (in term of CPU) another computation - config.Grammar = "" - predFunc, err := ModelInference(predInput, o.loader, *config, o, nil) - if err != nil { - log.Error().Msgf("inference error: %s", err.Error()) - return - } - - prediction, err := predFunc() - if err != nil { - log.Error().Msgf("inference error: %s", err.Error()) - return - } - - prediction = Finetune(*config, predInput, prediction) - *c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &prediction}}) - } else { - // otherwise reply with the function call - *c = append(*c, Choice{ - FinishReason: "function_call", - Message: &Message{Role: "assistant", FunctionCall: ss}, - }) - } - - return - } - *c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &s}}) - }, nil) - if err != nil { - return err - } - - resp := &OpenAIResponse{ - Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: result, - Object: "chat.completion", - } - respData, _ := json.Marshal(resp) - log.Debug().Msgf("Response: %s", respData) - - // Return the prediction in the response body - return c.JSON(resp) - } -} - -func editEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - model, input, err := readInput(c, o.loader, true) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - - config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - - log.Debug().Msgf("Parameter Config: %+v", config) - - templateFile := config.Model - - if config.TemplateConfig.Edit != "" { - templateFile = config.TemplateConfig.Edit - } - - var result []Choice - for _, i := range config.InputStrings { - // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - templatedInput, err := o.loader.TemplatePrefix(templateFile, struct { - Input string - Instruction string - }{Input: i}) - if err == nil { - i = templatedInput - log.Debug().Msgf("Template found, input modified to: %s", i) - } - - r, err := ComputeChoices(i, input, config, o, o.loader, func(s string, c *[]Choice) { - *c = append(*c, Choice{Text: s}) - }, nil) - if err != nil { - return err - } - - result = append(result, r...) - } - - resp := &OpenAIResponse{ - Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: result, - Object: "edit", - } - - jsonResult, _ := json.Marshal(resp) - log.Debug().Msgf("Response: %s", jsonResult) - - // Return the prediction in the response body - return c.JSON(resp) - } -} - -// https://platform.openai.com/docs/api-reference/images/create - -/* -* - - curl http://localhost:8080/v1/images/generations \ - -H "Content-Type: application/json" \ - -d '{ - "prompt": "A cute baby sea otter", - "n": 1, - "size": "512x512" - }' - -* -*/ -func imageEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - m, input, err := readInput(c, o.loader, false) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - - if m == "" { - m = model.StableDiffusionBackend - } - log.Debug().Msgf("Loading model: %+v", m) - - config, input, err := readConfig(m, input, cm, o.loader, o.debug, 0, 0, false) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - - log.Debug().Msgf("Parameter Config: %+v", config) - - // XXX: Only stablediffusion is supported for now - if config.Backend == "" { - config.Backend = model.StableDiffusionBackend - } - - sizeParts := strings.Split(input.Size, "x") - if len(sizeParts) != 2 { - return fmt.Errorf("Invalid value for 'size'") - } - width, err := strconv.Atoi(sizeParts[0]) - if err != nil { - return fmt.Errorf("Invalid value for 'size'") - } - height, err := strconv.Atoi(sizeParts[1]) - if err != nil { - return fmt.Errorf("Invalid value for 'size'") - } - - b64JSON := false - if input.ResponseFormat == "b64_json" { - b64JSON = true - } - - var result []Item - for _, i := range config.PromptStrings { - n := input.N - if input.N == 0 { - n = 1 - } - for j := 0; j < n; j++ { - prompts := strings.Split(i, "|") - positive_prompt := prompts[0] - negative_prompt := "" - if len(prompts) > 1 { - negative_prompt = prompts[1] - } - - mode := 0 - step := 15 - - if input.Mode != 0 { - mode = input.Mode - } - - if input.Step != 0 { - step = input.Step - } - - tempDir := "" - if !b64JSON { - tempDir = o.imageDir - } - // Create a temporary file - outputFile, err := ioutil.TempFile(tempDir, "b64") - if err != nil { - return err - } - outputFile.Close() - output := outputFile.Name() + ".png" - // Rename the temporary file - err = os.Rename(outputFile.Name(), output) - if err != nil { - return err - } - - baseURL := c.BaseURL() - - fn, err := ImageGeneration(height, width, mode, step, input.Seed, positive_prompt, negative_prompt, output, o.loader, *config, o) - if err != nil { - return err - } - if err := fn(); err != nil { - return err - } - - item := &Item{} - - if b64JSON { - defer os.RemoveAll(output) - data, err := os.ReadFile(output) - if err != nil { - return err - } - item.B64JSON = base64.StdEncoding.EncodeToString(data) - } else { - base := filepath.Base(output) - item.URL = baseURL + "/generated-images/" + base - } - - result = append(result, *item) - } - } - - resp := &OpenAIResponse{ - Data: result, - } - - jsonResult, _ := json.Marshal(resp) - log.Debug().Msgf("Response: %s", jsonResult) - - // Return the prediction in the response body - return c.JSON(resp) - } -} - -// https://platform.openai.com/docs/api-reference/audio/create -func transcriptEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - m, input, err := readInput(c, o.loader, false) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - - config, input, err := readConfig(m, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - // retrieve the file data from the request - file, err := c.FormFile("file") - if err != nil { - return err - } - f, err := file.Open() - if err != nil { - return err - } - defer f.Close() - - dir, err := os.MkdirTemp("", "whisper") - - if err != nil { - return err - } - defer os.RemoveAll(dir) - - dst := filepath.Join(dir, path.Base(file.Filename)) - dstFile, err := os.Create(dst) - if err != nil { - return err - } - - if _, err := io.Copy(dstFile, f); err != nil { - log.Debug().Msgf("Audio file copying error %+v - %+v - err %+v", file.Filename, dst, err) - return err - } - - log.Debug().Msgf("Audio file copied to: %+v", dst) - - whisperModel, err := o.loader.BackendLoader( - model.WithBackendString(model.WhisperBackend), - model.WithModelFile(config.Model), - model.WithThreads(uint32(config.Threads)), - model.WithAssetDir(o.assetsDestination)) - if err != nil { - return err - } - - if whisperModel == nil { - return fmt.Errorf("could not load whisper model") - } - - w, ok := whisperModel.(whisper.Model) - if !ok { - return fmt.Errorf("loader returned non-whisper object") - } - - tr, err := whisperutil.Transcript(w, dst, input.Language, uint(config.Threads)) - if err != nil { - return err - } - - log.Debug().Msgf("Trascribed: %+v", tr) - // TODO: handle different outputs here - return c.Status(http.StatusOK).JSON(tr) - } -} - -func listModels(loader *model.ModelLoader, cm *ConfigMerger) func(ctx *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - models, err := loader.ListModels() - if err != nil { - return err - } - var mm map[string]interface{} = map[string]interface{}{} - - dataModels := []OpenAIModel{} - for _, m := range models { - mm[m] = nil - dataModels = append(dataModels, OpenAIModel{ID: m, Object: "model"}) - } - - for _, k := range cm.ListConfigs() { - if _, exists := mm[k]; !exists { - dataModels = append(dataModels, OpenAIModel{ID: k, Object: "model"}) - } - } - - return c.JSON(struct { - Object string `json:"object"` - Data []OpenAIModel `json:"data"` - }{ - Object: "list", - Data: dataModels, - }) - } -} diff --git a/api/openai/api.go b/api/openai/api.go new file mode 100644 index 00000000..6d7ce5ea --- /dev/null +++ b/api/openai/api.go @@ -0,0 +1,105 @@ +package openai + +import ( + config "github.com/go-skynet/LocalAI/api/config" + + "github.com/go-skynet/LocalAI/pkg/grammar" +) + +// APIError provides error information returned by the OpenAI API. +type APIError struct { + Code any `json:"code,omitempty"` + Message string `json:"message"` + Param *string `json:"param,omitempty"` + Type string `json:"type"` +} + +type ErrorResponse struct { + Error *APIError `json:"error,omitempty"` +} + +type OpenAIUsage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +type Item struct { + Embedding []float32 `json:"embedding"` + Index int `json:"index"` + Object string `json:"object,omitempty"` + + // Images + URL string `json:"url,omitempty"` + B64JSON string `json:"b64_json,omitempty"` +} + +type OpenAIResponse struct { + Created int `json:"created,omitempty"` + Object string `json:"object,omitempty"` + ID string `json:"id,omitempty"` + Model string `json:"model,omitempty"` + Choices []Choice `json:"choices,omitempty"` + Data []Item `json:"data,omitempty"` + + Usage OpenAIUsage `json:"usage"` +} + +type Choice struct { + Index int `json:"index,omitempty"` + FinishReason string `json:"finish_reason,omitempty"` + Message *Message `json:"message,omitempty"` + Delta *Message `json:"delta,omitempty"` + Text string `json:"text,omitempty"` +} + +type Message struct { + // The message role + Role string `json:"role,omitempty" yaml:"role"` + // The message content + Content *string `json:"content" yaml:"content"` + // A result of a function call + FunctionCall interface{} `json:"function_call,omitempty" yaml:"function_call,omitempty"` +} + +type OpenAIModel struct { + ID string `json:"id"` + Object string `json:"object"` +} + +type OpenAIRequest struct { + config.PredictionOptions + + // whisper + File string `json:"file" validate:"required"` + //whisper/image + ResponseFormat string `json:"response_format"` + // image + Size string `json:"size"` + // Prompt is read only by completion/image API calls + Prompt interface{} `json:"prompt" yaml:"prompt"` + + // Edit endpoint + Instruction string `json:"instruction" yaml:"instruction"` + Input interface{} `json:"input" yaml:"input"` + + Stop interface{} `json:"stop" yaml:"stop"` + + // Messages is read only by chat/completion API calls + Messages []Message `json:"messages" yaml:"messages"` + + // A list of available functions to call + Functions []grammar.Function `json:"functions" yaml:"functions"` + FunctionCall interface{} `json:"function_call" yaml:"function_call"` // might be a string or an object + + Stream bool `json:"stream"` + + // Image (not supported by OpenAI) + Mode int `json:"mode"` + Step int `json:"step"` + + // A grammar to constrain the LLM output + Grammar string `json:"grammar" yaml:"grammar"` + + JSONFunctionGrammarObject *grammar.JSONFunctionStructure `json:"grammar_json_functions" yaml:"grammar_json_functions"` +} diff --git a/api/openai/chat.go b/api/openai/chat.go new file mode 100644 index 00000000..30f6e01a --- /dev/null +++ b/api/openai/chat.go @@ -0,0 +1,320 @@ +package openai + +import ( + "bufio" + "bytes" + "encoding/json" + "fmt" + "strings" + + "github.com/go-skynet/LocalAI/api/backend" + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/options" + "github.com/go-skynet/LocalAI/pkg/grammar" + model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/gofiber/fiber/v2" + "github.com/rs/zerolog/log" + "github.com/valyala/fasthttp" +) + +func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { + process := func(s string, req *OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan OpenAIResponse) { + initialMessage := OpenAIResponse{ + Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []Choice{{Delta: &Message{Role: "assistant"}}}, + Object: "chat.completion.chunk", + } + responses <- initialMessage + + ComputeChoices(s, req.N, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool { + resp := OpenAIResponse{ + Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []Choice{{Delta: &Message{Content: &s}, Index: 0}}, + Object: "chat.completion.chunk", + } + + responses <- resp + return true + }) + close(responses) + } + return func(c *fiber.Ctx) error { + processFunctions := false + funcs := grammar.Functions{} + model, input, err := readInput(c, o.Loader, true) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + config, input, err := readConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + log.Debug().Msgf("Configuration read: %+v", config) + + // Allow the user to set custom actions via config file + // to be "embedded" in each model + noActionName := "answer" + noActionDescription := "use this action to answer without performing any action" + + if config.FunctionsConfig.NoActionFunctionName != "" { + noActionName = config.FunctionsConfig.NoActionFunctionName + } + if config.FunctionsConfig.NoActionDescriptionName != "" { + noActionDescription = config.FunctionsConfig.NoActionDescriptionName + } + + // process functions if we have any defined or if we have a function call string + if len(input.Functions) > 0 && config.ShouldUseFunctions() { + log.Debug().Msgf("Response needs to process functions") + + processFunctions = true + + noActionGrammar := grammar.Function{ + Name: noActionName, + Description: noActionDescription, + Parameters: map[string]interface{}{ + "properties": map[string]interface{}{ + "message": map[string]interface{}{ + "type": "string", + "description": "The message to reply the user with", + }}, + }, + } + + // Append the no action function + funcs = append(funcs, input.Functions...) + if !config.FunctionsConfig.DisableNoAction { + funcs = append(funcs, noActionGrammar) + } + + // Force picking one of the functions by the request + if config.FunctionToCall() != "" { + funcs = funcs.Select(config.FunctionToCall()) + } + + // Update input grammar + jsStruct := funcs.ToJSONStructure() + config.Grammar = jsStruct.Grammar("") + } else if input.JSONFunctionGrammarObject != nil { + config.Grammar = input.JSONFunctionGrammarObject.Grammar("") + } + + // functions are not supported in stream mode (yet?) + toStream := input.Stream && !processFunctions + + log.Debug().Msgf("Parameters: %+v", config) + + var predInput string + + mess := []string{} + for _, i := range input.Messages { + var content string + role := i.Role + // if function call, we might want to customize the role so we can display better that the "assistant called a json action" + // if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request + if i.FunctionCall != nil && i.Role == "assistant" { + roleFn := "assistant_function_call" + r := config.Roles[roleFn] + if r != "" { + role = roleFn + } + } + r := config.Roles[role] + contentExists := i.Content != nil && *i.Content != "" + if r != "" { + if contentExists { + content = fmt.Sprint(r, " ", *i.Content) + } + if i.FunctionCall != nil { + j, err := json.Marshal(i.FunctionCall) + if err == nil { + if contentExists { + content += "\n" + fmt.Sprint(r, " ", string(j)) + } else { + content = fmt.Sprint(r, " ", string(j)) + } + } + } + } else { + if contentExists { + content = fmt.Sprint(*i.Content) + } + if i.FunctionCall != nil { + j, err := json.Marshal(i.FunctionCall) + if err == nil { + if contentExists { + content += "\n" + string(j) + } else { + content = string(j) + } + } + } + } + + mess = append(mess, content) + } + + predInput = strings.Join(mess, "\n") + log.Debug().Msgf("Prompt (before templating): %s", predInput) + + if toStream { + log.Debug().Msgf("Stream request received") + c.Context().SetContentType("text/event-stream") + //c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8) + // c.Set("Content-Type", "text/event-stream") + c.Set("Cache-Control", "no-cache") + c.Set("Connection", "keep-alive") + c.Set("Transfer-Encoding", "chunked") + } + + templateFile := config.Model + + if config.TemplateConfig.Chat != "" && !processFunctions { + templateFile = config.TemplateConfig.Chat + } + + if config.TemplateConfig.Functions != "" && processFunctions { + templateFile = config.TemplateConfig.Functions + } + + // A model can have a "file.bin.tmpl" file associated with a prompt template prefix + templatedInput, err := o.Loader.TemplatePrefix(templateFile, struct { + Input string + Functions []grammar.Function + }{ + Input: predInput, + Functions: funcs, + }) + if err == nil { + predInput = templatedInput + log.Debug().Msgf("Template found, input modified to: %s", predInput) + } else { + log.Debug().Msgf("Template failed loading: %s", err.Error()) + } + + log.Debug().Msgf("Prompt (after templating): %s", predInput) + if processFunctions { + log.Debug().Msgf("Grammar: %+v", config.Grammar) + } + + if toStream { + responses := make(chan OpenAIResponse) + + go process(predInput, input, config, o.Loader, responses) + + c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { + + for ev := range responses { + var buf bytes.Buffer + enc := json.NewEncoder(&buf) + enc.Encode(ev) + + log.Debug().Msgf("Sending chunk: %s", buf.String()) + fmt.Fprintf(w, "data: %v\n", buf.String()) + w.Flush() + } + + resp := &OpenAIResponse{ + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []Choice{ + { + FinishReason: "stop", + Index: 0, + Delta: &Message{}, + }}, + Object: "chat.completion.chunk", + } + respData, _ := json.Marshal(resp) + + w.WriteString(fmt.Sprintf("data: %s\n\n", respData)) + w.WriteString("data: [DONE]\n\n") + w.Flush() + })) + return nil + } + + result, err := ComputeChoices(predInput, input.N, config, o, o.Loader, func(s string, c *[]Choice) { + if processFunctions { + // As we have to change the result before processing, we can't stream the answer (yet?) + ss := map[string]interface{}{} + json.Unmarshal([]byte(s), &ss) + log.Debug().Msgf("Function return: %s %+v", s, ss) + + // The grammar defines the function name as "function", while OpenAI returns "name" + func_name := ss["function"] + // Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object + args := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix) + d, _ := json.Marshal(args) + + ss["arguments"] = string(d) + ss["name"] = func_name + + // if do nothing, reply with a message + if func_name == noActionName { + log.Debug().Msgf("nothing to do, computing a reply") + + // If there is a message that the LLM already sends as part of the JSON reply, use it + arguments := map[string]interface{}{} + json.Unmarshal([]byte(d), &arguments) + m, exists := arguments["message"] + if exists { + switch message := m.(type) { + case string: + if message != "" { + log.Debug().Msgf("Reply received from LLM: %s", message) + message = backend.Finetune(*config, predInput, message) + log.Debug().Msgf("Reply received from LLM(finetuned): %s", message) + + *c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &message}}) + return + } + } + } + + log.Debug().Msgf("No action received from LLM, without a message, computing a reply") + // Otherwise ask the LLM to understand the JSON output and the context, and return a message + // Note: This costs (in term of CPU) another computation + config.Grammar = "" + predFunc, err := backend.ModelInference(predInput, o.Loader, *config, o, nil) + if err != nil { + log.Error().Msgf("inference error: %s", err.Error()) + return + } + + prediction, err := predFunc() + if err != nil { + log.Error().Msgf("inference error: %s", err.Error()) + return + } + + prediction = backend.Finetune(*config, predInput, prediction) + *c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &prediction}}) + } else { + // otherwise reply with the function call + *c = append(*c, Choice{ + FinishReason: "function_call", + Message: &Message{Role: "assistant", FunctionCall: ss}, + }) + } + + return + } + *c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &s}}) + }, nil) + if err != nil { + return err + } + + resp := &OpenAIResponse{ + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: result, + Object: "chat.completion", + } + respData, _ := json.Marshal(resp) + log.Debug().Msgf("Response: %s", respData) + + // Return the prediction in the response body + return c.JSON(resp) + } +} diff --git a/api/openai/completion.go b/api/openai/completion.go new file mode 100644 index 00000000..d17fd607 --- /dev/null +++ b/api/openai/completion.go @@ -0,0 +1,159 @@ +package openai + +import ( + "bufio" + "bytes" + "encoding/json" + "errors" + "fmt" + + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/options" + model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/gofiber/fiber/v2" + "github.com/rs/zerolog/log" + "github.com/valyala/fasthttp" +) + +// https://platform.openai.com/docs/api-reference/completions +func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { + process := func(s string, req *OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan OpenAIResponse) { + ComputeChoices(s, req.N, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool { + resp := OpenAIResponse{ + Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []Choice{ + { + Index: 0, + Text: s, + }, + }, + Object: "text_completion", + } + log.Debug().Msgf("Sending goroutine: %s", s) + + responses <- resp + return true + }) + close(responses) + } + + return func(c *fiber.Ctx) error { + model, input, err := readInput(c, o.Loader, true) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + log.Debug().Msgf("`input`: %+v", input) + + config, input, err := readConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + log.Debug().Msgf("Parameter Config: %+v", config) + + if input.Stream { + log.Debug().Msgf("Stream request received") + c.Context().SetContentType("text/event-stream") + //c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8) + //c.Set("Content-Type", "text/event-stream") + c.Set("Cache-Control", "no-cache") + c.Set("Connection", "keep-alive") + c.Set("Transfer-Encoding", "chunked") + } + + templateFile := config.Model + + if config.TemplateConfig.Completion != "" { + templateFile = config.TemplateConfig.Completion + } + + if input.Stream { + if len(config.PromptStrings) > 1 { + return errors.New("cannot handle more than 1 `PromptStrings` when Streaming") + } + + predInput := config.PromptStrings[0] + + // A model can have a "file.bin.tmpl" file associated with a prompt template prefix + templatedInput, err := o.Loader.TemplatePrefix(templateFile, struct { + Input string + }{ + Input: predInput, + }) + if err == nil { + predInput = templatedInput + log.Debug().Msgf("Template found, input modified to: %s", predInput) + } + + responses := make(chan OpenAIResponse) + + go process(predInput, input, config, o.Loader, responses) + + c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { + + for ev := range responses { + var buf bytes.Buffer + enc := json.NewEncoder(&buf) + enc.Encode(ev) + + log.Debug().Msgf("Sending chunk: %s", buf.String()) + fmt.Fprintf(w, "data: %v\n", buf.String()) + w.Flush() + } + + resp := &OpenAIResponse{ + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []Choice{ + { + Index: 0, + FinishReason: "stop", + }, + }, + Object: "text_completion", + } + respData, _ := json.Marshal(resp) + + w.WriteString(fmt.Sprintf("data: %s\n\n", respData)) + w.WriteString("data: [DONE]\n\n") + w.Flush() + })) + return nil + } + + var result []Choice + for _, i := range config.PromptStrings { + // A model can have a "file.bin.tmpl" file associated with a prompt template prefix + templatedInput, err := o.Loader.TemplatePrefix(templateFile, struct { + Input string + }{ + Input: i, + }) + if err == nil { + i = templatedInput + log.Debug().Msgf("Template found, input modified to: %s", i) + } + + r, err := ComputeChoices(i, input.N, config, o, o.Loader, func(s string, c *[]Choice) { + *c = append(*c, Choice{Text: s}) + }, nil) + if err != nil { + return err + } + + result = append(result, r...) + } + + resp := &OpenAIResponse{ + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: result, + Object: "text_completion", + } + + jsonResult, _ := json.Marshal(resp) + log.Debug().Msgf("Response: %s", jsonResult) + + // Return the prediction in the response body + return c.JSON(resp) + } +} diff --git a/api/openai/edit.go b/api/openai/edit.go new file mode 100644 index 00000000..d988d6d1 --- /dev/null +++ b/api/openai/edit.go @@ -0,0 +1,67 @@ +package openai + +import ( + "encoding/json" + "fmt" + + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/options" + "github.com/gofiber/fiber/v2" + "github.com/rs/zerolog/log" +) + +func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + model, input, err := readInput(c, o.Loader, true) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + config, input, err := readConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + log.Debug().Msgf("Parameter Config: %+v", config) + + templateFile := config.Model + + if config.TemplateConfig.Edit != "" { + templateFile = config.TemplateConfig.Edit + } + + var result []Choice + for _, i := range config.InputStrings { + // A model can have a "file.bin.tmpl" file associated with a prompt template prefix + templatedInput, err := o.Loader.TemplatePrefix(templateFile, struct { + Input string + Instruction string + }{Input: i}) + if err == nil { + i = templatedInput + log.Debug().Msgf("Template found, input modified to: %s", i) + } + + r, err := ComputeChoices(i, input.N, config, o, o.Loader, func(s string, c *[]Choice) { + *c = append(*c, Choice{Text: s}) + }, nil) + if err != nil { + return err + } + + result = append(result, r...) + } + + resp := &OpenAIResponse{ + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: result, + Object: "edit", + } + + jsonResult, _ := json.Marshal(resp) + log.Debug().Msgf("Response: %s", jsonResult) + + // Return the prediction in the response body + return c.JSON(resp) + } +} diff --git a/api/openai/embeddings.go b/api/openai/embeddings.go new file mode 100644 index 00000000..248ae5cf --- /dev/null +++ b/api/openai/embeddings.go @@ -0,0 +1,70 @@ +package openai + +import ( + "encoding/json" + "fmt" + + "github.com/go-skynet/LocalAI/api/backend" + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/options" + "github.com/gofiber/fiber/v2" + "github.com/rs/zerolog/log" +) + +// https://platform.openai.com/docs/api-reference/embeddings +func EmbeddingsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + model, input, err := readInput(c, o.Loader, true) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + config, input, err := readConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + log.Debug().Msgf("Parameter Config: %+v", config) + items := []Item{} + + for i, s := range config.InputToken { + // get the model function to call for the result + embedFn, err := backend.ModelEmbedding("", s, o.Loader, *config, o) + if err != nil { + return err + } + + embeddings, err := embedFn() + if err != nil { + return err + } + items = append(items, Item{Embedding: embeddings, Index: i, Object: "embedding"}) + } + + for i, s := range config.InputStrings { + // get the model function to call for the result + embedFn, err := backend.ModelEmbedding(s, []int{}, o.Loader, *config, o) + if err != nil { + return err + } + + embeddings, err := embedFn() + if err != nil { + return err + } + items = append(items, Item{Embedding: embeddings, Index: i, Object: "embedding"}) + } + + resp := &OpenAIResponse{ + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Data: items, + Object: "list", + } + + jsonResult, _ := json.Marshal(resp) + log.Debug().Msgf("Response: %s", jsonResult) + + // Return the prediction in the response body + return c.JSON(resp) + } +} diff --git a/api/openai/image.go b/api/openai/image.go new file mode 100644 index 00000000..bca54c16 --- /dev/null +++ b/api/openai/image.go @@ -0,0 +1,158 @@ +package openai + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "io/ioutil" + "os" + "path/filepath" + "strconv" + "strings" + + "github.com/go-skynet/LocalAI/api/backend" + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/options" + model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/gofiber/fiber/v2" + "github.com/rs/zerolog/log" +) + +// https://platform.openai.com/docs/api-reference/images/create + +/* +* + + curl http://localhost:8080/v1/images/generations \ + -H "Content-Type: application/json" \ + -d '{ + "prompt": "A cute baby sea otter", + "n": 1, + "size": "512x512" + }' + +* +*/ +func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + m, input, err := readInput(c, o.Loader, false) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + if m == "" { + m = model.StableDiffusionBackend + } + log.Debug().Msgf("Loading model: %+v", m) + + config, input, err := readConfig(m, input, cm, o.Loader, o.Debug, 0, 0, false) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + log.Debug().Msgf("Parameter Config: %+v", config) + + // XXX: Only stablediffusion is supported for now + if config.Backend == "" { + config.Backend = model.StableDiffusionBackend + } + + sizeParts := strings.Split(input.Size, "x") + if len(sizeParts) != 2 { + return fmt.Errorf("Invalid value for 'size'") + } + width, err := strconv.Atoi(sizeParts[0]) + if err != nil { + return fmt.Errorf("Invalid value for 'size'") + } + height, err := strconv.Atoi(sizeParts[1]) + if err != nil { + return fmt.Errorf("Invalid value for 'size'") + } + + b64JSON := false + if input.ResponseFormat == "b64_json" { + b64JSON = true + } + + var result []Item + for _, i := range config.PromptStrings { + n := input.N + if input.N == 0 { + n = 1 + } + for j := 0; j < n; j++ { + prompts := strings.Split(i, "|") + positive_prompt := prompts[0] + negative_prompt := "" + if len(prompts) > 1 { + negative_prompt = prompts[1] + } + + mode := 0 + step := 15 + + if input.Mode != 0 { + mode = input.Mode + } + + if input.Step != 0 { + step = input.Step + } + + tempDir := "" + if !b64JSON { + tempDir = o.ImageDir + } + // Create a temporary file + outputFile, err := ioutil.TempFile(tempDir, "b64") + if err != nil { + return err + } + outputFile.Close() + output := outputFile.Name() + ".png" + // Rename the temporary file + err = os.Rename(outputFile.Name(), output) + if err != nil { + return err + } + + baseURL := c.BaseURL() + + fn, err := backend.ImageGeneration(height, width, mode, step, input.Seed, positive_prompt, negative_prompt, output, o.Loader, *config, o) + if err != nil { + return err + } + if err := fn(); err != nil { + return err + } + + item := &Item{} + + if b64JSON { + defer os.RemoveAll(output) + data, err := os.ReadFile(output) + if err != nil { + return err + } + item.B64JSON = base64.StdEncoding.EncodeToString(data) + } else { + base := filepath.Base(output) + item.URL = baseURL + "/generated-images/" + base + } + + result = append(result, *item) + } + } + + resp := &OpenAIResponse{ + Data: result, + } + + jsonResult, _ := json.Marshal(resp) + log.Debug().Msgf("Response: %s", jsonResult) + + // Return the prediction in the response body + return c.JSON(resp) + } +} diff --git a/api/openai/inference.go b/api/openai/inference.go new file mode 100644 index 00000000..a9991fa0 --- /dev/null +++ b/api/openai/inference.go @@ -0,0 +1,36 @@ +package openai + +import ( + "github.com/go-skynet/LocalAI/api/backend" + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/options" + model "github.com/go-skynet/LocalAI/pkg/model" +) + +func ComputeChoices(predInput string, n int, config *config.Config, o *options.Option, loader *model.ModelLoader, cb func(string, *[]Choice), tokenCallback func(string) bool) ([]Choice, error) { + result := []Choice{} + + if n == 0 { + n = 1 + } + + // get the model function to call for the result + predFunc, err := backend.ModelInference(predInput, loader, *config, o, tokenCallback) + if err != nil { + return result, err + } + + for i := 0; i < n; i++ { + prediction, err := predFunc() + if err != nil { + return result, err + } + + prediction = backend.Finetune(*config, predInput, prediction) + cb(prediction, &result) + + //result = append(result, Choice{Text: prediction}) + + } + return result, err +} diff --git a/api/openai/list.go b/api/openai/list.go new file mode 100644 index 00000000..0cd7f3af --- /dev/null +++ b/api/openai/list.go @@ -0,0 +1,37 @@ +package openai + +import ( + config "github.com/go-skynet/LocalAI/api/config" + model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/gofiber/fiber/v2" +) + +func ListModelsEndpoint(loader *model.ModelLoader, cm *config.ConfigLoader) func(ctx *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + models, err := loader.ListModels() + if err != nil { + return err + } + var mm map[string]interface{} = map[string]interface{}{} + + dataModels := []OpenAIModel{} + for _, m := range models { + mm[m] = nil + dataModels = append(dataModels, OpenAIModel{ID: m, Object: "model"}) + } + + for _, k := range cm.ListConfigs() { + if _, exists := mm[k]; !exists { + dataModels = append(dataModels, OpenAIModel{ID: k, Object: "model"}) + } + } + + return c.JSON(struct { + Object string `json:"object"` + Data []OpenAIModel `json:"data"` + }{ + Object: "list", + Data: dataModels, + }) + } +} diff --git a/api/openai/request.go b/api/openai/request.go new file mode 100644 index 00000000..84dbaa8e --- /dev/null +++ b/api/openai/request.go @@ -0,0 +1,234 @@ +package openai + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + + config "github.com/go-skynet/LocalAI/api/config" + model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/gofiber/fiber/v2" + "github.com/rs/zerolog/log" +) + +func readInput(c *fiber.Ctx, loader *model.ModelLoader, randomModel bool) (string, *OpenAIRequest, error) { + input := new(OpenAIRequest) + // Get input data from the request body + if err := c.BodyParser(input); err != nil { + return "", nil, err + } + + modelFile := input.Model + + if c.Params("model") != "" { + modelFile = c.Params("model") + } + + received, _ := json.Marshal(input) + + log.Debug().Msgf("Request received: %s", string(received)) + + // Set model from bearer token, if available + bearer := strings.TrimLeft(c.Get("authorization"), "Bearer ") + bearerExists := bearer != "" && loader.ExistsInModelPath(bearer) + + // If no model was specified, take the first available + if modelFile == "" && !bearerExists && randomModel { + models, _ := loader.ListModels() + if len(models) > 0 { + modelFile = models[0] + log.Debug().Msgf("No model specified, using: %s", modelFile) + } else { + log.Debug().Msgf("No model specified, returning error") + return "", nil, fmt.Errorf("no model specified") + } + } + + // If a model is found in bearer token takes precedence + if bearerExists { + log.Debug().Msgf("Using model from bearer token: %s", bearer) + modelFile = bearer + } + return modelFile, input, nil +} + +func updateConfig(config *config.Config, input *OpenAIRequest) { + if input.Echo { + config.Echo = input.Echo + } + if input.TopK != 0 { + config.TopK = input.TopK + } + if input.TopP != 0 { + config.TopP = input.TopP + } + + if input.Grammar != "" { + config.Grammar = input.Grammar + } + + if input.Temperature != 0 { + config.Temperature = input.Temperature + } + + if input.Maxtokens != 0 { + config.Maxtokens = input.Maxtokens + } + + switch stop := input.Stop.(type) { + case string: + if stop != "" { + config.StopWords = append(config.StopWords, stop) + } + case []interface{}: + for _, pp := range stop { + if s, ok := pp.(string); ok { + config.StopWords = append(config.StopWords, s) + } + } + } + + if input.RepeatPenalty != 0 { + config.RepeatPenalty = input.RepeatPenalty + } + + if input.Keep != 0 { + config.Keep = input.Keep + } + + if input.Batch != 0 { + config.Batch = input.Batch + } + + if input.F16 { + config.F16 = input.F16 + } + + if input.IgnoreEOS { + config.IgnoreEOS = input.IgnoreEOS + } + + if input.Seed != 0 { + config.Seed = input.Seed + } + + if input.Mirostat != 0 { + config.Mirostat = input.Mirostat + } + + if input.MirostatETA != 0 { + config.MirostatETA = input.MirostatETA + } + + if input.MirostatTAU != 0 { + config.MirostatTAU = input.MirostatTAU + } + + if input.TypicalP != 0 { + config.TypicalP = input.TypicalP + } + + switch inputs := input.Input.(type) { + case string: + if inputs != "" { + config.InputStrings = append(config.InputStrings, inputs) + } + case []interface{}: + for _, pp := range inputs { + switch i := pp.(type) { + case string: + config.InputStrings = append(config.InputStrings, i) + case []interface{}: + tokens := []int{} + for _, ii := range i { + tokens = append(tokens, int(ii.(float64))) + } + config.InputToken = append(config.InputToken, tokens) + } + } + } + + // Can be either a string or an object + switch fnc := input.FunctionCall.(type) { + case string: + if fnc != "" { + config.SetFunctionCallString(fnc) + } + case map[string]interface{}: + var name string + n, exists := fnc["name"] + if exists { + nn, e := n.(string) + if !e { + name = nn + } + } + config.SetFunctionCallNameString(name) + } + + switch p := input.Prompt.(type) { + case string: + config.PromptStrings = append(config.PromptStrings, p) + case []interface{}: + for _, pp := range p { + if s, ok := pp.(string); ok { + config.PromptStrings = append(config.PromptStrings, s) + } + } + } +} + +func readConfig(modelFile string, input *OpenAIRequest, cm *config.ConfigLoader, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*config.Config, *OpenAIRequest, error) { + // Load a config file if present after the model name + modelConfig := filepath.Join(loader.ModelPath, modelFile+".yaml") + + var cfg *config.Config + + defaults := func() { + cfg = config.DefaultConfig(modelFile) + cfg.ContextSize = ctx + cfg.Threads = threads + cfg.F16 = f16 + cfg.Debug = debug + } + + cfgExisting, exists := cm.GetConfig(modelFile) + if !exists { + if _, err := os.Stat(modelConfig); err == nil { + if err := cm.LoadConfig(modelConfig); err != nil { + return nil, nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error()) + } + cfgExisting, exists = cm.GetConfig(modelFile) + if exists { + cfg = &cfgExisting + } else { + defaults() + } + } else { + defaults() + } + } else { + cfg = &cfgExisting + } + + // Set the parameters for the language model prediction + updateConfig(cfg, input) + + // Don't allow 0 as setting + if cfg.Threads == 0 { + if threads != 0 { + cfg.Threads = threads + } else { + cfg.Threads = 4 + } + } + + // Enforce debug flag if passed from CLI + if debug { + cfg.Debug = true + } + + return cfg, input, nil +} diff --git a/api/openai/transcription.go b/api/openai/transcription.go new file mode 100644 index 00000000..279f320a --- /dev/null +++ b/api/openai/transcription.go @@ -0,0 +1,91 @@ +package openai + +import ( + "fmt" + "io" + "net/http" + "os" + "path" + "path/filepath" + + "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/options" + model "github.com/go-skynet/LocalAI/pkg/model" + whisperutil "github.com/go-skynet/LocalAI/pkg/whisper" + + "github.com/gofiber/fiber/v2" + "github.com/rs/zerolog/log" +) + +// https://platform.openai.com/docs/api-reference/audio/create +func TranscriptEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + m, input, err := readInput(c, o.Loader, false) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + config, input, err := readConfig(m, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + // retrieve the file data from the request + file, err := c.FormFile("file") + if err != nil { + return err + } + f, err := file.Open() + if err != nil { + return err + } + defer f.Close() + + dir, err := os.MkdirTemp("", "whisper") + + if err != nil { + return err + } + defer os.RemoveAll(dir) + + dst := filepath.Join(dir, path.Base(file.Filename)) + dstFile, err := os.Create(dst) + if err != nil { + return err + } + + if _, err := io.Copy(dstFile, f); err != nil { + log.Debug().Msgf("Audio file copying error %+v - %+v - err %+v", file.Filename, dst, err) + return err + } + + log.Debug().Msgf("Audio file copied to: %+v", dst) + + whisperModel, err := o.Loader.BackendLoader( + model.WithBackendString(model.WhisperBackend), + model.WithModelFile(config.Model), + model.WithThreads(uint32(config.Threads)), + model.WithAssetDir(o.AssetsDestination)) + if err != nil { + return err + } + + if whisperModel == nil { + return fmt.Errorf("could not load whisper model") + } + + w, ok := whisperModel.(whisper.Model) + if !ok { + return fmt.Errorf("loader returned non-whisper object") + } + + tr, err := whisperutil.Transcript(w, dst, input.Language, uint(config.Threads)) + if err != nil { + return err + } + + log.Debug().Msgf("Trascribed: %+v", tr) + // TODO: handle different outputs here + return c.Status(http.StatusOK).JSON(fiber.Map{"text": tr}) + } +} diff --git a/api/options.go b/api/options/options.go similarity index 60% rename from api/options.go rename to api/options/options.go index 923288ac..06029b04 100644 --- a/api/options.go +++ b/api/options/options.go @@ -1,4 +1,4 @@ -package api +package options import ( "context" @@ -11,35 +11,35 @@ import ( ) type Option struct { - context context.Context - configFile string - loader *model.ModelLoader - uploadLimitMB, threads, ctxSize int - f16 bool - debug, disableMessage bool - imageDir string - audioDir string - cors bool - preloadJSONModels string - preloadModelsFromPath string - corsAllowOrigins string + Context context.Context + ConfigFile string + Loader *model.ModelLoader + UploadLimitMB, Threads, ContextSize int + F16 bool + Debug, DisableMessage bool + ImageDir string + AudioDir string + CORS bool + PreloadJSONModels string + PreloadModelsFromPath string + CORSAllowOrigins string - galleries []gallery.Gallery + Galleries []gallery.Gallery - backendAssets embed.FS - assetsDestination string + BackendAssets embed.FS + AssetsDestination string } type AppOption func(*Option) -func newOptions(o ...AppOption) *Option { +func NewOptions(o ...AppOption) *Option { opt := &Option{ - context: context.Background(), - uploadLimitMB: 15, - threads: 1, - ctxSize: 512, - debug: true, - disableMessage: true, + Context: context.Background(), + UploadLimitMB: 15, + Threads: 1, + ContextSize: 512, + Debug: true, + DisableMessage: true, } for _, oo := range o { oo(opt) @@ -49,25 +49,25 @@ func newOptions(o ...AppOption) *Option { func WithCors(b bool) AppOption { return func(o *Option) { - o.cors = b + o.CORS = b } } func WithCorsAllowOrigins(b string) AppOption { return func(o *Option) { - o.corsAllowOrigins = b + o.CORSAllowOrigins = b } } func WithBackendAssetsOutput(out string) AppOption { return func(o *Option) { - o.assetsDestination = out + o.AssetsDestination = out } } func WithBackendAssets(f embed.FS) AppOption { return func(o *Option) { - o.backendAssets = f + o.BackendAssets = f } } @@ -81,89 +81,89 @@ func WithStringGalleries(galls string) AppOption { if err := json.Unmarshal([]byte(galls), &galleries); err != nil { log.Error().Msgf("failed loading galleries: %s", err.Error()) } - o.galleries = append(o.galleries, galleries...) + o.Galleries = append(o.Galleries, galleries...) } } func WithGalleries(galleries []gallery.Gallery) AppOption { return func(o *Option) { - o.galleries = append(o.galleries, galleries...) + o.Galleries = append(o.Galleries, galleries...) } } func WithContext(ctx context.Context) AppOption { return func(o *Option) { - o.context = ctx + o.Context = ctx } } func WithYAMLConfigPreload(configFile string) AppOption { return func(o *Option) { - o.preloadModelsFromPath = configFile + o.PreloadModelsFromPath = configFile } } func WithJSONStringPreload(configFile string) AppOption { return func(o *Option) { - o.preloadJSONModels = configFile + o.PreloadJSONModels = configFile } } func WithConfigFile(configFile string) AppOption { return func(o *Option) { - o.configFile = configFile + o.ConfigFile = configFile } } func WithModelLoader(loader *model.ModelLoader) AppOption { return func(o *Option) { - o.loader = loader + o.Loader = loader } } func WithUploadLimitMB(limit int) AppOption { return func(o *Option) { - o.uploadLimitMB = limit + o.UploadLimitMB = limit } } func WithThreads(threads int) AppOption { return func(o *Option) { - o.threads = threads + o.Threads = threads } } func WithContextSize(ctxSize int) AppOption { return func(o *Option) { - o.ctxSize = ctxSize + o.ContextSize = ctxSize } } func WithF16(f16 bool) AppOption { return func(o *Option) { - o.f16 = f16 + o.F16 = f16 } } func WithDebug(debug bool) AppOption { return func(o *Option) { - o.debug = debug + o.Debug = debug } } func WithDisableMessage(disableMessage bool) AppOption { return func(o *Option) { - o.disableMessage = disableMessage + o.DisableMessage = disableMessage } } func WithAudioDir(audioDir string) AppOption { return func(o *Option) { - o.audioDir = audioDir + o.AudioDir = audioDir } } func WithImageDir(imageDir string) AppOption { return func(o *Option) { - o.imageDir = imageDir + o.ImageDir = imageDir } } diff --git a/api/prediction.go b/api/prediction.go deleted file mode 100644 index 4a9c1c84..00000000 --- a/api/prediction.go +++ /dev/null @@ -1,415 +0,0 @@ -package api - -import ( - "context" - "fmt" - "os" - "path/filepath" - "regexp" - "strings" - "sync" - - "github.com/donomii/go-rwkv.cpp" - "github.com/go-skynet/LocalAI/pkg/grpc" - pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" - "github.com/go-skynet/LocalAI/pkg/langchain" - model "github.com/go-skynet/LocalAI/pkg/model" - "github.com/go-skynet/LocalAI/pkg/stablediffusion" - "github.com/go-skynet/bloomz.cpp" - bert "github.com/go-skynet/go-bert.cpp" -) - -// mutex still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 -var mutexMap sync.Mutex -var mutexes map[string]*sync.Mutex = make(map[string]*sync.Mutex) - -func gRPCModelOpts(c Config) *pb.ModelOptions { - b := 512 - if c.Batch != 0 { - b = c.Batch - } - return &pb.ModelOptions{ - ContextSize: int32(c.ContextSize), - Seed: int32(c.Seed), - NBatch: int32(b), - F16Memory: c.F16, - MLock: c.MMlock, - NUMA: c.NUMA, - Embeddings: c.Embeddings, - LowVRAM: c.LowVRAM, - NGPULayers: int32(c.NGPULayers), - MMap: c.MMap, - MainGPU: c.MainGPU, - Threads: int32(c.Threads), - TensorSplit: c.TensorSplit, - } -} - -func gRPCPredictOpts(c Config, modelPath string) *pb.PredictOptions { - promptCachePath := "" - if c.PromptCachePath != "" { - p := filepath.Join(modelPath, c.PromptCachePath) - os.MkdirAll(filepath.Dir(p), 0755) - promptCachePath = p - } - return &pb.PredictOptions{ - Temperature: float32(c.Temperature), - TopP: float32(c.TopP), - TopK: int32(c.TopK), - Tokens: int32(c.Maxtokens), - Threads: int32(c.Threads), - PromptCacheAll: c.PromptCacheAll, - PromptCacheRO: c.PromptCacheRO, - PromptCachePath: promptCachePath, - F16KV: c.F16, - DebugMode: c.Debug, - Grammar: c.Grammar, - - Mirostat: int32(c.Mirostat), - MirostatETA: float32(c.MirostatETA), - MirostatTAU: float32(c.MirostatTAU), - Debug: c.Debug, - StopPrompts: c.StopWords, - Repeat: int32(c.RepeatPenalty), - NKeep: int32(c.Keep), - Batch: int32(c.Batch), - IgnoreEOS: c.IgnoreEOS, - Seed: int32(c.Seed), - FrequencyPenalty: float32(c.FrequencyPenalty), - MLock: c.MMlock, - MMap: c.MMap, - MainGPU: c.MainGPU, - TensorSplit: c.TensorSplit, - TailFreeSamplingZ: float32(c.TFZ), - TypicalP: float32(c.TypicalP), - } -} - -func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, dst string, loader *model.ModelLoader, c Config, o *Option) (func() error, error) { - if c.Backend != model.StableDiffusionBackend { - return nil, fmt.Errorf("endpoint only working with stablediffusion models") - } - - inferenceModel, err := loader.BackendLoader( - model.WithBackendString(c.Backend), - model.WithAssetDir(o.assetsDestination), - model.WithThreads(uint32(c.Threads)), - model.WithModelFile(c.ImageGenerationAssets), - ) - if err != nil { - return nil, err - } - - var fn func() error - switch model := inferenceModel.(type) { - case *stablediffusion.StableDiffusion: - fn = func() error { - return model.GenerateImage(height, width, mode, step, seed, positive_prompt, negative_prompt, dst) - } - - default: - fn = func() error { - return fmt.Errorf("creation of images not supported by the backend") - } - } - - return func() error { - // This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 - mutexMap.Lock() - l, ok := mutexes[c.Backend] - if !ok { - m := &sync.Mutex{} - mutexes[c.Backend] = m - l = m - } - mutexMap.Unlock() - l.Lock() - defer l.Unlock() - - return fn() - }, nil -} - -func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c Config, o *Option) (func() ([]float32, error), error) { - if !c.Embeddings { - return nil, fmt.Errorf("endpoint disabled for this model by API configuration") - } - - modelFile := c.Model - - grpcOpts := gRPCModelOpts(c) - - var inferenceModel interface{} - var err error - - opts := []model.Option{ - model.WithLoadGRPCOpts(grpcOpts), - model.WithThreads(uint32(c.Threads)), - model.WithAssetDir(o.assetsDestination), - model.WithModelFile(modelFile), - } - - if c.Backend == "" { - inferenceModel, err = loader.GreedyLoader(opts...) - } else { - opts = append(opts, model.WithBackendString(c.Backend)) - inferenceModel, err = loader.BackendLoader(opts...) - } - if err != nil { - return nil, err - } - - var fn func() ([]float32, error) - switch model := inferenceModel.(type) { - case *grpc.Client: - fn = func() ([]float32, error) { - predictOptions := gRPCPredictOpts(c, loader.ModelPath) - if len(tokens) > 0 { - embeds := []int32{} - - for _, t := range tokens { - embeds = append(embeds, int32(t)) - } - predictOptions.EmbeddingTokens = embeds - - res, err := model.Embeddings(context.TODO(), predictOptions) - if err != nil { - return nil, err - } - - return res.Embeddings, nil - } - predictOptions.Embeddings = s - - res, err := model.Embeddings(context.TODO(), predictOptions) - if err != nil { - return nil, err - } - - return res.Embeddings, nil - } - - // bert embeddings - case *bert.Bert: - fn = func() ([]float32, error) { - if len(tokens) > 0 { - return model.TokenEmbeddings(tokens, bert.SetThreads(c.Threads)) - } - return model.Embeddings(s, bert.SetThreads(c.Threads)) - } - default: - fn = func() ([]float32, error) { - return nil, fmt.Errorf("embeddings not supported by the backend") - } - } - - return func() ([]float32, error) { - // This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 - mutexMap.Lock() - l, ok := mutexes[modelFile] - if !ok { - m := &sync.Mutex{} - mutexes[modelFile] = m - l = m - } - mutexMap.Unlock() - l.Lock() - defer l.Unlock() - - embeds, err := fn() - if err != nil { - return embeds, err - } - // Remove trailing 0s - for i := len(embeds) - 1; i >= 0; i-- { - if embeds[i] == 0.0 { - embeds = embeds[:i] - } else { - break - } - } - return embeds, nil - }, nil -} - -func ModelInference(s string, loader *model.ModelLoader, c Config, o *Option, tokenCallback func(string) bool) (func() (string, error), error) { - supportStreams := false - modelFile := c.Model - - grpcOpts := gRPCModelOpts(c) - - var inferenceModel interface{} - var err error - - opts := []model.Option{ - model.WithLoadGRPCOpts(grpcOpts), - model.WithThreads(uint32(c.Threads)), // GPT4all uses this - model.WithAssetDir(o.assetsDestination), - model.WithModelFile(modelFile), - } - - if c.Backend == "" { - inferenceModel, err = loader.GreedyLoader(opts...) - } else { - opts = append(opts, model.WithBackendString(c.Backend)) - inferenceModel, err = loader.BackendLoader(opts...) - } - if err != nil { - return nil, err - } - - var fn func() (string, error) - - switch model := inferenceModel.(type) { - case *rwkv.RwkvState: - supportStreams = true - - fn = func() (string, error) { - stopWord := "\n" - if len(c.StopWords) > 0 { - stopWord = c.StopWords[0] - } - - if err := model.ProcessInput(s); err != nil { - return "", err - } - - response := model.GenerateResponse(c.Maxtokens, stopWord, float32(c.Temperature), float32(c.TopP), tokenCallback) - - return response, nil - } - case *bloomz.Bloomz: - fn = func() (string, error) { - // Generate the prediction using the language model - predictOptions := []bloomz.PredictOption{ - bloomz.SetTemperature(c.Temperature), - bloomz.SetTopP(c.TopP), - bloomz.SetTopK(c.TopK), - bloomz.SetTokens(c.Maxtokens), - bloomz.SetThreads(c.Threads), - } - - if c.Seed != 0 { - predictOptions = append(predictOptions, bloomz.SetSeed(c.Seed)) - } - - return model.Predict( - s, - predictOptions..., - ) - } - - case *grpc.Client: - // in GRPC, the backend is supposed to answer to 1 single token if stream is not supported - supportStreams = true - fn = func() (string, error) { - - opts := gRPCPredictOpts(c, loader.ModelPath) - opts.Prompt = s - if tokenCallback != nil { - ss := "" - err := model.PredictStream(context.TODO(), opts, func(s string) { - tokenCallback(s) - ss += s - }) - return ss, err - } else { - reply, err := model.Predict(context.TODO(), opts) - return reply.Message, err - } - } - case *langchain.HuggingFace: - fn = func() (string, error) { - - // Generate the prediction using the language model - predictOptions := []langchain.PredictOption{ - langchain.SetModel(c.Model), - langchain.SetMaxTokens(c.Maxtokens), - langchain.SetTemperature(c.Temperature), - langchain.SetStopWords(c.StopWords), - } - - pred, er := model.PredictHuggingFace(s, predictOptions...) - if er != nil { - return "", er - } - return pred.Completion, nil - } - } - - return func() (string, error) { - // This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 - mutexMap.Lock() - l, ok := mutexes[modelFile] - if !ok { - m := &sync.Mutex{} - mutexes[modelFile] = m - l = m - } - mutexMap.Unlock() - l.Lock() - defer l.Unlock() - - res, err := fn() - if tokenCallback != nil && !supportStreams { - tokenCallback(res) - } - return res, err - }, nil -} - -func ComputeChoices(predInput string, input *OpenAIRequest, config *Config, o *Option, loader *model.ModelLoader, cb func(string, *[]Choice), tokenCallback func(string) bool) ([]Choice, error) { - result := []Choice{} - - n := input.N - - if input.N == 0 { - n = 1 - } - - // get the model function to call for the result - predFunc, err := ModelInference(predInput, loader, *config, o, tokenCallback) - if err != nil { - return result, err - } - - for i := 0; i < n; i++ { - prediction, err := predFunc() - if err != nil { - return result, err - } - - prediction = Finetune(*config, predInput, prediction) - cb(prediction, &result) - - //result = append(result, Choice{Text: prediction}) - - } - return result, err -} - -var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp) -var mu sync.Mutex = sync.Mutex{} - -func Finetune(config Config, input, prediction string) string { - if config.Echo { - prediction = input + prediction - } - - for _, c := range config.Cutstrings { - mu.Lock() - reg, ok := cutstrings[c] - if !ok { - cutstrings[c] = regexp.MustCompile(c) - reg = cutstrings[c] - } - mu.Unlock() - prediction = reg.ReplaceAllString(prediction, "") - } - - for _, c := range config.TrimSpace { - prediction = strings.TrimSpace(strings.TrimPrefix(prediction, c)) - } - return prediction - -} diff --git a/main.go b/main.go index fc1dea09..ec38afe5 100644 --- a/main.go +++ b/main.go @@ -5,6 +5,7 @@ import ( "path/filepath" api "github.com/go-skynet/LocalAI/api" + "github.com/go-skynet/LocalAI/api/options" "github.com/go-skynet/LocalAI/internal" model "github.com/go-skynet/LocalAI/pkg/model" "github.com/rs/zerolog" @@ -129,23 +130,23 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit Copyright: "Ettore Di Giacinto", Action: func(ctx *cli.Context) error { app, err := api.App( - api.WithConfigFile(ctx.String("config-file")), - api.WithJSONStringPreload(ctx.String("preload-models")), - api.WithYAMLConfigPreload(ctx.String("preload-models-config")), - api.WithModelLoader(model.NewModelLoader(ctx.String("models-path"))), - api.WithContextSize(ctx.Int("context-size")), - api.WithDebug(ctx.Bool("debug")), - api.WithImageDir(ctx.String("image-path")), - api.WithAudioDir(ctx.String("audio-path")), - api.WithF16(ctx.Bool("f16")), - api.WithStringGalleries(ctx.String("galleries")), - api.WithDisableMessage(false), - api.WithCors(ctx.Bool("cors")), - api.WithCorsAllowOrigins(ctx.String("cors-allow-origins")), - api.WithThreads(ctx.Int("threads")), - api.WithBackendAssets(backendAssets), - api.WithBackendAssetsOutput(ctx.String("backend-assets-path")), - api.WithUploadLimitMB(ctx.Int("upload-limit"))) + options.WithConfigFile(ctx.String("config-file")), + options.WithJSONStringPreload(ctx.String("preload-models")), + options.WithYAMLConfigPreload(ctx.String("preload-models-config")), + options.WithModelLoader(model.NewModelLoader(ctx.String("models-path"))), + options.WithContextSize(ctx.Int("context-size")), + options.WithDebug(ctx.Bool("debug")), + options.WithImageDir(ctx.String("image-path")), + options.WithAudioDir(ctx.String("audio-path")), + options.WithF16(ctx.Bool("f16")), + options.WithStringGalleries(ctx.String("galleries")), + options.WithDisableMessage(false), + options.WithCors(ctx.Bool("cors")), + options.WithCorsAllowOrigins(ctx.String("cors-allow-origins")), + options.WithThreads(ctx.Int("threads")), + options.WithBackendAssets(backendAssets), + options.WithBackendAssetsOutput(ctx.String("backend-assets-path")), + options.WithUploadLimitMB(ctx.Int("upload-limit"))) if err != nil { return err } diff --git a/pkg/grpc/llm/falcon/falcon.go b/pkg/grpc/llm/falcon/falcon.go index 5d8cf759..0a7a5334 100644 --- a/pkg/grpc/llm/falcon/falcon.go +++ b/pkg/grpc/llm/falcon/falcon.go @@ -126,6 +126,9 @@ func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) { predictOptions := buildPredictOptions(opts) predictOptions = append(predictOptions, ggllm.SetTokenCallback(func(token string) bool { + if token == "<|endoftext|>" { + return true + } results <- token return true }))