From 47743b74abd0f0265a5329dc715bc06b50ac0264 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Wed, 17 Apr 2024 23:36:17 +0200 Subject: [PATCH] Revert "Revert #1963 (#2056)" This reverts commit af9e5a2d05d477eedaf1bff08370208d2b4a9d86. --- .github/workflows/test.yml | 15 +- Makefile | 18 +- backend/go/transcribe/transcript.go | 6 +- backend/go/transcribe/whisper.go | 2 +- core/backend/embeddings.go | 90 +- core/backend/image.go | 261 +++++- core/backend/llm.go | 271 ++++-- core/backend/options.go | 84 +- core/backend/transcript.go | 41 +- core/backend/tts.go | 77 +- core/cli/run.go | 8 +- core/cli/transcript.go | 19 +- core/cli/tts.go | 26 +- core/config/backend_config.go | 301 +------ core/config/backend_config_loader.go | 509 +++++++++++ core/config/exports_test.go | 6 + core/http/api.go | 197 ++--- core/http/api_test.go | 98 ++- core/http/ctx/fiber.go | 65 +- core/http/endpoints/elevenlabs/tts.go | 39 +- .../http/endpoints/localai/backend_monitor.go | 4 +- core/http/endpoints/localai/tts.go | 39 +- core/http/endpoints/openai/assistant.go | 2 +- core/http/endpoints/openai/chat.go | 621 ++------------ core/http/endpoints/openai/completion.go | 163 +--- core/http/endpoints/openai/edit.go | 78 +- core/http/endpoints/openai/embeddings.go | 65 +- core/http/endpoints/openai/image.go | 218 +---- core/http/endpoints/openai/inference.go | 55 -- core/http/endpoints/openai/list.go | 52 +- core/http/endpoints/openai/request.go | 285 ------ core/http/endpoints/openai/transcription.go | 28 +- core/schema/{whisper.go => transcription.go} | 2 +- core/services/backend_monitor.go | 30 +- core/services/gallery.go | 116 ++- core/services/list_models.go | 72 ++ .../services}/model_preload_test.go | 5 +- core/services/openai.go | 808 ++++++++++++++++++ core/startup/startup.go | 91 +- core/state.go | 41 + .../llm text/-completions Stream.bru | 25 + pkg/concurrency/concurrency.go | 135 +++ pkg/concurrency/concurrency_test.go | 101 +++ pkg/concurrency/types.go | 6 + pkg/grpc/backend.go | 2 +- pkg/grpc/base/base.go | 4 +- pkg/grpc/client.go | 4 +- pkg/grpc/embed.go | 4 +- pkg/grpc/interface.go | 2 +- pkg/model/initializers.go | 8 +- pkg/startup/model_preload.go | 85 -- pkg/utils/base64.go | 50 ++ 52 files changed, 3052 insertions(+), 2282 deletions(-) create mode 100644 core/config/backend_config_loader.go create mode 100644 core/config/exports_test.go delete mode 100644 core/http/endpoints/openai/inference.go delete mode 100644 core/http/endpoints/openai/request.go rename core/schema/{whisper.go => transcription.go} (90%) create mode 100644 core/services/list_models.go rename {pkg/startup => core/services}/model_preload_test.go (96%) create mode 100644 core/services/openai.go create mode 100644 core/state.go create mode 100644 examples/bruno/LocalAI Test Requests/llm text/-completions Stream.bru create mode 100644 pkg/concurrency/concurrency.go create mode 100644 pkg/concurrency/concurrency_test.go create mode 100644 pkg/concurrency/types.go delete mode 100644 pkg/startup/model_preload.go create mode 100644 pkg/utils/base64.go diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 46c4e065..156294b5 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -121,8 +121,9 @@ jobs: PATH="$PATH:/root/go/bin" GO_TAGS="stablediffusion tts" make --jobs 5 --output-sync=target test - name: Setup tmate session if tests fail if: ${{ failure() }} - uses: mxschmitt/action-tmate@v3 - timeout-minutes: 5 + uses: mxschmitt/action-tmate@v3.18 + with: + connect-timeout-seconds: 180 tests-aio-container: runs-on: ubuntu-latest @@ -173,8 +174,9 @@ jobs: make run-e2e-aio - name: Setup tmate session if tests fail if: ${{ failure() }} - uses: mxschmitt/action-tmate@v3 - timeout-minutes: 5 + uses: mxschmitt/action-tmate@v3.18 + with: + connect-timeout-seconds: 180 tests-apple: runs-on: macOS-14 @@ -207,5 +209,6 @@ jobs: BUILD_TYPE="GITHUB_CI_HAS_BROKEN_METAL" CMAKE_ARGS="-DLLAMA_F16C=OFF -DLLAMA_AVX512=OFF -DLLAMA_AVX2=OFF -DLLAMA_FMA=OFF" make --jobs 4 --output-sync=target test - name: Setup tmate session if tests fail if: ${{ failure() }} - uses: mxschmitt/action-tmate@v3 - timeout-minutes: 5 \ No newline at end of file + uses: mxschmitt/action-tmate@v3.18 + with: + connect-timeout-seconds: 180 \ No newline at end of file diff --git a/Makefile b/Makefile index 6715e91e..fdc7aade 100644 --- a/Makefile +++ b/Makefile @@ -301,6 +301,9 @@ clean-tests: rm -rf test-dir rm -rf core/http/backend-assets +halt-backends: ## Used to clean up stray backends sometimes left running when debugging manually + ps | grep 'backend-assets/grpc/' | awk '{print $$1}' | xargs -I {} kill -9 {} + ## Build: build: prepare backend-assets grpcs ## Build the project $(info ${GREEN}I local-ai build info:${RESET}) @@ -365,13 +368,13 @@ run-e2e-image: run-e2e-aio: @echo 'Running e2e AIO tests' - $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts 5 -v -r ./tests/e2e-aio + $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e-aio test-e2e: @echo 'Running e2e tests' BUILD_TYPE=$(BUILD_TYPE) \ LOCALAI_API=http://$(E2E_BRIDGE_IP):5390/v1 \ - $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts 5 -v -r ./tests/e2e + $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e teardown-e2e: rm -rf $(TEST_DIR) || true @@ -379,15 +382,15 @@ teardown-e2e: test-gpt4all: prepare-test TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \ - $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="gpt4all" --flake-attempts 5 -v -r $(TEST_PATHS) + $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="gpt4all" --flake-attempts $(TEST_FLAKES) -v -r $(TEST_PATHS) test-llama: prepare-test TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \ - $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="llama" --flake-attempts 5 -v -r $(TEST_PATHS) + $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="llama" --flake-attempts $(TEST_FLAKES) -v -r $(TEST_PATHS) test-llama-gguf: prepare-test TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \ - $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="llama-gguf" --flake-attempts 5 -v -r $(TEST_PATHS) + $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="llama-gguf" --flake-attempts $(TEST_FLAKES) -v -r $(TEST_PATHS) test-tts: prepare-test TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \ @@ -645,7 +648,10 @@ backend-assets/grpc/llama-ggml: sources/go-llama-ggml sources/go-llama-ggml/libb $(GOCMD) mod edit -replace github.com/go-skynet/go-llama.cpp=$(CURDIR)/sources/go-llama-ggml CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(CURDIR)/sources/go-llama-ggml LIBRARY_PATH=$(CURDIR)/sources/go-llama-ggml \ $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/llama-ggml ./backend/go/llm/llama-ggml/ - +# EXPERIMENTAL: +ifeq ($(BUILD_TYPE),metal) + cp $(CURDIR)/sources/go-llama-ggml/llama.cpp/ggml-metal.metal backend-assets/grpc/ +endif backend-assets/grpc/piper: sources/go-piper sources/go-piper/libpiper_binding.a backend-assets/grpc backend-assets/espeak-ng-data CGO_CXXFLAGS="$(PIPER_CGO_CXXFLAGS)" CGO_LDFLAGS="$(PIPER_CGO_LDFLAGS)" LIBRARY_PATH=$(CURDIR)/sources/go-piper \ $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/piper ./backend/go/tts/ diff --git a/backend/go/transcribe/transcript.go b/backend/go/transcribe/transcript.go index fdfaa974..b38d5b9f 100644 --- a/backend/go/transcribe/transcript.go +++ b/backend/go/transcribe/transcript.go @@ -21,7 +21,7 @@ func runCommand(command []string) (string, error) { // AudioToWav converts audio to wav for transcribe. // TODO: use https://github.com/mccoyst/ogg? func audioToWav(src, dst string) error { - command := []string{"ffmpeg", "-i", src, "-format", "s16le", "-ar", "16000", "-ac", "1", "-acodec", "pcm_s16le", dst} + command := []string{"ffmpeg", "-i", src, "-format", "s16le", "-ar", "16000", "-ac", "1", "-acodec", "pcm_s16le", dst} out, err := runCommand(command) if err != nil { return fmt.Errorf("error: %w out: %s", err, out) @@ -29,8 +29,8 @@ func audioToWav(src, dst string) error { return nil } -func Transcript(model whisper.Model, audiopath, language string, threads uint) (schema.Result, error) { - res := schema.Result{} +func Transcript(model whisper.Model, audiopath, language string, threads uint) (schema.TranscriptionResult, error) { + res := schema.TranscriptionResult{} dir, err := os.MkdirTemp("", "whisper") if err != nil { diff --git a/backend/go/transcribe/whisper.go b/backend/go/transcribe/whisper.go index ac93be01..a9a62d24 100644 --- a/backend/go/transcribe/whisper.go +++ b/backend/go/transcribe/whisper.go @@ -21,6 +21,6 @@ func (sd *Whisper) Load(opts *pb.ModelOptions) error { return err } -func (sd *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (schema.Result, error) { +func (sd *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (schema.TranscriptionResult, error) { return Transcript(sd.whisper, opts.Dst, opts.Language, uint(opts.Threads)) } diff --git a/core/backend/embeddings.go b/core/backend/embeddings.go index 03ff90b9..2c63dedc 100644 --- a/core/backend/embeddings.go +++ b/core/backend/embeddings.go @@ -2,14 +2,100 @@ package backend import ( "fmt" + "time" "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/core/schema" + "github.com/google/uuid" + "github.com/go-skynet/LocalAI/pkg/concurrency" "github.com/go-skynet/LocalAI/pkg/grpc" - model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/go-skynet/LocalAI/pkg/model" ) -func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) { +type EmbeddingsBackendService struct { + ml *model.ModelLoader + bcl *config.BackendConfigLoader + appConfig *config.ApplicationConfig +} + +func NewEmbeddingsBackendService(ml *model.ModelLoader, bcl *config.BackendConfigLoader, appConfig *config.ApplicationConfig) *EmbeddingsBackendService { + return &EmbeddingsBackendService{ + ml: ml, + bcl: bcl, + appConfig: appConfig, + } +} + +func (ebs *EmbeddingsBackendService) Embeddings(request *schema.OpenAIRequest) <-chan concurrency.ErrorOr[*schema.OpenAIResponse] { + + resultChannel := make(chan concurrency.ErrorOr[*schema.OpenAIResponse]) + go func(request *schema.OpenAIRequest) { + if request.Model == "" { + request.Model = model.StableDiffusionBackend + } + + bc, request, err := ebs.bcl.LoadBackendConfigForModelAndOpenAIRequest(request.Model, request, ebs.appConfig) + if err != nil { + resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err} + close(resultChannel) + return + } + + items := []schema.Item{} + + for i, s := range bc.InputToken { + // get the model function to call for the result + embedFn, err := modelEmbedding("", s, ebs.ml, bc, ebs.appConfig) + if err != nil { + resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err} + close(resultChannel) + return + } + + embeddings, err := embedFn() + if err != nil { + resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err} + close(resultChannel) + return + } + items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"}) + } + + for i, s := range bc.InputStrings { + // get the model function to call for the result + embedFn, err := modelEmbedding(s, []int{}, ebs.ml, bc, ebs.appConfig) + if err != nil { + resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err} + close(resultChannel) + return + } + + embeddings, err := embedFn() + if err != nil { + resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err} + close(resultChannel) + return + } + items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"}) + } + + id := uuid.New().String() + created := int(time.Now().Unix()) + resp := &schema.OpenAIResponse{ + ID: id, + Created: created, + Model: request.Model, // we have to return what the user sent here, due to OpenAI spec. + Data: items, + Object: "list", + } + resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Value: resp} + close(resultChannel) + }(request) + return resultChannel +} + +func modelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendConfig *config.BackendConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) { modelFile := backendConfig.Model grpcOpts := gRPCModelOpts(backendConfig) diff --git a/core/backend/image.go b/core/backend/image.go index b0cffb0b..affb3bb3 100644 --- a/core/backend/image.go +++ b/core/backend/image.go @@ -1,18 +1,252 @@ package backend import ( - "github.com/go-skynet/LocalAI/core/config" + "bufio" + "encoding/base64" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strconv" + "strings" + "time" + "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/core/schema" + "github.com/google/uuid" + "github.com/rs/zerolog/log" + + "github.com/go-skynet/LocalAI/pkg/concurrency" "github.com/go-skynet/LocalAI/pkg/grpc/proto" - model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/go-skynet/LocalAI/pkg/model" ) -func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() error, error) { +type ImageGenerationBackendService struct { + ml *model.ModelLoader + bcl *config.BackendConfigLoader + appConfig *config.ApplicationConfig + BaseUrlForGeneratedImages string +} + +func NewImageGenerationBackendService(ml *model.ModelLoader, bcl *config.BackendConfigLoader, appConfig *config.ApplicationConfig) *ImageGenerationBackendService { + return &ImageGenerationBackendService{ + ml: ml, + bcl: bcl, + appConfig: appConfig, + } +} + +func (igbs *ImageGenerationBackendService) GenerateImage(request *schema.OpenAIRequest) <-chan concurrency.ErrorOr[*schema.OpenAIResponse] { + resultChannel := make(chan concurrency.ErrorOr[*schema.OpenAIResponse]) + go func(request *schema.OpenAIRequest) { + bc, request, err := igbs.bcl.LoadBackendConfigForModelAndOpenAIRequest(request.Model, request, igbs.appConfig) + if err != nil { + resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err} + close(resultChannel) + return + } + + src := "" + if request.File != "" { + + var fileData []byte + // check if input.File is an URL, if so download it and save it + // to a temporary file + if strings.HasPrefix(request.File, "http://") || strings.HasPrefix(request.File, "https://") { + out, err := downloadFile(request.File) + if err != nil { + resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: fmt.Errorf("failed downloading file:%w", err)} + close(resultChannel) + return + } + defer os.RemoveAll(out) + + fileData, err = os.ReadFile(out) + if err != nil { + resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: fmt.Errorf("failed reading file:%w", err)} + close(resultChannel) + return + } + + } else { + // base 64 decode the file and write it somewhere + // that we will cleanup + fileData, err = base64.StdEncoding.DecodeString(request.File) + if err != nil { + resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err} + close(resultChannel) + return + } + } + + // Create a temporary file + outputFile, err := os.CreateTemp(igbs.appConfig.ImageDir, "b64") + if err != nil { + resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err} + close(resultChannel) + return + } + // write the base64 result + writer := bufio.NewWriter(outputFile) + _, err = writer.Write(fileData) + if err != nil { + outputFile.Close() + resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err} + close(resultChannel) + return + } + outputFile.Close() + src = outputFile.Name() + defer os.RemoveAll(src) + } + + log.Debug().Msgf("Parameter Config: %+v", bc) + + switch bc.Backend { + case "stablediffusion": + bc.Backend = model.StableDiffusionBackend + case "tinydream": + bc.Backend = model.TinyDreamBackend + case "": + bc.Backend = model.StableDiffusionBackend + if bc.Model == "" { + bc.Model = "stablediffusion_assets" // TODO: check? + } + } + + sizeParts := strings.Split(request.Size, "x") + if len(sizeParts) != 2 { + resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: fmt.Errorf("invalid value for 'size'")} + close(resultChannel) + return + } + width, err := strconv.Atoi(sizeParts[0]) + if err != nil { + resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: fmt.Errorf("invalid value for 'size'")} + close(resultChannel) + return + } + height, err := strconv.Atoi(sizeParts[1]) + if err != nil { + resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: fmt.Errorf("invalid value for 'size'")} + close(resultChannel) + return + } + + b64JSON := false + if request.ResponseFormat.Type == "b64_json" { + b64JSON = true + } + // src and clip_skip + var result []schema.Item + for _, i := range bc.PromptStrings { + n := request.N + if request.N == 0 { + n = 1 + } + for j := 0; j < n; j++ { + prompts := strings.Split(i, "|") + positive_prompt := prompts[0] + negative_prompt := "" + if len(prompts) > 1 { + negative_prompt = prompts[1] + } + + mode := 0 + step := bc.Step + if step == 0 { + step = 15 + } + + if request.Mode != 0 { + mode = request.Mode + } + + if request.Step != 0 { + step = request.Step + } + + tempDir := "" + if !b64JSON { + tempDir = igbs.appConfig.ImageDir + } + // Create a temporary file + outputFile, err := os.CreateTemp(tempDir, "b64") + if err != nil { + resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err} + close(resultChannel) + return + } + outputFile.Close() + output := outputFile.Name() + ".png" + // Rename the temporary file + err = os.Rename(outputFile.Name(), output) + if err != nil { + resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err} + close(resultChannel) + return + } + + if request.Seed == nil { + zVal := 0 // Idiomatic way to do this? Actually needed? + request.Seed = &zVal + } + + fn, err := imageGeneration(height, width, mode, step, *request.Seed, positive_prompt, negative_prompt, src, output, igbs.ml, bc, igbs.appConfig) + if err != nil { + resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err} + close(resultChannel) + return + } + if err := fn(); err != nil { + resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err} + close(resultChannel) + return + } + + item := &schema.Item{} + + if b64JSON { + defer os.RemoveAll(output) + data, err := os.ReadFile(output) + if err != nil { + resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err} + close(resultChannel) + return + } + item.B64JSON = base64.StdEncoding.EncodeToString(data) + } else { + base := filepath.Base(output) + item.URL = igbs.BaseUrlForGeneratedImages + base + } + + result = append(result, *item) + } + } + + id := uuid.New().String() + created := int(time.Now().Unix()) + resp := &schema.OpenAIResponse{ + ID: id, + Created: created, + Data: result, + } + resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Value: resp} + close(resultChannel) + }(request) + return resultChannel +} + +func imageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, backendConfig *config.BackendConfig, appConfig *config.ApplicationConfig) (func() error, error) { + threads := backendConfig.Threads if *threads == 0 && appConfig.Threads != 0 { threads = &appConfig.Threads } + gRPCOpts := gRPCModelOpts(backendConfig) + opts := modelOpts(backendConfig, appConfig, []model.Option{ model.WithBackendString(backendConfig.Backend), model.WithAssetDir(appConfig.AssetsDestination), @@ -50,3 +284,24 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat return fn, nil } + +// TODO: Replace this function with pkg/downloader - no reason to have a (crappier) bespoke download file fn here, but get things working before that change. +func downloadFile(url string) (string, error) { + // Get the data + resp, err := http.Get(url) + if err != nil { + return "", err + } + defer resp.Body.Close() + + // Create the file + out, err := os.CreateTemp("", "image") + if err != nil { + return "", err + } + defer out.Close() + + // Write the body to file + _, err = io.Copy(out, resp.Body) + return out.Name(), err +} diff --git a/core/backend/llm.go b/core/backend/llm.go index a4d1e5f3..75766d78 100644 --- a/core/backend/llm.go +++ b/core/backend/llm.go @@ -11,17 +11,22 @@ import ( "github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/core/schema" + "github.com/rs/zerolog/log" + "github.com/go-skynet/LocalAI/pkg/concurrency" "github.com/go-skynet/LocalAI/pkg/gallery" "github.com/go-skynet/LocalAI/pkg/grpc" "github.com/go-skynet/LocalAI/pkg/grpc/proto" - model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/go-skynet/LocalAI/pkg/model" "github.com/go-skynet/LocalAI/pkg/utils" ) -type LLMResponse struct { - Response string // should this be []byte? - Usage TokenUsage +type LLMRequest struct { + Id int // TODO Remove if not used. + Text string + Images []string + RawMessages []schema.Message + // TODO: Other Modalities? } type TokenUsage struct { @@ -29,57 +34,94 @@ type TokenUsage struct { Completion int } -func ModelInference(ctx context.Context, s string, messages []schema.Message, images []string, loader *model.ModelLoader, c config.BackendConfig, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) { - modelFile := c.Model - threads := c.Threads - if *threads == 0 && o.Threads != 0 { - threads = &o.Threads +type LLMResponse struct { + Request *LLMRequest + Response string // should this be []byte? + Usage TokenUsage +} + +// TODO: Does this belong here or in core/services/openai.go? +type LLMResponseBundle struct { + Request *schema.OpenAIRequest + Response []schema.Choice + Usage TokenUsage +} + +type LLMBackendService struct { + bcl *config.BackendConfigLoader + ml *model.ModelLoader + appConfig *config.ApplicationConfig + ftMutex sync.Mutex + cutstrings map[string]*regexp.Regexp +} + +func NewLLMBackendService(ml *model.ModelLoader, bcl *config.BackendConfigLoader, appConfig *config.ApplicationConfig) *LLMBackendService { + return &LLMBackendService{ + bcl: bcl, + ml: ml, + appConfig: appConfig, + ftMutex: sync.Mutex{}, + cutstrings: make(map[string]*regexp.Regexp), } - grpcOpts := gRPCModelOpts(c) +} + +// TODO: Should ctx param be removed and replaced with hardcoded req.Context? +func (llmbs *LLMBackendService) Inference(ctx context.Context, req *LLMRequest, bc *config.BackendConfig, enableTokenChannel bool) ( + resultChannel <-chan concurrency.ErrorOr[*LLMResponse], tokenChannel <-chan concurrency.ErrorOr[*LLMResponse], err error) { + + threads := bc.Threads + if (threads == nil || *threads == 0) && llmbs.appConfig.Threads != 0 { + threads = &llmbs.appConfig.Threads + } + + grpcOpts := gRPCModelOpts(bc) var inferenceModel grpc.Backend - var err error - opts := modelOpts(c, o, []model.Option{ + opts := modelOpts(bc, llmbs.appConfig, []model.Option{ model.WithLoadGRPCLoadModelOpts(grpcOpts), model.WithThreads(uint32(*threads)), // some models uses this to allocate threads during startup - model.WithAssetDir(o.AssetsDestination), - model.WithModel(modelFile), - model.WithContext(o.Context), + model.WithAssetDir(llmbs.appConfig.AssetsDestination), + model.WithModel(bc.Model), + model.WithContext(llmbs.appConfig.Context), }) - if c.Backend != "" { - opts = append(opts, model.WithBackendString(c.Backend)) + if bc.Backend != "" { + opts = append(opts, model.WithBackendString(bc.Backend)) } - // Check if the modelFile exists, if it doesn't try to load it from the gallery - if o.AutoloadGalleries { // experimental - if _, err := os.Stat(modelFile); os.IsNotExist(err) { + // Check if bc.Model exists, if it doesn't try to load it from the gallery + if llmbs.appConfig.AutoloadGalleries { // experimental + if _, err := os.Stat(bc.Model); os.IsNotExist(err) { utils.ResetDownloadTimers() // if we failed to load the model, we try to download it - err := gallery.InstallModelFromGalleryByName(o.Galleries, modelFile, loader.ModelPath, gallery.GalleryModel{}, utils.DisplayDownloadFunction) + err := gallery.InstallModelFromGalleryByName(llmbs.appConfig.Galleries, bc.Model, llmbs.appConfig.ModelPath, gallery.GalleryModel{}, utils.DisplayDownloadFunction) if err != nil { - return nil, err + return nil, nil, err } } } - if c.Backend == "" { - inferenceModel, err = loader.GreedyLoader(opts...) + if bc.Backend == "" { + log.Debug().Msgf("backend not known for %q, falling back to greedy loader to find it", bc.Model) + inferenceModel, err = llmbs.ml.GreedyLoader(opts...) } else { - inferenceModel, err = loader.BackendLoader(opts...) + inferenceModel, err = llmbs.ml.BackendLoader(opts...) } if err != nil { - return nil, err + log.Error().Err(err).Msg("[llmbs.Inference] failed to load a backend") + return } - var protoMessages []*proto.Message - // if we are using the tokenizer template, we need to convert the messages to proto messages - // unless the prompt has already been tokenized (non-chat endpoints + functions) - if c.TemplateConfig.UseTokenizerTemplate && s == "" { - protoMessages = make([]*proto.Message, len(messages), len(messages)) - for i, message := range messages { + grpcPredOpts := gRPCPredictOpts(bc, llmbs.appConfig.ModelPath) + grpcPredOpts.Prompt = req.Text + grpcPredOpts.Images = req.Images + + if bc.TemplateConfig.UseTokenizerTemplate && req.Text == "" { + grpcPredOpts.UseTokenizerTemplate = true + protoMessages := make([]*proto.Message, len(req.RawMessages), len(req.RawMessages)) + for i, message := range req.RawMessages { protoMessages[i] = &proto.Message{ Role: message.Role, } @@ -87,47 +129,32 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im case string: protoMessages[i].Content = ct default: - return nil, fmt.Errorf("Unsupported type for schema.Message.Content for inference: %T", ct) + err = fmt.Errorf("unsupported type for schema.Message.Content for inference: %T", ct) + return } } } - // in GRPC, the backend is supposed to answer to 1 single token if stream is not supported - fn := func() (LLMResponse, error) { - opts := gRPCPredictOpts(c, loader.ModelPath) - opts.Prompt = s - opts.Messages = protoMessages - opts.UseTokenizerTemplate = c.TemplateConfig.UseTokenizerTemplate - opts.Images = images + tokenUsage := TokenUsage{} - tokenUsage := TokenUsage{} + promptInfo, pErr := inferenceModel.TokenizeString(ctx, grpcPredOpts) + if pErr == nil && promptInfo.Length > 0 { + tokenUsage.Prompt = int(promptInfo.Length) + } - // check the per-model feature flag for usage, since tokenCallback may have a cost. - // Defaults to off as for now it is still experimental - if c.FeatureFlag.Enabled("usage") { - userTokenCallback := tokenCallback - if userTokenCallback == nil { - userTokenCallback = func(token string, usage TokenUsage) bool { - return true - } - } + rawResultChannel := make(chan concurrency.ErrorOr[*LLMResponse]) + // TODO this next line is the biggest argument for taking named return values _back_ out!!! + var rawTokenChannel chan concurrency.ErrorOr[*LLMResponse] - promptInfo, pErr := inferenceModel.TokenizeString(ctx, opts) - if pErr == nil && promptInfo.Length > 0 { - tokenUsage.Prompt = int(promptInfo.Length) - } + if enableTokenChannel { + rawTokenChannel = make(chan concurrency.ErrorOr[*LLMResponse]) - tokenCallback = func(token string, usage TokenUsage) bool { - tokenUsage.Completion++ - return userTokenCallback(token, tokenUsage) - } - } - - if tokenCallback != nil { - ss := "" + // TODO Needs better name + ss := "" + go func() { var partialRune []byte - err := inferenceModel.PredictStream(ctx, opts, func(chars []byte) { + err := inferenceModel.PredictStream(ctx, grpcPredOpts, func(chars []byte) { partialRune = append(partialRune, chars...) for len(partialRune) > 0 { @@ -137,54 +164,126 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im break } - tokenCallback(string(r), tokenUsage) + tokenUsage.Completion++ + rawTokenChannel <- concurrency.ErrorOr[*LLMResponse]{Value: &LLMResponse{ + Response: string(r), + Usage: tokenUsage, + }} + ss += string(r) partialRune = partialRune[size:] } }) - return LLMResponse{ - Response: ss, - Usage: tokenUsage, - }, err - } else { - // TODO: Is the chicken bit the only way to get here? is that acceptable? - reply, err := inferenceModel.Predict(ctx, opts) + close(rawTokenChannel) if err != nil { - return LLMResponse{}, err + rawResultChannel <- concurrency.ErrorOr[*LLMResponse]{Error: err} + } else { + rawResultChannel <- concurrency.ErrorOr[*LLMResponse]{Value: &LLMResponse{ + Response: ss, + Usage: tokenUsage, + }} } + close(rawResultChannel) + }() + } else { + go func() { + reply, err := inferenceModel.Predict(ctx, grpcPredOpts) if tokenUsage.Prompt == 0 { tokenUsage.Prompt = int(reply.PromptTokens) } if tokenUsage.Completion == 0 { tokenUsage.Completion = int(reply.Tokens) } - return LLMResponse{ - Response: string(reply.Message), - Usage: tokenUsage, - }, err - } + if err != nil { + rawResultChannel <- concurrency.ErrorOr[*LLMResponse]{Error: err} + close(rawResultChannel) + } else { + rawResultChannel <- concurrency.ErrorOr[*LLMResponse]{Value: &LLMResponse{ + Response: string(reply.Message), + Usage: tokenUsage, + }} + close(rawResultChannel) + } + }() } - return fn, nil + resultChannel = rawResultChannel + tokenChannel = rawTokenChannel + return } -var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp) -var mu sync.Mutex = sync.Mutex{} +// TODO: Should predInput be a seperate param still, or should this fn handle extracting it from request?? +func (llmbs *LLMBackendService) GenerateText(predInput string, request *schema.OpenAIRequest, bc *config.BackendConfig, + mappingFn func(*LLMResponse) schema.Choice, enableCompletionChannels bool, enableTokenChannels bool) ( + // Returns: + resultChannel <-chan concurrency.ErrorOr[*LLMResponseBundle], completionChannels []<-chan concurrency.ErrorOr[*LLMResponse], tokenChannels []<-chan concurrency.ErrorOr[*LLMResponse], err error) { -func Finetune(config config.BackendConfig, input, prediction string) string { + rawChannel := make(chan concurrency.ErrorOr[*LLMResponseBundle]) + resultChannel = rawChannel + + if request.N == 0 { // number of completions to return + request.N = 1 + } + images := []string{} + for _, m := range request.Messages { + images = append(images, m.StringImages...) + } + + for i := 0; i < request.N; i++ { + + individualResultChannel, tokenChannel, infErr := llmbs.Inference(request.Context, &LLMRequest{ + Text: predInput, + Images: images, + RawMessages: request.Messages, + }, bc, enableTokenChannels) + if infErr != nil { + err = infErr // Avoids complaints about redeclaring err but looks dumb + return + } + completionChannels = append(completionChannels, individualResultChannel) + tokenChannels = append(tokenChannels, tokenChannel) + } + + go func() { + initialBundle := LLMResponseBundle{ + Request: request, + Response: []schema.Choice{}, + Usage: TokenUsage{}, + } + + wg := concurrency.SliceOfChannelsReducer(completionChannels, rawChannel, func(iv concurrency.ErrorOr[*LLMResponse], ov concurrency.ErrorOr[*LLMResponseBundle]) concurrency.ErrorOr[*LLMResponseBundle] { + if iv.Error != nil { + ov.Error = iv.Error + // TODO: Decide if we should wipe partials or not? + return ov + } + ov.Value.Usage.Prompt += iv.Value.Usage.Prompt + ov.Value.Usage.Completion += iv.Value.Usage.Completion + + ov.Value.Response = append(ov.Value.Response, mappingFn(iv.Value)) + return ov + }, concurrency.ErrorOr[*LLMResponseBundle]{Value: &initialBundle}, true) + wg.Wait() + + }() + + return +} + +func (llmbs *LLMBackendService) Finetune(config config.BackendConfig, input, prediction string) string { if config.Echo { prediction = input + prediction } for _, c := range config.Cutstrings { - mu.Lock() - reg, ok := cutstrings[c] + llmbs.ftMutex.Lock() + reg, ok := llmbs.cutstrings[c] if !ok { - cutstrings[c] = regexp.MustCompile(c) - reg = cutstrings[c] + llmbs.cutstrings[c] = regexp.MustCompile(c) + reg = llmbs.cutstrings[c] } - mu.Unlock() + llmbs.ftMutex.Unlock() prediction = reg.ReplaceAllString(prediction, "") } diff --git a/core/backend/options.go b/core/backend/options.go index 5b303b05..0b4e56db 100644 --- a/core/backend/options.go +++ b/core/backend/options.go @@ -10,7 +10,7 @@ import ( model "github.com/go-skynet/LocalAI/pkg/model" ) -func modelOpts(c config.BackendConfig, so *config.ApplicationConfig, opts []model.Option) []model.Option { +func modelOpts(bc *config.BackendConfig, so *config.ApplicationConfig, opts []model.Option) []model.Option { if so.SingleBackend { opts = append(opts, model.WithSingleActiveBackend()) } @@ -19,12 +19,12 @@ func modelOpts(c config.BackendConfig, so *config.ApplicationConfig, opts []mode opts = append(opts, model.EnableParallelRequests) } - if c.GRPC.Attempts != 0 { - opts = append(opts, model.WithGRPCAttempts(c.GRPC.Attempts)) + if bc.GRPC.Attempts != 0 { + opts = append(opts, model.WithGRPCAttempts(bc.GRPC.Attempts)) } - if c.GRPC.AttemptsSleepTime != 0 { - opts = append(opts, model.WithGRPCAttemptsDelay(c.GRPC.AttemptsSleepTime)) + if bc.GRPC.AttemptsSleepTime != 0 { + opts = append(opts, model.WithGRPCAttemptsDelay(bc.GRPC.AttemptsSleepTime)) } for k, v := range so.ExternalGRPCBackends { @@ -34,7 +34,7 @@ func modelOpts(c config.BackendConfig, so *config.ApplicationConfig, opts []mode return opts } -func getSeed(c config.BackendConfig) int32 { +func getSeed(c *config.BackendConfig) int32 { seed := int32(*c.Seed) if seed == config.RAND_SEED { seed = rand.Int31() @@ -43,7 +43,7 @@ func getSeed(c config.BackendConfig) int32 { return seed } -func gRPCModelOpts(c config.BackendConfig) *pb.ModelOptions { +func gRPCModelOpts(c *config.BackendConfig) *pb.ModelOptions { b := 512 if c.Batch != 0 { b = c.Batch @@ -104,47 +104,47 @@ func gRPCModelOpts(c config.BackendConfig) *pb.ModelOptions { } } -func gRPCPredictOpts(c config.BackendConfig, modelPath string) *pb.PredictOptions { +func gRPCPredictOpts(bc *config.BackendConfig, modelPath string) *pb.PredictOptions { promptCachePath := "" - if c.PromptCachePath != "" { - p := filepath.Join(modelPath, c.PromptCachePath) + if bc.PromptCachePath != "" { + p := filepath.Join(modelPath, bc.PromptCachePath) os.MkdirAll(filepath.Dir(p), 0755) promptCachePath = p } return &pb.PredictOptions{ - Temperature: float32(*c.Temperature), - TopP: float32(*c.TopP), - NDraft: c.NDraft, - TopK: int32(*c.TopK), - Tokens: int32(*c.Maxtokens), - Threads: int32(*c.Threads), - PromptCacheAll: c.PromptCacheAll, - PromptCacheRO: c.PromptCacheRO, + Temperature: float32(*bc.Temperature), + TopP: float32(*bc.TopP), + NDraft: bc.NDraft, + TopK: int32(*bc.TopK), + Tokens: int32(*bc.Maxtokens), + Threads: int32(*bc.Threads), + PromptCacheAll: bc.PromptCacheAll, + PromptCacheRO: bc.PromptCacheRO, PromptCachePath: promptCachePath, - F16KV: *c.F16, - DebugMode: *c.Debug, - Grammar: c.Grammar, - NegativePromptScale: c.NegativePromptScale, - RopeFreqBase: c.RopeFreqBase, - RopeFreqScale: c.RopeFreqScale, - NegativePrompt: c.NegativePrompt, - Mirostat: int32(*c.LLMConfig.Mirostat), - MirostatETA: float32(*c.LLMConfig.MirostatETA), - MirostatTAU: float32(*c.LLMConfig.MirostatTAU), - Debug: *c.Debug, - StopPrompts: c.StopWords, - Repeat: int32(c.RepeatPenalty), - NKeep: int32(c.Keep), - Batch: int32(c.Batch), - IgnoreEOS: c.IgnoreEOS, - Seed: getSeed(c), - FrequencyPenalty: float32(c.FrequencyPenalty), - MLock: *c.MMlock, - MMap: *c.MMap, - MainGPU: c.MainGPU, - TensorSplit: c.TensorSplit, - TailFreeSamplingZ: float32(*c.TFZ), - TypicalP: float32(*c.TypicalP), + F16KV: *bc.F16, + DebugMode: *bc.Debug, + Grammar: bc.Grammar, + NegativePromptScale: bc.NegativePromptScale, + RopeFreqBase: bc.RopeFreqBase, + RopeFreqScale: bc.RopeFreqScale, + NegativePrompt: bc.NegativePrompt, + Mirostat: int32(*bc.LLMConfig.Mirostat), + MirostatETA: float32(*bc.LLMConfig.MirostatETA), + MirostatTAU: float32(*bc.LLMConfig.MirostatTAU), + Debug: *bc.Debug, + StopPrompts: bc.StopWords, + Repeat: int32(bc.RepeatPenalty), + NKeep: int32(bc.Keep), + Batch: int32(bc.Batch), + IgnoreEOS: bc.IgnoreEOS, + Seed: getSeed(bc), + FrequencyPenalty: float32(bc.FrequencyPenalty), + MLock: *bc.MMlock, + MMap: *bc.MMap, + MainGPU: bc.MainGPU, + TensorSplit: bc.TensorSplit, + TailFreeSamplingZ: float32(*bc.TFZ), + TypicalP: float32(*bc.TypicalP), } } diff --git a/core/backend/transcript.go b/core/backend/transcript.go index 4c3859df..6761c2ac 100644 --- a/core/backend/transcript.go +++ b/core/backend/transcript.go @@ -7,11 +7,48 @@ import ( "github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/core/schema" + "github.com/go-skynet/LocalAI/pkg/concurrency" "github.com/go-skynet/LocalAI/pkg/grpc/proto" - model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/go-skynet/LocalAI/pkg/model" ) -func ModelTranscription(audio, language string, ml *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (*schema.Result, error) { +type TranscriptionBackendService struct { + ml *model.ModelLoader + bcl *config.BackendConfigLoader + appConfig *config.ApplicationConfig +} + +func NewTranscriptionBackendService(ml *model.ModelLoader, bcl *config.BackendConfigLoader, appConfig *config.ApplicationConfig) *TranscriptionBackendService { + return &TranscriptionBackendService{ + ml: ml, + bcl: bcl, + appConfig: appConfig, + } +} + +func (tbs *TranscriptionBackendService) Transcribe(request *schema.OpenAIRequest) <-chan concurrency.ErrorOr[*schema.TranscriptionResult] { + responseChannel := make(chan concurrency.ErrorOr[*schema.TranscriptionResult]) + go func(request *schema.OpenAIRequest) { + bc, request, err := tbs.bcl.LoadBackendConfigForModelAndOpenAIRequest(request.Model, request, tbs.appConfig) + if err != nil { + responseChannel <- concurrency.ErrorOr[*schema.TranscriptionResult]{Error: fmt.Errorf("failed reading parameters from request:%w", err)} + close(responseChannel) + return + } + + tr, err := modelTranscription(request.File, request.Language, tbs.ml, bc, tbs.appConfig) + if err != nil { + responseChannel <- concurrency.ErrorOr[*schema.TranscriptionResult]{Error: err} + close(responseChannel) + return + } + responseChannel <- concurrency.ErrorOr[*schema.TranscriptionResult]{Value: tr} + close(responseChannel) + }(request) + return responseChannel +} + +func modelTranscription(audio, language string, ml *model.ModelLoader, backendConfig *config.BackendConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) { opts := modelOpts(backendConfig, appConfig, []model.Option{ model.WithBackendString(model.WhisperBackend), diff --git a/core/backend/tts.go b/core/backend/tts.go index f97b6202..d1fa270d 100644 --- a/core/backend/tts.go +++ b/core/backend/tts.go @@ -7,29 +7,60 @@ import ( "path/filepath" "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/core/schema" + "github.com/go-skynet/LocalAI/pkg/concurrency" "github.com/go-skynet/LocalAI/pkg/grpc/proto" - model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/go-skynet/LocalAI/pkg/model" "github.com/go-skynet/LocalAI/pkg/utils" ) -func generateUniqueFileName(dir, baseName, ext string) string { - counter := 1 - fileName := baseName + ext +type TextToSpeechBackendService struct { + ml *model.ModelLoader + bcl *config.BackendConfigLoader + appConfig *config.ApplicationConfig +} - for { - filePath := filepath.Join(dir, fileName) - _, err := os.Stat(filePath) - if os.IsNotExist(err) { - return fileName - } - - counter++ - fileName = fmt.Sprintf("%s_%d%s", baseName, counter, ext) +func NewTextToSpeechBackendService(ml *model.ModelLoader, bcl *config.BackendConfigLoader, appConfig *config.ApplicationConfig) *TextToSpeechBackendService { + return &TextToSpeechBackendService{ + ml: ml, + bcl: bcl, + appConfig: appConfig, } } -func ModelTTS(backend, text, modelFile, voice string, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (string, *proto.Result, error) { +func (ttsbs *TextToSpeechBackendService) TextToAudioFile(request *schema.TTSRequest) <-chan concurrency.ErrorOr[*string] { + responseChannel := make(chan concurrency.ErrorOr[*string]) + go func(request *schema.TTSRequest) { + cfg, err := ttsbs.bcl.LoadBackendConfigFileByName(request.Model, ttsbs.appConfig.ModelPath, + config.LoadOptionDebug(ttsbs.appConfig.Debug), + config.LoadOptionThreads(ttsbs.appConfig.Threads), + config.LoadOptionContextSize(ttsbs.appConfig.ContextSize), + config.LoadOptionF16(ttsbs.appConfig.F16), + ) + if err != nil { + responseChannel <- concurrency.ErrorOr[*string]{Error: err} + close(responseChannel) + return + } + + if request.Backend != "" { + cfg.Backend = request.Backend + } + + outFile, _, err := modelTTS(cfg.Backend, request.Input, cfg.Model, request.Voice, ttsbs.ml, ttsbs.appConfig, cfg) + if err != nil { + responseChannel <- concurrency.ErrorOr[*string]{Error: err} + close(responseChannel) + return + } + responseChannel <- concurrency.ErrorOr[*string]{Value: &outFile} + close(responseChannel) + }(request) + return responseChannel +} + +func modelTTS(backend, text, modelFile string, voice string, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig *config.BackendConfig) (string, *proto.Result, error) { bb := backend if bb == "" { bb = model.PiperBackend @@ -37,7 +68,7 @@ func ModelTTS(backend, text, modelFile, voice string, loader *model.ModelLoader, grpcOpts := gRPCModelOpts(backendConfig) - opts := modelOpts(config.BackendConfig{}, appConfig, []model.Option{ + opts := modelOpts(&config.BackendConfig{}, appConfig, []model.Option{ model.WithBackendString(bb), model.WithModel(modelFile), model.WithContext(appConfig.Context), @@ -87,3 +118,19 @@ func ModelTTS(backend, text, modelFile, voice string, loader *model.ModelLoader, return filePath, res, err } + +func generateUniqueFileName(dir, baseName, ext string) string { + counter := 1 + fileName := baseName + ext + + for { + filePath := filepath.Join(dir, fileName) + _, err := os.Stat(filePath) + if os.IsNotExist(err) { + return fileName + } + + counter++ + fileName = fmt.Sprintf("%s_%d%s", baseName, counter, ext) + } +} diff --git a/core/cli/run.go b/core/cli/run.go index 0f3ba2de..cafc0b54 100644 --- a/core/cli/run.go +++ b/core/cli/run.go @@ -124,11 +124,11 @@ func (r *RunCMD) Run(ctx *Context) error { } if r.PreloadBackendOnly { - _, _, _, err := startup.Startup(opts...) + _, err := startup.Startup(opts...) return err } - cl, ml, options, err := startup.Startup(opts...) + application, err := startup.Startup(opts...) if err != nil { return fmt.Errorf("failed basic startup tasks with error %s", err.Error()) @@ -137,7 +137,7 @@ func (r *RunCMD) Run(ctx *Context) error { // Watch the configuration directory // If the directory does not exist, we don't watch it if _, err := os.Stat(r.LocalaiConfigDir); err == nil { - closeConfigWatcherFn, err := startup.WatchConfigDirectory(r.LocalaiConfigDir, options) + closeConfigWatcherFn, err := startup.WatchConfigDirectory(r.LocalaiConfigDir, application.ApplicationConfig) defer closeConfigWatcherFn() if err != nil { @@ -145,7 +145,7 @@ func (r *RunCMD) Run(ctx *Context) error { } } - appHTTP, err := http.App(cl, ml, options) + appHTTP, err := http.App(application) if err != nil { log.Error().Err(err).Msg("error during HTTP App construction") return err diff --git a/core/cli/transcript.go b/core/cli/transcript.go index 9f36a77c..f14a1a87 100644 --- a/core/cli/transcript.go +++ b/core/cli/transcript.go @@ -7,6 +7,7 @@ import ( "github.com/go-skynet/LocalAI/core/backend" "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/core/schema" "github.com/go-skynet/LocalAI/pkg/model" ) @@ -43,11 +44,21 @@ func (t *TranscriptCMD) Run(ctx *Context) error { defer ml.StopAllGRPC() - tr, err := backend.ModelTranscription(t.Filename, t.Language, ml, c, opts) - if err != nil { - return err + tbs := backend.NewTranscriptionBackendService(ml, cl, opts) + + resultChannel := tbs.Transcribe(&schema.OpenAIRequest{ + PredictionOptions: schema.PredictionOptions{ + Language: t.Language, + }, + File: t.Filename, + }) + + r := <-resultChannel + + if r.Error != nil { + return r.Error } - for _, segment := range tr.Segments { + for _, segment := range r.Value.Segments { fmt.Println(segment.Start.String(), "-", segment.Text) } return nil diff --git a/core/cli/tts.go b/core/cli/tts.go index 1d8fd3a3..c7758c48 100644 --- a/core/cli/tts.go +++ b/core/cli/tts.go @@ -9,6 +9,7 @@ import ( "github.com/go-skynet/LocalAI/core/backend" "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/core/schema" "github.com/go-skynet/LocalAI/pkg/model" ) @@ -42,20 +43,29 @@ func (t *TTSCMD) Run(ctx *Context) error { defer ml.StopAllGRPC() - options := config.BackendConfig{} - options.SetDefaults() + ttsbs := backend.NewTextToSpeechBackendService(ml, config.NewBackendConfigLoader(), opts) - filePath, _, err := backend.ModelTTS(t.Backend, text, t.Model, t.Voice, ml, opts, options) - if err != nil { - return err + request := &schema.TTSRequest{ + Model: t.Model, + Input: text, + Backend: t.Backend, + Voice: t.Voice, + } + + resultsChannel := ttsbs.TextToAudioFile(request) + + rawResult := <-resultsChannel + + if rawResult.Error != nil { + return rawResult.Error } if outputFile != "" { - if err := os.Rename(filePath, outputFile); err != nil { + if err := os.Rename(*rawResult.Value, outputFile); err != nil { return err } - fmt.Printf("Generate file %s\n", outputFile) + fmt.Printf("Generated file %q\n", outputFile) } else { - fmt.Printf("Generate file %s\n", filePath) + fmt.Printf("Generated file %q\n", *rawResult.Value) } return nil } diff --git a/core/config/backend_config.go b/core/config/backend_config.go index 81c92d01..47e4829d 100644 --- a/core/config/backend_config.go +++ b/core/config/backend_config.go @@ -1,22 +1,7 @@ package config import ( - "errors" - "fmt" - "io/fs" - "os" - "path/filepath" - "sort" - "strings" - "sync" - "github.com/go-skynet/LocalAI/core/schema" - "github.com/go-skynet/LocalAI/pkg/downloader" - "github.com/go-skynet/LocalAI/pkg/utils" - "github.com/rs/zerolog/log" - "gopkg.in/yaml.v3" - - "github.com/charmbracelet/glamour" ) const ( @@ -199,7 +184,7 @@ func (c *BackendConfig) FunctionToCall() string { } func (cfg *BackendConfig) SetDefaults(opts ...ConfigLoaderOption) { - lo := &LoadOptions{} + lo := &ConfigLoaderOptions{} lo.Apply(opts...) ctx := lo.ctxSize @@ -312,287 +297,3 @@ func (cfg *BackendConfig) SetDefaults(opts ...ConfigLoaderOption) { cfg.Debug = &trueV } } - -////// Config Loader //////// - -type BackendConfigLoader struct { - configs map[string]BackendConfig - sync.Mutex -} - -type LoadOptions struct { - debug bool - threads, ctxSize int - f16 bool -} - -func LoadOptionDebug(debug bool) ConfigLoaderOption { - return func(o *LoadOptions) { - o.debug = debug - } -} - -func LoadOptionThreads(threads int) ConfigLoaderOption { - return func(o *LoadOptions) { - o.threads = threads - } -} - -func LoadOptionContextSize(ctxSize int) ConfigLoaderOption { - return func(o *LoadOptions) { - o.ctxSize = ctxSize - } -} - -func LoadOptionF16(f16 bool) ConfigLoaderOption { - return func(o *LoadOptions) { - o.f16 = f16 - } -} - -type ConfigLoaderOption func(*LoadOptions) - -func (lo *LoadOptions) Apply(options ...ConfigLoaderOption) { - for _, l := range options { - l(lo) - } -} - -// Load a config file for a model -func (cl *BackendConfigLoader) LoadBackendConfigFileByName(modelName, modelPath string, opts ...ConfigLoaderOption) (*BackendConfig, error) { - - // Load a config file if present after the model name - cfg := &BackendConfig{ - PredictionOptions: schema.PredictionOptions{ - Model: modelName, - }, - } - - cfgExisting, exists := cl.GetBackendConfig(modelName) - if exists { - cfg = &cfgExisting - } else { - // Try loading a model config file - modelConfig := filepath.Join(modelPath, modelName+".yaml") - if _, err := os.Stat(modelConfig); err == nil { - if err := cl.LoadBackendConfig( - modelConfig, opts..., - ); err != nil { - return nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error()) - } - cfgExisting, exists = cl.GetBackendConfig(modelName) - if exists { - cfg = &cfgExisting - } - } - } - - cfg.SetDefaults(opts...) - - return cfg, nil -} - -func NewBackendConfigLoader() *BackendConfigLoader { - return &BackendConfigLoader{ - configs: make(map[string]BackendConfig), - } -} -func ReadBackendConfigFile(file string, opts ...ConfigLoaderOption) ([]*BackendConfig, error) { - c := &[]*BackendConfig{} - f, err := os.ReadFile(file) - if err != nil { - return nil, fmt.Errorf("cannot read config file: %w", err) - } - if err := yaml.Unmarshal(f, c); err != nil { - return nil, fmt.Errorf("cannot unmarshal config file: %w", err) - } - - for _, cc := range *c { - cc.SetDefaults(opts...) - } - - return *c, nil -} - -func ReadBackendConfig(file string, opts ...ConfigLoaderOption) (*BackendConfig, error) { - lo := &LoadOptions{} - lo.Apply(opts...) - - c := &BackendConfig{} - f, err := os.ReadFile(file) - if err != nil { - return nil, fmt.Errorf("cannot read config file: %w", err) - } - if err := yaml.Unmarshal(f, c); err != nil { - return nil, fmt.Errorf("cannot unmarshal config file: %w", err) - } - - c.SetDefaults(opts...) - return c, nil -} - -func (cm *BackendConfigLoader) LoadBackendConfigFile(file string, opts ...ConfigLoaderOption) error { - cm.Lock() - defer cm.Unlock() - c, err := ReadBackendConfigFile(file, opts...) - if err != nil { - return fmt.Errorf("cannot load config file: %w", err) - } - - for _, cc := range c { - cm.configs[cc.Name] = *cc - } - return nil -} - -func (cl *BackendConfigLoader) LoadBackendConfig(file string, opts ...ConfigLoaderOption) error { - cl.Lock() - defer cl.Unlock() - c, err := ReadBackendConfig(file, opts...) - if err != nil { - return fmt.Errorf("cannot read config file: %w", err) - } - - cl.configs[c.Name] = *c - return nil -} - -func (cl *BackendConfigLoader) GetBackendConfig(m string) (BackendConfig, bool) { - cl.Lock() - defer cl.Unlock() - v, exists := cl.configs[m] - return v, exists -} - -func (cl *BackendConfigLoader) GetAllBackendConfigs() []BackendConfig { - cl.Lock() - defer cl.Unlock() - var res []BackendConfig - for _, v := range cl.configs { - res = append(res, v) - } - - sort.SliceStable(res, func(i, j int) bool { - return res[i].Name < res[j].Name - }) - - return res -} - -func (cl *BackendConfigLoader) ListBackendConfigs() []string { - cl.Lock() - defer cl.Unlock() - var res []string - for k := range cl.configs { - res = append(res, k) - } - return res -} - -// Preload prepare models if they are not local but url or huggingface repositories -func (cl *BackendConfigLoader) Preload(modelPath string) error { - cl.Lock() - defer cl.Unlock() - - status := func(fileName, current, total string, percent float64) { - utils.DisplayDownloadFunction(fileName, current, total, percent) - } - - log.Info().Msgf("Preloading models from %s", modelPath) - - renderMode := "dark" - if os.Getenv("COLOR") != "" { - renderMode = os.Getenv("COLOR") - } - - glamText := func(t string) { - out, err := glamour.Render(t, renderMode) - if err == nil && os.Getenv("NO_COLOR") == "" { - fmt.Println(out) - } else { - fmt.Println(t) - } - } - - for i, config := range cl.configs { - - // Download files and verify their SHA - for _, file := range config.DownloadFiles { - log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename) - - if err := utils.VerifyPath(file.Filename, modelPath); err != nil { - return err - } - // Create file path - filePath := filepath.Join(modelPath, file.Filename) - - if err := downloader.DownloadFile(file.URI, filePath, file.SHA256, status); err != nil { - return err - } - } - - modelURL := config.PredictionOptions.Model - modelURL = downloader.ConvertURL(modelURL) - - if downloader.LooksLikeURL(modelURL) { - // md5 of model name - md5Name := utils.MD5(modelURL) - - // check if file exists - if _, err := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(err, os.ErrNotExist) { - err := downloader.DownloadFile(modelURL, filepath.Join(modelPath, md5Name), "", status) - if err != nil { - return err - } - } - - cc := cl.configs[i] - c := &cc - c.PredictionOptions.Model = md5Name - cl.configs[i] = *c - } - if cl.configs[i].Name != "" { - glamText(fmt.Sprintf("**Model name**: _%s_", cl.configs[i].Name)) - } - if cl.configs[i].Description != "" { - //glamText("**Description**") - glamText(cl.configs[i].Description) - } - if cl.configs[i].Usage != "" { - //glamText("**Usage**") - glamText(cl.configs[i].Usage) - } - } - return nil -} - -// LoadBackendConfigsFromPath reads all the configurations of the models from a path -// (non-recursive) -func (cm *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ...ConfigLoaderOption) error { - cm.Lock() - defer cm.Unlock() - entries, err := os.ReadDir(path) - if err != nil { - return err - } - files := make([]fs.FileInfo, 0, len(entries)) - for _, entry := range entries { - info, err := entry.Info() - if err != nil { - return err - } - files = append(files, info) - } - for _, file := range files { - // Skip templates, YAML and .keep files - if !strings.Contains(file.Name(), ".yaml") && !strings.Contains(file.Name(), ".yml") { - continue - } - c, err := ReadBackendConfig(filepath.Join(path, file.Name()), opts...) - if err == nil { - cm.configs[c.Name] = *c - } - } - - return nil -} diff --git a/core/config/backend_config_loader.go b/core/config/backend_config_loader.go new file mode 100644 index 00000000..62dfc1e0 --- /dev/null +++ b/core/config/backend_config_loader.go @@ -0,0 +1,509 @@ +package config + +import ( + "encoding/json" + "errors" + "fmt" + "io/fs" + "os" + "path/filepath" + "sort" + "strings" + "sync" + + "github.com/charmbracelet/glamour" + "github.com/go-skynet/LocalAI/core/schema" + "github.com/go-skynet/LocalAI/pkg/downloader" + "github.com/go-skynet/LocalAI/pkg/grammar" + "github.com/go-skynet/LocalAI/pkg/utils" + "github.com/rs/zerolog/log" + "gopkg.in/yaml.v2" +) + +type BackendConfigLoader struct { + configs map[string]BackendConfig + sync.Mutex +} + +type ConfigLoaderOptions struct { + debug bool + threads, ctxSize int + f16 bool +} + +func LoadOptionDebug(debug bool) ConfigLoaderOption { + return func(o *ConfigLoaderOptions) { + o.debug = debug + } +} + +func LoadOptionThreads(threads int) ConfigLoaderOption { + return func(o *ConfigLoaderOptions) { + o.threads = threads + } +} + +func LoadOptionContextSize(ctxSize int) ConfigLoaderOption { + return func(o *ConfigLoaderOptions) { + o.ctxSize = ctxSize + } +} + +func LoadOptionF16(f16 bool) ConfigLoaderOption { + return func(o *ConfigLoaderOptions) { + o.f16 = f16 + } +} + +type ConfigLoaderOption func(*ConfigLoaderOptions) + +func (lo *ConfigLoaderOptions) Apply(options ...ConfigLoaderOption) { + for _, l := range options { + l(lo) + } +} + +func NewBackendConfigLoader() *BackendConfigLoader { + return &BackendConfigLoader{ + configs: make(map[string]BackendConfig), + } +} + +func (bcl *BackendConfigLoader) LoadBackendConfig(file string, opts ...ConfigLoaderOption) error { + bcl.Lock() + defer bcl.Unlock() + c, err := readBackendConfig(file, opts...) + if err != nil { + return fmt.Errorf("cannot read config file: %w", err) + } + + bcl.configs[c.Name] = *c + return nil +} + +func (bcl *BackendConfigLoader) GetBackendConfig(m string) (BackendConfig, bool) { + bcl.Lock() + defer bcl.Unlock() + v, exists := bcl.configs[m] + return v, exists +} + +func (bcl *BackendConfigLoader) GetAllBackendConfigs() []BackendConfig { + bcl.Lock() + defer bcl.Unlock() + var res []BackendConfig + for _, v := range bcl.configs { + res = append(res, v) + } + sort.SliceStable(res, func(i, j int) bool { + return res[i].Name < res[j].Name + }) + return res +} + +func (bcl *BackendConfigLoader) ListBackendConfigs() []string { + bcl.Lock() + defer bcl.Unlock() + var res []string + for k := range bcl.configs { + res = append(res, k) + } + return res +} + +// Preload prepare models if they are not local but url or huggingface repositories +func (bcl *BackendConfigLoader) Preload(modelPath string) error { + bcl.Lock() + defer bcl.Unlock() + + status := func(fileName, current, total string, percent float64) { + utils.DisplayDownloadFunction(fileName, current, total, percent) + } + + log.Info().Msgf("Preloading models from %s", modelPath) + + renderMode := "dark" + if os.Getenv("COLOR") != "" { + renderMode = os.Getenv("COLOR") + } + + glamText := func(t string) { + out, err := glamour.Render(t, renderMode) + if err == nil && os.Getenv("NO_COLOR") == "" { + fmt.Println(out) + } else { + fmt.Println(t) + } + } + + for i, config := range bcl.configs { + + // Download files and verify their SHA + for _, file := range config.DownloadFiles { + log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename) + + if err := utils.VerifyPath(file.Filename, modelPath); err != nil { + return err + } + // Create file path + filePath := filepath.Join(modelPath, file.Filename) + + if err := downloader.DownloadFile(file.URI, filePath, file.SHA256, status); err != nil { + return err + } + } + + modelURL := config.PredictionOptions.Model + modelURL = downloader.ConvertURL(modelURL) + + if downloader.LooksLikeURL(modelURL) { + // md5 of model name + md5Name := utils.MD5(modelURL) + + // check if file exists + if _, err := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(err, os.ErrNotExist) { + err := downloader.DownloadFile(modelURL, filepath.Join(modelPath, md5Name), "", status) + if err != nil { + return err + } + } + + cc := bcl.configs[i] + c := &cc + c.PredictionOptions.Model = md5Name + bcl.configs[i] = *c + } + if bcl.configs[i].Name != "" { + glamText(fmt.Sprintf("**Model name**: _%s_", bcl.configs[i].Name)) + } + if bcl.configs[i].Description != "" { + //glamText("**Description**") + glamText(bcl.configs[i].Description) + } + if bcl.configs[i].Usage != "" { + //glamText("**Usage**") + glamText(bcl.configs[i].Usage) + } + } + return nil +} + +func (bcl *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ...ConfigLoaderOption) error { + bcl.Lock() + defer bcl.Unlock() + entries, err := os.ReadDir(path) + if err != nil { + return err + } + files := make([]fs.FileInfo, 0, len(entries)) + for _, entry := range entries { + info, err := entry.Info() + if err != nil { + return err + } + files = append(files, info) + } + for _, file := range files { + // Skip templates, YAML and .keep files + if !strings.Contains(file.Name(), ".yaml") && !strings.Contains(file.Name(), ".yml") { + continue + } + c, err := readBackendConfig(filepath.Join(path, file.Name()), opts...) + if err == nil { + bcl.configs[c.Name] = *c + } + } + + return nil +} + +func (bcl *BackendConfigLoader) LoadBackendConfigFile(file string, opts ...ConfigLoaderOption) error { + bcl.Lock() + defer bcl.Unlock() + c, err := readBackendConfigFile(file, opts...) + if err != nil { + return fmt.Errorf("cannot load config file: %w", err) + } + + for _, cc := range c { + bcl.configs[cc.Name] = *cc + } + return nil +} + +////////// + +// Load a config file for a model +func (bcl *BackendConfigLoader) LoadBackendConfigFileByName(modelName string, modelPath string, opts ...ConfigLoaderOption) (*BackendConfig, error) { + + // Load a config file if present after the model name + cfg := &BackendConfig{ + PredictionOptions: schema.PredictionOptions{ + Model: modelName, + }, + } + + cfgExisting, exists := bcl.GetBackendConfig(modelName) + if exists { + cfg = &cfgExisting + } else { + // Load a config file if present after the model name + modelConfig := filepath.Join(modelPath, modelName+".yaml") + if _, err := os.Stat(modelConfig); err == nil { + if err := bcl.LoadBackendConfig(modelConfig); err != nil { + return nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error()) + } + cfgExisting, exists = bcl.GetBackendConfig(modelName) + if exists { + cfg = &cfgExisting + } + } + } + + cfg.SetDefaults(opts...) + return cfg, nil +} + +func readBackendConfigFile(file string, opts ...ConfigLoaderOption) ([]*BackendConfig, error) { + c := &[]*BackendConfig{} + f, err := os.ReadFile(file) + if err != nil { + return nil, fmt.Errorf("cannot read config file: %w", err) + } + if err := yaml.Unmarshal(f, c); err != nil { + return nil, fmt.Errorf("cannot unmarshal config file: %w", err) + } + + for _, cc := range *c { + cc.SetDefaults(opts...) + } + + return *c, nil +} + +func readBackendConfig(file string, opts ...ConfigLoaderOption) (*BackendConfig, error) { + c := &BackendConfig{} + f, err := os.ReadFile(file) + if err != nil { + return nil, fmt.Errorf("cannot read config file: %w", err) + } + if err := yaml.Unmarshal(f, c); err != nil { + return nil, fmt.Errorf("cannot unmarshal config file: %w", err) + } + + c.SetDefaults(opts...) + return c, nil +} + +func (bcl *BackendConfigLoader) LoadBackendConfigForModelAndOpenAIRequest(modelFile string, input *schema.OpenAIRequest, appConfig *ApplicationConfig) (*BackendConfig, *schema.OpenAIRequest, error) { + cfg, err := bcl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath, + LoadOptionContextSize(appConfig.ContextSize), + LoadOptionDebug(appConfig.Debug), + LoadOptionF16(appConfig.F16), + LoadOptionThreads(appConfig.Threads), + ) + + // Set the parameters for the language model prediction + updateBackendConfigFromOpenAIRequest(cfg, input) + + return cfg, input, err +} + +func updateBackendConfigFromOpenAIRequest(bc *BackendConfig, request *schema.OpenAIRequest) { + if request.Echo { + bc.Echo = request.Echo + } + if request.TopK != nil && *request.TopK != 0 { + bc.TopK = request.TopK + } + if request.TopP != nil && *request.TopP != 0 { + bc.TopP = request.TopP + } + + if request.Backend != "" { + bc.Backend = request.Backend + } + + if request.ClipSkip != 0 { + bc.Diffusers.ClipSkip = request.ClipSkip + } + + if request.ModelBaseName != "" { + bc.AutoGPTQ.ModelBaseName = request.ModelBaseName + } + + if request.NegativePromptScale != 0 { + bc.NegativePromptScale = request.NegativePromptScale + } + + if request.UseFastTokenizer { + bc.UseFastTokenizer = request.UseFastTokenizer + } + + if request.NegativePrompt != "" { + bc.NegativePrompt = request.NegativePrompt + } + + if request.RopeFreqBase != 0 { + bc.RopeFreqBase = request.RopeFreqBase + } + + if request.RopeFreqScale != 0 { + bc.RopeFreqScale = request.RopeFreqScale + } + + if request.Grammar != "" { + bc.Grammar = request.Grammar + } + + if request.Temperature != nil && *request.Temperature != 0 { + bc.Temperature = request.Temperature + } + + if request.Maxtokens != nil && *request.Maxtokens != 0 { + bc.Maxtokens = request.Maxtokens + } + + switch stop := request.Stop.(type) { + case string: + if stop != "" { + bc.StopWords = append(bc.StopWords, stop) + } + case []interface{}: + for _, pp := range stop { + if s, ok := pp.(string); ok { + bc.StopWords = append(bc.StopWords, s) + } + } + } + + if len(request.Tools) > 0 { + for _, tool := range request.Tools { + request.Functions = append(request.Functions, tool.Function) + } + } + + if request.ToolsChoice != nil { + var toolChoice grammar.Tool + switch content := request.ToolsChoice.(type) { + case string: + _ = json.Unmarshal([]byte(content), &toolChoice) + case map[string]interface{}: + dat, _ := json.Marshal(content) + _ = json.Unmarshal(dat, &toolChoice) + } + request.FunctionCall = map[string]interface{}{ + "name": toolChoice.Function.Name, + } + } + + // Decode each request's message content + index := 0 + for i, m := range request.Messages { + switch content := m.Content.(type) { + case string: + request.Messages[i].StringContent = content + case []interface{}: + dat, _ := json.Marshal(content) + c := []schema.Content{} + json.Unmarshal(dat, &c) + for _, pp := range c { + if pp.Type == "text" { + request.Messages[i].StringContent = pp.Text + } else if pp.Type == "image_url" { + // Detect if pp.ImageURL is an URL, if it is download the image and encode it in base64: + base64, err := utils.GetImageURLAsBase64(pp.ImageURL.URL) + if err == nil { + request.Messages[i].StringImages = append(request.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff + // set a placeholder for each image + request.Messages[i].StringContent = fmt.Sprintf("[img-%d]", index) + request.Messages[i].StringContent + index++ + } else { + fmt.Print("Failed encoding image", err) + } + } + } + } + } + + if request.RepeatPenalty != 0 { + bc.RepeatPenalty = request.RepeatPenalty + } + + if request.FrequencyPenalty != 0 { + bc.FrequencyPenalty = request.FrequencyPenalty + } + + if request.PresencePenalty != 0 { + bc.PresencePenalty = request.PresencePenalty + } + + if request.Keep != 0 { + bc.Keep = request.Keep + } + + if request.Batch != 0 { + bc.Batch = request.Batch + } + + if request.IgnoreEOS { + bc.IgnoreEOS = request.IgnoreEOS + } + + if request.Seed != nil { + bc.Seed = request.Seed + } + + if request.TypicalP != nil { + bc.TypicalP = request.TypicalP + } + + switch inputs := request.Input.(type) { + case string: + if inputs != "" { + bc.InputStrings = append(bc.InputStrings, inputs) + } + case []interface{}: + for _, pp := range inputs { + switch i := pp.(type) { + case string: + bc.InputStrings = append(bc.InputStrings, i) + case []interface{}: + tokens := []int{} + for _, ii := range i { + tokens = append(tokens, int(ii.(float64))) + } + bc.InputToken = append(bc.InputToken, tokens) + } + } + } + + // Can be either a string or an object + switch fnc := request.FunctionCall.(type) { + case string: + if fnc != "" { + bc.SetFunctionCallString(fnc) + } + case map[string]interface{}: + var name string + n, exists := fnc["name"] + if exists { + nn, e := n.(string) + if e { + name = nn + } + } + bc.SetFunctionCallNameString(name) + } + + switch p := request.Prompt.(type) { + case string: + bc.PromptStrings = append(bc.PromptStrings, p) + case []interface{}: + for _, pp := range p { + if s, ok := pp.(string); ok { + bc.PromptStrings = append(bc.PromptStrings, s) + } + } + } +} diff --git a/core/config/exports_test.go b/core/config/exports_test.go new file mode 100644 index 00000000..70ba84e6 --- /dev/null +++ b/core/config/exports_test.go @@ -0,0 +1,6 @@ +package config + +// This file re-exports private functions to be used directly in unit tests. +// Since this file's name ends in _test.go, theoretically these should not be exposed past the tests. + +var ReadBackendConfigFile = readBackendConfigFile diff --git a/core/http/api.go b/core/http/api.go index af38512a..7094899a 100644 --- a/core/http/api.go +++ b/core/http/api.go @@ -1,23 +1,20 @@ package http import ( - "encoding/json" "errors" - "os" "strings" - "github.com/go-skynet/LocalAI/pkg/utils" + "github.com/go-skynet/LocalAI/core" + fiberContext "github.com/go-skynet/LocalAI/core/http/ctx" "github.com/gofiber/swagger" // swagger handler "github.com/go-skynet/LocalAI/core/http/endpoints/elevenlabs" "github.com/go-skynet/LocalAI/core/http/endpoints/localai" "github.com/go-skynet/LocalAI/core/http/endpoints/openai" - - "github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/core/schema" "github.com/go-skynet/LocalAI/core/services" "github.com/go-skynet/LocalAI/internal" - "github.com/go-skynet/LocalAI/pkg/model" + model "github.com/go-skynet/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2/middleware/cors" @@ -55,13 +52,12 @@ func readAuthHeader(c *fiber.Ctx) string { // @securityDefinitions.apikey BearerAuth // @in header // @name Authorization - -func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (*fiber.App, error) { +func App(application *core.Application) (*fiber.App, error) { // Return errors as JSON responses app := fiber.New(fiber.Config{ Views: renderEngine(), - BodyLimit: appConfig.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB - DisableStartupMessage: appConfig.DisableMessage, + BodyLimit: application.ApplicationConfig.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB + DisableStartupMessage: application.ApplicationConfig.DisableMessage, // Override default error handler ErrorHandler: func(ctx *fiber.Ctx, err error) error { // Status code defaults to 500 @@ -82,7 +78,7 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi }, }) - if appConfig.Debug { + if application.ApplicationConfig.Debug { app.Use(logger.New(logger.Config{ Format: "[${ip}]:${port} ${status} - ${method} ${path}\n", })) @@ -90,7 +86,7 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi // Default middleware config - if !appConfig.Debug { + if !application.ApplicationConfig.Debug { app.Use(recover.New()) } @@ -108,25 +104,7 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi // Auth middleware checking if API key is valid. If no API key is set, no auth is required. auth := func(c *fiber.Ctx) error { - if len(appConfig.ApiKeys) == 0 { - return c.Next() - } - - // Check for api_keys.json file - fileContent, err := os.ReadFile("api_keys.json") - if err == nil { - // Parse JSON content from the file - var fileKeys []string - err := json.Unmarshal(fileContent, &fileKeys) - if err != nil { - return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"message": "Error parsing api_keys.json"}) - } - - // Add file keys to options.ApiKeys - appConfig.ApiKeys = append(appConfig.ApiKeys, fileKeys...) - } - - if len(appConfig.ApiKeys) == 0 { + if len(application.ApplicationConfig.ApiKeys) == 0 { return c.Next() } @@ -142,7 +120,7 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi } apiKey := authHeaderParts[1] - for _, key := range appConfig.ApiKeys { + for _, key := range application.ApplicationConfig.ApiKeys { if apiKey == key { return c.Next() } @@ -151,20 +129,22 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid API key"}) } - if appConfig.CORS { + if application.ApplicationConfig.CORS { var c func(ctx *fiber.Ctx) error - if appConfig.CORSAllowOrigins == "" { + if application.ApplicationConfig.CORSAllowOrigins == "" { c = cors.New() } else { - c = cors.New(cors.Config{AllowOrigins: appConfig.CORSAllowOrigins}) + c = cors.New(cors.Config{AllowOrigins: application.ApplicationConfig.CORSAllowOrigins}) } app.Use(c) } + fiberContextExtractor := fiberContext.NewFiberContextExtractor(application.ModelLoader, application.ApplicationConfig) + // LocalAI API endpoints - galleryService := services.NewGalleryService(appConfig.ModelPath) - galleryService.Start(appConfig.Context, cl) + galleryService := services.NewGalleryService(application.ApplicationConfig.ModelPath) + galleryService.Start(application.ApplicationConfig.Context, application.BackendConfigLoader) app.Get("/version", auth, func(c *fiber.Ctx) error { return c.JSON(struct { @@ -172,29 +152,17 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi }{Version: internal.PrintableVersion()}) }) - // Make sure directories exists - os.MkdirAll(appConfig.ImageDir, 0755) - os.MkdirAll(appConfig.AudioDir, 0755) - os.MkdirAll(appConfig.UploadDir, 0755) - os.MkdirAll(appConfig.ConfigsDir, 0755) - os.MkdirAll(appConfig.ModelPath, 0755) - - // Load config jsons - utils.LoadConfig(appConfig.UploadDir, openai.UploadedFilesFile, &openai.UploadedFiles) - utils.LoadConfig(appConfig.ConfigsDir, openai.AssistantsConfigFile, &openai.Assistants) - utils.LoadConfig(appConfig.ConfigsDir, openai.AssistantsFileConfigFile, &openai.AssistantFiles) - app.Get("/swagger/*", swagger.HandlerDefault) // default welcomeRoute( app, - cl, - ml, - appConfig, + application.BackendConfigLoader, + application.ModelLoader, + application.ApplicationConfig, auth, ) - modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.ModelPath, galleryService) + modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(application.ApplicationConfig.Galleries, application.ApplicationConfig.ModelPath, galleryService) app.Post("/models/apply", auth, modelGalleryEndpointService.ApplyModelGalleryEndpoint()) app.Get("/models/available", auth, modelGalleryEndpointService.ListModelFromGalleryEndpoint()) app.Get("/models/galleries", auth, modelGalleryEndpointService.ListModelGalleriesEndpoint()) @@ -203,83 +171,85 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi app.Get("/models/jobs/:uuid", auth, modelGalleryEndpointService.GetOpStatusEndpoint()) app.Get("/models/jobs", auth, modelGalleryEndpointService.GetAllStatusEndpoint()) - app.Post("/tts", auth, localai.TTSEndpoint(cl, ml, appConfig)) - - // Elevenlabs - app.Post("/v1/text-to-speech/:voice-id", auth, elevenlabs.TTSEndpoint(cl, ml, appConfig)) - // Stores - sl := model.NewModelLoader("") - app.Post("/stores/set", auth, localai.StoresSetEndpoint(sl, appConfig)) - app.Post("/stores/delete", auth, localai.StoresDeleteEndpoint(sl, appConfig)) - app.Post("/stores/get", auth, localai.StoresGetEndpoint(sl, appConfig)) - app.Post("/stores/find", auth, localai.StoresFindEndpoint(sl, appConfig)) + storeLoader := model.NewModelLoader("") // TODO: Investigate if this should be migrated to application and reused. Should the path be configurable? Merging for now. + app.Post("/stores/set", auth, localai.StoresSetEndpoint(storeLoader, application.ApplicationConfig)) + app.Post("/stores/delete", auth, localai.StoresDeleteEndpoint(storeLoader, application.ApplicationConfig)) + app.Post("/stores/get", auth, localai.StoresGetEndpoint(storeLoader, application.ApplicationConfig)) + app.Post("/stores/find", auth, localai.StoresFindEndpoint(storeLoader, application.ApplicationConfig)) - // openAI compatible API endpoint + // openAI compatible API endpoints // chat - app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(cl, ml, appConfig)) - app.Post("/chat/completions", auth, openai.ChatEndpoint(cl, ml, appConfig)) + app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(fiberContextExtractor, application.OpenAIService)) + app.Post("/chat/completions", auth, openai.ChatEndpoint(fiberContextExtractor, application.OpenAIService)) // edit - app.Post("/v1/edits", auth, openai.EditEndpoint(cl, ml, appConfig)) - app.Post("/edits", auth, openai.EditEndpoint(cl, ml, appConfig)) + app.Post("/v1/edits", auth, openai.EditEndpoint(fiberContextExtractor, application.OpenAIService)) + app.Post("/edits", auth, openai.EditEndpoint(fiberContextExtractor, application.OpenAIService)) // assistant - app.Get("/v1/assistants", auth, openai.ListAssistantsEndpoint(cl, ml, appConfig)) - app.Get("/assistants", auth, openai.ListAssistantsEndpoint(cl, ml, appConfig)) - app.Post("/v1/assistants", auth, openai.CreateAssistantEndpoint(cl, ml, appConfig)) - app.Post("/assistants", auth, openai.CreateAssistantEndpoint(cl, ml, appConfig)) - app.Delete("/v1/assistants/:assistant_id", auth, openai.DeleteAssistantEndpoint(cl, ml, appConfig)) - app.Delete("/assistants/:assistant_id", auth, openai.DeleteAssistantEndpoint(cl, ml, appConfig)) - app.Get("/v1/assistants/:assistant_id", auth, openai.GetAssistantEndpoint(cl, ml, appConfig)) - app.Get("/assistants/:assistant_id", auth, openai.GetAssistantEndpoint(cl, ml, appConfig)) - app.Post("/v1/assistants/:assistant_id", auth, openai.ModifyAssistantEndpoint(cl, ml, appConfig)) - app.Post("/assistants/:assistant_id", auth, openai.ModifyAssistantEndpoint(cl, ml, appConfig)) - app.Get("/v1/assistants/:assistant_id/files", auth, openai.ListAssistantFilesEndpoint(cl, ml, appConfig)) - app.Get("/assistants/:assistant_id/files", auth, openai.ListAssistantFilesEndpoint(cl, ml, appConfig)) - app.Post("/v1/assistants/:assistant_id/files", auth, openai.CreateAssistantFileEndpoint(cl, ml, appConfig)) - app.Post("/assistants/:assistant_id/files", auth, openai.CreateAssistantFileEndpoint(cl, ml, appConfig)) - app.Delete("/v1/assistants/:assistant_id/files/:file_id", auth, openai.DeleteAssistantFileEndpoint(cl, ml, appConfig)) - app.Delete("/assistants/:assistant_id/files/:file_id", auth, openai.DeleteAssistantFileEndpoint(cl, ml, appConfig)) - app.Get("/v1/assistants/:assistant_id/files/:file_id", auth, openai.GetAssistantFileEndpoint(cl, ml, appConfig)) - app.Get("/assistants/:assistant_id/files/:file_id", auth, openai.GetAssistantFileEndpoint(cl, ml, appConfig)) + // TODO: Refactor this to the new style eventually + app.Get("/v1/assistants", auth, openai.ListAssistantsEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig)) + app.Get("/assistants", auth, openai.ListAssistantsEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig)) + app.Post("/v1/assistants", auth, openai.CreateAssistantEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig)) + app.Post("/assistants", auth, openai.CreateAssistantEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig)) + app.Delete("/v1/assistants/:assistant_id", auth, openai.DeleteAssistantEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig)) + app.Delete("/assistants/:assistant_id", auth, openai.DeleteAssistantEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig)) + app.Get("/v1/assistants/:assistant_id", auth, openai.GetAssistantEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig)) + app.Get("/assistants/:assistant_id", auth, openai.GetAssistantEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig)) + app.Post("/v1/assistants/:assistant_id", auth, openai.ModifyAssistantEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig)) + app.Post("/assistants/:assistant_id", auth, openai.ModifyAssistantEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig)) + app.Get("/v1/assistants/:assistant_id/files", auth, openai.ListAssistantFilesEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig)) + app.Get("/assistants/:assistant_id/files", auth, openai.ListAssistantFilesEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig)) + app.Post("/v1/assistants/:assistant_id/files", auth, openai.CreateAssistantFileEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig)) + app.Post("/assistants/:assistant_id/files", auth, openai.CreateAssistantFileEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig)) + app.Delete("/v1/assistants/:assistant_id/files/:file_id", auth, openai.DeleteAssistantFileEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig)) + app.Delete("/assistants/:assistant_id/files/:file_id", auth, openai.DeleteAssistantFileEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig)) + app.Get("/v1/assistants/:assistant_id/files/:file_id", auth, openai.GetAssistantFileEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig)) + app.Get("/assistants/:assistant_id/files/:file_id", auth, openai.GetAssistantFileEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig)) // files - app.Post("/v1/files", auth, openai.UploadFilesEndpoint(cl, appConfig)) - app.Post("/files", auth, openai.UploadFilesEndpoint(cl, appConfig)) - app.Get("/v1/files", auth, openai.ListFilesEndpoint(cl, appConfig)) - app.Get("/files", auth, openai.ListFilesEndpoint(cl, appConfig)) - app.Get("/v1/files/:file_id", auth, openai.GetFilesEndpoint(cl, appConfig)) - app.Get("/files/:file_id", auth, openai.GetFilesEndpoint(cl, appConfig)) - app.Delete("/v1/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, appConfig)) - app.Delete("/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, appConfig)) - app.Get("/v1/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, appConfig)) - app.Get("/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, appConfig)) + app.Post("/v1/files", auth, openai.UploadFilesEndpoint(application.BackendConfigLoader, application.ApplicationConfig)) + app.Post("/files", auth, openai.UploadFilesEndpoint(application.BackendConfigLoader, application.ApplicationConfig)) + app.Get("/v1/files", auth, openai.ListFilesEndpoint(application.BackendConfigLoader, application.ApplicationConfig)) + app.Get("/files", auth, openai.ListFilesEndpoint(application.BackendConfigLoader, application.ApplicationConfig)) + app.Get("/v1/files/:file_id", auth, openai.GetFilesEndpoint(application.BackendConfigLoader, application.ApplicationConfig)) + app.Get("/files/:file_id", auth, openai.GetFilesEndpoint(application.BackendConfigLoader, application.ApplicationConfig)) + app.Delete("/v1/files/:file_id", auth, openai.DeleteFilesEndpoint(application.BackendConfigLoader, application.ApplicationConfig)) + app.Delete("/files/:file_id", auth, openai.DeleteFilesEndpoint(application.BackendConfigLoader, application.ApplicationConfig)) + app.Get("/v1/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(application.BackendConfigLoader, application.ApplicationConfig)) + app.Get("/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(application.BackendConfigLoader, application.ApplicationConfig)) // completion - app.Post("/v1/completions", auth, openai.CompletionEndpoint(cl, ml, appConfig)) - app.Post("/completions", auth, openai.CompletionEndpoint(cl, ml, appConfig)) - app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(cl, ml, appConfig)) + app.Post("/v1/completions", auth, openai.CompletionEndpoint(fiberContextExtractor, application.OpenAIService)) + app.Post("/completions", auth, openai.CompletionEndpoint(fiberContextExtractor, application.OpenAIService)) + app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(fiberContextExtractor, application.OpenAIService)) // embeddings - app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, appConfig)) - app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, appConfig)) - app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, appConfig)) + app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(fiberContextExtractor, application.EmbeddingsBackendService)) + app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(fiberContextExtractor, application.EmbeddingsBackendService)) + app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(fiberContextExtractor, application.EmbeddingsBackendService)) // audio - app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(cl, ml, appConfig)) - app.Post("/v1/audio/speech", auth, localai.TTSEndpoint(cl, ml, appConfig)) + app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(fiberContextExtractor, application.TranscriptionBackendService)) + app.Post("/v1/audio/speech", auth, localai.TTSEndpoint(fiberContextExtractor, application.TextToSpeechBackendService)) // images - app.Post("/v1/images/generations", auth, openai.ImageEndpoint(cl, ml, appConfig)) + app.Post("/v1/images/generations", auth, openai.ImageEndpoint(fiberContextExtractor, application.ImageGenerationBackendService)) - if appConfig.ImageDir != "" { - app.Static("/generated-images", appConfig.ImageDir) + // Elevenlabs + app.Post("/v1/text-to-speech/:voice-id", auth, elevenlabs.TTSEndpoint(fiberContextExtractor, application.TextToSpeechBackendService)) + + // LocalAI TTS? + app.Post("/tts", auth, localai.TTSEndpoint(fiberContextExtractor, application.TextToSpeechBackendService)) + + if application.ApplicationConfig.ImageDir != "" { + app.Static("/generated-images", application.ApplicationConfig.ImageDir) } - if appConfig.AudioDir != "" { - app.Static("/generated-audio", appConfig.AudioDir) + if application.ApplicationConfig.AudioDir != "" { + app.Static("/generated-audio", application.ApplicationConfig.AudioDir) } ok := func(c *fiber.Ctx) error { @@ -291,13 +261,12 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi app.Get("/readyz", ok) // Experimental Backend Statistics Module - backendMonitor := services.NewBackendMonitor(cl, ml, appConfig) // Split out for now - app.Get("/backend/monitor", auth, localai.BackendMonitorEndpoint(backendMonitor)) - app.Post("/backend/shutdown", auth, localai.BackendShutdownEndpoint(backendMonitor)) + app.Get("/backend/monitor", auth, localai.BackendMonitorEndpoint(application.BackendMonitorService)) + app.Post("/backend/shutdown", auth, localai.BackendShutdownEndpoint(application.BackendMonitorService)) // models - app.Get("/v1/models", auth, openai.ListModelsEndpoint(cl, ml)) - app.Get("/models", auth, openai.ListModelsEndpoint(cl, ml)) + app.Get("/v1/models", auth, openai.ListModelsEndpoint(application.ListModelsService)) + app.Get("/models", auth, openai.ListModelsEndpoint(application.ListModelsService)) app.Get("/metrics", auth, localai.LocalAIMetricsEndpoint()) diff --git a/core/http/api_test.go b/core/http/api_test.go index 1553ed21..bf8feb1c 100644 --- a/core/http/api_test.go +++ b/core/http/api_test.go @@ -12,7 +12,9 @@ import ( "os" "path/filepath" "runtime" + "strings" + "github.com/go-skynet/LocalAI/core" "github.com/go-skynet/LocalAI/core/config" . "github.com/go-skynet/LocalAI/core/http" "github.com/go-skynet/LocalAI/core/schema" @@ -205,9 +207,7 @@ var _ = Describe("API test", func() { var cancel context.CancelFunc var tmpdir string var modelDir string - var bcl *config.BackendConfigLoader - var ml *model.ModelLoader - var applicationConfig *config.ApplicationConfig + var application *core.Application commonOpts := []config.AppOption{ config.WithDebug(true), @@ -252,7 +252,7 @@ var _ = Describe("API test", func() { }, } - bcl, ml, applicationConfig, err = startup.Startup( + application, err = startup.Startup( append(commonOpts, config.WithContext(c), config.WithGalleries(galleries), @@ -261,7 +261,7 @@ var _ = Describe("API test", func() { config.WithBackendAssetsOutput(backendAssetsDir))...) Expect(err).ToNot(HaveOccurred()) - app, err = App(bcl, ml, applicationConfig) + app, err = App(application) Expect(err).ToNot(HaveOccurred()) go app.Listen("127.0.0.1:9090") @@ -474,11 +474,11 @@ var _ = Describe("API test", func() { }) Expect(err).ToNot(HaveOccurred()) Expect(len(resp2.Choices)).To(Equal(1)) - Expect(resp2.Choices[0].Message.FunctionCall).ToNot(BeNil()) - Expect(resp2.Choices[0].Message.FunctionCall.Name).To(Equal("get_current_weather"), resp2.Choices[0].Message.FunctionCall.Name) + Expect(resp2.Choices[0].Message.ToolCalls[0].Function).ToNot(BeNil()) + Expect(resp2.Choices[0].Message.ToolCalls[0].Function.Name).To(Equal("get_current_weather"), resp2.Choices[0].Message.ToolCalls[0].Function.Name) var res map[string]string - err = json.Unmarshal([]byte(resp2.Choices[0].Message.FunctionCall.Arguments), &res) + err = json.Unmarshal([]byte(resp2.Choices[0].Message.ToolCalls[0].Function.Arguments), &res) Expect(err).ToNot(HaveOccurred()) Expect(res["location"]).To(Equal("San Francisco"), fmt.Sprint(res)) Expect(res["unit"]).To(Equal("celcius"), fmt.Sprint(res)) @@ -487,9 +487,9 @@ var _ = Describe("API test", func() { }) It("runs openllama gguf(llama-cpp)", Label("llama-gguf"), func() { - if runtime.GOOS != "linux" { - Skip("test supported only on linux") - } + // if runtime.GOOS != "linux" { + // Skip("test supported only on linux") + // } modelName := "codellama" response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ URL: "github:go-skynet/model-gallery/codellama-7b-instruct.yaml", @@ -504,7 +504,7 @@ var _ = Describe("API test", func() { Eventually(func() bool { response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) return response["processed"].(bool) - }, "360s", "10s").Should(Equal(true)) + }, "480s", "10s").Should(Equal(true)) By("testing chat") resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: modelName, Messages: []openai.ChatCompletionMessage{ @@ -551,11 +551,13 @@ var _ = Describe("API test", func() { }) Expect(err).ToNot(HaveOccurred()) Expect(len(resp2.Choices)).To(Equal(1)) - Expect(resp2.Choices[0].Message.FunctionCall).ToNot(BeNil()) - Expect(resp2.Choices[0].Message.FunctionCall.Name).To(Equal("get_current_weather"), resp2.Choices[0].Message.FunctionCall.Name) + fmt.Printf("\n--- %+v\n\n", resp2.Choices[0].Message) + Expect(resp2.Choices[0].Message.ToolCalls).ToNot(BeNil()) + Expect(resp2.Choices[0].Message.ToolCalls[0]).ToNot(BeNil()) + Expect(resp2.Choices[0].Message.ToolCalls[0].Function.Name).To(Equal("get_current_weather"), resp2.Choices[0].Message.ToolCalls[0].Function.Name) var res map[string]string - err = json.Unmarshal([]byte(resp2.Choices[0].Message.FunctionCall.Arguments), &res) + err = json.Unmarshal([]byte(resp2.Choices[0].Message.ToolCalls[0].Function.Arguments), &res) Expect(err).ToNot(HaveOccurred()) Expect(res["location"]).To(Equal("San Francisco"), fmt.Sprint(res)) Expect(res["unit"]).To(Equal("celcius"), fmt.Sprint(res)) @@ -609,7 +611,7 @@ var _ = Describe("API test", func() { }, } - bcl, ml, applicationConfig, err = startup.Startup( + application, err = startup.Startup( append(commonOpts, config.WithContext(c), config.WithAudioDir(tmpdir), @@ -620,7 +622,7 @@ var _ = Describe("API test", func() { config.WithBackendAssetsOutput(tmpdir))..., ) Expect(err).ToNot(HaveOccurred()) - app, err = App(bcl, ml, applicationConfig) + app, err = App(application) Expect(err).ToNot(HaveOccurred()) go app.Listen("127.0.0.1:9090") @@ -724,14 +726,14 @@ var _ = Describe("API test", func() { var err error - bcl, ml, applicationConfig, err = startup.Startup( + application, err = startup.Startup( append(commonOpts, config.WithExternalBackend("huggingface", os.Getenv("HUGGINGFACE_GRPC")), config.WithContext(c), config.WithModelPath(modelPath), )...) Expect(err).ToNot(HaveOccurred()) - app, err = App(bcl, ml, applicationConfig) + app, err = App(application) Expect(err).ToNot(HaveOccurred()) go app.Listen("127.0.0.1:9090") @@ -761,6 +763,11 @@ var _ = Describe("API test", func() { Expect(len(models.Models)).To(Equal(6)) // If "config.yaml" should be included, this should be 8? }) It("can generate completions via ggml", func() { + bt, ok := os.LookupEnv("BUILD_TYPE") + if ok && strings.ToLower(bt) == "metal" { + Skip("GGML + Metal is known flaky, skip test temporarily") + } + resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "testmodel.ggml", Prompt: testPrompt}) Expect(err).ToNot(HaveOccurred()) Expect(len(resp.Choices)).To(Equal(1)) @@ -768,6 +775,11 @@ var _ = Describe("API test", func() { }) It("can generate chat completions via ggml", func() { + bt, ok := os.LookupEnv("BUILD_TYPE") + if ok && strings.ToLower(bt) == "metal" { + Skip("GGML + Metal is known flaky, skip test temporarily") + } + resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "testmodel.ggml", Messages: []openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role: "user", Content: testPrompt}}}) Expect(err).ToNot(HaveOccurred()) Expect(len(resp.Choices)).To(Equal(1)) @@ -775,6 +787,11 @@ var _ = Describe("API test", func() { }) It("can generate completions from model configs", func() { + bt, ok := os.LookupEnv("BUILD_TYPE") + if ok && strings.ToLower(bt) == "metal" { + Skip("GGML + Metal is known flaky, skip test temporarily") + } + resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "gpt4all", Prompt: testPrompt}) Expect(err).ToNot(HaveOccurred()) Expect(len(resp.Choices)).To(Equal(1)) @@ -782,6 +799,11 @@ var _ = Describe("API test", func() { }) It("can generate chat completions from model configs", func() { + bt, ok := os.LookupEnv("BUILD_TYPE") + if ok && strings.ToLower(bt) == "metal" { + Skip("GGML + Metal is known flaky, skip test temporarily") + } + resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "gpt4all-2", Messages: []openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role: "user", Content: testPrompt}}}) Expect(err).ToNot(HaveOccurred()) Expect(len(resp.Choices)).To(Equal(1)) @@ -868,9 +890,9 @@ var _ = Describe("API test", func() { Context("backends", func() { It("runs rwkv completion", func() { - if runtime.GOOS != "linux" { - Skip("test supported only on linux") - } + // if runtime.GOOS != "linux" { + // Skip("test supported only on linux") + // } resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "rwkv_test", Prompt: "Count up to five: one, two, three, four,"}) Expect(err).ToNot(HaveOccurred()) Expect(len(resp.Choices) > 0).To(BeTrue()) @@ -891,17 +913,20 @@ var _ = Describe("API test", func() { } Expect(err).ToNot(HaveOccurred()) - text += response.Choices[0].Text - tokens++ + + if len(response.Choices) > 0 { + text += response.Choices[0].Text + tokens++ + } } Expect(text).ToNot(BeEmpty()) Expect(text).To(ContainSubstring("five")) Expect(tokens).ToNot(Or(Equal(1), Equal(0))) }) It("runs rwkv chat completion", func() { - if runtime.GOOS != "linux" { - Skip("test supported only on linux") - } + // if runtime.GOOS != "linux" { + // Skip("test supported only on linux") + // } resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "rwkv_test", Messages: []openai.ChatCompletionMessage{{Content: "Can you count up to five?", Role: "user"}}}) Expect(err).ToNot(HaveOccurred()) @@ -1010,14 +1035,14 @@ var _ = Describe("API test", func() { c, cancel = context.WithCancel(context.Background()) var err error - bcl, ml, applicationConfig, err = startup.Startup( + application, err = startup.Startup( append(commonOpts, config.WithContext(c), config.WithModelPath(modelPath), config.WithConfigFile(os.Getenv("CONFIG_FILE")))..., ) Expect(err).ToNot(HaveOccurred()) - app, err = App(bcl, ml, applicationConfig) + app, err = App(application) Expect(err).ToNot(HaveOccurred()) go app.Listen("127.0.0.1:9090") @@ -1041,18 +1066,33 @@ var _ = Describe("API test", func() { } }) It("can generate chat completions from config file (list1)", func() { + bt, ok := os.LookupEnv("BUILD_TYPE") + if ok && strings.ToLower(bt) == "metal" { + Skip("GGML + Metal is known flaky, skip test temporarily") + } + resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list1", Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}}}) Expect(err).ToNot(HaveOccurred()) Expect(len(resp.Choices)).To(Equal(1)) Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty()) }) It("can generate chat completions from config file (list2)", func() { + bt, ok := os.LookupEnv("BUILD_TYPE") + if ok && strings.ToLower(bt) == "metal" { + Skip("GGML + Metal is known flaky, skip test temporarily") + } + resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list2", Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}}}) Expect(err).ToNot(HaveOccurred()) Expect(len(resp.Choices)).To(Equal(1)) Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty()) }) It("can generate edit completions from config file", func() { + bt, ok := os.LookupEnv("BUILD_TYPE") + if ok && strings.ToLower(bt) == "metal" { + Skip("GGML + Metal is known flaky, skip test temporarily") + } + request := openaigo.EditCreateRequestBody{ Model: "list2", Instruction: "foo", diff --git a/core/http/ctx/fiber.go b/core/http/ctx/fiber.go index ffb63111..99fbcde9 100644 --- a/core/http/ctx/fiber.go +++ b/core/http/ctx/fiber.go @@ -1,43 +1,88 @@ package fiberContext import ( + "context" + "encoding/json" "fmt" "strings" + "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/core/schema" "github.com/go-skynet/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" "github.com/rs/zerolog/log" ) +type FiberContextExtractor struct { + ml *model.ModelLoader + appConfig *config.ApplicationConfig +} + +func NewFiberContextExtractor(ml *model.ModelLoader, appConfig *config.ApplicationConfig) *FiberContextExtractor { + return &FiberContextExtractor{ + ml: ml, + appConfig: appConfig, + } +} + // ModelFromContext returns the model from the context // If no model is specified, it will take the first available // Takes a model string as input which should be the one received from the user request. // It returns the model name resolved from the context and an error if any. -func ModelFromContext(ctx *fiber.Ctx, loader *model.ModelLoader, modelInput string, firstModel bool) (string, error) { - if ctx.Params("model") != "" { - modelInput = ctx.Params("model") +func (fce *FiberContextExtractor) ModelFromContext(ctx *fiber.Ctx, modelInput string, firstModel bool) (string, error) { + ctxPM := ctx.Params("model") + if ctxPM != "" { + log.Debug().Msgf("[FCE] Overriding param modelInput %q with ctx.Params value %q", modelInput, ctxPM) + modelInput = ctxPM } // Set model from bearer token, if available - bearer := strings.TrimLeft(ctx.Get("authorization"), "Bearer ") - bearerExists := bearer != "" && loader.ExistsInModelPath(bearer) + bearer := strings.TrimPrefix(ctx.Get("authorization"), "Bearer ") + bearerExists := bearer != "" && fce.ml.ExistsInModelPath(bearer) // If no model was specified, take the first available if modelInput == "" && !bearerExists && firstModel { - models, _ := loader.ListModels() + models, _ := fce.ml.ListModels() if len(models) > 0 { modelInput = models[0] - log.Debug().Msgf("No model specified, using: %s", modelInput) + log.Debug().Msgf("[FCE] No model specified, using first available: %s", modelInput) } else { - log.Debug().Msgf("No model specified, returning error") - return "", fmt.Errorf("no model specified") + log.Warn().Msgf("[FCE] No model specified, none available") + return "", fmt.Errorf("[fce] no model specified, none available") } } // If a model is found in bearer token takes precedence if bearerExists { - log.Debug().Msgf("Using model from bearer token: %s", bearer) + log.Debug().Msgf("[FCE] Using model from bearer token: %s", bearer) modelInput = bearer } + + if modelInput == "" { + log.Warn().Msg("[FCE] modelInput is empty") + } return modelInput, nil } + +// TODO: Do we still need the first return value? +func (fce *FiberContextExtractor) OpenAIRequestFromContext(c *fiber.Ctx, firstModel bool) (string, *schema.OpenAIRequest, error) { + input := new(schema.OpenAIRequest) + + // Get input data from the request body + if err := c.BodyParser(input); err != nil { + return "", nil, fmt.Errorf("failed parsing request body: %w", err) + } + + received, _ := json.Marshal(input) + + ctx, cancel := context.WithCancel(fce.appConfig.Context) + input.Context = ctx + input.Cancel = cancel + + log.Debug().Msgf("Request received: %s", string(received)) + + var err error + input.Model, err = fce.ModelFromContext(c, input.Model, firstModel) + + return input.Model, input, err +} diff --git a/core/http/endpoints/elevenlabs/tts.go b/core/http/endpoints/elevenlabs/tts.go index 841f9b5f..4f5db463 100644 --- a/core/http/endpoints/elevenlabs/tts.go +++ b/core/http/endpoints/elevenlabs/tts.go @@ -2,9 +2,7 @@ package elevenlabs import ( "github.com/go-skynet/LocalAI/core/backend" - "github.com/go-skynet/LocalAI/core/config" fiberContext "github.com/go-skynet/LocalAI/core/http/ctx" - "github.com/go-skynet/LocalAI/pkg/model" "github.com/go-skynet/LocalAI/core/schema" "github.com/gofiber/fiber/v2" @@ -17,7 +15,7 @@ import ( // @Param request body schema.TTSRequest true "query params" // @Success 200 {string} binary "Response" // @Router /v1/text-to-speech/{voice-id} [post] -func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { +func TTSEndpoint(fce *fiberContext.FiberContextExtractor, ttsbs *backend.TextToSpeechBackendService) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { input := new(schema.ElevenLabsTTSRequest) @@ -28,34 +26,21 @@ func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfi return err } - modelFile, err := fiberContext.ModelFromContext(c, ml, input.ModelID, false) + var err error + input.ModelID, err = fce.ModelFromContext(c, input.ModelID, false) if err != nil { - modelFile = input.ModelID log.Warn().Msgf("Model not found in context: %s", input.ModelID) } - cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath, - config.LoadOptionDebug(appConfig.Debug), - config.LoadOptionThreads(appConfig.Threads), - config.LoadOptionContextSize(appConfig.ContextSize), - config.LoadOptionF16(appConfig.F16), - ) - if err != nil { - modelFile = input.ModelID - log.Warn().Msgf("Model not found in context: %s", input.ModelID) - } else { - if input.ModelID != "" { - modelFile = input.ModelID - } else { - modelFile = cfg.Model - } + responseChannel := ttsbs.TextToAudioFile(&schema.TTSRequest{ + Model: input.ModelID, + Voice: voiceID, + Input: input.Text, + }) + rawValue := <-responseChannel + if rawValue.Error != nil { + return rawValue.Error } - log.Debug().Msgf("Request for model: %s", modelFile) - - filePath, _, err := backend.ModelTTS(cfg.Backend, input.Text, modelFile, voiceID, ml, appConfig, *cfg) - if err != nil { - return err - } - return c.Download(filePath) + return c.Download(*rawValue.Value) } } diff --git a/core/http/endpoints/localai/backend_monitor.go b/core/http/endpoints/localai/backend_monitor.go index 8c7a664a..dac20388 100644 --- a/core/http/endpoints/localai/backend_monitor.go +++ b/core/http/endpoints/localai/backend_monitor.go @@ -6,7 +6,7 @@ import ( "github.com/gofiber/fiber/v2" ) -func BackendMonitorEndpoint(bm services.BackendMonitor) func(c *fiber.Ctx) error { +func BackendMonitorEndpoint(bm *services.BackendMonitorService) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { input := new(schema.BackendMonitorRequest) @@ -23,7 +23,7 @@ func BackendMonitorEndpoint(bm services.BackendMonitor) func(c *fiber.Ctx) error } } -func BackendShutdownEndpoint(bm services.BackendMonitor) func(c *fiber.Ctx) error { +func BackendShutdownEndpoint(bm *services.BackendMonitorService) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { input := new(schema.BackendMonitorRequest) // Get input data from the request body diff --git a/core/http/endpoints/localai/tts.go b/core/http/endpoints/localai/tts.go index 7822e024..df7841fb 100644 --- a/core/http/endpoints/localai/tts.go +++ b/core/http/endpoints/localai/tts.go @@ -2,9 +2,7 @@ package localai import ( "github.com/go-skynet/LocalAI/core/backend" - "github.com/go-skynet/LocalAI/core/config" fiberContext "github.com/go-skynet/LocalAI/core/http/ctx" - "github.com/go-skynet/LocalAI/pkg/model" "github.com/go-skynet/LocalAI/core/schema" "github.com/gofiber/fiber/v2" @@ -16,45 +14,26 @@ import ( // @Param request body schema.TTSRequest true "query params" // @Success 200 {string} binary "Response" // @Router /v1/audio/speech [post] -func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { +func TTSEndpoint(fce *fiberContext.FiberContextExtractor, ttsbs *backend.TextToSpeechBackendService) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - + var err error input := new(schema.TTSRequest) // Get input data from the request body - if err := c.BodyParser(input); err != nil { + if err = c.BodyParser(input); err != nil { return err } - modelFile, err := fiberContext.ModelFromContext(c, ml, input.Model, false) + input.Model, err = fce.ModelFromContext(c, input.Model, false) if err != nil { - modelFile = input.Model log.Warn().Msgf("Model not found in context: %s", input.Model) } - cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath, - config.LoadOptionDebug(appConfig.Debug), - config.LoadOptionThreads(appConfig.Threads), - config.LoadOptionContextSize(appConfig.ContextSize), - config.LoadOptionF16(appConfig.F16), - ) - - if err != nil { - modelFile = input.Model - log.Warn().Msgf("Model not found in context: %s", input.Model) - } else { - modelFile = cfg.Model + responseChannel := ttsbs.TextToAudioFile(input) + rawValue := <-responseChannel + if rawValue.Error != nil { + return rawValue.Error } - log.Debug().Msgf("Request for model: %s", modelFile) - - if input.Backend != "" { - cfg.Backend = input.Backend - } - - filePath, _, err := backend.ModelTTS(cfg.Backend, input.Input, modelFile, input.Voice, ml, appConfig, *cfg) - if err != nil { - return err - } - return c.Download(filePath) + return c.Download(*rawValue.Value) } } diff --git a/core/http/endpoints/openai/assistant.go b/core/http/endpoints/openai/assistant.go index dceb3789..72cb8b4a 100644 --- a/core/http/endpoints/openai/assistant.go +++ b/core/http/endpoints/openai/assistant.go @@ -339,7 +339,7 @@ func CreateAssistantFileEndpoint(cl *config.BackendConfigLoader, ml *model.Model } } - return c.Status(fiber.StatusNotFound).SendString(fmt.Sprintf("Unable to find ")) + return c.Status(fiber.StatusNotFound).SendString(fmt.Sprintf("Unable to find assistantID %q", assistantID)) } } diff --git a/core/http/endpoints/openai/chat.go b/core/http/endpoints/openai/chat.go index 36d1142b..a240b024 100644 --- a/core/http/endpoints/openai/chat.go +++ b/core/http/endpoints/openai/chat.go @@ -5,17 +5,11 @@ import ( "bytes" "encoding/json" "fmt" - "strings" - "time" - "github.com/go-skynet/LocalAI/core/backend" - "github.com/go-skynet/LocalAI/core/config" + fiberContext "github.com/go-skynet/LocalAI/core/http/ctx" "github.com/go-skynet/LocalAI/core/schema" - "github.com/go-skynet/LocalAI/pkg/grammar" - model "github.com/go-skynet/LocalAI/pkg/model" - "github.com/go-skynet/LocalAI/pkg/utils" + "github.com/go-skynet/LocalAI/core/services" "github.com/gofiber/fiber/v2" - "github.com/google/uuid" "github.com/rs/zerolog/log" "github.com/valyala/fasthttp" ) @@ -25,412 +19,82 @@ import ( // @Param request body schema.OpenAIRequest true "query params" // @Success 200 {object} schema.OpenAIResponse "Response" // @Router /v1/chat/completions [post] -func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startupOptions *config.ApplicationConfig) func(c *fiber.Ctx) error { - emptyMessage := "" - id := uuid.New().String() - created := int(time.Now().Unix()) - - process := func(s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse) { - initialMessage := schema.OpenAIResponse{ - ID: id, - Created: created, - Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant", Content: &emptyMessage}}}, - Object: "chat.completion.chunk", - } - responses <- initialMessage - - ComputeChoices(req, s, config, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool { - resp := schema.OpenAIResponse{ - ID: id, - Created: created, - Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []schema.Choice{{Delta: &schema.Message{Content: &s}, Index: 0}}, - Object: "chat.completion.chunk", - Usage: schema.OpenAIUsage{ - PromptTokens: usage.Prompt, - CompletionTokens: usage.Completion, - TotalTokens: usage.Prompt + usage.Completion, - }, - } - - responses <- resp - return true - }) - close(responses) - } - processTools := func(noAction string, prompt string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse) { - result := "" - _, tokenUsage, _ := ComputeChoices(req, prompt, config, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool { - result += s - // TODO: Change generated BNF grammar to be compliant with the schema so we can - // stream the result token by token here. - return true - }) - - results := parseFunctionCall(result, config.FunctionsConfig.ParallelCalls) - noActionToRun := len(results) > 0 && results[0].name == noAction - - switch { - case noActionToRun: - initialMessage := schema.OpenAIResponse{ - ID: id, - Created: created, - Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant", Content: &emptyMessage}}}, - Object: "chat.completion.chunk", - } - responses <- initialMessage - - result, err := handleQuestion(config, req, ml, startupOptions, results[0].arguments, prompt) - if err != nil { - log.Error().Err(err).Msg("error handling question") - return - } - - resp := schema.OpenAIResponse{ - ID: id, - Created: created, - Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []schema.Choice{{Delta: &schema.Message{Content: &result}, Index: 0}}, - Object: "chat.completion.chunk", - Usage: schema.OpenAIUsage{ - PromptTokens: tokenUsage.Prompt, - CompletionTokens: tokenUsage.Completion, - TotalTokens: tokenUsage.Prompt + tokenUsage.Completion, - }, - } - - responses <- resp - - default: - for i, ss := range results { - name, args := ss.name, ss.arguments - - initialMessage := schema.OpenAIResponse{ - ID: id, - Created: created, - Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []schema.Choice{{ - Delta: &schema.Message{ - Role: "assistant", - ToolCalls: []schema.ToolCall{ - { - Index: i, - ID: id, - Type: "function", - FunctionCall: schema.FunctionCall{ - Name: name, - }, - }, - }, - }}}, - Object: "chat.completion.chunk", - } - responses <- initialMessage - - responses <- schema.OpenAIResponse{ - ID: id, - Created: created, - Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []schema.Choice{{ - Delta: &schema.Message{ - Role: "assistant", - ToolCalls: []schema.ToolCall{ - { - Index: i, - ID: id, - Type: "function", - FunctionCall: schema.FunctionCall{ - Arguments: args, - }, - }, - }, - }}}, - Object: "chat.completion.chunk", - } - } - } - - close(responses) - } - +func ChatEndpoint(fce *fiberContext.FiberContextExtractor, oais *services.OpenAIService) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - processFunctions := false - funcs := grammar.Functions{} - modelFile, input, err := readRequest(c, ml, startupOptions, true) + _, request, err := fce.OpenAIRequestFromContext(c, false) if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) + return fmt.Errorf("failed reading parameters from request: %w", err) } - config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, startupOptions.Debug, startupOptions.Threads, startupOptions.ContextSize, startupOptions.F16) + traceID, finalResultChannel, _, tokenChannel, err := oais.Chat(request, false, request.Stream) if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - log.Debug().Msgf("Configuration read: %+v", config) - - // Allow the user to set custom actions via config file - // to be "embedded" in each model - noActionName := "answer" - noActionDescription := "use this action to answer without performing any action" - - if config.FunctionsConfig.NoActionFunctionName != "" { - noActionName = config.FunctionsConfig.NoActionFunctionName - } - if config.FunctionsConfig.NoActionDescriptionName != "" { - noActionDescription = config.FunctionsConfig.NoActionDescriptionName + return err } - if input.ResponseFormat.Type == "json_object" { - input.Grammar = grammar.JSONBNF - } + if request.Stream { - config.Grammar = input.Grammar + log.Debug().Msgf("Chat Stream request received") - // process functions if we have any defined or if we have a function call string - if len(input.Functions) > 0 && config.ShouldUseFunctions() { - log.Debug().Msgf("Response needs to process functions") - - processFunctions = true - - noActionGrammar := grammar.Function{ - Name: noActionName, - Description: noActionDescription, - Parameters: map[string]interface{}{ - "properties": map[string]interface{}{ - "message": map[string]interface{}{ - "type": "string", - "description": "The message to reply the user with", - }}, - }, - } - - // Append the no action function - funcs = append(funcs, input.Functions...) - if !config.FunctionsConfig.DisableNoAction { - funcs = append(funcs, noActionGrammar) - } - - // Force picking one of the functions by the request - if config.FunctionToCall() != "" { - funcs = funcs.Select(config.FunctionToCall()) - } - - // Update input grammar - jsStruct := funcs.ToJSONStructure() - config.Grammar = jsStruct.Grammar("", config.FunctionsConfig.ParallelCalls) - } else if input.JSONFunctionGrammarObject != nil { - config.Grammar = input.JSONFunctionGrammarObject.Grammar("", config.FunctionsConfig.ParallelCalls) - } - - // functions are not supported in stream mode (yet?) - toStream := input.Stream - - log.Debug().Msgf("Parameters: %+v", config) - - var predInput string - - // If we are using the tokenizer template, we don't need to process the messages - // unless we are processing functions - if !config.TemplateConfig.UseTokenizerTemplate || processFunctions { - - suppressConfigSystemPrompt := false - mess := []string{} - for messageIndex, i := range input.Messages { - var content string - role := i.Role - - // if function call, we might want to customize the role so we can display better that the "assistant called a json action" - // if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request - if (i.FunctionCall != nil || i.ToolCalls != nil) && i.Role == "assistant" { - roleFn := "assistant_function_call" - r := config.Roles[roleFn] - if r != "" { - role = roleFn - } - } - r := config.Roles[role] - contentExists := i.Content != nil && i.StringContent != "" - - fcall := i.FunctionCall - if len(i.ToolCalls) > 0 { - fcall = i.ToolCalls - } - - // First attempt to populate content via a chat message specific template - if config.TemplateConfig.ChatMessage != "" { - chatMessageData := model.ChatMessageTemplateData{ - SystemPrompt: config.SystemPrompt, - Role: r, - RoleName: role, - Content: i.StringContent, - FunctionCall: fcall, - FunctionName: i.Name, - LastMessage: messageIndex == (len(input.Messages) - 1), - Function: config.Grammar != "" && (messageIndex == (len(input.Messages) - 1)), - MessageIndex: messageIndex, - } - templatedChatMessage, err := ml.EvaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData) - if err != nil { - log.Error().Err(err).Interface("message", chatMessageData).Str("template", config.TemplateConfig.ChatMessage).Msg("error processing message with template, skipping") - } else { - if templatedChatMessage == "" { - log.Warn().Msgf("template \"%s\" produced blank output for %+v. Skipping!", config.TemplateConfig.ChatMessage, chatMessageData) - continue // TODO: This continue is here intentionally to skip over the line `mess = append(mess, content)` below, and to prevent the sprintf - } - log.Debug().Msgf("templated message for chat: %s", templatedChatMessage) - content = templatedChatMessage - } - } - - marshalAnyRole := func(f any) { - j, err := json.Marshal(f) - if err == nil { - if contentExists { - content += "\n" + fmt.Sprint(r, " ", string(j)) - } else { - content = fmt.Sprint(r, " ", string(j)) - } - } - } - marshalAny := func(f any) { - j, err := json.Marshal(f) - if err == nil { - if contentExists { - content += "\n" + string(j) - } else { - content = string(j) - } - } - } - // If this model doesn't have such a template, or if that template fails to return a value, template at the message level. - if content == "" { - if r != "" { - if contentExists { - content = fmt.Sprint(r, i.StringContent) - } - - if i.FunctionCall != nil { - marshalAnyRole(i.FunctionCall) - } - if i.ToolCalls != nil { - marshalAnyRole(i.ToolCalls) - } - } else { - if contentExists { - content = fmt.Sprint(i.StringContent) - } - if i.FunctionCall != nil { - marshalAny(i.FunctionCall) - } - if i.ToolCalls != nil { - marshalAny(i.ToolCalls) - } - } - // Special Handling: System. We care if it was printed at all, not the r branch, so check seperately - if contentExists && role == "system" { - suppressConfigSystemPrompt = true - } - } - - mess = append(mess, content) - } - - predInput = strings.Join(mess, "\n") - log.Debug().Msgf("Prompt (before templating): %s", predInput) - - templateFile := "" - - // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) { - templateFile = config.Model - } - - if config.TemplateConfig.Chat != "" && !processFunctions { - templateFile = config.TemplateConfig.Chat - } - - if config.TemplateConfig.Functions != "" && processFunctions { - templateFile = config.TemplateConfig.Functions - } - - if templateFile != "" { - templatedInput, err := ml.EvaluateTemplateForPrompt(model.ChatPromptTemplate, templateFile, model.PromptTemplateData{ - SystemPrompt: config.SystemPrompt, - SuppressSystemPrompt: suppressConfigSystemPrompt, - Input: predInput, - Functions: funcs, - }) - if err == nil { - predInput = templatedInput - log.Debug().Msgf("Template found, input modified to: %s", predInput) - } else { - log.Debug().Msgf("Template failed loading: %s", err.Error()) - } - } - - log.Debug().Msgf("Prompt (after templating): %s", predInput) - if processFunctions { - log.Debug().Msgf("Grammar: %+v", config.Grammar) - } - } - - switch { - case toStream: - - log.Debug().Msgf("Stream request received") c.Context().SetContentType("text/event-stream") //c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8) - // c.Set("Content-Type", "text/event-stream") + // c.Set("Cache-Control", "no-cache") c.Set("Connection", "keep-alive") c.Set("Transfer-Encoding", "chunked") - responses := make(chan schema.OpenAIResponse) - - if !processFunctions { - go process(predInput, input, config, ml, responses) - } else { - go processTools(noActionName, predInput, input, config, ml, responses) - } - c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { usage := &schema.OpenAIUsage{} toolsCalled := false - for ev := range responses { - usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it - if len(ev.Choices[0].Delta.ToolCalls) > 0 { + for ev := range tokenChannel { + if ev.Error != nil { + log.Debug().Err(ev.Error).Msg("chat streaming responseChannel error") + request.Cancel() + break + } + usage = &ev.Value.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it + + if len(ev.Value.Choices[0].Delta.ToolCalls) > 0 { toolsCalled = true } var buf bytes.Buffer enc := json.NewEncoder(&buf) - enc.Encode(ev) - log.Debug().Msgf("Sending chunk: %s", buf.String()) + if ev.Error != nil { + log.Debug().Err(ev.Error).Msg("[ChatEndpoint] error to debug during tokenChannel handler") + enc.Encode(ev.Error) + } else { + enc.Encode(ev.Value) + } + log.Debug().Msgf("chat streaming sending chunk: %s", buf.String()) _, err := fmt.Fprintf(w, "data: %v\n", buf.String()) if err != nil { - log.Debug().Msgf("Sending chunk failed: %v", err) - input.Cancel() + log.Debug().Err(err).Msgf("Sending chunk failed") + request.Cancel() + break + } + err = w.Flush() + if err != nil { + log.Debug().Msg("error while flushing, closing connection") + request.Cancel() break } - w.Flush() } finishReason := "stop" if toolsCalled { finishReason = "tool_calls" - } else if toolsCalled && len(input.Tools) == 0 { + } else if toolsCalled && len(request.Tools) == 0 { finishReason = "function_call" } resp := &schema.OpenAIResponse{ - ID: id, - Created: created, - Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + ID: traceID.ID, + Created: traceID.Created, + Model: request.Model, // we have to return what the user sent here, due to OpenAI spec. Choices: []schema.Choice{ { FinishReason: finishReason, Index: 0, - Delta: &schema.Message{Content: &emptyMessage}, + Delta: &schema.Message{Content: ""}, }}, Object: "chat.completion.chunk", Usage: *usage, @@ -441,202 +105,21 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup w.WriteString("data: [DONE]\n\n") w.Flush() })) + return nil - - // no streaming mode - default: - result, tokenUsage, err := ComputeChoices(input, predInput, config, startupOptions, ml, func(s string, c *[]schema.Choice) { - if !processFunctions { - // no function is called, just reply and use stop as finish reason - *c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}}) - return - } - - results := parseFunctionCall(s, config.FunctionsConfig.ParallelCalls) - noActionsToRun := len(results) > 0 && results[0].name == noActionName - - switch { - case noActionsToRun: - result, err := handleQuestion(config, input, ml, startupOptions, results[0].arguments, predInput) - if err != nil { - log.Error().Err(err).Msg("error handling question") - return - } - *c = append(*c, schema.Choice{ - Message: &schema.Message{Role: "assistant", Content: &result}}) - default: - toolChoice := schema.Choice{ - Message: &schema.Message{ - Role: "assistant", - }, - } - - if len(input.Tools) > 0 { - toolChoice.FinishReason = "tool_calls" - } - - for _, ss := range results { - name, args := ss.name, ss.arguments - if len(input.Tools) > 0 { - // If we are using tools, we condense the function calls into - // a single response choice with all the tools - toolChoice.Message.ToolCalls = append(toolChoice.Message.ToolCalls, - schema.ToolCall{ - ID: id, - Type: "function", - FunctionCall: schema.FunctionCall{ - Name: name, - Arguments: args, - }, - }, - ) - } else { - // otherwise we return more choices directly - *c = append(*c, schema.Choice{ - FinishReason: "function_call", - Message: &schema.Message{ - Role: "assistant", - FunctionCall: map[string]interface{}{ - "name": name, - "arguments": args, - }, - }, - }) - } - } - - if len(input.Tools) > 0 { - // we need to append our result if we are using tools - *c = append(*c, toolChoice) - } - } - - }, nil) - if err != nil { - return err - } - - resp := &schema.OpenAIResponse{ - ID: id, - Created: created, - Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: result, - Object: "chat.completion", - Usage: schema.OpenAIUsage{ - PromptTokens: tokenUsage.Prompt, - CompletionTokens: tokenUsage.Completion, - TotalTokens: tokenUsage.Prompt + tokenUsage.Completion, - }, - } - respData, _ := json.Marshal(resp) - log.Debug().Msgf("Response: %s", respData) - - // Return the prediction in the response body - return c.JSON(resp) } + // TODO is this proper to have exclusive from Stream, or do we need to issue both responses? + rawResponse := <-finalResultChannel + + if rawResponse.Error != nil { + return rawResponse.Error + } + + jsonResult, _ := json.Marshal(rawResponse.Value) + log.Debug().Str("jsonResult", string(jsonResult)).Msg("Chat Final Response") + + // Return the prediction in the response body + return c.JSON(rawResponse.Value) } } - -func handleQuestion(config *config.BackendConfig, input *schema.OpenAIRequest, ml *model.ModelLoader, o *config.ApplicationConfig, args, prompt string) (string, error) { - log.Debug().Msgf("nothing to do, computing a reply") - - // If there is a message that the LLM already sends as part of the JSON reply, use it - arguments := map[string]interface{}{} - json.Unmarshal([]byte(args), &arguments) - m, exists := arguments["message"] - if exists { - switch message := m.(type) { - case string: - if message != "" { - log.Debug().Msgf("Reply received from LLM: %s", message) - message = backend.Finetune(*config, prompt, message) - log.Debug().Msgf("Reply received from LLM(finetuned): %s", message) - - return message, nil - } - } - } - - log.Debug().Msgf("No action received from LLM, without a message, computing a reply") - // Otherwise ask the LLM to understand the JSON output and the context, and return a message - // Note: This costs (in term of CPU/GPU) another computation - config.Grammar = "" - images := []string{} - for _, m := range input.Messages { - images = append(images, m.StringImages...) - } - - predFunc, err := backend.ModelInference(input.Context, prompt, input.Messages, images, ml, *config, o, nil) - if err != nil { - log.Error().Err(err).Msg("model inference failed") - return "", err - } - - prediction, err := predFunc() - if err != nil { - log.Error().Err(err).Msg("prediction failed") - return "", err - } - return backend.Finetune(*config, prompt, prediction.Response), nil -} - -type funcCallResults struct { - name string - arguments string -} - -func parseFunctionCall(llmresult string, multipleResults bool) []funcCallResults { - results := []funcCallResults{} - - // TODO: use generics to avoid this code duplication - if multipleResults { - ss := []map[string]interface{}{} - s := utils.EscapeNewLines(llmresult) - json.Unmarshal([]byte(s), &ss) - log.Debug().Msgf("Function return: %s %+v", s, ss) - - for _, s := range ss { - func_name, ok := s["function"] - if !ok { - continue - } - args, ok := s["arguments"] - if !ok { - continue - } - d, _ := json.Marshal(args) - funcName, ok := func_name.(string) - if !ok { - continue - } - results = append(results, funcCallResults{name: funcName, arguments: string(d)}) - } - } else { - // As we have to change the result before processing, we can't stream the answer token-by-token (yet?) - ss := map[string]interface{}{} - // This prevent newlines to break JSON parsing for clients - s := utils.EscapeNewLines(llmresult) - json.Unmarshal([]byte(s), &ss) - log.Debug().Msgf("Function return: %s %+v", s, ss) - - // The grammar defines the function name as "function", while OpenAI returns "name" - func_name, ok := ss["function"] - if !ok { - return results - } - // Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object - args, ok := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix) - if !ok { - return results - } - d, _ := json.Marshal(args) - funcName, ok := func_name.(string) - if !ok { - return results - } - results = append(results, funcCallResults{name: funcName, arguments: string(d)}) - } - - return results -} diff --git a/core/http/endpoints/openai/completion.go b/core/http/endpoints/openai/completion.go index 69923475..d8b412a9 100644 --- a/core/http/endpoints/openai/completion.go +++ b/core/http/endpoints/openai/completion.go @@ -4,18 +4,13 @@ import ( "bufio" "bytes" "encoding/json" - "errors" "fmt" - "time" - "github.com/go-skynet/LocalAI/core/backend" - "github.com/go-skynet/LocalAI/core/config" + fiberContext "github.com/go-skynet/LocalAI/core/http/ctx" + "github.com/go-skynet/LocalAI/core/services" "github.com/go-skynet/LocalAI/core/schema" - "github.com/go-skynet/LocalAI/pkg/grammar" - model "github.com/go-skynet/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" - "github.com/google/uuid" "github.com/rs/zerolog/log" "github.com/valyala/fasthttp" ) @@ -25,116 +20,50 @@ import ( // @Param request body schema.OpenAIRequest true "query params" // @Success 200 {object} schema.OpenAIResponse "Response" // @Router /v1/completions [post] -func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { - id := uuid.New().String() - created := int(time.Now().Unix()) - - process := func(s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse) { - ComputeChoices(req, s, config, appConfig, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool { - resp := schema.OpenAIResponse{ - ID: id, - Created: created, - Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []schema.Choice{ - { - Index: 0, - Text: s, - }, - }, - Object: "text_completion", - Usage: schema.OpenAIUsage{ - PromptTokens: usage.Prompt, - CompletionTokens: usage.Completion, - TotalTokens: usage.Prompt + usage.Completion, - }, - } - log.Debug().Msgf("Sending goroutine: %s", s) - - responses <- resp - return true - }) - close(responses) - } - +func CompletionEndpoint(fce *fiberContext.FiberContextExtractor, oais *services.OpenAIService) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - modelFile, input, err := readRequest(c, ml, appConfig, true) + _, request, err := fce.OpenAIRequestFromContext(c, false) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } - log.Debug().Msgf("`input`: %+v", input) + log.Debug().Msgf("`OpenAIRequest`: %+v", request) - config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16) + traceID, finalResultChannel, _, _, tokenChannel, err := oais.Completion(request, false, request.Stream) if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) + return err } - if input.ResponseFormat.Type == "json_object" { - input.Grammar = grammar.JSONBNF - } + if request.Stream { + log.Debug().Msgf("Completion Stream request received") - config.Grammar = input.Grammar - - log.Debug().Msgf("Parameter Config: %+v", config) - - if input.Stream { - log.Debug().Msgf("Stream request received") c.Context().SetContentType("text/event-stream") //c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8) //c.Set("Content-Type", "text/event-stream") c.Set("Cache-Control", "no-cache") c.Set("Connection", "keep-alive") c.Set("Transfer-Encoding", "chunked") - } - - templateFile := "" - - // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) { - templateFile = config.Model - } - - if config.TemplateConfig.Completion != "" { - templateFile = config.TemplateConfig.Completion - } - - if input.Stream { - if len(config.PromptStrings) > 1 { - return errors.New("cannot handle more than 1 `PromptStrings` when Streaming") - } - - predInput := config.PromptStrings[0] - - if templateFile != "" { - templatedInput, err := ml.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{ - Input: predInput, - }) - if err == nil { - predInput = templatedInput - log.Debug().Msgf("Template found, input modified to: %s", predInput) - } - } - - responses := make(chan schema.OpenAIResponse) - - go process(predInput, input, config, ml, responses) c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { - - for ev := range responses { + for ev := range tokenChannel { var buf bytes.Buffer enc := json.NewEncoder(&buf) - enc.Encode(ev) + if ev.Error != nil { + log.Debug().Msgf("[CompletionEndpoint] error to debug during tokenChannel handler: %q", ev.Error) + enc.Encode(ev.Error) + } else { + enc.Encode(ev.Value) + } - log.Debug().Msgf("Sending chunk: %s", buf.String()) + log.Debug().Msgf("completion streaming sending chunk: %s", buf.String()) fmt.Fprintf(w, "data: %v\n", buf.String()) w.Flush() } resp := &schema.OpenAIResponse{ - ID: id, - Created: created, - Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + ID: traceID.ID, + Created: traceID.Created, + Model: request.Model, // we have to return what the user sent here, due to OpenAI spec. Choices: []schema.Choice{ { Index: 0, @@ -151,55 +80,15 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a })) return nil } - - var result []schema.Choice - - totalTokenUsage := backend.TokenUsage{} - - for k, i := range config.PromptStrings { - if templateFile != "" { - // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - templatedInput, err := ml.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{ - SystemPrompt: config.SystemPrompt, - Input: i, - }) - if err == nil { - i = templatedInput - log.Debug().Msgf("Template found, input modified to: %s", i) - } - } - - r, tokenUsage, err := ComputeChoices( - input, i, config, appConfig, ml, func(s string, c *[]schema.Choice) { - *c = append(*c, schema.Choice{Text: s, FinishReason: "stop", Index: k}) - }, nil) - if err != nil { - return err - } - - totalTokenUsage.Prompt += tokenUsage.Prompt - totalTokenUsage.Completion += tokenUsage.Completion - - result = append(result, r...) + // TODO is this proper to have exclusive from Stream, or do we need to issue both responses? + rawResponse := <-finalResultChannel + if rawResponse.Error != nil { + return rawResponse.Error } - - resp := &schema.OpenAIResponse{ - ID: id, - Created: created, - Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: result, - Object: "text_completion", - Usage: schema.OpenAIUsage{ - PromptTokens: totalTokenUsage.Prompt, - CompletionTokens: totalTokenUsage.Completion, - TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion, - }, - } - - jsonResult, _ := json.Marshal(resp) + jsonResult, _ := json.Marshal(rawResponse.Value) log.Debug().Msgf("Response: %s", jsonResult) // Return the prediction in the response body - return c.JSON(resp) + return c.JSON(rawResponse.Value) } } diff --git a/core/http/endpoints/openai/edit.go b/core/http/endpoints/openai/edit.go index 25497095..a33050dd 100644 --- a/core/http/endpoints/openai/edit.go +++ b/core/http/endpoints/openai/edit.go @@ -3,92 +3,36 @@ package openai import ( "encoding/json" "fmt" - "time" - "github.com/go-skynet/LocalAI/core/backend" - "github.com/go-skynet/LocalAI/core/config" + fiberContext "github.com/go-skynet/LocalAI/core/http/ctx" + "github.com/go-skynet/LocalAI/core/services" - "github.com/go-skynet/LocalAI/core/schema" - model "github.com/go-skynet/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" - "github.com/google/uuid" "github.com/rs/zerolog/log" ) -func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { +func EditEndpoint(fce *fiberContext.FiberContextExtractor, oais *services.OpenAIService) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - modelFile, input, err := readRequest(c, ml, appConfig, true) + _, request, err := fce.OpenAIRequestFromContext(c, false) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } - config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16) + _, finalResultChannel, _, _, _, err := oais.Edit(request, false, request.Stream) if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) + return err } - log.Debug().Msgf("Parameter Config: %+v", config) - - templateFile := "" - - // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) { - templateFile = config.Model + rawResponse := <-finalResultChannel + if rawResponse.Error != nil { + return rawResponse.Error } - if config.TemplateConfig.Edit != "" { - templateFile = config.TemplateConfig.Edit - } - - var result []schema.Choice - totalTokenUsage := backend.TokenUsage{} - - for _, i := range config.InputStrings { - if templateFile != "" { - templatedInput, err := ml.EvaluateTemplateForPrompt(model.EditPromptTemplate, templateFile, model.PromptTemplateData{ - Input: i, - Instruction: input.Instruction, - SystemPrompt: config.SystemPrompt, - }) - if err == nil { - i = templatedInput - log.Debug().Msgf("Template found, input modified to: %s", i) - } - } - - r, tokenUsage, err := ComputeChoices(input, i, config, appConfig, ml, func(s string, c *[]schema.Choice) { - *c = append(*c, schema.Choice{Text: s}) - }, nil) - if err != nil { - return err - } - - totalTokenUsage.Prompt += tokenUsage.Prompt - totalTokenUsage.Completion += tokenUsage.Completion - - result = append(result, r...) - } - - id := uuid.New().String() - created := int(time.Now().Unix()) - resp := &schema.OpenAIResponse{ - ID: id, - Created: created, - Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: result, - Object: "edit", - Usage: schema.OpenAIUsage{ - PromptTokens: totalTokenUsage.Prompt, - CompletionTokens: totalTokenUsage.Completion, - TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion, - }, - } - - jsonResult, _ := json.Marshal(resp) + jsonResult, _ := json.Marshal(rawResponse.Value) log.Debug().Msgf("Response: %s", jsonResult) // Return the prediction in the response body - return c.JSON(resp) + return c.JSON(rawResponse.Value) } } diff --git a/core/http/endpoints/openai/embeddings.go b/core/http/endpoints/openai/embeddings.go index eca34f79..be546991 100644 --- a/core/http/endpoints/openai/embeddings.go +++ b/core/http/endpoints/openai/embeddings.go @@ -3,14 +3,9 @@ package openai import ( "encoding/json" "fmt" - "time" "github.com/go-skynet/LocalAI/core/backend" - "github.com/go-skynet/LocalAI/core/config" - "github.com/go-skynet/LocalAI/pkg/model" - - "github.com/go-skynet/LocalAI/core/schema" - "github.com/google/uuid" + fiberContext "github.com/go-skynet/LocalAI/core/http/ctx" "github.com/gofiber/fiber/v2" "github.com/rs/zerolog/log" @@ -21,63 +16,25 @@ import ( // @Param request body schema.OpenAIRequest true "query params" // @Success 200 {object} schema.OpenAIResponse "Response" // @Router /v1/embeddings [post] -func EmbeddingsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { +func EmbeddingsEndpoint(fce *fiberContext.FiberContextExtractor, ebs *backend.EmbeddingsBackendService) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - model, input, err := readRequest(c, ml, appConfig, true) + _, input, err := fce.OpenAIRequestFromContext(c, true) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } - config, input, err := mergeRequestWithConfig(model, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) + responseChannel := ebs.Embeddings(input) + + rawResponse := <-responseChannel + + if rawResponse.Error != nil { + return rawResponse.Error } - log.Debug().Msgf("Parameter Config: %+v", config) - items := []schema.Item{} - - for i, s := range config.InputToken { - // get the model function to call for the result - embedFn, err := backend.ModelEmbedding("", s, ml, *config, appConfig) - if err != nil { - return err - } - - embeddings, err := embedFn() - if err != nil { - return err - } - items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"}) - } - - for i, s := range config.InputStrings { - // get the model function to call for the result - embedFn, err := backend.ModelEmbedding(s, []int{}, ml, *config, appConfig) - if err != nil { - return err - } - - embeddings, err := embedFn() - if err != nil { - return err - } - items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"}) - } - - id := uuid.New().String() - created := int(time.Now().Unix()) - resp := &schema.OpenAIResponse{ - ID: id, - Created: created, - Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. - Data: items, - Object: "list", - } - - jsonResult, _ := json.Marshal(resp) + jsonResult, _ := json.Marshal(rawResponse.Value) log.Debug().Msgf("Response: %s", jsonResult) // Return the prediction in the response body - return c.JSON(resp) + return c.JSON(rawResponse.Value) } } diff --git a/core/http/endpoints/openai/image.go b/core/http/endpoints/openai/image.go index 9e806b3e..ec3d84da 100644 --- a/core/http/endpoints/openai/image.go +++ b/core/http/endpoints/openai/image.go @@ -1,50 +1,18 @@ package openai import ( - "bufio" - "encoding/base64" "encoding/json" "fmt" - "io" - "net/http" - "os" - "path/filepath" - "strconv" - "strings" - "time" - "github.com/go-skynet/LocalAI/core/config" - "github.com/go-skynet/LocalAI/core/schema" - "github.com/google/uuid" + fiberContext "github.com/go-skynet/LocalAI/core/http/ctx" "github.com/go-skynet/LocalAI/core/backend" - model "github.com/go-skynet/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" "github.com/rs/zerolog/log" ) -func downloadFile(url string) (string, error) { - // Get the data - resp, err := http.Get(url) - if err != nil { - return "", err - } - defer resp.Body.Close() - - // Create the file - out, err := os.CreateTemp("", "image") - if err != nil { - return "", err - } - defer out.Close() - - // Write the body to file - _, err = io.Copy(out, resp.Body) - return out.Name(), err -} - -// +// https://platform.openai.com/docs/api-reference/images/create /* * @@ -59,186 +27,36 @@ func downloadFile(url string) (string, error) { * */ + // ImageEndpoint is the OpenAI Image generation API endpoint https://platform.openai.com/docs/api-reference/images/create // @Summary Creates an image given a prompt. // @Param request body schema.OpenAIRequest true "query params" // @Success 200 {object} schema.OpenAIResponse "Response" // @Router /v1/images/generations [post] -func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { +func ImageEndpoint(fce *fiberContext.FiberContextExtractor, igbs *backend.ImageGenerationBackendService) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - m, input, err := readRequest(c, ml, appConfig, false) + // TODO: Somewhat a hack. Is there a better place to assign this? + if igbs.BaseUrlForGeneratedImages == "" { + igbs.BaseUrlForGeneratedImages = c.BaseURL() + "/generated-images/" + } + _, request, err := fce.OpenAIRequestFromContext(c, false) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } - if m == "" { - m = model.StableDiffusionBackend - } - log.Debug().Msgf("Loading model: %+v", m) + responseChannel := igbs.GenerateImage(request) + rawResponse := <-responseChannel - config, input, err := mergeRequestWithConfig(m, input, cl, ml, appConfig.Debug, 0, 0, false) + if rawResponse.Error != nil { + return rawResponse.Error + } + + jsonResult, err := json.Marshal(rawResponse.Value) if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) + return err } - - src := "" - if input.File != "" { - - fileData := []byte{} - // check if input.File is an URL, if so download it and save it - // to a temporary file - if strings.HasPrefix(input.File, "http://") || strings.HasPrefix(input.File, "https://") { - out, err := downloadFile(input.File) - if err != nil { - return fmt.Errorf("failed downloading file:%w", err) - } - defer os.RemoveAll(out) - - fileData, err = os.ReadFile(out) - if err != nil { - return fmt.Errorf("failed reading file:%w", err) - } - - } else { - // base 64 decode the file and write it somewhere - // that we will cleanup - fileData, err = base64.StdEncoding.DecodeString(input.File) - if err != nil { - return err - } - } - - // Create a temporary file - outputFile, err := os.CreateTemp(appConfig.ImageDir, "b64") - if err != nil { - return err - } - // write the base64 result - writer := bufio.NewWriter(outputFile) - _, err = writer.Write(fileData) - if err != nil { - outputFile.Close() - return err - } - outputFile.Close() - src = outputFile.Name() - defer os.RemoveAll(src) - } - - log.Debug().Msgf("Parameter Config: %+v", config) - - switch config.Backend { - case "stablediffusion": - config.Backend = model.StableDiffusionBackend - case "tinydream": - config.Backend = model.TinyDreamBackend - case "": - config.Backend = model.StableDiffusionBackend - } - - sizeParts := strings.Split(input.Size, "x") - if len(sizeParts) != 2 { - return fmt.Errorf("invalid value for 'size'") - } - width, err := strconv.Atoi(sizeParts[0]) - if err != nil { - return fmt.Errorf("invalid value for 'size'") - } - height, err := strconv.Atoi(sizeParts[1]) - if err != nil { - return fmt.Errorf("invalid value for 'size'") - } - - b64JSON := false - if input.ResponseFormat.Type == "b64_json" { - b64JSON = true - } - // src and clip_skip - var result []schema.Item - for _, i := range config.PromptStrings { - n := input.N - if input.N == 0 { - n = 1 - } - for j := 0; j < n; j++ { - prompts := strings.Split(i, "|") - positive_prompt := prompts[0] - negative_prompt := "" - if len(prompts) > 1 { - negative_prompt = prompts[1] - } - - mode := 0 - step := config.Step - if step == 0 { - step = 15 - } - - if input.Mode != 0 { - mode = input.Mode - } - - if input.Step != 0 { - step = input.Step - } - - tempDir := "" - if !b64JSON { - tempDir = appConfig.ImageDir - } - // Create a temporary file - outputFile, err := os.CreateTemp(tempDir, "b64") - if err != nil { - return err - } - outputFile.Close() - output := outputFile.Name() + ".png" - // Rename the temporary file - err = os.Rename(outputFile.Name(), output) - if err != nil { - return err - } - - baseURL := c.BaseURL() - - fn, err := backend.ImageGeneration(height, width, mode, step, *config.Seed, positive_prompt, negative_prompt, src, output, ml, *config, appConfig) - if err != nil { - return err - } - if err := fn(); err != nil { - return err - } - - item := &schema.Item{} - - if b64JSON { - defer os.RemoveAll(output) - data, err := os.ReadFile(output) - if err != nil { - return err - } - item.B64JSON = base64.StdEncoding.EncodeToString(data) - } else { - base := filepath.Base(output) - item.URL = baseURL + "/generated-images/" + base - } - - result = append(result, *item) - } - } - - id := uuid.New().String() - created := int(time.Now().Unix()) - resp := &schema.OpenAIResponse{ - ID: id, - Created: created, - Data: result, - } - - jsonResult, _ := json.Marshal(resp) log.Debug().Msgf("Response: %s", jsonResult) - // Return the prediction in the response body - return c.JSON(resp) + return c.JSON(rawResponse.Value) } } diff --git a/core/http/endpoints/openai/inference.go b/core/http/endpoints/openai/inference.go deleted file mode 100644 index 06e784b7..00000000 --- a/core/http/endpoints/openai/inference.go +++ /dev/null @@ -1,55 +0,0 @@ -package openai - -import ( - "github.com/go-skynet/LocalAI/core/backend" - "github.com/go-skynet/LocalAI/core/config" - - "github.com/go-skynet/LocalAI/core/schema" - model "github.com/go-skynet/LocalAI/pkg/model" -) - -func ComputeChoices( - req *schema.OpenAIRequest, - predInput string, - config *config.BackendConfig, - o *config.ApplicationConfig, - loader *model.ModelLoader, - cb func(string, *[]schema.Choice), - tokenCallback func(string, backend.TokenUsage) bool) ([]schema.Choice, backend.TokenUsage, error) { - n := req.N // number of completions to return - result := []schema.Choice{} - - if n == 0 { - n = 1 - } - - images := []string{} - for _, m := range req.Messages { - images = append(images, m.StringImages...) - } - - // get the model function to call for the result - predFunc, err := backend.ModelInference(req.Context, predInput, req.Messages, images, loader, *config, o, tokenCallback) - if err != nil { - return result, backend.TokenUsage{}, err - } - - tokenUsage := backend.TokenUsage{} - - for i := 0; i < n; i++ { - prediction, err := predFunc() - if err != nil { - return result, backend.TokenUsage{}, err - } - - tokenUsage.Prompt += prediction.Usage.Prompt - tokenUsage.Completion += prediction.Usage.Completion - - finetunedResponse := backend.Finetune(*config, predInput, prediction.Response) - cb(finetunedResponse, &result) - - //result = append(result, Choice{Text: prediction}) - - } - return result, tokenUsage, err -} diff --git a/core/http/endpoints/openai/list.go b/core/http/endpoints/openai/list.go index 04e611a2..9bb2b2ca 100644 --- a/core/http/endpoints/openai/list.go +++ b/core/http/endpoints/openai/list.go @@ -1,61 +1,21 @@ package openai import ( - "regexp" - - "github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/core/schema" - model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/go-skynet/LocalAI/core/services" "github.com/gofiber/fiber/v2" ) -func ListModelsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader) func(ctx *fiber.Ctx) error { +func ListModelsEndpoint(lms *services.ListModelsService) func(ctx *fiber.Ctx) error { return func(c *fiber.Ctx) error { - models, err := ml.ListModels() - if err != nil { - return err - } - var mm map[string]interface{} = map[string]interface{}{} - - dataModels := []schema.OpenAIModel{} - - var filterFn func(name string) bool + // If blank, no filter is applied. filter := c.Query("filter") - - // If filter is not specified, do not filter the list by model name - if filter == "" { - filterFn = func(_ string) bool { return true } - } else { - // If filter _IS_ specified, we compile it to a regex which is used to create the filterFn - rxp, err := regexp.Compile(filter) - if err != nil { - return err - } - filterFn = func(name string) bool { - return rxp.MatchString(name) - } - } - // By default, exclude any loose files that are already referenced by a configuration file. excludeConfigured := c.QueryBool("excludeConfigured", true) - // Start with the known configurations - for _, c := range cl.GetAllBackendConfigs() { - if excludeConfigured { - mm[c.Model] = nil - } - - if filterFn(c.Name) { - dataModels = append(dataModels, schema.OpenAIModel{ID: c.Name, Object: "model"}) - } - } - - // Then iterate through the loose files: - for _, m := range models { - // And only adds them if they shouldn't be skipped. - if _, exists := mm[m]; !exists && filterFn(m) { - dataModels = append(dataModels, schema.OpenAIModel{ID: m, Object: "model"}) - } + dataModels, err := lms.ListModels(filter, excludeConfigured) + if err != nil { + return err } return c.JSON(struct { diff --git a/core/http/endpoints/openai/request.go b/core/http/endpoints/openai/request.go deleted file mode 100644 index 369fb0b8..00000000 --- a/core/http/endpoints/openai/request.go +++ /dev/null @@ -1,285 +0,0 @@ -package openai - -import ( - "context" - "encoding/base64" - "encoding/json" - "fmt" - "io" - "net/http" - "strings" - - "github.com/go-skynet/LocalAI/core/config" - fiberContext "github.com/go-skynet/LocalAI/core/http/ctx" - "github.com/go-skynet/LocalAI/core/schema" - "github.com/go-skynet/LocalAI/pkg/grammar" - model "github.com/go-skynet/LocalAI/pkg/model" - "github.com/gofiber/fiber/v2" - "github.com/rs/zerolog/log" -) - -func readRequest(c *fiber.Ctx, ml *model.ModelLoader, o *config.ApplicationConfig, firstModel bool) (string, *schema.OpenAIRequest, error) { - input := new(schema.OpenAIRequest) - - // Get input data from the request body - if err := c.BodyParser(input); err != nil { - return "", nil, fmt.Errorf("failed parsing request body: %w", err) - } - - received, _ := json.Marshal(input) - - ctx, cancel := context.WithCancel(o.Context) - input.Context = ctx - input.Cancel = cancel - - log.Debug().Msgf("Request received: %s", string(received)) - - modelFile, err := fiberContext.ModelFromContext(c, ml, input.Model, firstModel) - - return modelFile, input, err -} - -// this function check if the string is an URL, if it's an URL downloads the image in memory -// encodes it in base64 and returns the base64 string -func getBase64Image(s string) (string, error) { - if strings.HasPrefix(s, "http") { - // download the image - resp, err := http.Get(s) - if err != nil { - return "", err - } - defer resp.Body.Close() - - // read the image data into memory - data, err := io.ReadAll(resp.Body) - if err != nil { - return "", err - } - - // encode the image data in base64 - encoded := base64.StdEncoding.EncodeToString(data) - - // return the base64 string - return encoded, nil - } - - // if the string instead is prefixed with "data:image/jpeg;base64,", drop it - if strings.HasPrefix(s, "data:image/jpeg;base64,") { - return strings.ReplaceAll(s, "data:image/jpeg;base64,", ""), nil - } - return "", fmt.Errorf("not valid string") -} - -func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIRequest) { - if input.Echo { - config.Echo = input.Echo - } - if input.TopK != nil { - config.TopK = input.TopK - } - if input.TopP != nil { - config.TopP = input.TopP - } - - if input.Backend != "" { - config.Backend = input.Backend - } - - if input.ClipSkip != 0 { - config.Diffusers.ClipSkip = input.ClipSkip - } - - if input.ModelBaseName != "" { - config.AutoGPTQ.ModelBaseName = input.ModelBaseName - } - - if input.NegativePromptScale != 0 { - config.NegativePromptScale = input.NegativePromptScale - } - - if input.UseFastTokenizer { - config.UseFastTokenizer = input.UseFastTokenizer - } - - if input.NegativePrompt != "" { - config.NegativePrompt = input.NegativePrompt - } - - if input.RopeFreqBase != 0 { - config.RopeFreqBase = input.RopeFreqBase - } - - if input.RopeFreqScale != 0 { - config.RopeFreqScale = input.RopeFreqScale - } - - if input.Grammar != "" { - config.Grammar = input.Grammar - } - - if input.Temperature != nil { - config.Temperature = input.Temperature - } - - if input.Maxtokens != nil { - config.Maxtokens = input.Maxtokens - } - - switch stop := input.Stop.(type) { - case string: - if stop != "" { - config.StopWords = append(config.StopWords, stop) - } - case []interface{}: - for _, pp := range stop { - if s, ok := pp.(string); ok { - config.StopWords = append(config.StopWords, s) - } - } - } - - if len(input.Tools) > 0 { - for _, tool := range input.Tools { - input.Functions = append(input.Functions, tool.Function) - } - } - - if input.ToolsChoice != nil { - var toolChoice grammar.Tool - - switch content := input.ToolsChoice.(type) { - case string: - _ = json.Unmarshal([]byte(content), &toolChoice) - case map[string]interface{}: - dat, _ := json.Marshal(content) - _ = json.Unmarshal(dat, &toolChoice) - } - input.FunctionCall = map[string]interface{}{ - "name": toolChoice.Function.Name, - } - } - - // Decode each request's message content - index := 0 - for i, m := range input.Messages { - switch content := m.Content.(type) { - case string: - input.Messages[i].StringContent = content - case []interface{}: - dat, _ := json.Marshal(content) - c := []schema.Content{} - json.Unmarshal(dat, &c) - for _, pp := range c { - if pp.Type == "text" { - input.Messages[i].StringContent = pp.Text - } else if pp.Type == "image_url" { - // Detect if pp.ImageURL is an URL, if it is download the image and encode it in base64: - base64, err := getBase64Image(pp.ImageURL.URL) - if err == nil { - input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff - // set a placeholder for each image - input.Messages[i].StringContent = fmt.Sprintf("[img-%d]", index) + input.Messages[i].StringContent - index++ - } else { - fmt.Print("Failed encoding image", err) - } - } - } - } - } - - if input.RepeatPenalty != 0 { - config.RepeatPenalty = input.RepeatPenalty - } - - if input.FrequencyPenalty != 0 { - config.FrequencyPenalty = input.FrequencyPenalty - } - - if input.PresencePenalty != 0 { - config.PresencePenalty = input.PresencePenalty - } - - if input.Keep != 0 { - config.Keep = input.Keep - } - - if input.Batch != 0 { - config.Batch = input.Batch - } - - if input.IgnoreEOS { - config.IgnoreEOS = input.IgnoreEOS - } - - if input.Seed != nil { - config.Seed = input.Seed - } - - if input.TypicalP != nil { - config.TypicalP = input.TypicalP - } - - switch inputs := input.Input.(type) { - case string: - if inputs != "" { - config.InputStrings = append(config.InputStrings, inputs) - } - case []interface{}: - for _, pp := range inputs { - switch i := pp.(type) { - case string: - config.InputStrings = append(config.InputStrings, i) - case []interface{}: - tokens := []int{} - for _, ii := range i { - tokens = append(tokens, int(ii.(float64))) - } - config.InputToken = append(config.InputToken, tokens) - } - } - } - - // Can be either a string or an object - switch fnc := input.FunctionCall.(type) { - case string: - if fnc != "" { - config.SetFunctionCallString(fnc) - } - case map[string]interface{}: - var name string - n, exists := fnc["name"] - if exists { - nn, e := n.(string) - if e { - name = nn - } - } - config.SetFunctionCallNameString(name) - } - - switch p := input.Prompt.(type) { - case string: - config.PromptStrings = append(config.PromptStrings, p) - case []interface{}: - for _, pp := range p { - if s, ok := pp.(string); ok { - config.PromptStrings = append(config.PromptStrings, s) - } - } - } -} - -func mergeRequestWithConfig(modelFile string, input *schema.OpenAIRequest, cm *config.BackendConfigLoader, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*config.BackendConfig, *schema.OpenAIRequest, error) { - cfg, err := cm.LoadBackendConfigFileByName(modelFile, loader.ModelPath, - config.LoadOptionDebug(debug), - config.LoadOptionThreads(threads), - config.LoadOptionContextSize(ctx), - config.LoadOptionF16(f16), - ) - - // Set the parameters for the language model prediction - updateRequestConfig(cfg, input) - - return cfg, input, err -} diff --git a/core/http/endpoints/openai/transcription.go b/core/http/endpoints/openai/transcription.go index c7dd39e7..572cec12 100644 --- a/core/http/endpoints/openai/transcription.go +++ b/core/http/endpoints/openai/transcription.go @@ -9,8 +9,7 @@ import ( "path/filepath" "github.com/go-skynet/LocalAI/core/backend" - "github.com/go-skynet/LocalAI/core/config" - model "github.com/go-skynet/LocalAI/pkg/model" + fiberContext "github.com/go-skynet/LocalAI/core/http/ctx" "github.com/gofiber/fiber/v2" "github.com/rs/zerolog/log" @@ -23,17 +22,15 @@ import ( // @Param file formData file true "file" // @Success 200 {object} map[string]string "Response" // @Router /v1/audio/transcriptions [post] -func TranscriptEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { +func TranscriptEndpoint(fce *fiberContext.FiberContextExtractor, tbs *backend.TranscriptionBackendService) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - m, input, err := readRequest(c, ml, appConfig, false) + _, request, err := fce.OpenAIRequestFromContext(c, false) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } - config, input, err := mergeRequestWithConfig(m, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } + // TODO: Investigate this file copy stuff later - potentially belongs in service. + // retrieve the file data from the request file, err := c.FormFile("file") if err != nil { @@ -65,13 +62,16 @@ func TranscriptEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a log.Debug().Msgf("Audio file copied to: %+v", dst) - tr, err := backend.ModelTranscription(dst, input.Language, ml, *config, appConfig) - if err != nil { - return err - } + request.File = dst - log.Debug().Msgf("Trascribed: %+v", tr) + responseChannel := tbs.Transcribe(request) + rawResponse := <-responseChannel + + if rawResponse.Error != nil { + return rawResponse.Error + } + log.Debug().Msgf("Transcribed: %+v", rawResponse.Value) // TODO: handle different outputs here - return c.Status(http.StatusOK).JSON(tr) + return c.Status(http.StatusOK).JSON(rawResponse.Value) } } diff --git a/core/schema/whisper.go b/core/schema/transcription.go similarity index 90% rename from core/schema/whisper.go rename to core/schema/transcription.go index 41413c1f..fe1799fa 100644 --- a/core/schema/whisper.go +++ b/core/schema/transcription.go @@ -10,7 +10,7 @@ type Segment struct { Tokens []int `json:"tokens"` } -type Result struct { +type TranscriptionResult struct { Segments []Segment `json:"segments"` Text string `json:"text"` } diff --git a/core/services/backend_monitor.go b/core/services/backend_monitor.go index 979a67a3..a610432c 100644 --- a/core/services/backend_monitor.go +++ b/core/services/backend_monitor.go @@ -15,22 +15,22 @@ import ( gopsutil "github.com/shirou/gopsutil/v3/process" ) -type BackendMonitor struct { +type BackendMonitorService struct { configLoader *config.BackendConfigLoader modelLoader *model.ModelLoader options *config.ApplicationConfig // Taking options in case we need to inspect ExternalGRPCBackends, though that's out of scope for now, hence the name. } -func NewBackendMonitor(configLoader *config.BackendConfigLoader, modelLoader *model.ModelLoader, appConfig *config.ApplicationConfig) BackendMonitor { - return BackendMonitor{ +func NewBackendMonitorService(modelLoader *model.ModelLoader, configLoader *config.BackendConfigLoader, appConfig *config.ApplicationConfig) *BackendMonitorService { + return &BackendMonitorService{ configLoader: configLoader, modelLoader: modelLoader, options: appConfig, } } -func (bm BackendMonitor) getModelLoaderIDFromModelName(modelName string) (string, error) { - config, exists := bm.configLoader.GetBackendConfig(modelName) +func (bms BackendMonitorService) getModelLoaderIDFromModelName(modelName string) (string, error) { + config, exists := bms.configLoader.GetBackendConfig(modelName) var backendId string if exists { backendId = config.Model @@ -46,8 +46,8 @@ func (bm BackendMonitor) getModelLoaderIDFromModelName(modelName string) (string return backendId, nil } -func (bm *BackendMonitor) SampleLocalBackendProcess(model string) (*schema.BackendMonitorResponse, error) { - config, exists := bm.configLoader.GetBackendConfig(model) +func (bms *BackendMonitorService) SampleLocalBackendProcess(model string) (*schema.BackendMonitorResponse, error) { + config, exists := bms.configLoader.GetBackendConfig(model) var backend string if exists { backend = config.Model @@ -60,7 +60,7 @@ func (bm *BackendMonitor) SampleLocalBackendProcess(model string) (*schema.Backe backend = fmt.Sprintf("%s.bin", backend) } - pid, err := bm.modelLoader.GetGRPCPID(backend) + pid, err := bms.modelLoader.GetGRPCPID(backend) if err != nil { log.Error().Err(err).Str("model", model).Msg("failed to find GRPC pid") @@ -101,12 +101,12 @@ func (bm *BackendMonitor) SampleLocalBackendProcess(model string) (*schema.Backe }, nil } -func (bm BackendMonitor) CheckAndSample(modelName string) (*proto.StatusResponse, error) { - backendId, err := bm.getModelLoaderIDFromModelName(modelName) +func (bms BackendMonitorService) CheckAndSample(modelName string) (*proto.StatusResponse, error) { + backendId, err := bms.getModelLoaderIDFromModelName(modelName) if err != nil { return nil, err } - modelAddr := bm.modelLoader.CheckIsLoaded(backendId) + modelAddr := bms.modelLoader.CheckIsLoaded(backendId) if modelAddr == "" { return nil, fmt.Errorf("backend %s is not currently loaded", backendId) } @@ -114,7 +114,7 @@ func (bm BackendMonitor) CheckAndSample(modelName string) (*proto.StatusResponse status, rpcErr := modelAddr.GRPC(false, nil).Status(context.TODO()) if rpcErr != nil { log.Warn().Msgf("backend %s experienced an error retrieving status info: %s", backendId, rpcErr.Error()) - val, slbErr := bm.SampleLocalBackendProcess(backendId) + val, slbErr := bms.SampleLocalBackendProcess(backendId) if slbErr != nil { return nil, fmt.Errorf("backend %s experienced an error retrieving status info via rpc: %s, then failed local node process sample: %s", backendId, rpcErr.Error(), slbErr.Error()) } @@ -131,10 +131,10 @@ func (bm BackendMonitor) CheckAndSample(modelName string) (*proto.StatusResponse return status, nil } -func (bm BackendMonitor) ShutdownModel(modelName string) error { - backendId, err := bm.getModelLoaderIDFromModelName(modelName) +func (bms BackendMonitorService) ShutdownModel(modelName string) error { + backendId, err := bms.getModelLoaderIDFromModelName(modelName) if err != nil { return err } - return bm.modelLoader.ShutdownModel(backendId) + return bms.modelLoader.ShutdownModel(backendId) } diff --git a/core/services/gallery.go b/core/services/gallery.go index b068abbb..1ef8e3e2 100644 --- a/core/services/gallery.go +++ b/core/services/gallery.go @@ -3,14 +3,18 @@ package services import ( "context" "encoding/json" + "errors" "os" + "path/filepath" "strings" "sync" "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/embedded" + "github.com/go-skynet/LocalAI/pkg/downloader" "github.com/go-skynet/LocalAI/pkg/gallery" - "github.com/go-skynet/LocalAI/pkg/startup" "github.com/go-skynet/LocalAI/pkg/utils" + "github.com/rs/zerolog/log" "gopkg.in/yaml.v2" ) @@ -29,18 +33,6 @@ func NewGalleryService(modelPath string) *GalleryService { } } -func prepareModel(modelPath string, req gallery.GalleryModel, cl *config.BackendConfigLoader, downloadStatus func(string, string, string, float64)) error { - - config, err := gallery.GetGalleryConfigFromURL(req.URL) - if err != nil { - return err - } - - config.Files = append(config.Files, req.AdditionalFiles...) - - return gallery.InstallModel(modelPath, req.Name, &config, req.Overrides, downloadStatus) -} - func (g *GalleryService) UpdateStatus(s string, op *gallery.GalleryOpStatus) { g.Lock() defer g.Unlock() @@ -92,10 +84,10 @@ func (g *GalleryService) Start(c context.Context, cl *config.BackendConfigLoader err = gallery.InstallModelFromGalleryByName(op.Galleries, op.GalleryName, g.modelPath, op.Req, progressCallback) } } else if op.ConfigURL != "" { - startup.PreloadModelsConfigurations(op.ConfigURL, g.modelPath, op.ConfigURL) + PreloadModelsConfigurations(op.ConfigURL, g.modelPath, op.ConfigURL) err = cl.Preload(g.modelPath) } else { - err = prepareModel(g.modelPath, op.Req, cl, progressCallback) + err = prepareModel(g.modelPath, op.Req, progressCallback) } if err != nil { @@ -127,13 +119,12 @@ type galleryModel struct { ID string `json:"id"` } -func processRequests(modelPath, s string, cm *config.BackendConfigLoader, galleries []gallery.Gallery, requests []galleryModel) error { +func processRequests(modelPath string, galleries []gallery.Gallery, requests []galleryModel) error { var err error for _, r := range requests { utils.ResetDownloadTimers() if r.ID == "" { - err = prepareModel(modelPath, r.GalleryModel, cm, utils.DisplayDownloadFunction) - + err = prepareModel(modelPath, r.GalleryModel, utils.DisplayDownloadFunction) } else { if strings.Contains(r.ID, "@") { err = gallery.InstallModelFromGallery( @@ -158,7 +149,7 @@ func ApplyGalleryFromFile(modelPath, s string, cl *config.BackendConfigLoader, g return err } - return processRequests(modelPath, s, cl, galleries, requests) + return processRequests(modelPath, galleries, requests) } func ApplyGalleryFromString(modelPath, s string, cl *config.BackendConfigLoader, galleries []gallery.Gallery) error { @@ -168,5 +159,90 @@ func ApplyGalleryFromString(modelPath, s string, cl *config.BackendConfigLoader, return err } - return processRequests(modelPath, s, cl, galleries, requests) + return processRequests(modelPath, galleries, requests) +} + +// PreloadModelsConfigurations will preload models from the given list of URLs +// It will download the model if it is not already present in the model path +// It will also try to resolve if the model is an embedded model YAML configuration +func PreloadModelsConfigurations(modelLibraryURL string, modelPath string, models ...string) { + for _, url := range models { + + // As a best effort, try to resolve the model from the remote library + // if it's not resolved we try with the other method below + if modelLibraryURL != "" { + lib, err := embedded.GetRemoteLibraryShorteners(modelLibraryURL) + if err == nil { + if lib[url] != "" { + log.Debug().Msgf("[startup] model configuration is defined remotely: %s (%s)", url, lib[url]) + url = lib[url] + } + } + } + + url = embedded.ModelShortURL(url) + switch { + case embedded.ExistsInModelsLibrary(url): + modelYAML, err := embedded.ResolveContent(url) + // If we resolve something, just save it to disk and continue + if err != nil { + log.Error().Err(err).Msg("error resolving model content") + continue + } + + log.Debug().Msgf("[startup] resolved embedded model: %s", url) + md5Name := utils.MD5(url) + modelDefinitionFilePath := filepath.Join(modelPath, md5Name) + ".yaml" + if err := os.WriteFile(modelDefinitionFilePath, modelYAML, os.ModePerm); err != nil { + log.Error().Err(err).Str("filepath", modelDefinitionFilePath).Msg("error writing model definition") + } + case downloader.LooksLikeURL(url): + log.Debug().Msgf("[startup] resolved model to download: %s", url) + + // md5 of model name + md5Name := utils.MD5(url) + + // check if file exists + if _, err := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(err, os.ErrNotExist) { + modelDefinitionFilePath := filepath.Join(modelPath, md5Name) + ".yaml" + err := downloader.DownloadFile(url, modelDefinitionFilePath, "", func(fileName, current, total string, percent float64) { + utils.DisplayDownloadFunction(fileName, current, total, percent) + }) + if err != nil { + log.Error().Err(err).Str("url", url).Str("filepath", modelDefinitionFilePath).Msg("error downloading model") + } + } + default: + if _, err := os.Stat(url); err == nil { + log.Debug().Msgf("[startup] resolved local model: %s", url) + // copy to modelPath + md5Name := utils.MD5(url) + + modelYAML, err := os.ReadFile(url) + if err != nil { + log.Error().Err(err).Str("filepath", url).Msg("error reading model definition") + continue + } + + modelDefinitionFilePath := filepath.Join(modelPath, md5Name) + ".yaml" + if err := os.WriteFile(modelDefinitionFilePath, modelYAML, os.ModePerm); err != nil { + log.Error().Err(err).Str("filepath", modelDefinitionFilePath).Msg("error loading model: %s") + } + } else { + log.Warn().Msgf("[startup] failed resolving model '%s'", url) + } + } + } +} + +func prepareModel(modelPath string, req gallery.GalleryModel, downloadStatus func(string, string, string, float64)) error { + + config, err := gallery.GetGalleryConfigFromURL(req.URL) + if err != nil { + return err + } + + config.Files = append(config.Files, req.AdditionalFiles...) + + return gallery.InstallModel(modelPath, req.Name, &config, req.Overrides, downloadStatus) } diff --git a/core/services/list_models.go b/core/services/list_models.go new file mode 100644 index 00000000..a21e6faf --- /dev/null +++ b/core/services/list_models.go @@ -0,0 +1,72 @@ +package services + +import ( + "regexp" + + "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/core/schema" + "github.com/go-skynet/LocalAI/pkg/model" +) + +type ListModelsService struct { + bcl *config.BackendConfigLoader + ml *model.ModelLoader + appConfig *config.ApplicationConfig +} + +func NewListModelsService(ml *model.ModelLoader, bcl *config.BackendConfigLoader, appConfig *config.ApplicationConfig) *ListModelsService { + return &ListModelsService{ + bcl: bcl, + ml: ml, + appConfig: appConfig, + } +} + +func (lms *ListModelsService) ListModels(filter string, excludeConfigured bool) ([]schema.OpenAIModel, error) { + + models, err := lms.ml.ListModels() + if err != nil { + return nil, err + } + + var mm map[string]interface{} = map[string]interface{}{} + + dataModels := []schema.OpenAIModel{} + + var filterFn func(name string) bool + + // If filter is not specified, do not filter the list by model name + if filter == "" { + filterFn = func(_ string) bool { return true } + } else { + // If filter _IS_ specified, we compile it to a regex which is used to create the filterFn + rxp, err := regexp.Compile(filter) + if err != nil { + return nil, err + } + filterFn = func(name string) bool { + return rxp.MatchString(name) + } + } + + // Start with the known configurations + for _, c := range lms.bcl.GetAllBackendConfigs() { + if excludeConfigured { + mm[c.Model] = nil + } + + if filterFn(c.Name) { + dataModels = append(dataModels, schema.OpenAIModel{ID: c.Name, Object: "model"}) + } + } + + // Then iterate through the loose files: + for _, m := range models { + // And only adds them if they shouldn't be skipped. + if _, exists := mm[m]; !exists && filterFn(m) { + dataModels = append(dataModels, schema.OpenAIModel{ID: m, Object: "model"}) + } + } + + return dataModels, nil +} diff --git a/pkg/startup/model_preload_test.go b/core/services/model_preload_test.go similarity index 96% rename from pkg/startup/model_preload_test.go rename to core/services/model_preload_test.go index 63a8f8b0..fc65d565 100644 --- a/pkg/startup/model_preload_test.go +++ b/core/services/model_preload_test.go @@ -1,13 +1,14 @@ -package startup_test +package services_test import ( "fmt" "os" "path/filepath" - . "github.com/go-skynet/LocalAI/pkg/startup" "github.com/go-skynet/LocalAI/pkg/utils" + . "github.com/go-skynet/LocalAI/core/services" + . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) diff --git a/core/services/openai.go b/core/services/openai.go new file mode 100644 index 00000000..7a2679ad --- /dev/null +++ b/core/services/openai.go @@ -0,0 +1,808 @@ +package services + +import ( + "encoding/json" + "errors" + "fmt" + "strings" + "sync" + "time" + + "github.com/go-skynet/LocalAI/core/backend" + "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/core/schema" + "github.com/go-skynet/LocalAI/pkg/concurrency" + "github.com/go-skynet/LocalAI/pkg/grammar" + "github.com/go-skynet/LocalAI/pkg/model" + "github.com/go-skynet/LocalAI/pkg/utils" + "github.com/google/uuid" + "github.com/imdario/mergo" + "github.com/rs/zerolog/log" +) + +type endpointGenerationConfigurationFn func(bc *config.BackendConfig, request *schema.OpenAIRequest) endpointConfiguration + +type endpointConfiguration struct { + SchemaObject string + TemplatePath string + TemplateData model.PromptTemplateData + ResultMappingFn func(resp *backend.LLMResponse, index int) schema.Choice + CompletionMappingFn func(resp concurrency.ErrorOr[*backend.LLMResponse]) concurrency.ErrorOr[*schema.OpenAIResponse] + TokenMappingFn func(resp concurrency.ErrorOr[*backend.LLMResponse]) concurrency.ErrorOr[*schema.OpenAIResponse] +} + +// TODO: This is used for completion and edit. I am pretty sure I forgot parts, but fix it later. +func simpleMapper(resp concurrency.ErrorOr[*backend.LLMResponse]) concurrency.ErrorOr[*schema.OpenAIResponse] { + if resp.Error != nil || resp.Value == nil { + return concurrency.ErrorOr[*schema.OpenAIResponse]{Error: resp.Error} + } + return concurrency.ErrorOr[*schema.OpenAIResponse]{ + Value: &schema.OpenAIResponse{ + Choices: []schema.Choice{ + { + Text: resp.Value.Response, + }, + }, + Usage: schema.OpenAIUsage{ + PromptTokens: resp.Value.Usage.Prompt, + CompletionTokens: resp.Value.Usage.Completion, + TotalTokens: resp.Value.Usage.Prompt + resp.Value.Usage.Completion, + }, + }, + } +} + +// TODO: Consider alternative names for this. +// The purpose of this struct is to hold a reference to the OpenAI request context information +// This keeps things simple within core/services/openai.go and allows consumers to "see" this information if they need it +type OpenAIRequestTraceID struct { + ID string + Created int +} + +// This type split out from core/backend/llm.go - I'm still not _totally_ sure about this, but it seems to make sense to keep the generic LLM code from the OpenAI specific higher level functionality +type OpenAIService struct { + bcl *config.BackendConfigLoader + ml *model.ModelLoader + appConfig *config.ApplicationConfig + llmbs *backend.LLMBackendService +} + +func NewOpenAIService(ml *model.ModelLoader, bcl *config.BackendConfigLoader, appConfig *config.ApplicationConfig, llmbs *backend.LLMBackendService) *OpenAIService { + return &OpenAIService{ + bcl: bcl, + ml: ml, + appConfig: appConfig, + llmbs: llmbs, + } +} + +// Keeping in place as a reminder to POTENTIALLY ADD MORE VALIDATION HERE??? +func (oais *OpenAIService) getConfig(request *schema.OpenAIRequest) (*config.BackendConfig, *schema.OpenAIRequest, error) { + return oais.bcl.LoadBackendConfigForModelAndOpenAIRequest(request.Model, request, oais.appConfig) +} + +// TODO: It would be a lot less messy to make a return struct that had references to each of these channels +// INTENTIONALLY not doing that quite yet - I believe we need to let the references to unused channels die for the GC to automatically collect -- can we manually free()? +// finalResultsChannel is the primary async return path: one result for the entire request. +// promptResultsChannels is DUBIOUS. It's expected to be raw fan-out used within the function itself, but I am exposing for testing? One bundle of LLMResponseBundle per PromptString? Gets all N completions for a single prompt. +// completionsChannel is a channel that emits one *LLMResponse per generated completion, be that different prompts or N. Seems the most useful other than "entire request" Request is available to attempt tracing??? +// tokensChannel is a channel that emits one *LLMResponse per generated token. Let's see what happens! +func (oais *OpenAIService) Completion(request *schema.OpenAIRequest, notifyOnPromptResult bool, notifyOnToken bool) ( + traceID *OpenAIRequestTraceID, finalResultChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], promptResultsChannels []<-chan concurrency.ErrorOr[*backend.LLMResponseBundle], + completionsChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], tokenChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], err error) { + + return oais.GenerateTextFromRequest(request, func(bc *config.BackendConfig, request *schema.OpenAIRequest) endpointConfiguration { + return endpointConfiguration{ + SchemaObject: "text_completion", + TemplatePath: bc.TemplateConfig.Completion, + TemplateData: model.PromptTemplateData{ + SystemPrompt: bc.SystemPrompt, + }, + ResultMappingFn: func(resp *backend.LLMResponse, promptIndex int) schema.Choice { + return schema.Choice{ + Index: promptIndex, + FinishReason: "stop", + Text: resp.Response, + } + }, + CompletionMappingFn: simpleMapper, + TokenMappingFn: simpleMapper, + } + }, notifyOnPromptResult, notifyOnToken, nil) +} + +func (oais *OpenAIService) Edit(request *schema.OpenAIRequest, notifyOnPromptResult bool, notifyOnToken bool) ( + traceID *OpenAIRequestTraceID, finalResultChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], promptResultsChannels []<-chan concurrency.ErrorOr[*backend.LLMResponseBundle], + completionsChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], tokenChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], err error) { + + return oais.GenerateTextFromRequest(request, func(bc *config.BackendConfig, request *schema.OpenAIRequest) endpointConfiguration { + + return endpointConfiguration{ + SchemaObject: "edit", + TemplatePath: bc.TemplateConfig.Edit, + TemplateData: model.PromptTemplateData{ + SystemPrompt: bc.SystemPrompt, + Instruction: request.Instruction, + }, + ResultMappingFn: func(resp *backend.LLMResponse, promptIndex int) schema.Choice { + return schema.Choice{ + Index: promptIndex, + FinishReason: "stop", + Text: resp.Response, + } + }, + CompletionMappingFn: simpleMapper, + TokenMappingFn: simpleMapper, + } + }, notifyOnPromptResult, notifyOnToken, nil) +} + +func (oais *OpenAIService) Chat(request *schema.OpenAIRequest, notifyOnPromptResult bool, notifyOnToken bool) ( + traceID *OpenAIRequestTraceID, finalResultChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], + completionsChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], tokenChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], err error) { + + return oais.GenerateFromMultipleMessagesChatRequest(request, notifyOnPromptResult, notifyOnToken, nil) +} + +func (oais *OpenAIService) GenerateTextFromRequest(request *schema.OpenAIRequest, endpointConfigFn endpointGenerationConfigurationFn, notifyOnPromptResult bool, notifyOnToken bool, initialTraceID *OpenAIRequestTraceID) ( + traceID *OpenAIRequestTraceID, finalResultChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], promptResultsChannels []<-chan concurrency.ErrorOr[*backend.LLMResponseBundle], + completionsChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], tokenChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], err error) { + + if initialTraceID == nil { + traceID = &OpenAIRequestTraceID{ + ID: uuid.New().String(), + Created: int(time.Now().Unix()), + } + } else { + traceID = initialTraceID + } + + bc, request, err := oais.getConfig(request) + if err != nil { + log.Error().Err(err).Msgf("[oais::GenerateTextFromRequest] error getting configuration") + return + } + + if request.ResponseFormat.Type == "json_object" { + request.Grammar = grammar.JSONBNF + } + + bc.Grammar = request.Grammar + + if request.Stream && len(bc.PromptStrings) > 1 { + log.Warn().Msg("potentially cannot handle more than 1 `PromptStrings` when Streaming?") + } + + rawFinalResultChannel := make(chan concurrency.ErrorOr[*schema.OpenAIResponse]) + finalResultChannel = rawFinalResultChannel + promptResultsChannels = []<-chan concurrency.ErrorOr[*backend.LLMResponseBundle]{} + var rawCompletionsChannel chan concurrency.ErrorOr[*schema.OpenAIResponse] + var rawTokenChannel chan concurrency.ErrorOr[*schema.OpenAIResponse] + if notifyOnPromptResult { + rawCompletionsChannel = make(chan concurrency.ErrorOr[*schema.OpenAIResponse]) + } + if notifyOnToken { + rawTokenChannel = make(chan concurrency.ErrorOr[*schema.OpenAIResponse]) + } + + promptResultsChannelLock := sync.Mutex{} + + endpointConfig := endpointConfigFn(bc, request) + + if len(endpointConfig.TemplatePath) == 0 { + // A model can have a "file.bin.tmpl" file associated with a prompt template prefix + if oais.ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", bc.Model)) { + endpointConfig.TemplatePath = bc.Model + } else { + log.Warn().Msgf("failed to find any template for %+v", request) + } + } + + setupWG := sync.WaitGroup{} + var prompts []string + if lPS := len(bc.PromptStrings); lPS > 0 { + setupWG.Add(lPS) + prompts = bc.PromptStrings + } else { + setupWG.Add(len(bc.InputStrings)) + prompts = bc.InputStrings + } + + var setupError error = nil + + for pI, p := range prompts { + + go func(promptIndex int, prompt string) { + if endpointConfig.TemplatePath != "" { + promptTemplateData := model.PromptTemplateData{ + Input: prompt, + } + err := mergo.Merge(promptTemplateData, endpointConfig.TemplateData, mergo.WithOverride) + if err == nil { + templatedInput, err := oais.ml.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, endpointConfig.TemplatePath, promptTemplateData) + if err == nil { + prompt = templatedInput + log.Debug().Msgf("Template found, input modified to: %s", prompt) + } + } + } + + log.Debug().Msgf("[OAIS GenerateTextFromRequest] Prompt: %q", prompt) + promptResultsChannel, completionChannels, tokenChannels, err := oais.llmbs.GenerateText(prompt, request, bc, + func(r *backend.LLMResponse) schema.Choice { + return endpointConfig.ResultMappingFn(r, promptIndex) + }, notifyOnPromptResult, notifyOnToken) + if err != nil { + log.Error().Msgf("Unable to generate text prompt: %q\nerr: %q", prompt, err) + promptResultsChannelLock.Lock() + setupError = errors.Join(setupError, err) + promptResultsChannelLock.Unlock() + setupWG.Done() + return + } + if notifyOnPromptResult { + concurrency.SliceOfChannelsRawMergerWithoutMapping(concurrency.SliceOfChannelsTransformer(completionChannels, endpointConfig.CompletionMappingFn), rawCompletionsChannel, true) + } + if notifyOnToken { + concurrency.SliceOfChannelsRawMergerWithoutMapping(concurrency.SliceOfChannelsTransformer(tokenChannels, endpointConfig.TokenMappingFn), rawTokenChannel, true) + } + promptResultsChannelLock.Lock() + promptResultsChannels = append(promptResultsChannels, promptResultsChannel) + promptResultsChannelLock.Unlock() + setupWG.Done() + }(pI, p) + + } + setupWG.Wait() + + // If any of the setup goroutines experienced an error, quit early here. + if setupError != nil { + go func() { + log.Error().Err(setupError).Msgf("[OAIS GenerateTextFromRequest] caught an error during setup") + rawFinalResultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: setupError} + close(rawFinalResultChannel) + }() + return + } + + initialResponse := &schema.OpenAIResponse{ + ID: traceID.ID, + Created: traceID.Created, + Model: request.Model, + Object: endpointConfig.SchemaObject, + Usage: schema.OpenAIUsage{}, + } + + // utils.SliceOfChannelsRawMerger[[]schema.Choice](promptResultsChannels, rawFinalResultChannel, func(results []schema.Choice) (*schema.OpenAIResponse, error) { + concurrency.SliceOfChannelsReducer( + promptResultsChannels, rawFinalResultChannel, + func(iv concurrency.ErrorOr[*backend.LLMResponseBundle], result concurrency.ErrorOr[*schema.OpenAIResponse]) concurrency.ErrorOr[*schema.OpenAIResponse] { + + if iv.Error != nil { + result.Error = iv.Error + return result + } + result.Value.Usage.PromptTokens += iv.Value.Usage.Prompt + result.Value.Usage.CompletionTokens += iv.Value.Usage.Completion + result.Value.Usage.TotalTokens = result.Value.Usage.PromptTokens + result.Value.Usage.CompletionTokens + + result.Value.Choices = append(result.Value.Choices, iv.Value.Response...) + + return result + }, concurrency.ErrorOr[*schema.OpenAIResponse]{Value: initialResponse}, true) + + completionsChannel = rawCompletionsChannel + tokenChannel = rawTokenChannel + + return +} + +// TODO: For porting sanity, this is distinct from GenerateTextFromRequest and is _currently_ specific to Chat purposes +// this is not a final decision -- just a reality of moving a lot of parts at once +// / This has _become_ Chat which wasn't the goal... More cleanup in the future once it's stable? +func (oais *OpenAIService) GenerateFromMultipleMessagesChatRequest(request *schema.OpenAIRequest, notifyOnPromptResult bool, notifyOnToken bool, initialTraceID *OpenAIRequestTraceID) ( + traceID *OpenAIRequestTraceID, finalResultChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], + completionsChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], tokenChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], err error) { + + if initialTraceID == nil { + traceID = &OpenAIRequestTraceID{ + ID: uuid.New().String(), + Created: int(time.Now().Unix()), + } + } else { + traceID = initialTraceID + } + + bc, request, err := oais.getConfig(request) + if err != nil { + return + } + + // Allow the user to set custom actions via config file + // to be "embedded" in each model + noActionName := "answer" + noActionDescription := "use this action to answer without performing any action" + + if bc.FunctionsConfig.NoActionFunctionName != "" { + noActionName = bc.FunctionsConfig.NoActionFunctionName + } + if bc.FunctionsConfig.NoActionDescriptionName != "" { + noActionDescription = bc.FunctionsConfig.NoActionDescriptionName + } + + if request.ResponseFormat.Type == "json_object" { + request.Grammar = grammar.JSONBNF + } + + bc.Grammar = request.Grammar + + processFunctions := false + funcs := grammar.Functions{} + // process functions if we have any defined or if we have a function call string + if len(request.Functions) > 0 && bc.ShouldUseFunctions() { + log.Debug().Msgf("Response needs to process functions") + + processFunctions = true + + noActionGrammar := grammar.Function{ + Name: noActionName, + Description: noActionDescription, + Parameters: map[string]interface{}{ + "properties": map[string]interface{}{ + "message": map[string]interface{}{ + "type": "string", + "description": "The message to reply the user with", + }}, + }, + } + + // Append the no action function + funcs = append(funcs, request.Functions...) + if !bc.FunctionsConfig.DisableNoAction { + funcs = append(funcs, noActionGrammar) + } + + // Force picking one of the functions by the request + if bc.FunctionToCall() != "" { + funcs = funcs.Select(bc.FunctionToCall()) + } + + // Update input grammar + jsStruct := funcs.ToJSONStructure() + bc.Grammar = jsStruct.Grammar("", bc.FunctionsConfig.ParallelCalls) + } else if request.JSONFunctionGrammarObject != nil { + bc.Grammar = request.JSONFunctionGrammarObject.Grammar("", bc.FunctionsConfig.ParallelCalls) + } + + if request.Stream && processFunctions { + log.Warn().Msg("Streaming + Functions is highly experimental in this version") + } + + var predInput string + + if !bc.TemplateConfig.UseTokenizerTemplate || processFunctions { + + suppressConfigSystemPrompt := false + mess := []string{} + for messageIndex, i := range request.Messages { + var content string + role := i.Role + + // if function call, we might want to customize the role so we can display better that the "assistant called a json action" + // if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request + if (i.FunctionCall != nil || i.ToolCalls != nil) && i.Role == "assistant" { + roleFn := "assistant_function_call" + r := bc.Roles[roleFn] + if r != "" { + role = roleFn + } + } + r := bc.Roles[role] + contentExists := i.Content != nil && i.StringContent != "" + + fcall := i.FunctionCall + if len(i.ToolCalls) > 0 { + fcall = i.ToolCalls + } + + // First attempt to populate content via a chat message specific template + if bc.TemplateConfig.ChatMessage != "" { + chatMessageData := model.ChatMessageTemplateData{ + SystemPrompt: bc.SystemPrompt, + Role: r, + RoleName: role, + Content: i.StringContent, + FunctionCall: fcall, + FunctionName: i.Name, + LastMessage: messageIndex == (len(request.Messages) - 1), + Function: bc.Grammar != "" && (messageIndex == (len(request.Messages) - 1)), + MessageIndex: messageIndex, + } + templatedChatMessage, err := oais.ml.EvaluateTemplateForChatMessage(bc.TemplateConfig.ChatMessage, chatMessageData) + if err != nil { + log.Error().Msgf("error processing message %+v using template \"%s\": %v. Skipping!", chatMessageData, bc.TemplateConfig.ChatMessage, err) + } else { + if templatedChatMessage == "" { + log.Warn().Msgf("template \"%s\" produced blank output for %+v. Skipping!", bc.TemplateConfig.ChatMessage, chatMessageData) + continue // TODO: This continue is here intentionally to skip over the line `mess = append(mess, content)` below, and to prevent the sprintf + } + log.Debug().Msgf("templated message for chat: %s", templatedChatMessage) + content = templatedChatMessage + } + } + marshalAnyRole := func(f any) { + j, err := json.Marshal(f) + if err == nil { + if contentExists { + content += "\n" + fmt.Sprint(r, " ", string(j)) + } else { + content = fmt.Sprint(r, " ", string(j)) + } + } + } + marshalAny := func(f any) { + j, err := json.Marshal(f) + if err == nil { + if contentExists { + content += "\n" + string(j) + } else { + content = string(j) + } + } + } + // If this model doesn't have such a template, or if that template fails to return a value, template at the message level. + if content == "" { + if r != "" { + if contentExists { + content = fmt.Sprint(r, i.StringContent) + } + + if i.FunctionCall != nil { + marshalAnyRole(i.FunctionCall) + } + } else { + if contentExists { + content = fmt.Sprint(i.StringContent) + } + + if i.FunctionCall != nil { + marshalAny(i.FunctionCall) + } + + if i.ToolCalls != nil { + marshalAny(i.ToolCalls) + } + } + // Special Handling: System. We care if it was printed at all, not the r branch, so check seperately + if contentExists && role == "system" { + suppressConfigSystemPrompt = true + } + } + + mess = append(mess, content) + } + + predInput = strings.Join(mess, "\n") + + log.Debug().Msgf("Prompt (before templating): %s", predInput) + + templateFile := "" + // A model can have a "file.bin.tmpl" file associated with a prompt template prefix + if oais.ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", bc.Model)) { + templateFile = bc.Model + } + + if bc.TemplateConfig.Chat != "" && !processFunctions { + templateFile = bc.TemplateConfig.Chat + } + + if bc.TemplateConfig.Functions != "" && processFunctions { + templateFile = bc.TemplateConfig.Functions + } + + if templateFile != "" { + templatedInput, err := oais.ml.EvaluateTemplateForPrompt(model.ChatPromptTemplate, templateFile, model.PromptTemplateData{ + SystemPrompt: bc.SystemPrompt, + SuppressSystemPrompt: suppressConfigSystemPrompt, + Input: predInput, + Functions: funcs, + }) + if err == nil { + predInput = templatedInput + log.Debug().Msgf("Template found, input modified to: %s", predInput) + } else { + log.Debug().Msgf("Template failed loading: %s", err.Error()) + } + } + } + log.Debug().Msgf("Prompt (after templating): %s", predInput) + if processFunctions { + log.Debug().Msgf("Grammar: %+v", bc.Grammar) + } + + rawFinalResultChannel := make(chan concurrency.ErrorOr[*schema.OpenAIResponse]) + var rawCompletionsChannel chan concurrency.ErrorOr[*schema.OpenAIResponse] + var rawTokenChannel chan concurrency.ErrorOr[*schema.OpenAIResponse] + if notifyOnPromptResult { + rawCompletionsChannel = make(chan concurrency.ErrorOr[*schema.OpenAIResponse]) + } + if notifyOnToken { + rawTokenChannel = make(chan concurrency.ErrorOr[*schema.OpenAIResponse]) + } + + rawResultChannel, individualCompletionChannels, tokenChannels, err := oais.llmbs.GenerateText(predInput, request, bc, func(resp *backend.LLMResponse) schema.Choice { + return schema.Choice{ + Index: 0, // ??? + FinishReason: "stop", + Message: &schema.Message{ + Role: "assistant", + Content: resp.Response, + }, + } + }, notifyOnPromptResult, notifyOnToken) + + chatSimpleMappingFn := func(resp concurrency.ErrorOr[*backend.LLMResponse]) concurrency.ErrorOr[*schema.OpenAIResponse] { + if resp.Error != nil || resp.Value == nil { + return concurrency.ErrorOr[*schema.OpenAIResponse]{Error: resp.Error} + } + return concurrency.ErrorOr[*schema.OpenAIResponse]{ + Value: &schema.OpenAIResponse{ + ID: traceID.ID, + Created: traceID.Created, + Model: request.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []schema.Choice{ + { + Delta: &schema.Message{ + Role: "assistant", + Content: resp.Value.Response, + }, + Index: 0, + }, + }, + Object: "chat.completion.chunk", + Usage: schema.OpenAIUsage{ + PromptTokens: resp.Value.Usage.Prompt, + CompletionTokens: resp.Value.Usage.Completion, + TotalTokens: resp.Value.Usage.Prompt + resp.Value.Usage.Completion, + }, + }, + } + } + + if notifyOnPromptResult { + concurrency.SliceOfChannelsRawMergerWithoutMapping(concurrency.SliceOfChannelsTransformer(individualCompletionChannels, chatSimpleMappingFn), rawCompletionsChannel, true) + } + if notifyOnToken { + concurrency.SliceOfChannelsRawMergerWithoutMapping(concurrency.SliceOfChannelsTransformer(tokenChannels, chatSimpleMappingFn), rawTokenChannel, true) + } + + go func() { + rawResult := <-rawResultChannel + if rawResult.Error != nil { + log.Warn().Msgf("OpenAIService::processTools GenerateText error [DEBUG THIS?] %q", rawResult.Error) + return + } + llmResponseChoices := rawResult.Value.Response + + if processFunctions && len(llmResponseChoices) > 1 { + log.Warn().Msgf("chat functions response with %d choices in response, debug this?", len(llmResponseChoices)) + log.Debug().Msgf("%+v", llmResponseChoices) + } + + for _, result := range rawResult.Value.Response { + // If no functions, just return the raw result. + if !processFunctions { + + resp := schema.OpenAIResponse{ + ID: traceID.ID, + Created: traceID.Created, + Model: request.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []schema.Choice{result}, + Object: "chat.completion.chunk", + Usage: schema.OpenAIUsage{ + PromptTokens: rawResult.Value.Usage.Prompt, + CompletionTokens: rawResult.Value.Usage.Completion, + TotalTokens: rawResult.Value.Usage.Prompt + rawResult.Value.Usage.Completion, + }, + } + + rawFinalResultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Value: &resp} + + continue + } + // At this point, things are function specific! + + // Oh no this can't be the right way to do this... but it works. Save us, mudler! + fString := fmt.Sprintf("%s", result.Message.Content) + results := parseFunctionCall(fString, bc.FunctionsConfig.ParallelCalls) + noActionToRun := (len(results) > 0 && results[0].name == noActionName) + + if noActionToRun { + log.Debug().Msg("-- noActionToRun branch --") + initialMessage := schema.OpenAIResponse{ + ID: traceID.ID, + Created: traceID.Created, + Model: request.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant", Content: ""}}}, + Object: "stop", + } + rawFinalResultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Value: &initialMessage} + + result, err := oais.handleQuestion(bc, request, results[0].arguments, predInput) + if err != nil { + log.Error().Msgf("error handling question: %s", err.Error()) + return + } + + resp := schema.OpenAIResponse{ + ID: traceID.ID, + Created: traceID.Created, + Model: request.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []schema.Choice{{Delta: &schema.Message{Content: &result}, Index: 0}}, + Object: "chat.completion.chunk", + Usage: schema.OpenAIUsage{ + PromptTokens: rawResult.Value.Usage.Prompt, + CompletionTokens: rawResult.Value.Usage.Completion, + TotalTokens: rawResult.Value.Usage.Prompt + rawResult.Value.Usage.Completion, + }, + } + + rawFinalResultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Value: &resp} + + } else { + log.Debug().Msgf("[GenerateFromMultipleMessagesChatRequest] fnResultsBranch: %+v", results) + for i, ss := range results { + name, args := ss.name, ss.arguments + + initialMessage := schema.OpenAIResponse{ + ID: traceID.ID, + Created: traceID.Created, + Model: request.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []schema.Choice{{ + FinishReason: "function_call", + Message: &schema.Message{ + Role: "assistant", + ToolCalls: []schema.ToolCall{ + { + Index: i, + ID: traceID.ID, + Type: "function", + FunctionCall: schema.FunctionCall{ + Name: name, + Arguments: args, + }, + }, + }, + }}}, + Object: "chat.completion.chunk", + } + rawFinalResultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Value: &initialMessage} + } + } + } + + close(rawFinalResultChannel) + }() + + finalResultChannel = rawFinalResultChannel + completionsChannel = rawCompletionsChannel + tokenChannel = rawTokenChannel + return +} + +func (oais *OpenAIService) handleQuestion(config *config.BackendConfig, input *schema.OpenAIRequest, args, prompt string) (string, error) { + log.Debug().Msgf("[handleQuestion called] nothing to do, computing a reply") + + // If there is a message that the LLM already sends as part of the JSON reply, use it + arguments := map[string]interface{}{} + json.Unmarshal([]byte(args), &arguments) + m, exists := arguments["message"] + if exists { + switch message := m.(type) { + case string: + if message != "" { + log.Debug().Msgf("Reply received from LLM: %s", message) + message = oais.llmbs.Finetune(*config, prompt, message) + log.Debug().Msgf("Reply received from LLM(finetuned): %s", message) + + return message, nil + } + } + } + + log.Debug().Msgf("No action received from LLM, without a message, computing a reply") + // Otherwise ask the LLM to understand the JSON output and the context, and return a message + // Note: This costs (in term of CPU/GPU) another computation + config.Grammar = "" + images := []string{} + for _, m := range input.Messages { + images = append(images, m.StringImages...) + } + + resultChannel, _, err := oais.llmbs.Inference(input.Context, &backend.LLMRequest{ + Text: prompt, + Images: images, + RawMessages: input.Messages, // Experimental + }, config, false) + + if err != nil { + log.Error().Msgf("inference setup error: %s", err.Error()) + return "", err + } + + raw := <-resultChannel + if raw.Error != nil { + log.Error().Msgf("inference error: %q", raw.Error.Error()) + return "", err + } + if raw.Value == nil { + log.Warn().Msgf("nil inference response") + return "", nil + } + return oais.llmbs.Finetune(*config, prompt, raw.Value.Response), nil +} + +type funcCallResults struct { + name string + arguments string +} + +func parseFunctionCall(llmresult string, multipleResults bool) []funcCallResults { + + results := []funcCallResults{} + + // TODO: use generics to avoid this code duplication + if multipleResults { + ss := []map[string]interface{}{} + s := utils.EscapeNewLines(llmresult) + json.Unmarshal([]byte(s), &ss) + + for _, s := range ss { + func_name, ok := s["function"] + if !ok { + continue + } + args, ok := s["arguments"] + if !ok { + continue + } + d, _ := json.Marshal(args) + funcName, ok := func_name.(string) + if !ok { + continue + } + results = append(results, funcCallResults{name: funcName, arguments: string(d)}) + } + } else { + // As we have to change the result before processing, we can't stream the answer token-by-token (yet?) + ss := map[string]interface{}{} + // This prevent newlines to break JSON parsing for clients + s := utils.EscapeNewLines(llmresult) + if err := json.Unmarshal([]byte(s), &ss); err != nil { + log.Error().Msgf("error unmarshalling JSON: %s", err.Error()) + return results + } + + // The grammar defines the function name as "function", while OpenAI returns "name" + func_name, ok := ss["function"] + if !ok { + log.Debug().Msgf("ss[function] is not OK!, llm result: %q", llmresult) + return results + } + // Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object + args, ok := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix) + if !ok { + log.Debug().Msg("ss[arguments] is not OK!") + return results + } + d, _ := json.Marshal(args) + funcName, ok := func_name.(string) + if !ok { + log.Debug().Msgf("unexpected func_name: %+v", func_name) + return results + } + results = append(results, funcCallResults{name: funcName, arguments: string(d)}) + } + return results +} diff --git a/core/startup/startup.go b/core/startup/startup.go index 6298f034..92ccaa9d 100644 --- a/core/startup/startup.go +++ b/core/startup/startup.go @@ -4,17 +4,21 @@ import ( "fmt" "os" + "github.com/go-skynet/LocalAI/core" + "github.com/go-skynet/LocalAI/core/backend" "github.com/go-skynet/LocalAI/core/config" + openaiendpoint "github.com/go-skynet/LocalAI/core/http/endpoints/openai" // TODO: This is dubious. Fix this when splitting assistant api up. "github.com/go-skynet/LocalAI/core/services" "github.com/go-skynet/LocalAI/internal" "github.com/go-skynet/LocalAI/pkg/assets" "github.com/go-skynet/LocalAI/pkg/model" - pkgStartup "github.com/go-skynet/LocalAI/pkg/startup" + "github.com/go-skynet/LocalAI/pkg/utils" "github.com/rs/zerolog" "github.com/rs/zerolog/log" ) -func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.ModelLoader, *config.ApplicationConfig, error) { +// (*config.BackendConfigLoader, *model.ModelLoader, *config.ApplicationConfig, error) { +func Startup(opts ...config.AppOption) (*core.Application, error) { options := config.NewApplicationConfig(opts...) zerolog.SetGlobalLevel(zerolog.InfoLevel) @@ -27,68 +31,75 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode // Make sure directories exists if options.ModelPath == "" { - return nil, nil, nil, fmt.Errorf("options.ModelPath cannot be empty") + return nil, fmt.Errorf("options.ModelPath cannot be empty") } err := os.MkdirAll(options.ModelPath, 0755) if err != nil { - return nil, nil, nil, fmt.Errorf("unable to create ModelPath: %q", err) + return nil, fmt.Errorf("unable to create ModelPath: %q", err) } if options.ImageDir != "" { err := os.MkdirAll(options.ImageDir, 0755) if err != nil { - return nil, nil, nil, fmt.Errorf("unable to create ImageDir: %q", err) + return nil, fmt.Errorf("unable to create ImageDir: %q", err) } } if options.AudioDir != "" { err := os.MkdirAll(options.AudioDir, 0755) if err != nil { - return nil, nil, nil, fmt.Errorf("unable to create AudioDir: %q", err) + return nil, fmt.Errorf("unable to create AudioDir: %q", err) } } if options.UploadDir != "" { err := os.MkdirAll(options.UploadDir, 0755) if err != nil { - return nil, nil, nil, fmt.Errorf("unable to create UploadDir: %q", err) + return nil, fmt.Errorf("unable to create UploadDir: %q", err) + } + } + if options.ConfigsDir != "" { + err := os.MkdirAll(options.ConfigsDir, 0755) + if err != nil { + return nil, fmt.Errorf("unable to create ConfigsDir: %q", err) } } - // - pkgStartup.PreloadModelsConfigurations(options.ModelLibraryURL, options.ModelPath, options.ModelsURL...) + // Load config jsons + utils.LoadConfig(options.UploadDir, openaiendpoint.UploadedFilesFile, &openaiendpoint.UploadedFiles) + utils.LoadConfig(options.ConfigsDir, openaiendpoint.AssistantsConfigFile, &openaiendpoint.Assistants) + utils.LoadConfig(options.ConfigsDir, openaiendpoint.AssistantsFileConfigFile, &openaiendpoint.AssistantFiles) - cl := config.NewBackendConfigLoader() - ml := model.NewModelLoader(options.ModelPath) + app := createApplication(options) - configLoaderOpts := options.ToConfigLoaderOptions() + services.PreloadModelsConfigurations(options.ModelLibraryURL, options.ModelPath, options.ModelsURL...) - if err := cl.LoadBackendConfigsFromPath(options.ModelPath, configLoaderOpts...); err != nil { + if err := app.BackendConfigLoader.LoadBackendConfigsFromPath(options.ModelPath, app.ApplicationConfig.ToConfigLoaderOptions()...); err != nil { log.Error().Err(err).Msg("error loading config files") } if options.ConfigFile != "" { - if err := cl.LoadBackendConfigFile(options.ConfigFile, configLoaderOpts...); err != nil { + if err := app.BackendConfigLoader.LoadBackendConfigFile(options.ConfigFile, app.ApplicationConfig.ToConfigLoaderOptions()...); err != nil { log.Error().Err(err).Msg("error loading config file") } } - if err := cl.Preload(options.ModelPath); err != nil { + if err := app.BackendConfigLoader.Preload(options.ModelPath); err != nil { log.Error().Err(err).Msg("error downloading models") } if options.PreloadJSONModels != "" { - if err := services.ApplyGalleryFromString(options.ModelPath, options.PreloadJSONModels, cl, options.Galleries); err != nil { - return nil, nil, nil, err + if err := services.ApplyGalleryFromString(options.ModelPath, options.PreloadJSONModels, app.BackendConfigLoader, options.Galleries); err != nil { + return nil, err } } if options.PreloadModelsFromPath != "" { - if err := services.ApplyGalleryFromFile(options.ModelPath, options.PreloadModelsFromPath, cl, options.Galleries); err != nil { - return nil, nil, nil, err + if err := services.ApplyGalleryFromFile(options.ModelPath, options.PreloadModelsFromPath, app.BackendConfigLoader, options.Galleries); err != nil { + return nil, err } } if options.Debug { - for _, v := range cl.ListBackendConfigs() { - cfg, _ := cl.GetBackendConfig(v) + for _, v := range app.BackendConfigLoader.ListBackendConfigs() { + cfg, _ := app.BackendConfigLoader.GetBackendConfig(v) log.Debug().Msgf("Model: %s (config: %+v)", v, cfg) } } @@ -106,17 +117,17 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode go func() { <-options.Context.Done() log.Debug().Msgf("Context canceled, shutting down") - ml.StopAllGRPC() + app.ModelLoader.StopAllGRPC() }() if options.WatchDog { wd := model.NewWatchDog( - ml, + app.ModelLoader, options.WatchDogBusyTimeout, options.WatchDogIdleTimeout, options.WatchDogBusy, options.WatchDogIdle) - ml.SetWatchDog(wd) + app.ModelLoader.SetWatchDog(wd) go wd.Run() go func() { <-options.Context.Done() @@ -126,5 +137,35 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode } log.Info().Msg("core/startup process completed!") - return cl, ml, options, nil + return app, nil +} + +// In Lieu of a proper DI framework, this function wires up the Application manually. +// This is in core/startup rather than core/state.go to keep package references clean! +func createApplication(appConfig *config.ApplicationConfig) *core.Application { + app := &core.Application{ + ApplicationConfig: appConfig, + BackendConfigLoader: config.NewBackendConfigLoader(), + ModelLoader: model.NewModelLoader(appConfig.ModelPath), + } + + var err error + + app.EmbeddingsBackendService = backend.NewEmbeddingsBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig) + app.ImageGenerationBackendService = backend.NewImageGenerationBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig) + app.LLMBackendService = backend.NewLLMBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig) + app.TranscriptionBackendService = backend.NewTranscriptionBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig) + app.TextToSpeechBackendService = backend.NewTextToSpeechBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig) + + app.BackendMonitorService = services.NewBackendMonitorService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig) + app.GalleryService = services.NewGalleryService(app.ApplicationConfig.ModelPath) + app.ListModelsService = services.NewListModelsService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig) + app.OpenAIService = services.NewOpenAIService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig, app.LLMBackendService) + + app.LocalAIMetricsService, err = services.NewLocalAIMetricsService() + if err != nil { + log.Warn().Msg("Unable to initialize LocalAIMetricsService - non-fatal, optional service") + } + + return app } diff --git a/core/state.go b/core/state.go new file mode 100644 index 00000000..cf0d614b --- /dev/null +++ b/core/state.go @@ -0,0 +1,41 @@ +package core + +import ( + "github.com/go-skynet/LocalAI/core/backend" + "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/core/services" + "github.com/go-skynet/LocalAI/pkg/model" +) + +// TODO: Can I come up with a better name or location for this? +// The purpose of this structure is to hold pointers to all initialized services, to make plumbing easy +// Perhaps a proper DI system is worth it in the future, but for now keep things simple. +type Application struct { + + // Application-Level Config + ApplicationConfig *config.ApplicationConfig + // ApplicationState *ApplicationState + + // Core Low-Level Services + BackendConfigLoader *config.BackendConfigLoader + ModelLoader *model.ModelLoader + + // Backend Services + EmbeddingsBackendService *backend.EmbeddingsBackendService + ImageGenerationBackendService *backend.ImageGenerationBackendService + LLMBackendService *backend.LLMBackendService + TranscriptionBackendService *backend.TranscriptionBackendService + TextToSpeechBackendService *backend.TextToSpeechBackendService + + // LocalAI System Services + BackendMonitorService *services.BackendMonitorService + GalleryService *services.GalleryService + ListModelsService *services.ListModelsService + LocalAIMetricsService *services.LocalAIMetricsService + OpenAIService *services.OpenAIService +} + +// TODO [NEXT PR?]: Break up ApplicationConfig. +// Migrate over stuff that is not set via config at all - especially runtime stuff +type ApplicationState struct { +} diff --git a/examples/bruno/LocalAI Test Requests/llm text/-completions Stream.bru b/examples/bruno/LocalAI Test Requests/llm text/-completions Stream.bru new file mode 100644 index 00000000..c33bafe1 --- /dev/null +++ b/examples/bruno/LocalAI Test Requests/llm text/-completions Stream.bru @@ -0,0 +1,25 @@ +meta { + name: -completions Stream + type: http + seq: 4 +} + +post { + url: {{PROTOCOL}}{{HOST}}:{{PORT}}/completions + body: json + auth: none +} + +headers { + Content-Type: application/json +} + +body:json { + { + "model": "{{DEFAULT_MODEL}}", + "prompt": "function downloadFile(string url, string outputPath) {", + "max_tokens": 256, + "temperature": 0.5, + "stream": true + } +} diff --git a/pkg/concurrency/concurrency.go b/pkg/concurrency/concurrency.go new file mode 100644 index 00000000..324e8cc5 --- /dev/null +++ b/pkg/concurrency/concurrency.go @@ -0,0 +1,135 @@ +package concurrency + +import ( + "sync" +) + +// TODO: closeWhenDone bool parameter :: +// It currently is experimental, and therefore exists. +// Is there ever a situation to use false? + +// This function is used to merge the results of a slice of channels of a specific result type down to a single result channel of a second type. +// mappingFn allows the caller to convert from the input type to the output type +// if closeWhenDone is set to true, the output channel will be closed when all individual result channels of the slice have been closed - otherwise it will be left open for future use. +// The same WaitGroup used to trigger that optional closing is returned for any other synchronization purposes. +func SliceOfChannelsRawMerger[IndividualResultType any, OutputResultType any](individualResultChannels []<-chan IndividualResultType, outputChannel chan<- OutputResultType, mappingFn func(IndividualResultType) (OutputResultType, error), closeWhenDone bool) *sync.WaitGroup { + var wg sync.WaitGroup + wg.Add(len(individualResultChannels)) + mergingFn := func(c <-chan IndividualResultType) { + for r := range c { + mr, err := mappingFn(r) + if err == nil { + outputChannel <- mr + } + } + wg.Done() + } + for _, irc := range individualResultChannels { + go mergingFn(irc) + } + if closeWhenDone { + go func() { + wg.Wait() + close(outputChannel) + }() + } + + return &wg +} + +// This function is used to merge the results of a slice of channels of a specific result type down to a single result channel of THE SAME TYPE. +// if closeWhenDone is set to true, the output channel will be closed when all individual result channels of the slice have been closed - otherwise it will be left open for future use. +// The same WaitGroup used to trigger that optional closing is returned for any other synchronization purposes. +func SliceOfChannelsRawMergerWithoutMapping[ResultType any](individualResultsChannels []<-chan ResultType, outputChannel chan<- ResultType, closeWhenDone bool) *sync.WaitGroup { + return SliceOfChannelsRawMerger(individualResultsChannels, outputChannel, func(v ResultType) (ResultType, error) { return v, nil }, closeWhenDone) +} + +// This function is used to merge the results of a slice of channels of a specific result type down to a single succcess result channel of a second type, and an error channel +// mappingFn allows the caller to convert from the input type to the output type +// This variant is designed to be aware of concurrency.ErrorOr[T], splitting successes from failures. +// if closeWhenDone is set to true, the output channel will be closed when all individual result channels of the slice have been closed - otherwise it will be left open for future use. +// The same WaitGroup used to trigger that optional closing is returned for any other synchronization purposes. +func SliceOfChannelsMergerWithErrors[IndividualResultType any, OutputResultType any](individualResultChannels []<-chan ErrorOr[IndividualResultType], successChannel chan<- OutputResultType, errorChannel chan<- error, mappingFn func(IndividualResultType) (OutputResultType, error), closeWhenDone bool) *sync.WaitGroup { + var wg sync.WaitGroup + wg.Add(len(individualResultChannels)) + mergingFn := func(c <-chan ErrorOr[IndividualResultType]) { + for r := range c { + if r.Error != nil { + errorChannel <- r.Error + } else { + mv, err := mappingFn(r.Value) + if err != nil { + errorChannel <- err + } else { + successChannel <- mv + } + } + } + wg.Done() + } + for _, irc := range individualResultChannels { + go mergingFn(irc) + } + if closeWhenDone { + go func() { + wg.Wait() + close(successChannel) + close(errorChannel) + }() + } + return &wg +} + +// This function is used to reduce down the results of a slice of channels of a specific result type down to a single result value of a second type. +// reducerFn allows the caller to convert from the input type to the output type +// if closeWhenDone is set to true, the output channel will be closed when all individual result channels of the slice have been closed - otherwise it will be left open for future use. +// The same WaitGroup used to trigger that optional closing is returned for any other synchronization purposes. +func SliceOfChannelsReducer[InputResultType any, OutputResultType any](individualResultsChannels []<-chan InputResultType, outputChannel chan<- OutputResultType, + reducerFn func(iv InputResultType, ov OutputResultType) OutputResultType, initialValue OutputResultType, closeWhenDone bool) (wg *sync.WaitGroup) { + wg = &sync.WaitGroup{} + wg.Add(len(individualResultsChannels)) + reduceLock := sync.Mutex{} + reducingFn := func(c <-chan InputResultType) { + for iv := range c { + reduceLock.Lock() + initialValue = reducerFn(iv, initialValue) + reduceLock.Unlock() + } + wg.Done() + } + for _, irc := range individualResultsChannels { + go reducingFn(irc) + } + go func() { + wg.Wait() + outputChannel <- initialValue + if closeWhenDone { + close(outputChannel) + } + }() + return wg +} + +// This function is primarily designed to be used in combination with the above utility functions. +// A slice of input result channels of a specific type is provided, along with a function to map those values to another type +// A slice of output result channels is returned, where each value is mapped as it comes in. +// The order of the slice will be retained. +func SliceOfChannelsTransformer[InputResultType any, OutputResultType any](inputChanels []<-chan InputResultType, mappingFn func(v InputResultType) OutputResultType) (outputChannels []<-chan OutputResultType) { + rawOutputChannels := make([]<-chan OutputResultType, len(inputChanels)) + + transformingFn := func(ic <-chan InputResultType, oc chan OutputResultType) { + for iv := range ic { + oc <- mappingFn(iv) + } + close(oc) + } + + for ci, c := range inputChanels { + roc := make(chan OutputResultType) + go transformingFn(c, roc) + rawOutputChannels[ci] = roc + } + + outputChannels = rawOutputChannels + return +} diff --git a/pkg/concurrency/concurrency_test.go b/pkg/concurrency/concurrency_test.go new file mode 100644 index 00000000..fedd74be --- /dev/null +++ b/pkg/concurrency/concurrency_test.go @@ -0,0 +1,101 @@ +package concurrency_test + +// TODO: noramlly, these go in utils_tests, right? Why does this cause problems only in pkg/utils? + +import ( + "fmt" + "slices" + + . "github.com/go-skynet/LocalAI/pkg/concurrency" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("utils/concurrency tests", func() { + It("SliceOfChannelsReducer works", func() { + individualResultsChannels := []<-chan int{} + initialValue := 0 + for i := 0; i < 3; i++ { + c := make(chan int) + go func(i int, c chan int) { + for ii := 1; ii < 4; ii++ { + c <- (i * ii) + } + close(c) + }(i, c) + individualResultsChannels = append(individualResultsChannels, c) + } + Expect(len(individualResultsChannels)).To(Equal(3)) + finalResultChannel := make(chan int) + wg := SliceOfChannelsReducer[int, int](individualResultsChannels, finalResultChannel, func(input int, val int) int { + return val + input + }, initialValue, true) + + Expect(wg).ToNot(BeNil()) + + result := <-finalResultChannel + + Expect(result).ToNot(Equal(0)) + Expect(result).To(Equal(18)) + }) + + It("SliceOfChannelsRawMergerWithoutMapping works", func() { + individualResultsChannels := []<-chan int{} + for i := 0; i < 3; i++ { + c := make(chan int) + go func(i int, c chan int) { + for ii := 1; ii < 4; ii++ { + c <- (i * ii) + } + close(c) + }(i, c) + individualResultsChannels = append(individualResultsChannels, c) + } + Expect(len(individualResultsChannels)).To(Equal(3)) + outputChannel := make(chan int) + wg := SliceOfChannelsRawMergerWithoutMapping(individualResultsChannels, outputChannel, true) + Expect(wg).ToNot(BeNil()) + outputSlice := []int{} + for v := range outputChannel { + outputSlice = append(outputSlice, v) + } + Expect(len(outputSlice)).To(Equal(9)) + slices.Sort(outputSlice) + Expect(outputSlice[0]).To(BeZero()) + Expect(outputSlice[3]).To(Equal(1)) + Expect(outputSlice[8]).To(Equal(6)) + }) + + It("SliceOfChannelsTransformer works", func() { + individualResultsChannels := []<-chan int{} + for i := 0; i < 3; i++ { + c := make(chan int) + go func(i int, c chan int) { + for ii := 1; ii < 4; ii++ { + c <- (i * ii) + } + close(c) + }(i, c) + individualResultsChannels = append(individualResultsChannels, c) + } + Expect(len(individualResultsChannels)).To(Equal(3)) + mappingFn := func(i int) string { + return fmt.Sprintf("$%d", i) + } + + outputChannels := SliceOfChannelsTransformer(individualResultsChannels, mappingFn) + Expect(len(outputChannels)).To(Equal(3)) + rSlice := []string{} + for ii := 1; ii < 4; ii++ { + for i := 0; i < 3; i++ { + res := <-outputChannels[i] + rSlice = append(rSlice, res) + } + } + slices.Sort(rSlice) + Expect(rSlice[0]).To(Equal("$0")) + Expect(rSlice[3]).To(Equal("$1")) + Expect(rSlice[8]).To(Equal("$6")) + }) +}) diff --git a/pkg/concurrency/types.go b/pkg/concurrency/types.go new file mode 100644 index 00000000..76081ba3 --- /dev/null +++ b/pkg/concurrency/types.go @@ -0,0 +1,6 @@ +package concurrency + +type ErrorOr[T any] struct { + Value T + Error error +} diff --git a/pkg/grpc/backend.go b/pkg/grpc/backend.go index 8fb8c39d..49a6b1bd 100644 --- a/pkg/grpc/backend.go +++ b/pkg/grpc/backend.go @@ -41,7 +41,7 @@ type Backend interface { PredictStream(ctx context.Context, in *pb.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error) - AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.Result, error) + AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.TranscriptionResult, error) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error) Status(ctx context.Context) (*pb.StatusResponse, error) diff --git a/pkg/grpc/base/base.go b/pkg/grpc/base/base.go index 0af5d94f..c0b4bc34 100644 --- a/pkg/grpc/base/base.go +++ b/pkg/grpc/base/base.go @@ -53,8 +53,8 @@ func (llm *Base) GenerateImage(*pb.GenerateImageRequest) error { return fmt.Errorf("unimplemented") } -func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (schema.Result, error) { - return schema.Result{}, fmt.Errorf("unimplemented") +func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (schema.TranscriptionResult, error) { + return schema.TranscriptionResult{}, fmt.Errorf("unimplemented") } func (llm *Base) TTS(*pb.TTSRequest) error { diff --git a/pkg/grpc/client.go b/pkg/grpc/client.go index 882db12a..0e0e56c7 100644 --- a/pkg/grpc/client.go +++ b/pkg/grpc/client.go @@ -210,7 +210,7 @@ func (c *Client) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOp return client.TTS(ctx, in, opts...) } -func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.Result, error) { +func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.TranscriptionResult, error) { if !c.parallel { c.opMutex.Lock() defer c.opMutex.Unlock() @@ -231,7 +231,7 @@ func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptReques if err != nil { return nil, err } - tresult := &schema.Result{} + tresult := &schema.TranscriptionResult{} for _, s := range res.Segments { tks := []int{} for _, t := range s.Tokens { diff --git a/pkg/grpc/embed.go b/pkg/grpc/embed.go index 73b185a3..b4ba4884 100644 --- a/pkg/grpc/embed.go +++ b/pkg/grpc/embed.go @@ -53,12 +53,12 @@ func (e *embedBackend) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc. return e.s.TTS(ctx, in) } -func (e *embedBackend) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.Result, error) { +func (e *embedBackend) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.TranscriptionResult, error) { r, err := e.s.AudioTranscription(ctx, in) if err != nil { return nil, err } - tr := &schema.Result{} + tr := &schema.TranscriptionResult{} for _, s := range r.Segments { var tks []int for _, t := range s.Tokens { diff --git a/pkg/grpc/interface.go b/pkg/grpc/interface.go index 4d06544d..aa7a3fbc 100644 --- a/pkg/grpc/interface.go +++ b/pkg/grpc/interface.go @@ -15,7 +15,7 @@ type LLM interface { Load(*pb.ModelOptions) error Embeddings(*pb.PredictOptions) ([]float32, error) GenerateImage(*pb.GenerateImageRequest) error - AudioTranscription(*pb.TranscriptRequest) (schema.Result, error) + AudioTranscription(*pb.TranscriptRequest) (schema.TranscriptionResult, error) TTS(*pb.TTSRequest) error TokenizeString(*pb.PredictOptions) (pb.TokenizationResponse, error) Status() (pb.StatusResponse, error) diff --git a/pkg/model/initializers.go b/pkg/model/initializers.go index 5d9808a4..617d8f62 100644 --- a/pkg/model/initializers.go +++ b/pkg/model/initializers.go @@ -81,7 +81,7 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string if _, err := os.Stat(uri); err == nil { serverAddress, err := getFreeAddress() if err != nil { - return "", fmt.Errorf("failed allocating free ports: %s", err.Error()) + return "", fmt.Errorf("%s failed allocating free ports: %s", backend, err.Error()) } // Make sure the process is executable if err := ml.startProcess(uri, o.model, serverAddress); err != nil { @@ -134,7 +134,7 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string if !ready { log.Debug().Msgf("GRPC Service NOT ready") - return "", fmt.Errorf("grpc service not ready") + return "", fmt.Errorf("%s grpc service not ready", backend) } options := *o.gRPCOptions @@ -145,10 +145,10 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string res, err := client.GRPC(o.parallelRequests, ml.wd).LoadModel(o.context, &options) if err != nil { - return "", fmt.Errorf("could not load model: %w", err) + return "", fmt.Errorf("\"%s\" could not load model: %w", backend, err) } if !res.Success { - return "", fmt.Errorf("could not load model (no success): %s", res.Message) + return "", fmt.Errorf("\"%s\" could not load model (no success): %s", backend, res.Message) } return client, nil diff --git a/pkg/startup/model_preload.go b/pkg/startup/model_preload.go deleted file mode 100644 index b09516a7..00000000 --- a/pkg/startup/model_preload.go +++ /dev/null @@ -1,85 +0,0 @@ -package startup - -import ( - "errors" - "os" - "path/filepath" - - "github.com/go-skynet/LocalAI/embedded" - "github.com/go-skynet/LocalAI/pkg/downloader" - "github.com/go-skynet/LocalAI/pkg/utils" - "github.com/rs/zerolog/log" -) - -// PreloadModelsConfigurations will preload models from the given list of URLs -// It will download the model if it is not already present in the model path -// It will also try to resolve if the model is an embedded model YAML configuration -func PreloadModelsConfigurations(modelLibraryURL string, modelPath string, models ...string) { - for _, url := range models { - - // As a best effort, try to resolve the model from the remote library - // if it's not resolved we try with the other method below - if modelLibraryURL != "" { - lib, err := embedded.GetRemoteLibraryShorteners(modelLibraryURL) - if err == nil { - if lib[url] != "" { - log.Debug().Msgf("[startup] model configuration is defined remotely: %s (%s)", url, lib[url]) - url = lib[url] - } - } - } - - url = embedded.ModelShortURL(url) - switch { - case embedded.ExistsInModelsLibrary(url): - modelYAML, err := embedded.ResolveContent(url) - // If we resolve something, just save it to disk and continue - if err != nil { - log.Error().Err(err).Msg("error resolving model content") - continue - } - - log.Debug().Msgf("[startup] resolved embedded model: %s", url) - md5Name := utils.MD5(url) - modelDefinitionFilePath := filepath.Join(modelPath, md5Name) + ".yaml" - if err := os.WriteFile(modelDefinitionFilePath, modelYAML, os.ModePerm); err != nil { - log.Error().Err(err).Str("filepath", modelDefinitionFilePath).Msg("error writing model definition") - } - case downloader.LooksLikeURL(url): - log.Debug().Msgf("[startup] resolved model to download: %s", url) - - // md5 of model name - md5Name := utils.MD5(url) - - // check if file exists - if _, err := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(err, os.ErrNotExist) { - modelDefinitionFilePath := filepath.Join(modelPath, md5Name) + ".yaml" - err := downloader.DownloadFile(url, modelDefinitionFilePath, "", func(fileName, current, total string, percent float64) { - utils.DisplayDownloadFunction(fileName, current, total, percent) - }) - if err != nil { - log.Error().Err(err).Str("url", url).Str("filepath", modelDefinitionFilePath).Msg("error downloading model") - } - } - default: - if _, err := os.Stat(url); err == nil { - log.Debug().Msgf("[startup] resolved local model: %s", url) - // copy to modelPath - md5Name := utils.MD5(url) - - modelYAML, err := os.ReadFile(url) - if err != nil { - log.Error().Err(err).Str("filepath", url).Msg("error reading model definition") - continue - } - - modelDefinitionFilePath := filepath.Join(modelPath, md5Name) + ".yaml" - if err := os.WriteFile(modelDefinitionFilePath, modelYAML, os.ModePerm); err != nil { - log.Error().Err(err).Str("filepath", modelDefinitionFilePath).Msg("error loading model: %s") - } - } else { - log.Warn().Msgf("[startup] failed resolving model '%s'", url) - } - } - } -} diff --git a/pkg/utils/base64.go b/pkg/utils/base64.go new file mode 100644 index 00000000..769d8a88 --- /dev/null +++ b/pkg/utils/base64.go @@ -0,0 +1,50 @@ +package utils + +import ( + "encoding/base64" + "fmt" + "io" + "net/http" + "strings" + "time" +) + +var base64DownloadClient http.Client = http.Client{ + Timeout: 30 * time.Second, +} + +// this function check if the string is an URL, if it's an URL downloads the image in memory +// encodes it in base64 and returns the base64 string + +// This may look weird down in pkg/utils while it is currently only used in core/config +// +// but I believe it may be useful for MQTT as well in the near future, so I'm +// extracting it while I'm thinking of it. +func GetImageURLAsBase64(s string) (string, error) { + if strings.HasPrefix(s, "http") { + // download the image + resp, err := base64DownloadClient.Get(s) + if err != nil { + return "", err + } + defer resp.Body.Close() + + // read the image data into memory + data, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + + // encode the image data in base64 + encoded := base64.StdEncoding.EncodeToString(data) + + // return the base64 string + return encoded, nil + } + + // if the string instead is prefixed with "data:image/jpeg;base64,", drop it + if strings.HasPrefix(s, "data:image/jpeg;base64,") { + return strings.ReplaceAll(s, "data:image/jpeg;base64,", ""), nil + } + return "", fmt.Errorf("not valid string") +}