mirror of
https://github.com/mudler/LocalAI.git
synced 2024-12-18 20:27:57 +00:00
feat(openai): add json_schema
format type and strict mode (#3193)
* feat(openai): add json_schema and strict mode Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * handle err vs _ security scanners prefer if we put these branches in, and I tend to agree. Signed-off-by: Dave <dave@gray101.com> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Signed-off-by: Dave <dave@gray101.com> Co-authored-by: Dave <dave@gray101.com>
This commit is contained in:
parent
66cf38b0b3
commit
e198347886
@ -172,6 +172,14 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
|
||||
|
||||
funcs := input.Functions
|
||||
shouldUseFn := len(input.Functions) > 0 && config.ShouldUseFunctions()
|
||||
strictMode := false
|
||||
|
||||
for _, f := range input.Functions {
|
||||
if f.Strict {
|
||||
strictMode = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Allow the user to set custom actions via config file
|
||||
// to be "embedded" in each model
|
||||
@ -187,10 +195,33 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
|
||||
|
||||
if config.ResponseFormatMap != nil {
|
||||
d := schema.ChatCompletionResponseFormat{}
|
||||
dat, _ := json.Marshal(config.ResponseFormatMap)
|
||||
_ = json.Unmarshal(dat, &d)
|
||||
dat, err := json.Marshal(config.ResponseFormatMap)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = json.Unmarshal(dat, &d)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if d.Type == "json_object" {
|
||||
input.Grammar = functions.JSONBNF
|
||||
} else if d.Type == "json_schema" {
|
||||
d := schema.JsonSchemaRequest{}
|
||||
dat, err := json.Marshal(config.ResponseFormatMap)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = json.Unmarshal(dat, &d)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fs := &functions.JSONFunctionStructure{
|
||||
AnyOf: []functions.Item{d.JsonSchema.Schema},
|
||||
}
|
||||
g, err := fs.Grammar(config.FunctionsConfig.GrammarOptions()...)
|
||||
if err == nil {
|
||||
input.Grammar = g
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -201,7 +232,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
|
||||
}
|
||||
|
||||
switch {
|
||||
case !config.FunctionsConfig.GrammarConfig.NoGrammar && shouldUseFn:
|
||||
case (!config.FunctionsConfig.GrammarConfig.NoGrammar || strictMode) && shouldUseFn:
|
||||
noActionGrammar := functions.Function{
|
||||
Name: noActionName,
|
||||
Description: noActionDescription,
|
||||
|
@ -139,6 +139,17 @@ type ChatCompletionResponseFormat struct {
|
||||
Type ChatCompletionResponseFormatType `json:"type,omitempty"`
|
||||
}
|
||||
|
||||
type JsonSchemaRequest struct {
|
||||
Type string `json:"type"`
|
||||
JsonSchema JsonSchema `json:"json_schema"`
|
||||
}
|
||||
|
||||
type JsonSchema struct {
|
||||
Name string `json:"name"`
|
||||
Strict bool `json:"strict"`
|
||||
Schema functions.Item `json:"schema"`
|
||||
}
|
||||
|
||||
type OpenAIRequest struct {
|
||||
PredictionOptions
|
||||
|
||||
|
@ -14,6 +14,7 @@ const (
|
||||
type Function struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Strict bool `json:"strict"`
|
||||
Parameters map[string]interface{} `json:"parameters"`
|
||||
}
|
||||
type Functions []Function
|
||||
|
Loading…
Reference in New Issue
Block a user