diff --git a/api/config/config.go b/api/config/config.go index 48d1b791..5ea16828 100644 --- a/api/config/config.go +++ b/api/config/config.go @@ -148,6 +148,7 @@ type Functions struct { DisableNoAction bool `yaml:"disable_no_action"` NoActionFunctionName string `yaml:"no_action_function_name"` NoActionDescriptionName string `yaml:"no_action_description_name"` + ParallelCalls bool `yaml:"parallel_calls"` } type TemplateConfig struct { diff --git a/api/openai/chat.go b/api/openai/chat.go index 68c3a291..d34f2a0c 100644 --- a/api/openai/chat.go +++ b/api/openai/chat.go @@ -64,11 +64,11 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) return true }) - ss := map[string]interface{}{} - name, args := parseFunctionCall(result) - ss["name"], ss["arguments"] = name, args + results := parseFunctionCall(result, config.FunctionsConfig.ParallelCalls) + noActionToRun := len(results) > 0 && results[0].name == noAction - if name == noAction { + switch { + case noActionToRun: initialMessage := schema.OpenAIResponse{ ID: id, Created: created, @@ -78,7 +78,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) } responses <- initialMessage - result, err := handleQuestion(config, req, o, args, prompt) + result, err := handleQuestion(config, req, o, results[0].arguments, prompt) if err != nil { log.Error().Msgf("error handling question: %s", err.Error()) return @@ -98,52 +98,56 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) } responses <- resp - close(responses) - return + + default: + for i, ss := range results { + name, args := ss.name, ss.arguments + + initialMessage := schema.OpenAIResponse{ + ID: id, + Created: created, + Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []schema.Choice{{ + Delta: &schema.Message{ + Role: "assistant", + ToolCalls: []schema.ToolCall{ + { + Index: i, + ID: id, + Type: "function", + FunctionCall: schema.FunctionCall{ + Name: name, + }, + }, + }, + }}}, + Object: "chat.completion.chunk", + } + responses <- initialMessage + + responses <- schema.OpenAIResponse{ + ID: id, + Created: created, + Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []schema.Choice{{ + Delta: &schema.Message{ + Role: "assistant", + ToolCalls: []schema.ToolCall{ + { + Index: i, + ID: id, + Type: "function", + FunctionCall: schema.FunctionCall{ + Arguments: args, + }, + }, + }, + }}}, + Object: "chat.completion.chunk", + } + } } - initialMessage := schema.OpenAIResponse{ - ID: id, - Created: created, - Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []schema.Choice{{ - Delta: &schema.Message{ - Role: "assistant", - ToolCalls: []schema.ToolCall{ - { - Index: 0, - ID: id, - Type: "function", - FunctionCall: schema.FunctionCall{ - Name: name, - }, - }, - }, - }}}, - Object: "chat.completion.chunk", - } - responses <- initialMessage - - responses <- schema.OpenAIResponse{ - ID: id, - Created: created, - Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []schema.Choice{{ - Delta: &schema.Message{ - Role: "assistant", - ToolCalls: []schema.ToolCall{ - { - Index: 0, - ID: id, - Type: "function", - FunctionCall: schema.FunctionCall{ - Arguments: args, - }, - }, - }, - }}}, - Object: "chat.completion.chunk", - } close(responses) } @@ -208,9 +212,9 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) // Update input grammar jsStruct := funcs.ToJSONStructure() - config.Grammar = jsStruct.Grammar("") + config.Grammar = jsStruct.Grammar("", config.FunctionsConfig.ParallelCalls) } else if input.JSONFunctionGrammarObject != nil { - config.Grammar = input.JSONFunctionGrammarObject.Grammar("") + config.Grammar = input.JSONFunctionGrammarObject.Grammar("", config.FunctionsConfig.ParallelCalls) } // functions are not supported in stream mode (yet?) @@ -407,57 +411,74 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) })) return nil + // no streaming mode default: result, tokenUsage, err := ComputeChoices(input, predInput, config, o, o.Loader, func(s string, c *[]schema.Choice) { - if processFunctions { - ss := map[string]interface{}{} + if !processFunctions { + // no function is called, just reply and use stop as finish reason + *c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}}) + return + } - name, args := parseFunctionCall(s) - ss["name"], ss["arguments"] = name, args + results := parseFunctionCall(s, config.FunctionsConfig.ParallelCalls) + noActionsToRun := len(results) > 0 && results[0].name == noActionName - // if do nothing, reply with a message - if name == noActionName { - result, err := handleQuestion(config, input, o, args, predInput) - if err != nil { - log.Error().Msgf("error handling question: %s", err.Error()) - return - } - *c = append(*c, schema.Choice{Message: &schema.Message{Role: "assistant", Content: &result}}) - } else { + switch { + case noActionsToRun: + result, err := handleQuestion(config, input, o, results[0].arguments, predInput) + if err != nil { + log.Error().Msgf("error handling question: %s", err.Error()) + return + } + *c = append(*c, schema.Choice{ + Message: &schema.Message{Role: "assistant", Content: &result}}) + default: + toolChoice := schema.Choice{ + Message: &schema.Message{ + Role: "assistant", + }, + } + + if len(input.Tools) > 0 { + toolChoice.FinishReason = "tool_calls" + } + + for _, ss := range results { + name, args := ss.name, ss.arguments if len(input.Tools) > 0 { - // Result is different in the case we have a tool call - *c = append(*c, schema.Choice{ - FinishReason: "tool_calls", - Message: &schema.Message{ - Role: "assistant", - ToolCalls: []schema.ToolCall{ - { - ID: id, - Type: "function", - FunctionCall: schema.FunctionCall{ - Name: name, - Arguments: args, - }, - }, + // If we are using tools, we condense the function calls into + // a single response choice with all the tools + toolChoice.Message.ToolCalls = append(toolChoice.Message.ToolCalls, + schema.ToolCall{ + ID: id, + Type: "function", + FunctionCall: schema.FunctionCall{ + Name: name, + Arguments: args, }, }, - }) + ) } else { - // otherwise reply with the function call + // otherwise we return more choices directly *c = append(*c, schema.Choice{ FinishReason: "function_call", Message: &schema.Message{ - Role: "assistant", - FunctionCall: ss, + Role: "assistant", + FunctionCall: map[string]interface{}{ + "name": name, + "arguments": args, + }, }, }) } } - return + if len(input.Tools) > 0 { + // we need to append our result if we are using tools + *c = append(*c, toolChoice) + } } - *c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}}) }, nil) if err != nil { return err @@ -528,19 +549,43 @@ func handleQuestion(config *config.Config, input *schema.OpenAIRequest, o *optio return backend.Finetune(*config, prompt, prediction.Response), nil } -func parseFunctionCall(llmresult string) (string, string) { - // As we have to change the result before processing, we can't stream the answer token-by-token (yet?) - ss := map[string]interface{}{} - // This prevent newlines to break JSON parsing for clients - s := utils.EscapeNewLines(llmresult) - json.Unmarshal([]byte(s), &ss) - log.Debug().Msgf("Function return: %s %+v", s, ss) - - // The grammar defines the function name as "function", while OpenAI returns "name" - func_name := ss["function"] - // Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object - args := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix) - d, _ := json.Marshal(args) - - return func_name.(string), string(d) +type funcCallResults struct { + name string + arguments string +} + +func parseFunctionCall(llmresult string, multipleResults bool) []funcCallResults { + results := []funcCallResults{} + + // TODO: use generics to avoid this code duplication + if multipleResults { + ss := []map[string]interface{}{} + s := utils.EscapeNewLines(llmresult) + json.Unmarshal([]byte(s), &ss) + log.Debug().Msgf("Function return: %s %+v", s, ss) + + for _, s := range ss { + func_name := s["function"] + args := s["arguments"] + d, _ := json.Marshal(args) + results = append(results, funcCallResults{name: func_name.(string), arguments: string(d)}) + } + } else { + // As we have to change the result before processing, we can't stream the answer token-by-token (yet?) + ss := map[string]interface{}{} + // This prevent newlines to break JSON parsing for clients + s := utils.EscapeNewLines(llmresult) + json.Unmarshal([]byte(s), &ss) + log.Debug().Msgf("Function return: %s %+v", s, ss) + + // The grammar defines the function name as "function", while OpenAI returns "name" + func_name := ss["function"] + // Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object + args := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix) + d, _ := json.Marshal(args) + + results = append(results, funcCallResults{name: func_name.(string), arguments: string(d)}) + } + + return results } diff --git a/pkg/grammar/json_schema.go b/pkg/grammar/json_schema.go index 40d7f4e6..76f9778f 100644 --- a/pkg/grammar/json_schema.go +++ b/pkg/grammar/json_schema.go @@ -105,11 +105,28 @@ func (sc *JSONSchemaConverter) addRule(name, rule string) string { return key } -func (sc *JSONSchemaConverter) formatGrammar() string { +const array = `arr ::= + "[\n" ( + realvalue + (",\n" realvalue)* + )? "]"` + +func (sc *JSONSchemaConverter) finalizeGrammar(maybeArray bool) string { var lines []string + // write down the computed rules. + // if maybeArray is true, we need to add the array rule and slightly tweak the root rule for name, rule := range sc.rules { + if maybeArray && name == "root" { + name = "realvalue" + } lines = append(lines, fmt.Sprintf("%s ::= %s", name, rule)) } + + if maybeArray { + lines = append(lines, fmt.Sprintf("%s ::= %s", "root", "arr | realvalue")) + lines = append(lines, array) + } + return strings.Join(lines, "\n") } @@ -234,15 +251,15 @@ func (sc *JSONSchemaConverter) resolveReference(ref string, rootSchema map[strin return def } -func (sc *JSONSchemaConverter) Grammar(schema map[string]interface{}) string { +func (sc *JSONSchemaConverter) Grammar(schema map[string]interface{}, maybeArray bool) string { sc.visit(schema, "", schema) - return sc.formatGrammar() + return sc.finalizeGrammar(maybeArray) } -func (sc *JSONSchemaConverter) GrammarFromBytes(b []byte) string { +func (sc *JSONSchemaConverter) GrammarFromBytes(b []byte, maybeArray bool) string { var schema map[string]interface{} _ = json.Unmarshal(b, &schema) - return sc.Grammar(schema) + return sc.Grammar(schema, maybeArray) } func jsonString(v interface{}) string { @@ -275,7 +292,7 @@ type JSONFunctionStructure struct { Defs map[string]interface{} `json:"$defs,omitempty"` } -func (j JSONFunctionStructure) Grammar(propOrder string) string { +func (j JSONFunctionStructure) Grammar(propOrder string, maybeArray bool) string { dat, _ := json.Marshal(j) - return NewJSONSchemaConverter(propOrder).GrammarFromBytes(dat) + return NewJSONSchemaConverter(propOrder).GrammarFromBytes(dat, maybeArray) } diff --git a/pkg/grammar/json_schema_test.go b/pkg/grammar/json_schema_test.go index 9d4086cb..39d2a4d5 100644 --- a/pkg/grammar/json_schema_test.go +++ b/pkg/grammar/json_schema_test.go @@ -52,13 +52,32 @@ string ::= "\"" ( [^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) )* "\"" space +root-1-function ::= "\"search\""` + + inputResult2 = `root-0-function ::= "\"create_event\"" +root-0 ::= "{" space "\"arguments\"" space ":" space root-0-arguments "," space "\"function\"" space ":" space root-0-function "}" space +root-1-arguments ::= "{" space "\"query\"" space ":" space string "}" space +realvalue ::= root-0 | root-1 +root ::= arr | realvalue +space ::= " "? +root-0-arguments ::= "{" space "\"date\"" space ":" space string "," space "\"time\"" space ":" space string "," space "\"title\"" space ":" space string "}" space +root-1 ::= "{" space "\"arguments\"" space ":" space root-1-arguments "," space "\"function\"" space ":" space root-1-function "}" space +string ::= "\"" ( + [^"\\] | + "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) +)* "\"" space +arr ::= + "[\n" ( + realvalue + (",\n" realvalue)* + )? "]" root-1-function ::= "\"search\""` ) var _ = Describe("JSON schema grammar tests", func() { Context("JSON", func() { It("generates a valid grammar from JSON schema", func() { - grammar := NewJSONSchemaConverter("").GrammarFromBytes([]byte(testInput1)) + grammar := NewJSONSchemaConverter("").GrammarFromBytes([]byte(testInput1), false) results := strings.Split(inputResult1, "\n") for _, r := range results { if r != "" { @@ -103,7 +122,7 @@ var _ = Describe("JSON schema grammar tests", func() { }, }} - grammar := structuredGrammar.Grammar("") + grammar := structuredGrammar.Grammar("", false) results := strings.Split(inputResult1, "\n") for _, r := range results { if r != "" { @@ -112,5 +131,50 @@ var _ = Describe("JSON schema grammar tests", func() { } Expect(len(results)).To(Equal(len(strings.Split(grammar, "\n")))) }) + + It("generates a valid grammar from JSON Objects for multiple function return", func() { + structuredGrammar := JSONFunctionStructure{ + OneOf: []Item{ + { + Type: "object", + Properties: Properties{ + Function: FunctionName{ + Const: "create_event", + }, + Arguments: Argument{ // this is OpenAI's parameter + Type: "object", + Properties: map[string]interface{}{ + "title": map[string]string{"type": "string"}, + "date": map[string]string{"type": "string"}, + "time": map[string]string{"type": "string"}, + }, + }, + }, + }, + { + Type: "object", + Properties: Properties{ + Function: FunctionName{ + Const: "search", + }, + Arguments: Argument{ + Type: "object", + Properties: map[string]interface{}{ + "query": map[string]string{"type": "string"}, + }, + }, + }, + }, + }} + + grammar := structuredGrammar.Grammar("", true) + results := strings.Split(inputResult2, "\n") + for _, r := range results { + if r != "" { + Expect(grammar).To(ContainSubstring(r)) + } + } + Expect(len(results)).To(Equal(len(strings.Split(grammar, "\n"))), grammar) + }) }) })