wip: try to let JSON grammar to return strings as well

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2024-05-14 22:58:09 +02:00
parent a670318a9f
commit ac47aeaddd
4 changed files with 33 additions and 22 deletions

View File

@ -219,15 +219,15 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
// Handle if we should return "name" instead of "functions" // Handle if we should return "name" instead of "functions"
if config.FunctionsConfig.FunctionName { if config.FunctionsConfig.FunctionName {
jsStruct := funcs.ToJSONNameStructure() jsStruct := funcs.ToJSONNameStructure()
config.Grammar = jsStruct.Grammar("", config.FunctionsConfig.ParallelCalls) config.Grammar = jsStruct.Grammar("", config.FunctionsConfig.ParallelCalls, config.FunctionsConfig.GrammarMessage)
} else { } else {
jsStruct := funcs.ToJSONFunctionStructure() jsStruct := funcs.ToJSONFunctionStructure()
config.Grammar = jsStruct.Grammar("", config.FunctionsConfig.ParallelCalls) config.Grammar = jsStruct.Grammar("", config.FunctionsConfig.ParallelCalls, config.FunctionsConfig.GrammarMessage)
} }
case input.JSONFunctionGrammarObject != nil: case input.JSONFunctionGrammarObject != nil:
config.Grammar = input.JSONFunctionGrammarObject.Grammar("", config.FunctionsConfig.ParallelCalls) config.Grammar = input.JSONFunctionGrammarObject.Grammar("", config.FunctionsConfig.ParallelCalls, config.FunctionsConfig.GrammarMessage)
case input.JSONFunctionGrammarObjectName != nil: case input.JSONFunctionGrammarObjectName != nil:
config.Grammar = input.JSONFunctionGrammarObjectName.Grammar("", config.FunctionsConfig.ParallelCalls) config.Grammar = input.JSONFunctionGrammarObjectName.Grammar("", config.FunctionsConfig.ParallelCalls, config.FunctionsConfig.GrammarMessage)
default: default:
// Force picking one of the functions by the request // Force picking one of the functions by the request
if config.FunctionToCall() != "" { if config.FunctionToCall() != "" {

View File

@ -111,21 +111,29 @@ const array = `arr ::=
(",\n" realvalue)* (",\n" realvalue)*
)? "]"` )? "]"`
func (sc *JSONSchemaConverter) finalizeGrammar(maybeArray bool) string { func (sc *JSONSchemaConverter) finalizeGrammar(maybeArray, maybeString bool) string {
var lines []string var lines []string
// write down the computed rules. // write down the computed rules.
// if maybeArray is true, we need to add the array rule and slightly tweak the root rule // if maybeArray is true, we need to add the array rule and slightly tweak the root rule
for name, rule := range sc.rules { for name, rule := range sc.rules {
if maybeArray && name == "root" { if (maybeArray || maybeString) && name == "root" {
name = "realvalue" name = "realvalue"
} }
lines = append(lines, fmt.Sprintf("%s ::= %s", name, rule)) lines = append(lines, fmt.Sprintf("%s ::= %s", name, rule))
} }
if maybeArray { if maybeArray && !maybeString {
lines = append(lines, fmt.Sprintf("%s ::= %s", "root", "arr | realvalue")) lines = append(lines, fmt.Sprintf("%s ::= %s", "root", "arr | realvalue"))
lines = append(lines, array) lines = append(lines, array)
} }
if maybeArray && maybeString {
lines = append(lines, fmt.Sprintf("%s ::= %s", "root", "arr | realvalue | string"))
lines = append(lines, array)
}
if maybeString && !maybeArray {
lines = append(lines, fmt.Sprintf("%s ::= %s", "root", "realvalue | string"))
lines = append(lines, array)
}
return strings.Join(lines, "\n") return strings.Join(lines, "\n")
} }
@ -251,15 +259,15 @@ func (sc *JSONSchemaConverter) resolveReference(ref string, rootSchema map[strin
return def return def
} }
func (sc *JSONSchemaConverter) Grammar(schema map[string]interface{}, maybeArray bool) string { func (sc *JSONSchemaConverter) Grammar(schema map[string]interface{}, maybeArray, maybeString bool) string {
sc.visit(schema, "", schema) sc.visit(schema, "", schema)
return sc.finalizeGrammar(maybeArray) return sc.finalizeGrammar(maybeArray, maybeString)
} }
func (sc *JSONSchemaConverter) GrammarFromBytes(b []byte, maybeArray bool) string { func (sc *JSONSchemaConverter) GrammarFromBytes(b []byte, maybeArray, maybeString bool) string {
var schema map[string]interface{} var schema map[string]interface{}
_ = json.Unmarshal(b, &schema) _ = json.Unmarshal(b, &schema)
return sc.Grammar(schema, maybeArray) return sc.Grammar(schema, maybeArray, maybeString)
} }
func jsonString(v interface{}) string { func jsonString(v interface{}) string {
@ -302,9 +310,9 @@ type JSONFunctionStructureName struct {
Defs map[string]interface{} `json:"$defs,omitempty"` Defs map[string]interface{} `json:"$defs,omitempty"`
} }
func (j JSONFunctionStructureName) Grammar(propOrder string, maybeArray bool) string { func (j JSONFunctionStructureName) Grammar(propOrder string, maybeArray, maybeString bool) string {
dat, _ := json.Marshal(j) dat, _ := json.Marshal(j)
return NewJSONSchemaConverter(propOrder).GrammarFromBytes(dat, maybeArray) return NewJSONSchemaConverter(propOrder).GrammarFromBytes(dat, maybeArray, maybeString)
} }
type JSONFunctionStructureFunction struct { type JSONFunctionStructureFunction struct {
@ -313,7 +321,7 @@ type JSONFunctionStructureFunction struct {
Defs map[string]interface{} `json:"$defs,omitempty"` Defs map[string]interface{} `json:"$defs,omitempty"`
} }
func (j JSONFunctionStructureFunction) Grammar(propOrder string, maybeArray bool) string { func (j JSONFunctionStructureFunction) Grammar(propOrder string, maybeArray, maybeString bool) string {
dat, _ := json.Marshal(j) dat, _ := json.Marshal(j)
return NewJSONSchemaConverter(propOrder).GrammarFromBytes(dat, maybeArray) return NewJSONSchemaConverter(propOrder).GrammarFromBytes(dat, maybeArray, maybeString)
} }

View File

@ -141,7 +141,7 @@ root-1-name ::= "\"search\""`
var _ = Describe("JSON schema grammar tests", func() { var _ = Describe("JSON schema grammar tests", func() {
Context("JSON", func() { Context("JSON", func() {
It("generates a valid grammar from JSON schema", func() { It("generates a valid grammar from JSON schema", func() {
grammar := NewJSONSchemaConverter("").GrammarFromBytes([]byte(testInput1), false) grammar := NewJSONSchemaConverter("").GrammarFromBytes([]byte(testInput1), false, false)
results := strings.Split(inputResult1, "\n") results := strings.Split(inputResult1, "\n")
for _, r := range results { for _, r := range results {
if r != "" { if r != "" {
@ -151,7 +151,7 @@ var _ = Describe("JSON schema grammar tests", func() {
Expect(len(results)).To(Equal(len(strings.Split(grammar, "\n")))) Expect(len(results)).To(Equal(len(strings.Split(grammar, "\n"))))
}) })
It("generates a valid grammar from JSON schema", func() { It("generates a valid grammar from JSON schema", func() {
grammar := NewJSONSchemaConverter("").GrammarFromBytes([]byte(testInput2), false) grammar := NewJSONSchemaConverter("").GrammarFromBytes([]byte(testInput2), false, false)
results := strings.Split(inputResult3, "\n") results := strings.Split(inputResult3, "\n")
for _, r := range results { for _, r := range results {
if r != "" { if r != "" {
@ -196,7 +196,7 @@ var _ = Describe("JSON schema grammar tests", func() {
}, },
}} }}
grammar := structuredGrammar.Grammar("", false) grammar := structuredGrammar.Grammar("", false, false)
results := strings.Split(inputResult1, "\n") results := strings.Split(inputResult1, "\n")
for _, r := range results { for _, r := range results {
if r != "" { if r != "" {
@ -241,7 +241,7 @@ var _ = Describe("JSON schema grammar tests", func() {
}, },
}} }}
grammar := structuredGrammar.Grammar("", true) grammar := structuredGrammar.Grammar("", true, false)
results := strings.Split(inputResult2, "\n") results := strings.Split(inputResult2, "\n")
for _, r := range results { for _, r := range results {
if r != "" { if r != "" {
@ -286,7 +286,7 @@ var _ = Describe("JSON schema grammar tests", func() {
}, },
}} }}
grammar := structuredGrammar.Grammar("", true) grammar := structuredGrammar.Grammar("", true, false)
results := strings.Split(inputResult4, "\n") results := strings.Split(inputResult4, "\n")
for _, r := range results { for _, r := range results {
if r != "" { if r != "" {

View File

@ -14,6 +14,9 @@ type FunctionsConfig struct {
NoActionFunctionName string `yaml:"no_action_function_name"` NoActionFunctionName string `yaml:"no_action_function_name"`
NoActionDescriptionName string `yaml:"no_action_description_name"` NoActionDescriptionName string `yaml:"no_action_description_name"`
ParallelCalls bool `yaml:"parallel_calls"` ParallelCalls bool `yaml:"parallel_calls"`
// GrammarMessage enables the LLM to return strings and not only JSON objects
GrammarMessage bool `yaml:"grammar_message"`
NoGrammar bool `yaml:"no_grammar"` NoGrammar bool `yaml:"no_grammar"`
ResponseRegex string `yaml:"response_regex"` ResponseRegex string `yaml:"response_regex"`