mirror of
https://github.com/mudler/LocalAI.git
synced 2024-12-18 20:27:57 +00:00
wip: try to let JSON grammar to return strings as well
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
a670318a9f
commit
ac47aeaddd
@ -219,15 +219,15 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
|
||||
// Handle if we should return "name" instead of "functions"
|
||||
if config.FunctionsConfig.FunctionName {
|
||||
jsStruct := funcs.ToJSONNameStructure()
|
||||
config.Grammar = jsStruct.Grammar("", config.FunctionsConfig.ParallelCalls)
|
||||
config.Grammar = jsStruct.Grammar("", config.FunctionsConfig.ParallelCalls, config.FunctionsConfig.GrammarMessage)
|
||||
} else {
|
||||
jsStruct := funcs.ToJSONFunctionStructure()
|
||||
config.Grammar = jsStruct.Grammar("", config.FunctionsConfig.ParallelCalls)
|
||||
config.Grammar = jsStruct.Grammar("", config.FunctionsConfig.ParallelCalls, config.FunctionsConfig.GrammarMessage)
|
||||
}
|
||||
case input.JSONFunctionGrammarObject != nil:
|
||||
config.Grammar = input.JSONFunctionGrammarObject.Grammar("", config.FunctionsConfig.ParallelCalls)
|
||||
config.Grammar = input.JSONFunctionGrammarObject.Grammar("", config.FunctionsConfig.ParallelCalls, config.FunctionsConfig.GrammarMessage)
|
||||
case input.JSONFunctionGrammarObjectName != nil:
|
||||
config.Grammar = input.JSONFunctionGrammarObjectName.Grammar("", config.FunctionsConfig.ParallelCalls)
|
||||
config.Grammar = input.JSONFunctionGrammarObjectName.Grammar("", config.FunctionsConfig.ParallelCalls, config.FunctionsConfig.GrammarMessage)
|
||||
default:
|
||||
// Force picking one of the functions by the request
|
||||
if config.FunctionToCall() != "" {
|
||||
|
@ -111,21 +111,29 @@ const array = `arr ::=
|
||||
(",\n" realvalue)*
|
||||
)? "]"`
|
||||
|
||||
func (sc *JSONSchemaConverter) finalizeGrammar(maybeArray bool) string {
|
||||
func (sc *JSONSchemaConverter) finalizeGrammar(maybeArray, maybeString bool) string {
|
||||
var lines []string
|
||||
// write down the computed rules.
|
||||
// if maybeArray is true, we need to add the array rule and slightly tweak the root rule
|
||||
for name, rule := range sc.rules {
|
||||
if maybeArray && name == "root" {
|
||||
if (maybeArray || maybeString) && name == "root" {
|
||||
name = "realvalue"
|
||||
}
|
||||
lines = append(lines, fmt.Sprintf("%s ::= %s", name, rule))
|
||||
}
|
||||
|
||||
if maybeArray {
|
||||
if maybeArray && !maybeString {
|
||||
lines = append(lines, fmt.Sprintf("%s ::= %s", "root", "arr | realvalue"))
|
||||
lines = append(lines, array)
|
||||
}
|
||||
if maybeArray && maybeString {
|
||||
lines = append(lines, fmt.Sprintf("%s ::= %s", "root", "arr | realvalue | string"))
|
||||
lines = append(lines, array)
|
||||
}
|
||||
if maybeString && !maybeArray {
|
||||
lines = append(lines, fmt.Sprintf("%s ::= %s", "root", "realvalue | string"))
|
||||
lines = append(lines, array)
|
||||
}
|
||||
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
@ -251,15 +259,15 @@ func (sc *JSONSchemaConverter) resolveReference(ref string, rootSchema map[strin
|
||||
|
||||
return def
|
||||
}
|
||||
func (sc *JSONSchemaConverter) Grammar(schema map[string]interface{}, maybeArray bool) string {
|
||||
func (sc *JSONSchemaConverter) Grammar(schema map[string]interface{}, maybeArray, maybeString bool) string {
|
||||
sc.visit(schema, "", schema)
|
||||
return sc.finalizeGrammar(maybeArray)
|
||||
return sc.finalizeGrammar(maybeArray, maybeString)
|
||||
}
|
||||
|
||||
func (sc *JSONSchemaConverter) GrammarFromBytes(b []byte, maybeArray bool) string {
|
||||
func (sc *JSONSchemaConverter) GrammarFromBytes(b []byte, maybeArray, maybeString bool) string {
|
||||
var schema map[string]interface{}
|
||||
_ = json.Unmarshal(b, &schema)
|
||||
return sc.Grammar(schema, maybeArray)
|
||||
return sc.Grammar(schema, maybeArray, maybeString)
|
||||
}
|
||||
|
||||
func jsonString(v interface{}) string {
|
||||
@ -302,9 +310,9 @@ type JSONFunctionStructureName struct {
|
||||
Defs map[string]interface{} `json:"$defs,omitempty"`
|
||||
}
|
||||
|
||||
func (j JSONFunctionStructureName) Grammar(propOrder string, maybeArray bool) string {
|
||||
func (j JSONFunctionStructureName) Grammar(propOrder string, maybeArray, maybeString bool) string {
|
||||
dat, _ := json.Marshal(j)
|
||||
return NewJSONSchemaConverter(propOrder).GrammarFromBytes(dat, maybeArray)
|
||||
return NewJSONSchemaConverter(propOrder).GrammarFromBytes(dat, maybeArray, maybeString)
|
||||
}
|
||||
|
||||
type JSONFunctionStructureFunction struct {
|
||||
@ -313,7 +321,7 @@ type JSONFunctionStructureFunction struct {
|
||||
Defs map[string]interface{} `json:"$defs,omitempty"`
|
||||
}
|
||||
|
||||
func (j JSONFunctionStructureFunction) Grammar(propOrder string, maybeArray bool) string {
|
||||
func (j JSONFunctionStructureFunction) Grammar(propOrder string, maybeArray, maybeString bool) string {
|
||||
dat, _ := json.Marshal(j)
|
||||
return NewJSONSchemaConverter(propOrder).GrammarFromBytes(dat, maybeArray)
|
||||
return NewJSONSchemaConverter(propOrder).GrammarFromBytes(dat, maybeArray, maybeString)
|
||||
}
|
||||
|
@ -141,7 +141,7 @@ root-1-name ::= "\"search\""`
|
||||
var _ = Describe("JSON schema grammar tests", func() {
|
||||
Context("JSON", func() {
|
||||
It("generates a valid grammar from JSON schema", func() {
|
||||
grammar := NewJSONSchemaConverter("").GrammarFromBytes([]byte(testInput1), false)
|
||||
grammar := NewJSONSchemaConverter("").GrammarFromBytes([]byte(testInput1), false, false)
|
||||
results := strings.Split(inputResult1, "\n")
|
||||
for _, r := range results {
|
||||
if r != "" {
|
||||
@ -151,7 +151,7 @@ var _ = Describe("JSON schema grammar tests", func() {
|
||||
Expect(len(results)).To(Equal(len(strings.Split(grammar, "\n"))))
|
||||
})
|
||||
It("generates a valid grammar from JSON schema", func() {
|
||||
grammar := NewJSONSchemaConverter("").GrammarFromBytes([]byte(testInput2), false)
|
||||
grammar := NewJSONSchemaConverter("").GrammarFromBytes([]byte(testInput2), false, false)
|
||||
results := strings.Split(inputResult3, "\n")
|
||||
for _, r := range results {
|
||||
if r != "" {
|
||||
@ -196,7 +196,7 @@ var _ = Describe("JSON schema grammar tests", func() {
|
||||
},
|
||||
}}
|
||||
|
||||
grammar := structuredGrammar.Grammar("", false)
|
||||
grammar := structuredGrammar.Grammar("", false, false)
|
||||
results := strings.Split(inputResult1, "\n")
|
||||
for _, r := range results {
|
||||
if r != "" {
|
||||
@ -241,7 +241,7 @@ var _ = Describe("JSON schema grammar tests", func() {
|
||||
},
|
||||
}}
|
||||
|
||||
grammar := structuredGrammar.Grammar("", true)
|
||||
grammar := structuredGrammar.Grammar("", true, false)
|
||||
results := strings.Split(inputResult2, "\n")
|
||||
for _, r := range results {
|
||||
if r != "" {
|
||||
@ -286,7 +286,7 @@ var _ = Describe("JSON schema grammar tests", func() {
|
||||
},
|
||||
}}
|
||||
|
||||
grammar := structuredGrammar.Grammar("", true)
|
||||
grammar := structuredGrammar.Grammar("", true, false)
|
||||
results := strings.Split(inputResult4, "\n")
|
||||
for _, r := range results {
|
||||
if r != "" {
|
||||
|
@ -14,8 +14,11 @@ type FunctionsConfig struct {
|
||||
NoActionFunctionName string `yaml:"no_action_function_name"`
|
||||
NoActionDescriptionName string `yaml:"no_action_description_name"`
|
||||
ParallelCalls bool `yaml:"parallel_calls"`
|
||||
NoGrammar bool `yaml:"no_grammar"`
|
||||
ResponseRegex string `yaml:"response_regex"`
|
||||
|
||||
// GrammarMessage enables the LLM to return strings and not only JSON objects
|
||||
GrammarMessage bool `yaml:"grammar_message"`
|
||||
NoGrammar bool `yaml:"no_grammar"`
|
||||
ResponseRegex string `yaml:"response_regex"`
|
||||
|
||||
JSONRegexMatch string `yaml:"json_regex_match"`
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user