diff --git a/core/http/endpoints/openai/chat.go b/core/http/endpoints/openai/chat.go
index 2b0b10a8..ccbf0946 100644
--- a/core/http/endpoints/openai/chat.go
+++ b/core/http/endpoints/openai/chat.go
@@ -81,7 +81,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
}
responses <- initialMessage
- result, err := handleQuestion(config, req, ml, startupOptions, results, prompt)
+ result, err := handleQuestion(config, req, ml, startupOptions, results, result, prompt)
if err != nil {
log.Error().Err(err).Msg("error handling question")
return
@@ -470,7 +470,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
switch {
case noActionsToRun:
- result, err := handleQuestion(config, input, ml, startupOptions, results, predInput)
+ result, err := handleQuestion(config, input, ml, startupOptions, results, s, predInput)
if err != nil {
log.Error().Err(err).Msg("error handling question")
return
@@ -550,7 +550,14 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
}
}
-func handleQuestion(config *config.BackendConfig, input *schema.OpenAIRequest, ml *model.ModelLoader, o *config.ApplicationConfig, funcResults []functions.FuncCallResults, prompt string) (string, error) {
+func handleQuestion(config *config.BackendConfig, input *schema.OpenAIRequest, ml *model.ModelLoader, o *config.ApplicationConfig, funcResults []functions.FuncCallResults, result, prompt string) (string, error) {
+
+ if len(funcResults) == 0 && result != "" {
+ log.Debug().Msgf("nothing function results but we had a message from the LLM")
+
+ return result, nil
+ }
+
log.Debug().Msgf("nothing to do, computing a reply")
arg := ""
if len(funcResults) > 0 {
diff --git a/pkg/functions/parse.go b/pkg/functions/parse.go
index 304f336f..c6941ff6 100644
--- a/pkg/functions/parse.go
+++ b/pkg/functions/parse.go
@@ -2,6 +2,7 @@ package functions
import (
"encoding/json"
+ "fmt"
"regexp"
"github.com/go-skynet/LocalAI/pkg/utils"
@@ -16,6 +17,8 @@ type FunctionsConfig struct {
NoGrammar bool `yaml:"no_grammar"`
ResponseRegex string `yaml:"response_regex"`
+ JSONRegexMatch string `yaml:"json_regex_match"`
+
// FunctionName enable the LLM to return { "name": "function_name", "arguments": { "arg1": "value1", "arg2": "value2" } }
// instead of { "function": "function_name", "arguments": { "arg1": "value1", "arg2": "value2" } }.
// This might be useful for certain models trained with the function name as the first token.
@@ -38,6 +41,36 @@ func ParseFunctionCall(llmresult string, functionConfig FunctionsConfig) []FuncC
results := []FuncCallResults{}
+ returnResult := func(s string) (name, arguments string, e error) {
+ // As we have to change the result before processing, we can't stream the answer token-by-token (yet?)
+ ss := map[string]interface{}{}
+ // This prevent newlines to break JSON parsing for clients
+ s = utils.EscapeNewLines(s)
+ err := json.Unmarshal([]byte(s), &ss)
+ if err != nil {
+ log.Error().Err(err).Str("escapedLLMResult", s).Msg("unable to unmarshal llm result")
+ }
+ log.Debug().Msgf("Function return: %s %+v", s, ss)
+
+ // The grammar defines the function name as "function", while OpenAI returns "name"
+ func_name, ok := ss[functionNameKey]
+ if !ok {
+ return "", "", fmt.Errorf("unable to find function name in result")
+ }
+ // Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
+ args, ok := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
+ if !ok {
+ return "", "", fmt.Errorf("unable to find arguments in result")
+ }
+ d, _ := json.Marshal(args)
+ funcName, ok := func_name.(string)
+ if !ok {
+ return "", "", fmt.Errorf("unable to cast function name to string")
+ }
+
+ return funcName, string(d), nil
+ }
+
// if no grammar is used, we have to extract function and arguments from the result
if !useGrammars {
// the response is a string that we have to parse
@@ -61,13 +94,32 @@ func ParseFunctionCall(llmresult string, functionConfig FunctionsConfig) []FuncC
if functionName == "" {
return results
}
- } else {
- // We expect the result to be a JSON object with a function name and arguments
- err := json.Unmarshal([]byte(llmresult), &result)
- if err != nil {
- log.Error().Err(err).Str("llmresult", llmresult).Msg("unable to unmarshal llm result")
+ } else if functionConfig.JSONRegexMatch != "" {
+ //re := regexp.MustCompile(`(?s)(.*?)`)
+ //m:= re.FindStringSubmatch(`{ foo barr }`)
+
+ // We use a regex to extract the JSON object from the response
+ var respRegex = regexp.MustCompile(functionConfig.JSONRegexMatch)
+ match := respRegex.FindStringSubmatch(llmresult)
+ if len(match) < 2 {
return results
}
+
+ funcName, args, err := returnResult(match[1])
+ if err != nil {
+ return results
+ }
+
+ return append(results, FuncCallResults{Name: funcName, Arguments: args})
+
+ } else {
+
+ funcName, args, err := returnResult(llmresult)
+ if err != nil {
+ return results
+ }
+
+ return append(results, FuncCallResults{Name: funcName, Arguments: args})
}
return append(results, FuncCallResults{Name: result[functionNameKey], Arguments: result["arguments"]})
@@ -101,32 +153,12 @@ func ParseFunctionCall(llmresult string, functionConfig FunctionsConfig) []FuncC
results = append(results, FuncCallResults{Name: funcName, Arguments: string(d)})
}
} else {
- // As we have to change the result before processing, we can't stream the answer token-by-token (yet?)
- ss := map[string]interface{}{}
- // This prevent newlines to break JSON parsing for clients
- s := utils.EscapeNewLines(llmresult)
- err := json.Unmarshal([]byte(s), &ss)
+ funcName, args, err := returnResult(llmresult)
if err != nil {
- log.Error().Err(err).Str("escapedLLMResult", s).Msg("unable to unmarshal llm result")
+ return results
}
- log.Debug().Msgf("Function return: %s %+v", s, ss)
- // The grammar defines the function name as "function", while OpenAI returns "name"
- func_name, ok := ss[functionNameKey]
- if !ok {
- return results
- }
- // Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
- args, ok := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
- if !ok {
- return results
- }
- d, _ := json.Marshal(args)
- funcName, ok := func_name.(string)
- if !ok {
- return results
- }
- results = append(results, FuncCallResults{Name: funcName, Arguments: string(d)})
+ results = append(results, FuncCallResults{Name: funcName, Arguments: args})
}
return results
diff --git a/pkg/functions/parse_test.go b/pkg/functions/parse_test.go
index 4c2208ed..7aedc097 100644
--- a/pkg/functions/parse_test.go
+++ b/pkg/functions/parse_test.go
@@ -87,7 +87,7 @@ var _ = Describe("LocalAI function parse tests", func() {
It("should parse the function name and arguments correctly with the name key", func() {
input := `{"name": "add", "arguments": {"x": 5, "y": 3}}`
functionConfig.ParallelCalls = false
- functionConfig.NoGrammar = false
+ functionConfig.NoGrammar = true
functionConfig.ResponseRegex = ""
functionConfig.FunctionName = true
@@ -100,7 +100,40 @@ var _ = Describe("LocalAI function parse tests", func() {
It("should parse the function name and arguments correctly with the function key", func() {
input := `{"function": "add", "arguments": {"x": 5, "y": 3}}`
functionConfig.ParallelCalls = false
- functionConfig.NoGrammar = false
+ functionConfig.NoGrammar = true
+ functionConfig.ResponseRegex = ""
+ functionConfig.FunctionName = false
+
+ results := ParseFunctionCall(input, functionConfig)
+ Expect(results).To(HaveLen(1))
+ Expect(results[0].Name).To(Equal("add"))
+ Expect(results[0].Arguments).To(Equal(`{"x":5,"y":3}`))
+ })
+
+ It("Should parse the result by matching the JSONRegexMatch", func() {
+ input := `
+
+{"function": "add", "arguments": {"x": 5, "y": 3}}
+`
+ functionConfig.ParallelCalls = false
+ functionConfig.NoGrammar = true
+ functionConfig.JSONRegexMatch = `(?s)(.*?)`
+ functionConfig.ResponseRegex = ""
+ functionConfig.FunctionName = false
+
+ results := ParseFunctionCall(input, functionConfig)
+ Expect(results).To(HaveLen(1))
+ Expect(results[0].Name).To(Equal("add"))
+ Expect(results[0].Arguments).To(Equal(`{"x":5,"y":3}`))
+ })
+
+ It("Should parse the result by matching the JSONRegexMatch", func() {
+ input := `
+{"function": "add", "arguments": {"x": 5, "y": 3}}
+`
+ functionConfig.ParallelCalls = false
+ functionConfig.NoGrammar = true
+ functionConfig.JSONRegexMatch = `(?s)(.*?)`
functionConfig.ResponseRegex = ""
functionConfig.FunctionName = false