package grammars import ( "encoding/json" "fmt" "regexp" "sort" "strings" ) type LLama31SchemaConverter struct { fnName string rules Rules } func NewLLama31SchemaConverter(fnName string) *LLama31SchemaConverter { rules := make(map[string]string) rules["space"] = SPACE_RULE if fnName == "" { fnName = "name" } return &LLama31SchemaConverter{ rules: rules, fnName: fnName, } } var GRAMMAR_LITERAL_ESCAPESLlama = map[string]string{ "\r": `\r`, "\n": `\n`, } var GRAMMAR_LITERAL_ESCAPE_RELlama = regexp.MustCompile(`[\r\n]`) func (sc *LLama31SchemaConverter) formatLiteral(literal interface{}) (string, error) { jLiteral, err := jsonString(literal) if err != nil { return "", err } escaped := GRAMMAR_LITERAL_ESCAPE_RELlama.ReplaceAllStringFunc(jLiteral, func(match string) string { return GRAMMAR_LITERAL_ESCAPESLlama[match] }) return escaped, nil } func (sc *LLama31SchemaConverter) formatLiteralQuoted(literal interface{}) (string, error) { jLiteral, err := jsonString(literal) if err != nil { return "", err } escaped := GRAMMAR_LITERAL_ESCAPE_RE.ReplaceAllStringFunc(jLiteral, func(match string) string { return GRAMMAR_LITERAL_ESCAPES[match] }) return fmt.Sprintf(`"%s"`, escaped), nil } func (sc *LLama31SchemaConverter) addRule(name, rule string) string { escName := INVALID_RULE_CHARS_RE.ReplaceAllString(name, "-") key := escName if existingRule, ok := sc.rules[escName]; ok && existingRule != rule { i := 0 for { key = fmt.Sprintf("%s%d", escName, i) if _, ok := sc.rules[key]; !ok { break } i++ } } sc.rules[key] = rule return key } func (sc *LLama31SchemaConverter) visit(schema map[string]interface{}, name string, rootSchema map[string]interface{}) (string, error) { st, existType := schema["type"] var schemaType string if existType { schemaType = st.(string) } ruleName := name if name == "" { ruleName = "root" } _, oneOfExists := schema["oneOf"] _, anyOfExists := schema["anyOf"] if oneOfExists || anyOfExists { var alternatives []string oneOfSchemas, oneOfExists := schema["oneOf"].([]interface{}) anyOfSchemas, anyOfExists := schema["anyOf"].([]interface{}) if oneOfExists { for i, altSchema := range oneOfSchemas { alternative, err := sc.visit(altSchema.(map[string]interface{}), fmt.Sprintf("%s-%d", ruleName, i), rootSchema) if err != nil { return "", err } alternatives = append(alternatives, alternative) } } else if anyOfExists { for i, altSchema := range anyOfSchemas { alternative, err := sc.visit(altSchema.(map[string]interface{}), fmt.Sprintf("%s-%d", ruleName, i), rootSchema) if err != nil { return "", err } alternatives = append(alternatives, alternative) } } rule := strings.Join(alternatives, " | ") return sc.addRule(ruleName, rule), nil } else if ref, exists := schema["$ref"].(string); exists { referencedSchema, err := sc.resolveReference(ref, rootSchema) if err != nil { return "", err } return sc.visit(referencedSchema, name, rootSchema) } else if constVal, exists := schema["const"]; exists { literal, err := sc.formatLiteral((constVal)) if err != nil { return "", err } return sc.addRule(ruleName, literal), nil } else if enumVals, exists := schema["enum"].([]interface{}); exists { var enumRules []string for _, enumVal := range enumVals { enumRule, err := sc.formatLiteralQuoted(enumVal) if err != nil { return "", err } enumRules = append(enumRules, enumRule) } rule := strings.Join(enumRules, " | ") return sc.addRule(ruleName, rule), nil } else if properties, exists := schema["properties"].(map[string]interface{}); schemaType == "object" && exists { baseProperty := false depth := strings.Split(name, "-") if len(depth) == 2 { baseProperty = true } type propData []struct { propName string propSchema map[string]interface{} } var propPairs propData for propName, propSchema := range properties { propPairs = append(propPairs, struct { propName string propSchema map[string]interface{} }{propName: propName, propSchema: propSchema.(map[string]interface{})}) } sort.Slice(propPairs, func(i, j int) bool { return propPairs[i].propName < propPairs[j].propName }) var rule strings.Builder if baseProperty { rule.WriteString(`"{" `, propRuleName)) for _, propPair := range propPairs { propName := propPair.propName propSchema := propPair.propSchema propRuleName, err := sc.visit(propSchema, fmt.Sprintf("%s-%s", ruleName, propName), rootSchema) if err != nil { return "", err } rule.WriteString(propRuleName) } rule.WriteString(` "}"`) } else { for i, propPair := range propPairs { propName := propPair.propName propSchema := propPair.propSchema propRuleName, err := sc.visit(propSchema, fmt.Sprintf("%s-%s", ruleName, propName), rootSchema) if err != nil { return "", err } lPropName, err := sc.formatLiteralQuoted(propName) if err != nil { return "", err } if i > 0 { rule.WriteString(` "," space`) } rule.WriteString(fmt.Sprintf(` %s space ":" space %s`, lPropName, propRuleName)) } } if !baseProperty { rule.WriteString(` "}" space`) } return sc.addRule(ruleName, rule.String()), nil } else if items, exists := schema["items"].(map[string]interface{}); schemaType == "array" && exists { itemRuleName, err := sc.visit(items, fmt.Sprintf("%s-item", ruleName), rootSchema) if err != nil { return "", err } rule := fmt.Sprintf(`"[" space (%s ("," space %s)*)? "]" space`, itemRuleName, itemRuleName) return sc.addRule(ruleName, rule), nil } else { primitiveRule, exists := PRIMITIVE_RULES[schemaType] if !exists { return "", fmt.Errorf("unrecognized schema: %v", schema) } if ruleName == "root" { schemaType = "root" } return sc.addRule(schemaType, primitiveRule), nil } } func (sc *LLama31SchemaConverter) resolveReference(ref string, rootSchema map[string]interface{}) (map[string]interface{}, error) { if !strings.HasPrefix(ref, "#/$defs/") { return nil, fmt.Errorf("invalid reference format: %s", ref) } defKey := strings.TrimPrefix(ref, "#/$defs/") definitions, exists := rootSchema["$defs"].(map[string]interface{}) if !exists { return nil, fmt.Errorf("no definitions found in the schema: %s", rootSchema) } def, exists := definitions[defKey].(map[string]interface{}) if !exists { return nil, fmt.Errorf("definition not found: %s %+v", defKey, definitions) } return def, nil } func (sc *LLama31SchemaConverter) Grammar(schema map[string]interface{}, options ...func(*GrammarOption)) (string, error) { sc.addRule("freestring", PRIMITIVE_RULES["freestring"]) _, err := sc.visit(schema, "", schema) if err != nil { return "", err } return sc.rules.ToGrammar(options...), nil } func (sc *LLama31SchemaConverter) GrammarFromBytes(b []byte, options ...func(*GrammarOption)) (string, error) { var schema map[string]interface{} err := json.Unmarshal(b, &schema) if err != nil { return "", err } return sc.Grammar(schema, options...) }