diff --git a/core/http/endpoints/openai/chat.go b/core/http/endpoints/openai/chat.go index ccbf0946..35bdfd82 100644 --- a/core/http/endpoints/openai/chat.go +++ b/core/http/endpoints/openai/chat.go @@ -219,15 +219,15 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup // Handle if we should return "name" instead of "functions" if config.FunctionsConfig.FunctionName { jsStruct := funcs.ToJSONNameStructure() - config.Grammar = jsStruct.Grammar("", config.FunctionsConfig.ParallelCalls) + config.Grammar = jsStruct.Grammar("", config.FunctionsConfig.ParallelCalls, config.FunctionsConfig.GrammarMessage) } else { jsStruct := funcs.ToJSONFunctionStructure() - config.Grammar = jsStruct.Grammar("", config.FunctionsConfig.ParallelCalls) + config.Grammar = jsStruct.Grammar("", config.FunctionsConfig.ParallelCalls, config.FunctionsConfig.GrammarMessage) } case input.JSONFunctionGrammarObject != nil: - config.Grammar = input.JSONFunctionGrammarObject.Grammar("", config.FunctionsConfig.ParallelCalls) + config.Grammar = input.JSONFunctionGrammarObject.Grammar("", config.FunctionsConfig.ParallelCalls, config.FunctionsConfig.GrammarMessage) case input.JSONFunctionGrammarObjectName != nil: - config.Grammar = input.JSONFunctionGrammarObjectName.Grammar("", config.FunctionsConfig.ParallelCalls) + config.Grammar = input.JSONFunctionGrammarObjectName.Grammar("", config.FunctionsConfig.ParallelCalls, config.FunctionsConfig.GrammarMessage) default: // Force picking one of the functions by the request if config.FunctionToCall() != "" { diff --git a/pkg/functions/grammar_json_schema.go b/pkg/functions/grammar_json_schema.go index ede52fab..9ce8b583 100644 --- a/pkg/functions/grammar_json_schema.go +++ b/pkg/functions/grammar_json_schema.go @@ -111,21 +111,29 @@ const array = `arr ::= (",\n" realvalue)* )? "]"` -func (sc *JSONSchemaConverter) finalizeGrammar(maybeArray bool) string { +func (sc *JSONSchemaConverter) finalizeGrammar(maybeArray, maybeString bool) string { var lines []string // write down the computed rules. // if maybeArray is true, we need to add the array rule and slightly tweak the root rule for name, rule := range sc.rules { - if maybeArray && name == "root" { + if (maybeArray || maybeString) && name == "root" { name = "realvalue" } lines = append(lines, fmt.Sprintf("%s ::= %s", name, rule)) } - if maybeArray { + if maybeArray && !maybeString { lines = append(lines, fmt.Sprintf("%s ::= %s", "root", "arr | realvalue")) lines = append(lines, array) } + if maybeArray && maybeString { + lines = append(lines, fmt.Sprintf("%s ::= %s", "root", "arr | realvalue | string")) + lines = append(lines, array) + } + if maybeString && !maybeArray { + lines = append(lines, fmt.Sprintf("%s ::= %s", "root", "realvalue | string")) + lines = append(lines, array) + } return strings.Join(lines, "\n") } @@ -251,15 +259,15 @@ func (sc *JSONSchemaConverter) resolveReference(ref string, rootSchema map[strin return def } -func (sc *JSONSchemaConverter) Grammar(schema map[string]interface{}, maybeArray bool) string { +func (sc *JSONSchemaConverter) Grammar(schema map[string]interface{}, maybeArray, maybeString bool) string { sc.visit(schema, "", schema) - return sc.finalizeGrammar(maybeArray) + return sc.finalizeGrammar(maybeArray, maybeString) } -func (sc *JSONSchemaConverter) GrammarFromBytes(b []byte, maybeArray bool) string { +func (sc *JSONSchemaConverter) GrammarFromBytes(b []byte, maybeArray, maybeString bool) string { var schema map[string]interface{} _ = json.Unmarshal(b, &schema) - return sc.Grammar(schema, maybeArray) + return sc.Grammar(schema, maybeArray, maybeString) } func jsonString(v interface{}) string { @@ -302,9 +310,9 @@ type JSONFunctionStructureName struct { Defs map[string]interface{} `json:"$defs,omitempty"` } -func (j JSONFunctionStructureName) Grammar(propOrder string, maybeArray bool) string { +func (j JSONFunctionStructureName) Grammar(propOrder string, maybeArray, maybeString bool) string { dat, _ := json.Marshal(j) - return NewJSONSchemaConverter(propOrder).GrammarFromBytes(dat, maybeArray) + return NewJSONSchemaConverter(propOrder).GrammarFromBytes(dat, maybeArray, maybeString) } type JSONFunctionStructureFunction struct { @@ -313,7 +321,7 @@ type JSONFunctionStructureFunction struct { Defs map[string]interface{} `json:"$defs,omitempty"` } -func (j JSONFunctionStructureFunction) Grammar(propOrder string, maybeArray bool) string { +func (j JSONFunctionStructureFunction) Grammar(propOrder string, maybeArray, maybeString bool) string { dat, _ := json.Marshal(j) - return NewJSONSchemaConverter(propOrder).GrammarFromBytes(dat, maybeArray) + return NewJSONSchemaConverter(propOrder).GrammarFromBytes(dat, maybeArray, maybeString) } diff --git a/pkg/functions/grammar_json_schema_test.go b/pkg/functions/grammar_json_schema_test.go index 83fae372..da8beb9b 100644 --- a/pkg/functions/grammar_json_schema_test.go +++ b/pkg/functions/grammar_json_schema_test.go @@ -141,7 +141,7 @@ root-1-name ::= "\"search\""` var _ = Describe("JSON schema grammar tests", func() { Context("JSON", func() { It("generates a valid grammar from JSON schema", func() { - grammar := NewJSONSchemaConverter("").GrammarFromBytes([]byte(testInput1), false) + grammar := NewJSONSchemaConverter("").GrammarFromBytes([]byte(testInput1), false, false) results := strings.Split(inputResult1, "\n") for _, r := range results { if r != "" { @@ -151,7 +151,7 @@ var _ = Describe("JSON schema grammar tests", func() { Expect(len(results)).To(Equal(len(strings.Split(grammar, "\n")))) }) It("generates a valid grammar from JSON schema", func() { - grammar := NewJSONSchemaConverter("").GrammarFromBytes([]byte(testInput2), false) + grammar := NewJSONSchemaConverter("").GrammarFromBytes([]byte(testInput2), false, false) results := strings.Split(inputResult3, "\n") for _, r := range results { if r != "" { @@ -196,7 +196,7 @@ var _ = Describe("JSON schema grammar tests", func() { }, }} - grammar := structuredGrammar.Grammar("", false) + grammar := structuredGrammar.Grammar("", false, false) results := strings.Split(inputResult1, "\n") for _, r := range results { if r != "" { @@ -241,7 +241,7 @@ var _ = Describe("JSON schema grammar tests", func() { }, }} - grammar := structuredGrammar.Grammar("", true) + grammar := structuredGrammar.Grammar("", true, false) results := strings.Split(inputResult2, "\n") for _, r := range results { if r != "" { @@ -286,7 +286,7 @@ var _ = Describe("JSON schema grammar tests", func() { }, }} - grammar := structuredGrammar.Grammar("", true) + grammar := structuredGrammar.Grammar("", true, false) results := strings.Split(inputResult4, "\n") for _, r := range results { if r != "" { diff --git a/pkg/functions/parse.go b/pkg/functions/parse.go index c6941ff6..7c645075 100644 --- a/pkg/functions/parse.go +++ b/pkg/functions/parse.go @@ -14,8 +14,11 @@ type FunctionsConfig struct { NoActionFunctionName string `yaml:"no_action_function_name"` NoActionDescriptionName string `yaml:"no_action_description_name"` ParallelCalls bool `yaml:"parallel_calls"` - NoGrammar bool `yaml:"no_grammar"` - ResponseRegex string `yaml:"response_regex"` + + // GrammarMessage enables the LLM to return strings and not only JSON objects + GrammarMessage bool `yaml:"grammar_message"` + NoGrammar bool `yaml:"no_grammar"` + ResponseRegex string `yaml:"response_regex"` JSONRegexMatch string `yaml:"json_regex_match"`