mirror of
https://github.com/mudler/LocalAI.git
synced 2024-12-24 06:46:39 +00:00
feat(functions): allow to set JSON matcher (#2319)
Signed-off-by: mudler <mudler@localai.io>
This commit is contained in:
parent
c4186f13c3
commit
84e2407afa
@ -81,7 +81,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
|
|||||||
}
|
}
|
||||||
responses <- initialMessage
|
responses <- initialMessage
|
||||||
|
|
||||||
result, err := handleQuestion(config, req, ml, startupOptions, results, prompt)
|
result, err := handleQuestion(config, req, ml, startupOptions, results, result, prompt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("error handling question")
|
log.Error().Err(err).Msg("error handling question")
|
||||||
return
|
return
|
||||||
@ -470,7 +470,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
|
|||||||
|
|
||||||
switch {
|
switch {
|
||||||
case noActionsToRun:
|
case noActionsToRun:
|
||||||
result, err := handleQuestion(config, input, ml, startupOptions, results, predInput)
|
result, err := handleQuestion(config, input, ml, startupOptions, results, s, predInput)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("error handling question")
|
log.Error().Err(err).Msg("error handling question")
|
||||||
return
|
return
|
||||||
@ -550,7 +550,14 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleQuestion(config *config.BackendConfig, input *schema.OpenAIRequest, ml *model.ModelLoader, o *config.ApplicationConfig, funcResults []functions.FuncCallResults, prompt string) (string, error) {
|
func handleQuestion(config *config.BackendConfig, input *schema.OpenAIRequest, ml *model.ModelLoader, o *config.ApplicationConfig, funcResults []functions.FuncCallResults, result, prompt string) (string, error) {
|
||||||
|
|
||||||
|
if len(funcResults) == 0 && result != "" {
|
||||||
|
log.Debug().Msgf("nothing function results but we had a message from the LLM")
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
log.Debug().Msgf("nothing to do, computing a reply")
|
log.Debug().Msgf("nothing to do, computing a reply")
|
||||||
arg := ""
|
arg := ""
|
||||||
if len(funcResults) > 0 {
|
if len(funcResults) > 0 {
|
||||||
|
@ -2,6 +2,7 @@ package functions
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||||
@ -16,6 +17,8 @@ type FunctionsConfig struct {
|
|||||||
NoGrammar bool `yaml:"no_grammar"`
|
NoGrammar bool `yaml:"no_grammar"`
|
||||||
ResponseRegex string `yaml:"response_regex"`
|
ResponseRegex string `yaml:"response_regex"`
|
||||||
|
|
||||||
|
JSONRegexMatch string `yaml:"json_regex_match"`
|
||||||
|
|
||||||
// FunctionName enable the LLM to return { "name": "function_name", "arguments": { "arg1": "value1", "arg2": "value2" } }
|
// FunctionName enable the LLM to return { "name": "function_name", "arguments": { "arg1": "value1", "arg2": "value2" } }
|
||||||
// instead of { "function": "function_name", "arguments": { "arg1": "value1", "arg2": "value2" } }.
|
// instead of { "function": "function_name", "arguments": { "arg1": "value1", "arg2": "value2" } }.
|
||||||
// This might be useful for certain models trained with the function name as the first token.
|
// This might be useful for certain models trained with the function name as the first token.
|
||||||
@ -38,6 +41,36 @@ func ParseFunctionCall(llmresult string, functionConfig FunctionsConfig) []FuncC
|
|||||||
|
|
||||||
results := []FuncCallResults{}
|
results := []FuncCallResults{}
|
||||||
|
|
||||||
|
returnResult := func(s string) (name, arguments string, e error) {
|
||||||
|
// As we have to change the result before processing, we can't stream the answer token-by-token (yet?)
|
||||||
|
ss := map[string]interface{}{}
|
||||||
|
// This prevent newlines to break JSON parsing for clients
|
||||||
|
s = utils.EscapeNewLines(s)
|
||||||
|
err := json.Unmarshal([]byte(s), &ss)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Str("escapedLLMResult", s).Msg("unable to unmarshal llm result")
|
||||||
|
}
|
||||||
|
log.Debug().Msgf("Function return: %s %+v", s, ss)
|
||||||
|
|
||||||
|
// The grammar defines the function name as "function", while OpenAI returns "name"
|
||||||
|
func_name, ok := ss[functionNameKey]
|
||||||
|
if !ok {
|
||||||
|
return "", "", fmt.Errorf("unable to find function name in result")
|
||||||
|
}
|
||||||
|
// Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
|
||||||
|
args, ok := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
|
||||||
|
if !ok {
|
||||||
|
return "", "", fmt.Errorf("unable to find arguments in result")
|
||||||
|
}
|
||||||
|
d, _ := json.Marshal(args)
|
||||||
|
funcName, ok := func_name.(string)
|
||||||
|
if !ok {
|
||||||
|
return "", "", fmt.Errorf("unable to cast function name to string")
|
||||||
|
}
|
||||||
|
|
||||||
|
return funcName, string(d), nil
|
||||||
|
}
|
||||||
|
|
||||||
// if no grammar is used, we have to extract function and arguments from the result
|
// if no grammar is used, we have to extract function and arguments from the result
|
||||||
if !useGrammars {
|
if !useGrammars {
|
||||||
// the response is a string that we have to parse
|
// the response is a string that we have to parse
|
||||||
@ -61,13 +94,32 @@ func ParseFunctionCall(llmresult string, functionConfig FunctionsConfig) []FuncC
|
|||||||
if functionName == "" {
|
if functionName == "" {
|
||||||
return results
|
return results
|
||||||
}
|
}
|
||||||
} else {
|
} else if functionConfig.JSONRegexMatch != "" {
|
||||||
// We expect the result to be a JSON object with a function name and arguments
|
//re := regexp.MustCompile(`(?s)<tool_call>(.*?)</tool_call>`)
|
||||||
err := json.Unmarshal([]byte(llmresult), &result)
|
//m:= re.FindStringSubmatch(`<tool_call>{ foo barr }</tool_call>`)
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Str("llmresult", llmresult).Msg("unable to unmarshal llm result")
|
// We use a regex to extract the JSON object from the response
|
||||||
|
var respRegex = regexp.MustCompile(functionConfig.JSONRegexMatch)
|
||||||
|
match := respRegex.FindStringSubmatch(llmresult)
|
||||||
|
if len(match) < 2 {
|
||||||
return results
|
return results
|
||||||
}
|
}
|
||||||
|
|
||||||
|
funcName, args, err := returnResult(match[1])
|
||||||
|
if err != nil {
|
||||||
|
return results
|
||||||
|
}
|
||||||
|
|
||||||
|
return append(results, FuncCallResults{Name: funcName, Arguments: args})
|
||||||
|
|
||||||
|
} else {
|
||||||
|
|
||||||
|
funcName, args, err := returnResult(llmresult)
|
||||||
|
if err != nil {
|
||||||
|
return results
|
||||||
|
}
|
||||||
|
|
||||||
|
return append(results, FuncCallResults{Name: funcName, Arguments: args})
|
||||||
}
|
}
|
||||||
|
|
||||||
return append(results, FuncCallResults{Name: result[functionNameKey], Arguments: result["arguments"]})
|
return append(results, FuncCallResults{Name: result[functionNameKey], Arguments: result["arguments"]})
|
||||||
@ -101,32 +153,12 @@ func ParseFunctionCall(llmresult string, functionConfig FunctionsConfig) []FuncC
|
|||||||
results = append(results, FuncCallResults{Name: funcName, Arguments: string(d)})
|
results = append(results, FuncCallResults{Name: funcName, Arguments: string(d)})
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// As we have to change the result before processing, we can't stream the answer token-by-token (yet?)
|
funcName, args, err := returnResult(llmresult)
|
||||||
ss := map[string]interface{}{}
|
|
||||||
// This prevent newlines to break JSON parsing for clients
|
|
||||||
s := utils.EscapeNewLines(llmresult)
|
|
||||||
err := json.Unmarshal([]byte(s), &ss)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Str("escapedLLMResult", s).Msg("unable to unmarshal llm result")
|
return results
|
||||||
}
|
}
|
||||||
log.Debug().Msgf("Function return: %s %+v", s, ss)
|
|
||||||
|
|
||||||
// The grammar defines the function name as "function", while OpenAI returns "name"
|
results = append(results, FuncCallResults{Name: funcName, Arguments: args})
|
||||||
func_name, ok := ss[functionNameKey]
|
|
||||||
if !ok {
|
|
||||||
return results
|
|
||||||
}
|
|
||||||
// Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
|
|
||||||
args, ok := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
|
|
||||||
if !ok {
|
|
||||||
return results
|
|
||||||
}
|
|
||||||
d, _ := json.Marshal(args)
|
|
||||||
funcName, ok := func_name.(string)
|
|
||||||
if !ok {
|
|
||||||
return results
|
|
||||||
}
|
|
||||||
results = append(results, FuncCallResults{Name: funcName, Arguments: string(d)})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
@ -87,7 +87,7 @@ var _ = Describe("LocalAI function parse tests", func() {
|
|||||||
It("should parse the function name and arguments correctly with the name key", func() {
|
It("should parse the function name and arguments correctly with the name key", func() {
|
||||||
input := `{"name": "add", "arguments": {"x": 5, "y": 3}}`
|
input := `{"name": "add", "arguments": {"x": 5, "y": 3}}`
|
||||||
functionConfig.ParallelCalls = false
|
functionConfig.ParallelCalls = false
|
||||||
functionConfig.NoGrammar = false
|
functionConfig.NoGrammar = true
|
||||||
functionConfig.ResponseRegex = ""
|
functionConfig.ResponseRegex = ""
|
||||||
functionConfig.FunctionName = true
|
functionConfig.FunctionName = true
|
||||||
|
|
||||||
@ -100,7 +100,40 @@ var _ = Describe("LocalAI function parse tests", func() {
|
|||||||
It("should parse the function name and arguments correctly with the function key", func() {
|
It("should parse the function name and arguments correctly with the function key", func() {
|
||||||
input := `{"function": "add", "arguments": {"x": 5, "y": 3}}`
|
input := `{"function": "add", "arguments": {"x": 5, "y": 3}}`
|
||||||
functionConfig.ParallelCalls = false
|
functionConfig.ParallelCalls = false
|
||||||
functionConfig.NoGrammar = false
|
functionConfig.NoGrammar = true
|
||||||
|
functionConfig.ResponseRegex = ""
|
||||||
|
functionConfig.FunctionName = false
|
||||||
|
|
||||||
|
results := ParseFunctionCall(input, functionConfig)
|
||||||
|
Expect(results).To(HaveLen(1))
|
||||||
|
Expect(results[0].Name).To(Equal("add"))
|
||||||
|
Expect(results[0].Arguments).To(Equal(`{"x":5,"y":3}`))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("Should parse the result by matching the JSONRegexMatch", func() {
|
||||||
|
input := `
|
||||||
|
<tool_call>
|
||||||
|
{"function": "add", "arguments": {"x": 5, "y": 3}}
|
||||||
|
</tool_call>`
|
||||||
|
functionConfig.ParallelCalls = false
|
||||||
|
functionConfig.NoGrammar = true
|
||||||
|
functionConfig.JSONRegexMatch = `(?s)<tool_call>(.*?)</tool_call>`
|
||||||
|
functionConfig.ResponseRegex = ""
|
||||||
|
functionConfig.FunctionName = false
|
||||||
|
|
||||||
|
results := ParseFunctionCall(input, functionConfig)
|
||||||
|
Expect(results).To(HaveLen(1))
|
||||||
|
Expect(results[0].Name).To(Equal("add"))
|
||||||
|
Expect(results[0].Arguments).To(Equal(`{"x":5,"y":3}`))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("Should parse the result by matching the JSONRegexMatch", func() {
|
||||||
|
input := `
|
||||||
|
{"function": "add", "arguments": {"x": 5, "y": 3}}
|
||||||
|
</tool_call>`
|
||||||
|
functionConfig.ParallelCalls = false
|
||||||
|
functionConfig.NoGrammar = true
|
||||||
|
functionConfig.JSONRegexMatch = `(?s)(.*?)</tool_call>`
|
||||||
functionConfig.ResponseRegex = ""
|
functionConfig.ResponseRegex = ""
|
||||||
functionConfig.FunctionName = false
|
functionConfig.FunctionName = false
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user