feat(functions): allow response_regex to be a list (#2447)

feat(functions): allow regex match to be a list

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2024-05-31 22:52:02 +02:00 committed by GitHub
parent ff8a6962cd
commit 5d31e5269d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 19 additions and 16 deletions

View File

@ -93,8 +93,9 @@ parameters:
function: function:
# set to true to not use grammars # set to true to not use grammars
no_grammar: true no_grammar: true
# set a regex to extract the function tool arguments from the LLM response # set one or more regexes used to extract the function tool arguments from the LLM response
response_regex: "(?P<function>\w+)\s*\((?P<arguments>.*)\)" response_regex:
- "(?P<function>\w+)\s*\((?P<arguments>.*)\)"
``` ```
The response regex have to be a regex with named parameters to allow to scan the function name and the arguments. For instance, consider: The response regex have to be a regex with named parameters to allow to scan the function name and the arguments. For instance, consider:

View File

@ -52,7 +52,7 @@ type FunctionsConfig struct {
NoActionDescriptionName string `yaml:"no_action_description_name"` NoActionDescriptionName string `yaml:"no_action_description_name"`
// ResponseRegex is a named regex to extract the function name and arguments from the response // ResponseRegex is a named regex to extract the function name and arguments from the response
ResponseRegex string `yaml:"response_regex"` ResponseRegex []string `yaml:"response_regex"`
// JSONRegexMatch is a regex to extract the JSON object from the response // JSONRegexMatch is a regex to extract the JSON object from the response
JSONRegexMatch []string `yaml:"json_regex_match"` JSONRegexMatch []string `yaml:"json_regex_match"`
@ -228,11 +228,12 @@ func ParseFunctionCall(llmresult string, functionConfig FunctionsConfig) []FuncC
} }
} }
if functionConfig.ResponseRegex != "" { if len(functionConfig.ResponseRegex) > 0 {
// We use named regexes here to extract the function name and arguments // We use named regexes here to extract the function name and arguments
// obviously, this expects the LLM to be stable and return correctly formatted JSON // obviously, this expects the LLM to be stable and return correctly formatted JSON
// TODO: optimize this and pre-compile it // TODO: optimize this and pre-compile it
var respRegex = regexp.MustCompile(functionConfig.ResponseRegex) for _, r := range functionConfig.ResponseRegex {
var respRegex = regexp.MustCompile(r)
matches := respRegex.FindAllStringSubmatch(llmresult, -1) matches := respRegex.FindAllStringSubmatch(llmresult, -1)
for _, match := range matches { for _, match := range matches {
for i, name := range respRegex.SubexpNames() { for i, name := range respRegex.SubexpNames() {
@ -247,6 +248,7 @@ func ParseFunctionCall(llmresult string, functionConfig FunctionsConfig) []FuncC
} }
results = append(results, FuncCallResults{Name: result[functionNameKey], Arguments: result["arguments"]}) results = append(results, FuncCallResults{Name: result[functionNameKey], Arguments: result["arguments"]})
} }
}
} else { } else {
if len(llmResults) == 0 { if len(llmResults) == 0 {
llmResults = append(llmResults, llmresult) llmResults = append(llmResults, llmresult)

View File

@ -28,7 +28,7 @@ var _ = Describe("LocalAI function parse tests", func() {
Context("when not using grammars and regex is needed", func() { Context("when not using grammars and regex is needed", func() {
It("should extract function name and arguments from the regex", func() { It("should extract function name and arguments from the regex", func() {
input := `add({"x":5,"y":3})` input := `add({"x":5,"y":3})`
functionConfig.ResponseRegex = `(?P<function>\w+)\s*\((?P<arguments>.*)\)` functionConfig.ResponseRegex = []string{`(?P<function>\w+)\s*\((?P<arguments>.*)\)`}
results := ParseFunctionCall(input, functionConfig) results := ParseFunctionCall(input, functionConfig)
Expect(results).To(HaveLen(1)) Expect(results).To(HaveLen(1))