From e198347886199a8119140f0d7d1a6442b4541ebc Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Wed, 7 Aug 2024 21:27:02 +0200 Subject: [PATCH] feat(openai): add `json_schema` format type and strict mode (#3193) * feat(openai): add json_schema and strict mode Signed-off-by: Ettore Di Giacinto * handle err vs _ security scanners prefer if we put these branches in, and I tend to agree. Signed-off-by: Dave --------- Signed-off-by: Ettore Di Giacinto Signed-off-by: Dave Co-authored-by: Dave --- core/http/endpoints/openai/chat.go | 37 +++++++++++++++++++++++++++--- core/schema/openai.go | 11 +++++++++ pkg/functions/functions.go | 1 + 3 files changed, 46 insertions(+), 3 deletions(-) diff --git a/core/http/endpoints/openai/chat.go b/core/http/endpoints/openai/chat.go index 86b75601..12a14eac 100644 --- a/core/http/endpoints/openai/chat.go +++ b/core/http/endpoints/openai/chat.go @@ -172,6 +172,14 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup funcs := input.Functions shouldUseFn := len(input.Functions) > 0 && config.ShouldUseFunctions() + strictMode := false + + for _, f := range input.Functions { + if f.Strict { + strictMode = true + break + } + } // Allow the user to set custom actions via config file // to be "embedded" in each model @@ -187,10 +195,33 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup if config.ResponseFormatMap != nil { d := schema.ChatCompletionResponseFormat{} - dat, _ := json.Marshal(config.ResponseFormatMap) - _ = json.Unmarshal(dat, &d) + dat, err := json.Marshal(config.ResponseFormatMap) + if err != nil { + return err + } + err = json.Unmarshal(dat, &d) + if err != nil { + return err + } if d.Type == "json_object" { input.Grammar = functions.JSONBNF + } else if d.Type == "json_schema" { + d := schema.JsonSchemaRequest{} + dat, err := json.Marshal(config.ResponseFormatMap) + if err != nil { + return err + } + err = json.Unmarshal(dat, &d) + if err != nil { + return err + } + fs := &functions.JSONFunctionStructure{ + AnyOf: []functions.Item{d.JsonSchema.Schema}, + } + g, err := fs.Grammar(config.FunctionsConfig.GrammarOptions()...) + if err == nil { + input.Grammar = g + } } } @@ -201,7 +232,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup } switch { - case !config.FunctionsConfig.GrammarConfig.NoGrammar && shouldUseFn: + case (!config.FunctionsConfig.GrammarConfig.NoGrammar || strictMode) && shouldUseFn: noActionGrammar := functions.Function{ Name: noActionName, Description: noActionDescription, diff --git a/core/schema/openai.go b/core/schema/openai.go index 3b39eaf3..fe4745bf 100644 --- a/core/schema/openai.go +++ b/core/schema/openai.go @@ -139,6 +139,17 @@ type ChatCompletionResponseFormat struct { Type ChatCompletionResponseFormatType `json:"type,omitempty"` } +type JsonSchemaRequest struct { + Type string `json:"type"` + JsonSchema JsonSchema `json:"json_schema"` +} + +type JsonSchema struct { + Name string `json:"name"` + Strict bool `json:"strict"` + Schema functions.Item `json:"schema"` +} + type OpenAIRequest struct { PredictionOptions diff --git a/pkg/functions/functions.go b/pkg/functions/functions.go index 19012d53..1a7e1ff1 100644 --- a/pkg/functions/functions.go +++ b/pkg/functions/functions.go @@ -14,6 +14,7 @@ const ( type Function struct { Name string `json:"name"` Description string `json:"description"` + Strict bool `json:"strict"` Parameters map[string]interface{} `json:"parameters"` } type Functions []Function