diff --git a/core/http/endpoints/openai/chat.go b/core/http/endpoints/openai/chat.go
index c7afb7bf..86b75601 100644
--- a/core/http/endpoints/openai/chat.go
+++ b/core/http/endpoints/openai/chat.go
@@ -226,12 +226,12 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
// Update input grammar
jsStruct := funcs.ToJSONStructure(config.FunctionsConfig.FunctionNameKey, config.FunctionsConfig.FunctionNameKey)
- g, err := jsStruct.Grammar(config.FunctionsConfig.GrammarConfig.Options()...)
+ g, err := jsStruct.Grammar(config.FunctionsConfig.GrammarOptions()...)
if err == nil {
config.Grammar = g
}
case input.JSONFunctionGrammarObject != nil:
- g, err := input.JSONFunctionGrammarObject.Grammar(config.FunctionsConfig.GrammarConfig.Options()...)
+ g, err := input.JSONFunctionGrammarObject.Grammar(config.FunctionsConfig.GrammarOptions()...)
if err == nil {
config.Grammar = g
}
diff --git a/pkg/functions/function_structure.go b/pkg/functions/function_structure.go
index 62cc68fa..c4337d67 100644
--- a/pkg/functions/function_structure.go
+++ b/pkg/functions/function_structure.go
@@ -1,6 +1,10 @@
package functions
-import "encoding/json"
+import (
+ "encoding/json"
+
+ "github.com/mudler/LocalAI/pkg/functions/grammars"
+)
type Item struct {
Type string `json:"type"`
@@ -13,13 +17,27 @@ type JSONFunctionStructure struct {
Defs map[string]interface{} `json:"$defs,omitempty"`
}
-func (j JSONFunctionStructure) Grammar(options ...func(*GrammarOption)) (string, error) {
- grammarOpts := &GrammarOption{}
+func (j JSONFunctionStructure) Grammar(options ...func(*grammars.GrammarOption)) (string, error) {
+ grammarOpts := &grammars.GrammarOption{}
grammarOpts.Apply(options...)
dat, err := json.Marshal(j)
if err != nil {
return "", err
}
- return NewJSONSchemaConverter(grammarOpts.PropOrder).GrammarFromBytes(dat, options...)
+
+ converter := NewSchemaConverter(*grammarOpts)
+ return converter.GrammarFromBytes(dat, options...)
+}
+
+type SchemaConverter interface {
+ GrammarFromBytes([]byte, ...func(*grammars.GrammarOption)) (string, error)
+}
+
+func NewSchemaConverter(opt grammars.GrammarOption) SchemaConverter {
+ switch {
+ case opt.SchemaType == grammars.LLama31Schema:
+ return grammars.NewLLama31SchemaConverter(opt.FunctionName)
+ }
+ return grammars.NewJSONSchemaConverter(opt.PropOrder)
}
diff --git a/pkg/functions/functions.go b/pkg/functions/functions.go
index 2690b8ec..19012d53 100644
--- a/pkg/functions/functions.go
+++ b/pkg/functions/functions.go
@@ -95,11 +95,3 @@ func (f Functions) Select(name string) Functions {
return funcs
}
-
-func jsonString(v interface{}) (string, error) {
- b, err := json.Marshal(v)
- if err != nil {
- return "", err
- }
- return string(b), nil
-}
diff --git a/pkg/functions/functions_suite_test.go b/pkg/functions/functions_suite_test.go
index 59a90ab0..ab743609 100644
--- a/pkg/functions/functions_suite_test.go
+++ b/pkg/functions/functions_suite_test.go
@@ -3,23 +3,11 @@ package functions_test
import (
"testing"
- . "github.com/mudler/LocalAI/pkg/functions"
-
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
-func TestGrammar(t *testing.T) {
+func TestFunctions(t *testing.T) {
RegisterFailHandler(Fail)
- RunSpecs(t, "Grammar test suite")
-}
-
-func createFunction(field1 string, field2 string, name string, properties map[string]interface{}) map[string]interface{} {
- property := map[string]interface{}{}
- property[field1] = FunctionName{Const: name}
- property[field2] = Argument{
- Type: "object",
- Properties: properties,
- }
- return property
+ RunSpecs(t, "Functions test suite")
}
diff --git a/pkg/functions/bnf_rules.go b/pkg/functions/grammars/bnf_rules.go
similarity index 85%
rename from pkg/functions/bnf_rules.go
rename to pkg/functions/grammars/bnf_rules.go
index 13aa3654..469e187a 100644
--- a/pkg/functions/bnf_rules.go
+++ b/pkg/functions/grammars/bnf_rules.go
@@ -1,6 +1,9 @@
-package functions
+package grammars
-import "regexp"
+import (
+ "encoding/json"
+ "regexp"
+)
var (
PRIMITIVE_RULES = map[string]string{
@@ -45,3 +48,11 @@ const (
("," realvalue)*
)? "]"`
)
+
+func jsonString(v interface{}) (string, error) {
+ b, err := json.Marshal(v)
+ if err != nil {
+ return "", err
+ }
+ return string(b), nil
+}
diff --git a/pkg/functions/grammars/grammars_suite_test.go b/pkg/functions/grammars/grammars_suite_test.go
new file mode 100644
index 00000000..5ac02bc1
--- /dev/null
+++ b/pkg/functions/grammars/grammars_suite_test.go
@@ -0,0 +1,25 @@
+package grammars_test
+
+import (
+ "testing"
+
+ . "github.com/mudler/LocalAI/pkg/functions"
+
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+)
+
+func TestGrammar(t *testing.T) {
+ RegisterFailHandler(Fail)
+ RunSpecs(t, "Grammar test suite")
+}
+
+func createFunction(field1 string, field2 string, name string, properties map[string]interface{}) map[string]interface{} {
+ property := map[string]interface{}{}
+ property[field1] = FunctionName{Const: name}
+ property[field2] = Argument{
+ Type: "object",
+ Properties: properties,
+ }
+ return property
+}
diff --git a/pkg/functions/grammar_json_schema.go b/pkg/functions/grammars/json_schema.go
similarity index 67%
rename from pkg/functions/grammar_json_schema.go
rename to pkg/functions/grammars/json_schema.go
index 5ffc0ba5..df4ca6a1 100644
--- a/pkg/functions/grammar_json_schema.go
+++ b/pkg/functions/grammars/json_schema.go
@@ -1,4 +1,4 @@
-package functions
+package grammars
// a golang port of https://github.com/ggerganov/llama.cpp/pull/1887
@@ -7,13 +7,11 @@ import (
"fmt"
"sort"
"strings"
-
- "github.com/mudler/LocalAI/pkg/utils"
)
type JSONSchemaConverter struct {
propOrder map[string]int
- rules map[string]string
+ rules Rules
}
func NewJSONSchemaConverter(propOrder string) *JSONSchemaConverter {
@@ -60,90 +58,6 @@ func (sc *JSONSchemaConverter) addRule(name, rule string) string {
return key
}
-func (sc *JSONSchemaConverter) finalizeGrammar(options ...func(*GrammarOption)) string {
-
- grammarOpts := &GrammarOption{}
- grammarOpts.Apply(options...)
-
- prefix := grammarOpts.Prefix
- maybeArray := grammarOpts.MaybeArray
- disableParallelNewLines := grammarOpts.DisableParallelNewLines
- maybeString := grammarOpts.MaybeString
- noMixedFreeString := grammarOpts.NoMixedFreeString
-
- var lines []string
-
- swapRoot := maybeArray || maybeString || prefix != ""
-
- // write down the computed rules.
- // if maybeArray is true, we need to add the array rule and slightly tweak the root rule
- for name, rule := range sc.rules {
- if swapRoot && name == "root" {
- name = "realvalue"
- }
- lines = append(lines, fmt.Sprintf("%s ::= %s", name, rule))
- }
-
- if !swapRoot {
- return strings.Join(lines, "\n")
- }
-
- newRoot := "realvalue"
- if maybeArray {
- newRoot = "arr | realvalue"
- }
-
- freestringRule := "mixedstring"
- if noMixedFreeString {
- freestringRule = "freestring"
- }
-
- if prefix != "" {
- // quote newlines in suffix
- prefix = utils.EscapeNewLines(prefix)
-
- if maybeArray && maybeString {
- newRoot = "(" + newRoot + ")"
- }
-
- if maybeString {
- //newRoot = "( (\"" + suffix + "\" " + newRoot + ") | freestring ) "
- newRoot = "( \"" + prefix + "\" " + newRoot + " | " + freestringRule + " ) "
- } else {
- newRoot = "\"" + prefix + "\" " + "" + newRoot + ""
- }
- } else if maybeString {
- if maybeArray {
- // newRoot = "(" + newRoot + ")"
- }
-
- newRoot = freestringRule + " | " + newRoot
- }
-
- lines = append(lines, fmt.Sprintf("%s ::= %s", "root", newRoot))
- if disableParallelNewLines {
- lines = append(lines, array)
- } else {
- lines = append(lines, arrayNewLines)
- }
-
- if maybeArray {
- if grammarOpts.ExpectStringsAfterJSON {
- lines = append(lines, `mixedstring ::= freestring | freestring arr freestring | (freestring realvalue freestring)* | realvalue | arr`)
- } else {
- lines = append(lines, `mixedstring ::= freestring | freestring arr | freestring realvalue | realvalue | arr`)
- }
- } else {
- if grammarOpts.ExpectStringsAfterJSON {
- lines = append(lines, `mixedstring ::= freestring | (freestring realvalue freestring)* | realvalue`)
- } else {
- lines = append(lines, `mixedstring ::= freestring | freestring realvalue | realvalue`)
- }
- }
-
- return strings.Join(lines, "\n")
-}
-
func (sc *JSONSchemaConverter) visit(schema map[string]interface{}, name string, rootSchema map[string]interface{}) (string, error) {
st, existType := schema["type"]
var schemaType string
@@ -182,7 +96,10 @@ func (sc *JSONSchemaConverter) visit(schema map[string]interface{}, name string,
rule := strings.Join(alternatives, " | ")
return sc.addRule(ruleName, rule), nil
} else if ref, exists := schema["$ref"].(string); exists {
- referencedSchema := sc.resolveReference(ref, rootSchema)
+ referencedSchema, err := sc.resolveReference(ref, rootSchema)
+ if err != nil {
+ return "", err
+ }
return sc.visit(referencedSchema, name, rootSchema)
} else if constVal, exists := schema["const"]; exists {
literal, err := sc.formatLiteral((constVal))
@@ -257,7 +174,7 @@ func (sc *JSONSchemaConverter) visit(schema map[string]interface{}, name string,
} else {
primitiveRule, exists := PRIMITIVE_RULES[schemaType]
if !exists {
- panic(fmt.Sprintf("Unrecognized schema: %v", schema))
+ return "", fmt.Errorf("unrecognized schema: %v", schema)
}
if ruleName == "root" {
schemaType = "root"
@@ -265,27 +182,23 @@ func (sc *JSONSchemaConverter) visit(schema map[string]interface{}, name string,
return sc.addRule(schemaType, primitiveRule), nil
}
}
-func (sc *JSONSchemaConverter) resolveReference(ref string, rootSchema map[string]interface{}) map[string]interface{} {
+func (sc *JSONSchemaConverter) resolveReference(ref string, rootSchema map[string]interface{}) (map[string]interface{}, error) {
if !strings.HasPrefix(ref, "#/$defs/") {
- panic(fmt.Sprintf("Invalid reference format: %s", ref))
+ return nil, fmt.Errorf("invalid reference format: %s", ref)
}
defKey := strings.TrimPrefix(ref, "#/$defs/")
definitions, exists := rootSchema["$defs"].(map[string]interface{})
if !exists {
- fmt.Println(rootSchema)
-
- panic("No definitions found in the schema")
+ return nil, fmt.Errorf("no definitions found in the schema: %s", rootSchema)
}
def, exists := definitions[defKey].(map[string]interface{})
if !exists {
- fmt.Println(definitions)
-
- panic(fmt.Sprintf("Definition not found: %s", defKey))
+ return nil, fmt.Errorf("definition not found: %s %+v", defKey, definitions)
}
- return def
+ return def, nil
}
func (sc *JSONSchemaConverter) Grammar(schema map[string]interface{}, options ...func(*GrammarOption)) (string, error) {
@@ -294,7 +207,7 @@ func (sc *JSONSchemaConverter) Grammar(schema map[string]interface{}, options ..
if err != nil {
return "", err
}
- return sc.finalizeGrammar(options...), nil
+ return sc.rules.ToGrammar(options...), nil
}
func (sc *JSONSchemaConverter) GrammarFromBytes(b []byte, options ...func(*GrammarOption)) (string, error) {
diff --git a/pkg/functions/grammar_json_schema_test.go b/pkg/functions/grammars/json_schema_test.go
similarity index 99%
rename from pkg/functions/grammar_json_schema_test.go
rename to pkg/functions/grammars/json_schema_test.go
index 56c5fe1e..5fc4a602 100644
--- a/pkg/functions/grammar_json_schema_test.go
+++ b/pkg/functions/grammars/json_schema_test.go
@@ -1,9 +1,10 @@
-package functions_test
+package grammars_test
import (
"strings"
. "github.com/mudler/LocalAI/pkg/functions"
+ . "github.com/mudler/LocalAI/pkg/functions/grammars"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
diff --git a/pkg/functions/grammars/llama31_schema.go b/pkg/functions/grammars/llama31_schema.go
new file mode 100644
index 00000000..04b74aa5
--- /dev/null
+++ b/pkg/functions/grammars/llama31_schema.go
@@ -0,0 +1,281 @@
+package grammars
+
+import (
+ "encoding/json"
+ "fmt"
+ "regexp"
+ "sort"
+ "strings"
+)
+
+type LLama31SchemaConverter struct {
+ fnName string
+ rules Rules
+}
+
+func NewLLama31SchemaConverter(fnName string) *LLama31SchemaConverter {
+ rules := make(map[string]string)
+ rules["space"] = SPACE_RULE
+ if fnName == "" {
+ fnName = "name"
+ }
+
+ return &LLama31SchemaConverter{
+ rules: rules,
+ fnName: fnName,
+ }
+}
+
+var GRAMMAR_LITERAL_ESCAPESLlama = map[string]string{
+ "\r": `\r`,
+ "\n": `\n`,
+}
+
+var GRAMMAR_LITERAL_ESCAPE_RELlama = regexp.MustCompile(`[\r\n]`)
+
+func (sc *LLama31SchemaConverter) formatLiteral(literal interface{}) (string, error) {
+ jLiteral, err := jsonString(literal)
+ if err != nil {
+ return "", err
+ }
+ escaped := GRAMMAR_LITERAL_ESCAPE_RELlama.ReplaceAllStringFunc(jLiteral, func(match string) string {
+ return GRAMMAR_LITERAL_ESCAPESLlama[match]
+ })
+ return escaped, nil
+}
+
+func (sc *LLama31SchemaConverter) formatLiteralQuoted(literal interface{}) (string, error) {
+ jLiteral, err := jsonString(literal)
+ if err != nil {
+ return "", err
+ }
+ escaped := GRAMMAR_LITERAL_ESCAPE_RE.ReplaceAllStringFunc(jLiteral, func(match string) string {
+ return GRAMMAR_LITERAL_ESCAPES[match]
+ })
+ return fmt.Sprintf(`"%s"`, escaped), nil
+}
+
+func (sc *LLama31SchemaConverter) addRule(name, rule string) string {
+ escName := INVALID_RULE_CHARS_RE.ReplaceAllString(name, "-")
+ key := escName
+ if existingRule, ok := sc.rules[escName]; ok && existingRule != rule {
+ i := 0
+ for {
+ key = fmt.Sprintf("%s%d", escName, i)
+ if _, ok := sc.rules[key]; !ok {
+ break
+ }
+ i++
+ }
+ }
+ sc.rules[key] = rule
+ return key
+}
+
+func (sc *LLama31SchemaConverter) visit(schema map[string]interface{}, name string, rootSchema map[string]interface{}) (string, error) {
+ st, existType := schema["type"]
+ var schemaType string
+ if existType {
+ schemaType = st.(string)
+ }
+ ruleName := name
+ if name == "" {
+ ruleName = "root"
+ }
+ _, oneOfExists := schema["oneOf"]
+ _, anyOfExists := schema["anyOf"]
+ if oneOfExists || anyOfExists {
+ var alternatives []string
+ oneOfSchemas, oneOfExists := schema["oneOf"].([]interface{})
+ anyOfSchemas, anyOfExists := schema["anyOf"].([]interface{})
+
+ if oneOfExists {
+ for i, altSchema := range oneOfSchemas {
+ alternative, err := sc.visit(altSchema.(map[string]interface{}), fmt.Sprintf("%s-%d", ruleName, i), rootSchema)
+ if err != nil {
+ return "", err
+ }
+ alternatives = append(alternatives, alternative)
+ }
+ } else if anyOfExists {
+ for i, altSchema := range anyOfSchemas {
+ alternative, err := sc.visit(altSchema.(map[string]interface{}), fmt.Sprintf("%s-%d", ruleName, i), rootSchema)
+ if err != nil {
+ return "", err
+ }
+ alternatives = append(alternatives, alternative)
+ }
+ }
+
+ rule := strings.Join(alternatives, " | ")
+ return sc.addRule(ruleName, rule), nil
+ } else if ref, exists := schema["$ref"].(string); exists {
+ referencedSchema, err := sc.resolveReference(ref, rootSchema)
+ if err != nil {
+ return "", err
+ }
+ return sc.visit(referencedSchema, name, rootSchema)
+ } else if constVal, exists := schema["const"]; exists {
+
+ literal, err := sc.formatLiteral((constVal))
+ if err != nil {
+ return "", err
+ }
+ return sc.addRule(ruleName, literal), nil
+ } else if enumVals, exists := schema["enum"].([]interface{}); exists {
+ var enumRules []string
+ for _, enumVal := range enumVals {
+ enumRule, err := sc.formatLiteralQuoted(enumVal)
+ if err != nil {
+ return "", err
+ }
+ enumRules = append(enumRules, enumRule)
+ }
+ rule := strings.Join(enumRules, " | ")
+ return sc.addRule(ruleName, rule), nil
+ } else if properties, exists := schema["properties"].(map[string]interface{}); schemaType == "object" && exists {
+ baseProperty := false
+ depth := strings.Split(name, "-")
+ if len(depth) == 2 {
+ baseProperty = true
+ }
+ type propData []struct {
+ propName string
+ propSchema map[string]interface{}
+ }
+ var propPairs propData
+
+ for propName, propSchema := range properties {
+ propPairs = append(propPairs, struct {
+ propName string
+ propSchema map[string]interface{}
+ }{propName: propName, propSchema: propSchema.(map[string]interface{})})
+ }
+
+ sort.Slice(propPairs, func(i, j int) bool {
+ return propPairs[i].propName < propPairs[j].propName
+ })
+
+ var rule strings.Builder
+ if baseProperty {
+ rule.WriteString(`"{" `, propRuleName))
+
+ for _, propPair := range propPairs {
+ propName := propPair.propName
+ propSchema := propPair.propSchema
+ propRuleName, err := sc.visit(propSchema, fmt.Sprintf("%s-%s", ruleName, propName), rootSchema)
+ if err != nil {
+ return "", err
+ }
+
+ rule.WriteString(propRuleName)
+ }
+
+ rule.WriteString(` "}"`)
+
+ } else {
+ for i, propPair := range propPairs {
+ propName := propPair.propName
+ propSchema := propPair.propSchema
+ propRuleName, err := sc.visit(propSchema, fmt.Sprintf("%s-%s", ruleName, propName), rootSchema)
+ if err != nil {
+ return "", err
+ }
+ lPropName, err := sc.formatLiteralQuoted(propName)
+ if err != nil {
+ return "", err
+ }
+ if i > 0 {
+ rule.WriteString(` "," space`)
+ }
+
+ rule.WriteString(fmt.Sprintf(` %s space ":" space %s`, lPropName, propRuleName))
+ }
+
+ }
+
+ if !baseProperty {
+ rule.WriteString(` "}" space`)
+ }
+
+ return sc.addRule(ruleName, rule.String()), nil
+ } else if items, exists := schema["items"].(map[string]interface{}); schemaType == "array" && exists {
+ itemRuleName, err := sc.visit(items, fmt.Sprintf("%s-item", ruleName), rootSchema)
+ if err != nil {
+ return "", err
+ }
+ rule := fmt.Sprintf(`"[" space (%s ("," space %s)*)? "]" space`, itemRuleName, itemRuleName)
+ return sc.addRule(ruleName, rule), nil
+ } else {
+ primitiveRule, exists := PRIMITIVE_RULES[schemaType]
+ if !exists {
+ return "", fmt.Errorf("unrecognized schema: %v", schema)
+ }
+ if ruleName == "root" {
+ schemaType = "root"
+ }
+ return sc.addRule(schemaType, primitiveRule), nil
+ }
+}
+func (sc *LLama31SchemaConverter) resolveReference(ref string, rootSchema map[string]interface{}) (map[string]interface{}, error) {
+ if !strings.HasPrefix(ref, "#/$defs/") {
+ return nil, fmt.Errorf("invalid reference format: %s", ref)
+ }
+
+ defKey := strings.TrimPrefix(ref, "#/$defs/")
+ definitions, exists := rootSchema["$defs"].(map[string]interface{})
+ if !exists {
+ return nil, fmt.Errorf("no definitions found in the schema: %s", rootSchema)
+ }
+
+ def, exists := definitions[defKey].(map[string]interface{})
+ if !exists {
+ return nil, fmt.Errorf("definition not found: %s %+v", defKey, definitions)
+ }
+
+ return def, nil
+}
+
+func (sc *LLama31SchemaConverter) Grammar(schema map[string]interface{}, options ...func(*GrammarOption)) (string, error) {
+ sc.addRule("freestring", PRIMITIVE_RULES["freestring"])
+ _, err := sc.visit(schema, "", schema)
+ if err != nil {
+ return "", err
+ }
+ return sc.rules.ToGrammar(options...), nil
+}
+
+func (sc *LLama31SchemaConverter) GrammarFromBytes(b []byte, options ...func(*GrammarOption)) (string, error) {
+ var schema map[string]interface{}
+ err := json.Unmarshal(b, &schema)
+ if err != nil {
+ return "", err
+ }
+ return sc.Grammar(schema, options...)
+}
diff --git a/pkg/functions/grammars/llama31_schema_test.go b/pkg/functions/grammars/llama31_schema_test.go
new file mode 100644
index 00000000..84d09bd5
--- /dev/null
+++ b/pkg/functions/grammars/llama31_schema_test.go
@@ -0,0 +1,76 @@
+package grammars_test
+
+import (
+ "strings"
+
+ . "github.com/mudler/LocalAI/pkg/functions/grammars"
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+)
+
+const (
+ testllama31Input1 = `
+ {
+ "oneOf": [
+ {
+ "type": "object",
+ "properties": {
+ "function": {"const": "create_event"},
+ "arguments": {
+ "type": "object",
+ "properties": {
+ "title": {"type": "string"},
+ "date": {"type": "string"},
+ "time": {"type": "string"}
+ }
+ }
+ }
+ },
+ {
+ "type": "object",
+ "properties": {
+ "function": {"const": "search"},
+ "arguments": {
+ "type": "object",
+ "properties": {
+ "query": {"type": "string"}
+ }
+ }
+ }
+ }
+ ]
+ }`
+ // {{"example_name": "example_value"}}
+ testllama31inputResult1 = `root-0-function ::= "create_event"
+freestring ::= (
+ [^"\\] |
+ "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
+ )* space
+root-0 ::= "{" root-0-arguments "}"
+root-1-arguments ::= "{" space "\"query\"" space ":" space string "}" space
+root ::= root-0 | root-1
+space ::= " "?
+root-0-arguments ::= "{" space "\"date\"" space ":" space string "," space "\"time\"" space ":" space string "," space "\"title\"" space ":" space string "}" space
+root-1 ::= "{" root-1-arguments "}"
+string ::= "\"" (
+ [^"\\] |
+ "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
+)* "\"" space
+root-1-function ::= "search"`
+)
+
+var _ = Describe("JSON schema grammar tests", func() {
+ Context("JSON", func() {
+ It("generates a valid grammar from JSON schema", func() {
+ grammar, err := NewLLama31SchemaConverter("function").GrammarFromBytes([]byte(testllama31Input1))
+ Expect(err).ToNot(HaveOccurred())
+ results := strings.Split(testllama31inputResult1, "\n")
+ for _, r := range results {
+ if r != "" {
+ Expect(grammar).To(ContainSubstring(r))
+ }
+ }
+ Expect(len(results)).To(Equal(len(strings.Split(grammar, "\n"))))
+ })
+ })
+})
diff --git a/pkg/functions/options.go b/pkg/functions/grammars/options.go
similarity index 76%
rename from pkg/functions/options.go
rename to pkg/functions/grammars/options.go
index 3a341a43..07c6c951 100644
--- a/pkg/functions/options.go
+++ b/pkg/functions/grammars/options.go
@@ -1,4 +1,4 @@
-package functions
+package grammars
type GrammarOption struct {
PropOrder string
@@ -8,6 +8,9 @@ type GrammarOption struct {
MaybeString bool
NoMixedFreeString bool
ExpectStringsAfterJSON bool
+
+ FunctionName string
+ SchemaType SchemaConverterType
}
func (o *GrammarOption) Apply(options ...func(*GrammarOption)) {
@@ -48,3 +51,15 @@ func SetPropOrder(order string) func(*GrammarOption) {
o.PropOrder = order
}
}
+
+func WithSchemaType(schemaType SchemaConverterType) func(*GrammarOption) {
+ return func(o *GrammarOption) {
+ o.SchemaType = schemaType
+ }
+}
+
+func WithFunctionName(name string) func(*GrammarOption) {
+ return func(o *GrammarOption) {
+ o.FunctionName = name
+ }
+}
diff --git a/pkg/functions/grammars/rules.go b/pkg/functions/grammars/rules.go
new file mode 100644
index 00000000..84fc8a25
--- /dev/null
+++ b/pkg/functions/grammars/rules.go
@@ -0,0 +1,93 @@
+package grammars
+
+import (
+ "fmt"
+ "strings"
+
+ "github.com/mudler/LocalAI/pkg/utils"
+)
+
+type Rules map[string]string
+
+func (rules Rules) ToGrammar(options ...func(*GrammarOption)) string {
+ grammarOpts := &GrammarOption{}
+ grammarOpts.Apply(options...)
+
+ prefix := grammarOpts.Prefix
+ maybeArray := grammarOpts.MaybeArray
+ disableParallelNewLines := grammarOpts.DisableParallelNewLines
+ maybeString := grammarOpts.MaybeString
+ noMixedFreeString := grammarOpts.NoMixedFreeString
+
+ var lines []string
+
+ swapRoot := maybeArray || maybeString || prefix != ""
+
+ // write down the computed rules.
+ // if maybeArray is true, we need to add the array rule and slightly tweak the root rule
+ for name, rule := range rules {
+ if swapRoot && name == "root" {
+ name = "realvalue"
+ }
+ lines = append(lines, fmt.Sprintf("%s ::= %s", name, rule))
+ }
+
+ if !swapRoot {
+ return strings.Join(lines, "\n")
+ }
+
+ newRoot := "realvalue"
+ if maybeArray {
+ newRoot = "arr | realvalue"
+ }
+
+ freestringRule := "mixedstring"
+ if noMixedFreeString {
+ freestringRule = "freestring"
+ }
+
+ if prefix != "" {
+ // quote newlines in suffix
+ prefix = utils.EscapeNewLines(prefix)
+
+ if maybeArray && maybeString {
+ newRoot = "(" + newRoot + ")"
+ }
+
+ if maybeString {
+ //newRoot = "( (\"" + suffix + "\" " + newRoot + ") | freestring ) "
+ newRoot = "( \"" + prefix + "\" " + newRoot + " | " + freestringRule + " ) "
+ } else {
+ newRoot = "\"" + prefix + "\" " + "" + newRoot + ""
+ }
+ } else if maybeString {
+ if maybeArray {
+ // newRoot = "(" + newRoot + ")"
+ }
+
+ newRoot = freestringRule + " | " + newRoot
+ }
+
+ lines = append(lines, fmt.Sprintf("%s ::= %s", "root", newRoot))
+ if disableParallelNewLines {
+ lines = append(lines, array)
+ } else {
+ lines = append(lines, arrayNewLines)
+ }
+
+ if maybeArray {
+ if grammarOpts.ExpectStringsAfterJSON {
+ lines = append(lines, `mixedstring ::= freestring | freestring arr freestring | (freestring realvalue freestring)* | realvalue | arr`)
+ } else {
+ lines = append(lines, `mixedstring ::= freestring | freestring arr | freestring realvalue | realvalue | arr`)
+ }
+ } else {
+ if grammarOpts.ExpectStringsAfterJSON {
+ lines = append(lines, `mixedstring ::= freestring | (freestring realvalue freestring)* | realvalue`)
+ } else {
+ lines = append(lines, `mixedstring ::= freestring | freestring realvalue | realvalue`)
+ }
+ }
+
+ return strings.Join(lines, "\n")
+}
diff --git a/pkg/functions/grammars/types.go b/pkg/functions/grammars/types.go
new file mode 100644
index 00000000..1fe6444a
--- /dev/null
+++ b/pkg/functions/grammars/types.go
@@ -0,0 +1,33 @@
+package grammars
+
+type SchemaConverterType int
+
+const (
+ JSONSchema SchemaConverterType = iota
+ LLama31Schema
+)
+
+const (
+ LlamaType string = "llama3.1"
+ JSONType string = "json"
+)
+
+func (s SchemaConverterType) String() string {
+ switch s {
+ case JSONSchema:
+ return JSONType
+ case LLama31Schema:
+ return LlamaType
+ }
+ return "unknown"
+}
+
+func NewType(t string) SchemaConverterType {
+ switch t {
+ case JSONType:
+ return JSONSchema
+ case LlamaType:
+ return LLama31Schema
+ }
+ return JSONSchema
+}
diff --git a/pkg/functions/parse.go b/pkg/functions/parse.go
index 8e848a60..f5593690 100644
--- a/pkg/functions/parse.go
+++ b/pkg/functions/parse.go
@@ -7,6 +7,7 @@ import (
"regexp"
"strings"
+ "github.com/mudler/LocalAI/pkg/functions/grammars"
"github.com/mudler/LocalAI/pkg/utils"
"github.com/rs/zerolog/log"
)
@@ -22,7 +23,9 @@ type GrammarConfig struct {
MixedMode bool `yaml:"mixed_mode"`
// NoMixedFreeString disables the mixed mode for free strings
- // In this way if the LLM selects a free string, it won't be mixed necessarly with JSON objects
+ // In this way if the LLM selects a free string, it won't be mixed necessarly with JSON objects.
+ // For example, if enabled the LLM or returns a JSON object or a free string, but not a mix of both
+ // If disabled(default): the LLM can return a JSON object surrounded by free strings (e.g. `this is the JSON result: { "bar": "baz" } for your question`). This forces the LLM to return at least a JSON object, but its not going to be strict
NoMixedFreeString bool `yaml:"no_mixed_free_string"`
// NoGrammar disables the grammar parsing and parses the responses directly from the LLM
@@ -39,6 +42,10 @@ type GrammarConfig struct {
// for instance name,arguments will make print { "name": "foo", "arguments": { "bar": "baz" } }
// instead of { "arguments": { "bar": "baz" }, "name": "foo" }
PropOrder string `yaml:"properties_order"`
+
+ // SchemaType can be configured to use a specific schema type to force the grammar
+ // available : json, llama3.1
+ SchemaType string `yaml:"schema_type"`
}
// FunctionsConfig is the configuration for the tool/function call.
@@ -92,28 +99,36 @@ type FuncCallResults struct {
Arguments string
}
-func (g GrammarConfig) Options() []func(o *GrammarOption) {
- opts := []func(o *GrammarOption){}
- if g.MixedMode {
- opts = append(opts, EnableMaybeString)
+func (g FunctionsConfig) GrammarOptions() []func(o *grammars.GrammarOption) {
+ opts := []func(o *grammars.GrammarOption){}
+ if g.GrammarConfig.MixedMode {
+ opts = append(opts, grammars.EnableMaybeString)
}
- if g.ParallelCalls {
- opts = append(opts, EnableMaybeArray)
+ if g.GrammarConfig.ParallelCalls {
+ opts = append(opts, grammars.EnableMaybeArray)
}
- if g.DisableParallelNewLines {
- opts = append(opts, DisableParallelNewLines)
+ if g.GrammarConfig.DisableParallelNewLines {
+ opts = append(opts, grammars.DisableParallelNewLines)
}
- if g.Prefix != "" {
- opts = append(opts, SetPrefix(g.Prefix))
+ if g.GrammarConfig.Prefix != "" {
+ opts = append(opts, grammars.SetPrefix(g.GrammarConfig.Prefix))
}
- if g.NoMixedFreeString {
- opts = append(opts, NoMixedFreeString)
+ if g.GrammarConfig.NoMixedFreeString {
+ opts = append(opts, grammars.NoMixedFreeString)
}
- if g.ExpectStringsAfterJSON {
- opts = append(opts, ExpectStringsAfterJSON)
+ if g.GrammarConfig.ExpectStringsAfterJSON {
+ opts = append(opts, grammars.ExpectStringsAfterJSON)
}
- opts = append(opts, SetPropOrder(g.PropOrder))
+ if g.GrammarConfig.SchemaType != "" {
+ opts = append(opts, grammars.WithSchemaType(grammars.NewType(g.GrammarConfig.SchemaType)))
+ }
+
+ if g.FunctionNameKey != "" {
+ opts = append(opts, grammars.WithFunctionName(g.FunctionNameKey))
+ }
+
+ opts = append(opts, grammars.SetPropOrder(g.GrammarConfig.PropOrder))
return opts
}