feat(tools): Parallel function calling (#1726)

feat(tools): support returning multiple tools choices

Fixes: https://github.com/mudler/LocalAI/issues/1275
This commit is contained in:
Ettore Di Giacinto 2024-02-20 21:58:45 +01:00 committed by GitHub
parent ed3b50622b
commit 960d314e4f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 235 additions and 108 deletions

View File

@ -148,6 +148,7 @@ type Functions struct {
DisableNoAction bool `yaml:"disable_no_action"` DisableNoAction bool `yaml:"disable_no_action"`
NoActionFunctionName string `yaml:"no_action_function_name"` NoActionFunctionName string `yaml:"no_action_function_name"`
NoActionDescriptionName string `yaml:"no_action_description_name"` NoActionDescriptionName string `yaml:"no_action_description_name"`
ParallelCalls bool `yaml:"parallel_calls"`
} }
type TemplateConfig struct { type TemplateConfig struct {

View File

@ -64,11 +64,11 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
return true return true
}) })
ss := map[string]interface{}{} results := parseFunctionCall(result, config.FunctionsConfig.ParallelCalls)
name, args := parseFunctionCall(result) noActionToRun := len(results) > 0 && results[0].name == noAction
ss["name"], ss["arguments"] = name, args
if name == noAction { switch {
case noActionToRun:
initialMessage := schema.OpenAIResponse{ initialMessage := schema.OpenAIResponse{
ID: id, ID: id,
Created: created, Created: created,
@ -78,7 +78,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
} }
responses <- initialMessage responses <- initialMessage
result, err := handleQuestion(config, req, o, args, prompt) result, err := handleQuestion(config, req, o, results[0].arguments, prompt)
if err != nil { if err != nil {
log.Error().Msgf("error handling question: %s", err.Error()) log.Error().Msgf("error handling question: %s", err.Error())
return return
@ -98,9 +98,10 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
} }
responses <- resp responses <- resp
close(responses)
return default:
} for i, ss := range results {
name, args := ss.name, ss.arguments
initialMessage := schema.OpenAIResponse{ initialMessage := schema.OpenAIResponse{
ID: id, ID: id,
@ -111,7 +112,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
Role: "assistant", Role: "assistant",
ToolCalls: []schema.ToolCall{ ToolCalls: []schema.ToolCall{
{ {
Index: 0, Index: i,
ID: id, ID: id,
Type: "function", Type: "function",
FunctionCall: schema.FunctionCall{ FunctionCall: schema.FunctionCall{
@ -133,7 +134,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
Role: "assistant", Role: "assistant",
ToolCalls: []schema.ToolCall{ ToolCalls: []schema.ToolCall{
{ {
Index: 0, Index: i,
ID: id, ID: id,
Type: "function", Type: "function",
FunctionCall: schema.FunctionCall{ FunctionCall: schema.FunctionCall{
@ -144,6 +145,9 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
}}}, }}},
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
} }
}
}
close(responses) close(responses)
} }
@ -208,9 +212,9 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
// Update input grammar // Update input grammar
jsStruct := funcs.ToJSONStructure() jsStruct := funcs.ToJSONStructure()
config.Grammar = jsStruct.Grammar("") config.Grammar = jsStruct.Grammar("", config.FunctionsConfig.ParallelCalls)
} else if input.JSONFunctionGrammarObject != nil { } else if input.JSONFunctionGrammarObject != nil {
config.Grammar = input.JSONFunctionGrammarObject.Grammar("") config.Grammar = input.JSONFunctionGrammarObject.Grammar("", config.FunctionsConfig.ParallelCalls)
} }
// functions are not supported in stream mode (yet?) // functions are not supported in stream mode (yet?)
@ -407,31 +411,45 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
})) }))
return nil return nil
// no streaming mode
default: default:
result, tokenUsage, err := ComputeChoices(input, predInput, config, o, o.Loader, func(s string, c *[]schema.Choice) { result, tokenUsage, err := ComputeChoices(input, predInput, config, o, o.Loader, func(s string, c *[]schema.Choice) {
if processFunctions { if !processFunctions {
ss := map[string]interface{}{} // no function is called, just reply and use stop as finish reason
*c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}})
return
}
name, args := parseFunctionCall(s) results := parseFunctionCall(s, config.FunctionsConfig.ParallelCalls)
ss["name"], ss["arguments"] = name, args noActionsToRun := len(results) > 0 && results[0].name == noActionName
// if do nothing, reply with a message switch {
if name == noActionName { case noActionsToRun:
result, err := handleQuestion(config, input, o, args, predInput) result, err := handleQuestion(config, input, o, results[0].arguments, predInput)
if err != nil { if err != nil {
log.Error().Msgf("error handling question: %s", err.Error()) log.Error().Msgf("error handling question: %s", err.Error())
return return
} }
*c = append(*c, schema.Choice{Message: &schema.Message{Role: "assistant", Content: &result}})
} else {
if len(input.Tools) > 0 {
// Result is different in the case we have a tool call
*c = append(*c, schema.Choice{ *c = append(*c, schema.Choice{
FinishReason: "tool_calls", Message: &schema.Message{Role: "assistant", Content: &result}})
default:
toolChoice := schema.Choice{
Message: &schema.Message{ Message: &schema.Message{
Role: "assistant", Role: "assistant",
ToolCalls: []schema.ToolCall{ },
{ }
if len(input.Tools) > 0 {
toolChoice.FinishReason = "tool_calls"
}
for _, ss := range results {
name, args := ss.name, ss.arguments
if len(input.Tools) > 0 {
// If we are using tools, we condense the function calls into
// a single response choice with all the tools
toolChoice.Message.ToolCalls = append(toolChoice.Message.ToolCalls,
schema.ToolCall{
ID: id, ID: id,
Type: "function", Type: "function",
FunctionCall: schema.FunctionCall{ FunctionCall: schema.FunctionCall{
@ -439,25 +457,28 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
Arguments: args, Arguments: args,
}, },
}, },
}, )
},
})
} else { } else {
// otherwise reply with the function call // otherwise we return more choices directly
*c = append(*c, schema.Choice{ *c = append(*c, schema.Choice{
FinishReason: "function_call", FinishReason: "function_call",
Message: &schema.Message{ Message: &schema.Message{
Role: "assistant", Role: "assistant",
FunctionCall: ss, FunctionCall: map[string]interface{}{
"name": name,
"arguments": args,
},
}, },
}) })
} }
} }
return if len(input.Tools) > 0 {
// we need to append our result if we are using tools
*c = append(*c, toolChoice)
}
} }
*c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}})
}, nil) }, nil)
if err != nil { if err != nil {
return err return err
@ -528,7 +549,28 @@ func handleQuestion(config *config.Config, input *schema.OpenAIRequest, o *optio
return backend.Finetune(*config, prompt, prediction.Response), nil return backend.Finetune(*config, prompt, prediction.Response), nil
} }
func parseFunctionCall(llmresult string) (string, string) { type funcCallResults struct {
name string
arguments string
}
func parseFunctionCall(llmresult string, multipleResults bool) []funcCallResults {
results := []funcCallResults{}
// TODO: use generics to avoid this code duplication
if multipleResults {
ss := []map[string]interface{}{}
s := utils.EscapeNewLines(llmresult)
json.Unmarshal([]byte(s), &ss)
log.Debug().Msgf("Function return: %s %+v", s, ss)
for _, s := range ss {
func_name := s["function"]
args := s["arguments"]
d, _ := json.Marshal(args)
results = append(results, funcCallResults{name: func_name.(string), arguments: string(d)})
}
} else {
// As we have to change the result before processing, we can't stream the answer token-by-token (yet?) // As we have to change the result before processing, we can't stream the answer token-by-token (yet?)
ss := map[string]interface{}{} ss := map[string]interface{}{}
// This prevent newlines to break JSON parsing for clients // This prevent newlines to break JSON parsing for clients
@ -542,5 +584,8 @@ func parseFunctionCall(llmresult string) (string, string) {
args := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix) args := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
d, _ := json.Marshal(args) d, _ := json.Marshal(args)
return func_name.(string), string(d) results = append(results, funcCallResults{name: func_name.(string), arguments: string(d)})
}
return results
} }

View File

@ -105,11 +105,28 @@ func (sc *JSONSchemaConverter) addRule(name, rule string) string {
return key return key
} }
func (sc *JSONSchemaConverter) formatGrammar() string { const array = `arr ::=
"[\n" (
realvalue
(",\n" realvalue)*
)? "]"`
func (sc *JSONSchemaConverter) finalizeGrammar(maybeArray bool) string {
var lines []string var lines []string
// write down the computed rules.
// if maybeArray is true, we need to add the array rule and slightly tweak the root rule
for name, rule := range sc.rules { for name, rule := range sc.rules {
if maybeArray && name == "root" {
name = "realvalue"
}
lines = append(lines, fmt.Sprintf("%s ::= %s", name, rule)) lines = append(lines, fmt.Sprintf("%s ::= %s", name, rule))
} }
if maybeArray {
lines = append(lines, fmt.Sprintf("%s ::= %s", "root", "arr | realvalue"))
lines = append(lines, array)
}
return strings.Join(lines, "\n") return strings.Join(lines, "\n")
} }
@ -234,15 +251,15 @@ func (sc *JSONSchemaConverter) resolveReference(ref string, rootSchema map[strin
return def return def
} }
func (sc *JSONSchemaConverter) Grammar(schema map[string]interface{}) string { func (sc *JSONSchemaConverter) Grammar(schema map[string]interface{}, maybeArray bool) string {
sc.visit(schema, "", schema) sc.visit(schema, "", schema)
return sc.formatGrammar() return sc.finalizeGrammar(maybeArray)
} }
func (sc *JSONSchemaConverter) GrammarFromBytes(b []byte) string { func (sc *JSONSchemaConverter) GrammarFromBytes(b []byte, maybeArray bool) string {
var schema map[string]interface{} var schema map[string]interface{}
_ = json.Unmarshal(b, &schema) _ = json.Unmarshal(b, &schema)
return sc.Grammar(schema) return sc.Grammar(schema, maybeArray)
} }
func jsonString(v interface{}) string { func jsonString(v interface{}) string {
@ -275,7 +292,7 @@ type JSONFunctionStructure struct {
Defs map[string]interface{} `json:"$defs,omitempty"` Defs map[string]interface{} `json:"$defs,omitempty"`
} }
func (j JSONFunctionStructure) Grammar(propOrder string) string { func (j JSONFunctionStructure) Grammar(propOrder string, maybeArray bool) string {
dat, _ := json.Marshal(j) dat, _ := json.Marshal(j)
return NewJSONSchemaConverter(propOrder).GrammarFromBytes(dat) return NewJSONSchemaConverter(propOrder).GrammarFromBytes(dat, maybeArray)
} }

View File

@ -52,13 +52,32 @@ string ::= "\"" (
[^"\\] | [^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" space )* "\"" space
root-1-function ::= "\"search\""`
inputResult2 = `root-0-function ::= "\"create_event\""
root-0 ::= "{" space "\"arguments\"" space ":" space root-0-arguments "," space "\"function\"" space ":" space root-0-function "}" space
root-1-arguments ::= "{" space "\"query\"" space ":" space string "}" space
realvalue ::= root-0 | root-1
root ::= arr | realvalue
space ::= " "?
root-0-arguments ::= "{" space "\"date\"" space ":" space string "," space "\"time\"" space ":" space string "," space "\"title\"" space ":" space string "}" space
root-1 ::= "{" space "\"arguments\"" space ":" space root-1-arguments "," space "\"function\"" space ":" space root-1-function "}" space
string ::= "\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" space
arr ::=
"[\n" (
realvalue
(",\n" realvalue)*
)? "]"
root-1-function ::= "\"search\""` root-1-function ::= "\"search\""`
) )
var _ = Describe("JSON schema grammar tests", func() { var _ = Describe("JSON schema grammar tests", func() {
Context("JSON", func() { Context("JSON", func() {
It("generates a valid grammar from JSON schema", func() { It("generates a valid grammar from JSON schema", func() {
grammar := NewJSONSchemaConverter("").GrammarFromBytes([]byte(testInput1)) grammar := NewJSONSchemaConverter("").GrammarFromBytes([]byte(testInput1), false)
results := strings.Split(inputResult1, "\n") results := strings.Split(inputResult1, "\n")
for _, r := range results { for _, r := range results {
if r != "" { if r != "" {
@ -103,7 +122,7 @@ var _ = Describe("JSON schema grammar tests", func() {
}, },
}} }}
grammar := structuredGrammar.Grammar("") grammar := structuredGrammar.Grammar("", false)
results := strings.Split(inputResult1, "\n") results := strings.Split(inputResult1, "\n")
for _, r := range results { for _, r := range results {
if r != "" { if r != "" {
@ -112,5 +131,50 @@ var _ = Describe("JSON schema grammar tests", func() {
} }
Expect(len(results)).To(Equal(len(strings.Split(grammar, "\n")))) Expect(len(results)).To(Equal(len(strings.Split(grammar, "\n"))))
}) })
It("generates a valid grammar from JSON Objects for multiple function return", func() {
structuredGrammar := JSONFunctionStructure{
OneOf: []Item{
{
Type: "object",
Properties: Properties{
Function: FunctionName{
Const: "create_event",
},
Arguments: Argument{ // this is OpenAI's parameter
Type: "object",
Properties: map[string]interface{}{
"title": map[string]string{"type": "string"},
"date": map[string]string{"type": "string"},
"time": map[string]string{"type": "string"},
},
},
},
},
{
Type: "object",
Properties: Properties{
Function: FunctionName{
Const: "search",
},
Arguments: Argument{
Type: "object",
Properties: map[string]interface{}{
"query": map[string]string{"type": "string"},
},
},
},
},
}}
grammar := structuredGrammar.Grammar("", true)
results := strings.Split(inputResult2, "\n")
for _, r := range results {
if r != "" {
Expect(grammar).To(ContainSubstring(r))
}
}
Expect(len(results)).To(Equal(len(strings.Split(grammar, "\n"))), grammar)
})
}) })
}) })