diff --git a/pkg/functions/parse.go b/pkg/functions/parse.go index 0246d70e..5327ee6d 100644 --- a/pkg/functions/parse.go +++ b/pkg/functions/parse.go @@ -2,7 +2,6 @@ package functions import ( "encoding/json" - "fmt" "regexp" "strings" @@ -68,9 +67,6 @@ func ParseFunctionCall(llmresult string, functionConfig FunctionsConfig) []FuncC log.Debug().Msgf("LLM result(processed): %s", llmresult) - multipleResults := functionConfig.ParallelCalls - useGrammars := !functionConfig.NoGrammar - functionNameKey := "function" if functionConfig.FunctionName { functionNameKey = "name" @@ -78,124 +74,85 @@ func ParseFunctionCall(llmresult string, functionConfig FunctionsConfig) []FuncC results := []FuncCallResults{} - returnResult := func(s string) (name, arguments string, e error) { + returnResult := func(s string) (result []FuncCallResults, e error) { // As we have to change the result before processing, we can't stream the answer token-by-token (yet?) - ss := map[string]interface{}{} - // This prevent newlines to break JSON parsing for clients + var ss []map[string]interface{} + result = make([]FuncCallResults, 0) s = utils.EscapeNewLines(s) err := json.Unmarshal([]byte(s), &ss) if err != nil { - log.Warn().Err(err).Str("escapedLLMResult", s).Msg("unable to unmarshal llm result") - } - log.Debug().Msgf("Function return: %s %+v", s, ss) - - // The grammar defines the function name as "function", while OpenAI returns "name" - func_name, ok := ss[functionNameKey] - if !ok { - return "", "", fmt.Errorf("unable to find function name in result") - } - // Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object - args, ok := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix) - if !ok { - return "", "", fmt.Errorf("unable to find arguments in result") - } - d, _ := json.Marshal(args) - funcName, ok := func_name.(string) - if !ok { - return "", "", fmt.Errorf("unable to cast function name to string") - } - - return funcName, string(d), nil - } - - // if no grammar is used, we have to extract function and arguments from the result - if !useGrammars { - // the response is a string that we have to parse - result := make(map[string]string) - - if functionConfig.ResponseRegex != "" { - // We use named regexes here to extract the function name and arguments - // obviously, this expects the LLM to be stable and return correctly formatted JSON - // TODO: optimize this and pre-compile it - var respRegex = regexp.MustCompile(functionConfig.ResponseRegex) - match := respRegex.FindStringSubmatch(llmresult) - for i, name := range respRegex.SubexpNames() { - if i != 0 && name != "" && len(match) > i { - result[name] = match[i] - } - } - - // TODO: open point about multiple results and/or mixed with chat messages - // This is not handled as for now, we only expect one function call per response - functionName := result[functionNameKey] - if functionName == "" { - return results - } - } else if functionConfig.JSONRegexMatch != "" { - //re := regexp.MustCompile(`(?s)(.*?)`) - //m:= re.FindStringSubmatch(`{ foo barr }`) - - // We use a regex to extract the JSON object from the response - var respRegex = regexp.MustCompile(functionConfig.JSONRegexMatch) - match := respRegex.FindStringSubmatch(llmresult) - if len(match) < 2 { - return results - } - - funcName, args, err := returnResult(match[1]) + // If the LLM result is a single object, try unmarshaling it into a single map + var singleObj map[string]interface{} + err = json.Unmarshal([]byte(s), &singleObj) if err != nil { - return results + log.Warn().Err(err).Str("escapedLLMResult", s).Msg("unable to unmarshal llm result") + } else { + ss = []map[string]interface{}{singleObj} } - - return append(results, FuncCallResults{Name: funcName, Arguments: args}) - - } else { - - funcName, args, err := returnResult(llmresult) - if err != nil { - return results - } - - return append(results, FuncCallResults{Name: funcName, Arguments: args}) } - return append(results, FuncCallResults{Name: result[functionNameKey], Arguments: result["arguments"]}) - } - - // with grammars - // TODO: use generics to avoid this code duplication - if multipleResults { - ss := []map[string]interface{}{} - s := utils.EscapeNewLines(llmresult) - err := json.Unmarshal([]byte(s), &ss) - if err != nil { - log.Warn().Err(err).Str("escapedLLMResult", s).Msg("multiple results: unable to unmarshal llm result") - } log.Debug().Msgf("Function return: %s %+v", s, ss) for _, s := range ss { + // The grammar defines the function name as "function", while OpenAI returns "name" func_name, ok := s[functionNameKey] if !ok { continue + //return result, fmt.Errorf("unable to find function name in result") } - args, ok := s["arguments"] + // Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object + args, ok := s["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix) if !ok { continue + //return result, fmt.Errorf("unable to find arguments in result") } d, _ := json.Marshal(args) funcName, ok := func_name.(string) if !ok { continue + //return result, fmt.Errorf("unable to cast function name to string") } - results = append(results, FuncCallResults{Name: funcName, Arguments: string(d)}) + + result = append(result, FuncCallResults{Name: funcName, Arguments: string(d)}) } - } else { - funcName, args, err := returnResult(llmresult) - if err != nil { + + return result, nil + } + + // the response is a string that we have to parse + result := make(map[string]string) + + if functionConfig.ResponseRegex != "" { + // We use named regexes here to extract the function name and arguments + // obviously, this expects the LLM to be stable and return correctly formatted JSON + // TODO: optimize this and pre-compile it + var respRegex = regexp.MustCompile(functionConfig.ResponseRegex) + match := respRegex.FindStringSubmatch(llmresult) + for i, name := range respRegex.SubexpNames() { + if i != 0 && name != "" && len(match) > i { + result[name] = match[i] + } + } + + // TODO: open point about multiple results and/or mixed with chat messages + // This is not handled as for now, we only expect one function call per response + functionName := result[functionNameKey] + if functionName == "" { + return results + } + results = append(results, FuncCallResults{Name: result[functionNameKey], Arguments: result["arguments"]}) + } else if functionConfig.JSONRegexMatch != "" { + + // We use a regex to extract the JSON object from the response + var respRegex = regexp.MustCompile(functionConfig.JSONRegexMatch) + match := respRegex.FindStringSubmatch(llmresult) + if len(match) < 2 { return results } - results = append(results, FuncCallResults{Name: funcName, Arguments: args}) + results, _ = returnResult(match[1]) + } else { + results, _ = returnResult(llmresult) } return results diff --git a/pkg/functions/parse_test.go b/pkg/functions/parse_test.go index 7aedc097..03a01239 100644 --- a/pkg/functions/parse_test.go +++ b/pkg/functions/parse_test.go @@ -11,18 +11,12 @@ var _ = Describe("LocalAI function parse tests", func() { BeforeEach(func() { // Default configuration setup - functionConfig = FunctionsConfig{ - ParallelCalls: false, - NoGrammar: false, - ResponseRegex: `(?P\w+)\s*\((?P.*)\)`, - } + functionConfig = FunctionsConfig{} }) Context("when using grammars and single result expected", func() { It("should parse the function name and arguments correctly", func() { input := `{"function": "add", "arguments": {"x": 5, "y": 3}}` - functionConfig.ParallelCalls = false - functionConfig.NoGrammar = false results := ParseFunctionCall(input, functionConfig) Expect(results).To(HaveLen(1)) @@ -34,7 +28,7 @@ var _ = Describe("LocalAI function parse tests", func() { Context("when not using grammars and regex is needed", func() { It("should extract function name and arguments from the regex", func() { input := `add({"x":5,"y":3})` - functionConfig.NoGrammar = true + functionConfig.ResponseRegex = `(?P\w+)\s*\((?P.*)\)` results := ParseFunctionCall(input, functionConfig) Expect(results).To(HaveLen(1)) @@ -46,33 +40,19 @@ var _ = Describe("LocalAI function parse tests", func() { Context("when having invalid input", func() { It("returns no results when there is no input", func() { input := "" - functionConfig.NoGrammar = true - results := ParseFunctionCall(input, functionConfig) Expect(results).To(HaveLen(0)) - - functionConfig.NoGrammar = false - - results = ParseFunctionCall(input, functionConfig) - Expect(results).To(HaveLen(0)) }) It("returns no results when is invalid", func() { input := "invalid input" - functionConfig.NoGrammar = true results := ParseFunctionCall(input, functionConfig) Expect(results).To(HaveLen(0)) - functionConfig.NoGrammar = false - - results = ParseFunctionCall(input, functionConfig) - Expect(results).To(HaveLen(0)) }) }) Context("when parallel calls are enabled", func() { It("should handle multiple function calls", func() { input := `[{"function": "add", "arguments": {"x": 5, "y": 3}}, {"function": "subtract", "arguments": {"x": 10, "y": 7}}]` - functionConfig.ParallelCalls = true - functionConfig.NoGrammar = false results := ParseFunctionCall(input, functionConfig) Expect(results).To(HaveLen(2)) @@ -86,9 +66,6 @@ var _ = Describe("LocalAI function parse tests", func() { Context("without grammars and without regex", func() { It("should parse the function name and arguments correctly with the name key", func() { input := `{"name": "add", "arguments": {"x": 5, "y": 3}}` - functionConfig.ParallelCalls = false - functionConfig.NoGrammar = true - functionConfig.ResponseRegex = "" functionConfig.FunctionName = true results := ParseFunctionCall(input, functionConfig) @@ -99,10 +76,6 @@ var _ = Describe("LocalAI function parse tests", func() { It("should parse the function name and arguments correctly with the function key", func() { input := `{"function": "add", "arguments": {"x": 5, "y": 3}}` - functionConfig.ParallelCalls = false - functionConfig.NoGrammar = true - functionConfig.ResponseRegex = "" - functionConfig.FunctionName = false results := ParseFunctionCall(input, functionConfig) Expect(results).To(HaveLen(1)) @@ -115,11 +88,8 @@ var _ = Describe("LocalAI function parse tests", func() { {"function": "add", "arguments": {"x": 5, "y": 3}} ` - functionConfig.ParallelCalls = false - functionConfig.NoGrammar = true + functionConfig.JSONRegexMatch = `(?s)(.*?)` - functionConfig.ResponseRegex = "" - functionConfig.FunctionName = false results := ParseFunctionCall(input, functionConfig) Expect(results).To(HaveLen(1)) @@ -131,11 +101,8 @@ var _ = Describe("LocalAI function parse tests", func() { input := ` {"function": "add", "arguments": {"x": 5, "y": 3}} ` - functionConfig.ParallelCalls = false - functionConfig.NoGrammar = true + functionConfig.JSONRegexMatch = `(?s)(.*?)` - functionConfig.ResponseRegex = "" - functionConfig.FunctionName = false results := ParseFunctionCall(input, functionConfig) Expect(results).To(HaveLen(1))