From 27ec84827c40a81663ef4df51c5e9e30bbb458c9 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Fri, 19 Apr 2024 04:40:18 +0200 Subject: [PATCH] refactor(template): isolate and add tests (#2069) * refactor(template): isolate and add tests Signed-off-by: Ettore Di Giacinto --------- Signed-off-by: Ettore Di Giacinto Signed-off-by: Dave Co-authored-by: Dave --- pkg/model/loader.go | 111 +++++------------------------- pkg/model/loader_test.go | 7 +- pkg/templates/cache.go | 103 +++++++++++++++++++++++++++ pkg/templates/cache_test.go | 73 ++++++++++++++++++++ pkg/templates/utils_suite_test.go | 13 ++++ pkg/utils/path.go | 6 ++ 6 files changed, 218 insertions(+), 95 deletions(-) create mode 100644 pkg/templates/cache.go create mode 100644 pkg/templates/cache_test.go create mode 100644 pkg/templates/utils_suite_test.go diff --git a/pkg/model/loader.go b/pkg/model/loader.go index f3182940..1b5c9aa0 100644 --- a/pkg/model/loader.go +++ b/pkg/model/loader.go @@ -1,18 +1,19 @@ package model import ( - "bytes" "context" "fmt" "os" "path/filepath" "strings" "sync" - "text/template" - "github.com/Masterminds/sprig/v3" + "github.com/go-skynet/LocalAI/pkg/templates" + "github.com/go-skynet/LocalAI/pkg/functions" "github.com/go-skynet/LocalAI/pkg/grpc" + "github.com/go-skynet/LocalAI/pkg/utils" + process "github.com/mudler/go-processmanager" "github.com/rs/zerolog/log" ) @@ -42,21 +43,6 @@ type ChatMessageTemplateData struct { LastMessage bool } -// Keep this in sync with config.TemplateConfig. Is there a more idiomatic way to accomplish this in go? -// Technically, order doesn't _really_ matter, but the count must stay in sync, see tests/integration/reflect_test.go -type TemplateType int - -const ( - ChatPromptTemplate TemplateType = iota - ChatMessageTemplate - CompletionPromptTemplate - EditPromptTemplate - FunctionsPromptTemplate - - // The following TemplateType is **NOT** a valid value and MUST be last. It exists to make the sanity integration tests simpler! - IntegrationTestTemplate -) - // new idea: what if we declare a struct of these here, and use a loop to check? // TODO: Split ModelLoader and TemplateLoader? Just to keep things more organized. Left together to share a mutex until I look into that. Would split if we seperate directories for .bin/.yaml and .tmpl @@ -67,7 +53,7 @@ type ModelLoader struct { grpcClients map[string]grpc.Backend models map[string]ModelAddress grpcProcesses map[string]*process.Process - templates map[TemplateType]map[string]*template.Template + templates *templates.TemplateCache wd *WatchDog } @@ -86,11 +72,10 @@ func NewModelLoader(modelPath string) *ModelLoader { ModelPath: modelPath, grpcClients: make(map[string]grpc.Backend), models: make(map[string]ModelAddress), - templates: make(map[TemplateType]map[string]*template.Template), + templates: templates.NewTemplateCache(modelPath), grpcProcesses: make(map[string]*process.Process), } - nml.initializeTemplateMap() return nml } @@ -99,7 +84,7 @@ func (ml *ModelLoader) SetWatchDog(wd *WatchDog) { } func (ml *ModelLoader) ExistsInModelPath(s string) bool { - return existsInPath(ml.ModelPath, s) + return utils.ExistsInPath(ml.ModelPath, s) } func (ml *ModelLoader) ListModels() ([]string, error) { @@ -194,82 +179,22 @@ func (ml *ModelLoader) CheckIsLoaded(s string) ModelAddress { return "" } -func (ml *ModelLoader) EvaluateTemplateForPrompt(templateType TemplateType, templateName string, in PromptTemplateData) (string, error) { +const ( + ChatPromptTemplate templates.TemplateType = iota + ChatMessageTemplate + CompletionPromptTemplate + EditPromptTemplate + FunctionsPromptTemplate +) + +func (ml *ModelLoader) EvaluateTemplateForPrompt(templateType templates.TemplateType, templateName string, in PromptTemplateData) (string, error) { // TODO: should this check be improved? if templateType == ChatMessageTemplate { return "", fmt.Errorf("invalid templateType: ChatMessage") } - return ml.evaluateTemplate(templateType, templateName, in) + return ml.templates.EvaluateTemplate(templateType, templateName, in) } func (ml *ModelLoader) EvaluateTemplateForChatMessage(templateName string, messageData ChatMessageTemplateData) (string, error) { - return ml.evaluateTemplate(ChatMessageTemplate, templateName, messageData) -} - -func existsInPath(path string, s string) bool { - _, err := os.Stat(filepath.Join(path, s)) - return err == nil -} - -func (ml *ModelLoader) initializeTemplateMap() { - // This also seems somewhat clunky as we reference the Test / End of valid data value slug, but it works? - for tt := TemplateType(0); tt < IntegrationTestTemplate; tt++ { - ml.templates[tt] = make(map[string]*template.Template) - } -} - -func (ml *ModelLoader) evaluateTemplate(templateType TemplateType, templateName string, in interface{}) (string, error) { - ml.mu.Lock() - defer ml.mu.Unlock() - - m, ok := ml.templates[templateType][templateName] - if !ok { - // return "", fmt.Errorf("template not loaded: %s", templateName) - loadErr := ml.loadTemplateIfExists(templateType, templateName) - if loadErr != nil { - return "", loadErr - } - m = ml.templates[templateType][templateName] // ok is not important since we check m on the next line, and wealready checked - } - if m == nil { - return "", fmt.Errorf("failed loading a template for %s", templateName) - } - - var buf bytes.Buffer - - if err := m.Execute(&buf, in); err != nil { - return "", err - } - return buf.String(), nil -} - -func (ml *ModelLoader) loadTemplateIfExists(templateType TemplateType, templateName string) error { - // Check if the template was already loaded - if _, ok := ml.templates[templateType][templateName]; ok { - return nil - } - - // Check if the model path exists - // skip any error here - we run anyway if a template does not exist - modelTemplateFile := fmt.Sprintf("%s.tmpl", templateName) - - dat := "" - if ml.ExistsInModelPath(modelTemplateFile) { - d, err := os.ReadFile(filepath.Join(ml.ModelPath, modelTemplateFile)) - if err != nil { - return err - } - dat = string(d) - } else { - dat = templateName - } - - // Parse the template - tmpl, err := template.New("prompt").Funcs(sprig.FuncMap()).Parse(dat) - if err != nil { - return err - } - ml.templates[templateType][templateName] = tmpl - - return nil + return ml.templates.EvaluateTemplate(ChatMessageTemplate, templateName, messageData) } diff --git a/pkg/model/loader_test.go b/pkg/model/loader_test.go index 4c3c1a88..e4207b35 100644 --- a/pkg/model/loader_test.go +++ b/pkg/model/loader_test.go @@ -92,10 +92,13 @@ var testMatch map[string]map[string]interface{} = map[string]map[string]interfac var _ = Describe("Templates", func() { Context("chat message", func() { - modelLoader := NewModelLoader("") + var modelLoader *ModelLoader + BeforeEach(func() { + modelLoader = NewModelLoader("") + }) for key := range testMatch { foo := testMatch[key] - It("renders correctly "+key, func() { + It("renders correctly `"+key+"`", func() { templated, err := modelLoader.EvaluateTemplateForChatMessage(foo["template"].(string), foo["data"].(model.ChatMessageTemplateData)) Expect(err).ToNot(HaveOccurred()) Expect(templated).To(Equal(foo["expected"]), templated) diff --git a/pkg/templates/cache.go b/pkg/templates/cache.go new file mode 100644 index 00000000..9ff55605 --- /dev/null +++ b/pkg/templates/cache.go @@ -0,0 +1,103 @@ +package templates + +import ( + "bytes" + "fmt" + "os" + "path/filepath" + "sync" + "text/template" + + "github.com/go-skynet/LocalAI/pkg/utils" + + "github.com/Masterminds/sprig/v3" +) + +// Keep this in sync with config.TemplateConfig. Is there a more idiomatic way to accomplish this in go? +// Technically, order doesn't _really_ matter, but the count must stay in sync, see tests/integration/reflect_test.go +type TemplateType int + +type TemplateCache struct { + mu sync.Mutex + templatesPath string + templates map[TemplateType]map[string]*template.Template +} + +func NewTemplateCache(templatesPath string) *TemplateCache { + tc := &TemplateCache{ + templatesPath: templatesPath, + templates: make(map[TemplateType]map[string]*template.Template), + } + return tc +} + +func (tc *TemplateCache) initializeTemplateMapKey(tt TemplateType) { + if _, ok := tc.templates[tt]; !ok { + tc.templates[tt] = make(map[string]*template.Template) + } +} + +func (tc *TemplateCache) EvaluateTemplate(templateType TemplateType, templateName string, in interface{}) (string, error) { + tc.mu.Lock() + defer tc.mu.Unlock() + + tc.initializeTemplateMapKey(templateType) + m, ok := tc.templates[templateType][templateName] + if !ok { + // return "", fmt.Errorf("template not loaded: %s", templateName) + loadErr := tc.loadTemplateIfExists(templateType, templateName) + if loadErr != nil { + return "", loadErr + } + m = tc.templates[templateType][templateName] // ok is not important since we check m on the next line, and wealready checked + } + if m == nil { + return "", fmt.Errorf("failed loading a template for %s", templateName) + } + + var buf bytes.Buffer + + if err := m.Execute(&buf, in); err != nil { + return "", err + } + return buf.String(), nil +} + +func (tc *TemplateCache) loadTemplateIfExists(templateType TemplateType, templateName string) error { + + // Check if the template was already loaded + if _, ok := tc.templates[templateType][templateName]; ok { + return nil + } + + // Check if the model path exists + // skip any error here - we run anyway if a template does not exist + modelTemplateFile := fmt.Sprintf("%s.tmpl", templateName) + + dat := "" + file := filepath.Join(tc.templatesPath, modelTemplateFile) + + // Security check + if err := utils.VerifyPath(modelTemplateFile, tc.templatesPath); err != nil { + return fmt.Errorf("template file outside path: %s", file) + } + + if utils.ExistsInPath(tc.templatesPath, modelTemplateFile) { + d, err := os.ReadFile(file) + if err != nil { + return err + } + dat = string(d) + } else { + dat = templateName + } + + // Parse the template + tmpl, err := template.New("prompt").Funcs(sprig.FuncMap()).Parse(dat) + if err != nil { + return err + } + tc.templates[templateType][templateName] = tmpl + + return nil +} diff --git a/pkg/templates/cache_test.go b/pkg/templates/cache_test.go new file mode 100644 index 00000000..83af02b2 --- /dev/null +++ b/pkg/templates/cache_test.go @@ -0,0 +1,73 @@ +package templates_test + +import ( + "os" + "path/filepath" + + "github.com/go-skynet/LocalAI/pkg/templates" // Update with your module path + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("TemplateCache", func() { + var ( + templateCache *templates.TemplateCache + tempDir string + ) + + BeforeEach(func() { + var err error + tempDir, err = os.MkdirTemp("", "templates") + Expect(err).NotTo(HaveOccurred()) + + // Writing example template files + err = os.WriteFile(filepath.Join(tempDir, "example.tmpl"), []byte("Hello, {{.Name}}!"), 0644) + Expect(err).NotTo(HaveOccurred()) + err = os.WriteFile(filepath.Join(tempDir, "empty.tmpl"), []byte(""), 0644) + Expect(err).NotTo(HaveOccurred()) + + templateCache = templates.NewTemplateCache(tempDir) + }) + + AfterEach(func() { + os.RemoveAll(tempDir) // Clean up + }) + + Describe("EvaluateTemplate", func() { + Context("when template is loaded successfully", func() { + It("should evaluate the template correctly", func() { + result, err := templateCache.EvaluateTemplate(1, "example", map[string]string{"Name": "Gopher"}) + Expect(err).NotTo(HaveOccurred()) + Expect(result).To(Equal("Hello, Gopher!")) + }) + }) + + Context("when template isn't a file", func() { + It("should parse from string", func() { + result, err := templateCache.EvaluateTemplate(1, "{{.Name}}", map[string]string{"Name": "Gopher"}) + Expect(err).ToNot(HaveOccurred()) + Expect(result).To(Equal("Gopher")) + }) + }) + + Context("when template is empty", func() { + It("should return an empty string", func() { + result, err := templateCache.EvaluateTemplate(1, "empty", nil) + Expect(err).NotTo(HaveOccurred()) + Expect(result).To(Equal("")) + }) + }) + }) + + Describe("concurrency", func() { + It("should handle multiple concurrent accesses", func(done Done) { + go func() { + _, _ = templateCache.EvaluateTemplate(1, "example", map[string]string{"Name": "Gopher"}) + }() + go func() { + _, _ = templateCache.EvaluateTemplate(1, "example", map[string]string{"Name": "Gopher"}) + }() + close(done) + }, 0.1) // timeout in seconds + }) +}) diff --git a/pkg/templates/utils_suite_test.go b/pkg/templates/utils_suite_test.go new file mode 100644 index 00000000..011ba8f6 --- /dev/null +++ b/pkg/templates/utils_suite_test.go @@ -0,0 +1,13 @@ +package templates_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestTemplates(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Templates test suite") +} diff --git a/pkg/utils/path.go b/pkg/utils/path.go index f95b0138..9982bc1e 100644 --- a/pkg/utils/path.go +++ b/pkg/utils/path.go @@ -2,10 +2,16 @@ package utils import ( "fmt" + "os" "path/filepath" "strings" ) +func ExistsInPath(path string, s string) bool { + _, err := os.Stat(filepath.Join(path, s)) + return err == nil +} + func inTrustedRoot(path string, trustedRoot string) error { for path != "/" { path = filepath.Dir(path)