From bbea62b907db917b8ad7036d06b828da48269bf8 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Thu, 18 Apr 2024 22:43:12 +0200 Subject: [PATCH] feat(functions): support models with no grammar, add tests (#2068) Signed-off-by: Ettore Di Giacinto --- core/config/backend_config.go | 10 +- core/http/endpoints/openai/chat.go | 131 ++++++------------ core/http/endpoints/openai/completion.go | 4 +- core/http/endpoints/openai/request.go | 4 +- core/schema/openai.go | 14 +- pkg/{grammar => functions}/functions.go | 2 +- .../functions_suite_test.go} | 2 +- pkg/{grammar => functions}/functions_test.go | 4 +- .../grammar_json_schema.go} | 2 +- .../grammar_json_schema_test.go} | 4 +- pkg/functions/parse.go | 108 +++++++++++++++ pkg/functions/parse_test.go | 85 ++++++++++++ pkg/model/loader.go | 4 +- 13 files changed, 255 insertions(+), 119 deletions(-) rename pkg/{grammar => functions}/functions.go (98%) rename pkg/{grammar/grammar_suite_test.go => functions/functions_suite_test.go} (90%) rename pkg/{grammar => functions}/functions_test.go (96%) rename pkg/{grammar/json_schema.go => functions/grammar_json_schema.go} (99%) rename pkg/{grammar/json_schema_test.go => functions/grammar_json_schema_test.go} (98%) create mode 100644 pkg/functions/parse.go create mode 100644 pkg/functions/parse_test.go diff --git a/core/config/backend_config.go b/core/config/backend_config.go index 81c92d01..1161cf9f 100644 --- a/core/config/backend_config.go +++ b/core/config/backend_config.go @@ -12,6 +12,7 @@ import ( "github.com/go-skynet/LocalAI/core/schema" "github.com/go-skynet/LocalAI/pkg/downloader" + "github.com/go-skynet/LocalAI/pkg/functions" "github.com/go-skynet/LocalAI/pkg/utils" "github.com/rs/zerolog/log" "gopkg.in/yaml.v3" @@ -39,7 +40,7 @@ type BackendConfig struct { InputToken [][]int `yaml:"-"` functionCallString, functionCallNameString string `yaml:"-"` - FunctionsConfig Functions `yaml:"function"` + FunctionsConfig functions.FunctionsConfig `yaml:"function"` FeatureFlag FeatureFlag `yaml:"feature_flags"` // Feature Flag registry. We move fast, and features may break on a per model/backend basis. Registry for (usually temporary) flags that indicate aborting something early. // LLM configs (GPT4ALL, Llama.cpp, ...) @@ -157,13 +158,6 @@ type AutoGPTQ struct { UseFastTokenizer bool `yaml:"use_fast_tokenizer"` } -type Functions struct { - DisableNoAction bool `yaml:"disable_no_action"` - NoActionFunctionName string `yaml:"no_action_function_name"` - NoActionDescriptionName string `yaml:"no_action_description_name"` - ParallelCalls bool `yaml:"parallel_calls"` -} - type TemplateConfig struct { Chat string `yaml:"chat"` ChatMessage string `yaml:"chat_message"` diff --git a/core/http/endpoints/openai/chat.go b/core/http/endpoints/openai/chat.go index 36d1142b..9adba8ea 100644 --- a/core/http/endpoints/openai/chat.go +++ b/core/http/endpoints/openai/chat.go @@ -11,9 +11,8 @@ import ( "github.com/go-skynet/LocalAI/core/backend" "github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/core/schema" - "github.com/go-skynet/LocalAI/pkg/grammar" + "github.com/go-skynet/LocalAI/pkg/functions" model "github.com/go-skynet/LocalAI/pkg/model" - "github.com/go-skynet/LocalAI/pkg/utils" "github.com/gofiber/fiber/v2" "github.com/google/uuid" "github.com/rs/zerolog/log" @@ -68,8 +67,8 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup return true }) - results := parseFunctionCall(result, config.FunctionsConfig.ParallelCalls) - noActionToRun := len(results) > 0 && results[0].name == noAction + results := functions.ParseFunctionCall(result, config.FunctionsConfig) + noActionToRun := len(results) > 0 && results[0].Name == noAction || len(results) == 0 switch { case noActionToRun: @@ -82,7 +81,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup } responses <- initialMessage - result, err := handleQuestion(config, req, ml, startupOptions, results[0].arguments, prompt) + result, err := handleQuestion(config, req, ml, startupOptions, results, prompt) if err != nil { log.Error().Err(err).Msg("error handling question") return @@ -105,7 +104,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup default: for i, ss := range results { - name, args := ss.name, ss.arguments + name, args := ss.Name, ss.Arguments initialMessage := schema.OpenAIResponse{ ID: id, @@ -156,8 +155,6 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup } return func(c *fiber.Ctx) error { - processFunctions := false - funcs := grammar.Functions{} modelFile, input, err := readRequest(c, ml, startupOptions, true) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) @@ -169,6 +166,9 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup } log.Debug().Msgf("Configuration read: %+v", config) + funcs := input.Functions + shouldUseFn := len(input.Functions) > 0 && config.ShouldUseFunctions() + // Allow the user to set custom actions via config file // to be "embedded" in each model noActionName := "answer" @@ -182,18 +182,18 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup } if input.ResponseFormat.Type == "json_object" { - input.Grammar = grammar.JSONBNF + input.Grammar = functions.JSONBNF } config.Grammar = input.Grammar - // process functions if we have any defined or if we have a function call string - if len(input.Functions) > 0 && config.ShouldUseFunctions() { + if shouldUseFn { log.Debug().Msgf("Response needs to process functions") + } - processFunctions = true - - noActionGrammar := grammar.Function{ + switch { + case !config.FunctionsConfig.NoGrammar && shouldUseFn: + noActionGrammar := functions.Function{ Name: noActionName, Description: noActionDescription, Parameters: map[string]interface{}{ @@ -206,7 +206,6 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup } // Append the no action function - funcs = append(funcs, input.Functions...) if !config.FunctionsConfig.DisableNoAction { funcs = append(funcs, noActionGrammar) } @@ -219,10 +218,17 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup // Update input grammar jsStruct := funcs.ToJSONStructure() config.Grammar = jsStruct.Grammar("", config.FunctionsConfig.ParallelCalls) - } else if input.JSONFunctionGrammarObject != nil { + case input.JSONFunctionGrammarObject != nil: config.Grammar = input.JSONFunctionGrammarObject.Grammar("", config.FunctionsConfig.ParallelCalls) + default: + // Force picking one of the functions by the request + if config.FunctionToCall() != "" { + funcs = funcs.Select(config.FunctionToCall()) + } } + // process functions if we have any defined or if we have a function call string + // functions are not supported in stream mode (yet?) toStream := input.Stream @@ -232,8 +238,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup // If we are using the tokenizer template, we don't need to process the messages // unless we are processing functions - if !config.TemplateConfig.UseTokenizerTemplate || processFunctions { - + if !config.TemplateConfig.UseTokenizerTemplate || shouldUseFn { suppressConfigSystemPrompt := false mess := []string{} for messageIndex, i := range input.Messages { @@ -346,11 +351,11 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup templateFile = config.Model } - if config.TemplateConfig.Chat != "" && !processFunctions { + if config.TemplateConfig.Chat != "" && !shouldUseFn { templateFile = config.TemplateConfig.Chat } - if config.TemplateConfig.Functions != "" && processFunctions { + if config.TemplateConfig.Functions != "" && shouldUseFn { templateFile = config.TemplateConfig.Functions } @@ -370,7 +375,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup } log.Debug().Msgf("Prompt (after templating): %s", predInput) - if processFunctions { + if shouldUseFn && config.Grammar != "" { log.Debug().Msgf("Grammar: %+v", config.Grammar) } } @@ -388,7 +393,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup responses := make(chan schema.OpenAIResponse) - if !processFunctions { + if !shouldUseFn { go process(predInput, input, config, ml, responses) } else { go processTools(noActionName, predInput, input, config, ml, responses) @@ -446,18 +451,18 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup // no streaming mode default: result, tokenUsage, err := ComputeChoices(input, predInput, config, startupOptions, ml, func(s string, c *[]schema.Choice) { - if !processFunctions { + if !shouldUseFn { // no function is called, just reply and use stop as finish reason *c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}}) return } - results := parseFunctionCall(s, config.FunctionsConfig.ParallelCalls) - noActionsToRun := len(results) > 0 && results[0].name == noActionName + results := functions.ParseFunctionCall(s, config.FunctionsConfig) + noActionsToRun := len(results) > 0 && results[0].Name == noActionName || len(results) == 0 switch { case noActionsToRun: - result, err := handleQuestion(config, input, ml, startupOptions, results[0].arguments, predInput) + result, err := handleQuestion(config, input, ml, startupOptions, results, predInput) if err != nil { log.Error().Err(err).Msg("error handling question") return @@ -476,7 +481,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup } for _, ss := range results { - name, args := ss.name, ss.arguments + name, args := ss.Name, ss.Arguments if len(input.Tools) > 0 { // If we are using tools, we condense the function calls into // a single response choice with all the tools @@ -534,16 +539,20 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup // Return the prediction in the response body return c.JSON(resp) } - } } -func handleQuestion(config *config.BackendConfig, input *schema.OpenAIRequest, ml *model.ModelLoader, o *config.ApplicationConfig, args, prompt string) (string, error) { +func handleQuestion(config *config.BackendConfig, input *schema.OpenAIRequest, ml *model.ModelLoader, o *config.ApplicationConfig, funcResults []functions.FuncCallResults, prompt string) (string, error) { log.Debug().Msgf("nothing to do, computing a reply") - + arg := "" + if len(funcResults) > 0 { + arg = funcResults[0].Arguments + } // If there is a message that the LLM already sends as part of the JSON reply, use it arguments := map[string]interface{}{} - json.Unmarshal([]byte(args), &arguments) + if err := json.Unmarshal([]byte(arg), &arguments); err != nil { + log.Debug().Msg("handleQuestion: function result did not contain a valid JSON object") + } m, exists := arguments["message"] if exists { switch message := m.(type) { @@ -580,63 +589,3 @@ func handleQuestion(config *config.BackendConfig, input *schema.OpenAIRequest, m } return backend.Finetune(*config, prompt, prediction.Response), nil } - -type funcCallResults struct { - name string - arguments string -} - -func parseFunctionCall(llmresult string, multipleResults bool) []funcCallResults { - results := []funcCallResults{} - - // TODO: use generics to avoid this code duplication - if multipleResults { - ss := []map[string]interface{}{} - s := utils.EscapeNewLines(llmresult) - json.Unmarshal([]byte(s), &ss) - log.Debug().Msgf("Function return: %s %+v", s, ss) - - for _, s := range ss { - func_name, ok := s["function"] - if !ok { - continue - } - args, ok := s["arguments"] - if !ok { - continue - } - d, _ := json.Marshal(args) - funcName, ok := func_name.(string) - if !ok { - continue - } - results = append(results, funcCallResults{name: funcName, arguments: string(d)}) - } - } else { - // As we have to change the result before processing, we can't stream the answer token-by-token (yet?) - ss := map[string]interface{}{} - // This prevent newlines to break JSON parsing for clients - s := utils.EscapeNewLines(llmresult) - json.Unmarshal([]byte(s), &ss) - log.Debug().Msgf("Function return: %s %+v", s, ss) - - // The grammar defines the function name as "function", while OpenAI returns "name" - func_name, ok := ss["function"] - if !ok { - return results - } - // Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object - args, ok := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix) - if !ok { - return results - } - d, _ := json.Marshal(args) - funcName, ok := func_name.(string) - if !ok { - return results - } - results = append(results, funcCallResults{name: funcName, arguments: string(d)}) - } - - return results -} diff --git a/core/http/endpoints/openai/completion.go b/core/http/endpoints/openai/completion.go index 69923475..bcd46db5 100644 --- a/core/http/endpoints/openai/completion.go +++ b/core/http/endpoints/openai/completion.go @@ -12,7 +12,7 @@ import ( "github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/core/schema" - "github.com/go-skynet/LocalAI/pkg/grammar" + "github.com/go-skynet/LocalAI/pkg/functions" model "github.com/go-skynet/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" "github.com/google/uuid" @@ -70,7 +70,7 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a } if input.ResponseFormat.Type == "json_object" { - input.Grammar = grammar.JSONBNF + input.Grammar = functions.JSONBNF } config.Grammar = input.Grammar diff --git a/core/http/endpoints/openai/request.go b/core/http/endpoints/openai/request.go index 369fb0b8..9a107bab 100644 --- a/core/http/endpoints/openai/request.go +++ b/core/http/endpoints/openai/request.go @@ -12,7 +12,7 @@ import ( "github.com/go-skynet/LocalAI/core/config" fiberContext "github.com/go-skynet/LocalAI/core/http/ctx" "github.com/go-skynet/LocalAI/core/schema" - "github.com/go-skynet/LocalAI/pkg/grammar" + "github.com/go-skynet/LocalAI/pkg/functions" model "github.com/go-skynet/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" "github.com/rs/zerolog/log" @@ -145,7 +145,7 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque } if input.ToolsChoice != nil { - var toolChoice grammar.Tool + var toolChoice functions.Tool switch content := input.ToolsChoice.(type) { case string: diff --git a/core/schema/openai.go b/core/schema/openai.go index 6aa0f1b0..a251ba68 100644 --- a/core/schema/openai.go +++ b/core/schema/openai.go @@ -3,7 +3,7 @@ package schema import ( "context" - "github.com/go-skynet/LocalAI/pkg/grammar" + functions "github.com/go-skynet/LocalAI/pkg/functions" ) // APIError provides error information returned by the OpenAI API. @@ -108,7 +108,7 @@ type ChatCompletionResponseFormat struct { type OpenAIRequest struct { PredictionOptions - Context context.Context `json:"-"` + Context context.Context `json:"-"` Cancel context.CancelFunc `json:"-"` // whisper @@ -130,11 +130,11 @@ type OpenAIRequest struct { Messages []Message `json:"messages" yaml:"messages"` // A list of available functions to call - Functions []grammar.Function `json:"functions" yaml:"functions"` - FunctionCall interface{} `json:"function_call" yaml:"function_call"` // might be a string or an object + Functions functions.Functions `json:"functions" yaml:"functions"` + FunctionCall interface{} `json:"function_call" yaml:"function_call"` // might be a string or an object - Tools []grammar.Tool `json:"tools,omitempty" yaml:"tools"` - ToolsChoice interface{} `json:"tool_choice,omitempty" yaml:"tool_choice"` + Tools []functions.Tool `json:"tools,omitempty" yaml:"tools"` + ToolsChoice interface{} `json:"tool_choice,omitempty" yaml:"tool_choice"` Stream bool `json:"stream"` @@ -145,7 +145,7 @@ type OpenAIRequest struct { // A grammar to constrain the LLM output Grammar string `json:"grammar" yaml:"grammar"` - JSONFunctionGrammarObject *grammar.JSONFunctionStructure `json:"grammar_json_functions" yaml:"grammar_json_functions"` + JSONFunctionGrammarObject *functions.JSONFunctionStructure `json:"grammar_json_functions" yaml:"grammar_json_functions"` Backend string `json:"backend" yaml:"backend"` diff --git a/pkg/grammar/functions.go b/pkg/functions/functions.go similarity index 98% rename from pkg/grammar/functions.go rename to pkg/functions/functions.go index 1038f5e6..d75a2ee3 100644 --- a/pkg/grammar/functions.go +++ b/pkg/functions/functions.go @@ -1,4 +1,4 @@ -package grammar +package functions import ( "encoding/json" diff --git a/pkg/grammar/grammar_suite_test.go b/pkg/functions/functions_suite_test.go similarity index 90% rename from pkg/grammar/grammar_suite_test.go rename to pkg/functions/functions_suite_test.go index 652643b6..8964b1c8 100644 --- a/pkg/grammar/grammar_suite_test.go +++ b/pkg/functions/functions_suite_test.go @@ -1,4 +1,4 @@ -package grammar +package functions import ( "testing" diff --git a/pkg/grammar/functions_test.go b/pkg/functions/functions_test.go similarity index 96% rename from pkg/grammar/functions_test.go rename to pkg/functions/functions_test.go index 6e8a56ed..97953a5e 100644 --- a/pkg/grammar/functions_test.go +++ b/pkg/functions/functions_test.go @@ -1,7 +1,7 @@ -package grammar_test +package functions_test import ( - . "github.com/go-skynet/LocalAI/pkg/grammar" + . "github.com/go-skynet/LocalAI/pkg/functions" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) diff --git a/pkg/grammar/json_schema.go b/pkg/functions/grammar_json_schema.go similarity index 99% rename from pkg/grammar/json_schema.go rename to pkg/functions/grammar_json_schema.go index 76f9778f..01046390 100644 --- a/pkg/grammar/json_schema.go +++ b/pkg/functions/grammar_json_schema.go @@ -1,4 +1,4 @@ -package grammar +package functions // a golang port of https://github.com/ggerganov/llama.cpp/pull/1887 diff --git a/pkg/grammar/json_schema_test.go b/pkg/functions/grammar_json_schema_test.go similarity index 98% rename from pkg/grammar/json_schema_test.go rename to pkg/functions/grammar_json_schema_test.go index 39d2a4d5..fc9029a8 100644 --- a/pkg/grammar/json_schema_test.go +++ b/pkg/functions/grammar_json_schema_test.go @@ -1,9 +1,9 @@ -package grammar_test +package functions_test import ( "strings" - . "github.com/go-skynet/LocalAI/pkg/grammar" + . "github.com/go-skynet/LocalAI/pkg/functions" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) diff --git a/pkg/functions/parse.go b/pkg/functions/parse.go new file mode 100644 index 00000000..5324e8c6 --- /dev/null +++ b/pkg/functions/parse.go @@ -0,0 +1,108 @@ +package functions + +import ( + "encoding/json" + "regexp" + + "github.com/go-skynet/LocalAI/pkg/utils" + "github.com/rs/zerolog/log" +) + +type FunctionsConfig struct { + DisableNoAction bool `yaml:"disable_no_action"` + NoActionFunctionName string `yaml:"no_action_function_name"` + NoActionDescriptionName string `yaml:"no_action_description_name"` + ParallelCalls bool `yaml:"parallel_calls"` + NoGrammar bool `yaml:"no_grammar"` + ResponseRegex string `yaml:"response_regex"` +} + +type FuncCallResults struct { + Name string + Arguments string +} + +func ParseFunctionCall(llmresult string, functionConfig FunctionsConfig) []FuncCallResults { + multipleResults := functionConfig.ParallelCalls + useGrammars := !functionConfig.NoGrammar + + results := []FuncCallResults{} + + // if no grammar is used, we have to extract function and arguments from the result + if !useGrammars { + // the response is a string that we have to parse + + // We use named regexes here to extract the function name and arguments + // obviously, this expects the LLM to be stable and return correctly formatted JSON + // TODO: optimize this and pre-compile it + var respRegex = regexp.MustCompile(functionConfig.ResponseRegex) + match := respRegex.FindStringSubmatch(llmresult) + result := make(map[string]string) + for i, name := range respRegex.SubexpNames() { + if i != 0 && name != "" && len(match) > i { + result[name] = match[i] + } + } + + // TODO: open point about multiple results and/or mixed with chat messages + // This is not handled as for now, we only expect one function call per response + functionName := result["function"] + if functionName == "" { + return results + } + + return append(results, FuncCallResults{Name: result["function"], Arguments: result["arguments"]}) + } + + // with grammars + // TODO: use generics to avoid this code duplication + if multipleResults { + ss := []map[string]interface{}{} + s := utils.EscapeNewLines(llmresult) + json.Unmarshal([]byte(s), &ss) + log.Debug().Msgf("Function return: %s %+v", s, ss) + + for _, s := range ss { + func_name, ok := s["function"] + if !ok { + continue + } + args, ok := s["arguments"] + if !ok { + continue + } + d, _ := json.Marshal(args) + funcName, ok := func_name.(string) + if !ok { + continue + } + results = append(results, FuncCallResults{Name: funcName, Arguments: string(d)}) + } + } else { + // As we have to change the result before processing, we can't stream the answer token-by-token (yet?) + ss := map[string]interface{}{} + // This prevent newlines to break JSON parsing for clients + s := utils.EscapeNewLines(llmresult) + json.Unmarshal([]byte(s), &ss) + log.Debug().Msgf("Function return: %s %+v", s, ss) + + // The grammar defines the function name as "function", while OpenAI returns "name" + func_name, ok := ss["function"] + if !ok { + return results + } + // Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object + args, ok := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix) + if !ok { + return results + } + d, _ := json.Marshal(args) + funcName, ok := func_name.(string) + if !ok { + return results + } + results = append(results, FuncCallResults{Name: funcName, Arguments: string(d)}) + } + + return results +} diff --git a/pkg/functions/parse_test.go b/pkg/functions/parse_test.go new file mode 100644 index 00000000..5168a7d1 --- /dev/null +++ b/pkg/functions/parse_test.go @@ -0,0 +1,85 @@ +package functions_test + +import ( + . "github.com/go-skynet/LocalAI/pkg/functions" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("LocalAI function parse tests", func() { + var functionConfig FunctionsConfig + + BeforeEach(func() { + // Default configuration setup + functionConfig = FunctionsConfig{ + ParallelCalls: false, + NoGrammar: false, + ResponseRegex: `(?P\w+)\s*\((?P.*)\)`, + } + }) + + Context("when using grammars and single result expected", func() { + It("should parse the function name and arguments correctly", func() { + input := `{"function": "add", "arguments": {"x": 5, "y": 3}}` + functionConfig.ParallelCalls = false + functionConfig.NoGrammar = false + + results := ParseFunctionCall(input, functionConfig) + Expect(results).To(HaveLen(1)) + Expect(results[0].Name).To(Equal("add")) + Expect(results[0].Arguments).To(Equal(`{"x":5,"y":3}`)) + }) + }) + + Context("when not using grammars and regex is needed", func() { + It("should extract function name and arguments from the regex", func() { + input := `add({"x":5,"y":3})` + functionConfig.NoGrammar = true + + results := ParseFunctionCall(input, functionConfig) + Expect(results).To(HaveLen(1)) + Expect(results[0].Name).To(Equal("add")) + Expect(results[0].Arguments).To(Equal(`{"x":5,"y":3}`)) + }) + }) + + Context("when having invalid input", func() { + It("returns no results when there is no input", func() { + input := "" + functionConfig.NoGrammar = true + + results := ParseFunctionCall(input, functionConfig) + Expect(results).To(HaveLen(0)) + + functionConfig.NoGrammar = false + + results = ParseFunctionCall(input, functionConfig) + Expect(results).To(HaveLen(0)) + }) + It("returns no results when is invalid", func() { + input := "invalid input" + functionConfig.NoGrammar = true + + results := ParseFunctionCall(input, functionConfig) + Expect(results).To(HaveLen(0)) + functionConfig.NoGrammar = false + + results = ParseFunctionCall(input, functionConfig) + Expect(results).To(HaveLen(0)) + }) + }) + Context("when parallel calls are enabled", func() { + It("should handle multiple function calls", func() { + input := `[{"function": "add", "arguments": {"x": 5, "y": 3}}, {"function": "subtract", "arguments": {"x": 10, "y": 7}}]` + functionConfig.ParallelCalls = true + functionConfig.NoGrammar = false + + results := ParseFunctionCall(input, functionConfig) + Expect(results).To(HaveLen(2)) + Expect(results[0].Name).To(Equal("add")) + Expect(results[0].Arguments).To(Equal(`{"x":5,"y":3}`)) + Expect(results[1].Name).To(Equal("subtract")) + Expect(results[1].Arguments).To(Equal(`{"x":10,"y":7}`)) + }) + }) +}) diff --git a/pkg/model/loader.go b/pkg/model/loader.go index 003d8327..f3182940 100644 --- a/pkg/model/loader.go +++ b/pkg/model/loader.go @@ -11,7 +11,7 @@ import ( "text/template" "github.com/Masterminds/sprig/v3" - grammar "github.com/go-skynet/LocalAI/pkg/grammar" + "github.com/go-skynet/LocalAI/pkg/functions" "github.com/go-skynet/LocalAI/pkg/grpc" process "github.com/mudler/go-processmanager" "github.com/rs/zerolog/log" @@ -25,7 +25,7 @@ type PromptTemplateData struct { SuppressSystemPrompt bool // used by chat specifically to indicate that SystemPrompt above should be _ignored_ Input string Instruction string - Functions []grammar.Function + Functions []functions.Function MessageIndex int }