feat: LocalAI functions (#726)

This commit is contained in:
Ettore Di Giacinto 2023-07-09 13:38:46 +02:00 committed by GitHub
commit 7aaa10680d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 781 additions and 37 deletions

View File

@ -1,4 +1,3 @@
.git
.idea
models
examples/chatbot-ui/models

4
.env
View File

@ -26,8 +26,8 @@ MODELS_PATH=/models
## Specify a build type. Available: cublas, openblas, clblas.
# BUILD_TYPE=openblas
## Uncomment and set to false to disable rebuilding from source
# REBUILD=false
## Uncomment and set to true to enable rebuilding from source
# REBUILD=true
## Enable go tags, available: stablediffusion, tts
## stablediffusion: image generation with stablediffusion

View File

@ -83,6 +83,8 @@ RUN make get-sources
COPY go.mod .
RUN make prepare
COPY . .
COPY .git .
RUN ESPEAK_DATA=/build/lib/Linux-$(uname -m)/piper_phonemize/lib/espeak-ng-data make build
###################################
@ -92,7 +94,7 @@ FROM requirements
ARG FFMPEG
ENV REBUILD=true
ENV REBUILD=false
ENV HEALTHCHECK_ENDPOINT=http://localhost:8080/readyz
# Add FFmpeg

View File

@ -3,24 +3,51 @@ GOTEST=$(GOCMD) test
GOVET=$(GOCMD) vet
BINARY_NAME=local-ai
GOLLAMA_VERSION?=ecd358d2f144b4282a73df443d60474fca5db9ec
# llama.cpp versions
# Temporarly pinned to https://github.com/go-skynet/go-llama.cpp/pull/124
GOLLAMA_VERSION?=cb8d7cd4cb95725a04504a9e3a26dd72a12b69ac
# Temporary set a specific version of llama.cpp
# containing: https://github.com/ggerganov/llama.cpp/pull/1773 and
# rebased on top of master.
# This pin can be dropped when the PR above is merged, and go-llama has merged changes as well
# Set empty to use the version pinned by go-llama
LLAMA_CPP_REPO?=https://github.com/mudler/llama.cpp
LLAMA_CPP_VERSION?=48ce8722a05a018681634af801fd0fd45b3a87cc
# gpt4all version
GPT4ALL_REPO?=https://github.com/nomic-ai/gpt4all
GPT4ALL_VERSION?=70cbff70cc2a9ad26d492d44ab582d32e6219956
# go-ggml-transformers version
GOGGMLTRANSFORMERS_VERSION?=8e31841dcddca16468c11b2e7809f279fa76a832
# go-rwkv version
RWKV_REPO?=https://github.com/donomii/go-rwkv.cpp
RWKV_VERSION?=f5a8c45396741470583f59b916a2a7641e63bcd0
# whisper.cpp version
WHISPER_CPP_VERSION?=85ed71aaec8e0612a84c0b67804bde75aa75a273
# bert.cpp version
BERT_VERSION?=6069103f54b9969c02e789d0fb12a23bd614285f
# go-piper version
PIPER_VERSION?=56b8a81b4760a6fbee1a82e62f007ae7e8f010a7
# go-bloomz version
BLOOMZ_VERSION?=1834e77b83faafe912ad4092ccf7f77937349e2f
# stablediffusion version
STABLEDIFFUSION_VERSION?=d89260f598afb809279bc72aa0107b4292587632
export BUILD_TYPE?=
CGO_LDFLAGS?=
CUDA_LIBPATH?=/usr/local/cuda/lib64/
STABLEDIFFUSION_VERSION?=d89260f598afb809279bc72aa0107b4292587632
GO_TAGS?=
BUILD_ID?=git
VERSION?=$(shell git describe --always --tags --dirty || echo "dev" )
VERSION?=$(shell git describe --always --tags || echo "dev" )
# go tool nm ./local-ai | grep Commit
LD_FLAGS?=
override LD_FLAGS += -X "github.com/go-skynet/LocalAI/internal.Version=$(VERSION)"
@ -201,6 +228,9 @@ whisper.cpp/libwhisper.a: whisper.cpp
go-llama:
git clone --recurse-submodules https://github.com/go-skynet/go-llama.cpp go-llama
cd go-llama && git checkout -b build $(GOLLAMA_VERSION) && git submodule update --init --recursive --depth 1
ifneq ($(LLAMA_CPP_REPO),)
cd go-llama && rm -rf llama.cpp && git clone $(LLAMA_CPP_REPO) llama.cpp && cd llama.cpp && git checkout -b build $(LLAMA_CPP_VERSION) && git submodule update --init --recursive --depth 1
endif
go-llama/libbinding.a: go-llama
$(MAKE) -C go-llama BUILD_TYPE=$(BUILD_TYPE) libbinding.a
@ -227,6 +257,7 @@ prepare-sources: get-sources replace
## GENERIC
rebuild: ## Rebuilds the project
$(GOCMD) clean -cache
$(MAKE) -C go-llama clean
$(MAKE) -C gpt4all/gpt4all-bindings/golang/ clean
$(MAKE) -C go-ggml-transformers clean
@ -242,6 +273,7 @@ prepare: prepare-sources backend-assets/gpt4all $(OPTIONAL_TARGETS) go-llama/lib
touch $@
clean: ## Remove build related file
$(GOCMD) clean -cache
rm -fr ./go-llama
rm -rf ./gpt4all
rm -rf ./go-gpt2

View File

@ -51,6 +51,9 @@ func App(opts ...AppOption) (*fiber.App, error) {
}))
}
log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.threads, options.loader.ModelPath)
log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion())
cm := NewConfigMerger()
if err := cm.LoadConfigs(options.loader.ModelPath); err != nil {
log.Error().Msgf("error loading config files: %s", err.Error())

View File

@ -46,12 +46,24 @@ type Config struct {
PromptCacheAll bool `yaml:"prompt_cache_all"`
PromptCacheRO bool `yaml:"prompt_cache_ro"`
PromptStrings, InputStrings []string
InputToken [][]int
Grammar string `yaml:"grammar"`
FunctionsConfig Functions `yaml:"function"`
PromptStrings, InputStrings []string
InputToken [][]int
functionCallString, functionCallNameString string
}
type Functions struct {
DisableNoAction bool `yaml:"disable_no_action"`
NoActionFunctionName string `yaml:"no_action_function_name"`
NoActionDescriptionName string `yaml:"no_action_description_name"`
}
type TemplateConfig struct {
Completion string `yaml:"completion"`
Functions string `yaml:"function"`
Chat string `yaml:"chat"`
Edit string `yaml:"edit"`
}
@ -181,6 +193,10 @@ func updateConfig(config *Config, input *OpenAIRequest) {
config.TopP = input.TopP
}
if input.Grammar != "" {
config.Grammar = input.Grammar
}
if input.Temperature != 0 {
config.Temperature = input.Temperature
}
@ -261,6 +277,23 @@ func updateConfig(config *Config, input *OpenAIRequest) {
}
}
}
// Can be either a string or an object
switch fnc := input.FunctionCall.(type) {
case string:
if fnc != "" {
config.functionCallString = fnc
}
case map[string]interface{}:
var name string
n, exists := fnc["name"]
if exists {
nn, e := n.(string)
if e {
name = nn
}
}
config.functionCallNameString = name
}
switch p := input.Prompt.(type) {
case string:

View File

@ -17,6 +17,7 @@ import (
"strings"
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
"github.com/go-skynet/LocalAI/pkg/grammar"
model "github.com/go-skynet/LocalAI/pkg/model"
whisperutil "github.com/go-skynet/LocalAI/pkg/whisper"
llama "github.com/go-skynet/go-llama.cpp"
@ -73,8 +74,12 @@ type Choice struct {
}
type Message struct {
Role string `json:"role,omitempty" yaml:"role"`
Content string `json:"content,omitempty" yaml:"content"`
// The message role
Role string `json:"role,omitempty" yaml:"role"`
// The message content
Content *string `json:"content" yaml:"content"`
// A result of a function call
FunctionCall interface{} `json:"function_call,omitempty" yaml:"function_call,omitempty"`
}
type OpenAIModel struct {
@ -104,6 +109,10 @@ type OpenAIRequest struct {
// Messages is read only by chat/completion API calls
Messages []Message `json:"messages" yaml:"messages"`
// A list of available functions to call
Functions []grammar.Function `json:"functions" yaml:"functions"`
FunctionCall interface{} `json:"function_call" yaml:"function_call"` // might be a string or an object
Stream bool `json:"stream"`
Echo bool `json:"echo"`
// Common options between all the API calls
@ -134,6 +143,11 @@ type OpenAIRequest struct {
Mode int `json:"mode"`
Step int `json:"step"`
// A grammar to constrain the LLM output
Grammar string `json:"grammar" yaml:"grammar"`
// A grammar object
JSONFunctionGrammarObject *grammar.JSONFunctionStructure `json:"grammar_json_functions" yaml:"grammar_json_functions"`
TypicalP float64 `json:"typical_p" yaml:"typical_p"`
}
@ -202,7 +216,7 @@ func completionEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
if input.Stream {
if len(config.PromptStrings) > 1 {
return errors.New("cannot handle more than 1 `PromptStrings` when `Stream`ing")
return errors.New("cannot handle more than 1 `PromptStrings` when Streaming")
}
predInput := config.PromptStrings[0]
@ -210,7 +224,9 @@ func completionEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
templatedInput, err := o.loader.TemplatePrefix(templateFile, struct {
Input string
}{Input: predInput})
}{
Input: predInput,
})
if err == nil {
predInput = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", predInput)
@ -256,7 +272,9 @@ func completionEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
templatedInput, err := o.loader.TemplatePrefix(templateFile, struct {
Input string
}{Input: i})
}{
Input: i,
})
if err == nil {
i = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", i)
@ -357,7 +375,7 @@ func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
ComputeChoices(s, req, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool {
resp := OpenAIResponse{
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []Choice{{Delta: &Message{Content: s}, Index: 0}},
Choices: []Choice{{Delta: &Message{Content: &s}, Index: 0}},
Object: "chat.completion.chunk",
}
log.Debug().Msgf("Sending goroutine: %s", s)
@ -368,6 +386,8 @@ func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
close(responses)
}
return func(c *fiber.Ctx) error {
processFunctions := false
funcs := grammar.Functions{}
model, input, err := readInput(c, o.loader, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
@ -377,27 +397,116 @@ func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
log.Debug().Msgf("Configuration read: %+v", config)
log.Debug().Msgf("Parameter Config: %+v", config)
// Allow the user to set custom actions via config file
// to be "embedded" in each model
noActionName := "answer"
noActionDescription := "use this action to answer without performing any action"
if config.FunctionsConfig.NoActionFunctionName != "" {
noActionName = config.FunctionsConfig.NoActionFunctionName
}
if config.FunctionsConfig.NoActionDescriptionName != "" {
noActionDescription = config.FunctionsConfig.NoActionDescriptionName
}
// process functions if we have any defined or if we have a function call string
if len(input.Functions) > 0 &&
((config.functionCallString != "none" || config.functionCallString == "") || len(config.functionCallNameString) > 0) {
log.Debug().Msgf("Response needs to process functions")
processFunctions = true
noActionGrammar := grammar.Function{
Name: noActionName,
Description: noActionDescription,
Parameters: map[string]interface{}{
"properties": map[string]interface{}{
"message": map[string]interface{}{
"type": "string",
"description": "The message to reply the user with",
}},
},
}
// Append the no action function
funcs = append(funcs, input.Functions...)
if !config.FunctionsConfig.DisableNoAction {
funcs = append(funcs, noActionGrammar)
}
// Force picking one of the functions by the request
if config.functionCallNameString != "" {
funcs = funcs.Select(config.functionCallNameString)
}
// Update input grammar
jsStruct := funcs.ToJSONStructure()
config.Grammar = jsStruct.Grammar("")
} else if input.JSONFunctionGrammarObject != nil {
config.Grammar = input.JSONFunctionGrammarObject.Grammar("")
}
// functions are not supported in stream mode (yet?)
toStream := input.Stream && !processFunctions
log.Debug().Msgf("Parameters: %+v", config)
var predInput string
mess := []string{}
for _, i := range input.Messages {
var content string
r := config.Roles[i.Role]
role := i.Role
// if function call, we might want to customize the role so we can display better that the "assistant called a json action"
// if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request
if i.FunctionCall != nil && i.Role == "assistant" {
roleFn := "assistant_function_call"
r := config.Roles[roleFn]
if r != "" {
role = roleFn
}
}
r := config.Roles[role]
contentExists := i.Content != nil && *i.Content != ""
if r != "" {
content = fmt.Sprint(r, " ", i.Content)
if contentExists {
content = fmt.Sprint(r, " ", *i.Content)
}
if i.FunctionCall != nil {
j, err := json.Marshal(i.FunctionCall)
if err == nil {
if contentExists {
content += "\n" + fmt.Sprint(r, " ", string(j))
} else {
content = fmt.Sprint(r, " ", string(j))
}
}
}
} else {
content = i.Content
if contentExists {
content = fmt.Sprint(*i.Content)
}
if i.FunctionCall != nil {
j, err := json.Marshal(i.FunctionCall)
if err == nil {
if contentExists {
content += "\n" + string(j)
} else {
content = string(j)
}
}
}
}
mess = append(mess, content)
}
predInput = strings.Join(mess, "\n")
log.Debug().Msgf("Prompt (before templating): %s", predInput)
if input.Stream {
if toStream {
log.Debug().Msgf("Stream request received")
c.Context().SetContentType("text/event-stream")
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
@ -409,20 +518,35 @@ func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
templateFile := config.Model
if config.TemplateConfig.Chat != "" {
if config.TemplateConfig.Chat != "" && !processFunctions {
templateFile = config.TemplateConfig.Chat
}
if config.TemplateConfig.Functions != "" && processFunctions {
templateFile = config.TemplateConfig.Functions
}
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
templatedInput, err := o.loader.TemplatePrefix(templateFile, struct {
Input string
}{Input: predInput})
Input string
Functions []grammar.Function
}{
Input: predInput,
Functions: funcs,
})
if err == nil {
predInput = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", predInput)
} else {
log.Debug().Msgf("Template failed loading: %s", err.Error())
}
if input.Stream {
log.Debug().Msgf("Prompt (after templating): %s", predInput)
if processFunctions {
log.Debug().Msgf("Grammar: %+v", config.Grammar)
}
if toStream {
responses := make(chan OpenAIResponse)
go process(predInput, input, config, o.loader, responses)
@ -459,7 +583,72 @@ func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
}
result, err := ComputeChoices(predInput, input, config, o, o.loader, func(s string, c *[]Choice) {
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: s}})
if processFunctions {
// As we have to change the result before processing, we can't stream the answer (yet?)
ss := map[string]interface{}{}
json.Unmarshal([]byte(s), &ss)
log.Debug().Msgf("Function return: %s %+v", s, ss)
// The grammar defines the function name as "function", while OpenAI returns "name"
func_name := ss["function"]
// Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
args := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
d, _ := json.Marshal(args)
ss["arguments"] = string(d)
ss["name"] = func_name
// if do nothing, reply with a message
if func_name == noActionName {
log.Debug().Msgf("nothing to do, computing a reply")
// If there is a message that the LLM already sends as part of the JSON reply, use it
arguments := map[string]interface{}{}
json.Unmarshal([]byte(d), &arguments)
m, exists := arguments["message"]
if exists {
switch message := m.(type) {
case string:
if message != "" {
log.Debug().Msgf("Reply received from LLM: %s", message)
message = Finetune(*config, predInput, message)
log.Debug().Msgf("Reply received from LLM(finetuned): %s", message)
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &message}})
return
}
}
}
log.Debug().Msgf("No action received from LLM, without a message, computing a reply")
// Otherwise ask the LLM to understand the JSON output and the context, and return a message
// Note: This costs (in term of CPU) another computation
config.Grammar = ""
predFunc, err := ModelInference(predInput, o.loader, *config, o, nil)
if err != nil {
log.Error().Msgf("inference error: %s", err.Error())
return
}
prediction, err := predFunc()
if err != nil {
log.Error().Msgf("inference error: %s", err.Error())
return
}
prediction = Finetune(*config, predInput, prediction)
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &prediction}})
} else {
// otherwise reply with the function call
*c = append(*c, Choice{
FinishReason: "function_call",
Message: &Message{Role: "assistant", FunctionCall: ss},
})
}
return
}
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &s}})
}, nil)
if err != nil {
return err

View File

@ -3,9 +3,11 @@ package api
import (
"context"
"embed"
"encoding/json"
"github.com/go-skynet/LocalAI/pkg/gallery"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
)
type Option struct {
@ -69,6 +71,20 @@ func WithBackendAssets(f embed.FS) AppOption {
}
}
func WithStringGalleries(galls string) AppOption {
return func(o *Option) {
if galls == "" {
log.Debug().Msgf("no galleries to load")
return
}
var galleries []gallery.Gallery
if err := json.Unmarshal([]byte(galls), &galleries); err != nil {
log.Error().Msgf("failed loading galleries: %s", err.Error())
}
o.galleries = append(o.galleries, galleries...)
}
}
func WithGalleries(galleries []gallery.Gallery) AppOption {
return func(o *Option) {
o.galleries = append(o.galleries, galleries...)

View File

@ -189,6 +189,8 @@ func buildLLamaPredictOptions(c Config, modelPath string) []llama.PredictOption
predictOptions = append(predictOptions, llama.EnablePromptCacheRO)
}
predictOptions = append(predictOptions, llama.WithGrammar(c.Grammar))
if c.PromptCachePath != "" {
// Create parent directory
p := filepath.Join(modelPath, c.PromptCachePath)

View File

@ -6,6 +6,16 @@ cd /build
if [ "$REBUILD" != "false" ]; then
rm -rf ./local-ai
ESPEAK_DATA=/build/lib/Linux-$(uname -m)/piper_phonemize/lib/espeak-ng-data make build -j${BUILD_PARALLELISM:-1}
else
echo "@@@@@"
echo "Skipping rebuild"
echo "@@@@@"
echo "If you are experiencing issues with the pre-compiled builds, try setting REBUILD=true"
echo "If you are still experiencing issues with the build, try setting CMAKE_ARGS and disable the instructions set as needed:"
echo 'CMAKE_ARGS="-DLLAMA_F16C=OFF -DLLAMA_AVX512=OFF -DLLAMA_AVX2=OFF -DLLAMA_FMA=OFF"'
echo "see the documentation at: https://localai.io/basics/build/index.html"
echo "Note: See also https://github.com/go-skynet/LocalAI/issues/288"
echo "@@@@@"
fi
./local-ai "$@"

View File

@ -6,5 +6,5 @@ var Version = ""
var Commit = ""
func PrintableVersion() string {
return fmt.Sprintf("LocalAI %s (%s)", Version, Commit)
return fmt.Sprintf("%s (%s)", Version, Commit)
}

14
main.go
View File

@ -1,14 +1,11 @@
package main
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
api "github.com/go-skynet/LocalAI/api"
"github.com/go-skynet/LocalAI/internal"
"github.com/go-skynet/LocalAI/pkg/gallery"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
@ -126,19 +123,13 @@ Some of the models compatible are:
- Alpaca
- StableLM (ggml quantized)
It uses llama.cpp, ggml and gpt4all as backend with golang c bindings.
For a list of compatible model, check out: https://localai.io/model-compatibility/index.html
`,
UsageText: `local-ai [options]`,
Copyright: "go-skynet authors",
Copyright: "Ettore Di Giacinto",
Action: func(ctx *cli.Context) error {
fmt.Printf("Starting LocalAI using %d threads, with models path: %s\n", ctx.Int("threads"), ctx.String("models-path"))
galls := ctx.String("galleries")
var galleries []gallery.Gallery
err := json.Unmarshal([]byte(galls), &galleries)
fmt.Println(err)
app, err := api.App(
api.WithConfigFile(ctx.String("config-file")),
api.WithGalleries(galleries),
api.WithJSONStringPreload(ctx.String("preload-models")),
api.WithYAMLConfigPreload(ctx.String("preload-models-config")),
api.WithModelLoader(model.NewModelLoader(ctx.String("models-path"))),
@ -147,6 +138,7 @@ It uses llama.cpp, ggml and gpt4all as backend with golang c bindings.
api.WithImageDir(ctx.String("image-path")),
api.WithAudioDir(ctx.String("audio-path")),
api.WithF16(ctx.Bool("f16")),
api.WithStringGalleries(ctx.String("galleries")),
api.WithDisableMessage(false),
api.WithCors(ctx.Bool("cors")),
api.WithCorsAllowOrigins(ctx.String("cors-allow-origins")),

View File

@ -4,6 +4,7 @@ import (
"fmt"
"os"
"path/filepath"
"strings"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/imdario/mergo"
@ -17,6 +18,10 @@ type Gallery struct {
// Installs a model from the gallery (galleryname@modelname)
func InstallModelFromGallery(galleries []Gallery, name string, basePath string, req GalleryModel, downloadStatus func(string, string, string, float64)) error {
// os.PathSeparator is not allowed in model names. Replace them with "__" to avoid conflicts with file paths.
name = strings.ReplaceAll(name, string(os.PathSeparator), "__")
models, err := AvailableGalleryModels(galleries, basePath)
if err != nil {
return err

50
pkg/grammar/functions.go Normal file
View File

@ -0,0 +1,50 @@
package grammar
import (
"encoding/json"
)
type Function struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters map[string]interface{} `json:"parameters"`
}
type Functions []Function
func (f Functions) ToJSONStructure() JSONFunctionStructure {
js := JSONFunctionStructure{}
for _, function := range f {
// t := function.Parameters["type"]
//tt := t.(string)
properties := function.Parameters["properties"]
dat, _ := json.Marshal(properties)
prop := map[string]interface{}{}
json.Unmarshal(dat, &prop)
js.OneOf = append(js.OneOf, Item{
Type: "object",
Properties: Properties{
Function: FunctionName{Const: function.Name},
Arguments: Argument{
Type: "object",
Properties: prop,
},
},
})
}
return js
}
// Select returns a list of functions containing the function with the given name
func (f Functions) Select(name string) Functions {
var funcs Functions
for _, f := range f {
if f.Name == name {
funcs = []Function{f}
break
}
}
return funcs
}

View File

@ -0,0 +1,63 @@
package grammar_test
import (
. "github.com/go-skynet/LocalAI/pkg/grammar"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("LocalAI grammar functions", func() {
Describe("ToJSONStructure()", func() {
It("converts a list of functions to a JSON structure that can be parsed to a grammar", func() {
var functions Functions = []Function{
{
Name: "create_event",
Parameters: map[string]interface{}{
"properties": map[string]interface{}{
"event_name": map[string]interface{}{
"type": "string",
},
"event_date": map[string]interface{}{
"type": "string",
},
},
},
},
{
Name: "search",
Parameters: map[string]interface{}{
"properties": map[string]interface{}{
"query": map[string]interface{}{
"type": "string",
},
},
},
},
}
js := functions.ToJSONStructure()
Expect(len(js.OneOf)).To(Equal(2))
Expect(js.OneOf[0].Properties.Function.Const).To(Equal("create_event"))
Expect(js.OneOf[0].Properties.Arguments.Properties["event_name"].(map[string]interface{})["type"]).To(Equal("string"))
Expect(js.OneOf[0].Properties.Arguments.Properties["event_date"].(map[string]interface{})["type"]).To(Equal("string"))
Expect(js.OneOf[1].Properties.Function.Const).To(Equal("search"))
Expect(js.OneOf[1].Properties.Arguments.Properties["query"].(map[string]interface{})["type"]).To(Equal("string"))
})
})
Context("Select()", func() {
It("selects one of the functions and returns a list containing only the selected one", func() {
var functions Functions = []Function{
{
Name: "create_event",
},
{
Name: "search",
},
}
functions = functions.Select("create_event")
Expect(len(functions)).To(Equal(1))
Expect(functions[0].Name).To(Equal("create_event"))
})
})
})

View File

@ -0,0 +1,13 @@
package grammar
import (
"testing"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
func TestGrammar(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Grammar test suite")
}

222
pkg/grammar/json_schema.go Normal file
View File

@ -0,0 +1,222 @@
package grammar
// a golang port of https://github.com/ggerganov/llama.cpp/pull/1887
import (
"encoding/json"
"fmt"
"regexp"
"sort"
"strings"
)
var (
SPACE_RULE = `" "?`
PRIMITIVE_RULES = map[string]string{
"boolean": `("true" | "false") space`,
"number": `[0-9]+ space`, // TODO complete
"string": `"\"" [ \t!#-\[\]-~]* "\"" space`, // TODO complete
"null": `"null" space`,
}
INVALID_RULE_CHARS_RE = regexp.MustCompile(`[^a-zA-Z0-9-]+`)
GRAMMAR_LITERAL_ESCAPE_RE = regexp.MustCompile(`[\r\n"]`)
GRAMMAR_LITERAL_ESCAPES = map[string]string{
"\r": `\r`,
"\n": `\n`,
`"`: `\"`,
}
)
type JSONSchemaConverter struct {
propOrder map[string]int
rules map[string]string
}
func NewJSONSchemaConverter(propOrder string) *JSONSchemaConverter {
propOrderSlice := strings.Split(propOrder, ",")
propOrderMap := make(map[string]int)
for idx, name := range propOrderSlice {
propOrderMap[name] = idx
}
rules := make(map[string]string)
rules["space"] = SPACE_RULE
return &JSONSchemaConverter{
propOrder: propOrderMap,
rules: rules,
}
}
func (sc *JSONSchemaConverter) formatLiteral(literal interface{}) string {
escaped := GRAMMAR_LITERAL_ESCAPE_RE.ReplaceAllStringFunc(jsonString(literal), func(match string) string {
return GRAMMAR_LITERAL_ESCAPES[match]
})
return fmt.Sprintf(`"%s"`, escaped)
}
func (sc *JSONSchemaConverter) addRule(name, rule string) string {
escName := INVALID_RULE_CHARS_RE.ReplaceAllString(name, "-")
key := escName
if existingRule, ok := sc.rules[escName]; ok && existingRule != rule {
i := 0
for {
key = fmt.Sprintf("%s%d", escName, i)
if _, ok := sc.rules[key]; !ok {
break
}
i++
}
}
sc.rules[key] = rule
return key
}
func (sc *JSONSchemaConverter) formatGrammar() string {
var lines []string
for name, rule := range sc.rules {
lines = append(lines, fmt.Sprintf("%s ::= %s", name, rule))
}
return strings.Join(lines, "\n")
}
func (sc *JSONSchemaConverter) visit(schema map[string]interface{}, name string) string {
st, existType := schema["type"]
var schemaType string
if existType {
schemaType = st.(string)
}
ruleName := name
if name == "" {
ruleName = "root"
}
_, oneOfExists := schema["oneOf"]
_, anyOfExists := schema["anyOf"]
if oneOfExists || anyOfExists {
var alternatives []string
oneOfSchemas, oneOfExists := schema["oneOf"].([]interface{})
anyOfSchemas, anyOfExists := schema["anyOf"].([]interface{})
if oneOfExists {
for i, altSchema := range oneOfSchemas {
alternative := sc.visit(altSchema.(map[string]interface{}), fmt.Sprintf("%s-%d", ruleName, i))
alternatives = append(alternatives, alternative)
}
} else if anyOfExists {
for i, altSchema := range anyOfSchemas {
alternative := sc.visit(altSchema.(map[string]interface{}), fmt.Sprintf("%s-%d", ruleName, i))
alternatives = append(alternatives, alternative)
}
}
rule := strings.Join(alternatives, " | ")
return sc.addRule(ruleName, rule)
} else if constVal, exists := schema["const"]; exists {
return sc.addRule(ruleName, sc.formatLiteral(constVal))
} else if enumVals, exists := schema["enum"].([]interface{}); exists {
var enumRules []string
for _, enumVal := range enumVals {
enumRule := sc.formatLiteral(enumVal)
enumRules = append(enumRules, enumRule)
}
rule := strings.Join(enumRules, " | ")
return sc.addRule(ruleName, rule)
} else if properties, exists := schema["properties"].(map[string]interface{}); schemaType == "object" && exists {
propOrder := sc.propOrder
var propPairs []struct {
propName string
propSchema map[string]interface{}
}
for propName, propSchema := range properties {
propPairs = append(propPairs, struct {
propName string
propSchema map[string]interface{}
}{propName: propName, propSchema: propSchema.(map[string]interface{})})
}
sort.Slice(propPairs, func(i, j int) bool {
iOrder := propOrder[propPairs[i].propName]
jOrder := propOrder[propPairs[j].propName]
if iOrder != 0 && jOrder != 0 {
return iOrder < jOrder
}
return propPairs[i].propName < propPairs[j].propName
})
var rule strings.Builder
rule.WriteString(`"{" space`)
for i, propPair := range propPairs {
propName := propPair.propName
propSchema := propPair.propSchema
propRuleName := sc.visit(propSchema, fmt.Sprintf("%s-%s", ruleName, propName))
if i > 0 {
rule.WriteString(` "," space`)
}
rule.WriteString(fmt.Sprintf(` %s space ":" space %s`, sc.formatLiteral(propName), propRuleName))
}
rule.WriteString(` "}" space`)
return sc.addRule(ruleName, rule.String())
} else if items, exists := schema["items"].(map[string]interface{}); schemaType == "array" && exists {
itemRuleName := sc.visit(items, fmt.Sprintf("%s-item", ruleName))
rule := fmt.Sprintf(`"[" space (%s ("," space %s)*)? "]" space`, itemRuleName, itemRuleName)
return sc.addRule(ruleName, rule)
} else {
primitiveRule, exists := PRIMITIVE_RULES[schemaType]
if !exists {
panic(fmt.Sprintf("Unrecognized schema: %v", schema))
}
return sc.addRule(schemaType, primitiveRule)
}
}
func (sc *JSONSchemaConverter) Grammar(schema map[string]interface{}) string {
sc.visit(schema, "")
return sc.formatGrammar()
}
func (sc *JSONSchemaConverter) GrammarFromBytes(b []byte) string {
var schema map[string]interface{}
_ = json.Unmarshal(b, &schema)
return sc.Grammar(schema)
}
func jsonString(v interface{}) string {
b, _ := json.Marshal(v)
return string(b)
}
type FunctionName struct {
Const string `json:"const"`
}
type Properties struct {
Function FunctionName `json:"function"`
Arguments Argument `json:"arguments"`
}
type Argument struct {
Type string `json:"type"`
Properties map[string]interface{} `json:"properties"`
}
type Item struct {
Type string `json:"type"`
Properties Properties `json:"properties"`
}
type JSONFunctionStructure struct {
OneOf []Item `json:"oneOf,omitempty"`
AnyOf []Item `json:"anyOf,omitempty"`
}
func (j JSONFunctionStructure) Grammar(propOrder string) string {
dat, _ := json.Marshal(j)
return NewJSONSchemaConverter(propOrder).GrammarFromBytes(dat)
}

View File

@ -0,0 +1,113 @@
package grammar_test
import (
"strings"
. "github.com/go-skynet/LocalAI/pkg/grammar"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
const (
testInput1 = `
{
"oneOf": [
{
"type": "object",
"properties": {
"function": {"const": "create_event"},
"arguments": {
"type": "object",
"properties": {
"title": {"type": "string"},
"date": {"type": "string"},
"time": {"type": "string"}
}
}
}
},
{
"type": "object",
"properties": {
"function": {"const": "search"},
"arguments": {
"type": "object",
"properties": {
"query": {"type": "string"}
}
}
}
}
]
}`
inputResult1 = `root-0-function ::= "\"create_event\""
root-0 ::= "{" space "\"arguments\"" space ":" space root-0-arguments "," space "\"function\"" space ":" space root-0-function "}" space
root-1-arguments ::= "{" space "\"query\"" space ":" space string "}" space
root ::= root-0 | root-1
space ::= " "?
root-0-arguments ::= "{" space "\"date\"" space ":" space string "," space "\"time\"" space ":" space string "," space "\"title\"" space ":" space string "}" space
root-1 ::= "{" space "\"arguments\"" space ":" space root-1-arguments "," space "\"function\"" space ":" space root-1-function "}" space
string ::= "\"" [ \t!#-\[\]-~]* "\"" space
root-1-function ::= "\"search\""`
)
var _ = Describe("JSON schema grammar tests", func() {
Context("JSON", func() {
It("generates a valid grammar from JSON schema", func() {
grammar := NewJSONSchemaConverter("").GrammarFromBytes([]byte(testInput1))
results := strings.Split(inputResult1, "\n")
for _, r := range results {
if r != "" {
Expect(grammar).To(ContainSubstring(r))
}
}
Expect(len(results)).To(Equal(len(strings.Split(grammar, "\n"))))
})
It("generates a valid grammar from JSON Objects", func() {
structuredGrammar := JSONFunctionStructure{
OneOf: []Item{
{
Type: "object",
Properties: Properties{
Function: FunctionName{
Const: "create_event",
},
Arguments: Argument{ // this is OpenAI's parameter
Type: "object",
Properties: map[string]interface{}{
"title": map[string]string{"type": "string"},
"date": map[string]string{"type": "string"},
"time": map[string]string{"type": "string"},
},
},
},
},
{
Type: "object",
Properties: Properties{
Function: FunctionName{
Const: "search",
},
Arguments: Argument{
Type: "object",
Properties: map[string]interface{}{
"query": map[string]string{"type": "string"},
},
},
},
},
}}
grammar := structuredGrammar.Grammar("")
results := strings.Split(inputResult1, "\n")
for _, r := range results {
if r != "" {
Expect(grammar).To(ContainSubstring(r))
}
}
Expect(len(results)).To(Equal(len(strings.Split(grammar, "\n"))))
})
})
})