diff --git a/core/http/endpoints/openai/chat.go b/core/http/endpoints/openai/chat.go index c7afb7bf..86b75601 100644 --- a/core/http/endpoints/openai/chat.go +++ b/core/http/endpoints/openai/chat.go @@ -226,12 +226,12 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup // Update input grammar jsStruct := funcs.ToJSONStructure(config.FunctionsConfig.FunctionNameKey, config.FunctionsConfig.FunctionNameKey) - g, err := jsStruct.Grammar(config.FunctionsConfig.GrammarConfig.Options()...) + g, err := jsStruct.Grammar(config.FunctionsConfig.GrammarOptions()...) if err == nil { config.Grammar = g } case input.JSONFunctionGrammarObject != nil: - g, err := input.JSONFunctionGrammarObject.Grammar(config.FunctionsConfig.GrammarConfig.Options()...) + g, err := input.JSONFunctionGrammarObject.Grammar(config.FunctionsConfig.GrammarOptions()...) if err == nil { config.Grammar = g } diff --git a/pkg/functions/function_structure.go b/pkg/functions/function_structure.go index 62cc68fa..c4337d67 100644 --- a/pkg/functions/function_structure.go +++ b/pkg/functions/function_structure.go @@ -1,6 +1,10 @@ package functions -import "encoding/json" +import ( + "encoding/json" + + "github.com/mudler/LocalAI/pkg/functions/grammars" +) type Item struct { Type string `json:"type"` @@ -13,13 +17,27 @@ type JSONFunctionStructure struct { Defs map[string]interface{} `json:"$defs,omitempty"` } -func (j JSONFunctionStructure) Grammar(options ...func(*GrammarOption)) (string, error) { - grammarOpts := &GrammarOption{} +func (j JSONFunctionStructure) Grammar(options ...func(*grammars.GrammarOption)) (string, error) { + grammarOpts := &grammars.GrammarOption{} grammarOpts.Apply(options...) dat, err := json.Marshal(j) if err != nil { return "", err } - return NewJSONSchemaConverter(grammarOpts.PropOrder).GrammarFromBytes(dat, options...) + + converter := NewSchemaConverter(*grammarOpts) + return converter.GrammarFromBytes(dat, options...) +} + +type SchemaConverter interface { + GrammarFromBytes([]byte, ...func(*grammars.GrammarOption)) (string, error) +} + +func NewSchemaConverter(opt grammars.GrammarOption) SchemaConverter { + switch { + case opt.SchemaType == grammars.LLama31Schema: + return grammars.NewLLama31SchemaConverter(opt.FunctionName) + } + return grammars.NewJSONSchemaConverter(opt.PropOrder) } diff --git a/pkg/functions/functions.go b/pkg/functions/functions.go index 2690b8ec..19012d53 100644 --- a/pkg/functions/functions.go +++ b/pkg/functions/functions.go @@ -95,11 +95,3 @@ func (f Functions) Select(name string) Functions { return funcs } - -func jsonString(v interface{}) (string, error) { - b, err := json.Marshal(v) - if err != nil { - return "", err - } - return string(b), nil -} diff --git a/pkg/functions/functions_suite_test.go b/pkg/functions/functions_suite_test.go index 59a90ab0..ab743609 100644 --- a/pkg/functions/functions_suite_test.go +++ b/pkg/functions/functions_suite_test.go @@ -3,23 +3,11 @@ package functions_test import ( "testing" - . "github.com/mudler/LocalAI/pkg/functions" - . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) -func TestGrammar(t *testing.T) { +func TestFunctions(t *testing.T) { RegisterFailHandler(Fail) - RunSpecs(t, "Grammar test suite") -} - -func createFunction(field1 string, field2 string, name string, properties map[string]interface{}) map[string]interface{} { - property := map[string]interface{}{} - property[field1] = FunctionName{Const: name} - property[field2] = Argument{ - Type: "object", - Properties: properties, - } - return property + RunSpecs(t, "Functions test suite") } diff --git a/pkg/functions/bnf_rules.go b/pkg/functions/grammars/bnf_rules.go similarity index 85% rename from pkg/functions/bnf_rules.go rename to pkg/functions/grammars/bnf_rules.go index 13aa3654..469e187a 100644 --- a/pkg/functions/bnf_rules.go +++ b/pkg/functions/grammars/bnf_rules.go @@ -1,6 +1,9 @@ -package functions +package grammars -import "regexp" +import ( + "encoding/json" + "regexp" +) var ( PRIMITIVE_RULES = map[string]string{ @@ -45,3 +48,11 @@ const ( ("," realvalue)* )? "]"` ) + +func jsonString(v interface{}) (string, error) { + b, err := json.Marshal(v) + if err != nil { + return "", err + } + return string(b), nil +} diff --git a/pkg/functions/grammars/grammars_suite_test.go b/pkg/functions/grammars/grammars_suite_test.go new file mode 100644 index 00000000..5ac02bc1 --- /dev/null +++ b/pkg/functions/grammars/grammars_suite_test.go @@ -0,0 +1,25 @@ +package grammars_test + +import ( + "testing" + + . "github.com/mudler/LocalAI/pkg/functions" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestGrammar(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Grammar test suite") +} + +func createFunction(field1 string, field2 string, name string, properties map[string]interface{}) map[string]interface{} { + property := map[string]interface{}{} + property[field1] = FunctionName{Const: name} + property[field2] = Argument{ + Type: "object", + Properties: properties, + } + return property +} diff --git a/pkg/functions/grammar_json_schema.go b/pkg/functions/grammars/json_schema.go similarity index 67% rename from pkg/functions/grammar_json_schema.go rename to pkg/functions/grammars/json_schema.go index 5ffc0ba5..df4ca6a1 100644 --- a/pkg/functions/grammar_json_schema.go +++ b/pkg/functions/grammars/json_schema.go @@ -1,4 +1,4 @@ -package functions +package grammars // a golang port of https://github.com/ggerganov/llama.cpp/pull/1887 @@ -7,13 +7,11 @@ import ( "fmt" "sort" "strings" - - "github.com/mudler/LocalAI/pkg/utils" ) type JSONSchemaConverter struct { propOrder map[string]int - rules map[string]string + rules Rules } func NewJSONSchemaConverter(propOrder string) *JSONSchemaConverter { @@ -60,90 +58,6 @@ func (sc *JSONSchemaConverter) addRule(name, rule string) string { return key } -func (sc *JSONSchemaConverter) finalizeGrammar(options ...func(*GrammarOption)) string { - - grammarOpts := &GrammarOption{} - grammarOpts.Apply(options...) - - prefix := grammarOpts.Prefix - maybeArray := grammarOpts.MaybeArray - disableParallelNewLines := grammarOpts.DisableParallelNewLines - maybeString := grammarOpts.MaybeString - noMixedFreeString := grammarOpts.NoMixedFreeString - - var lines []string - - swapRoot := maybeArray || maybeString || prefix != "" - - // write down the computed rules. - // if maybeArray is true, we need to add the array rule and slightly tweak the root rule - for name, rule := range sc.rules { - if swapRoot && name == "root" { - name = "realvalue" - } - lines = append(lines, fmt.Sprintf("%s ::= %s", name, rule)) - } - - if !swapRoot { - return strings.Join(lines, "\n") - } - - newRoot := "realvalue" - if maybeArray { - newRoot = "arr | realvalue" - } - - freestringRule := "mixedstring" - if noMixedFreeString { - freestringRule = "freestring" - } - - if prefix != "" { - // quote newlines in suffix - prefix = utils.EscapeNewLines(prefix) - - if maybeArray && maybeString { - newRoot = "(" + newRoot + ")" - } - - if maybeString { - //newRoot = "( (\"" + suffix + "\" " + newRoot + ") | freestring ) " - newRoot = "( \"" + prefix + "\" " + newRoot + " | " + freestringRule + " ) " - } else { - newRoot = "\"" + prefix + "\" " + "" + newRoot + "" - } - } else if maybeString { - if maybeArray { - // newRoot = "(" + newRoot + ")" - } - - newRoot = freestringRule + " | " + newRoot - } - - lines = append(lines, fmt.Sprintf("%s ::= %s", "root", newRoot)) - if disableParallelNewLines { - lines = append(lines, array) - } else { - lines = append(lines, arrayNewLines) - } - - if maybeArray { - if grammarOpts.ExpectStringsAfterJSON { - lines = append(lines, `mixedstring ::= freestring | freestring arr freestring | (freestring realvalue freestring)* | realvalue | arr`) - } else { - lines = append(lines, `mixedstring ::= freestring | freestring arr | freestring realvalue | realvalue | arr`) - } - } else { - if grammarOpts.ExpectStringsAfterJSON { - lines = append(lines, `mixedstring ::= freestring | (freestring realvalue freestring)* | realvalue`) - } else { - lines = append(lines, `mixedstring ::= freestring | freestring realvalue | realvalue`) - } - } - - return strings.Join(lines, "\n") -} - func (sc *JSONSchemaConverter) visit(schema map[string]interface{}, name string, rootSchema map[string]interface{}) (string, error) { st, existType := schema["type"] var schemaType string @@ -182,7 +96,10 @@ func (sc *JSONSchemaConverter) visit(schema map[string]interface{}, name string, rule := strings.Join(alternatives, " | ") return sc.addRule(ruleName, rule), nil } else if ref, exists := schema["$ref"].(string); exists { - referencedSchema := sc.resolveReference(ref, rootSchema) + referencedSchema, err := sc.resolveReference(ref, rootSchema) + if err != nil { + return "", err + } return sc.visit(referencedSchema, name, rootSchema) } else if constVal, exists := schema["const"]; exists { literal, err := sc.formatLiteral((constVal)) @@ -257,7 +174,7 @@ func (sc *JSONSchemaConverter) visit(schema map[string]interface{}, name string, } else { primitiveRule, exists := PRIMITIVE_RULES[schemaType] if !exists { - panic(fmt.Sprintf("Unrecognized schema: %v", schema)) + return "", fmt.Errorf("unrecognized schema: %v", schema) } if ruleName == "root" { schemaType = "root" @@ -265,27 +182,23 @@ func (sc *JSONSchemaConverter) visit(schema map[string]interface{}, name string, return sc.addRule(schemaType, primitiveRule), nil } } -func (sc *JSONSchemaConverter) resolveReference(ref string, rootSchema map[string]interface{}) map[string]interface{} { +func (sc *JSONSchemaConverter) resolveReference(ref string, rootSchema map[string]interface{}) (map[string]interface{}, error) { if !strings.HasPrefix(ref, "#/$defs/") { - panic(fmt.Sprintf("Invalid reference format: %s", ref)) + return nil, fmt.Errorf("invalid reference format: %s", ref) } defKey := strings.TrimPrefix(ref, "#/$defs/") definitions, exists := rootSchema["$defs"].(map[string]interface{}) if !exists { - fmt.Println(rootSchema) - - panic("No definitions found in the schema") + return nil, fmt.Errorf("no definitions found in the schema: %s", rootSchema) } def, exists := definitions[defKey].(map[string]interface{}) if !exists { - fmt.Println(definitions) - - panic(fmt.Sprintf("Definition not found: %s", defKey)) + return nil, fmt.Errorf("definition not found: %s %+v", defKey, definitions) } - return def + return def, nil } func (sc *JSONSchemaConverter) Grammar(schema map[string]interface{}, options ...func(*GrammarOption)) (string, error) { @@ -294,7 +207,7 @@ func (sc *JSONSchemaConverter) Grammar(schema map[string]interface{}, options .. if err != nil { return "", err } - return sc.finalizeGrammar(options...), nil + return sc.rules.ToGrammar(options...), nil } func (sc *JSONSchemaConverter) GrammarFromBytes(b []byte, options ...func(*GrammarOption)) (string, error) { diff --git a/pkg/functions/grammar_json_schema_test.go b/pkg/functions/grammars/json_schema_test.go similarity index 99% rename from pkg/functions/grammar_json_schema_test.go rename to pkg/functions/grammars/json_schema_test.go index 56c5fe1e..5fc4a602 100644 --- a/pkg/functions/grammar_json_schema_test.go +++ b/pkg/functions/grammars/json_schema_test.go @@ -1,9 +1,10 @@ -package functions_test +package grammars_test import ( "strings" . "github.com/mudler/LocalAI/pkg/functions" + . "github.com/mudler/LocalAI/pkg/functions/grammars" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) diff --git a/pkg/functions/grammars/llama31_schema.go b/pkg/functions/grammars/llama31_schema.go new file mode 100644 index 00000000..04b74aa5 --- /dev/null +++ b/pkg/functions/grammars/llama31_schema.go @@ -0,0 +1,281 @@ +package grammars + +import ( + "encoding/json" + "fmt" + "regexp" + "sort" + "strings" +) + +type LLama31SchemaConverter struct { + fnName string + rules Rules +} + +func NewLLama31SchemaConverter(fnName string) *LLama31SchemaConverter { + rules := make(map[string]string) + rules["space"] = SPACE_RULE + if fnName == "" { + fnName = "name" + } + + return &LLama31SchemaConverter{ + rules: rules, + fnName: fnName, + } +} + +var GRAMMAR_LITERAL_ESCAPESLlama = map[string]string{ + "\r": `\r`, + "\n": `\n`, +} + +var GRAMMAR_LITERAL_ESCAPE_RELlama = regexp.MustCompile(`[\r\n]`) + +func (sc *LLama31SchemaConverter) formatLiteral(literal interface{}) (string, error) { + jLiteral, err := jsonString(literal) + if err != nil { + return "", err + } + escaped := GRAMMAR_LITERAL_ESCAPE_RELlama.ReplaceAllStringFunc(jLiteral, func(match string) string { + return GRAMMAR_LITERAL_ESCAPESLlama[match] + }) + return escaped, nil +} + +func (sc *LLama31SchemaConverter) formatLiteralQuoted(literal interface{}) (string, error) { + jLiteral, err := jsonString(literal) + if err != nil { + return "", err + } + escaped := GRAMMAR_LITERAL_ESCAPE_RE.ReplaceAllStringFunc(jLiteral, func(match string) string { + return GRAMMAR_LITERAL_ESCAPES[match] + }) + return fmt.Sprintf(`"%s"`, escaped), nil +} + +func (sc *LLama31SchemaConverter) addRule(name, rule string) string { + escName := INVALID_RULE_CHARS_RE.ReplaceAllString(name, "-") + key := escName + if existingRule, ok := sc.rules[escName]; ok && existingRule != rule { + i := 0 + for { + key = fmt.Sprintf("%s%d", escName, i) + if _, ok := sc.rules[key]; !ok { + break + } + i++ + } + } + sc.rules[key] = rule + return key +} + +func (sc *LLama31SchemaConverter) visit(schema map[string]interface{}, name string, rootSchema map[string]interface{}) (string, error) { + st, existType := schema["type"] + var schemaType string + if existType { + schemaType = st.(string) + } + ruleName := name + if name == "" { + ruleName = "root" + } + _, oneOfExists := schema["oneOf"] + _, anyOfExists := schema["anyOf"] + if oneOfExists || anyOfExists { + var alternatives []string + oneOfSchemas, oneOfExists := schema["oneOf"].([]interface{}) + anyOfSchemas, anyOfExists := schema["anyOf"].([]interface{}) + + if oneOfExists { + for i, altSchema := range oneOfSchemas { + alternative, err := sc.visit(altSchema.(map[string]interface{}), fmt.Sprintf("%s-%d", ruleName, i), rootSchema) + if err != nil { + return "", err + } + alternatives = append(alternatives, alternative) + } + } else if anyOfExists { + for i, altSchema := range anyOfSchemas { + alternative, err := sc.visit(altSchema.(map[string]interface{}), fmt.Sprintf("%s-%d", ruleName, i), rootSchema) + if err != nil { + return "", err + } + alternatives = append(alternatives, alternative) + } + } + + rule := strings.Join(alternatives, " | ") + return sc.addRule(ruleName, rule), nil + } else if ref, exists := schema["$ref"].(string); exists { + referencedSchema, err := sc.resolveReference(ref, rootSchema) + if err != nil { + return "", err + } + return sc.visit(referencedSchema, name, rootSchema) + } else if constVal, exists := schema["const"]; exists { + + literal, err := sc.formatLiteral((constVal)) + if err != nil { + return "", err + } + return sc.addRule(ruleName, literal), nil + } else if enumVals, exists := schema["enum"].([]interface{}); exists { + var enumRules []string + for _, enumVal := range enumVals { + enumRule, err := sc.formatLiteralQuoted(enumVal) + if err != nil { + return "", err + } + enumRules = append(enumRules, enumRule) + } + rule := strings.Join(enumRules, " | ") + return sc.addRule(ruleName, rule), nil + } else if properties, exists := schema["properties"].(map[string]interface{}); schemaType == "object" && exists { + baseProperty := false + depth := strings.Split(name, "-") + if len(depth) == 2 { + baseProperty = true + } + type propData []struct { + propName string + propSchema map[string]interface{} + } + var propPairs propData + + for propName, propSchema := range properties { + propPairs = append(propPairs, struct { + propName string + propSchema map[string]interface{} + }{propName: propName, propSchema: propSchema.(map[string]interface{})}) + } + + sort.Slice(propPairs, func(i, j int) bool { + return propPairs[i].propName < propPairs[j].propName + }) + + var rule strings.Builder + if baseProperty { + rule.WriteString(`"{" `, propRuleName)) + + for _, propPair := range propPairs { + propName := propPair.propName + propSchema := propPair.propSchema + propRuleName, err := sc.visit(propSchema, fmt.Sprintf("%s-%s", ruleName, propName), rootSchema) + if err != nil { + return "", err + } + + rule.WriteString(propRuleName) + } + + rule.WriteString(` "}"`) + + } else { + for i, propPair := range propPairs { + propName := propPair.propName + propSchema := propPair.propSchema + propRuleName, err := sc.visit(propSchema, fmt.Sprintf("%s-%s", ruleName, propName), rootSchema) + if err != nil { + return "", err + } + lPropName, err := sc.formatLiteralQuoted(propName) + if err != nil { + return "", err + } + if i > 0 { + rule.WriteString(` "," space`) + } + + rule.WriteString(fmt.Sprintf(` %s space ":" space %s`, lPropName, propRuleName)) + } + + } + + if !baseProperty { + rule.WriteString(` "}" space`) + } + + return sc.addRule(ruleName, rule.String()), nil + } else if items, exists := schema["items"].(map[string]interface{}); schemaType == "array" && exists { + itemRuleName, err := sc.visit(items, fmt.Sprintf("%s-item", ruleName), rootSchema) + if err != nil { + return "", err + } + rule := fmt.Sprintf(`"[" space (%s ("," space %s)*)? "]" space`, itemRuleName, itemRuleName) + return sc.addRule(ruleName, rule), nil + } else { + primitiveRule, exists := PRIMITIVE_RULES[schemaType] + if !exists { + return "", fmt.Errorf("unrecognized schema: %v", schema) + } + if ruleName == "root" { + schemaType = "root" + } + return sc.addRule(schemaType, primitiveRule), nil + } +} +func (sc *LLama31SchemaConverter) resolveReference(ref string, rootSchema map[string]interface{}) (map[string]interface{}, error) { + if !strings.HasPrefix(ref, "#/$defs/") { + return nil, fmt.Errorf("invalid reference format: %s", ref) + } + + defKey := strings.TrimPrefix(ref, "#/$defs/") + definitions, exists := rootSchema["$defs"].(map[string]interface{}) + if !exists { + return nil, fmt.Errorf("no definitions found in the schema: %s", rootSchema) + } + + def, exists := definitions[defKey].(map[string]interface{}) + if !exists { + return nil, fmt.Errorf("definition not found: %s %+v", defKey, definitions) + } + + return def, nil +} + +func (sc *LLama31SchemaConverter) Grammar(schema map[string]interface{}, options ...func(*GrammarOption)) (string, error) { + sc.addRule("freestring", PRIMITIVE_RULES["freestring"]) + _, err := sc.visit(schema, "", schema) + if err != nil { + return "", err + } + return sc.rules.ToGrammar(options...), nil +} + +func (sc *LLama31SchemaConverter) GrammarFromBytes(b []byte, options ...func(*GrammarOption)) (string, error) { + var schema map[string]interface{} + err := json.Unmarshal(b, &schema) + if err != nil { + return "", err + } + return sc.Grammar(schema, options...) +} diff --git a/pkg/functions/grammars/llama31_schema_test.go b/pkg/functions/grammars/llama31_schema_test.go new file mode 100644 index 00000000..84d09bd5 --- /dev/null +++ b/pkg/functions/grammars/llama31_schema_test.go @@ -0,0 +1,76 @@ +package grammars_test + +import ( + "strings" + + . "github.com/mudler/LocalAI/pkg/functions/grammars" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +const ( + testllama31Input1 = ` + { + "oneOf": [ + { + "type": "object", + "properties": { + "function": {"const": "create_event"}, + "arguments": { + "type": "object", + "properties": { + "title": {"type": "string"}, + "date": {"type": "string"}, + "time": {"type": "string"} + } + } + } + }, + { + "type": "object", + "properties": { + "function": {"const": "search"}, + "arguments": { + "type": "object", + "properties": { + "query": {"type": "string"} + } + } + } + } + ] + }` + // {{"example_name": "example_value"}} + testllama31inputResult1 = `root-0-function ::= "create_event" +freestring ::= ( + [^"\\] | + "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) + )* space +root-0 ::= "{" root-0-arguments "}" +root-1-arguments ::= "{" space "\"query\"" space ":" space string "}" space +root ::= root-0 | root-1 +space ::= " "? +root-0-arguments ::= "{" space "\"date\"" space ":" space string "," space "\"time\"" space ":" space string "," space "\"title\"" space ":" space string "}" space +root-1 ::= "{" root-1-arguments "}" +string ::= "\"" ( + [^"\\] | + "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) +)* "\"" space +root-1-function ::= "search"` +) + +var _ = Describe("JSON schema grammar tests", func() { + Context("JSON", func() { + It("generates a valid grammar from JSON schema", func() { + grammar, err := NewLLama31SchemaConverter("function").GrammarFromBytes([]byte(testllama31Input1)) + Expect(err).ToNot(HaveOccurred()) + results := strings.Split(testllama31inputResult1, "\n") + for _, r := range results { + if r != "" { + Expect(grammar).To(ContainSubstring(r)) + } + } + Expect(len(results)).To(Equal(len(strings.Split(grammar, "\n")))) + }) + }) +}) diff --git a/pkg/functions/options.go b/pkg/functions/grammars/options.go similarity index 76% rename from pkg/functions/options.go rename to pkg/functions/grammars/options.go index 3a341a43..07c6c951 100644 --- a/pkg/functions/options.go +++ b/pkg/functions/grammars/options.go @@ -1,4 +1,4 @@ -package functions +package grammars type GrammarOption struct { PropOrder string @@ -8,6 +8,9 @@ type GrammarOption struct { MaybeString bool NoMixedFreeString bool ExpectStringsAfterJSON bool + + FunctionName string + SchemaType SchemaConverterType } func (o *GrammarOption) Apply(options ...func(*GrammarOption)) { @@ -48,3 +51,15 @@ func SetPropOrder(order string) func(*GrammarOption) { o.PropOrder = order } } + +func WithSchemaType(schemaType SchemaConverterType) func(*GrammarOption) { + return func(o *GrammarOption) { + o.SchemaType = schemaType + } +} + +func WithFunctionName(name string) func(*GrammarOption) { + return func(o *GrammarOption) { + o.FunctionName = name + } +} diff --git a/pkg/functions/grammars/rules.go b/pkg/functions/grammars/rules.go new file mode 100644 index 00000000..84fc8a25 --- /dev/null +++ b/pkg/functions/grammars/rules.go @@ -0,0 +1,93 @@ +package grammars + +import ( + "fmt" + "strings" + + "github.com/mudler/LocalAI/pkg/utils" +) + +type Rules map[string]string + +func (rules Rules) ToGrammar(options ...func(*GrammarOption)) string { + grammarOpts := &GrammarOption{} + grammarOpts.Apply(options...) + + prefix := grammarOpts.Prefix + maybeArray := grammarOpts.MaybeArray + disableParallelNewLines := grammarOpts.DisableParallelNewLines + maybeString := grammarOpts.MaybeString + noMixedFreeString := grammarOpts.NoMixedFreeString + + var lines []string + + swapRoot := maybeArray || maybeString || prefix != "" + + // write down the computed rules. + // if maybeArray is true, we need to add the array rule and slightly tweak the root rule + for name, rule := range rules { + if swapRoot && name == "root" { + name = "realvalue" + } + lines = append(lines, fmt.Sprintf("%s ::= %s", name, rule)) + } + + if !swapRoot { + return strings.Join(lines, "\n") + } + + newRoot := "realvalue" + if maybeArray { + newRoot = "arr | realvalue" + } + + freestringRule := "mixedstring" + if noMixedFreeString { + freestringRule = "freestring" + } + + if prefix != "" { + // quote newlines in suffix + prefix = utils.EscapeNewLines(prefix) + + if maybeArray && maybeString { + newRoot = "(" + newRoot + ")" + } + + if maybeString { + //newRoot = "( (\"" + suffix + "\" " + newRoot + ") | freestring ) " + newRoot = "( \"" + prefix + "\" " + newRoot + " | " + freestringRule + " ) " + } else { + newRoot = "\"" + prefix + "\" " + "" + newRoot + "" + } + } else if maybeString { + if maybeArray { + // newRoot = "(" + newRoot + ")" + } + + newRoot = freestringRule + " | " + newRoot + } + + lines = append(lines, fmt.Sprintf("%s ::= %s", "root", newRoot)) + if disableParallelNewLines { + lines = append(lines, array) + } else { + lines = append(lines, arrayNewLines) + } + + if maybeArray { + if grammarOpts.ExpectStringsAfterJSON { + lines = append(lines, `mixedstring ::= freestring | freestring arr freestring | (freestring realvalue freestring)* | realvalue | arr`) + } else { + lines = append(lines, `mixedstring ::= freestring | freestring arr | freestring realvalue | realvalue | arr`) + } + } else { + if grammarOpts.ExpectStringsAfterJSON { + lines = append(lines, `mixedstring ::= freestring | (freestring realvalue freestring)* | realvalue`) + } else { + lines = append(lines, `mixedstring ::= freestring | freestring realvalue | realvalue`) + } + } + + return strings.Join(lines, "\n") +} diff --git a/pkg/functions/grammars/types.go b/pkg/functions/grammars/types.go new file mode 100644 index 00000000..1fe6444a --- /dev/null +++ b/pkg/functions/grammars/types.go @@ -0,0 +1,33 @@ +package grammars + +type SchemaConverterType int + +const ( + JSONSchema SchemaConverterType = iota + LLama31Schema +) + +const ( + LlamaType string = "llama3.1" + JSONType string = "json" +) + +func (s SchemaConverterType) String() string { + switch s { + case JSONSchema: + return JSONType + case LLama31Schema: + return LlamaType + } + return "unknown" +} + +func NewType(t string) SchemaConverterType { + switch t { + case JSONType: + return JSONSchema + case LlamaType: + return LLama31Schema + } + return JSONSchema +} diff --git a/pkg/functions/parse.go b/pkg/functions/parse.go index 8e848a60..f5593690 100644 --- a/pkg/functions/parse.go +++ b/pkg/functions/parse.go @@ -7,6 +7,7 @@ import ( "regexp" "strings" + "github.com/mudler/LocalAI/pkg/functions/grammars" "github.com/mudler/LocalAI/pkg/utils" "github.com/rs/zerolog/log" ) @@ -22,7 +23,9 @@ type GrammarConfig struct { MixedMode bool `yaml:"mixed_mode"` // NoMixedFreeString disables the mixed mode for free strings - // In this way if the LLM selects a free string, it won't be mixed necessarly with JSON objects + // In this way if the LLM selects a free string, it won't be mixed necessarly with JSON objects. + // For example, if enabled the LLM or returns a JSON object or a free string, but not a mix of both + // If disabled(default): the LLM can return a JSON object surrounded by free strings (e.g. `this is the JSON result: { "bar": "baz" } for your question`). This forces the LLM to return at least a JSON object, but its not going to be strict NoMixedFreeString bool `yaml:"no_mixed_free_string"` // NoGrammar disables the grammar parsing and parses the responses directly from the LLM @@ -39,6 +42,10 @@ type GrammarConfig struct { // for instance name,arguments will make print { "name": "foo", "arguments": { "bar": "baz" } } // instead of { "arguments": { "bar": "baz" }, "name": "foo" } PropOrder string `yaml:"properties_order"` + + // SchemaType can be configured to use a specific schema type to force the grammar + // available : json, llama3.1 + SchemaType string `yaml:"schema_type"` } // FunctionsConfig is the configuration for the tool/function call. @@ -92,28 +99,36 @@ type FuncCallResults struct { Arguments string } -func (g GrammarConfig) Options() []func(o *GrammarOption) { - opts := []func(o *GrammarOption){} - if g.MixedMode { - opts = append(opts, EnableMaybeString) +func (g FunctionsConfig) GrammarOptions() []func(o *grammars.GrammarOption) { + opts := []func(o *grammars.GrammarOption){} + if g.GrammarConfig.MixedMode { + opts = append(opts, grammars.EnableMaybeString) } - if g.ParallelCalls { - opts = append(opts, EnableMaybeArray) + if g.GrammarConfig.ParallelCalls { + opts = append(opts, grammars.EnableMaybeArray) } - if g.DisableParallelNewLines { - opts = append(opts, DisableParallelNewLines) + if g.GrammarConfig.DisableParallelNewLines { + opts = append(opts, grammars.DisableParallelNewLines) } - if g.Prefix != "" { - opts = append(opts, SetPrefix(g.Prefix)) + if g.GrammarConfig.Prefix != "" { + opts = append(opts, grammars.SetPrefix(g.GrammarConfig.Prefix)) } - if g.NoMixedFreeString { - opts = append(opts, NoMixedFreeString) + if g.GrammarConfig.NoMixedFreeString { + opts = append(opts, grammars.NoMixedFreeString) } - if g.ExpectStringsAfterJSON { - opts = append(opts, ExpectStringsAfterJSON) + if g.GrammarConfig.ExpectStringsAfterJSON { + opts = append(opts, grammars.ExpectStringsAfterJSON) } - opts = append(opts, SetPropOrder(g.PropOrder)) + if g.GrammarConfig.SchemaType != "" { + opts = append(opts, grammars.WithSchemaType(grammars.NewType(g.GrammarConfig.SchemaType))) + } + + if g.FunctionNameKey != "" { + opts = append(opts, grammars.WithFunctionName(g.FunctionNameKey)) + } + + opts = append(opts, grammars.SetPropOrder(g.GrammarConfig.PropOrder)) return opts }