LocalAI/pkg/functions/grammars/json_schema.go

221 lines
6.3 KiB
Go
Raw Normal View History

package grammars
// a golang port of https://github.com/ggerganov/llama.cpp/pull/1887
import (
"encoding/json"
"fmt"
"sort"
"strings"
)
type JSONSchemaConverter struct {
propOrder map[string]int
rules Rules
}
func NewJSONSchemaConverter(propOrder string) *JSONSchemaConverter {
propOrderSlice := strings.Split(propOrder, ",")
propOrderMap := make(map[string]int)
for idx, name := range propOrderSlice {
propOrderMap[name] = idx
}
rules := make(map[string]string)
rules["space"] = SPACE_RULE
return &JSONSchemaConverter{
propOrder: propOrderMap,
rules: rules,
}
}
func (sc *JSONSchemaConverter) formatLiteral(literal interface{}) (string, error) {
jLiteral, err := jsonString(literal)
if err != nil {
return "", err
}
escaped := GRAMMAR_LITERAL_ESCAPE_RE.ReplaceAllStringFunc(jLiteral, func(match string) string {
return GRAMMAR_LITERAL_ESCAPES[match]
})
return fmt.Sprintf(`"%s"`, escaped), nil
}
func (sc *JSONSchemaConverter) addRule(name, rule string) string {
escName := INVALID_RULE_CHARS_RE.ReplaceAllString(name, "-")
key := escName
if existingRule, ok := sc.rules[escName]; ok && existingRule != rule {
i := 0
for {
key = fmt.Sprintf("%s%d", escName, i)
if _, ok := sc.rules[key]; !ok {
break
}
i++
}
}
sc.rules[key] = rule
return key
}
func (sc *JSONSchemaConverter) visit(schema map[string]interface{}, name string, rootSchema map[string]interface{}) (string, error) {
st, existType := schema["type"]
var schemaType string
if existType {
schemaType = st.(string)
}
ruleName := name
if name == "" {
ruleName = "root"
}
_, oneOfExists := schema["oneOf"]
_, anyOfExists := schema["anyOf"]
if oneOfExists || anyOfExists {
var alternatives []string
oneOfSchemas, oneOfExists := schema["oneOf"].([]interface{})
anyOfSchemas, anyOfExists := schema["anyOf"].([]interface{})
if oneOfExists {
for i, altSchema := range oneOfSchemas {
alternative, err := sc.visit(altSchema.(map[string]interface{}), fmt.Sprintf("%s-%d", ruleName, i), rootSchema)
if err != nil {
return "", err
}
alternatives = append(alternatives, alternative)
}
} else if anyOfExists {
for i, altSchema := range anyOfSchemas {
alternative, err := sc.visit(altSchema.(map[string]interface{}), fmt.Sprintf("%s-%d", ruleName, i), rootSchema)
if err != nil {
return "", err
}
alternatives = append(alternatives, alternative)
}
}
rule := strings.Join(alternatives, " | ")
return sc.addRule(ruleName, rule), nil
} else if ref, exists := schema["$ref"].(string); exists {
referencedSchema, err := sc.resolveReference(ref, rootSchema)
if err != nil {
return "", err
}
return sc.visit(referencedSchema, name, rootSchema)
} else if constVal, exists := schema["const"]; exists {
literal, err := sc.formatLiteral((constVal))
if err != nil {
return "", err
}
return sc.addRule(ruleName, literal), nil
} else if enumVals, exists := schema["enum"].([]interface{}); exists {
var enumRules []string
for _, enumVal := range enumVals {
enumRule, err := sc.formatLiteral(enumVal)
if err != nil {
return "", err
}
enumRules = append(enumRules, enumRule)
}
rule := strings.Join(enumRules, " | ")
return sc.addRule(ruleName, rule), nil
} else if properties, exists := schema["properties"].(map[string]interface{}); schemaType == "object" && exists {
propOrder := sc.propOrder
var propPairs []struct {
propName string
propSchema map[string]interface{}
}
for propName, propSchema := range properties {
propPairs = append(propPairs, struct {
propName string
propSchema map[string]interface{}
}{propName: propName, propSchema: propSchema.(map[string]interface{})})
}
sort.Slice(propPairs, func(i, j int) bool {
iOrder := propOrder[propPairs[i].propName]
jOrder := propOrder[propPairs[j].propName]
if iOrder != 0 && jOrder != 0 {
return iOrder < jOrder
}
return propPairs[i].propName < propPairs[j].propName
})
var rule strings.Builder
rule.WriteString(`"{" space`)
for i, propPair := range propPairs {
propName := propPair.propName
propSchema := propPair.propSchema
propRuleName, err := sc.visit(propSchema, fmt.Sprintf("%s-%s", ruleName, propName), rootSchema)
if err != nil {
return "", err
}
lPropName, err := sc.formatLiteral(propName)
if err != nil {
return "", err
}
if i > 0 {
rule.WriteString(` "," space`)
}
rule.WriteString(fmt.Sprintf(` %s space ":" space %s`, lPropName, propRuleName))
}
rule.WriteString(` "}" space`)
return sc.addRule(ruleName, rule.String()), nil
} else if items, exists := schema["items"].(map[string]interface{}); schemaType == "array" && exists {
itemRuleName, err := sc.visit(items, fmt.Sprintf("%s-item", ruleName), rootSchema)
if err != nil {
return "", err
}
rule := fmt.Sprintf(`"[" space (%s ("," space %s)*)? "]" space`, itemRuleName, itemRuleName)
return sc.addRule(ruleName, rule), nil
} else {
primitiveRule, exists := PRIMITIVE_RULES[schemaType]
if !exists {
return "", fmt.Errorf("unrecognized schema: %v", schema)
}
if ruleName == "root" {
schemaType = "root"
}
return sc.addRule(schemaType, primitiveRule), nil
}
}
func (sc *JSONSchemaConverter) resolveReference(ref string, rootSchema map[string]interface{}) (map[string]interface{}, error) {
if !strings.HasPrefix(ref, "#/$defs/") {
return nil, fmt.Errorf("invalid reference format: %s", ref)
}
defKey := strings.TrimPrefix(ref, "#/$defs/")
definitions, exists := rootSchema["$defs"].(map[string]interface{})
if !exists {
return nil, fmt.Errorf("no definitions found in the schema: %s", rootSchema)
}
def, exists := definitions[defKey].(map[string]interface{})
if !exists {
return nil, fmt.Errorf("definition not found: %s %+v", defKey, definitions)
}
return def, nil
}
func (sc *JSONSchemaConverter) Grammar(schema map[string]interface{}, options ...func(*GrammarOption)) (string, error) {
feat(functions): mixed JSON BNF grammars (#2328) feat(functions): support mixed JSON BNF grammar This PR provides new options to control how functions are extracted from the LLM, and also provides more control on how JSON grammars can be used (also in conjunction). New YAML settings introduced: - `grammar_message`: when enabled, the generated grammar can also decide to push strings and not only JSON objects. This allows the LLM to pick to either respond freely or using JSON. - `grammar_prefix`: Allows to prefix a string to the JSON grammar definition. - `replace_results`: Is a map that allows to replace strings in the LLM result. As an example, consider the following settings for Hermes-2-Pro-Mistral, which allow extracting both JSON results coming from the model, and the ones coming from the grammar: ```yaml function: # disable injecting the "answer" tool disable_no_action: true # This allows the grammar to also return messages grammar_message: true # Suffix to add to the grammar grammar_prefix: '<tool_call>\n' return_name_in_function_response: true # Without grammar uncomment the lines below # Warning: this is relying only on the capability of the # LLM model to generate the correct function call. # no_grammar: true # json_regex_match: "(?s)<tool_call>(.*?)</tool_call>" replace_results: "<tool_call>": "" "\'": "\"" ``` Note: To disable entirely grammars usage in the example above, uncomment the `no_grammar` and `json_regex_match`. Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2024-05-15 18:03:18 +00:00
sc.addRule("freestring", PRIMITIVE_RULES["freestring"])
_, err := sc.visit(schema, "", schema)
if err != nil {
return "", err
}
return sc.rules.ToGrammar(options...), nil
}
func (sc *JSONSchemaConverter) GrammarFromBytes(b []byte, options ...func(*GrammarOption)) (string, error) {
var schema map[string]interface{}
err := json.Unmarshal(b, &schema)
if err != nil {
return "", err
}
return sc.Grammar(schema, options...)
}