From cea5a0ea42348f64b982ef7fb64796a86d2bd70e Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sun, 8 Dec 2024 13:50:33 +0100 Subject: [PATCH] feat(template): read jinja templates from gguf files (#4332) * Read jinja templates as fallback Signed-off-by: Ettore Di Giacinto * Move templating out of model loader Signed-off-by: Ettore Di Giacinto * Test TemplateMessages Signed-off-by: Ettore Di Giacinto * Set role and content from transformers Signed-off-by: Ettore Di Giacinto * Tests: be more flexible Signed-off-by: Ettore Di Giacinto * More jinja Signed-off-by: Ettore Di Giacinto * Small refactoring and adaptations Signed-off-by: Ettore Di Giacinto --------- Signed-off-by: Ettore Di Giacinto --- core/application.go | 38 --- core/application/application.go | 39 +++ .../config_file_watcher.go | 4 +- core/{startup => application}/startup.go | 77 ++--- core/cli/run.go | 8 +- core/config/backend_config.go | 2 + core/config/guesser.go | 16 +- core/http/app.go | 73 +++-- core/http/app_test.go | 24 +- core/http/endpoints/openai/chat.go | 146 +-------- core/http/endpoints/openai/completion.go | 47 +-- core/http/endpoints/openai/edit.go | 33 +- core/http/routes/localai.go | 48 +-- core/http/routes/openai.go | 154 ++++++--- go.mod | 5 + go.sum | 12 + pkg/model/loader.go | 4 - pkg/model/template.go | 52 --- pkg/model/template_test.go | 197 ------------ pkg/templates/cache.go | 156 ++++++--- pkg/templates/cache_test.go | 73 ----- pkg/templates/evaluator.go | 295 ++++++++++++++++++ pkg/templates/evaluator_test.go | 253 +++++++++++++++ 23 files changed, 971 insertions(+), 785 deletions(-) delete mode 100644 core/application.go create mode 100644 core/application/application.go rename core/{startup => application}/config_file_watcher.go (96%) rename core/{startup => application}/startup.go (62%) delete mode 100644 pkg/model/template.go delete mode 100644 pkg/model/template_test.go delete mode 100644 pkg/templates/cache_test.go create mode 100644 pkg/templates/evaluator.go create mode 100644 pkg/templates/evaluator_test.go diff --git a/core/application.go b/core/application.go deleted file mode 100644 index e4efbdd0..00000000 --- a/core/application.go +++ /dev/null @@ -1,38 +0,0 @@ -package core - -import ( - "github.com/mudler/LocalAI/core/config" - "github.com/mudler/LocalAI/core/services" - "github.com/mudler/LocalAI/pkg/model" -) - -// The purpose of this structure is to hold pointers to all initialized services, to make plumbing easy -// Perhaps a proper DI system is worth it in the future, but for now keep things simple. -type Application struct { - - // Application-Level Config - ApplicationConfig *config.ApplicationConfig - // ApplicationState *ApplicationState - - // Core Low-Level Services - BackendConfigLoader *config.BackendConfigLoader - ModelLoader *model.ModelLoader - - // Backend Services - // EmbeddingsBackendService *backend.EmbeddingsBackendService - // ImageGenerationBackendService *backend.ImageGenerationBackendService - // LLMBackendService *backend.LLMBackendService - // TranscriptionBackendService *backend.TranscriptionBackendService - // TextToSpeechBackendService *backend.TextToSpeechBackendService - - // LocalAI System Services - BackendMonitorService *services.BackendMonitorService - GalleryService *services.GalleryService - LocalAIMetricsService *services.LocalAIMetricsService - // OpenAIService *services.OpenAIService -} - -// TODO [NEXT PR?]: Break up ApplicationConfig. -// Migrate over stuff that is not set via config at all - especially runtime stuff -type ApplicationState struct { -} diff --git a/core/application/application.go b/core/application/application.go new file mode 100644 index 00000000..6e8d6204 --- /dev/null +++ b/core/application/application.go @@ -0,0 +1,39 @@ +package application + +import ( + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/pkg/model" + "github.com/mudler/LocalAI/pkg/templates" +) + +type Application struct { + backendLoader *config.BackendConfigLoader + modelLoader *model.ModelLoader + applicationConfig *config.ApplicationConfig + templatesEvaluator *templates.Evaluator +} + +func newApplication(appConfig *config.ApplicationConfig) *Application { + return &Application{ + backendLoader: config.NewBackendConfigLoader(appConfig.ModelPath), + modelLoader: model.NewModelLoader(appConfig.ModelPath), + applicationConfig: appConfig, + templatesEvaluator: templates.NewEvaluator(appConfig.ModelPath), + } +} + +func (a *Application) BackendLoader() *config.BackendConfigLoader { + return a.backendLoader +} + +func (a *Application) ModelLoader() *model.ModelLoader { + return a.modelLoader +} + +func (a *Application) ApplicationConfig() *config.ApplicationConfig { + return a.applicationConfig +} + +func (a *Application) TemplatesEvaluator() *templates.Evaluator { + return a.templatesEvaluator +} diff --git a/core/startup/config_file_watcher.go b/core/application/config_file_watcher.go similarity index 96% rename from core/startup/config_file_watcher.go rename to core/application/config_file_watcher.go index df72483f..46f29b10 100644 --- a/core/startup/config_file_watcher.go +++ b/core/application/config_file_watcher.go @@ -1,4 +1,4 @@ -package startup +package application import ( "encoding/json" @@ -8,8 +8,8 @@ import ( "path/filepath" "time" - "github.com/fsnotify/fsnotify" "dario.cat/mergo" + "github.com/fsnotify/fsnotify" "github.com/mudler/LocalAI/core/config" "github.com/rs/zerolog/log" ) diff --git a/core/startup/startup.go b/core/application/startup.go similarity index 62% rename from core/startup/startup.go rename to core/application/startup.go index 0eb5fa58..cd52d37a 100644 --- a/core/startup/startup.go +++ b/core/application/startup.go @@ -1,15 +1,15 @@ -package startup +package application import ( "fmt" "os" - "github.com/mudler/LocalAI/core" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/internal" "github.com/mudler/LocalAI/pkg/assets" + "github.com/mudler/LocalAI/pkg/library" "github.com/mudler/LocalAI/pkg/model" pkgStartup "github.com/mudler/LocalAI/pkg/startup" @@ -17,8 +17,9 @@ import ( "github.com/rs/zerolog/log" ) -func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.ModelLoader, *config.ApplicationConfig, error) { +func New(opts ...config.AppOption) (*Application, error) { options := config.NewApplicationConfig(opts...) + application := newApplication(options) log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.ModelPath) log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion()) @@ -36,28 +37,28 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode // Make sure directories exists if options.ModelPath == "" { - return nil, nil, nil, fmt.Errorf("options.ModelPath cannot be empty") + return nil, fmt.Errorf("options.ModelPath cannot be empty") } err = os.MkdirAll(options.ModelPath, 0750) if err != nil { - return nil, nil, nil, fmt.Errorf("unable to create ModelPath: %q", err) + return nil, fmt.Errorf("unable to create ModelPath: %q", err) } if options.ImageDir != "" { err := os.MkdirAll(options.ImageDir, 0750) if err != nil { - return nil, nil, nil, fmt.Errorf("unable to create ImageDir: %q", err) + return nil, fmt.Errorf("unable to create ImageDir: %q", err) } } if options.AudioDir != "" { err := os.MkdirAll(options.AudioDir, 0750) if err != nil { - return nil, nil, nil, fmt.Errorf("unable to create AudioDir: %q", err) + return nil, fmt.Errorf("unable to create AudioDir: %q", err) } } if options.UploadDir != "" { err := os.MkdirAll(options.UploadDir, 0750) if err != nil { - return nil, nil, nil, fmt.Errorf("unable to create UploadDir: %q", err) + return nil, fmt.Errorf("unable to create UploadDir: %q", err) } } @@ -65,39 +66,36 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode log.Error().Err(err).Msg("error installing models") } - cl := config.NewBackendConfigLoader(options.ModelPath) - ml := model.NewModelLoader(options.ModelPath) - configLoaderOpts := options.ToConfigLoaderOptions() - if err := cl.LoadBackendConfigsFromPath(options.ModelPath, configLoaderOpts...); err != nil { + if err := application.BackendLoader().LoadBackendConfigsFromPath(options.ModelPath, configLoaderOpts...); err != nil { log.Error().Err(err).Msg("error loading config files") } if options.ConfigFile != "" { - if err := cl.LoadMultipleBackendConfigsSingleFile(options.ConfigFile, configLoaderOpts...); err != nil { + if err := application.BackendLoader().LoadMultipleBackendConfigsSingleFile(options.ConfigFile, configLoaderOpts...); err != nil { log.Error().Err(err).Msg("error loading config file") } } - if err := cl.Preload(options.ModelPath); err != nil { + if err := application.BackendLoader().Preload(options.ModelPath); err != nil { log.Error().Err(err).Msg("error downloading models") } if options.PreloadJSONModels != "" { if err := services.ApplyGalleryFromString(options.ModelPath, options.PreloadJSONModels, options.EnforcePredownloadScans, options.Galleries); err != nil { - return nil, nil, nil, err + return nil, err } } if options.PreloadModelsFromPath != "" { if err := services.ApplyGalleryFromFile(options.ModelPath, options.PreloadModelsFromPath, options.EnforcePredownloadScans, options.Galleries); err != nil { - return nil, nil, nil, err + return nil, err } } if options.Debug { - for _, v := range cl.GetAllBackendConfigs() { + for _, v := range application.BackendLoader().GetAllBackendConfigs() { log.Debug().Msgf("Model: %s (config: %+v)", v.Name, v) } } @@ -123,7 +121,7 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode go func() { <-options.Context.Done() log.Debug().Msgf("Context canceled, shutting down") - err := ml.StopAllGRPC() + err := application.ModelLoader().StopAllGRPC() if err != nil { log.Error().Err(err).Msg("error while stopping all grpc backends") } @@ -131,12 +129,12 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode if options.WatchDog { wd := model.NewWatchDog( - ml, + application.ModelLoader(), options.WatchDogBusyTimeout, options.WatchDogIdleTimeout, options.WatchDogBusy, options.WatchDogIdle) - ml.SetWatchDog(wd) + application.ModelLoader().SetWatchDog(wd) go wd.Run() go func() { <-options.Context.Done() @@ -147,7 +145,7 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode if options.LoadToMemory != nil { for _, m := range options.LoadToMemory { - cfg, err := cl.LoadBackendConfigFileByName(m, options.ModelPath, + cfg, err := application.BackendLoader().LoadBackendConfigFileByName(m, options.ModelPath, config.LoadOptionDebug(options.Debug), config.LoadOptionThreads(options.Threads), config.LoadOptionContextSize(options.ContextSize), @@ -155,7 +153,7 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode config.ModelPath(options.ModelPath), ) if err != nil { - return nil, nil, nil, err + return nil, err } log.Debug().Msgf("Auto loading model %s into memory from file: %s", m, cfg.Model) @@ -163,9 +161,9 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode o := backend.ModelOptions(*cfg, options) var backendErr error - _, backendErr = ml.Load(o...) + _, backendErr = application.ModelLoader().Load(o...) if backendErr != nil { - return nil, nil, nil, err + return nil, err } } } @@ -174,7 +172,7 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode startWatcher(options) log.Info().Msg("core/startup process completed!") - return cl, ml, options, nil + return application, nil } func startWatcher(options *config.ApplicationConfig) { @@ -201,32 +199,3 @@ func startWatcher(options *config.ApplicationConfig) { log.Error().Err(err).Msg("failed creating watcher") } } - -// In Lieu of a proper DI framework, this function wires up the Application manually. -// This is in core/startup rather than core/state.go to keep package references clean! -func createApplication(appConfig *config.ApplicationConfig) *core.Application { - app := &core.Application{ - ApplicationConfig: appConfig, - BackendConfigLoader: config.NewBackendConfigLoader(appConfig.ModelPath), - ModelLoader: model.NewModelLoader(appConfig.ModelPath), - } - - var err error - - // app.EmbeddingsBackendService = backend.NewEmbeddingsBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig) - // app.ImageGenerationBackendService = backend.NewImageGenerationBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig) - // app.LLMBackendService = backend.NewLLMBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig) - // app.TranscriptionBackendService = backend.NewTranscriptionBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig) - // app.TextToSpeechBackendService = backend.NewTextToSpeechBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig) - - app.BackendMonitorService = services.NewBackendMonitorService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig) - app.GalleryService = services.NewGalleryService(app.ApplicationConfig) - // app.OpenAIService = services.NewOpenAIService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig, app.LLMBackendService) - - app.LocalAIMetricsService, err = services.NewLocalAIMetricsService() - if err != nil { - log.Error().Err(err).Msg("encountered an error initializing metrics service, startup will continue but metrics will not be tracked.") - } - - return app -} diff --git a/core/cli/run.go b/core/cli/run.go index b2d439a0..a0e16155 100644 --- a/core/cli/run.go +++ b/core/cli/run.go @@ -6,12 +6,12 @@ import ( "strings" "time" + "github.com/mudler/LocalAI/core/application" cli_api "github.com/mudler/LocalAI/core/cli/api" cliContext "github.com/mudler/LocalAI/core/cli/context" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http" "github.com/mudler/LocalAI/core/p2p" - "github.com/mudler/LocalAI/core/startup" "github.com/rs/zerolog" "github.com/rs/zerolog/log" ) @@ -186,16 +186,16 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error { } if r.PreloadBackendOnly { - _, _, _, err := startup.Startup(opts...) + _, err := application.New(opts...) return err } - cl, ml, options, err := startup.Startup(opts...) + app, err := application.New(opts...) if err != nil { return fmt.Errorf("failed basic startup tasks with error %s", err.Error()) } - appHTTP, err := http.App(cl, ml, options) + appHTTP, err := http.API(app) if err != nil { log.Error().Err(err).Msg("error during HTTP App construction") return err diff --git a/core/config/backend_config.go b/core/config/backend_config.go index 0ff34769..f07ec3d3 100644 --- a/core/config/backend_config.go +++ b/core/config/backend_config.go @@ -206,6 +206,8 @@ type TemplateConfig struct { JoinChatMessagesByCharacter *string `yaml:"join_chat_messages_by_character"` Multimodal string `yaml:"multimodal"` + + JinjaTemplate bool `yaml:"jinja_template"` } func (c *BackendConfig) UnmarshalYAML(value *yaml.Node) error { diff --git a/core/config/guesser.go b/core/config/guesser.go index b63dd051..f5627461 100644 --- a/core/config/guesser.go +++ b/core/config/guesser.go @@ -26,14 +26,14 @@ const ( type settingsConfig struct { StopWords []string TemplateConfig TemplateConfig - RepeatPenalty float64 + RepeatPenalty float64 } // default settings to adopt with a given model family var defaultsSettings map[familyType]settingsConfig = map[familyType]settingsConfig{ Gemma: { RepeatPenalty: 1.0, - StopWords: []string{"<|im_end|>", "", ""}, + StopWords: []string{"<|im_end|>", "", ""}, TemplateConfig: TemplateConfig{ Chat: "{{.Input }}\nmodel\n", ChatMessage: "{{if eq .RoleName \"assistant\" }}model{{else}}{{ .RoleName }}{{end}}\n{{ if .Content -}}\n{{.Content -}}\n{{ end -}}", @@ -200,6 +200,18 @@ func guessDefaultsFromFile(cfg *BackendConfig, modelPath string) { } else { log.Debug().Any("family", family).Msgf("guessDefaultsFromFile: no template found for family") } + + if cfg.HasTemplate() { + return + } + + // identify from well known templates first, otherwise use the raw jinja template + chatTemplate, found := f.Header.MetadataKV.Get("tokenizer.chat_template") + if found { + // try to use the jinja template + cfg.TemplateConfig.JinjaTemplate = true + cfg.TemplateConfig.ChatMessage = chatTemplate.ValueString() + } } func identifyFamily(f *gguf.GGUFFile) familyType { diff --git a/core/http/app.go b/core/http/app.go index 2ba2c2b9..a2d8b87a 100644 --- a/core/http/app.go +++ b/core/http/app.go @@ -14,10 +14,9 @@ import ( "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/http/routes" - "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/application" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/services" - "github.com/mudler/LocalAI/pkg/model" "github.com/gofiber/contrib/fiberzerolog" "github.com/gofiber/fiber/v2" @@ -49,18 +48,18 @@ var embedDirStatic embed.FS // @in header // @name Authorization -func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (*fiber.App, error) { +func API(application *application.Application) (*fiber.App, error) { fiberCfg := fiber.Config{ Views: renderEngine(), - BodyLimit: appConfig.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB + BodyLimit: application.ApplicationConfig().UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB // We disable the Fiber startup message as it does not conform to structured logging. // We register a startup log line with connection information in the OnListen hook to keep things user friendly though DisableStartupMessage: true, // Override default error handler } - if !appConfig.OpaqueErrors { + if !application.ApplicationConfig().OpaqueErrors { // Normally, return errors as JSON responses fiberCfg.ErrorHandler = func(ctx *fiber.Ctx, err error) error { // Status code defaults to 500 @@ -86,9 +85,9 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi } } - app := fiber.New(fiberCfg) + router := fiber.New(fiberCfg) - app.Hooks().OnListen(func(listenData fiber.ListenData) error { + router.Hooks().OnListen(func(listenData fiber.ListenData) error { scheme := "http" if listenData.TLS { scheme = "https" @@ -99,82 +98,82 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi // Have Fiber use zerolog like the rest of the application rather than it's built-in logger logger := log.Logger - app.Use(fiberzerolog.New(fiberzerolog.Config{ + router.Use(fiberzerolog.New(fiberzerolog.Config{ Logger: &logger, })) // Default middleware config - if !appConfig.Debug { - app.Use(recover.New()) + if !application.ApplicationConfig().Debug { + router.Use(recover.New()) } - if !appConfig.DisableMetrics { + if !application.ApplicationConfig().DisableMetrics { metricsService, err := services.NewLocalAIMetricsService() if err != nil { return nil, err } if metricsService != nil { - app.Use(localai.LocalAIMetricsAPIMiddleware(metricsService)) - app.Hooks().OnShutdown(func() error { + router.Use(localai.LocalAIMetricsAPIMiddleware(metricsService)) + router.Hooks().OnShutdown(func() error { return metricsService.Shutdown() }) } } // Health Checks should always be exempt from auth, so register these first - routes.HealthRoutes(app) + routes.HealthRoutes(router) - kaConfig, err := middleware.GetKeyAuthConfig(appConfig) + kaConfig, err := middleware.GetKeyAuthConfig(application.ApplicationConfig()) if err != nil || kaConfig == nil { return nil, fmt.Errorf("failed to create key auth config: %w", err) } // Auth is applied to _all_ endpoints. No exceptions. Filtering out endpoints to bypass is the role of the Filter property of the KeyAuth Configuration - app.Use(v2keyauth.New(*kaConfig)) + router.Use(v2keyauth.New(*kaConfig)) - if appConfig.CORS { + if application.ApplicationConfig().CORS { var c func(ctx *fiber.Ctx) error - if appConfig.CORSAllowOrigins == "" { + if application.ApplicationConfig().CORSAllowOrigins == "" { c = cors.New() } else { - c = cors.New(cors.Config{AllowOrigins: appConfig.CORSAllowOrigins}) + c = cors.New(cors.Config{AllowOrigins: application.ApplicationConfig().CORSAllowOrigins}) } - app.Use(c) + router.Use(c) } - if appConfig.CSRF { + if application.ApplicationConfig().CSRF { log.Debug().Msg("Enabling CSRF middleware. Tokens are now required for state-modifying requests") - app.Use(csrf.New()) + router.Use(csrf.New()) } // Load config jsons - utils.LoadConfig(appConfig.UploadDir, openai.UploadedFilesFile, &openai.UploadedFiles) - utils.LoadConfig(appConfig.ConfigsDir, openai.AssistantsConfigFile, &openai.Assistants) - utils.LoadConfig(appConfig.ConfigsDir, openai.AssistantsFileConfigFile, &openai.AssistantFiles) + utils.LoadConfig(application.ApplicationConfig().UploadDir, openai.UploadedFilesFile, &openai.UploadedFiles) + utils.LoadConfig(application.ApplicationConfig().ConfigsDir, openai.AssistantsConfigFile, &openai.Assistants) + utils.LoadConfig(application.ApplicationConfig().ConfigsDir, openai.AssistantsFileConfigFile, &openai.AssistantFiles) - galleryService := services.NewGalleryService(appConfig) - galleryService.Start(appConfig.Context, cl) + galleryService := services.NewGalleryService(application.ApplicationConfig()) + galleryService.Start(application.ApplicationConfig().Context, application.BackendLoader()) - routes.RegisterElevenLabsRoutes(app, cl, ml, appConfig) - routes.RegisterLocalAIRoutes(app, cl, ml, appConfig, galleryService) - routes.RegisterOpenAIRoutes(app, cl, ml, appConfig) - if !appConfig.DisableWebUI { - routes.RegisterUIRoutes(app, cl, ml, appConfig, galleryService) + routes.RegisterElevenLabsRoutes(router, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()) + routes.RegisterLocalAIRoutes(router, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig(), galleryService) + routes.RegisterOpenAIRoutes(router, application) + if !application.ApplicationConfig().DisableWebUI { + routes.RegisterUIRoutes(router, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig(), galleryService) } - routes.RegisterJINARoutes(app, cl, ml, appConfig) + routes.RegisterJINARoutes(router, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()) httpFS := http.FS(embedDirStatic) - app.Use(favicon.New(favicon.Config{ + router.Use(favicon.New(favicon.Config{ URL: "/favicon.ico", FileSystem: httpFS, File: "static/favicon.ico", })) - app.Use("/static", filesystem.New(filesystem.Config{ + router.Use("/static", filesystem.New(filesystem.Config{ Root: httpFS, PathPrefix: "static", Browse: true, @@ -182,7 +181,7 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi // Define a custom 404 handler // Note: keep this at the bottom! - app.Use(notFoundHandler) + router.Use(notFoundHandler) - return app, nil + return router, nil } diff --git a/core/http/app_test.go b/core/http/app_test.go index 83fb0e73..34ebacf7 100644 --- a/core/http/app_test.go +++ b/core/http/app_test.go @@ -12,15 +12,14 @@ import ( "path/filepath" "runtime" + "github.com/mudler/LocalAI/core/application" "github.com/mudler/LocalAI/core/config" . "github.com/mudler/LocalAI/core/http" "github.com/mudler/LocalAI/core/schema" - "github.com/mudler/LocalAI/core/startup" "github.com/gofiber/fiber/v2" "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/pkg/downloader" - "github.com/mudler/LocalAI/pkg/model" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "gopkg.in/yaml.v3" @@ -252,9 +251,6 @@ var _ = Describe("API test", func() { var cancel context.CancelFunc var tmpdir string var modelDir string - var bcl *config.BackendConfigLoader - var ml *model.ModelLoader - var applicationConfig *config.ApplicationConfig commonOpts := []config.AppOption{ config.WithDebug(true), @@ -300,7 +296,7 @@ var _ = Describe("API test", func() { }, } - bcl, ml, applicationConfig, err = startup.Startup( + application, err := application.New( append(commonOpts, config.WithContext(c), config.WithGalleries(galleries), @@ -310,7 +306,7 @@ var _ = Describe("API test", func() { config.WithBackendAssetsOutput(backendAssetsDir))...) Expect(err).ToNot(HaveOccurred()) - app, err = App(bcl, ml, applicationConfig) + app, err = API(application) Expect(err).ToNot(HaveOccurred()) go app.Listen("127.0.0.1:9090") @@ -539,7 +535,7 @@ var _ = Describe("API test", func() { var res map[string]string err = json.Unmarshal([]byte(resp2.Choices[0].Message.FunctionCall.Arguments), &res) Expect(err).ToNot(HaveOccurred()) - Expect(res["location"]).To(Equal("San Francisco"), fmt.Sprint(res)) + Expect(res["location"]).To(ContainSubstring("San Francisco"), fmt.Sprint(res)) Expect(res["unit"]).To(Equal("celcius"), fmt.Sprint(res)) Expect(string(resp2.Choices[0].FinishReason)).To(Equal("function_call"), fmt.Sprint(resp2.Choices[0].FinishReason)) @@ -641,7 +637,7 @@ var _ = Describe("API test", func() { }, } - bcl, ml, applicationConfig, err = startup.Startup( + application, err := application.New( append(commonOpts, config.WithContext(c), config.WithAudioDir(tmpdir), @@ -652,7 +648,7 @@ var _ = Describe("API test", func() { config.WithBackendAssetsOutput(tmpdir))..., ) Expect(err).ToNot(HaveOccurred()) - app, err = App(bcl, ml, applicationConfig) + app, err = API(application) Expect(err).ToNot(HaveOccurred()) go app.Listen("127.0.0.1:9090") @@ -772,14 +768,14 @@ var _ = Describe("API test", func() { var err error - bcl, ml, applicationConfig, err = startup.Startup( + application, err := application.New( append(commonOpts, config.WithExternalBackend("huggingface", os.Getenv("HUGGINGFACE_GRPC")), config.WithContext(c), config.WithModelPath(modelPath), )...) Expect(err).ToNot(HaveOccurred()) - app, err = App(bcl, ml, applicationConfig) + app, err = API(application) Expect(err).ToNot(HaveOccurred()) go app.Listen("127.0.0.1:9090") @@ -990,14 +986,14 @@ var _ = Describe("API test", func() { c, cancel = context.WithCancel(context.Background()) var err error - bcl, ml, applicationConfig, err = startup.Startup( + application, err := application.New( append(commonOpts, config.WithContext(c), config.WithModelPath(modelPath), config.WithConfigFile(os.Getenv("CONFIG_FILE")))..., ) Expect(err).ToNot(HaveOccurred()) - app, err = App(bcl, ml, applicationConfig) + app, err = API(application) Expect(err).ToNot(HaveOccurred()) go app.Listen("127.0.0.1:9090") diff --git a/core/http/endpoints/openai/chat.go b/core/http/endpoints/openai/chat.go index b03b18bd..21e71d35 100644 --- a/core/http/endpoints/openai/chat.go +++ b/core/http/endpoints/openai/chat.go @@ -14,6 +14,8 @@ import ( "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/pkg/functions" + "github.com/mudler/LocalAI/pkg/templates" + model "github.com/mudler/LocalAI/pkg/model" "github.com/rs/zerolog/log" "github.com/valyala/fasthttp" @@ -24,7 +26,7 @@ import ( // @Param request body schema.OpenAIRequest true "query params" // @Success 200 {object} schema.OpenAIResponse "Response" // @Router /v1/chat/completions [post] -func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startupOptions *config.ApplicationConfig) func(c *fiber.Ctx) error { +func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, startupOptions *config.ApplicationConfig) func(c *fiber.Ctx) error { var id, textContentToReturn string var created int @@ -298,148 +300,10 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup // If we are using the tokenizer template, we don't need to process the messages // unless we are processing functions if !config.TemplateConfig.UseTokenizerTemplate || shouldUseFn { - suppressConfigSystemPrompt := false - mess := []string{} - for messageIndex, i := range input.Messages { - var content string - role := i.Role - - // if function call, we might want to customize the role so we can display better that the "assistant called a json action" - // if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request - if (i.FunctionCall != nil || i.ToolCalls != nil) && i.Role == "assistant" { - roleFn := "assistant_function_call" - r := config.Roles[roleFn] - if r != "" { - role = roleFn - } - } - r := config.Roles[role] - contentExists := i.Content != nil && i.StringContent != "" - - fcall := i.FunctionCall - if len(i.ToolCalls) > 0 { - fcall = i.ToolCalls - } - - // First attempt to populate content via a chat message specific template - if config.TemplateConfig.ChatMessage != "" { - chatMessageData := model.ChatMessageTemplateData{ - SystemPrompt: config.SystemPrompt, - Role: r, - RoleName: role, - Content: i.StringContent, - FunctionCall: fcall, - FunctionName: i.Name, - LastMessage: messageIndex == (len(input.Messages) - 1), - Function: config.Grammar != "" && (messageIndex == (len(input.Messages) - 1)), - MessageIndex: messageIndex, - } - templatedChatMessage, err := ml.EvaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData) - if err != nil { - log.Error().Err(err).Interface("message", chatMessageData).Str("template", config.TemplateConfig.ChatMessage).Msg("error processing message with template, skipping") - } else { - if templatedChatMessage == "" { - log.Warn().Msgf("template \"%s\" produced blank output for %+v. Skipping!", config.TemplateConfig.ChatMessage, chatMessageData) - continue // TODO: This continue is here intentionally to skip over the line `mess = append(mess, content)` below, and to prevent the sprintf - } - log.Debug().Msgf("templated message for chat: %s", templatedChatMessage) - content = templatedChatMessage - } - } - - marshalAnyRole := func(f any) { - j, err := json.Marshal(f) - if err == nil { - if contentExists { - content += "\n" + fmt.Sprint(r, " ", string(j)) - } else { - content = fmt.Sprint(r, " ", string(j)) - } - } - } - marshalAny := func(f any) { - j, err := json.Marshal(f) - if err == nil { - if contentExists { - content += "\n" + string(j) - } else { - content = string(j) - } - } - } - // If this model doesn't have such a template, or if that template fails to return a value, template at the message level. - if content == "" { - if r != "" { - if contentExists { - content = fmt.Sprint(r, i.StringContent) - } - - if i.FunctionCall != nil { - marshalAnyRole(i.FunctionCall) - } - if i.ToolCalls != nil { - marshalAnyRole(i.ToolCalls) - } - } else { - if contentExists { - content = fmt.Sprint(i.StringContent) - } - if i.FunctionCall != nil { - marshalAny(i.FunctionCall) - } - if i.ToolCalls != nil { - marshalAny(i.ToolCalls) - } - } - // Special Handling: System. We care if it was printed at all, not the r branch, so check seperately - if contentExists && role == "system" { - suppressConfigSystemPrompt = true - } - } - - mess = append(mess, content) - } - - joinCharacter := "\n" - if config.TemplateConfig.JoinChatMessagesByCharacter != nil { - joinCharacter = *config.TemplateConfig.JoinChatMessagesByCharacter - } - - predInput = strings.Join(mess, joinCharacter) - log.Debug().Msgf("Prompt (before templating): %s", predInput) - - templateFile := "" - - // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) { - templateFile = config.Model - } - - if config.TemplateConfig.Chat != "" && !shouldUseFn { - templateFile = config.TemplateConfig.Chat - } - - if config.TemplateConfig.Functions != "" && shouldUseFn { - templateFile = config.TemplateConfig.Functions - } - - if templateFile != "" { - templatedInput, err := ml.EvaluateTemplateForPrompt(model.ChatPromptTemplate, templateFile, model.PromptTemplateData{ - SystemPrompt: config.SystemPrompt, - SuppressSystemPrompt: suppressConfigSystemPrompt, - Input: predInput, - Functions: funcs, - }) - if err == nil { - predInput = templatedInput - log.Debug().Msgf("Template found, input modified to: %s", predInput) - } else { - log.Debug().Msgf("Template failed loading: %s", err.Error()) - } - } + predInput = evaluator.TemplateMessages(input.Messages, config, funcs, shouldUseFn) log.Debug().Msgf("Prompt (after templating): %s", predInput) - if shouldUseFn && config.Grammar != "" { + if config.Grammar != "" { log.Debug().Msgf("Grammar: %+v", config.Grammar) } } diff --git a/core/http/endpoints/openai/completion.go b/core/http/endpoints/openai/completion.go index e5de1b3f..04ebc847 100644 --- a/core/http/endpoints/openai/completion.go +++ b/core/http/endpoints/openai/completion.go @@ -16,6 +16,7 @@ import ( "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/pkg/functions" model "github.com/mudler/LocalAI/pkg/model" + "github.com/mudler/LocalAI/pkg/templates" "github.com/rs/zerolog/log" "github.com/valyala/fasthttp" ) @@ -25,7 +26,7 @@ import ( // @Param request body schema.OpenAIRequest true "query params" // @Success 200 {object} schema.OpenAIResponse "Response" // @Router /v1/completions [post] -func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { +func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { id := uuid.New().String() created := int(time.Now().Unix()) @@ -94,17 +95,6 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a c.Set("Transfer-Encoding", "chunked") } - templateFile := "" - - // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) { - templateFile = config.Model - } - - if config.TemplateConfig.Completion != "" { - templateFile = config.TemplateConfig.Completion - } - if input.Stream { if len(config.PromptStrings) > 1 { return errors.New("cannot handle more than 1 `PromptStrings` when Streaming") @@ -112,15 +102,13 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a predInput := config.PromptStrings[0] - if templateFile != "" { - templatedInput, err := ml.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{ - Input: predInput, - SystemPrompt: config.SystemPrompt, - }) - if err == nil { - predInput = templatedInput - log.Debug().Msgf("Template found, input modified to: %s", predInput) - } + templatedInput, err := evaluator.EvaluateTemplateForPrompt(templates.CompletionPromptTemplate, *config, templates.PromptTemplateData{ + Input: predInput, + SystemPrompt: config.SystemPrompt, + }) + if err == nil { + predInput = templatedInput + log.Debug().Msgf("Template found, input modified to: %s", predInput) } responses := make(chan schema.OpenAIResponse) @@ -165,16 +153,13 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a totalTokenUsage := backend.TokenUsage{} for k, i := range config.PromptStrings { - if templateFile != "" { - // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - templatedInput, err := ml.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{ - SystemPrompt: config.SystemPrompt, - Input: i, - }) - if err == nil { - i = templatedInput - log.Debug().Msgf("Template found, input modified to: %s", i) - } + templatedInput, err := evaluator.EvaluateTemplateForPrompt(templates.CompletionPromptTemplate, *config, templates.PromptTemplateData{ + SystemPrompt: config.SystemPrompt, + Input: i, + }) + if err == nil { + i = templatedInput + log.Debug().Msgf("Template found, input modified to: %s", i) } r, tokenUsage, err := ComputeChoices( diff --git a/core/http/endpoints/openai/edit.go b/core/http/endpoints/openai/edit.go index 12fb4035..a6d609fb 100644 --- a/core/http/endpoints/openai/edit.go +++ b/core/http/endpoints/openai/edit.go @@ -12,6 +12,7 @@ import ( "github.com/google/uuid" "github.com/mudler/LocalAI/core/schema" model "github.com/mudler/LocalAI/pkg/model" + "github.com/mudler/LocalAI/pkg/templates" "github.com/rs/zerolog/log" ) @@ -21,7 +22,8 @@ import ( // @Param request body schema.OpenAIRequest true "query params" // @Success 200 {object} schema.OpenAIResponse "Response" // @Router /v1/edits [post] -func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { +func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { modelFile, input, err := readRequest(c, cl, ml, appConfig, true) if err != nil { @@ -35,31 +37,18 @@ func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConf log.Debug().Msgf("Parameter Config: %+v", config) - templateFile := "" - - // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) { - templateFile = config.Model - } - - if config.TemplateConfig.Edit != "" { - templateFile = config.TemplateConfig.Edit - } - var result []schema.Choice totalTokenUsage := backend.TokenUsage{} for _, i := range config.InputStrings { - if templateFile != "" { - templatedInput, err := ml.EvaluateTemplateForPrompt(model.EditPromptTemplate, templateFile, model.PromptTemplateData{ - Input: i, - Instruction: input.Instruction, - SystemPrompt: config.SystemPrompt, - }) - if err == nil { - i = templatedInput - log.Debug().Msgf("Template found, input modified to: %s", i) - } + templatedInput, err := evaluator.EvaluateTemplateForPrompt(templates.EditPromptTemplate, *config, templates.PromptTemplateData{ + Input: i, + Instruction: input.Instruction, + SystemPrompt: config.SystemPrompt, + }) + if err == nil { + i = templatedInput + log.Debug().Msgf("Template found, input modified to: %s", i) } r, tokenUsage, err := ComputeChoices(input, i, config, appConfig, ml, func(s string, c *[]schema.Choice) { diff --git a/core/http/routes/localai.go b/core/http/routes/localai.go index e7097741..2ea9896a 100644 --- a/core/http/routes/localai.go +++ b/core/http/routes/localai.go @@ -11,62 +11,62 @@ import ( "github.com/mudler/LocalAI/pkg/model" ) -func RegisterLocalAIRoutes(app *fiber.App, +func RegisterLocalAIRoutes(router *fiber.App, cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig, galleryService *services.GalleryService) { - app.Get("/swagger/*", swagger.HandlerDefault) // default + router.Get("/swagger/*", swagger.HandlerDefault) // default // LocalAI API endpoints if !appConfig.DisableGalleryEndpoint { modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.ModelPath, galleryService) - app.Post("/models/apply", modelGalleryEndpointService.ApplyModelGalleryEndpoint()) - app.Post("/models/delete/:name", modelGalleryEndpointService.DeleteModelGalleryEndpoint()) + router.Post("/models/apply", modelGalleryEndpointService.ApplyModelGalleryEndpoint()) + router.Post("/models/delete/:name", modelGalleryEndpointService.DeleteModelGalleryEndpoint()) - app.Get("/models/available", modelGalleryEndpointService.ListModelFromGalleryEndpoint()) - app.Get("/models/galleries", modelGalleryEndpointService.ListModelGalleriesEndpoint()) - app.Post("/models/galleries", modelGalleryEndpointService.AddModelGalleryEndpoint()) - app.Delete("/models/galleries", modelGalleryEndpointService.RemoveModelGalleryEndpoint()) - app.Get("/models/jobs/:uuid", modelGalleryEndpointService.GetOpStatusEndpoint()) - app.Get("/models/jobs", modelGalleryEndpointService.GetAllStatusEndpoint()) + router.Get("/models/available", modelGalleryEndpointService.ListModelFromGalleryEndpoint()) + router.Get("/models/galleries", modelGalleryEndpointService.ListModelGalleriesEndpoint()) + router.Post("/models/galleries", modelGalleryEndpointService.AddModelGalleryEndpoint()) + router.Delete("/models/galleries", modelGalleryEndpointService.RemoveModelGalleryEndpoint()) + router.Get("/models/jobs/:uuid", modelGalleryEndpointService.GetOpStatusEndpoint()) + router.Get("/models/jobs", modelGalleryEndpointService.GetAllStatusEndpoint()) } - app.Post("/tts", localai.TTSEndpoint(cl, ml, appConfig)) - app.Post("/vad", localai.VADEndpoint(cl, ml, appConfig)) + router.Post("/tts", localai.TTSEndpoint(cl, ml, appConfig)) + router.Post("/vad", localai.VADEndpoint(cl, ml, appConfig)) // Stores sl := model.NewModelLoader("") - app.Post("/stores/set", localai.StoresSetEndpoint(sl, appConfig)) - app.Post("/stores/delete", localai.StoresDeleteEndpoint(sl, appConfig)) - app.Post("/stores/get", localai.StoresGetEndpoint(sl, appConfig)) - app.Post("/stores/find", localai.StoresFindEndpoint(sl, appConfig)) + router.Post("/stores/set", localai.StoresSetEndpoint(sl, appConfig)) + router.Post("/stores/delete", localai.StoresDeleteEndpoint(sl, appConfig)) + router.Post("/stores/get", localai.StoresGetEndpoint(sl, appConfig)) + router.Post("/stores/find", localai.StoresFindEndpoint(sl, appConfig)) if !appConfig.DisableMetrics { - app.Get("/metrics", localai.LocalAIMetricsEndpoint()) + router.Get("/metrics", localai.LocalAIMetricsEndpoint()) } // Experimental Backend Statistics Module backendMonitorService := services.NewBackendMonitorService(ml, cl, appConfig) // Split out for now - app.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService)) - app.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService)) + router.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService)) + router.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService)) // p2p if p2p.IsP2PEnabled() { - app.Get("/api/p2p", localai.ShowP2PNodes(appConfig)) - app.Get("/api/p2p/token", localai.ShowP2PToken(appConfig)) + router.Get("/api/p2p", localai.ShowP2PNodes(appConfig)) + router.Get("/api/p2p/token", localai.ShowP2PToken(appConfig)) } - app.Get("/version", func(c *fiber.Ctx) error { + router.Get("/version", func(c *fiber.Ctx) error { return c.JSON(struct { Version string `json:"version"` }{Version: internal.PrintableVersion()}) }) - app.Get("/system", localai.SystemInformations(ml, appConfig)) + router.Get("/system", localai.SystemInformations(ml, appConfig)) // misc - app.Post("/v1/tokenize", localai.TokenizeEndpoint(cl, ml, appConfig)) + router.Post("/v1/tokenize", localai.TokenizeEndpoint(cl, ml, appConfig)) } diff --git a/core/http/routes/openai.go b/core/http/routes/openai.go index 081daf70..5ff301b6 100644 --- a/core/http/routes/openai.go +++ b/core/http/routes/openai.go @@ -2,84 +2,134 @@ package routes import ( "github.com/gofiber/fiber/v2" - "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/application" "github.com/mudler/LocalAI/core/http/endpoints/localai" "github.com/mudler/LocalAI/core/http/endpoints/openai" - "github.com/mudler/LocalAI/pkg/model" ) func RegisterOpenAIRoutes(app *fiber.App, - cl *config.BackendConfigLoader, - ml *model.ModelLoader, - appConfig *config.ApplicationConfig) { + application *application.Application) { // openAI compatible API endpoint // chat - app.Post("/v1/chat/completions", openai.ChatEndpoint(cl, ml, appConfig)) - app.Post("/chat/completions", openai.ChatEndpoint(cl, ml, appConfig)) + app.Post("/v1/chat/completions", + openai.ChatEndpoint( + application.BackendLoader(), + application.ModelLoader(), + application.TemplatesEvaluator(), + application.ApplicationConfig(), + ), + ) + + app.Post("/chat/completions", + openai.ChatEndpoint( + application.BackendLoader(), + application.ModelLoader(), + application.TemplatesEvaluator(), + application.ApplicationConfig(), + ), + ) // edit - app.Post("/v1/edits", openai.EditEndpoint(cl, ml, appConfig)) - app.Post("/edits", openai.EditEndpoint(cl, ml, appConfig)) + app.Post("/v1/edits", + openai.EditEndpoint( + application.BackendLoader(), + application.ModelLoader(), + application.TemplatesEvaluator(), + application.ApplicationConfig(), + ), + ) + + app.Post("/edits", + openai.EditEndpoint( + application.BackendLoader(), + application.ModelLoader(), + application.TemplatesEvaluator(), + application.ApplicationConfig(), + ), + ) // assistant - app.Get("/v1/assistants", openai.ListAssistantsEndpoint(cl, ml, appConfig)) - app.Get("/assistants", openai.ListAssistantsEndpoint(cl, ml, appConfig)) - app.Post("/v1/assistants", openai.CreateAssistantEndpoint(cl, ml, appConfig)) - app.Post("/assistants", openai.CreateAssistantEndpoint(cl, ml, appConfig)) - app.Delete("/v1/assistants/:assistant_id", openai.DeleteAssistantEndpoint(cl, ml, appConfig)) - app.Delete("/assistants/:assistant_id", openai.DeleteAssistantEndpoint(cl, ml, appConfig)) - app.Get("/v1/assistants/:assistant_id", openai.GetAssistantEndpoint(cl, ml, appConfig)) - app.Get("/assistants/:assistant_id", openai.GetAssistantEndpoint(cl, ml, appConfig)) - app.Post("/v1/assistants/:assistant_id", openai.ModifyAssistantEndpoint(cl, ml, appConfig)) - app.Post("/assistants/:assistant_id", openai.ModifyAssistantEndpoint(cl, ml, appConfig)) - app.Get("/v1/assistants/:assistant_id/files", openai.ListAssistantFilesEndpoint(cl, ml, appConfig)) - app.Get("/assistants/:assistant_id/files", openai.ListAssistantFilesEndpoint(cl, ml, appConfig)) - app.Post("/v1/assistants/:assistant_id/files", openai.CreateAssistantFileEndpoint(cl, ml, appConfig)) - app.Post("/assistants/:assistant_id/files", openai.CreateAssistantFileEndpoint(cl, ml, appConfig)) - app.Delete("/v1/assistants/:assistant_id/files/:file_id", openai.DeleteAssistantFileEndpoint(cl, ml, appConfig)) - app.Delete("/assistants/:assistant_id/files/:file_id", openai.DeleteAssistantFileEndpoint(cl, ml, appConfig)) - app.Get("/v1/assistants/:assistant_id/files/:file_id", openai.GetAssistantFileEndpoint(cl, ml, appConfig)) - app.Get("/assistants/:assistant_id/files/:file_id", openai.GetAssistantFileEndpoint(cl, ml, appConfig)) + app.Get("/v1/assistants", openai.ListAssistantsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) + app.Get("/assistants", openai.ListAssistantsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) + app.Post("/v1/assistants", openai.CreateAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) + app.Post("/assistants", openai.CreateAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) + app.Delete("/v1/assistants/:assistant_id", openai.DeleteAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) + app.Delete("/assistants/:assistant_id", openai.DeleteAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) + app.Get("/v1/assistants/:assistant_id", openai.GetAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) + app.Get("/assistants/:assistant_id", openai.GetAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) + app.Post("/v1/assistants/:assistant_id", openai.ModifyAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) + app.Post("/assistants/:assistant_id", openai.ModifyAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) + app.Get("/v1/assistants/:assistant_id/files", openai.ListAssistantFilesEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) + app.Get("/assistants/:assistant_id/files", openai.ListAssistantFilesEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) + app.Post("/v1/assistants/:assistant_id/files", openai.CreateAssistantFileEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) + app.Post("/assistants/:assistant_id/files", openai.CreateAssistantFileEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) + app.Delete("/v1/assistants/:assistant_id/files/:file_id", openai.DeleteAssistantFileEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) + app.Delete("/assistants/:assistant_id/files/:file_id", openai.DeleteAssistantFileEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) + app.Get("/v1/assistants/:assistant_id/files/:file_id", openai.GetAssistantFileEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) + app.Get("/assistants/:assistant_id/files/:file_id", openai.GetAssistantFileEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) // files - app.Post("/v1/files", openai.UploadFilesEndpoint(cl, appConfig)) - app.Post("/files", openai.UploadFilesEndpoint(cl, appConfig)) - app.Get("/v1/files", openai.ListFilesEndpoint(cl, appConfig)) - app.Get("/files", openai.ListFilesEndpoint(cl, appConfig)) - app.Get("/v1/files/:file_id", openai.GetFilesEndpoint(cl, appConfig)) - app.Get("/files/:file_id", openai.GetFilesEndpoint(cl, appConfig)) - app.Delete("/v1/files/:file_id", openai.DeleteFilesEndpoint(cl, appConfig)) - app.Delete("/files/:file_id", openai.DeleteFilesEndpoint(cl, appConfig)) - app.Get("/v1/files/:file_id/content", openai.GetFilesContentsEndpoint(cl, appConfig)) - app.Get("/files/:file_id/content", openai.GetFilesContentsEndpoint(cl, appConfig)) + app.Post("/v1/files", openai.UploadFilesEndpoint(application.BackendLoader(), application.ApplicationConfig())) + app.Post("/files", openai.UploadFilesEndpoint(application.BackendLoader(), application.ApplicationConfig())) + app.Get("/v1/files", openai.ListFilesEndpoint(application.BackendLoader(), application.ApplicationConfig())) + app.Get("/files", openai.ListFilesEndpoint(application.BackendLoader(), application.ApplicationConfig())) + app.Get("/v1/files/:file_id", openai.GetFilesEndpoint(application.BackendLoader(), application.ApplicationConfig())) + app.Get("/files/:file_id", openai.GetFilesEndpoint(application.BackendLoader(), application.ApplicationConfig())) + app.Delete("/v1/files/:file_id", openai.DeleteFilesEndpoint(application.BackendLoader(), application.ApplicationConfig())) + app.Delete("/files/:file_id", openai.DeleteFilesEndpoint(application.BackendLoader(), application.ApplicationConfig())) + app.Get("/v1/files/:file_id/content", openai.GetFilesContentsEndpoint(application.BackendLoader(), application.ApplicationConfig())) + app.Get("/files/:file_id/content", openai.GetFilesContentsEndpoint(application.BackendLoader(), application.ApplicationConfig())) // completion - app.Post("/v1/completions", openai.CompletionEndpoint(cl, ml, appConfig)) - app.Post("/completions", openai.CompletionEndpoint(cl, ml, appConfig)) - app.Post("/v1/engines/:model/completions", openai.CompletionEndpoint(cl, ml, appConfig)) + app.Post("/v1/completions", + openai.CompletionEndpoint( + application.BackendLoader(), + application.ModelLoader(), + application.TemplatesEvaluator(), + application.ApplicationConfig(), + ), + ) + + app.Post("/completions", + openai.CompletionEndpoint( + application.BackendLoader(), + application.ModelLoader(), + application.TemplatesEvaluator(), + application.ApplicationConfig(), + ), + ) + + app.Post("/v1/engines/:model/completions", + openai.CompletionEndpoint( + application.BackendLoader(), + application.ModelLoader(), + application.TemplatesEvaluator(), + application.ApplicationConfig(), + ), + ) // embeddings - app.Post("/v1/embeddings", openai.EmbeddingsEndpoint(cl, ml, appConfig)) - app.Post("/embeddings", openai.EmbeddingsEndpoint(cl, ml, appConfig)) - app.Post("/v1/engines/:model/embeddings", openai.EmbeddingsEndpoint(cl, ml, appConfig)) + app.Post("/v1/embeddings", openai.EmbeddingsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) + app.Post("/embeddings", openai.EmbeddingsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) + app.Post("/v1/engines/:model/embeddings", openai.EmbeddingsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) // audio - app.Post("/v1/audio/transcriptions", openai.TranscriptEndpoint(cl, ml, appConfig)) - app.Post("/v1/audio/speech", localai.TTSEndpoint(cl, ml, appConfig)) + app.Post("/v1/audio/transcriptions", openai.TranscriptEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) + app.Post("/v1/audio/speech", localai.TTSEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) // images - app.Post("/v1/images/generations", openai.ImageEndpoint(cl, ml, appConfig)) + app.Post("/v1/images/generations", openai.ImageEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) - if appConfig.ImageDir != "" { - app.Static("/generated-images", appConfig.ImageDir) + if application.ApplicationConfig().ImageDir != "" { + app.Static("/generated-images", application.ApplicationConfig().ImageDir) } - if appConfig.AudioDir != "" { - app.Static("/generated-audio", appConfig.AudioDir) + if application.ApplicationConfig().AudioDir != "" { + app.Static("/generated-audio", application.ApplicationConfig().AudioDir) } // List models - app.Get("/v1/models", openai.ListModelsEndpoint(cl, ml)) - app.Get("/models", openai.ListModelsEndpoint(cl, ml)) + app.Get("/v1/models", openai.ListModelsEndpoint(application.BackendLoader(), application.ModelLoader())) + app.Get("/models", openai.ListModelsEndpoint(application.BackendLoader(), application.ModelLoader())) } diff --git a/go.mod b/go.mod index 3bc625ac..e9bcf3ec 100644 --- a/go.mod +++ b/go.mod @@ -76,6 +76,7 @@ require ( cloud.google.com/go/auth v0.4.1 // indirect cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect cloud.google.com/go/compute/metadata v0.3.0 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect github.com/envoyproxy/protoc-gen-validate v1.0.4 // indirect github.com/fasthttp/websocket v1.5.3 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect @@ -84,8 +85,12 @@ require ( github.com/google/s2a-go v0.1.7 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect github.com/googleapis/gax-go/v2 v2.12.4 // indirect + github.com/json-iterator/go v1.1.12 // indirect github.com/moby/docker-image-spec v1.3.1 // indirect + github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect + github.com/modern-go/reflect2 v1.0.2 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect + github.com/nikolalohinski/gonja/v2 v2.3.2 // indirect github.com/pion/datachannel v1.5.8 // indirect github.com/pion/dtls/v2 v2.2.12 // indirect github.com/pion/ice/v2 v2.3.34 // indirect diff --git a/go.sum b/go.sum index 11b87fa9..f1628f7a 100644 --- a/go.sum +++ b/go.sum @@ -140,6 +140,8 @@ github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5 h1:iFaUwBSo5Svw6L github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5/go.mod h1:qssHWj60/X5sZFNxpG4HBPDHVqxNm4DfnCKgrbZOT+s= github.com/dsnet/golib v0.0.0-20171103203638-1ea166775780/go.mod h1:Lj+Z9rebOhdfkVLjJ8T6VcRQv3SXugXy999NBtR9aFY= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/elastic/gosigar v0.12.0/go.mod h1:iXRIGg2tLnu7LBdpqzyQfGDEidKCfWcCMS0WKyPWoMs= github.com/elastic/gosigar v0.14.3 h1:xwkKwPia+hSfg9GqrCUKYdId102m9qTJIIr7egmK/uo= github.com/elastic/gosigar v0.14.3/go.mod h1:iXRIGg2tLnu7LBdpqzyQfGDEidKCfWcCMS0WKyPWoMs= @@ -268,6 +270,7 @@ github.com/google/go-containerregistry v0.19.2 h1:TannFKE1QSajsP6hPWb5oJNgKe1IKj github.com/google/go-containerregistry v0.19.2/go.mod h1:YCMFNQeeXeLF+dnhhWkqDItx/JSkH01j1Kis4PsjzFI= github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ= github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= @@ -353,6 +356,8 @@ github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwA github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= +github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= +github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= @@ -474,8 +479,12 @@ github.com/moby/sys/sequential v0.5.0 h1:OPvI35Lzn9K04PBbCLW0g4LcFAJgHsvXsRyewg5 github.com/moby/sys/sequential v0.5.0/go.mod h1:tH2cOOs5V9MlPiXcQzRC+eEyab644PWKGRYaaV5ZZlo= github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0= github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= +github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/mr-tron/base58 v1.1.2/go.mod h1:BinMc/sQntlIE1frQmRFPUoPA1Zkr8VRgBdjWI2mNwc= github.com/mr-tron/base58 v1.2.0 h1:T/HDJBh4ZCPbU39/+c3rRvE0uKBQlU27+QI8LJ4t64o= github.com/mr-tron/base58 v1.2.0/go.mod h1:BinMc/sQntlIE1frQmRFPUoPA1Zkr8VRgBdjWI2mNwc= @@ -519,6 +528,9 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo= github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= +github.com/nikolalohinski/gonja v1.5.3 h1:GsA+EEaZDZPGJ8JtpeGN78jidhOlxeJROpqMT9fTj9c= +github.com/nikolalohinski/gonja/v2 v2.3.2 h1:UgLFfqi7L9XfX0PEcE4eUpvGojVQL5KhBfJJaBp7ZxY= +github.com/nikolalohinski/gonja/v2 v2.3.2/go.mod h1:1Wcc/5huTu6y36e0sOFR1XQoFlylw3c3H3L5WOz0RDg= github.com/nwaples/rardecode v1.1.0 h1:vSxaY8vQhOcVr4mm5e8XllHWTiM4JF507A0Katqw7MQ= github.com/nwaples/rardecode v1.1.0/go.mod h1:5DzqNKiOdpKKBH87u8VlvAnPZMXcGRhxWkRpHbbfGS0= github.com/nxadm/tail v1.4.11 h1:8feyoE3OzPrcshW5/MJ4sGESc5cqmGkGCWlco4l0bqY= diff --git a/pkg/model/loader.go b/pkg/model/loader.go index b32e3745..d62f52b2 100644 --- a/pkg/model/loader.go +++ b/pkg/model/loader.go @@ -9,8 +9,6 @@ import ( "sync" "time" - "github.com/mudler/LocalAI/pkg/templates" - "github.com/mudler/LocalAI/pkg/utils" "github.com/rs/zerolog/log" @@ -23,7 +21,6 @@ type ModelLoader struct { ModelPath string mu sync.Mutex models map[string]*Model - templates *templates.TemplateCache wd *WatchDog } @@ -31,7 +28,6 @@ func NewModelLoader(modelPath string) *ModelLoader { nml := &ModelLoader{ ModelPath: modelPath, models: make(map[string]*Model), - templates: templates.NewTemplateCache(modelPath), } return nml diff --git a/pkg/model/template.go b/pkg/model/template.go deleted file mode 100644 index 3dc850cf..00000000 --- a/pkg/model/template.go +++ /dev/null @@ -1,52 +0,0 @@ -package model - -import ( - "fmt" - - "github.com/mudler/LocalAI/pkg/functions" - "github.com/mudler/LocalAI/pkg/templates" -) - -// Rather than pass an interface{} to the prompt template: -// These are the definitions of all possible variables LocalAI will currently populate for use in a prompt template file -// Please note: Not all of these are populated on every endpoint - your template should either be tested for each endpoint you map it to, or tolerant of zero values. -type PromptTemplateData struct { - SystemPrompt string - SuppressSystemPrompt bool // used by chat specifically to indicate that SystemPrompt above should be _ignored_ - Input string - Instruction string - Functions []functions.Function - MessageIndex int -} - -type ChatMessageTemplateData struct { - SystemPrompt string - Role string - RoleName string - FunctionName string - Content string - MessageIndex int - Function bool - FunctionCall interface{} - LastMessage bool -} - -const ( - ChatPromptTemplate templates.TemplateType = iota - ChatMessageTemplate - CompletionPromptTemplate - EditPromptTemplate - FunctionsPromptTemplate -) - -func (ml *ModelLoader) EvaluateTemplateForPrompt(templateType templates.TemplateType, templateName string, in PromptTemplateData) (string, error) { - // TODO: should this check be improved? - if templateType == ChatMessageTemplate { - return "", fmt.Errorf("invalid templateType: ChatMessage") - } - return ml.templates.EvaluateTemplate(templateType, templateName, in) -} - -func (ml *ModelLoader) EvaluateTemplateForChatMessage(templateName string, messageData ChatMessageTemplateData) (string, error) { - return ml.templates.EvaluateTemplate(ChatMessageTemplate, templateName, messageData) -} diff --git a/pkg/model/template_test.go b/pkg/model/template_test.go deleted file mode 100644 index 1142ed0c..00000000 --- a/pkg/model/template_test.go +++ /dev/null @@ -1,197 +0,0 @@ -package model_test - -import ( - . "github.com/mudler/LocalAI/pkg/model" - - . "github.com/onsi/ginkgo/v2" - . "github.com/onsi/gomega" -) - -const chatML = `<|im_start|>{{if eq .RoleName "assistant"}}assistant{{else if eq .RoleName "system"}}system{{else if eq .RoleName "tool"}}tool{{else if eq .RoleName "user"}}user{{end}} -{{- if .FunctionCall }} - -{{- else if eq .RoleName "tool" }} - -{{- end }} -{{- if .Content}} -{{.Content }} -{{- end }} -{{- if .FunctionCall}} -{{toJson .FunctionCall}} -{{- end }} -{{- if .FunctionCall }} - -{{- else if eq .RoleName "tool" }} - -{{- end }}<|im_end|>` - -const llama3 = `<|start_header_id|>{{if eq .RoleName "assistant"}}assistant{{else if eq .RoleName "system"}}system{{else if eq .RoleName "tool"}}tool{{else if eq .RoleName "user"}}user{{end}}<|end_header_id|> - -{{ if .FunctionCall -}} -Function call: -{{ else if eq .RoleName "tool" -}} -Function response: -{{ end -}} -{{ if .Content -}} -{{.Content -}} -{{ else if .FunctionCall -}} -{{ toJson .FunctionCall -}} -{{ end -}} -<|eot_id|>` - -var llama3TestMatch map[string]map[string]interface{} = map[string]map[string]interface{}{ - "user": { - "template": llama3, - "expected": "<|start_header_id|>user<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|>", - "data": ChatMessageTemplateData{ - SystemPrompt: "", - Role: "user", - RoleName: "user", - Content: "A long time ago in a galaxy far, far away...", - FunctionCall: nil, - FunctionName: "", - LastMessage: false, - Function: false, - MessageIndex: 0, - }, - }, - "assistant": { - "template": llama3, - "expected": "<|start_header_id|>assistant<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|>", - "data": ChatMessageTemplateData{ - SystemPrompt: "", - Role: "assistant", - RoleName: "assistant", - Content: "A long time ago in a galaxy far, far away...", - FunctionCall: nil, - FunctionName: "", - LastMessage: false, - Function: false, - MessageIndex: 0, - }, - }, - "function_call": { - "template": llama3, - "expected": "<|start_header_id|>assistant<|end_header_id|>\n\nFunction call:\n{\"function\":\"test\"}<|eot_id|>", - "data": ChatMessageTemplateData{ - SystemPrompt: "", - Role: "assistant", - RoleName: "assistant", - Content: "", - FunctionCall: map[string]string{"function": "test"}, - FunctionName: "", - LastMessage: false, - Function: false, - MessageIndex: 0, - }, - }, - "function_response": { - "template": llama3, - "expected": "<|start_header_id|>tool<|end_header_id|>\n\nFunction response:\nResponse from tool<|eot_id|>", - "data": ChatMessageTemplateData{ - SystemPrompt: "", - Role: "tool", - RoleName: "tool", - Content: "Response from tool", - FunctionCall: nil, - FunctionName: "", - LastMessage: false, - Function: false, - MessageIndex: 0, - }, - }, -} - -var chatMLTestMatch map[string]map[string]interface{} = map[string]map[string]interface{}{ - "user": { - "template": chatML, - "expected": "<|im_start|>user\nA long time ago in a galaxy far, far away...<|im_end|>", - "data": ChatMessageTemplateData{ - SystemPrompt: "", - Role: "user", - RoleName: "user", - Content: "A long time ago in a galaxy far, far away...", - FunctionCall: nil, - FunctionName: "", - LastMessage: false, - Function: false, - MessageIndex: 0, - }, - }, - "assistant": { - "template": chatML, - "expected": "<|im_start|>assistant\nA long time ago in a galaxy far, far away...<|im_end|>", - "data": ChatMessageTemplateData{ - SystemPrompt: "", - Role: "assistant", - RoleName: "assistant", - Content: "A long time ago in a galaxy far, far away...", - FunctionCall: nil, - FunctionName: "", - LastMessage: false, - Function: false, - MessageIndex: 0, - }, - }, - "function_call": { - "template": chatML, - "expected": "<|im_start|>assistant\n\n{\"function\":\"test\"}\n<|im_end|>", - "data": ChatMessageTemplateData{ - SystemPrompt: "", - Role: "assistant", - RoleName: "assistant", - Content: "", - FunctionCall: map[string]string{"function": "test"}, - FunctionName: "", - LastMessage: false, - Function: false, - MessageIndex: 0, - }, - }, - "function_response": { - "template": chatML, - "expected": "<|im_start|>tool\n\nResponse from tool\n<|im_end|>", - "data": ChatMessageTemplateData{ - SystemPrompt: "", - Role: "tool", - RoleName: "tool", - Content: "Response from tool", - FunctionCall: nil, - FunctionName: "", - LastMessage: false, - Function: false, - MessageIndex: 0, - }, - }, -} - -var _ = Describe("Templates", func() { - Context("chat message ChatML", func() { - var modelLoader *ModelLoader - BeforeEach(func() { - modelLoader = NewModelLoader("") - }) - for key := range chatMLTestMatch { - foo := chatMLTestMatch[key] - It("renders correctly `"+key+"`", func() { - templated, err := modelLoader.EvaluateTemplateForChatMessage(foo["template"].(string), foo["data"].(ChatMessageTemplateData)) - Expect(err).ToNot(HaveOccurred()) - Expect(templated).To(Equal(foo["expected"]), templated) - }) - } - }) - Context("chat message llama3", func() { - var modelLoader *ModelLoader - BeforeEach(func() { - modelLoader = NewModelLoader("") - }) - for key := range llama3TestMatch { - foo := llama3TestMatch[key] - It("renders correctly `"+key+"`", func() { - templated, err := modelLoader.EvaluateTemplateForChatMessage(foo["template"].(string), foo["data"].(ChatMessageTemplateData)) - Expect(err).ToNot(HaveOccurred()) - Expect(templated).To(Equal(foo["expected"]), templated) - }) - } - }) -}) diff --git a/pkg/templates/cache.go b/pkg/templates/cache.go index e4801946..1efce660 100644 --- a/pkg/templates/cache.go +++ b/pkg/templates/cache.go @@ -11,59 +11,41 @@ import ( "github.com/mudler/LocalAI/pkg/utils" "github.com/Masterminds/sprig/v3" + + "github.com/nikolalohinski/gonja/v2" + "github.com/nikolalohinski/gonja/v2/exec" ) // Keep this in sync with config.TemplateConfig. Is there a more idiomatic way to accomplish this in go? // Technically, order doesn't _really_ matter, but the count must stay in sync, see tests/integration/reflect_test.go type TemplateType int -type TemplateCache struct { - mu sync.Mutex - templatesPath string - templates map[TemplateType]map[string]*template.Template +type templateCache struct { + mu sync.Mutex + templatesPath string + templates map[TemplateType]map[string]*template.Template + jinjaTemplates map[TemplateType]map[string]*exec.Template } -func NewTemplateCache(templatesPath string) *TemplateCache { - tc := &TemplateCache{ - templatesPath: templatesPath, - templates: make(map[TemplateType]map[string]*template.Template), +func newTemplateCache(templatesPath string) *templateCache { + tc := &templateCache{ + templatesPath: templatesPath, + templates: make(map[TemplateType]map[string]*template.Template), + jinjaTemplates: make(map[TemplateType]map[string]*exec.Template), } return tc } -func (tc *TemplateCache) initializeTemplateMapKey(tt TemplateType) { +func (tc *templateCache) initializeTemplateMapKey(tt TemplateType) { if _, ok := tc.templates[tt]; !ok { tc.templates[tt] = make(map[string]*template.Template) } } -func (tc *TemplateCache) EvaluateTemplate(templateType TemplateType, templateName string, in interface{}) (string, error) { - tc.mu.Lock() - defer tc.mu.Unlock() - - tc.initializeTemplateMapKey(templateType) - m, ok := tc.templates[templateType][templateName] - if !ok { - // return "", fmt.Errorf("template not loaded: %s", templateName) - loadErr := tc.loadTemplateIfExists(templateType, templateName) - if loadErr != nil { - return "", loadErr - } - m = tc.templates[templateType][templateName] // ok is not important since we check m on the next line, and wealready checked - } - if m == nil { - return "", fmt.Errorf("failed loading a template for %s", templateName) - } - - var buf bytes.Buffer - - if err := m.Execute(&buf, in); err != nil { - return "", err - } - return buf.String(), nil +func (tc *templateCache) existsInModelPath(s string) bool { + return utils.ExistsInPath(tc.templatesPath, s) } - -func (tc *TemplateCache) loadTemplateIfExists(templateType TemplateType, templateName string) error { +func (tc *templateCache) loadTemplateIfExists(templateType TemplateType, templateName string) error { // Check if the template was already loaded if _, ok := tc.templates[templateType][templateName]; ok { @@ -82,6 +64,51 @@ func (tc *TemplateCache) loadTemplateIfExists(templateType TemplateType, templat return fmt.Errorf("template file outside path: %s", file) } + // can either be a file in the system or a string with the template + if tc.existsInModelPath(modelTemplateFile) { + d, err := os.ReadFile(file) + if err != nil { + return err + } + dat = string(d) + } else { + dat = templateName + } + + // Parse the template + tmpl, err := template.New("prompt").Funcs(sprig.FuncMap()).Parse(dat) + if err != nil { + return err + } + tc.templates[templateType][templateName] = tmpl + + return nil +} + +func (tc *templateCache) initializeJinjaTemplateMapKey(tt TemplateType) { + if _, ok := tc.jinjaTemplates[tt]; !ok { + tc.jinjaTemplates[tt] = make(map[string]*exec.Template) + } +} + +func (tc *templateCache) loadJinjaTemplateIfExists(templateType TemplateType, templateName string) error { + // Check if the template was already loaded + if _, ok := tc.jinjaTemplates[templateType][templateName]; ok { + return nil + } + + // Check if the model path exists + // skip any error here - we run anyway if a template does not exist + modelTemplateFile := fmt.Sprintf("%s.tmpl", templateName) + + dat := "" + file := filepath.Join(tc.templatesPath, modelTemplateFile) + + // Security check + if err := utils.VerifyPath(modelTemplateFile, tc.templatesPath); err != nil { + return fmt.Errorf("template file outside path: %s", file) + } + // can either be a file in the system or a string with the template if utils.ExistsInPath(tc.templatesPath, modelTemplateFile) { d, err := os.ReadFile(file) @@ -93,12 +120,65 @@ func (tc *TemplateCache) loadTemplateIfExists(templateType TemplateType, templat dat = templateName } - // Parse the template - tmpl, err := template.New("prompt").Funcs(sprig.FuncMap()).Parse(dat) + tmpl, err := gonja.FromString(dat) if err != nil { return err } - tc.templates[templateType][templateName] = tmpl + tc.jinjaTemplates[templateType][templateName] = tmpl return nil } + +func (tc *templateCache) evaluateJinjaTemplate(templateType TemplateType, templateNameOrContent string, in map[string]interface{}) (string, error) { + tc.mu.Lock() + defer tc.mu.Unlock() + + tc.initializeJinjaTemplateMapKey(templateType) + m, ok := tc.jinjaTemplates[templateType][templateNameOrContent] + if !ok { + // return "", fmt.Errorf("template not loaded: %s", templateName) + loadErr := tc.loadJinjaTemplateIfExists(templateType, templateNameOrContent) + if loadErr != nil { + return "", loadErr + } + m = tc.jinjaTemplates[templateType][templateNameOrContent] // ok is not important since we check m on the next line, and wealready checked + } + if m == nil { + return "", fmt.Errorf("failed loading a template for %s", templateNameOrContent) + } + + var buf bytes.Buffer + + data := exec.NewContext(in) + + if err := m.Execute(&buf, data); err != nil { + return "", err + } + return buf.String(), nil +} + +func (tc *templateCache) evaluateTemplate(templateType TemplateType, templateNameOrContent string, in interface{}) (string, error) { + tc.mu.Lock() + defer tc.mu.Unlock() + + tc.initializeTemplateMapKey(templateType) + m, ok := tc.templates[templateType][templateNameOrContent] + if !ok { + // return "", fmt.Errorf("template not loaded: %s", templateName) + loadErr := tc.loadTemplateIfExists(templateType, templateNameOrContent) + if loadErr != nil { + return "", loadErr + } + m = tc.templates[templateType][templateNameOrContent] // ok is not important since we check m on the next line, and wealready checked + } + if m == nil { + return "", fmt.Errorf("failed loading a template for %s", templateNameOrContent) + } + + var buf bytes.Buffer + + if err := m.Execute(&buf, in); err != nil { + return "", err + } + return buf.String(), nil +} diff --git a/pkg/templates/cache_test.go b/pkg/templates/cache_test.go deleted file mode 100644 index 8bb50766..00000000 --- a/pkg/templates/cache_test.go +++ /dev/null @@ -1,73 +0,0 @@ -package templates_test - -import ( - "os" - "path/filepath" - - "github.com/mudler/LocalAI/pkg/templates" // Update with your module path - . "github.com/onsi/ginkgo/v2" - . "github.com/onsi/gomega" -) - -var _ = Describe("TemplateCache", func() { - var ( - templateCache *templates.TemplateCache - tempDir string - ) - - BeforeEach(func() { - var err error - tempDir, err = os.MkdirTemp("", "templates") - Expect(err).NotTo(HaveOccurred()) - - // Writing example template files - err = os.WriteFile(filepath.Join(tempDir, "example.tmpl"), []byte("Hello, {{.Name}}!"), 0600) - Expect(err).NotTo(HaveOccurred()) - err = os.WriteFile(filepath.Join(tempDir, "empty.tmpl"), []byte(""), 0600) - Expect(err).NotTo(HaveOccurred()) - - templateCache = templates.NewTemplateCache(tempDir) - }) - - AfterEach(func() { - os.RemoveAll(tempDir) // Clean up - }) - - Describe("EvaluateTemplate", func() { - Context("when template is loaded successfully", func() { - It("should evaluate the template correctly", func() { - result, err := templateCache.EvaluateTemplate(1, "example", map[string]string{"Name": "Gopher"}) - Expect(err).NotTo(HaveOccurred()) - Expect(result).To(Equal("Hello, Gopher!")) - }) - }) - - Context("when template isn't a file", func() { - It("should parse from string", func() { - result, err := templateCache.EvaluateTemplate(1, "{{.Name}}", map[string]string{"Name": "Gopher"}) - Expect(err).ToNot(HaveOccurred()) - Expect(result).To(Equal("Gopher")) - }) - }) - - Context("when template is empty", func() { - It("should return an empty string", func() { - result, err := templateCache.EvaluateTemplate(1, "empty", nil) - Expect(err).NotTo(HaveOccurred()) - Expect(result).To(Equal("")) - }) - }) - }) - - Describe("concurrency", func() { - It("should handle multiple concurrent accesses", func(done Done) { - go func() { - _, _ = templateCache.EvaluateTemplate(1, "example", map[string]string{"Name": "Gopher"}) - }() - go func() { - _, _ = templateCache.EvaluateTemplate(1, "example", map[string]string{"Name": "Gopher"}) - }() - close(done) - }, 0.1) // timeout in seconds - }) -}) diff --git a/pkg/templates/evaluator.go b/pkg/templates/evaluator.go new file mode 100644 index 00000000..aedf7b41 --- /dev/null +++ b/pkg/templates/evaluator.go @@ -0,0 +1,295 @@ +package templates + +import ( + "encoding/json" + "fmt" + "strings" + + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/LocalAI/pkg/functions" + "github.com/rs/zerolog/log" +) + +// Rather than pass an interface{} to the prompt template: +// These are the definitions of all possible variables LocalAI will currently populate for use in a prompt template file +// Please note: Not all of these are populated on every endpoint - your template should either be tested for each endpoint you map it to, or tolerant of zero values. +type PromptTemplateData struct { + SystemPrompt string + SuppressSystemPrompt bool // used by chat specifically to indicate that SystemPrompt above should be _ignored_ + Input string + Instruction string + Functions []functions.Function + MessageIndex int +} + +type ChatMessageTemplateData struct { + SystemPrompt string + Role string + RoleName string + FunctionName string + Content string + MessageIndex int + Function bool + FunctionCall interface{} + LastMessage bool +} + +const ( + ChatPromptTemplate TemplateType = iota + ChatMessageTemplate + CompletionPromptTemplate + EditPromptTemplate + FunctionsPromptTemplate +) + +type Evaluator struct { + cache *templateCache +} + +func NewEvaluator(modelPath string) *Evaluator { + return &Evaluator{ + cache: newTemplateCache(modelPath), + } +} + +func (e *Evaluator) EvaluateTemplateForPrompt(templateType TemplateType, config config.BackendConfig, in PromptTemplateData) (string, error) { + template := "" + + // A model can have a "file.bin.tmpl" file associated with a prompt template prefix + if e.cache.existsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) { + template = config.Model + } + + switch templateType { + case CompletionPromptTemplate: + if config.TemplateConfig.Completion != "" { + template = config.TemplateConfig.Completion + } + case EditPromptTemplate: + if config.TemplateConfig.Edit != "" { + template = config.TemplateConfig.Edit + } + case ChatPromptTemplate: + if config.TemplateConfig.Chat != "" { + template = config.TemplateConfig.Chat + } + case FunctionsPromptTemplate: + if config.TemplateConfig.Functions != "" { + template = config.TemplateConfig.Functions + } + } + + if template == "" { + return in.Input, nil + } + + if config.TemplateConfig.JinjaTemplate { + return e.evaluateJinjaTemplateForPrompt(templateType, template, in) + } + + return e.cache.evaluateTemplate(templateType, template, in) +} + +func (e *Evaluator) evaluateTemplateForChatMessage(templateName string, messageData ChatMessageTemplateData) (string, error) { + return e.cache.evaluateTemplate(ChatMessageTemplate, templateName, messageData) +} + +func (e *Evaluator) templateJinjaChat(templateName string, messageData []ChatMessageTemplateData, funcs []functions.Function) (string, error) { + + conversation := make(map[string]interface{}) + messages := make([]map[string]interface{}, len(messageData)) + + // convert from ChatMessageTemplateData to what the jinja template expects + + for _, message := range messageData { + // TODO: this seems to cover minimum text templates. Can be expanded to cover more complex interactions + var data []byte + data, _ = json.Marshal(message.FunctionCall) + messages = append(messages, map[string]interface{}{ + "role": message.RoleName, + "content": message.Content, + "tool_call": string(data), + }) + } + + conversation["messages"] = messages + + // if tools are detected, add these + if len(funcs) > 0 { + conversation["tools"] = funcs + } + + return e.cache.evaluateJinjaTemplate(ChatMessageTemplate, templateName, conversation) +} + +func (e *Evaluator) evaluateJinjaTemplateForPrompt(templateType TemplateType, templateName string, in PromptTemplateData) (string, error) { + + conversation := make(map[string]interface{}) + + conversation["system_prompt"] = in.SystemPrompt + conversation["content"] = in.Input + + return e.cache.evaluateJinjaTemplate(templateType, templateName, conversation) +} + +func (e *Evaluator) TemplateMessages(messages []schema.Message, config *config.BackendConfig, funcs []functions.Function, shouldUseFn bool) string { + + if config.TemplateConfig.JinjaTemplate { + var messageData []ChatMessageTemplateData + for messageIndex, i := range messages { + fcall := i.FunctionCall + if len(i.ToolCalls) > 0 { + fcall = i.ToolCalls + } + messageData = append(messageData, ChatMessageTemplateData{ + SystemPrompt: config.SystemPrompt, + Role: config.Roles[i.Role], + RoleName: i.Role, + Content: i.StringContent, + FunctionCall: fcall, + FunctionName: i.Name, + LastMessage: messageIndex == (len(messages) - 1), + Function: config.Grammar != "" && (messageIndex == (len(messages) - 1)), + MessageIndex: messageIndex, + }) + } + + templatedInput, err := e.templateJinjaChat(config.TemplateConfig.ChatMessage, messageData, funcs) + if err == nil { + return templatedInput + } + } + + var predInput string + suppressConfigSystemPrompt := false + mess := []string{} + for messageIndex, i := range messages { + var content string + role := i.Role + + // if function call, we might want to customize the role so we can display better that the "assistant called a json action" + // if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request + if (i.FunctionCall != nil || i.ToolCalls != nil) && i.Role == "assistant" { + roleFn := "assistant_function_call" + r := config.Roles[roleFn] + if r != "" { + role = roleFn + } + } + r := config.Roles[role] + contentExists := i.Content != nil && i.StringContent != "" + + fcall := i.FunctionCall + if len(i.ToolCalls) > 0 { + fcall = i.ToolCalls + } + + // First attempt to populate content via a chat message specific template + if config.TemplateConfig.ChatMessage != "" { + chatMessageData := ChatMessageTemplateData{ + SystemPrompt: config.SystemPrompt, + Role: r, + RoleName: role, + Content: i.StringContent, + FunctionCall: fcall, + FunctionName: i.Name, + LastMessage: messageIndex == (len(messages) - 1), + Function: config.Grammar != "" && (messageIndex == (len(messages) - 1)), + MessageIndex: messageIndex, + } + templatedChatMessage, err := e.evaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData) + if err != nil { + log.Error().Err(err).Interface("message", chatMessageData).Str("template", config.TemplateConfig.ChatMessage).Msg("error processing message with template, skipping") + } else { + if templatedChatMessage == "" { + log.Warn().Msgf("template \"%s\" produced blank output for %+v. Skipping!", config.TemplateConfig.ChatMessage, chatMessageData) + continue // TODO: This continue is here intentionally to skip over the line `mess = append(mess, content)` below, and to prevent the sprintf + } + log.Debug().Msgf("templated message for chat: %s", templatedChatMessage) + content = templatedChatMessage + } + } + + marshalAnyRole := func(f any) { + j, err := json.Marshal(f) + if err == nil { + if contentExists { + content += "\n" + fmt.Sprint(r, " ", string(j)) + } else { + content = fmt.Sprint(r, " ", string(j)) + } + } + } + marshalAny := func(f any) { + j, err := json.Marshal(f) + if err == nil { + if contentExists { + content += "\n" + string(j) + } else { + content = string(j) + } + } + } + // If this model doesn't have such a template, or if that template fails to return a value, template at the message level. + if content == "" { + if r != "" { + if contentExists { + content = fmt.Sprint(r, i.StringContent) + } + + if i.FunctionCall != nil { + marshalAnyRole(i.FunctionCall) + } + if i.ToolCalls != nil { + marshalAnyRole(i.ToolCalls) + } + } else { + if contentExists { + content = fmt.Sprint(i.StringContent) + } + if i.FunctionCall != nil { + marshalAny(i.FunctionCall) + } + if i.ToolCalls != nil { + marshalAny(i.ToolCalls) + } + } + // Special Handling: System. We care if it was printed at all, not the r branch, so check seperately + if contentExists && role == "system" { + suppressConfigSystemPrompt = true + } + } + + mess = append(mess, content) + } + + joinCharacter := "\n" + if config.TemplateConfig.JoinChatMessagesByCharacter != nil { + joinCharacter = *config.TemplateConfig.JoinChatMessagesByCharacter + } + + predInput = strings.Join(mess, joinCharacter) + log.Debug().Msgf("Prompt (before templating): %s", predInput) + + promptTemplate := ChatPromptTemplate + + if config.TemplateConfig.Functions != "" && shouldUseFn { + promptTemplate = FunctionsPromptTemplate + } + + templatedInput, err := e.EvaluateTemplateForPrompt(promptTemplate, *config, PromptTemplateData{ + SystemPrompt: config.SystemPrompt, + SuppressSystemPrompt: suppressConfigSystemPrompt, + Input: predInput, + Functions: funcs, + }) + if err == nil { + predInput = templatedInput + log.Debug().Msgf("Template found, input modified to: %s", predInput) + } else { + log.Debug().Msgf("Template failed loading: %s", err.Error()) + } + + return predInput +} diff --git a/pkg/templates/evaluator_test.go b/pkg/templates/evaluator_test.go new file mode 100644 index 00000000..b58dd40b --- /dev/null +++ b/pkg/templates/evaluator_test.go @@ -0,0 +1,253 @@ +package templates_test + +import ( + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/LocalAI/pkg/functions" + . "github.com/mudler/LocalAI/pkg/templates" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +const toolCallJinja = `{{ '<|begin_of_text|>' }}{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}{% if system_message is defined %}{{ '<|start_header_id|>system<|end_header_id|> + +' + system_message + '<|eot_id|>' }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|start_header_id|>user<|end_header_id|> + +' + content + '<|eot_id|><|start_header_id|>assistant<|end_header_id|> + +' }}{% elif message['role'] == 'assistant' %}{{ content + '<|eot_id|>' }}{% endif %}{% endfor %}` + +const chatML = `<|im_start|>{{if eq .RoleName "assistant"}}assistant{{else if eq .RoleName "system"}}system{{else if eq .RoleName "tool"}}tool{{else if eq .RoleName "user"}}user{{end}} +{{- if .FunctionCall }} + +{{- else if eq .RoleName "tool" }} + +{{- end }} +{{- if .Content}} +{{.Content }} +{{- end }} +{{- if .FunctionCall}} +{{toJson .FunctionCall}} +{{- end }} +{{- if .FunctionCall }} + +{{- else if eq .RoleName "tool" }} + +{{- end }}<|im_end|>` + +const llama3 = `<|start_header_id|>{{if eq .RoleName "assistant"}}assistant{{else if eq .RoleName "system"}}system{{else if eq .RoleName "tool"}}tool{{else if eq .RoleName "user"}}user{{end}}<|end_header_id|> + +{{ if .FunctionCall -}} +Function call: +{{ else if eq .RoleName "tool" -}} +Function response: +{{ end -}} +{{ if .Content -}} +{{.Content -}} +{{ else if .FunctionCall -}} +{{ toJson .FunctionCall -}} +{{ end -}} +<|eot_id|>` + +var llama3TestMatch map[string]map[string]interface{} = map[string]map[string]interface{}{ + "user": { + "expected": "<|start_header_id|>user<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|>", + "config": &config.BackendConfig{ + TemplateConfig: config.TemplateConfig{ + ChatMessage: llama3, + }, + }, + "functions": []functions.Function{}, + "shouldUseFn": false, + "messages": []schema.Message{ + { + Role: "user", + StringContent: "A long time ago in a galaxy far, far away...", + }, + }, + }, + "assistant": { + "expected": "<|start_header_id|>assistant<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|>", + "config": &config.BackendConfig{ + TemplateConfig: config.TemplateConfig{ + ChatMessage: llama3, + }, + }, + "functions": []functions.Function{}, + "messages": []schema.Message{ + { + Role: "assistant", + StringContent: "A long time ago in a galaxy far, far away...", + }, + }, + "shouldUseFn": false, + }, + "function_call": { + + "expected": "<|start_header_id|>assistant<|end_header_id|>\n\nFunction call:\n{\"function\":\"test\"}<|eot_id|>", + "config": &config.BackendConfig{ + TemplateConfig: config.TemplateConfig{ + ChatMessage: llama3, + }, + }, + "functions": []functions.Function{}, + "messages": []schema.Message{ + { + Role: "assistant", + FunctionCall: map[string]string{"function": "test"}, + }, + }, + "shouldUseFn": false, + }, + "function_response": { + "expected": "<|start_header_id|>tool<|end_header_id|>\n\nFunction response:\nResponse from tool<|eot_id|>", + "config": &config.BackendConfig{ + TemplateConfig: config.TemplateConfig{ + ChatMessage: llama3, + }, + }, + "functions": []functions.Function{}, + "messages": []schema.Message{ + { + Role: "tool", + StringContent: "Response from tool", + }, + }, + "shouldUseFn": false, + }, +} + +var chatMLTestMatch map[string]map[string]interface{} = map[string]map[string]interface{}{ + "user": { + "expected": "<|im_start|>user\nA long time ago in a galaxy far, far away...<|im_end|>", + "config": &config.BackendConfig{ + TemplateConfig: config.TemplateConfig{ + ChatMessage: chatML, + }, + }, + "functions": []functions.Function{}, + "shouldUseFn": false, + "messages": []schema.Message{ + { + Role: "user", + StringContent: "A long time ago in a galaxy far, far away...", + }, + }, + }, + "assistant": { + "expected": "<|im_start|>assistant\nA long time ago in a galaxy far, far away...<|im_end|>", + "config": &config.BackendConfig{ + TemplateConfig: config.TemplateConfig{ + ChatMessage: chatML, + }, + }, + "functions": []functions.Function{}, + "messages": []schema.Message{ + { + Role: "assistant", + StringContent: "A long time ago in a galaxy far, far away...", + }, + }, + "shouldUseFn": false, + }, + "function_call": { + "expected": "<|im_start|>assistant\n\n{\"function\":\"test\"}\n<|im_end|>", + "config": &config.BackendConfig{ + TemplateConfig: config.TemplateConfig{ + ChatMessage: chatML, + }, + }, + "functions": []functions.Function{ + { + Name: "test", + Description: "test", + Parameters: nil, + }, + }, + "shouldUseFn": true, + "messages": []schema.Message{ + { + Role: "assistant", + FunctionCall: map[string]string{"function": "test"}, + }, + }, + }, + "function_response": { + "expected": "<|im_start|>tool\n\nResponse from tool\n<|im_end|>", + "config": &config.BackendConfig{ + TemplateConfig: config.TemplateConfig{ + ChatMessage: chatML, + }, + }, + "functions": []functions.Function{}, + "shouldUseFn": false, + "messages": []schema.Message{ + { + Role: "tool", + StringContent: "Response from tool", + }, + }, + }, +} + +var jinjaTest map[string]map[string]interface{} = map[string]map[string]interface{}{ + "user": { + "expected": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", + "config": &config.BackendConfig{ + TemplateConfig: config.TemplateConfig{ + ChatMessage: toolCallJinja, + JinjaTemplate: true, + }, + }, + "functions": []functions.Function{}, + "shouldUseFn": false, + "messages": []schema.Message{ + { + Role: "user", + StringContent: "A long time ago in a galaxy far, far away...", + }, + }, + }, +} +var _ = Describe("Templates", func() { + Context("chat message ChatML", func() { + var evaluator *Evaluator + BeforeEach(func() { + evaluator = NewEvaluator("") + }) + for key := range chatMLTestMatch { + foo := chatMLTestMatch[key] + It("renders correctly `"+key+"`", func() { + templated := evaluator.TemplateMessages(foo["messages"].([]schema.Message), foo["config"].(*config.BackendConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool)) + Expect(templated).To(Equal(foo["expected"]), templated) + }) + } + }) + Context("chat message llama3", func() { + var evaluator *Evaluator + BeforeEach(func() { + evaluator = NewEvaluator("") + }) + for key := range llama3TestMatch { + foo := llama3TestMatch[key] + It("renders correctly `"+key+"`", func() { + templated := evaluator.TemplateMessages(foo["messages"].([]schema.Message), foo["config"].(*config.BackendConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool)) + Expect(templated).To(Equal(foo["expected"]), templated) + }) + } + }) + Context("chat message jinja", func() { + var evaluator *Evaluator + BeforeEach(func() { + evaluator = NewEvaluator("") + }) + for key := range jinjaTest { + foo := jinjaTest[key] + It("renders correctly `"+key+"`", func() { + templated := evaluator.TemplateMessages(foo["messages"].([]schema.Message), foo["config"].(*config.BackendConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool)) + Expect(templated).To(Equal(foo["expected"]), templated) + }) + } + }) +})