diff --git a/core/backend/backend_suite_test.go b/core/backend/backend_suite_test.go
new file mode 100644
index 00000000..541c91f6
--- /dev/null
+++ b/core/backend/backend_suite_test.go
@@ -0,0 +1,13 @@
+package backend_test
+
+import (
+ "testing"
+
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+)
+
+func TestBackend(t *testing.T) {
+ RegisterFailHandler(Fail)
+ RunSpecs(t, "Backend test suite")
+}
diff --git a/core/backend/llm.go b/core/backend/llm.go
index 72c4ad9f..2b4564a8 100644
--- a/core/backend/llm.go
+++ b/core/backend/llm.go
@@ -9,6 +9,8 @@ import (
"sync"
"unicode/utf8"
+ "github.com/rs/zerolog/log"
+
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema"
@@ -181,13 +183,37 @@ func Finetune(config config.BackendConfig, input, prediction string) string {
mu.Lock()
reg, ok := cutstrings[c]
if !ok {
- cutstrings[c] = regexp.MustCompile(c)
+ r, err := regexp.Compile(c)
+ if err != nil {
+ log.Fatal().Err(err).Msg("failed to compile regex")
+ }
+ cutstrings[c] = r
reg = cutstrings[c]
}
mu.Unlock()
prediction = reg.ReplaceAllString(prediction, "")
}
+ // extract results from the response which can be for instance inside XML tags
+ var predResult string
+ for _, r := range config.ExtractRegex {
+ mu.Lock()
+ reg, ok := cutstrings[r]
+ if !ok {
+ regex, err := regexp.Compile(r)
+ if err != nil {
+ log.Fatal().Err(err).Msg("failed to compile regex")
+ }
+ cutstrings[r] = regex
+ reg = regex
+ }
+ mu.Unlock()
+ predResult += reg.FindString(prediction)
+ }
+ if predResult != "" {
+ prediction = predResult
+ }
+
for _, c := range config.TrimSpace {
prediction = strings.TrimSpace(strings.TrimPrefix(prediction, c))
}
diff --git a/core/backend/llm_test.go b/core/backend/llm_test.go
new file mode 100644
index 00000000..f7630702
--- /dev/null
+++ b/core/backend/llm_test.go
@@ -0,0 +1,109 @@
+package backend_test
+
+import (
+ . "github.com/mudler/LocalAI/core/backend"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/schema"
+
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+)
+
+var _ = Describe("LLM tests", func() {
+ Context("Finetune LLM output", func() {
+ var (
+ testConfig config.BackendConfig
+ input string
+ prediction string
+ result string
+ )
+
+ BeforeEach(func() {
+ testConfig = config.BackendConfig{
+ PredictionOptions: schema.PredictionOptions{
+ Echo: false,
+ },
+ LLMConfig: config.LLMConfig{
+ Cutstrings: []string{`<.*?>`}, // Example regex for removing XML tags
+ ExtractRegex: []string{`(.*?)`}, // Example regex to extract from tags
+ TrimSpace: []string{" ", "\n"},
+ TrimSuffix: []string{".", "!"},
+ },
+ }
+ })
+
+ Context("when echo is enabled", func() {
+ BeforeEach(func() {
+ testConfig.Echo = true
+ input = "Hello"
+ prediction = "World"
+ })
+
+ It("should prepend input to prediction", func() {
+ result = Finetune(testConfig, input, prediction)
+ Expect(result).To(Equal("HelloWorld"))
+ })
+ })
+
+ Context("when echo is disabled", func() {
+ BeforeEach(func() {
+ testConfig.Echo = false
+ input = "Hello"
+ prediction = "World"
+ })
+
+ It("should not modify the prediction with input", func() {
+ result = Finetune(testConfig, input, prediction)
+ Expect(result).To(Equal("World"))
+ })
+ })
+
+ Context("when cutstrings regex is applied", func() {
+ BeforeEach(func() {
+ input = ""
+ prediction = "
Hello
World"
+ })
+
+ It("should remove substrings matching cutstrings regex", func() {
+ result = Finetune(testConfig, input, prediction)
+ Expect(result).To(Equal("Hello World"))
+ })
+ })
+
+ Context("when extract regex is applied", func() {
+ BeforeEach(func() {
+ input = ""
+ prediction = "42"
+ })
+
+ It("should extract substrings matching the extract regex", func() {
+ result = Finetune(testConfig, input, prediction)
+ Expect(result).To(Equal("42"))
+ })
+ })
+
+ Context("when trimming spaces", func() {
+ BeforeEach(func() {
+ input = ""
+ prediction = " Hello World "
+ })
+
+ It("should trim spaces from the prediction", func() {
+ result = Finetune(testConfig, input, prediction)
+ Expect(result).To(Equal("Hello World"))
+ })
+ })
+
+ Context("when trimming suffixes", func() {
+ BeforeEach(func() {
+ input = ""
+ prediction = "Hello World."
+ })
+
+ It("should trim suffixes from the prediction", func() {
+ result = Finetune(testConfig, input, prediction)
+ Expect(result).To(Equal("Hello World"))
+ })
+ })
+ })
+})
diff --git a/core/config/backend_config.go b/core/config/backend_config.go
index b83e1a98..027e18a4 100644
--- a/core/config/backend_config.go
+++ b/core/config/backend_config.go
@@ -126,6 +126,7 @@ type LLMConfig struct {
Grammar string `yaml:"grammar"`
StopWords []string `yaml:"stopwords"`
Cutstrings []string `yaml:"cutstrings"`
+ ExtractRegex []string `yaml:"extract_regex"`
TrimSpace []string `yaml:"trimspace"`
TrimSuffix []string `yaml:"trimsuffix"`
diff --git a/core/http/endpoints/openai/chat.go b/core/http/endpoints/openai/chat.go
index a979b7bc..8144bdcd 100644
--- a/core/http/endpoints/openai/chat.go
+++ b/core/http/endpoints/openai/chat.go
@@ -68,9 +68,9 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
textContentToReturn = functions.ParseTextContent(result, config.FunctionsConfig)
result = functions.CleanupLLMResult(result, config.FunctionsConfig)
- results := functions.ParseFunctionCall(result, config.FunctionsConfig)
+ functionResults := functions.ParseFunctionCall(result, config.FunctionsConfig)
log.Debug().Msgf("Text content to return: %s", textContentToReturn)
- noActionToRun := len(results) > 0 && results[0].Name == noAction || len(results) == 0
+ noActionToRun := len(functionResults) > 0 && functionResults[0].Name == noAction || len(functionResults) == 0
switch {
case noActionToRun:
@@ -83,7 +83,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
}
responses <- initialMessage
- result, err := handleQuestion(config, req, ml, startupOptions, results, result, prompt)
+ result, err := handleQuestion(config, req, ml, startupOptions, functionResults, result, prompt)
if err != nil {
log.Error().Err(err).Msg("error handling question")
return
@@ -105,7 +105,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
responses <- resp
default:
- for i, ss := range results {
+ for i, ss := range functionResults {
name, args := ss.Name, ss.Arguments
initialMessage := schema.OpenAIResponse{