diff --git a/core/http/endpoints/openai/chat.go b/core/http/endpoints/openai/chat.go index 2b0b10a8..ccbf0946 100644 --- a/core/http/endpoints/openai/chat.go +++ b/core/http/endpoints/openai/chat.go @@ -81,7 +81,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup } responses <- initialMessage - result, err := handleQuestion(config, req, ml, startupOptions, results, prompt) + result, err := handleQuestion(config, req, ml, startupOptions, results, result, prompt) if err != nil { log.Error().Err(err).Msg("error handling question") return @@ -470,7 +470,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup switch { case noActionsToRun: - result, err := handleQuestion(config, input, ml, startupOptions, results, predInput) + result, err := handleQuestion(config, input, ml, startupOptions, results, s, predInput) if err != nil { log.Error().Err(err).Msg("error handling question") return @@ -550,7 +550,14 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup } } -func handleQuestion(config *config.BackendConfig, input *schema.OpenAIRequest, ml *model.ModelLoader, o *config.ApplicationConfig, funcResults []functions.FuncCallResults, prompt string) (string, error) { +func handleQuestion(config *config.BackendConfig, input *schema.OpenAIRequest, ml *model.ModelLoader, o *config.ApplicationConfig, funcResults []functions.FuncCallResults, result, prompt string) (string, error) { + + if len(funcResults) == 0 && result != "" { + log.Debug().Msgf("nothing function results but we had a message from the LLM") + + return result, nil + } + log.Debug().Msgf("nothing to do, computing a reply") arg := "" if len(funcResults) > 0 { diff --git a/pkg/functions/parse.go b/pkg/functions/parse.go index 304f336f..c6941ff6 100644 --- a/pkg/functions/parse.go +++ b/pkg/functions/parse.go @@ -2,6 +2,7 @@ package functions import ( "encoding/json" + "fmt" "regexp" "github.com/go-skynet/LocalAI/pkg/utils" @@ -16,6 +17,8 @@ type FunctionsConfig struct { NoGrammar bool `yaml:"no_grammar"` ResponseRegex string `yaml:"response_regex"` + JSONRegexMatch string `yaml:"json_regex_match"` + // FunctionName enable the LLM to return { "name": "function_name", "arguments": { "arg1": "value1", "arg2": "value2" } } // instead of { "function": "function_name", "arguments": { "arg1": "value1", "arg2": "value2" } }. // This might be useful for certain models trained with the function name as the first token. @@ -38,6 +41,36 @@ func ParseFunctionCall(llmresult string, functionConfig FunctionsConfig) []FuncC results := []FuncCallResults{} + returnResult := func(s string) (name, arguments string, e error) { + // As we have to change the result before processing, we can't stream the answer token-by-token (yet?) + ss := map[string]interface{}{} + // This prevent newlines to break JSON parsing for clients + s = utils.EscapeNewLines(s) + err := json.Unmarshal([]byte(s), &ss) + if err != nil { + log.Error().Err(err).Str("escapedLLMResult", s).Msg("unable to unmarshal llm result") + } + log.Debug().Msgf("Function return: %s %+v", s, ss) + + // The grammar defines the function name as "function", while OpenAI returns "name" + func_name, ok := ss[functionNameKey] + if !ok { + return "", "", fmt.Errorf("unable to find function name in result") + } + // Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object + args, ok := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix) + if !ok { + return "", "", fmt.Errorf("unable to find arguments in result") + } + d, _ := json.Marshal(args) + funcName, ok := func_name.(string) + if !ok { + return "", "", fmt.Errorf("unable to cast function name to string") + } + + return funcName, string(d), nil + } + // if no grammar is used, we have to extract function and arguments from the result if !useGrammars { // the response is a string that we have to parse @@ -61,13 +94,32 @@ func ParseFunctionCall(llmresult string, functionConfig FunctionsConfig) []FuncC if functionName == "" { return results } - } else { - // We expect the result to be a JSON object with a function name and arguments - err := json.Unmarshal([]byte(llmresult), &result) - if err != nil { - log.Error().Err(err).Str("llmresult", llmresult).Msg("unable to unmarshal llm result") + } else if functionConfig.JSONRegexMatch != "" { + //re := regexp.MustCompile(`(?s)(.*?)`) + //m:= re.FindStringSubmatch(`{ foo barr }`) + + // We use a regex to extract the JSON object from the response + var respRegex = regexp.MustCompile(functionConfig.JSONRegexMatch) + match := respRegex.FindStringSubmatch(llmresult) + if len(match) < 2 { return results } + + funcName, args, err := returnResult(match[1]) + if err != nil { + return results + } + + return append(results, FuncCallResults{Name: funcName, Arguments: args}) + + } else { + + funcName, args, err := returnResult(llmresult) + if err != nil { + return results + } + + return append(results, FuncCallResults{Name: funcName, Arguments: args}) } return append(results, FuncCallResults{Name: result[functionNameKey], Arguments: result["arguments"]}) @@ -101,32 +153,12 @@ func ParseFunctionCall(llmresult string, functionConfig FunctionsConfig) []FuncC results = append(results, FuncCallResults{Name: funcName, Arguments: string(d)}) } } else { - // As we have to change the result before processing, we can't stream the answer token-by-token (yet?) - ss := map[string]interface{}{} - // This prevent newlines to break JSON parsing for clients - s := utils.EscapeNewLines(llmresult) - err := json.Unmarshal([]byte(s), &ss) + funcName, args, err := returnResult(llmresult) if err != nil { - log.Error().Err(err).Str("escapedLLMResult", s).Msg("unable to unmarshal llm result") + return results } - log.Debug().Msgf("Function return: %s %+v", s, ss) - // The grammar defines the function name as "function", while OpenAI returns "name" - func_name, ok := ss[functionNameKey] - if !ok { - return results - } - // Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object - args, ok := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix) - if !ok { - return results - } - d, _ := json.Marshal(args) - funcName, ok := func_name.(string) - if !ok { - return results - } - results = append(results, FuncCallResults{Name: funcName, Arguments: string(d)}) + results = append(results, FuncCallResults{Name: funcName, Arguments: args}) } return results diff --git a/pkg/functions/parse_test.go b/pkg/functions/parse_test.go index 4c2208ed..7aedc097 100644 --- a/pkg/functions/parse_test.go +++ b/pkg/functions/parse_test.go @@ -87,7 +87,7 @@ var _ = Describe("LocalAI function parse tests", func() { It("should parse the function name and arguments correctly with the name key", func() { input := `{"name": "add", "arguments": {"x": 5, "y": 3}}` functionConfig.ParallelCalls = false - functionConfig.NoGrammar = false + functionConfig.NoGrammar = true functionConfig.ResponseRegex = "" functionConfig.FunctionName = true @@ -100,7 +100,40 @@ var _ = Describe("LocalAI function parse tests", func() { It("should parse the function name and arguments correctly with the function key", func() { input := `{"function": "add", "arguments": {"x": 5, "y": 3}}` functionConfig.ParallelCalls = false - functionConfig.NoGrammar = false + functionConfig.NoGrammar = true + functionConfig.ResponseRegex = "" + functionConfig.FunctionName = false + + results := ParseFunctionCall(input, functionConfig) + Expect(results).To(HaveLen(1)) + Expect(results[0].Name).To(Equal("add")) + Expect(results[0].Arguments).To(Equal(`{"x":5,"y":3}`)) + }) + + It("Should parse the result by matching the JSONRegexMatch", func() { + input := ` + +{"function": "add", "arguments": {"x": 5, "y": 3}} +` + functionConfig.ParallelCalls = false + functionConfig.NoGrammar = true + functionConfig.JSONRegexMatch = `(?s)(.*?)` + functionConfig.ResponseRegex = "" + functionConfig.FunctionName = false + + results := ParseFunctionCall(input, functionConfig) + Expect(results).To(HaveLen(1)) + Expect(results[0].Name).To(Equal("add")) + Expect(results[0].Arguments).To(Equal(`{"x":5,"y":3}`)) + }) + + It("Should parse the result by matching the JSONRegexMatch", func() { + input := ` +{"function": "add", "arguments": {"x": 5, "y": 3}} +` + functionConfig.ParallelCalls = false + functionConfig.NoGrammar = true + functionConfig.JSONRegexMatch = `(?s)(.*?)` functionConfig.ResponseRegex = "" functionConfig.FunctionName = false