diff --git a/core/backend/backend_suite_test.go b/core/backend/backend_suite_test.go new file mode 100644 index 00000000..541c91f6 --- /dev/null +++ b/core/backend/backend_suite_test.go @@ -0,0 +1,13 @@ +package backend_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestBackend(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Backend test suite") +} diff --git a/core/backend/llm.go b/core/backend/llm.go index 72c4ad9f..2b4564a8 100644 --- a/core/backend/llm.go +++ b/core/backend/llm.go @@ -9,6 +9,8 @@ import ( "sync" "unicode/utf8" + "github.com/rs/zerolog/log" + "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/schema" @@ -181,13 +183,37 @@ func Finetune(config config.BackendConfig, input, prediction string) string { mu.Lock() reg, ok := cutstrings[c] if !ok { - cutstrings[c] = regexp.MustCompile(c) + r, err := regexp.Compile(c) + if err != nil { + log.Fatal().Err(err).Msg("failed to compile regex") + } + cutstrings[c] = r reg = cutstrings[c] } mu.Unlock() prediction = reg.ReplaceAllString(prediction, "") } + // extract results from the response which can be for instance inside XML tags + var predResult string + for _, r := range config.ExtractRegex { + mu.Lock() + reg, ok := cutstrings[r] + if !ok { + regex, err := regexp.Compile(r) + if err != nil { + log.Fatal().Err(err).Msg("failed to compile regex") + } + cutstrings[r] = regex + reg = regex + } + mu.Unlock() + predResult += reg.FindString(prediction) + } + if predResult != "" { + prediction = predResult + } + for _, c := range config.TrimSpace { prediction = strings.TrimSpace(strings.TrimPrefix(prediction, c)) } diff --git a/core/backend/llm_test.go b/core/backend/llm_test.go new file mode 100644 index 00000000..f7630702 --- /dev/null +++ b/core/backend/llm_test.go @@ -0,0 +1,109 @@ +package backend_test + +import ( + . "github.com/mudler/LocalAI/core/backend" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/schema" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("LLM tests", func() { + Context("Finetune LLM output", func() { + var ( + testConfig config.BackendConfig + input string + prediction string + result string + ) + + BeforeEach(func() { + testConfig = config.BackendConfig{ + PredictionOptions: schema.PredictionOptions{ + Echo: false, + }, + LLMConfig: config.LLMConfig{ + Cutstrings: []string{`<.*?>`}, // Example regex for removing XML tags + ExtractRegex: []string{`(.*?)`}, // Example regex to extract from tags + TrimSpace: []string{" ", "\n"}, + TrimSuffix: []string{".", "!"}, + }, + } + }) + + Context("when echo is enabled", func() { + BeforeEach(func() { + testConfig.Echo = true + input = "Hello" + prediction = "World" + }) + + It("should prepend input to prediction", func() { + result = Finetune(testConfig, input, prediction) + Expect(result).To(Equal("HelloWorld")) + }) + }) + + Context("when echo is disabled", func() { + BeforeEach(func() { + testConfig.Echo = false + input = "Hello" + prediction = "World" + }) + + It("should not modify the prediction with input", func() { + result = Finetune(testConfig, input, prediction) + Expect(result).To(Equal("World")) + }) + }) + + Context("when cutstrings regex is applied", func() { + BeforeEach(func() { + input = "" + prediction = "
Hello
World" + }) + + It("should remove substrings matching cutstrings regex", func() { + result = Finetune(testConfig, input, prediction) + Expect(result).To(Equal("Hello World")) + }) + }) + + Context("when extract regex is applied", func() { + BeforeEach(func() { + input = "" + prediction = "42" + }) + + It("should extract substrings matching the extract regex", func() { + result = Finetune(testConfig, input, prediction) + Expect(result).To(Equal("42")) + }) + }) + + Context("when trimming spaces", func() { + BeforeEach(func() { + input = "" + prediction = " Hello World " + }) + + It("should trim spaces from the prediction", func() { + result = Finetune(testConfig, input, prediction) + Expect(result).To(Equal("Hello World")) + }) + }) + + Context("when trimming suffixes", func() { + BeforeEach(func() { + input = "" + prediction = "Hello World." + }) + + It("should trim suffixes from the prediction", func() { + result = Finetune(testConfig, input, prediction) + Expect(result).To(Equal("Hello World")) + }) + }) + }) +}) diff --git a/core/config/backend_config.go b/core/config/backend_config.go index b83e1a98..027e18a4 100644 --- a/core/config/backend_config.go +++ b/core/config/backend_config.go @@ -126,6 +126,7 @@ type LLMConfig struct { Grammar string `yaml:"grammar"` StopWords []string `yaml:"stopwords"` Cutstrings []string `yaml:"cutstrings"` + ExtractRegex []string `yaml:"extract_regex"` TrimSpace []string `yaml:"trimspace"` TrimSuffix []string `yaml:"trimsuffix"` diff --git a/core/http/endpoints/openai/chat.go b/core/http/endpoints/openai/chat.go index a979b7bc..8144bdcd 100644 --- a/core/http/endpoints/openai/chat.go +++ b/core/http/endpoints/openai/chat.go @@ -68,9 +68,9 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup textContentToReturn = functions.ParseTextContent(result, config.FunctionsConfig) result = functions.CleanupLLMResult(result, config.FunctionsConfig) - results := functions.ParseFunctionCall(result, config.FunctionsConfig) + functionResults := functions.ParseFunctionCall(result, config.FunctionsConfig) log.Debug().Msgf("Text content to return: %s", textContentToReturn) - noActionToRun := len(results) > 0 && results[0].Name == noAction || len(results) == 0 + noActionToRun := len(functionResults) > 0 && functionResults[0].Name == noAction || len(functionResults) == 0 switch { case noActionToRun: @@ -83,7 +83,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup } responses <- initialMessage - result, err := handleQuestion(config, req, ml, startupOptions, results, result, prompt) + result, err := handleQuestion(config, req, ml, startupOptions, functionResults, result, prompt) if err != nil { log.Error().Err(err).Msg("error handling question") return @@ -105,7 +105,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup responses <- resp default: - for i, ss := range results { + for i, ss := range functionResults { name, args := ss.Name, ss.Arguments initialMessage := schema.OpenAIResponse{