diff --git a/pkg/functions/parse.go b/pkg/functions/parse.go
index 0246d70e..5327ee6d 100644
--- a/pkg/functions/parse.go
+++ b/pkg/functions/parse.go
@@ -2,7 +2,6 @@ package functions
import (
"encoding/json"
- "fmt"
"regexp"
"strings"
@@ -68,9 +67,6 @@ func ParseFunctionCall(llmresult string, functionConfig FunctionsConfig) []FuncC
log.Debug().Msgf("LLM result(processed): %s", llmresult)
- multipleResults := functionConfig.ParallelCalls
- useGrammars := !functionConfig.NoGrammar
-
functionNameKey := "function"
if functionConfig.FunctionName {
functionNameKey = "name"
@@ -78,124 +74,85 @@ func ParseFunctionCall(llmresult string, functionConfig FunctionsConfig) []FuncC
results := []FuncCallResults{}
- returnResult := func(s string) (name, arguments string, e error) {
+ returnResult := func(s string) (result []FuncCallResults, e error) {
// As we have to change the result before processing, we can't stream the answer token-by-token (yet?)
- ss := map[string]interface{}{}
- // This prevent newlines to break JSON parsing for clients
+ var ss []map[string]interface{}
+ result = make([]FuncCallResults, 0)
s = utils.EscapeNewLines(s)
err := json.Unmarshal([]byte(s), &ss)
if err != nil {
- log.Warn().Err(err).Str("escapedLLMResult", s).Msg("unable to unmarshal llm result")
- }
- log.Debug().Msgf("Function return: %s %+v", s, ss)
-
- // The grammar defines the function name as "function", while OpenAI returns "name"
- func_name, ok := ss[functionNameKey]
- if !ok {
- return "", "", fmt.Errorf("unable to find function name in result")
- }
- // Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
- args, ok := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
- if !ok {
- return "", "", fmt.Errorf("unable to find arguments in result")
- }
- d, _ := json.Marshal(args)
- funcName, ok := func_name.(string)
- if !ok {
- return "", "", fmt.Errorf("unable to cast function name to string")
- }
-
- return funcName, string(d), nil
- }
-
- // if no grammar is used, we have to extract function and arguments from the result
- if !useGrammars {
- // the response is a string that we have to parse
- result := make(map[string]string)
-
- if functionConfig.ResponseRegex != "" {
- // We use named regexes here to extract the function name and arguments
- // obviously, this expects the LLM to be stable and return correctly formatted JSON
- // TODO: optimize this and pre-compile it
- var respRegex = regexp.MustCompile(functionConfig.ResponseRegex)
- match := respRegex.FindStringSubmatch(llmresult)
- for i, name := range respRegex.SubexpNames() {
- if i != 0 && name != "" && len(match) > i {
- result[name] = match[i]
- }
- }
-
- // TODO: open point about multiple results and/or mixed with chat messages
- // This is not handled as for now, we only expect one function call per response
- functionName := result[functionNameKey]
- if functionName == "" {
- return results
- }
- } else if functionConfig.JSONRegexMatch != "" {
- //re := regexp.MustCompile(`(?s)(.*?)`)
- //m:= re.FindStringSubmatch(`{ foo barr }`)
-
- // We use a regex to extract the JSON object from the response
- var respRegex = regexp.MustCompile(functionConfig.JSONRegexMatch)
- match := respRegex.FindStringSubmatch(llmresult)
- if len(match) < 2 {
- return results
- }
-
- funcName, args, err := returnResult(match[1])
+ // If the LLM result is a single object, try unmarshaling it into a single map
+ var singleObj map[string]interface{}
+ err = json.Unmarshal([]byte(s), &singleObj)
if err != nil {
- return results
+ log.Warn().Err(err).Str("escapedLLMResult", s).Msg("unable to unmarshal llm result")
+ } else {
+ ss = []map[string]interface{}{singleObj}
}
-
- return append(results, FuncCallResults{Name: funcName, Arguments: args})
-
- } else {
-
- funcName, args, err := returnResult(llmresult)
- if err != nil {
- return results
- }
-
- return append(results, FuncCallResults{Name: funcName, Arguments: args})
}
- return append(results, FuncCallResults{Name: result[functionNameKey], Arguments: result["arguments"]})
- }
-
- // with grammars
- // TODO: use generics to avoid this code duplication
- if multipleResults {
- ss := []map[string]interface{}{}
- s := utils.EscapeNewLines(llmresult)
- err := json.Unmarshal([]byte(s), &ss)
- if err != nil {
- log.Warn().Err(err).Str("escapedLLMResult", s).Msg("multiple results: unable to unmarshal llm result")
- }
log.Debug().Msgf("Function return: %s %+v", s, ss)
for _, s := range ss {
+ // The grammar defines the function name as "function", while OpenAI returns "name"
func_name, ok := s[functionNameKey]
if !ok {
continue
+ //return result, fmt.Errorf("unable to find function name in result")
}
- args, ok := s["arguments"]
+ // Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
+ args, ok := s["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
if !ok {
continue
+ //return result, fmt.Errorf("unable to find arguments in result")
}
d, _ := json.Marshal(args)
funcName, ok := func_name.(string)
if !ok {
continue
+ //return result, fmt.Errorf("unable to cast function name to string")
}
- results = append(results, FuncCallResults{Name: funcName, Arguments: string(d)})
+
+ result = append(result, FuncCallResults{Name: funcName, Arguments: string(d)})
}
- } else {
- funcName, args, err := returnResult(llmresult)
- if err != nil {
+
+ return result, nil
+ }
+
+ // the response is a string that we have to parse
+ result := make(map[string]string)
+
+ if functionConfig.ResponseRegex != "" {
+ // We use named regexes here to extract the function name and arguments
+ // obviously, this expects the LLM to be stable and return correctly formatted JSON
+ // TODO: optimize this and pre-compile it
+ var respRegex = regexp.MustCompile(functionConfig.ResponseRegex)
+ match := respRegex.FindStringSubmatch(llmresult)
+ for i, name := range respRegex.SubexpNames() {
+ if i != 0 && name != "" && len(match) > i {
+ result[name] = match[i]
+ }
+ }
+
+ // TODO: open point about multiple results and/or mixed with chat messages
+ // This is not handled as for now, we only expect one function call per response
+ functionName := result[functionNameKey]
+ if functionName == "" {
+ return results
+ }
+ results = append(results, FuncCallResults{Name: result[functionNameKey], Arguments: result["arguments"]})
+ } else if functionConfig.JSONRegexMatch != "" {
+
+ // We use a regex to extract the JSON object from the response
+ var respRegex = regexp.MustCompile(functionConfig.JSONRegexMatch)
+ match := respRegex.FindStringSubmatch(llmresult)
+ if len(match) < 2 {
return results
}
- results = append(results, FuncCallResults{Name: funcName, Arguments: args})
+ results, _ = returnResult(match[1])
+ } else {
+ results, _ = returnResult(llmresult)
}
return results
diff --git a/pkg/functions/parse_test.go b/pkg/functions/parse_test.go
index 7aedc097..03a01239 100644
--- a/pkg/functions/parse_test.go
+++ b/pkg/functions/parse_test.go
@@ -11,18 +11,12 @@ var _ = Describe("LocalAI function parse tests", func() {
BeforeEach(func() {
// Default configuration setup
- functionConfig = FunctionsConfig{
- ParallelCalls: false,
- NoGrammar: false,
- ResponseRegex: `(?P\w+)\s*\((?P.*)\)`,
- }
+ functionConfig = FunctionsConfig{}
})
Context("when using grammars and single result expected", func() {
It("should parse the function name and arguments correctly", func() {
input := `{"function": "add", "arguments": {"x": 5, "y": 3}}`
- functionConfig.ParallelCalls = false
- functionConfig.NoGrammar = false
results := ParseFunctionCall(input, functionConfig)
Expect(results).To(HaveLen(1))
@@ -34,7 +28,7 @@ var _ = Describe("LocalAI function parse tests", func() {
Context("when not using grammars and regex is needed", func() {
It("should extract function name and arguments from the regex", func() {
input := `add({"x":5,"y":3})`
- functionConfig.NoGrammar = true
+ functionConfig.ResponseRegex = `(?P\w+)\s*\((?P.*)\)`
results := ParseFunctionCall(input, functionConfig)
Expect(results).To(HaveLen(1))
@@ -46,33 +40,19 @@ var _ = Describe("LocalAI function parse tests", func() {
Context("when having invalid input", func() {
It("returns no results when there is no input", func() {
input := ""
- functionConfig.NoGrammar = true
-
results := ParseFunctionCall(input, functionConfig)
Expect(results).To(HaveLen(0))
-
- functionConfig.NoGrammar = false
-
- results = ParseFunctionCall(input, functionConfig)
- Expect(results).To(HaveLen(0))
})
It("returns no results when is invalid", func() {
input := "invalid input"
- functionConfig.NoGrammar = true
results := ParseFunctionCall(input, functionConfig)
Expect(results).To(HaveLen(0))
- functionConfig.NoGrammar = false
-
- results = ParseFunctionCall(input, functionConfig)
- Expect(results).To(HaveLen(0))
})
})
Context("when parallel calls are enabled", func() {
It("should handle multiple function calls", func() {
input := `[{"function": "add", "arguments": {"x": 5, "y": 3}}, {"function": "subtract", "arguments": {"x": 10, "y": 7}}]`
- functionConfig.ParallelCalls = true
- functionConfig.NoGrammar = false
results := ParseFunctionCall(input, functionConfig)
Expect(results).To(HaveLen(2))
@@ -86,9 +66,6 @@ var _ = Describe("LocalAI function parse tests", func() {
Context("without grammars and without regex", func() {
It("should parse the function name and arguments correctly with the name key", func() {
input := `{"name": "add", "arguments": {"x": 5, "y": 3}}`
- functionConfig.ParallelCalls = false
- functionConfig.NoGrammar = true
- functionConfig.ResponseRegex = ""
functionConfig.FunctionName = true
results := ParseFunctionCall(input, functionConfig)
@@ -99,10 +76,6 @@ var _ = Describe("LocalAI function parse tests", func() {
It("should parse the function name and arguments correctly with the function key", func() {
input := `{"function": "add", "arguments": {"x": 5, "y": 3}}`
- functionConfig.ParallelCalls = false
- functionConfig.NoGrammar = true
- functionConfig.ResponseRegex = ""
- functionConfig.FunctionName = false
results := ParseFunctionCall(input, functionConfig)
Expect(results).To(HaveLen(1))
@@ -115,11 +88,8 @@ var _ = Describe("LocalAI function parse tests", func() {
{"function": "add", "arguments": {"x": 5, "y": 3}}
`
- functionConfig.ParallelCalls = false
- functionConfig.NoGrammar = true
+
functionConfig.JSONRegexMatch = `(?s)(.*?)`
- functionConfig.ResponseRegex = ""
- functionConfig.FunctionName = false
results := ParseFunctionCall(input, functionConfig)
Expect(results).To(HaveLen(1))
@@ -131,11 +101,8 @@ var _ = Describe("LocalAI function parse tests", func() {
input := `
{"function": "add", "arguments": {"x": 5, "y": 3}}
`
- functionConfig.ParallelCalls = false
- functionConfig.NoGrammar = true
+
functionConfig.JSONRegexMatch = `(?s)(.*?)`
- functionConfig.ResponseRegex = ""
- functionConfig.FunctionName = false
results := ParseFunctionCall(input, functionConfig)
Expect(results).To(HaveLen(1))