mirror of
https://github.com/mudler/LocalAI.git
synced 2024-12-18 20:27:57 +00:00
feat(functions): support models with no grammar, add tests (#2068)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
13012cfa70
commit
bbea62b907
@ -12,6 +12,7 @@ import (
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/schema"
|
||||
"github.com/go-skynet/LocalAI/pkg/downloader"
|
||||
"github.com/go-skynet/LocalAI/pkg/functions"
|
||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
"github.com/rs/zerolog/log"
|
||||
"gopkg.in/yaml.v3"
|
||||
@ -39,7 +40,7 @@ type BackendConfig struct {
|
||||
InputToken [][]int `yaml:"-"`
|
||||
functionCallString, functionCallNameString string `yaml:"-"`
|
||||
|
||||
FunctionsConfig Functions `yaml:"function"`
|
||||
FunctionsConfig functions.FunctionsConfig `yaml:"function"`
|
||||
|
||||
FeatureFlag FeatureFlag `yaml:"feature_flags"` // Feature Flag registry. We move fast, and features may break on a per model/backend basis. Registry for (usually temporary) flags that indicate aborting something early.
|
||||
// LLM configs (GPT4ALL, Llama.cpp, ...)
|
||||
@ -157,13 +158,6 @@ type AutoGPTQ struct {
|
||||
UseFastTokenizer bool `yaml:"use_fast_tokenizer"`
|
||||
}
|
||||
|
||||
type Functions struct {
|
||||
DisableNoAction bool `yaml:"disable_no_action"`
|
||||
NoActionFunctionName string `yaml:"no_action_function_name"`
|
||||
NoActionDescriptionName string `yaml:"no_action_description_name"`
|
||||
ParallelCalls bool `yaml:"parallel_calls"`
|
||||
}
|
||||
|
||||
type TemplateConfig struct {
|
||||
Chat string `yaml:"chat"`
|
||||
ChatMessage string `yaml:"chat_message"`
|
||||
|
@ -11,9 +11,8 @@ import (
|
||||
"github.com/go-skynet/LocalAI/core/backend"
|
||||
"github.com/go-skynet/LocalAI/core/config"
|
||||
"github.com/go-skynet/LocalAI/core/schema"
|
||||
"github.com/go-skynet/LocalAI/pkg/grammar"
|
||||
"github.com/go-skynet/LocalAI/pkg/functions"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog/log"
|
||||
@ -68,8 +67,8 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
|
||||
return true
|
||||
})
|
||||
|
||||
results := parseFunctionCall(result, config.FunctionsConfig.ParallelCalls)
|
||||
noActionToRun := len(results) > 0 && results[0].name == noAction
|
||||
results := functions.ParseFunctionCall(result, config.FunctionsConfig)
|
||||
noActionToRun := len(results) > 0 && results[0].Name == noAction || len(results) == 0
|
||||
|
||||
switch {
|
||||
case noActionToRun:
|
||||
@ -82,7 +81,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
|
||||
}
|
||||
responses <- initialMessage
|
||||
|
||||
result, err := handleQuestion(config, req, ml, startupOptions, results[0].arguments, prompt)
|
||||
result, err := handleQuestion(config, req, ml, startupOptions, results, prompt)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("error handling question")
|
||||
return
|
||||
@ -105,7 +104,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
|
||||
|
||||
default:
|
||||
for i, ss := range results {
|
||||
name, args := ss.name, ss.arguments
|
||||
name, args := ss.Name, ss.Arguments
|
||||
|
||||
initialMessage := schema.OpenAIResponse{
|
||||
ID: id,
|
||||
@ -156,8 +155,6 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
|
||||
}
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
processFunctions := false
|
||||
funcs := grammar.Functions{}
|
||||
modelFile, input, err := readRequest(c, ml, startupOptions, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
@ -169,6 +166,9 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
|
||||
}
|
||||
log.Debug().Msgf("Configuration read: %+v", config)
|
||||
|
||||
funcs := input.Functions
|
||||
shouldUseFn := len(input.Functions) > 0 && config.ShouldUseFunctions()
|
||||
|
||||
// Allow the user to set custom actions via config file
|
||||
// to be "embedded" in each model
|
||||
noActionName := "answer"
|
||||
@ -182,18 +182,18 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
|
||||
}
|
||||
|
||||
if input.ResponseFormat.Type == "json_object" {
|
||||
input.Grammar = grammar.JSONBNF
|
||||
input.Grammar = functions.JSONBNF
|
||||
}
|
||||
|
||||
config.Grammar = input.Grammar
|
||||
|
||||
// process functions if we have any defined or if we have a function call string
|
||||
if len(input.Functions) > 0 && config.ShouldUseFunctions() {
|
||||
if shouldUseFn {
|
||||
log.Debug().Msgf("Response needs to process functions")
|
||||
}
|
||||
|
||||
processFunctions = true
|
||||
|
||||
noActionGrammar := grammar.Function{
|
||||
switch {
|
||||
case !config.FunctionsConfig.NoGrammar && shouldUseFn:
|
||||
noActionGrammar := functions.Function{
|
||||
Name: noActionName,
|
||||
Description: noActionDescription,
|
||||
Parameters: map[string]interface{}{
|
||||
@ -206,7 +206,6 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
|
||||
}
|
||||
|
||||
// Append the no action function
|
||||
funcs = append(funcs, input.Functions...)
|
||||
if !config.FunctionsConfig.DisableNoAction {
|
||||
funcs = append(funcs, noActionGrammar)
|
||||
}
|
||||
@ -219,9 +218,16 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
|
||||
// Update input grammar
|
||||
jsStruct := funcs.ToJSONStructure()
|
||||
config.Grammar = jsStruct.Grammar("", config.FunctionsConfig.ParallelCalls)
|
||||
} else if input.JSONFunctionGrammarObject != nil {
|
||||
case input.JSONFunctionGrammarObject != nil:
|
||||
config.Grammar = input.JSONFunctionGrammarObject.Grammar("", config.FunctionsConfig.ParallelCalls)
|
||||
default:
|
||||
// Force picking one of the functions by the request
|
||||
if config.FunctionToCall() != "" {
|
||||
funcs = funcs.Select(config.FunctionToCall())
|
||||
}
|
||||
}
|
||||
|
||||
// process functions if we have any defined or if we have a function call string
|
||||
|
||||
// functions are not supported in stream mode (yet?)
|
||||
toStream := input.Stream
|
||||
@ -232,8 +238,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
|
||||
|
||||
// If we are using the tokenizer template, we don't need to process the messages
|
||||
// unless we are processing functions
|
||||
if !config.TemplateConfig.UseTokenizerTemplate || processFunctions {
|
||||
|
||||
if !config.TemplateConfig.UseTokenizerTemplate || shouldUseFn {
|
||||
suppressConfigSystemPrompt := false
|
||||
mess := []string{}
|
||||
for messageIndex, i := range input.Messages {
|
||||
@ -346,11 +351,11 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
|
||||
templateFile = config.Model
|
||||
}
|
||||
|
||||
if config.TemplateConfig.Chat != "" && !processFunctions {
|
||||
if config.TemplateConfig.Chat != "" && !shouldUseFn {
|
||||
templateFile = config.TemplateConfig.Chat
|
||||
}
|
||||
|
||||
if config.TemplateConfig.Functions != "" && processFunctions {
|
||||
if config.TemplateConfig.Functions != "" && shouldUseFn {
|
||||
templateFile = config.TemplateConfig.Functions
|
||||
}
|
||||
|
||||
@ -370,7 +375,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Prompt (after templating): %s", predInput)
|
||||
if processFunctions {
|
||||
if shouldUseFn && config.Grammar != "" {
|
||||
log.Debug().Msgf("Grammar: %+v", config.Grammar)
|
||||
}
|
||||
}
|
||||
@ -388,7 +393,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
|
||||
|
||||
responses := make(chan schema.OpenAIResponse)
|
||||
|
||||
if !processFunctions {
|
||||
if !shouldUseFn {
|
||||
go process(predInput, input, config, ml, responses)
|
||||
} else {
|
||||
go processTools(noActionName, predInput, input, config, ml, responses)
|
||||
@ -446,18 +451,18 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
|
||||
// no streaming mode
|
||||
default:
|
||||
result, tokenUsage, err := ComputeChoices(input, predInput, config, startupOptions, ml, func(s string, c *[]schema.Choice) {
|
||||
if !processFunctions {
|
||||
if !shouldUseFn {
|
||||
// no function is called, just reply and use stop as finish reason
|
||||
*c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}})
|
||||
return
|
||||
}
|
||||
|
||||
results := parseFunctionCall(s, config.FunctionsConfig.ParallelCalls)
|
||||
noActionsToRun := len(results) > 0 && results[0].name == noActionName
|
||||
results := functions.ParseFunctionCall(s, config.FunctionsConfig)
|
||||
noActionsToRun := len(results) > 0 && results[0].Name == noActionName || len(results) == 0
|
||||
|
||||
switch {
|
||||
case noActionsToRun:
|
||||
result, err := handleQuestion(config, input, ml, startupOptions, results[0].arguments, predInput)
|
||||
result, err := handleQuestion(config, input, ml, startupOptions, results, predInput)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("error handling question")
|
||||
return
|
||||
@ -476,7 +481,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
|
||||
}
|
||||
|
||||
for _, ss := range results {
|
||||
name, args := ss.name, ss.arguments
|
||||
name, args := ss.Name, ss.Arguments
|
||||
if len(input.Tools) > 0 {
|
||||
// If we are using tools, we condense the function calls into
|
||||
// a single response choice with all the tools
|
||||
@ -534,16 +539,20 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func handleQuestion(config *config.BackendConfig, input *schema.OpenAIRequest, ml *model.ModelLoader, o *config.ApplicationConfig, args, prompt string) (string, error) {
|
||||
func handleQuestion(config *config.BackendConfig, input *schema.OpenAIRequest, ml *model.ModelLoader, o *config.ApplicationConfig, funcResults []functions.FuncCallResults, prompt string) (string, error) {
|
||||
log.Debug().Msgf("nothing to do, computing a reply")
|
||||
|
||||
arg := ""
|
||||
if len(funcResults) > 0 {
|
||||
arg = funcResults[0].Arguments
|
||||
}
|
||||
// If there is a message that the LLM already sends as part of the JSON reply, use it
|
||||
arguments := map[string]interface{}{}
|
||||
json.Unmarshal([]byte(args), &arguments)
|
||||
if err := json.Unmarshal([]byte(arg), &arguments); err != nil {
|
||||
log.Debug().Msg("handleQuestion: function result did not contain a valid JSON object")
|
||||
}
|
||||
m, exists := arguments["message"]
|
||||
if exists {
|
||||
switch message := m.(type) {
|
||||
@ -580,63 +589,3 @@ func handleQuestion(config *config.BackendConfig, input *schema.OpenAIRequest, m
|
||||
}
|
||||
return backend.Finetune(*config, prompt, prediction.Response), nil
|
||||
}
|
||||
|
||||
type funcCallResults struct {
|
||||
name string
|
||||
arguments string
|
||||
}
|
||||
|
||||
func parseFunctionCall(llmresult string, multipleResults bool) []funcCallResults {
|
||||
results := []funcCallResults{}
|
||||
|
||||
// TODO: use generics to avoid this code duplication
|
||||
if multipleResults {
|
||||
ss := []map[string]interface{}{}
|
||||
s := utils.EscapeNewLines(llmresult)
|
||||
json.Unmarshal([]byte(s), &ss)
|
||||
log.Debug().Msgf("Function return: %s %+v", s, ss)
|
||||
|
||||
for _, s := range ss {
|
||||
func_name, ok := s["function"]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
args, ok := s["arguments"]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
d, _ := json.Marshal(args)
|
||||
funcName, ok := func_name.(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
results = append(results, funcCallResults{name: funcName, arguments: string(d)})
|
||||
}
|
||||
} else {
|
||||
// As we have to change the result before processing, we can't stream the answer token-by-token (yet?)
|
||||
ss := map[string]interface{}{}
|
||||
// This prevent newlines to break JSON parsing for clients
|
||||
s := utils.EscapeNewLines(llmresult)
|
||||
json.Unmarshal([]byte(s), &ss)
|
||||
log.Debug().Msgf("Function return: %s %+v", s, ss)
|
||||
|
||||
// The grammar defines the function name as "function", while OpenAI returns "name"
|
||||
func_name, ok := ss["function"]
|
||||
if !ok {
|
||||
return results
|
||||
}
|
||||
// Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
|
||||
args, ok := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
|
||||
if !ok {
|
||||
return results
|
||||
}
|
||||
d, _ := json.Marshal(args)
|
||||
funcName, ok := func_name.(string)
|
||||
if !ok {
|
||||
return results
|
||||
}
|
||||
results = append(results, funcCallResults{name: funcName, arguments: string(d)})
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
||||
|
@ -12,7 +12,7 @@ import (
|
||||
"github.com/go-skynet/LocalAI/core/config"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/schema"
|
||||
"github.com/go-skynet/LocalAI/pkg/grammar"
|
||||
"github.com/go-skynet/LocalAI/pkg/functions"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
@ -70,7 +70,7 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a
|
||||
}
|
||||
|
||||
if input.ResponseFormat.Type == "json_object" {
|
||||
input.Grammar = grammar.JSONBNF
|
||||
input.Grammar = functions.JSONBNF
|
||||
}
|
||||
|
||||
config.Grammar = input.Grammar
|
||||
|
@ -12,7 +12,7 @@ import (
|
||||
"github.com/go-skynet/LocalAI/core/config"
|
||||
fiberContext "github.com/go-skynet/LocalAI/core/http/ctx"
|
||||
"github.com/go-skynet/LocalAI/core/schema"
|
||||
"github.com/go-skynet/LocalAI/pkg/grammar"
|
||||
"github.com/go-skynet/LocalAI/pkg/functions"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
@ -145,7 +145,7 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque
|
||||
}
|
||||
|
||||
if input.ToolsChoice != nil {
|
||||
var toolChoice grammar.Tool
|
||||
var toolChoice functions.Tool
|
||||
|
||||
switch content := input.ToolsChoice.(type) {
|
||||
case string:
|
||||
|
@ -3,7 +3,7 @@ package schema
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grammar"
|
||||
functions "github.com/go-skynet/LocalAI/pkg/functions"
|
||||
)
|
||||
|
||||
// APIError provides error information returned by the OpenAI API.
|
||||
@ -130,10 +130,10 @@ type OpenAIRequest struct {
|
||||
Messages []Message `json:"messages" yaml:"messages"`
|
||||
|
||||
// A list of available functions to call
|
||||
Functions []grammar.Function `json:"functions" yaml:"functions"`
|
||||
Functions functions.Functions `json:"functions" yaml:"functions"`
|
||||
FunctionCall interface{} `json:"function_call" yaml:"function_call"` // might be a string or an object
|
||||
|
||||
Tools []grammar.Tool `json:"tools,omitempty" yaml:"tools"`
|
||||
Tools []functions.Tool `json:"tools,omitempty" yaml:"tools"`
|
||||
ToolsChoice interface{} `json:"tool_choice,omitempty" yaml:"tool_choice"`
|
||||
|
||||
Stream bool `json:"stream"`
|
||||
@ -145,7 +145,7 @@ type OpenAIRequest struct {
|
||||
// A grammar to constrain the LLM output
|
||||
Grammar string `json:"grammar" yaml:"grammar"`
|
||||
|
||||
JSONFunctionGrammarObject *grammar.JSONFunctionStructure `json:"grammar_json_functions" yaml:"grammar_json_functions"`
|
||||
JSONFunctionGrammarObject *functions.JSONFunctionStructure `json:"grammar_json_functions" yaml:"grammar_json_functions"`
|
||||
|
||||
Backend string `json:"backend" yaml:"backend"`
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
package grammar
|
||||
package functions
|
||||
|
||||
import (
|
||||
"encoding/json"
|
@ -1,4 +1,4 @@
|
||||
package grammar
|
||||
package functions
|
||||
|
||||
import (
|
||||
"testing"
|
@ -1,7 +1,7 @@
|
||||
package grammar_test
|
||||
package functions_test
|
||||
|
||||
import (
|
||||
. "github.com/go-skynet/LocalAI/pkg/grammar"
|
||||
. "github.com/go-skynet/LocalAI/pkg/functions"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
@ -1,4 +1,4 @@
|
||||
package grammar
|
||||
package functions
|
||||
|
||||
// a golang port of https://github.com/ggerganov/llama.cpp/pull/1887
|
||||
|
@ -1,9 +1,9 @@
|
||||
package grammar_test
|
||||
package functions_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
. "github.com/go-skynet/LocalAI/pkg/grammar"
|
||||
. "github.com/go-skynet/LocalAI/pkg/functions"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
108
pkg/functions/parse.go
Normal file
108
pkg/functions/parse.go
Normal file
@ -0,0 +1,108 @@
|
||||
package functions
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"regexp"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type FunctionsConfig struct {
|
||||
DisableNoAction bool `yaml:"disable_no_action"`
|
||||
NoActionFunctionName string `yaml:"no_action_function_name"`
|
||||
NoActionDescriptionName string `yaml:"no_action_description_name"`
|
||||
ParallelCalls bool `yaml:"parallel_calls"`
|
||||
NoGrammar bool `yaml:"no_grammar"`
|
||||
ResponseRegex string `yaml:"response_regex"`
|
||||
}
|
||||
|
||||
type FuncCallResults struct {
|
||||
Name string
|
||||
Arguments string
|
||||
}
|
||||
|
||||
func ParseFunctionCall(llmresult string, functionConfig FunctionsConfig) []FuncCallResults {
|
||||
multipleResults := functionConfig.ParallelCalls
|
||||
useGrammars := !functionConfig.NoGrammar
|
||||
|
||||
results := []FuncCallResults{}
|
||||
|
||||
// if no grammar is used, we have to extract function and arguments from the result
|
||||
if !useGrammars {
|
||||
// the response is a string that we have to parse
|
||||
|
||||
// We use named regexes here to extract the function name and arguments
|
||||
// obviously, this expects the LLM to be stable and return correctly formatted JSON
|
||||
// TODO: optimize this and pre-compile it
|
||||
var respRegex = regexp.MustCompile(functionConfig.ResponseRegex)
|
||||
match := respRegex.FindStringSubmatch(llmresult)
|
||||
result := make(map[string]string)
|
||||
for i, name := range respRegex.SubexpNames() {
|
||||
if i != 0 && name != "" && len(match) > i {
|
||||
result[name] = match[i]
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: open point about multiple results and/or mixed with chat messages
|
||||
// This is not handled as for now, we only expect one function call per response
|
||||
functionName := result["function"]
|
||||
if functionName == "" {
|
||||
return results
|
||||
}
|
||||
|
||||
return append(results, FuncCallResults{Name: result["function"], Arguments: result["arguments"]})
|
||||
}
|
||||
|
||||
// with grammars
|
||||
// TODO: use generics to avoid this code duplication
|
||||
if multipleResults {
|
||||
ss := []map[string]interface{}{}
|
||||
s := utils.EscapeNewLines(llmresult)
|
||||
json.Unmarshal([]byte(s), &ss)
|
||||
log.Debug().Msgf("Function return: %s %+v", s, ss)
|
||||
|
||||
for _, s := range ss {
|
||||
func_name, ok := s["function"]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
args, ok := s["arguments"]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
d, _ := json.Marshal(args)
|
||||
funcName, ok := func_name.(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
results = append(results, FuncCallResults{Name: funcName, Arguments: string(d)})
|
||||
}
|
||||
} else {
|
||||
// As we have to change the result before processing, we can't stream the answer token-by-token (yet?)
|
||||
ss := map[string]interface{}{}
|
||||
// This prevent newlines to break JSON parsing for clients
|
||||
s := utils.EscapeNewLines(llmresult)
|
||||
json.Unmarshal([]byte(s), &ss)
|
||||
log.Debug().Msgf("Function return: %s %+v", s, ss)
|
||||
|
||||
// The grammar defines the function name as "function", while OpenAI returns "name"
|
||||
func_name, ok := ss["function"]
|
||||
if !ok {
|
||||
return results
|
||||
}
|
||||
// Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
|
||||
args, ok := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
|
||||
if !ok {
|
||||
return results
|
||||
}
|
||||
d, _ := json.Marshal(args)
|
||||
funcName, ok := func_name.(string)
|
||||
if !ok {
|
||||
return results
|
||||
}
|
||||
results = append(results, FuncCallResults{Name: funcName, Arguments: string(d)})
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
85
pkg/functions/parse_test.go
Normal file
85
pkg/functions/parse_test.go
Normal file
@ -0,0 +1,85 @@
|
||||
package functions_test
|
||||
|
||||
import (
|
||||
. "github.com/go-skynet/LocalAI/pkg/functions"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("LocalAI function parse tests", func() {
|
||||
var functionConfig FunctionsConfig
|
||||
|
||||
BeforeEach(func() {
|
||||
// Default configuration setup
|
||||
functionConfig = FunctionsConfig{
|
||||
ParallelCalls: false,
|
||||
NoGrammar: false,
|
||||
ResponseRegex: `(?P<function>\w+)\s*\((?P<arguments>.*)\)`,
|
||||
}
|
||||
})
|
||||
|
||||
Context("when using grammars and single result expected", func() {
|
||||
It("should parse the function name and arguments correctly", func() {
|
||||
input := `{"function": "add", "arguments": {"x": 5, "y": 3}}`
|
||||
functionConfig.ParallelCalls = false
|
||||
functionConfig.NoGrammar = false
|
||||
|
||||
results := ParseFunctionCall(input, functionConfig)
|
||||
Expect(results).To(HaveLen(1))
|
||||
Expect(results[0].Name).To(Equal("add"))
|
||||
Expect(results[0].Arguments).To(Equal(`{"x":5,"y":3}`))
|
||||
})
|
||||
})
|
||||
|
||||
Context("when not using grammars and regex is needed", func() {
|
||||
It("should extract function name and arguments from the regex", func() {
|
||||
input := `add({"x":5,"y":3})`
|
||||
functionConfig.NoGrammar = true
|
||||
|
||||
results := ParseFunctionCall(input, functionConfig)
|
||||
Expect(results).To(HaveLen(1))
|
||||
Expect(results[0].Name).To(Equal("add"))
|
||||
Expect(results[0].Arguments).To(Equal(`{"x":5,"y":3}`))
|
||||
})
|
||||
})
|
||||
|
||||
Context("when having invalid input", func() {
|
||||
It("returns no results when there is no input", func() {
|
||||
input := ""
|
||||
functionConfig.NoGrammar = true
|
||||
|
||||
results := ParseFunctionCall(input, functionConfig)
|
||||
Expect(results).To(HaveLen(0))
|
||||
|
||||
functionConfig.NoGrammar = false
|
||||
|
||||
results = ParseFunctionCall(input, functionConfig)
|
||||
Expect(results).To(HaveLen(0))
|
||||
})
|
||||
It("returns no results when is invalid", func() {
|
||||
input := "invalid input"
|
||||
functionConfig.NoGrammar = true
|
||||
|
||||
results := ParseFunctionCall(input, functionConfig)
|
||||
Expect(results).To(HaveLen(0))
|
||||
functionConfig.NoGrammar = false
|
||||
|
||||
results = ParseFunctionCall(input, functionConfig)
|
||||
Expect(results).To(HaveLen(0))
|
||||
})
|
||||
})
|
||||
Context("when parallel calls are enabled", func() {
|
||||
It("should handle multiple function calls", func() {
|
||||
input := `[{"function": "add", "arguments": {"x": 5, "y": 3}}, {"function": "subtract", "arguments": {"x": 10, "y": 7}}]`
|
||||
functionConfig.ParallelCalls = true
|
||||
functionConfig.NoGrammar = false
|
||||
|
||||
results := ParseFunctionCall(input, functionConfig)
|
||||
Expect(results).To(HaveLen(2))
|
||||
Expect(results[0].Name).To(Equal("add"))
|
||||
Expect(results[0].Arguments).To(Equal(`{"x":5,"y":3}`))
|
||||
Expect(results[1].Name).To(Equal("subtract"))
|
||||
Expect(results[1].Arguments).To(Equal(`{"x":10,"y":7}`))
|
||||
})
|
||||
})
|
||||
})
|
@ -11,7 +11,7 @@ import (
|
||||
"text/template"
|
||||
|
||||
"github.com/Masterminds/sprig/v3"
|
||||
grammar "github.com/go-skynet/LocalAI/pkg/grammar"
|
||||
"github.com/go-skynet/LocalAI/pkg/functions"
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc"
|
||||
process "github.com/mudler/go-processmanager"
|
||||
"github.com/rs/zerolog/log"
|
||||
@ -25,7 +25,7 @@ type PromptTemplateData struct {
|
||||
SuppressSystemPrompt bool // used by chat specifically to indicate that SystemPrompt above should be _ignored_
|
||||
Input string
|
||||
Instruction string
|
||||
Functions []grammar.Function
|
||||
Functions []functions.Function
|
||||
MessageIndex int
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user