feat(tools): Parallel function calling (#1726)

feat(tools): support returning multiple tools choices

Fixes: https://github.com/mudler/LocalAI/issues/1275
This commit is contained in:
Ettore Di Giacinto 2024-02-20 21:58:45 +01:00 committed by GitHub
parent ed3b50622b
commit 960d314e4f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 235 additions and 108 deletions

View File

@ -148,6 +148,7 @@ type Functions struct {
DisableNoAction bool `yaml:"disable_no_action"`
NoActionFunctionName string `yaml:"no_action_function_name"`
NoActionDescriptionName string `yaml:"no_action_description_name"`
ParallelCalls bool `yaml:"parallel_calls"`
}
type TemplateConfig struct {

View File

@ -64,11 +64,11 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
return true
})
ss := map[string]interface{}{}
name, args := parseFunctionCall(result)
ss["name"], ss["arguments"] = name, args
results := parseFunctionCall(result, config.FunctionsConfig.ParallelCalls)
noActionToRun := len(results) > 0 && results[0].name == noAction
if name == noAction {
switch {
case noActionToRun:
initialMessage := schema.OpenAIResponse{
ID: id,
Created: created,
@ -78,7 +78,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
}
responses <- initialMessage
result, err := handleQuestion(config, req, o, args, prompt)
result, err := handleQuestion(config, req, o, results[0].arguments, prompt)
if err != nil {
log.Error().Msgf("error handling question: %s", err.Error())
return
@ -98,52 +98,56 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
}
responses <- resp
close(responses)
return
default:
for i, ss := range results {
name, args := ss.name, ss.arguments
initialMessage := schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{
Delta: &schema.Message{
Role: "assistant",
ToolCalls: []schema.ToolCall{
{
Index: i,
ID: id,
Type: "function",
FunctionCall: schema.FunctionCall{
Name: name,
},
},
},
}}},
Object: "chat.completion.chunk",
}
responses <- initialMessage
responses <- schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{
Delta: &schema.Message{
Role: "assistant",
ToolCalls: []schema.ToolCall{
{
Index: i,
ID: id,
Type: "function",
FunctionCall: schema.FunctionCall{
Arguments: args,
},
},
},
}}},
Object: "chat.completion.chunk",
}
}
}
initialMessage := schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{
Delta: &schema.Message{
Role: "assistant",
ToolCalls: []schema.ToolCall{
{
Index: 0,
ID: id,
Type: "function",
FunctionCall: schema.FunctionCall{
Name: name,
},
},
},
}}},
Object: "chat.completion.chunk",
}
responses <- initialMessage
responses <- schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{
Delta: &schema.Message{
Role: "assistant",
ToolCalls: []schema.ToolCall{
{
Index: 0,
ID: id,
Type: "function",
FunctionCall: schema.FunctionCall{
Arguments: args,
},
},
},
}}},
Object: "chat.completion.chunk",
}
close(responses)
}
@ -208,9 +212,9 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
// Update input grammar
jsStruct := funcs.ToJSONStructure()
config.Grammar = jsStruct.Grammar("")
config.Grammar = jsStruct.Grammar("", config.FunctionsConfig.ParallelCalls)
} else if input.JSONFunctionGrammarObject != nil {
config.Grammar = input.JSONFunctionGrammarObject.Grammar("")
config.Grammar = input.JSONFunctionGrammarObject.Grammar("", config.FunctionsConfig.ParallelCalls)
}
// functions are not supported in stream mode (yet?)
@ -407,57 +411,74 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
}))
return nil
// no streaming mode
default:
result, tokenUsage, err := ComputeChoices(input, predInput, config, o, o.Loader, func(s string, c *[]schema.Choice) {
if processFunctions {
ss := map[string]interface{}{}
if !processFunctions {
// no function is called, just reply and use stop as finish reason
*c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}})
return
}
name, args := parseFunctionCall(s)
ss["name"], ss["arguments"] = name, args
results := parseFunctionCall(s, config.FunctionsConfig.ParallelCalls)
noActionsToRun := len(results) > 0 && results[0].name == noActionName
// if do nothing, reply with a message
if name == noActionName {
result, err := handleQuestion(config, input, o, args, predInput)
if err != nil {
log.Error().Msgf("error handling question: %s", err.Error())
return
}
*c = append(*c, schema.Choice{Message: &schema.Message{Role: "assistant", Content: &result}})
} else {
switch {
case noActionsToRun:
result, err := handleQuestion(config, input, o, results[0].arguments, predInput)
if err != nil {
log.Error().Msgf("error handling question: %s", err.Error())
return
}
*c = append(*c, schema.Choice{
Message: &schema.Message{Role: "assistant", Content: &result}})
default:
toolChoice := schema.Choice{
Message: &schema.Message{
Role: "assistant",
},
}
if len(input.Tools) > 0 {
toolChoice.FinishReason = "tool_calls"
}
for _, ss := range results {
name, args := ss.name, ss.arguments
if len(input.Tools) > 0 {
// Result is different in the case we have a tool call
*c = append(*c, schema.Choice{
FinishReason: "tool_calls",
Message: &schema.Message{
Role: "assistant",
ToolCalls: []schema.ToolCall{
{
ID: id,
Type: "function",
FunctionCall: schema.FunctionCall{
Name: name,
Arguments: args,
},
},
// If we are using tools, we condense the function calls into
// a single response choice with all the tools
toolChoice.Message.ToolCalls = append(toolChoice.Message.ToolCalls,
schema.ToolCall{
ID: id,
Type: "function",
FunctionCall: schema.FunctionCall{
Name: name,
Arguments: args,
},
},
})
)
} else {
// otherwise reply with the function call
// otherwise we return more choices directly
*c = append(*c, schema.Choice{
FinishReason: "function_call",
Message: &schema.Message{
Role: "assistant",
FunctionCall: ss,
Role: "assistant",
FunctionCall: map[string]interface{}{
"name": name,
"arguments": args,
},
},
})
}
}
return
if len(input.Tools) > 0 {
// we need to append our result if we are using tools
*c = append(*c, toolChoice)
}
}
*c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}})
}, nil)
if err != nil {
return err
@ -528,19 +549,43 @@ func handleQuestion(config *config.Config, input *schema.OpenAIRequest, o *optio
return backend.Finetune(*config, prompt, prediction.Response), nil
}
func parseFunctionCall(llmresult string) (string, string) {
// As we have to change the result before processing, we can't stream the answer token-by-token (yet?)
ss := map[string]interface{}{}
// This prevent newlines to break JSON parsing for clients
s := utils.EscapeNewLines(llmresult)
json.Unmarshal([]byte(s), &ss)
log.Debug().Msgf("Function return: %s %+v", s, ss)
// The grammar defines the function name as "function", while OpenAI returns "name"
func_name := ss["function"]
// Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
args := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
d, _ := json.Marshal(args)
return func_name.(string), string(d)
type funcCallResults struct {
name string
arguments string
}
func parseFunctionCall(llmresult string, multipleResults bool) []funcCallResults {
results := []funcCallResults{}
// TODO: use generics to avoid this code duplication
if multipleResults {
ss := []map[string]interface{}{}
s := utils.EscapeNewLines(llmresult)
json.Unmarshal([]byte(s), &ss)
log.Debug().Msgf("Function return: %s %+v", s, ss)
for _, s := range ss {
func_name := s["function"]
args := s["arguments"]
d, _ := json.Marshal(args)
results = append(results, funcCallResults{name: func_name.(string), arguments: string(d)})
}
} else {
// As we have to change the result before processing, we can't stream the answer token-by-token (yet?)
ss := map[string]interface{}{}
// This prevent newlines to break JSON parsing for clients
s := utils.EscapeNewLines(llmresult)
json.Unmarshal([]byte(s), &ss)
log.Debug().Msgf("Function return: %s %+v", s, ss)
// The grammar defines the function name as "function", while OpenAI returns "name"
func_name := ss["function"]
// Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
args := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
d, _ := json.Marshal(args)
results = append(results, funcCallResults{name: func_name.(string), arguments: string(d)})
}
return results
}

View File

@ -105,11 +105,28 @@ func (sc *JSONSchemaConverter) addRule(name, rule string) string {
return key
}
func (sc *JSONSchemaConverter) formatGrammar() string {
const array = `arr ::=
"[\n" (
realvalue
(",\n" realvalue)*
)? "]"`
func (sc *JSONSchemaConverter) finalizeGrammar(maybeArray bool) string {
var lines []string
// write down the computed rules.
// if maybeArray is true, we need to add the array rule and slightly tweak the root rule
for name, rule := range sc.rules {
if maybeArray && name == "root" {
name = "realvalue"
}
lines = append(lines, fmt.Sprintf("%s ::= %s", name, rule))
}
if maybeArray {
lines = append(lines, fmt.Sprintf("%s ::= %s", "root", "arr | realvalue"))
lines = append(lines, array)
}
return strings.Join(lines, "\n")
}
@ -234,15 +251,15 @@ func (sc *JSONSchemaConverter) resolveReference(ref string, rootSchema map[strin
return def
}
func (sc *JSONSchemaConverter) Grammar(schema map[string]interface{}) string {
func (sc *JSONSchemaConverter) Grammar(schema map[string]interface{}, maybeArray bool) string {
sc.visit(schema, "", schema)
return sc.formatGrammar()
return sc.finalizeGrammar(maybeArray)
}
func (sc *JSONSchemaConverter) GrammarFromBytes(b []byte) string {
func (sc *JSONSchemaConverter) GrammarFromBytes(b []byte, maybeArray bool) string {
var schema map[string]interface{}
_ = json.Unmarshal(b, &schema)
return sc.Grammar(schema)
return sc.Grammar(schema, maybeArray)
}
func jsonString(v interface{}) string {
@ -275,7 +292,7 @@ type JSONFunctionStructure struct {
Defs map[string]interface{} `json:"$defs,omitempty"`
}
func (j JSONFunctionStructure) Grammar(propOrder string) string {
func (j JSONFunctionStructure) Grammar(propOrder string, maybeArray bool) string {
dat, _ := json.Marshal(j)
return NewJSONSchemaConverter(propOrder).GrammarFromBytes(dat)
return NewJSONSchemaConverter(propOrder).GrammarFromBytes(dat, maybeArray)
}

View File

@ -52,13 +52,32 @@ string ::= "\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" space
root-1-function ::= "\"search\""`
inputResult2 = `root-0-function ::= "\"create_event\""
root-0 ::= "{" space "\"arguments\"" space ":" space root-0-arguments "," space "\"function\"" space ":" space root-0-function "}" space
root-1-arguments ::= "{" space "\"query\"" space ":" space string "}" space
realvalue ::= root-0 | root-1
root ::= arr | realvalue
space ::= " "?
root-0-arguments ::= "{" space "\"date\"" space ":" space string "," space "\"time\"" space ":" space string "," space "\"title\"" space ":" space string "}" space
root-1 ::= "{" space "\"arguments\"" space ":" space root-1-arguments "," space "\"function\"" space ":" space root-1-function "}" space
string ::= "\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" space
arr ::=
"[\n" (
realvalue
(",\n" realvalue)*
)? "]"
root-1-function ::= "\"search\""`
)
var _ = Describe("JSON schema grammar tests", func() {
Context("JSON", func() {
It("generates a valid grammar from JSON schema", func() {
grammar := NewJSONSchemaConverter("").GrammarFromBytes([]byte(testInput1))
grammar := NewJSONSchemaConverter("").GrammarFromBytes([]byte(testInput1), false)
results := strings.Split(inputResult1, "\n")
for _, r := range results {
if r != "" {
@ -103,7 +122,7 @@ var _ = Describe("JSON schema grammar tests", func() {
},
}}
grammar := structuredGrammar.Grammar("")
grammar := structuredGrammar.Grammar("", false)
results := strings.Split(inputResult1, "\n")
for _, r := range results {
if r != "" {
@ -112,5 +131,50 @@ var _ = Describe("JSON schema grammar tests", func() {
}
Expect(len(results)).To(Equal(len(strings.Split(grammar, "\n"))))
})
It("generates a valid grammar from JSON Objects for multiple function return", func() {
structuredGrammar := JSONFunctionStructure{
OneOf: []Item{
{
Type: "object",
Properties: Properties{
Function: FunctionName{
Const: "create_event",
},
Arguments: Argument{ // this is OpenAI's parameter
Type: "object",
Properties: map[string]interface{}{
"title": map[string]string{"type": "string"},
"date": map[string]string{"type": "string"},
"time": map[string]string{"type": "string"},
},
},
},
},
{
Type: "object",
Properties: Properties{
Function: FunctionName{
Const: "search",
},
Arguments: Argument{
Type: "object",
Properties: map[string]interface{}{
"query": map[string]string{"type": "string"},
},
},
},
},
}}
grammar := structuredGrammar.Grammar("", true)
results := strings.Split(inputResult2, "\n")
for _, r := range results {
if r != "" {
Expect(grammar).To(ContainSubstring(r))
}
}
Expect(len(results)).To(Equal(len(strings.Split(grammar, "\n"))), grammar)
})
})
})