mirror of
https://github.com/mudler/LocalAI.git
synced 2024-12-21 21:47:51 +00:00
282 lines
7.7 KiB
Go
282 lines
7.7 KiB
Go
|
package grammars
|
||
|
|
||
|
import (
|
||
|
"encoding/json"
|
||
|
"fmt"
|
||
|
"regexp"
|
||
|
"sort"
|
||
|
"strings"
|
||
|
)
|
||
|
|
||
|
type LLama31SchemaConverter struct {
|
||
|
fnName string
|
||
|
rules Rules
|
||
|
}
|
||
|
|
||
|
func NewLLama31SchemaConverter(fnName string) *LLama31SchemaConverter {
|
||
|
rules := make(map[string]string)
|
||
|
rules["space"] = SPACE_RULE
|
||
|
if fnName == "" {
|
||
|
fnName = "name"
|
||
|
}
|
||
|
|
||
|
return &LLama31SchemaConverter{
|
||
|
rules: rules,
|
||
|
fnName: fnName,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
var GRAMMAR_LITERAL_ESCAPESLlama = map[string]string{
|
||
|
"\r": `\r`,
|
||
|
"\n": `\n`,
|
||
|
}
|
||
|
|
||
|
var GRAMMAR_LITERAL_ESCAPE_RELlama = regexp.MustCompile(`[\r\n]`)
|
||
|
|
||
|
func (sc *LLama31SchemaConverter) formatLiteral(literal interface{}) (string, error) {
|
||
|
jLiteral, err := jsonString(literal)
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
escaped := GRAMMAR_LITERAL_ESCAPE_RELlama.ReplaceAllStringFunc(jLiteral, func(match string) string {
|
||
|
return GRAMMAR_LITERAL_ESCAPESLlama[match]
|
||
|
})
|
||
|
return escaped, nil
|
||
|
}
|
||
|
|
||
|
func (sc *LLama31SchemaConverter) formatLiteralQuoted(literal interface{}) (string, error) {
|
||
|
jLiteral, err := jsonString(literal)
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
escaped := GRAMMAR_LITERAL_ESCAPE_RE.ReplaceAllStringFunc(jLiteral, func(match string) string {
|
||
|
return GRAMMAR_LITERAL_ESCAPES[match]
|
||
|
})
|
||
|
return fmt.Sprintf(`"%s"`, escaped), nil
|
||
|
}
|
||
|
|
||
|
func (sc *LLama31SchemaConverter) addRule(name, rule string) string {
|
||
|
escName := INVALID_RULE_CHARS_RE.ReplaceAllString(name, "-")
|
||
|
key := escName
|
||
|
if existingRule, ok := sc.rules[escName]; ok && existingRule != rule {
|
||
|
i := 0
|
||
|
for {
|
||
|
key = fmt.Sprintf("%s%d", escName, i)
|
||
|
if _, ok := sc.rules[key]; !ok {
|
||
|
break
|
||
|
}
|
||
|
i++
|
||
|
}
|
||
|
}
|
||
|
sc.rules[key] = rule
|
||
|
return key
|
||
|
}
|
||
|
|
||
|
func (sc *LLama31SchemaConverter) visit(schema map[string]interface{}, name string, rootSchema map[string]interface{}) (string, error) {
|
||
|
st, existType := schema["type"]
|
||
|
var schemaType string
|
||
|
if existType {
|
||
|
schemaType = st.(string)
|
||
|
}
|
||
|
ruleName := name
|
||
|
if name == "" {
|
||
|
ruleName = "root"
|
||
|
}
|
||
|
_, oneOfExists := schema["oneOf"]
|
||
|
_, anyOfExists := schema["anyOf"]
|
||
|
if oneOfExists || anyOfExists {
|
||
|
var alternatives []string
|
||
|
oneOfSchemas, oneOfExists := schema["oneOf"].([]interface{})
|
||
|
anyOfSchemas, anyOfExists := schema["anyOf"].([]interface{})
|
||
|
|
||
|
if oneOfExists {
|
||
|
for i, altSchema := range oneOfSchemas {
|
||
|
alternative, err := sc.visit(altSchema.(map[string]interface{}), fmt.Sprintf("%s-%d", ruleName, i), rootSchema)
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
alternatives = append(alternatives, alternative)
|
||
|
}
|
||
|
} else if anyOfExists {
|
||
|
for i, altSchema := range anyOfSchemas {
|
||
|
alternative, err := sc.visit(altSchema.(map[string]interface{}), fmt.Sprintf("%s-%d", ruleName, i), rootSchema)
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
alternatives = append(alternatives, alternative)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
rule := strings.Join(alternatives, " | ")
|
||
|
return sc.addRule(ruleName, rule), nil
|
||
|
} else if ref, exists := schema["$ref"].(string); exists {
|
||
|
referencedSchema, err := sc.resolveReference(ref, rootSchema)
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
return sc.visit(referencedSchema, name, rootSchema)
|
||
|
} else if constVal, exists := schema["const"]; exists {
|
||
|
|
||
|
literal, err := sc.formatLiteral((constVal))
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
return sc.addRule(ruleName, literal), nil
|
||
|
} else if enumVals, exists := schema["enum"].([]interface{}); exists {
|
||
|
var enumRules []string
|
||
|
for _, enumVal := range enumVals {
|
||
|
enumRule, err := sc.formatLiteralQuoted(enumVal)
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
enumRules = append(enumRules, enumRule)
|
||
|
}
|
||
|
rule := strings.Join(enumRules, " | ")
|
||
|
return sc.addRule(ruleName, rule), nil
|
||
|
} else if properties, exists := schema["properties"].(map[string]interface{}); schemaType == "object" && exists {
|
||
|
baseProperty := false
|
||
|
depth := strings.Split(name, "-")
|
||
|
if len(depth) == 2 {
|
||
|
baseProperty = true
|
||
|
}
|
||
|
type propData []struct {
|
||
|
propName string
|
||
|
propSchema map[string]interface{}
|
||
|
}
|
||
|
var propPairs propData
|
||
|
|
||
|
for propName, propSchema := range properties {
|
||
|
propPairs = append(propPairs, struct {
|
||
|
propName string
|
||
|
propSchema map[string]interface{}
|
||
|
}{propName: propName, propSchema: propSchema.(map[string]interface{})})
|
||
|
}
|
||
|
|
||
|
sort.Slice(propPairs, func(i, j int) bool {
|
||
|
return propPairs[i].propName < propPairs[j].propName
|
||
|
})
|
||
|
|
||
|
var rule strings.Builder
|
||
|
if baseProperty {
|
||
|
rule.WriteString(`"<function="`)
|
||
|
} else {
|
||
|
rule.WriteString(`"{" space`)
|
||
|
}
|
||
|
|
||
|
if baseProperty {
|
||
|
|
||
|
namePair := propData{}
|
||
|
for i, propPair := range propPairs {
|
||
|
propName := propPair.propName
|
||
|
if propName == sc.fnName {
|
||
|
namePair = append(namePair, propPair)
|
||
|
// remove namePair from propPairs
|
||
|
propPairs = append(propPairs[:i], propPairs[i+1:]...)
|
||
|
break
|
||
|
}
|
||
|
}
|
||
|
if len(namePair) == 0 {
|
||
|
return "", fmt.Errorf("no function name found in the schema: %s", schema)
|
||
|
}
|
||
|
|
||
|
propRuleName, err := sc.visit(namePair[0].propSchema, fmt.Sprintf("%s-%s", ruleName, sc.fnName), rootSchema)
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
|
||
|
rule.WriteString(fmt.Sprintf(` %s ">{" `, propRuleName))
|
||
|
|
||
|
for _, propPair := range propPairs {
|
||
|
propName := propPair.propName
|
||
|
propSchema := propPair.propSchema
|
||
|
propRuleName, err := sc.visit(propSchema, fmt.Sprintf("%s-%s", ruleName, propName), rootSchema)
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
|
||
|
rule.WriteString(propRuleName)
|
||
|
}
|
||
|
|
||
|
rule.WriteString(` "}</function>"`)
|
||
|
|
||
|
} else {
|
||
|
for i, propPair := range propPairs {
|
||
|
propName := propPair.propName
|
||
|
propSchema := propPair.propSchema
|
||
|
propRuleName, err := sc.visit(propSchema, fmt.Sprintf("%s-%s", ruleName, propName), rootSchema)
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
lPropName, err := sc.formatLiteralQuoted(propName)
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
if i > 0 {
|
||
|
rule.WriteString(` "," space`)
|
||
|
}
|
||
|
|
||
|
rule.WriteString(fmt.Sprintf(` %s space ":" space %s`, lPropName, propRuleName))
|
||
|
}
|
||
|
|
||
|
}
|
||
|
|
||
|
if !baseProperty {
|
||
|
rule.WriteString(` "}" space`)
|
||
|
}
|
||
|
|
||
|
return sc.addRule(ruleName, rule.String()), nil
|
||
|
} else if items, exists := schema["items"].(map[string]interface{}); schemaType == "array" && exists {
|
||
|
itemRuleName, err := sc.visit(items, fmt.Sprintf("%s-item", ruleName), rootSchema)
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
rule := fmt.Sprintf(`"[" space (%s ("," space %s)*)? "]" space`, itemRuleName, itemRuleName)
|
||
|
return sc.addRule(ruleName, rule), nil
|
||
|
} else {
|
||
|
primitiveRule, exists := PRIMITIVE_RULES[schemaType]
|
||
|
if !exists {
|
||
|
return "", fmt.Errorf("unrecognized schema: %v", schema)
|
||
|
}
|
||
|
if ruleName == "root" {
|
||
|
schemaType = "root"
|
||
|
}
|
||
|
return sc.addRule(schemaType, primitiveRule), nil
|
||
|
}
|
||
|
}
|
||
|
func (sc *LLama31SchemaConverter) resolveReference(ref string, rootSchema map[string]interface{}) (map[string]interface{}, error) {
|
||
|
if !strings.HasPrefix(ref, "#/$defs/") {
|
||
|
return nil, fmt.Errorf("invalid reference format: %s", ref)
|
||
|
}
|
||
|
|
||
|
defKey := strings.TrimPrefix(ref, "#/$defs/")
|
||
|
definitions, exists := rootSchema["$defs"].(map[string]interface{})
|
||
|
if !exists {
|
||
|
return nil, fmt.Errorf("no definitions found in the schema: %s", rootSchema)
|
||
|
}
|
||
|
|
||
|
def, exists := definitions[defKey].(map[string]interface{})
|
||
|
if !exists {
|
||
|
return nil, fmt.Errorf("definition not found: %s %+v", defKey, definitions)
|
||
|
}
|
||
|
|
||
|
return def, nil
|
||
|
}
|
||
|
|
||
|
func (sc *LLama31SchemaConverter) Grammar(schema map[string]interface{}, options ...func(*GrammarOption)) (string, error) {
|
||
|
sc.addRule("freestring", PRIMITIVE_RULES["freestring"])
|
||
|
_, err := sc.visit(schema, "", schema)
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
return sc.rules.ToGrammar(options...), nil
|
||
|
}
|
||
|
|
||
|
func (sc *LLama31SchemaConverter) GrammarFromBytes(b []byte, options ...func(*GrammarOption)) (string, error) {
|
||
|
var schema map[string]interface{}
|
||
|
err := json.Unmarshal(b, &schema)
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
return sc.Grammar(schema, options...)
|
||
|
}
|