wip: try to let JSON grammar to return strings as well

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2024-05-14 22:58:09 +02:00
parent a670318a9f
commit ac47aeaddd
4 changed files with 33 additions and 22 deletions

View File

@ -219,15 +219,15 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
// Handle if we should return "name" instead of "functions"
if config.FunctionsConfig.FunctionName {
jsStruct := funcs.ToJSONNameStructure()
config.Grammar = jsStruct.Grammar("", config.FunctionsConfig.ParallelCalls)
config.Grammar = jsStruct.Grammar("", config.FunctionsConfig.ParallelCalls, config.FunctionsConfig.GrammarMessage)
} else {
jsStruct := funcs.ToJSONFunctionStructure()
config.Grammar = jsStruct.Grammar("", config.FunctionsConfig.ParallelCalls)
config.Grammar = jsStruct.Grammar("", config.FunctionsConfig.ParallelCalls, config.FunctionsConfig.GrammarMessage)
}
case input.JSONFunctionGrammarObject != nil:
config.Grammar = input.JSONFunctionGrammarObject.Grammar("", config.FunctionsConfig.ParallelCalls)
config.Grammar = input.JSONFunctionGrammarObject.Grammar("", config.FunctionsConfig.ParallelCalls, config.FunctionsConfig.GrammarMessage)
case input.JSONFunctionGrammarObjectName != nil:
config.Grammar = input.JSONFunctionGrammarObjectName.Grammar("", config.FunctionsConfig.ParallelCalls)
config.Grammar = input.JSONFunctionGrammarObjectName.Grammar("", config.FunctionsConfig.ParallelCalls, config.FunctionsConfig.GrammarMessage)
default:
// Force picking one of the functions by the request
if config.FunctionToCall() != "" {

View File

@ -111,21 +111,29 @@ const array = `arr ::=
(",\n" realvalue)*
)? "]"`
func (sc *JSONSchemaConverter) finalizeGrammar(maybeArray bool) string {
func (sc *JSONSchemaConverter) finalizeGrammar(maybeArray, maybeString bool) string {
var lines []string
// write down the computed rules.
// if maybeArray is true, we need to add the array rule and slightly tweak the root rule
for name, rule := range sc.rules {
if maybeArray && name == "root" {
if (maybeArray || maybeString) && name == "root" {
name = "realvalue"
}
lines = append(lines, fmt.Sprintf("%s ::= %s", name, rule))
}
if maybeArray {
if maybeArray && !maybeString {
lines = append(lines, fmt.Sprintf("%s ::= %s", "root", "arr | realvalue"))
lines = append(lines, array)
}
if maybeArray && maybeString {
lines = append(lines, fmt.Sprintf("%s ::= %s", "root", "arr | realvalue | string"))
lines = append(lines, array)
}
if maybeString && !maybeArray {
lines = append(lines, fmt.Sprintf("%s ::= %s", "root", "realvalue | string"))
lines = append(lines, array)
}
return strings.Join(lines, "\n")
}
@ -251,15 +259,15 @@ func (sc *JSONSchemaConverter) resolveReference(ref string, rootSchema map[strin
return def
}
func (sc *JSONSchemaConverter) Grammar(schema map[string]interface{}, maybeArray bool) string {
func (sc *JSONSchemaConverter) Grammar(schema map[string]interface{}, maybeArray, maybeString bool) string {
sc.visit(schema, "", schema)
return sc.finalizeGrammar(maybeArray)
return sc.finalizeGrammar(maybeArray, maybeString)
}
func (sc *JSONSchemaConverter) GrammarFromBytes(b []byte, maybeArray bool) string {
func (sc *JSONSchemaConverter) GrammarFromBytes(b []byte, maybeArray, maybeString bool) string {
var schema map[string]interface{}
_ = json.Unmarshal(b, &schema)
return sc.Grammar(schema, maybeArray)
return sc.Grammar(schema, maybeArray, maybeString)
}
func jsonString(v interface{}) string {
@ -302,9 +310,9 @@ type JSONFunctionStructureName struct {
Defs map[string]interface{} `json:"$defs,omitempty"`
}
func (j JSONFunctionStructureName) Grammar(propOrder string, maybeArray bool) string {
func (j JSONFunctionStructureName) Grammar(propOrder string, maybeArray, maybeString bool) string {
dat, _ := json.Marshal(j)
return NewJSONSchemaConverter(propOrder).GrammarFromBytes(dat, maybeArray)
return NewJSONSchemaConverter(propOrder).GrammarFromBytes(dat, maybeArray, maybeString)
}
type JSONFunctionStructureFunction struct {
@ -313,7 +321,7 @@ type JSONFunctionStructureFunction struct {
Defs map[string]interface{} `json:"$defs,omitempty"`
}
func (j JSONFunctionStructureFunction) Grammar(propOrder string, maybeArray bool) string {
func (j JSONFunctionStructureFunction) Grammar(propOrder string, maybeArray, maybeString bool) string {
dat, _ := json.Marshal(j)
return NewJSONSchemaConverter(propOrder).GrammarFromBytes(dat, maybeArray)
return NewJSONSchemaConverter(propOrder).GrammarFromBytes(dat, maybeArray, maybeString)
}

View File

@ -141,7 +141,7 @@ root-1-name ::= "\"search\""`
var _ = Describe("JSON schema grammar tests", func() {
Context("JSON", func() {
It("generates a valid grammar from JSON schema", func() {
grammar := NewJSONSchemaConverter("").GrammarFromBytes([]byte(testInput1), false)
grammar := NewJSONSchemaConverter("").GrammarFromBytes([]byte(testInput1), false, false)
results := strings.Split(inputResult1, "\n")
for _, r := range results {
if r != "" {
@ -151,7 +151,7 @@ var _ = Describe("JSON schema grammar tests", func() {
Expect(len(results)).To(Equal(len(strings.Split(grammar, "\n"))))
})
It("generates a valid grammar from JSON schema", func() {
grammar := NewJSONSchemaConverter("").GrammarFromBytes([]byte(testInput2), false)
grammar := NewJSONSchemaConverter("").GrammarFromBytes([]byte(testInput2), false, false)
results := strings.Split(inputResult3, "\n")
for _, r := range results {
if r != "" {
@ -196,7 +196,7 @@ var _ = Describe("JSON schema grammar tests", func() {
},
}}
grammar := structuredGrammar.Grammar("", false)
grammar := structuredGrammar.Grammar("", false, false)
results := strings.Split(inputResult1, "\n")
for _, r := range results {
if r != "" {
@ -241,7 +241,7 @@ var _ = Describe("JSON schema grammar tests", func() {
},
}}
grammar := structuredGrammar.Grammar("", true)
grammar := structuredGrammar.Grammar("", true, false)
results := strings.Split(inputResult2, "\n")
for _, r := range results {
if r != "" {
@ -286,7 +286,7 @@ var _ = Describe("JSON schema grammar tests", func() {
},
}}
grammar := structuredGrammar.Grammar("", true)
grammar := structuredGrammar.Grammar("", true, false)
results := strings.Split(inputResult4, "\n")
for _, r := range results {
if r != "" {

View File

@ -14,8 +14,11 @@ type FunctionsConfig struct {
NoActionFunctionName string `yaml:"no_action_function_name"`
NoActionDescriptionName string `yaml:"no_action_description_name"`
ParallelCalls bool `yaml:"parallel_calls"`
NoGrammar bool `yaml:"no_grammar"`
ResponseRegex string `yaml:"response_regex"`
// GrammarMessage enables the LLM to return strings and not only JSON objects
GrammarMessage bool `yaml:"grammar_message"`
NoGrammar bool `yaml:"no_grammar"`
ResponseRegex string `yaml:"response_regex"`
JSONRegexMatch string `yaml:"json_regex_match"`