feat(openai): add json_schema format type and strict mode (#3193)

* feat(openai): add json_schema and strict mode

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* handle err vs _

security scanners prefer if we put these branches in, and I tend to agree.

Signed-off-by: Dave <dave@gray101.com>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Signed-off-by: Dave <dave@gray101.com>
Co-authored-by: Dave <dave@gray101.com>
This commit is contained in:
Ettore Di Giacinto 2024-08-07 21:27:02 +02:00 committed by GitHub
parent 66cf38b0b3
commit e198347886
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 46 additions and 3 deletions

View File

@ -172,6 +172,14 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
funcs := input.Functions funcs := input.Functions
shouldUseFn := len(input.Functions) > 0 && config.ShouldUseFunctions() shouldUseFn := len(input.Functions) > 0 && config.ShouldUseFunctions()
strictMode := false
for _, f := range input.Functions {
if f.Strict {
strictMode = true
break
}
}
// Allow the user to set custom actions via config file // Allow the user to set custom actions via config file
// to be "embedded" in each model // to be "embedded" in each model
@ -187,10 +195,33 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
if config.ResponseFormatMap != nil { if config.ResponseFormatMap != nil {
d := schema.ChatCompletionResponseFormat{} d := schema.ChatCompletionResponseFormat{}
dat, _ := json.Marshal(config.ResponseFormatMap) dat, err := json.Marshal(config.ResponseFormatMap)
_ = json.Unmarshal(dat, &d) if err != nil {
return err
}
err = json.Unmarshal(dat, &d)
if err != nil {
return err
}
if d.Type == "json_object" { if d.Type == "json_object" {
input.Grammar = functions.JSONBNF input.Grammar = functions.JSONBNF
} else if d.Type == "json_schema" {
d := schema.JsonSchemaRequest{}
dat, err := json.Marshal(config.ResponseFormatMap)
if err != nil {
return err
}
err = json.Unmarshal(dat, &d)
if err != nil {
return err
}
fs := &functions.JSONFunctionStructure{
AnyOf: []functions.Item{d.JsonSchema.Schema},
}
g, err := fs.Grammar(config.FunctionsConfig.GrammarOptions()...)
if err == nil {
input.Grammar = g
}
} }
} }
@ -201,7 +232,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
} }
switch { switch {
case !config.FunctionsConfig.GrammarConfig.NoGrammar && shouldUseFn: case (!config.FunctionsConfig.GrammarConfig.NoGrammar || strictMode) && shouldUseFn:
noActionGrammar := functions.Function{ noActionGrammar := functions.Function{
Name: noActionName, Name: noActionName,
Description: noActionDescription, Description: noActionDescription,

View File

@ -139,6 +139,17 @@ type ChatCompletionResponseFormat struct {
Type ChatCompletionResponseFormatType `json:"type,omitempty"` Type ChatCompletionResponseFormatType `json:"type,omitempty"`
} }
type JsonSchemaRequest struct {
Type string `json:"type"`
JsonSchema JsonSchema `json:"json_schema"`
}
type JsonSchema struct {
Name string `json:"name"`
Strict bool `json:"strict"`
Schema functions.Item `json:"schema"`
}
type OpenAIRequest struct { type OpenAIRequest struct {
PredictionOptions PredictionOptions

View File

@ -14,6 +14,7 @@ const (
type Function struct { type Function struct {
Name string `json:"name"` Name string `json:"name"`
Description string `json:"description"` Description string `json:"description"`
Strict bool `json:"strict"`
Parameters map[string]interface{} `json:"parameters"` Parameters map[string]interface{} `json:"parameters"`
} }
type Functions []Function type Functions []Function