mirror of
https://github.com/mudler/LocalAI.git
synced 2024-12-18 20:27:57 +00:00
feat(template): read jinja templates from gguf files (#4332)
* Read jinja templates as fallback Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Move templating out of model loader Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Test TemplateMessages Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Set role and content from transformers Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Tests: be more flexible Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * More jinja Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Small refactoring and adaptations Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
f5e1527a5a
commit
cea5a0ea42
@ -1,38 +0,0 @@
|
|||||||
package core
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/mudler/LocalAI/core/config"
|
|
||||||
"github.com/mudler/LocalAI/core/services"
|
|
||||||
"github.com/mudler/LocalAI/pkg/model"
|
|
||||||
)
|
|
||||||
|
|
||||||
// The purpose of this structure is to hold pointers to all initialized services, to make plumbing easy
|
|
||||||
// Perhaps a proper DI system is worth it in the future, but for now keep things simple.
|
|
||||||
type Application struct {
|
|
||||||
|
|
||||||
// Application-Level Config
|
|
||||||
ApplicationConfig *config.ApplicationConfig
|
|
||||||
// ApplicationState *ApplicationState
|
|
||||||
|
|
||||||
// Core Low-Level Services
|
|
||||||
BackendConfigLoader *config.BackendConfigLoader
|
|
||||||
ModelLoader *model.ModelLoader
|
|
||||||
|
|
||||||
// Backend Services
|
|
||||||
// EmbeddingsBackendService *backend.EmbeddingsBackendService
|
|
||||||
// ImageGenerationBackendService *backend.ImageGenerationBackendService
|
|
||||||
// LLMBackendService *backend.LLMBackendService
|
|
||||||
// TranscriptionBackendService *backend.TranscriptionBackendService
|
|
||||||
// TextToSpeechBackendService *backend.TextToSpeechBackendService
|
|
||||||
|
|
||||||
// LocalAI System Services
|
|
||||||
BackendMonitorService *services.BackendMonitorService
|
|
||||||
GalleryService *services.GalleryService
|
|
||||||
LocalAIMetricsService *services.LocalAIMetricsService
|
|
||||||
// OpenAIService *services.OpenAIService
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO [NEXT PR?]: Break up ApplicationConfig.
|
|
||||||
// Migrate over stuff that is not set via config at all - especially runtime stuff
|
|
||||||
type ApplicationState struct {
|
|
||||||
}
|
|
39
core/application/application.go
Normal file
39
core/application/application.go
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
package application
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/mudler/LocalAI/core/config"
|
||||||
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
|
"github.com/mudler/LocalAI/pkg/templates"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Application struct {
|
||||||
|
backendLoader *config.BackendConfigLoader
|
||||||
|
modelLoader *model.ModelLoader
|
||||||
|
applicationConfig *config.ApplicationConfig
|
||||||
|
templatesEvaluator *templates.Evaluator
|
||||||
|
}
|
||||||
|
|
||||||
|
func newApplication(appConfig *config.ApplicationConfig) *Application {
|
||||||
|
return &Application{
|
||||||
|
backendLoader: config.NewBackendConfigLoader(appConfig.ModelPath),
|
||||||
|
modelLoader: model.NewModelLoader(appConfig.ModelPath),
|
||||||
|
applicationConfig: appConfig,
|
||||||
|
templatesEvaluator: templates.NewEvaluator(appConfig.ModelPath),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Application) BackendLoader() *config.BackendConfigLoader {
|
||||||
|
return a.backendLoader
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Application) ModelLoader() *model.ModelLoader {
|
||||||
|
return a.modelLoader
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Application) ApplicationConfig() *config.ApplicationConfig {
|
||||||
|
return a.applicationConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Application) TemplatesEvaluator() *templates.Evaluator {
|
||||||
|
return a.templatesEvaluator
|
||||||
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package startup
|
package application
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@ -8,8 +8,8 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/fsnotify/fsnotify"
|
|
||||||
"dario.cat/mergo"
|
"dario.cat/mergo"
|
||||||
|
"github.com/fsnotify/fsnotify"
|
||||||
"github.com/mudler/LocalAI/core/config"
|
"github.com/mudler/LocalAI/core/config"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
@ -1,15 +1,15 @@
|
|||||||
package startup
|
package application
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/core"
|
|
||||||
"github.com/mudler/LocalAI/core/backend"
|
"github.com/mudler/LocalAI/core/backend"
|
||||||
"github.com/mudler/LocalAI/core/config"
|
"github.com/mudler/LocalAI/core/config"
|
||||||
"github.com/mudler/LocalAI/core/services"
|
"github.com/mudler/LocalAI/core/services"
|
||||||
"github.com/mudler/LocalAI/internal"
|
"github.com/mudler/LocalAI/internal"
|
||||||
"github.com/mudler/LocalAI/pkg/assets"
|
"github.com/mudler/LocalAI/pkg/assets"
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/pkg/library"
|
"github.com/mudler/LocalAI/pkg/library"
|
||||||
"github.com/mudler/LocalAI/pkg/model"
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
pkgStartup "github.com/mudler/LocalAI/pkg/startup"
|
pkgStartup "github.com/mudler/LocalAI/pkg/startup"
|
||||||
@ -17,8 +17,9 @@ import (
|
|||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.ModelLoader, *config.ApplicationConfig, error) {
|
func New(opts ...config.AppOption) (*Application, error) {
|
||||||
options := config.NewApplicationConfig(opts...)
|
options := config.NewApplicationConfig(opts...)
|
||||||
|
application := newApplication(options)
|
||||||
|
|
||||||
log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.ModelPath)
|
log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.ModelPath)
|
||||||
log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion())
|
log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion())
|
||||||
@ -36,28 +37,28 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode
|
|||||||
|
|
||||||
// Make sure directories exists
|
// Make sure directories exists
|
||||||
if options.ModelPath == "" {
|
if options.ModelPath == "" {
|
||||||
return nil, nil, nil, fmt.Errorf("options.ModelPath cannot be empty")
|
return nil, fmt.Errorf("options.ModelPath cannot be empty")
|
||||||
}
|
}
|
||||||
err = os.MkdirAll(options.ModelPath, 0750)
|
err = os.MkdirAll(options.ModelPath, 0750)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, fmt.Errorf("unable to create ModelPath: %q", err)
|
return nil, fmt.Errorf("unable to create ModelPath: %q", err)
|
||||||
}
|
}
|
||||||
if options.ImageDir != "" {
|
if options.ImageDir != "" {
|
||||||
err := os.MkdirAll(options.ImageDir, 0750)
|
err := os.MkdirAll(options.ImageDir, 0750)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, fmt.Errorf("unable to create ImageDir: %q", err)
|
return nil, fmt.Errorf("unable to create ImageDir: %q", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if options.AudioDir != "" {
|
if options.AudioDir != "" {
|
||||||
err := os.MkdirAll(options.AudioDir, 0750)
|
err := os.MkdirAll(options.AudioDir, 0750)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, fmt.Errorf("unable to create AudioDir: %q", err)
|
return nil, fmt.Errorf("unable to create AudioDir: %q", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if options.UploadDir != "" {
|
if options.UploadDir != "" {
|
||||||
err := os.MkdirAll(options.UploadDir, 0750)
|
err := os.MkdirAll(options.UploadDir, 0750)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, fmt.Errorf("unable to create UploadDir: %q", err)
|
return nil, fmt.Errorf("unable to create UploadDir: %q", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -65,39 +66,36 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode
|
|||||||
log.Error().Err(err).Msg("error installing models")
|
log.Error().Err(err).Msg("error installing models")
|
||||||
}
|
}
|
||||||
|
|
||||||
cl := config.NewBackendConfigLoader(options.ModelPath)
|
|
||||||
ml := model.NewModelLoader(options.ModelPath)
|
|
||||||
|
|
||||||
configLoaderOpts := options.ToConfigLoaderOptions()
|
configLoaderOpts := options.ToConfigLoaderOptions()
|
||||||
|
|
||||||
if err := cl.LoadBackendConfigsFromPath(options.ModelPath, configLoaderOpts...); err != nil {
|
if err := application.BackendLoader().LoadBackendConfigsFromPath(options.ModelPath, configLoaderOpts...); err != nil {
|
||||||
log.Error().Err(err).Msg("error loading config files")
|
log.Error().Err(err).Msg("error loading config files")
|
||||||
}
|
}
|
||||||
|
|
||||||
if options.ConfigFile != "" {
|
if options.ConfigFile != "" {
|
||||||
if err := cl.LoadMultipleBackendConfigsSingleFile(options.ConfigFile, configLoaderOpts...); err != nil {
|
if err := application.BackendLoader().LoadMultipleBackendConfigsSingleFile(options.ConfigFile, configLoaderOpts...); err != nil {
|
||||||
log.Error().Err(err).Msg("error loading config file")
|
log.Error().Err(err).Msg("error loading config file")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := cl.Preload(options.ModelPath); err != nil {
|
if err := application.BackendLoader().Preload(options.ModelPath); err != nil {
|
||||||
log.Error().Err(err).Msg("error downloading models")
|
log.Error().Err(err).Msg("error downloading models")
|
||||||
}
|
}
|
||||||
|
|
||||||
if options.PreloadJSONModels != "" {
|
if options.PreloadJSONModels != "" {
|
||||||
if err := services.ApplyGalleryFromString(options.ModelPath, options.PreloadJSONModels, options.EnforcePredownloadScans, options.Galleries); err != nil {
|
if err := services.ApplyGalleryFromString(options.ModelPath, options.PreloadJSONModels, options.EnforcePredownloadScans, options.Galleries); err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if options.PreloadModelsFromPath != "" {
|
if options.PreloadModelsFromPath != "" {
|
||||||
if err := services.ApplyGalleryFromFile(options.ModelPath, options.PreloadModelsFromPath, options.EnforcePredownloadScans, options.Galleries); err != nil {
|
if err := services.ApplyGalleryFromFile(options.ModelPath, options.PreloadModelsFromPath, options.EnforcePredownloadScans, options.Galleries); err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if options.Debug {
|
if options.Debug {
|
||||||
for _, v := range cl.GetAllBackendConfigs() {
|
for _, v := range application.BackendLoader().GetAllBackendConfigs() {
|
||||||
log.Debug().Msgf("Model: %s (config: %+v)", v.Name, v)
|
log.Debug().Msgf("Model: %s (config: %+v)", v.Name, v)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -123,7 +121,7 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode
|
|||||||
go func() {
|
go func() {
|
||||||
<-options.Context.Done()
|
<-options.Context.Done()
|
||||||
log.Debug().Msgf("Context canceled, shutting down")
|
log.Debug().Msgf("Context canceled, shutting down")
|
||||||
err := ml.StopAllGRPC()
|
err := application.ModelLoader().StopAllGRPC()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("error while stopping all grpc backends")
|
log.Error().Err(err).Msg("error while stopping all grpc backends")
|
||||||
}
|
}
|
||||||
@ -131,12 +129,12 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode
|
|||||||
|
|
||||||
if options.WatchDog {
|
if options.WatchDog {
|
||||||
wd := model.NewWatchDog(
|
wd := model.NewWatchDog(
|
||||||
ml,
|
application.ModelLoader(),
|
||||||
options.WatchDogBusyTimeout,
|
options.WatchDogBusyTimeout,
|
||||||
options.WatchDogIdleTimeout,
|
options.WatchDogIdleTimeout,
|
||||||
options.WatchDogBusy,
|
options.WatchDogBusy,
|
||||||
options.WatchDogIdle)
|
options.WatchDogIdle)
|
||||||
ml.SetWatchDog(wd)
|
application.ModelLoader().SetWatchDog(wd)
|
||||||
go wd.Run()
|
go wd.Run()
|
||||||
go func() {
|
go func() {
|
||||||
<-options.Context.Done()
|
<-options.Context.Done()
|
||||||
@ -147,7 +145,7 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode
|
|||||||
|
|
||||||
if options.LoadToMemory != nil {
|
if options.LoadToMemory != nil {
|
||||||
for _, m := range options.LoadToMemory {
|
for _, m := range options.LoadToMemory {
|
||||||
cfg, err := cl.LoadBackendConfigFileByName(m, options.ModelPath,
|
cfg, err := application.BackendLoader().LoadBackendConfigFileByName(m, options.ModelPath,
|
||||||
config.LoadOptionDebug(options.Debug),
|
config.LoadOptionDebug(options.Debug),
|
||||||
config.LoadOptionThreads(options.Threads),
|
config.LoadOptionThreads(options.Threads),
|
||||||
config.LoadOptionContextSize(options.ContextSize),
|
config.LoadOptionContextSize(options.ContextSize),
|
||||||
@ -155,7 +153,7 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode
|
|||||||
config.ModelPath(options.ModelPath),
|
config.ModelPath(options.ModelPath),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug().Msgf("Auto loading model %s into memory from file: %s", m, cfg.Model)
|
log.Debug().Msgf("Auto loading model %s into memory from file: %s", m, cfg.Model)
|
||||||
@ -163,9 +161,9 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode
|
|||||||
o := backend.ModelOptions(*cfg, options)
|
o := backend.ModelOptions(*cfg, options)
|
||||||
|
|
||||||
var backendErr error
|
var backendErr error
|
||||||
_, backendErr = ml.Load(o...)
|
_, backendErr = application.ModelLoader().Load(o...)
|
||||||
if backendErr != nil {
|
if backendErr != nil {
|
||||||
return nil, nil, nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -174,7 +172,7 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode
|
|||||||
startWatcher(options)
|
startWatcher(options)
|
||||||
|
|
||||||
log.Info().Msg("core/startup process completed!")
|
log.Info().Msg("core/startup process completed!")
|
||||||
return cl, ml, options, nil
|
return application, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func startWatcher(options *config.ApplicationConfig) {
|
func startWatcher(options *config.ApplicationConfig) {
|
||||||
@ -201,32 +199,3 @@ func startWatcher(options *config.ApplicationConfig) {
|
|||||||
log.Error().Err(err).Msg("failed creating watcher")
|
log.Error().Err(err).Msg("failed creating watcher")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// In Lieu of a proper DI framework, this function wires up the Application manually.
|
|
||||||
// This is in core/startup rather than core/state.go to keep package references clean!
|
|
||||||
func createApplication(appConfig *config.ApplicationConfig) *core.Application {
|
|
||||||
app := &core.Application{
|
|
||||||
ApplicationConfig: appConfig,
|
|
||||||
BackendConfigLoader: config.NewBackendConfigLoader(appConfig.ModelPath),
|
|
||||||
ModelLoader: model.NewModelLoader(appConfig.ModelPath),
|
|
||||||
}
|
|
||||||
|
|
||||||
var err error
|
|
||||||
|
|
||||||
// app.EmbeddingsBackendService = backend.NewEmbeddingsBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
|
|
||||||
// app.ImageGenerationBackendService = backend.NewImageGenerationBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
|
|
||||||
// app.LLMBackendService = backend.NewLLMBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
|
|
||||||
// app.TranscriptionBackendService = backend.NewTranscriptionBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
|
|
||||||
// app.TextToSpeechBackendService = backend.NewTextToSpeechBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
|
|
||||||
|
|
||||||
app.BackendMonitorService = services.NewBackendMonitorService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
|
|
||||||
app.GalleryService = services.NewGalleryService(app.ApplicationConfig)
|
|
||||||
// app.OpenAIService = services.NewOpenAIService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig, app.LLMBackendService)
|
|
||||||
|
|
||||||
app.LocalAIMetricsService, err = services.NewLocalAIMetricsService()
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Msg("encountered an error initializing metrics service, startup will continue but metrics will not be tracked.")
|
|
||||||
}
|
|
||||||
|
|
||||||
return app
|
|
||||||
}
|
|
@ -6,12 +6,12 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/mudler/LocalAI/core/application"
|
||||||
cli_api "github.com/mudler/LocalAI/core/cli/api"
|
cli_api "github.com/mudler/LocalAI/core/cli/api"
|
||||||
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||||
"github.com/mudler/LocalAI/core/config"
|
"github.com/mudler/LocalAI/core/config"
|
||||||
"github.com/mudler/LocalAI/core/http"
|
"github.com/mudler/LocalAI/core/http"
|
||||||
"github.com/mudler/LocalAI/core/p2p"
|
"github.com/mudler/LocalAI/core/p2p"
|
||||||
"github.com/mudler/LocalAI/core/startup"
|
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
@ -186,16 +186,16 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if r.PreloadBackendOnly {
|
if r.PreloadBackendOnly {
|
||||||
_, _, _, err := startup.Startup(opts...)
|
_, err := application.New(opts...)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
cl, ml, options, err := startup.Startup(opts...)
|
app, err := application.New(opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed basic startup tasks with error %s", err.Error())
|
return fmt.Errorf("failed basic startup tasks with error %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
appHTTP, err := http.App(cl, ml, options)
|
appHTTP, err := http.API(app)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("error during HTTP App construction")
|
log.Error().Err(err).Msg("error during HTTP App construction")
|
||||||
return err
|
return err
|
||||||
|
@ -206,6 +206,8 @@ type TemplateConfig struct {
|
|||||||
JoinChatMessagesByCharacter *string `yaml:"join_chat_messages_by_character"`
|
JoinChatMessagesByCharacter *string `yaml:"join_chat_messages_by_character"`
|
||||||
|
|
||||||
Multimodal string `yaml:"multimodal"`
|
Multimodal string `yaml:"multimodal"`
|
||||||
|
|
||||||
|
JinjaTemplate bool `yaml:"jinja_template"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *BackendConfig) UnmarshalYAML(value *yaml.Node) error {
|
func (c *BackendConfig) UnmarshalYAML(value *yaml.Node) error {
|
||||||
|
@ -26,14 +26,14 @@ const (
|
|||||||
type settingsConfig struct {
|
type settingsConfig struct {
|
||||||
StopWords []string
|
StopWords []string
|
||||||
TemplateConfig TemplateConfig
|
TemplateConfig TemplateConfig
|
||||||
RepeatPenalty float64
|
RepeatPenalty float64
|
||||||
}
|
}
|
||||||
|
|
||||||
// default settings to adopt with a given model family
|
// default settings to adopt with a given model family
|
||||||
var defaultsSettings map[familyType]settingsConfig = map[familyType]settingsConfig{
|
var defaultsSettings map[familyType]settingsConfig = map[familyType]settingsConfig{
|
||||||
Gemma: {
|
Gemma: {
|
||||||
RepeatPenalty: 1.0,
|
RepeatPenalty: 1.0,
|
||||||
StopWords: []string{"<|im_end|>", "<end_of_turn>", "<start_of_turn>"},
|
StopWords: []string{"<|im_end|>", "<end_of_turn>", "<start_of_turn>"},
|
||||||
TemplateConfig: TemplateConfig{
|
TemplateConfig: TemplateConfig{
|
||||||
Chat: "{{.Input }}\n<start_of_turn>model\n",
|
Chat: "{{.Input }}\n<start_of_turn>model\n",
|
||||||
ChatMessage: "<start_of_turn>{{if eq .RoleName \"assistant\" }}model{{else}}{{ .RoleName }}{{end}}\n{{ if .Content -}}\n{{.Content -}}\n{{ end -}}<end_of_turn>",
|
ChatMessage: "<start_of_turn>{{if eq .RoleName \"assistant\" }}model{{else}}{{ .RoleName }}{{end}}\n{{ if .Content -}}\n{{.Content -}}\n{{ end -}}<end_of_turn>",
|
||||||
@ -200,6 +200,18 @@ func guessDefaultsFromFile(cfg *BackendConfig, modelPath string) {
|
|||||||
} else {
|
} else {
|
||||||
log.Debug().Any("family", family).Msgf("guessDefaultsFromFile: no template found for family")
|
log.Debug().Any("family", family).Msgf("guessDefaultsFromFile: no template found for family")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if cfg.HasTemplate() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// identify from well known templates first, otherwise use the raw jinja template
|
||||||
|
chatTemplate, found := f.Header.MetadataKV.Get("tokenizer.chat_template")
|
||||||
|
if found {
|
||||||
|
// try to use the jinja template
|
||||||
|
cfg.TemplateConfig.JinjaTemplate = true
|
||||||
|
cfg.TemplateConfig.ChatMessage = chatTemplate.ValueString()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func identifyFamily(f *gguf.GGUFFile) familyType {
|
func identifyFamily(f *gguf.GGUFFile) familyType {
|
||||||
|
@ -14,10 +14,9 @@ import (
|
|||||||
"github.com/mudler/LocalAI/core/http/middleware"
|
"github.com/mudler/LocalAI/core/http/middleware"
|
||||||
"github.com/mudler/LocalAI/core/http/routes"
|
"github.com/mudler/LocalAI/core/http/routes"
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/core/config"
|
"github.com/mudler/LocalAI/core/application"
|
||||||
"github.com/mudler/LocalAI/core/schema"
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
"github.com/mudler/LocalAI/core/services"
|
"github.com/mudler/LocalAI/core/services"
|
||||||
"github.com/mudler/LocalAI/pkg/model"
|
|
||||||
|
|
||||||
"github.com/gofiber/contrib/fiberzerolog"
|
"github.com/gofiber/contrib/fiberzerolog"
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
@ -49,18 +48,18 @@ var embedDirStatic embed.FS
|
|||||||
// @in header
|
// @in header
|
||||||
// @name Authorization
|
// @name Authorization
|
||||||
|
|
||||||
func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (*fiber.App, error) {
|
func API(application *application.Application) (*fiber.App, error) {
|
||||||
|
|
||||||
fiberCfg := fiber.Config{
|
fiberCfg := fiber.Config{
|
||||||
Views: renderEngine(),
|
Views: renderEngine(),
|
||||||
BodyLimit: appConfig.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
|
BodyLimit: application.ApplicationConfig().UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
|
||||||
// We disable the Fiber startup message as it does not conform to structured logging.
|
// We disable the Fiber startup message as it does not conform to structured logging.
|
||||||
// We register a startup log line with connection information in the OnListen hook to keep things user friendly though
|
// We register a startup log line with connection information in the OnListen hook to keep things user friendly though
|
||||||
DisableStartupMessage: true,
|
DisableStartupMessage: true,
|
||||||
// Override default error handler
|
// Override default error handler
|
||||||
}
|
}
|
||||||
|
|
||||||
if !appConfig.OpaqueErrors {
|
if !application.ApplicationConfig().OpaqueErrors {
|
||||||
// Normally, return errors as JSON responses
|
// Normally, return errors as JSON responses
|
||||||
fiberCfg.ErrorHandler = func(ctx *fiber.Ctx, err error) error {
|
fiberCfg.ErrorHandler = func(ctx *fiber.Ctx, err error) error {
|
||||||
// Status code defaults to 500
|
// Status code defaults to 500
|
||||||
@ -86,9 +85,9 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
app := fiber.New(fiberCfg)
|
router := fiber.New(fiberCfg)
|
||||||
|
|
||||||
app.Hooks().OnListen(func(listenData fiber.ListenData) error {
|
router.Hooks().OnListen(func(listenData fiber.ListenData) error {
|
||||||
scheme := "http"
|
scheme := "http"
|
||||||
if listenData.TLS {
|
if listenData.TLS {
|
||||||
scheme = "https"
|
scheme = "https"
|
||||||
@ -99,82 +98,82 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi
|
|||||||
|
|
||||||
// Have Fiber use zerolog like the rest of the application rather than it's built-in logger
|
// Have Fiber use zerolog like the rest of the application rather than it's built-in logger
|
||||||
logger := log.Logger
|
logger := log.Logger
|
||||||
app.Use(fiberzerolog.New(fiberzerolog.Config{
|
router.Use(fiberzerolog.New(fiberzerolog.Config{
|
||||||
Logger: &logger,
|
Logger: &logger,
|
||||||
}))
|
}))
|
||||||
|
|
||||||
// Default middleware config
|
// Default middleware config
|
||||||
|
|
||||||
if !appConfig.Debug {
|
if !application.ApplicationConfig().Debug {
|
||||||
app.Use(recover.New())
|
router.Use(recover.New())
|
||||||
}
|
}
|
||||||
|
|
||||||
if !appConfig.DisableMetrics {
|
if !application.ApplicationConfig().DisableMetrics {
|
||||||
metricsService, err := services.NewLocalAIMetricsService()
|
metricsService, err := services.NewLocalAIMetricsService()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if metricsService != nil {
|
if metricsService != nil {
|
||||||
app.Use(localai.LocalAIMetricsAPIMiddleware(metricsService))
|
router.Use(localai.LocalAIMetricsAPIMiddleware(metricsService))
|
||||||
app.Hooks().OnShutdown(func() error {
|
router.Hooks().OnShutdown(func() error {
|
||||||
return metricsService.Shutdown()
|
return metricsService.Shutdown()
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
// Health Checks should always be exempt from auth, so register these first
|
// Health Checks should always be exempt from auth, so register these first
|
||||||
routes.HealthRoutes(app)
|
routes.HealthRoutes(router)
|
||||||
|
|
||||||
kaConfig, err := middleware.GetKeyAuthConfig(appConfig)
|
kaConfig, err := middleware.GetKeyAuthConfig(application.ApplicationConfig())
|
||||||
if err != nil || kaConfig == nil {
|
if err != nil || kaConfig == nil {
|
||||||
return nil, fmt.Errorf("failed to create key auth config: %w", err)
|
return nil, fmt.Errorf("failed to create key auth config: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Auth is applied to _all_ endpoints. No exceptions. Filtering out endpoints to bypass is the role of the Filter property of the KeyAuth Configuration
|
// Auth is applied to _all_ endpoints. No exceptions. Filtering out endpoints to bypass is the role of the Filter property of the KeyAuth Configuration
|
||||||
app.Use(v2keyauth.New(*kaConfig))
|
router.Use(v2keyauth.New(*kaConfig))
|
||||||
|
|
||||||
if appConfig.CORS {
|
if application.ApplicationConfig().CORS {
|
||||||
var c func(ctx *fiber.Ctx) error
|
var c func(ctx *fiber.Ctx) error
|
||||||
if appConfig.CORSAllowOrigins == "" {
|
if application.ApplicationConfig().CORSAllowOrigins == "" {
|
||||||
c = cors.New()
|
c = cors.New()
|
||||||
} else {
|
} else {
|
||||||
c = cors.New(cors.Config{AllowOrigins: appConfig.CORSAllowOrigins})
|
c = cors.New(cors.Config{AllowOrigins: application.ApplicationConfig().CORSAllowOrigins})
|
||||||
}
|
}
|
||||||
|
|
||||||
app.Use(c)
|
router.Use(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
if appConfig.CSRF {
|
if application.ApplicationConfig().CSRF {
|
||||||
log.Debug().Msg("Enabling CSRF middleware. Tokens are now required for state-modifying requests")
|
log.Debug().Msg("Enabling CSRF middleware. Tokens are now required for state-modifying requests")
|
||||||
app.Use(csrf.New())
|
router.Use(csrf.New())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load config jsons
|
// Load config jsons
|
||||||
utils.LoadConfig(appConfig.UploadDir, openai.UploadedFilesFile, &openai.UploadedFiles)
|
utils.LoadConfig(application.ApplicationConfig().UploadDir, openai.UploadedFilesFile, &openai.UploadedFiles)
|
||||||
utils.LoadConfig(appConfig.ConfigsDir, openai.AssistantsConfigFile, &openai.Assistants)
|
utils.LoadConfig(application.ApplicationConfig().ConfigsDir, openai.AssistantsConfigFile, &openai.Assistants)
|
||||||
utils.LoadConfig(appConfig.ConfigsDir, openai.AssistantsFileConfigFile, &openai.AssistantFiles)
|
utils.LoadConfig(application.ApplicationConfig().ConfigsDir, openai.AssistantsFileConfigFile, &openai.AssistantFiles)
|
||||||
|
|
||||||
galleryService := services.NewGalleryService(appConfig)
|
galleryService := services.NewGalleryService(application.ApplicationConfig())
|
||||||
galleryService.Start(appConfig.Context, cl)
|
galleryService.Start(application.ApplicationConfig().Context, application.BackendLoader())
|
||||||
|
|
||||||
routes.RegisterElevenLabsRoutes(app, cl, ml, appConfig)
|
routes.RegisterElevenLabsRoutes(router, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||||
routes.RegisterLocalAIRoutes(app, cl, ml, appConfig, galleryService)
|
routes.RegisterLocalAIRoutes(router, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig(), galleryService)
|
||||||
routes.RegisterOpenAIRoutes(app, cl, ml, appConfig)
|
routes.RegisterOpenAIRoutes(router, application)
|
||||||
if !appConfig.DisableWebUI {
|
if !application.ApplicationConfig().DisableWebUI {
|
||||||
routes.RegisterUIRoutes(app, cl, ml, appConfig, galleryService)
|
routes.RegisterUIRoutes(router, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig(), galleryService)
|
||||||
}
|
}
|
||||||
routes.RegisterJINARoutes(app, cl, ml, appConfig)
|
routes.RegisterJINARoutes(router, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||||
|
|
||||||
httpFS := http.FS(embedDirStatic)
|
httpFS := http.FS(embedDirStatic)
|
||||||
|
|
||||||
app.Use(favicon.New(favicon.Config{
|
router.Use(favicon.New(favicon.Config{
|
||||||
URL: "/favicon.ico",
|
URL: "/favicon.ico",
|
||||||
FileSystem: httpFS,
|
FileSystem: httpFS,
|
||||||
File: "static/favicon.ico",
|
File: "static/favicon.ico",
|
||||||
}))
|
}))
|
||||||
|
|
||||||
app.Use("/static", filesystem.New(filesystem.Config{
|
router.Use("/static", filesystem.New(filesystem.Config{
|
||||||
Root: httpFS,
|
Root: httpFS,
|
||||||
PathPrefix: "static",
|
PathPrefix: "static",
|
||||||
Browse: true,
|
Browse: true,
|
||||||
@ -182,7 +181,7 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi
|
|||||||
|
|
||||||
// Define a custom 404 handler
|
// Define a custom 404 handler
|
||||||
// Note: keep this at the bottom!
|
// Note: keep this at the bottom!
|
||||||
app.Use(notFoundHandler)
|
router.Use(notFoundHandler)
|
||||||
|
|
||||||
return app, nil
|
return router, nil
|
||||||
}
|
}
|
||||||
|
@ -12,15 +12,14 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
|
||||||
|
"github.com/mudler/LocalAI/core/application"
|
||||||
"github.com/mudler/LocalAI/core/config"
|
"github.com/mudler/LocalAI/core/config"
|
||||||
. "github.com/mudler/LocalAI/core/http"
|
. "github.com/mudler/LocalAI/core/http"
|
||||||
"github.com/mudler/LocalAI/core/schema"
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
"github.com/mudler/LocalAI/core/startup"
|
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/mudler/LocalAI/core/gallery"
|
"github.com/mudler/LocalAI/core/gallery"
|
||||||
"github.com/mudler/LocalAI/pkg/downloader"
|
"github.com/mudler/LocalAI/pkg/downloader"
|
||||||
"github.com/mudler/LocalAI/pkg/model"
|
|
||||||
. "github.com/onsi/ginkgo/v2"
|
. "github.com/onsi/ginkgo/v2"
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
@ -252,9 +251,6 @@ var _ = Describe("API test", func() {
|
|||||||
var cancel context.CancelFunc
|
var cancel context.CancelFunc
|
||||||
var tmpdir string
|
var tmpdir string
|
||||||
var modelDir string
|
var modelDir string
|
||||||
var bcl *config.BackendConfigLoader
|
|
||||||
var ml *model.ModelLoader
|
|
||||||
var applicationConfig *config.ApplicationConfig
|
|
||||||
|
|
||||||
commonOpts := []config.AppOption{
|
commonOpts := []config.AppOption{
|
||||||
config.WithDebug(true),
|
config.WithDebug(true),
|
||||||
@ -300,7 +296,7 @@ var _ = Describe("API test", func() {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
bcl, ml, applicationConfig, err = startup.Startup(
|
application, err := application.New(
|
||||||
append(commonOpts,
|
append(commonOpts,
|
||||||
config.WithContext(c),
|
config.WithContext(c),
|
||||||
config.WithGalleries(galleries),
|
config.WithGalleries(galleries),
|
||||||
@ -310,7 +306,7 @@ var _ = Describe("API test", func() {
|
|||||||
config.WithBackendAssetsOutput(backendAssetsDir))...)
|
config.WithBackendAssetsOutput(backendAssetsDir))...)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
app, err = App(bcl, ml, applicationConfig)
|
app, err = API(application)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
go app.Listen("127.0.0.1:9090")
|
go app.Listen("127.0.0.1:9090")
|
||||||
@ -539,7 +535,7 @@ var _ = Describe("API test", func() {
|
|||||||
var res map[string]string
|
var res map[string]string
|
||||||
err = json.Unmarshal([]byte(resp2.Choices[0].Message.FunctionCall.Arguments), &res)
|
err = json.Unmarshal([]byte(resp2.Choices[0].Message.FunctionCall.Arguments), &res)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(res["location"]).To(Equal("San Francisco"), fmt.Sprint(res))
|
Expect(res["location"]).To(ContainSubstring("San Francisco"), fmt.Sprint(res))
|
||||||
Expect(res["unit"]).To(Equal("celcius"), fmt.Sprint(res))
|
Expect(res["unit"]).To(Equal("celcius"), fmt.Sprint(res))
|
||||||
Expect(string(resp2.Choices[0].FinishReason)).To(Equal("function_call"), fmt.Sprint(resp2.Choices[0].FinishReason))
|
Expect(string(resp2.Choices[0].FinishReason)).To(Equal("function_call"), fmt.Sprint(resp2.Choices[0].FinishReason))
|
||||||
|
|
||||||
@ -641,7 +637,7 @@ var _ = Describe("API test", func() {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
bcl, ml, applicationConfig, err = startup.Startup(
|
application, err := application.New(
|
||||||
append(commonOpts,
|
append(commonOpts,
|
||||||
config.WithContext(c),
|
config.WithContext(c),
|
||||||
config.WithAudioDir(tmpdir),
|
config.WithAudioDir(tmpdir),
|
||||||
@ -652,7 +648,7 @@ var _ = Describe("API test", func() {
|
|||||||
config.WithBackendAssetsOutput(tmpdir))...,
|
config.WithBackendAssetsOutput(tmpdir))...,
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
app, err = App(bcl, ml, applicationConfig)
|
app, err = API(application)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
go app.Listen("127.0.0.1:9090")
|
go app.Listen("127.0.0.1:9090")
|
||||||
@ -772,14 +768,14 @@ var _ = Describe("API test", func() {
|
|||||||
|
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
bcl, ml, applicationConfig, err = startup.Startup(
|
application, err := application.New(
|
||||||
append(commonOpts,
|
append(commonOpts,
|
||||||
config.WithExternalBackend("huggingface", os.Getenv("HUGGINGFACE_GRPC")),
|
config.WithExternalBackend("huggingface", os.Getenv("HUGGINGFACE_GRPC")),
|
||||||
config.WithContext(c),
|
config.WithContext(c),
|
||||||
config.WithModelPath(modelPath),
|
config.WithModelPath(modelPath),
|
||||||
)...)
|
)...)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
app, err = App(bcl, ml, applicationConfig)
|
app, err = API(application)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
go app.Listen("127.0.0.1:9090")
|
go app.Listen("127.0.0.1:9090")
|
||||||
|
|
||||||
@ -990,14 +986,14 @@ var _ = Describe("API test", func() {
|
|||||||
c, cancel = context.WithCancel(context.Background())
|
c, cancel = context.WithCancel(context.Background())
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
bcl, ml, applicationConfig, err = startup.Startup(
|
application, err := application.New(
|
||||||
append(commonOpts,
|
append(commonOpts,
|
||||||
config.WithContext(c),
|
config.WithContext(c),
|
||||||
config.WithModelPath(modelPath),
|
config.WithModelPath(modelPath),
|
||||||
config.WithConfigFile(os.Getenv("CONFIG_FILE")))...,
|
config.WithConfigFile(os.Getenv("CONFIG_FILE")))...,
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
app, err = App(bcl, ml, applicationConfig)
|
app, err = API(application)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
go app.Listen("127.0.0.1:9090")
|
go app.Listen("127.0.0.1:9090")
|
||||||
|
@ -14,6 +14,8 @@ import (
|
|||||||
"github.com/mudler/LocalAI/core/config"
|
"github.com/mudler/LocalAI/core/config"
|
||||||
"github.com/mudler/LocalAI/core/schema"
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
"github.com/mudler/LocalAI/pkg/functions"
|
"github.com/mudler/LocalAI/pkg/functions"
|
||||||
|
"github.com/mudler/LocalAI/pkg/templates"
|
||||||
|
|
||||||
model "github.com/mudler/LocalAI/pkg/model"
|
model "github.com/mudler/LocalAI/pkg/model"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
@ -24,7 +26,7 @@ import (
|
|||||||
// @Param request body schema.OpenAIRequest true "query params"
|
// @Param request body schema.OpenAIRequest true "query params"
|
||||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||||
// @Router /v1/chat/completions [post]
|
// @Router /v1/chat/completions [post]
|
||||||
func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startupOptions *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, startupOptions *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
var id, textContentToReturn string
|
var id, textContentToReturn string
|
||||||
var created int
|
var created int
|
||||||
|
|
||||||
@ -298,148 +300,10 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
|
|||||||
// If we are using the tokenizer template, we don't need to process the messages
|
// If we are using the tokenizer template, we don't need to process the messages
|
||||||
// unless we are processing functions
|
// unless we are processing functions
|
||||||
if !config.TemplateConfig.UseTokenizerTemplate || shouldUseFn {
|
if !config.TemplateConfig.UseTokenizerTemplate || shouldUseFn {
|
||||||
suppressConfigSystemPrompt := false
|
predInput = evaluator.TemplateMessages(input.Messages, config, funcs, shouldUseFn)
|
||||||
mess := []string{}
|
|
||||||
for messageIndex, i := range input.Messages {
|
|
||||||
var content string
|
|
||||||
role := i.Role
|
|
||||||
|
|
||||||
// if function call, we might want to customize the role so we can display better that the "assistant called a json action"
|
|
||||||
// if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request
|
|
||||||
if (i.FunctionCall != nil || i.ToolCalls != nil) && i.Role == "assistant" {
|
|
||||||
roleFn := "assistant_function_call"
|
|
||||||
r := config.Roles[roleFn]
|
|
||||||
if r != "" {
|
|
||||||
role = roleFn
|
|
||||||
}
|
|
||||||
}
|
|
||||||
r := config.Roles[role]
|
|
||||||
contentExists := i.Content != nil && i.StringContent != ""
|
|
||||||
|
|
||||||
fcall := i.FunctionCall
|
|
||||||
if len(i.ToolCalls) > 0 {
|
|
||||||
fcall = i.ToolCalls
|
|
||||||
}
|
|
||||||
|
|
||||||
// First attempt to populate content via a chat message specific template
|
|
||||||
if config.TemplateConfig.ChatMessage != "" {
|
|
||||||
chatMessageData := model.ChatMessageTemplateData{
|
|
||||||
SystemPrompt: config.SystemPrompt,
|
|
||||||
Role: r,
|
|
||||||
RoleName: role,
|
|
||||||
Content: i.StringContent,
|
|
||||||
FunctionCall: fcall,
|
|
||||||
FunctionName: i.Name,
|
|
||||||
LastMessage: messageIndex == (len(input.Messages) - 1),
|
|
||||||
Function: config.Grammar != "" && (messageIndex == (len(input.Messages) - 1)),
|
|
||||||
MessageIndex: messageIndex,
|
|
||||||
}
|
|
||||||
templatedChatMessage, err := ml.EvaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Interface("message", chatMessageData).Str("template", config.TemplateConfig.ChatMessage).Msg("error processing message with template, skipping")
|
|
||||||
} else {
|
|
||||||
if templatedChatMessage == "" {
|
|
||||||
log.Warn().Msgf("template \"%s\" produced blank output for %+v. Skipping!", config.TemplateConfig.ChatMessage, chatMessageData)
|
|
||||||
continue // TODO: This continue is here intentionally to skip over the line `mess = append(mess, content)` below, and to prevent the sprintf
|
|
||||||
}
|
|
||||||
log.Debug().Msgf("templated message for chat: %s", templatedChatMessage)
|
|
||||||
content = templatedChatMessage
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
marshalAnyRole := func(f any) {
|
|
||||||
j, err := json.Marshal(f)
|
|
||||||
if err == nil {
|
|
||||||
if contentExists {
|
|
||||||
content += "\n" + fmt.Sprint(r, " ", string(j))
|
|
||||||
} else {
|
|
||||||
content = fmt.Sprint(r, " ", string(j))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
marshalAny := func(f any) {
|
|
||||||
j, err := json.Marshal(f)
|
|
||||||
if err == nil {
|
|
||||||
if contentExists {
|
|
||||||
content += "\n" + string(j)
|
|
||||||
} else {
|
|
||||||
content = string(j)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// If this model doesn't have such a template, or if that template fails to return a value, template at the message level.
|
|
||||||
if content == "" {
|
|
||||||
if r != "" {
|
|
||||||
if contentExists {
|
|
||||||
content = fmt.Sprint(r, i.StringContent)
|
|
||||||
}
|
|
||||||
|
|
||||||
if i.FunctionCall != nil {
|
|
||||||
marshalAnyRole(i.FunctionCall)
|
|
||||||
}
|
|
||||||
if i.ToolCalls != nil {
|
|
||||||
marshalAnyRole(i.ToolCalls)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if contentExists {
|
|
||||||
content = fmt.Sprint(i.StringContent)
|
|
||||||
}
|
|
||||||
if i.FunctionCall != nil {
|
|
||||||
marshalAny(i.FunctionCall)
|
|
||||||
}
|
|
||||||
if i.ToolCalls != nil {
|
|
||||||
marshalAny(i.ToolCalls)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Special Handling: System. We care if it was printed at all, not the r branch, so check seperately
|
|
||||||
if contentExists && role == "system" {
|
|
||||||
suppressConfigSystemPrompt = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
mess = append(mess, content)
|
|
||||||
}
|
|
||||||
|
|
||||||
joinCharacter := "\n"
|
|
||||||
if config.TemplateConfig.JoinChatMessagesByCharacter != nil {
|
|
||||||
joinCharacter = *config.TemplateConfig.JoinChatMessagesByCharacter
|
|
||||||
}
|
|
||||||
|
|
||||||
predInput = strings.Join(mess, joinCharacter)
|
|
||||||
log.Debug().Msgf("Prompt (before templating): %s", predInput)
|
|
||||||
|
|
||||||
templateFile := ""
|
|
||||||
|
|
||||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
|
||||||
if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
|
|
||||||
templateFile = config.Model
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.TemplateConfig.Chat != "" && !shouldUseFn {
|
|
||||||
templateFile = config.TemplateConfig.Chat
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.TemplateConfig.Functions != "" && shouldUseFn {
|
|
||||||
templateFile = config.TemplateConfig.Functions
|
|
||||||
}
|
|
||||||
|
|
||||||
if templateFile != "" {
|
|
||||||
templatedInput, err := ml.EvaluateTemplateForPrompt(model.ChatPromptTemplate, templateFile, model.PromptTemplateData{
|
|
||||||
SystemPrompt: config.SystemPrompt,
|
|
||||||
SuppressSystemPrompt: suppressConfigSystemPrompt,
|
|
||||||
Input: predInput,
|
|
||||||
Functions: funcs,
|
|
||||||
})
|
|
||||||
if err == nil {
|
|
||||||
predInput = templatedInput
|
|
||||||
log.Debug().Msgf("Template found, input modified to: %s", predInput)
|
|
||||||
} else {
|
|
||||||
log.Debug().Msgf("Template failed loading: %s", err.Error())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug().Msgf("Prompt (after templating): %s", predInput)
|
log.Debug().Msgf("Prompt (after templating): %s", predInput)
|
||||||
if shouldUseFn && config.Grammar != "" {
|
if config.Grammar != "" {
|
||||||
log.Debug().Msgf("Grammar: %+v", config.Grammar)
|
log.Debug().Msgf("Grammar: %+v", config.Grammar)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -16,6 +16,7 @@ import (
|
|||||||
"github.com/mudler/LocalAI/core/schema"
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
"github.com/mudler/LocalAI/pkg/functions"
|
"github.com/mudler/LocalAI/pkg/functions"
|
||||||
model "github.com/mudler/LocalAI/pkg/model"
|
model "github.com/mudler/LocalAI/pkg/model"
|
||||||
|
"github.com/mudler/LocalAI/pkg/templates"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
@ -25,7 +26,7 @@ import (
|
|||||||
// @Param request body schema.OpenAIRequest true "query params"
|
// @Param request body schema.OpenAIRequest true "query params"
|
||||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||||
// @Router /v1/completions [post]
|
// @Router /v1/completions [post]
|
||||||
func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
id := uuid.New().String()
|
id := uuid.New().String()
|
||||||
created := int(time.Now().Unix())
|
created := int(time.Now().Unix())
|
||||||
|
|
||||||
@ -94,17 +95,6 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a
|
|||||||
c.Set("Transfer-Encoding", "chunked")
|
c.Set("Transfer-Encoding", "chunked")
|
||||||
}
|
}
|
||||||
|
|
||||||
templateFile := ""
|
|
||||||
|
|
||||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
|
||||||
if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
|
|
||||||
templateFile = config.Model
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.TemplateConfig.Completion != "" {
|
|
||||||
templateFile = config.TemplateConfig.Completion
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.Stream {
|
if input.Stream {
|
||||||
if len(config.PromptStrings) > 1 {
|
if len(config.PromptStrings) > 1 {
|
||||||
return errors.New("cannot handle more than 1 `PromptStrings` when Streaming")
|
return errors.New("cannot handle more than 1 `PromptStrings` when Streaming")
|
||||||
@ -112,15 +102,13 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a
|
|||||||
|
|
||||||
predInput := config.PromptStrings[0]
|
predInput := config.PromptStrings[0]
|
||||||
|
|
||||||
if templateFile != "" {
|
templatedInput, err := evaluator.EvaluateTemplateForPrompt(templates.CompletionPromptTemplate, *config, templates.PromptTemplateData{
|
||||||
templatedInput, err := ml.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
|
Input: predInput,
|
||||||
Input: predInput,
|
SystemPrompt: config.SystemPrompt,
|
||||||
SystemPrompt: config.SystemPrompt,
|
})
|
||||||
})
|
if err == nil {
|
||||||
if err == nil {
|
predInput = templatedInput
|
||||||
predInput = templatedInput
|
log.Debug().Msgf("Template found, input modified to: %s", predInput)
|
||||||
log.Debug().Msgf("Template found, input modified to: %s", predInput)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
responses := make(chan schema.OpenAIResponse)
|
responses := make(chan schema.OpenAIResponse)
|
||||||
@ -165,16 +153,13 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a
|
|||||||
totalTokenUsage := backend.TokenUsage{}
|
totalTokenUsage := backend.TokenUsage{}
|
||||||
|
|
||||||
for k, i := range config.PromptStrings {
|
for k, i := range config.PromptStrings {
|
||||||
if templateFile != "" {
|
templatedInput, err := evaluator.EvaluateTemplateForPrompt(templates.CompletionPromptTemplate, *config, templates.PromptTemplateData{
|
||||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
SystemPrompt: config.SystemPrompt,
|
||||||
templatedInput, err := ml.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
|
Input: i,
|
||||||
SystemPrompt: config.SystemPrompt,
|
})
|
||||||
Input: i,
|
if err == nil {
|
||||||
})
|
i = templatedInput
|
||||||
if err == nil {
|
log.Debug().Msgf("Template found, input modified to: %s", i)
|
||||||
i = templatedInput
|
|
||||||
log.Debug().Msgf("Template found, input modified to: %s", i)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
r, tokenUsage, err := ComputeChoices(
|
r, tokenUsage, err := ComputeChoices(
|
||||||
|
@ -12,6 +12,7 @@ import (
|
|||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/mudler/LocalAI/core/schema"
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
model "github.com/mudler/LocalAI/pkg/model"
|
model "github.com/mudler/LocalAI/pkg/model"
|
||||||
|
"github.com/mudler/LocalAI/pkg/templates"
|
||||||
|
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
@ -21,7 +22,8 @@ import (
|
|||||||
// @Param request body schema.OpenAIRequest true "query params"
|
// @Param request body schema.OpenAIRequest true "query params"
|
||||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||||
// @Router /v1/edits [post]
|
// @Router /v1/edits [post]
|
||||||
func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
|
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
modelFile, input, err := readRequest(c, cl, ml, appConfig, true)
|
modelFile, input, err := readRequest(c, cl, ml, appConfig, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -35,31 +37,18 @@ func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConf
|
|||||||
|
|
||||||
log.Debug().Msgf("Parameter Config: %+v", config)
|
log.Debug().Msgf("Parameter Config: %+v", config)
|
||||||
|
|
||||||
templateFile := ""
|
|
||||||
|
|
||||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
|
||||||
if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
|
|
||||||
templateFile = config.Model
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.TemplateConfig.Edit != "" {
|
|
||||||
templateFile = config.TemplateConfig.Edit
|
|
||||||
}
|
|
||||||
|
|
||||||
var result []schema.Choice
|
var result []schema.Choice
|
||||||
totalTokenUsage := backend.TokenUsage{}
|
totalTokenUsage := backend.TokenUsage{}
|
||||||
|
|
||||||
for _, i := range config.InputStrings {
|
for _, i := range config.InputStrings {
|
||||||
if templateFile != "" {
|
templatedInput, err := evaluator.EvaluateTemplateForPrompt(templates.EditPromptTemplate, *config, templates.PromptTemplateData{
|
||||||
templatedInput, err := ml.EvaluateTemplateForPrompt(model.EditPromptTemplate, templateFile, model.PromptTemplateData{
|
Input: i,
|
||||||
Input: i,
|
Instruction: input.Instruction,
|
||||||
Instruction: input.Instruction,
|
SystemPrompt: config.SystemPrompt,
|
||||||
SystemPrompt: config.SystemPrompt,
|
})
|
||||||
})
|
if err == nil {
|
||||||
if err == nil {
|
i = templatedInput
|
||||||
i = templatedInput
|
log.Debug().Msgf("Template found, input modified to: %s", i)
|
||||||
log.Debug().Msgf("Template found, input modified to: %s", i)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
r, tokenUsage, err := ComputeChoices(input, i, config, appConfig, ml, func(s string, c *[]schema.Choice) {
|
r, tokenUsage, err := ComputeChoices(input, i, config, appConfig, ml, func(s string, c *[]schema.Choice) {
|
||||||
|
@ -11,62 +11,62 @@ import (
|
|||||||
"github.com/mudler/LocalAI/pkg/model"
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
func RegisterLocalAIRoutes(app *fiber.App,
|
func RegisterLocalAIRoutes(router *fiber.App,
|
||||||
cl *config.BackendConfigLoader,
|
cl *config.BackendConfigLoader,
|
||||||
ml *model.ModelLoader,
|
ml *model.ModelLoader,
|
||||||
appConfig *config.ApplicationConfig,
|
appConfig *config.ApplicationConfig,
|
||||||
galleryService *services.GalleryService) {
|
galleryService *services.GalleryService) {
|
||||||
|
|
||||||
app.Get("/swagger/*", swagger.HandlerDefault) // default
|
router.Get("/swagger/*", swagger.HandlerDefault) // default
|
||||||
|
|
||||||
// LocalAI API endpoints
|
// LocalAI API endpoints
|
||||||
if !appConfig.DisableGalleryEndpoint {
|
if !appConfig.DisableGalleryEndpoint {
|
||||||
modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.ModelPath, galleryService)
|
modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.ModelPath, galleryService)
|
||||||
app.Post("/models/apply", modelGalleryEndpointService.ApplyModelGalleryEndpoint())
|
router.Post("/models/apply", modelGalleryEndpointService.ApplyModelGalleryEndpoint())
|
||||||
app.Post("/models/delete/:name", modelGalleryEndpointService.DeleteModelGalleryEndpoint())
|
router.Post("/models/delete/:name", modelGalleryEndpointService.DeleteModelGalleryEndpoint())
|
||||||
|
|
||||||
app.Get("/models/available", modelGalleryEndpointService.ListModelFromGalleryEndpoint())
|
router.Get("/models/available", modelGalleryEndpointService.ListModelFromGalleryEndpoint())
|
||||||
app.Get("/models/galleries", modelGalleryEndpointService.ListModelGalleriesEndpoint())
|
router.Get("/models/galleries", modelGalleryEndpointService.ListModelGalleriesEndpoint())
|
||||||
app.Post("/models/galleries", modelGalleryEndpointService.AddModelGalleryEndpoint())
|
router.Post("/models/galleries", modelGalleryEndpointService.AddModelGalleryEndpoint())
|
||||||
app.Delete("/models/galleries", modelGalleryEndpointService.RemoveModelGalleryEndpoint())
|
router.Delete("/models/galleries", modelGalleryEndpointService.RemoveModelGalleryEndpoint())
|
||||||
app.Get("/models/jobs/:uuid", modelGalleryEndpointService.GetOpStatusEndpoint())
|
router.Get("/models/jobs/:uuid", modelGalleryEndpointService.GetOpStatusEndpoint())
|
||||||
app.Get("/models/jobs", modelGalleryEndpointService.GetAllStatusEndpoint())
|
router.Get("/models/jobs", modelGalleryEndpointService.GetAllStatusEndpoint())
|
||||||
}
|
}
|
||||||
|
|
||||||
app.Post("/tts", localai.TTSEndpoint(cl, ml, appConfig))
|
router.Post("/tts", localai.TTSEndpoint(cl, ml, appConfig))
|
||||||
app.Post("/vad", localai.VADEndpoint(cl, ml, appConfig))
|
router.Post("/vad", localai.VADEndpoint(cl, ml, appConfig))
|
||||||
|
|
||||||
// Stores
|
// Stores
|
||||||
sl := model.NewModelLoader("")
|
sl := model.NewModelLoader("")
|
||||||
app.Post("/stores/set", localai.StoresSetEndpoint(sl, appConfig))
|
router.Post("/stores/set", localai.StoresSetEndpoint(sl, appConfig))
|
||||||
app.Post("/stores/delete", localai.StoresDeleteEndpoint(sl, appConfig))
|
router.Post("/stores/delete", localai.StoresDeleteEndpoint(sl, appConfig))
|
||||||
app.Post("/stores/get", localai.StoresGetEndpoint(sl, appConfig))
|
router.Post("/stores/get", localai.StoresGetEndpoint(sl, appConfig))
|
||||||
app.Post("/stores/find", localai.StoresFindEndpoint(sl, appConfig))
|
router.Post("/stores/find", localai.StoresFindEndpoint(sl, appConfig))
|
||||||
|
|
||||||
if !appConfig.DisableMetrics {
|
if !appConfig.DisableMetrics {
|
||||||
app.Get("/metrics", localai.LocalAIMetricsEndpoint())
|
router.Get("/metrics", localai.LocalAIMetricsEndpoint())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Experimental Backend Statistics Module
|
// Experimental Backend Statistics Module
|
||||||
backendMonitorService := services.NewBackendMonitorService(ml, cl, appConfig) // Split out for now
|
backendMonitorService := services.NewBackendMonitorService(ml, cl, appConfig) // Split out for now
|
||||||
app.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService))
|
router.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService))
|
||||||
app.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService))
|
router.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService))
|
||||||
|
|
||||||
// p2p
|
// p2p
|
||||||
if p2p.IsP2PEnabled() {
|
if p2p.IsP2PEnabled() {
|
||||||
app.Get("/api/p2p", localai.ShowP2PNodes(appConfig))
|
router.Get("/api/p2p", localai.ShowP2PNodes(appConfig))
|
||||||
app.Get("/api/p2p/token", localai.ShowP2PToken(appConfig))
|
router.Get("/api/p2p/token", localai.ShowP2PToken(appConfig))
|
||||||
}
|
}
|
||||||
|
|
||||||
app.Get("/version", func(c *fiber.Ctx) error {
|
router.Get("/version", func(c *fiber.Ctx) error {
|
||||||
return c.JSON(struct {
|
return c.JSON(struct {
|
||||||
Version string `json:"version"`
|
Version string `json:"version"`
|
||||||
}{Version: internal.PrintableVersion()})
|
}{Version: internal.PrintableVersion()})
|
||||||
})
|
})
|
||||||
|
|
||||||
app.Get("/system", localai.SystemInformations(ml, appConfig))
|
router.Get("/system", localai.SystemInformations(ml, appConfig))
|
||||||
|
|
||||||
// misc
|
// misc
|
||||||
app.Post("/v1/tokenize", localai.TokenizeEndpoint(cl, ml, appConfig))
|
router.Post("/v1/tokenize", localai.TokenizeEndpoint(cl, ml, appConfig))
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -2,84 +2,134 @@ package routes
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/mudler/LocalAI/core/config"
|
"github.com/mudler/LocalAI/core/application"
|
||||||
"github.com/mudler/LocalAI/core/http/endpoints/localai"
|
"github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||||
"github.com/mudler/LocalAI/core/http/endpoints/openai"
|
"github.com/mudler/LocalAI/core/http/endpoints/openai"
|
||||||
"github.com/mudler/LocalAI/pkg/model"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func RegisterOpenAIRoutes(app *fiber.App,
|
func RegisterOpenAIRoutes(app *fiber.App,
|
||||||
cl *config.BackendConfigLoader,
|
application *application.Application) {
|
||||||
ml *model.ModelLoader,
|
|
||||||
appConfig *config.ApplicationConfig) {
|
|
||||||
// openAI compatible API endpoint
|
// openAI compatible API endpoint
|
||||||
|
|
||||||
// chat
|
// chat
|
||||||
app.Post("/v1/chat/completions", openai.ChatEndpoint(cl, ml, appConfig))
|
app.Post("/v1/chat/completions",
|
||||||
app.Post("/chat/completions", openai.ChatEndpoint(cl, ml, appConfig))
|
openai.ChatEndpoint(
|
||||||
|
application.BackendLoader(),
|
||||||
|
application.ModelLoader(),
|
||||||
|
application.TemplatesEvaluator(),
|
||||||
|
application.ApplicationConfig(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
app.Post("/chat/completions",
|
||||||
|
openai.ChatEndpoint(
|
||||||
|
application.BackendLoader(),
|
||||||
|
application.ModelLoader(),
|
||||||
|
application.TemplatesEvaluator(),
|
||||||
|
application.ApplicationConfig(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
// edit
|
// edit
|
||||||
app.Post("/v1/edits", openai.EditEndpoint(cl, ml, appConfig))
|
app.Post("/v1/edits",
|
||||||
app.Post("/edits", openai.EditEndpoint(cl, ml, appConfig))
|
openai.EditEndpoint(
|
||||||
|
application.BackendLoader(),
|
||||||
|
application.ModelLoader(),
|
||||||
|
application.TemplatesEvaluator(),
|
||||||
|
application.ApplicationConfig(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
app.Post("/edits",
|
||||||
|
openai.EditEndpoint(
|
||||||
|
application.BackendLoader(),
|
||||||
|
application.ModelLoader(),
|
||||||
|
application.TemplatesEvaluator(),
|
||||||
|
application.ApplicationConfig(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
// assistant
|
// assistant
|
||||||
app.Get("/v1/assistants", openai.ListAssistantsEndpoint(cl, ml, appConfig))
|
app.Get("/v1/assistants", openai.ListAssistantsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||||
app.Get("/assistants", openai.ListAssistantsEndpoint(cl, ml, appConfig))
|
app.Get("/assistants", openai.ListAssistantsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||||
app.Post("/v1/assistants", openai.CreateAssistantEndpoint(cl, ml, appConfig))
|
app.Post("/v1/assistants", openai.CreateAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||||
app.Post("/assistants", openai.CreateAssistantEndpoint(cl, ml, appConfig))
|
app.Post("/assistants", openai.CreateAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||||
app.Delete("/v1/assistants/:assistant_id", openai.DeleteAssistantEndpoint(cl, ml, appConfig))
|
app.Delete("/v1/assistants/:assistant_id", openai.DeleteAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||||
app.Delete("/assistants/:assistant_id", openai.DeleteAssistantEndpoint(cl, ml, appConfig))
|
app.Delete("/assistants/:assistant_id", openai.DeleteAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||||
app.Get("/v1/assistants/:assistant_id", openai.GetAssistantEndpoint(cl, ml, appConfig))
|
app.Get("/v1/assistants/:assistant_id", openai.GetAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||||
app.Get("/assistants/:assistant_id", openai.GetAssistantEndpoint(cl, ml, appConfig))
|
app.Get("/assistants/:assistant_id", openai.GetAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||||
app.Post("/v1/assistants/:assistant_id", openai.ModifyAssistantEndpoint(cl, ml, appConfig))
|
app.Post("/v1/assistants/:assistant_id", openai.ModifyAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||||
app.Post("/assistants/:assistant_id", openai.ModifyAssistantEndpoint(cl, ml, appConfig))
|
app.Post("/assistants/:assistant_id", openai.ModifyAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||||
app.Get("/v1/assistants/:assistant_id/files", openai.ListAssistantFilesEndpoint(cl, ml, appConfig))
|
app.Get("/v1/assistants/:assistant_id/files", openai.ListAssistantFilesEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||||
app.Get("/assistants/:assistant_id/files", openai.ListAssistantFilesEndpoint(cl, ml, appConfig))
|
app.Get("/assistants/:assistant_id/files", openai.ListAssistantFilesEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||||
app.Post("/v1/assistants/:assistant_id/files", openai.CreateAssistantFileEndpoint(cl, ml, appConfig))
|
app.Post("/v1/assistants/:assistant_id/files", openai.CreateAssistantFileEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||||
app.Post("/assistants/:assistant_id/files", openai.CreateAssistantFileEndpoint(cl, ml, appConfig))
|
app.Post("/assistants/:assistant_id/files", openai.CreateAssistantFileEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||||
app.Delete("/v1/assistants/:assistant_id/files/:file_id", openai.DeleteAssistantFileEndpoint(cl, ml, appConfig))
|
app.Delete("/v1/assistants/:assistant_id/files/:file_id", openai.DeleteAssistantFileEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||||
app.Delete("/assistants/:assistant_id/files/:file_id", openai.DeleteAssistantFileEndpoint(cl, ml, appConfig))
|
app.Delete("/assistants/:assistant_id/files/:file_id", openai.DeleteAssistantFileEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||||
app.Get("/v1/assistants/:assistant_id/files/:file_id", openai.GetAssistantFileEndpoint(cl, ml, appConfig))
|
app.Get("/v1/assistants/:assistant_id/files/:file_id", openai.GetAssistantFileEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||||
app.Get("/assistants/:assistant_id/files/:file_id", openai.GetAssistantFileEndpoint(cl, ml, appConfig))
|
app.Get("/assistants/:assistant_id/files/:file_id", openai.GetAssistantFileEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||||
|
|
||||||
// files
|
// files
|
||||||
app.Post("/v1/files", openai.UploadFilesEndpoint(cl, appConfig))
|
app.Post("/v1/files", openai.UploadFilesEndpoint(application.BackendLoader(), application.ApplicationConfig()))
|
||||||
app.Post("/files", openai.UploadFilesEndpoint(cl, appConfig))
|
app.Post("/files", openai.UploadFilesEndpoint(application.BackendLoader(), application.ApplicationConfig()))
|
||||||
app.Get("/v1/files", openai.ListFilesEndpoint(cl, appConfig))
|
app.Get("/v1/files", openai.ListFilesEndpoint(application.BackendLoader(), application.ApplicationConfig()))
|
||||||
app.Get("/files", openai.ListFilesEndpoint(cl, appConfig))
|
app.Get("/files", openai.ListFilesEndpoint(application.BackendLoader(), application.ApplicationConfig()))
|
||||||
app.Get("/v1/files/:file_id", openai.GetFilesEndpoint(cl, appConfig))
|
app.Get("/v1/files/:file_id", openai.GetFilesEndpoint(application.BackendLoader(), application.ApplicationConfig()))
|
||||||
app.Get("/files/:file_id", openai.GetFilesEndpoint(cl, appConfig))
|
app.Get("/files/:file_id", openai.GetFilesEndpoint(application.BackendLoader(), application.ApplicationConfig()))
|
||||||
app.Delete("/v1/files/:file_id", openai.DeleteFilesEndpoint(cl, appConfig))
|
app.Delete("/v1/files/:file_id", openai.DeleteFilesEndpoint(application.BackendLoader(), application.ApplicationConfig()))
|
||||||
app.Delete("/files/:file_id", openai.DeleteFilesEndpoint(cl, appConfig))
|
app.Delete("/files/:file_id", openai.DeleteFilesEndpoint(application.BackendLoader(), application.ApplicationConfig()))
|
||||||
app.Get("/v1/files/:file_id/content", openai.GetFilesContentsEndpoint(cl, appConfig))
|
app.Get("/v1/files/:file_id/content", openai.GetFilesContentsEndpoint(application.BackendLoader(), application.ApplicationConfig()))
|
||||||
app.Get("/files/:file_id/content", openai.GetFilesContentsEndpoint(cl, appConfig))
|
app.Get("/files/:file_id/content", openai.GetFilesContentsEndpoint(application.BackendLoader(), application.ApplicationConfig()))
|
||||||
|
|
||||||
// completion
|
// completion
|
||||||
app.Post("/v1/completions", openai.CompletionEndpoint(cl, ml, appConfig))
|
app.Post("/v1/completions",
|
||||||
app.Post("/completions", openai.CompletionEndpoint(cl, ml, appConfig))
|
openai.CompletionEndpoint(
|
||||||
app.Post("/v1/engines/:model/completions", openai.CompletionEndpoint(cl, ml, appConfig))
|
application.BackendLoader(),
|
||||||
|
application.ModelLoader(),
|
||||||
|
application.TemplatesEvaluator(),
|
||||||
|
application.ApplicationConfig(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
app.Post("/completions",
|
||||||
|
openai.CompletionEndpoint(
|
||||||
|
application.BackendLoader(),
|
||||||
|
application.ModelLoader(),
|
||||||
|
application.TemplatesEvaluator(),
|
||||||
|
application.ApplicationConfig(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
app.Post("/v1/engines/:model/completions",
|
||||||
|
openai.CompletionEndpoint(
|
||||||
|
application.BackendLoader(),
|
||||||
|
application.ModelLoader(),
|
||||||
|
application.TemplatesEvaluator(),
|
||||||
|
application.ApplicationConfig(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
// embeddings
|
// embeddings
|
||||||
app.Post("/v1/embeddings", openai.EmbeddingsEndpoint(cl, ml, appConfig))
|
app.Post("/v1/embeddings", openai.EmbeddingsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||||
app.Post("/embeddings", openai.EmbeddingsEndpoint(cl, ml, appConfig))
|
app.Post("/embeddings", openai.EmbeddingsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||||
app.Post("/v1/engines/:model/embeddings", openai.EmbeddingsEndpoint(cl, ml, appConfig))
|
app.Post("/v1/engines/:model/embeddings", openai.EmbeddingsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||||
|
|
||||||
// audio
|
// audio
|
||||||
app.Post("/v1/audio/transcriptions", openai.TranscriptEndpoint(cl, ml, appConfig))
|
app.Post("/v1/audio/transcriptions", openai.TranscriptEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||||
app.Post("/v1/audio/speech", localai.TTSEndpoint(cl, ml, appConfig))
|
app.Post("/v1/audio/speech", localai.TTSEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||||
|
|
||||||
// images
|
// images
|
||||||
app.Post("/v1/images/generations", openai.ImageEndpoint(cl, ml, appConfig))
|
app.Post("/v1/images/generations", openai.ImageEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||||
|
|
||||||
if appConfig.ImageDir != "" {
|
if application.ApplicationConfig().ImageDir != "" {
|
||||||
app.Static("/generated-images", appConfig.ImageDir)
|
app.Static("/generated-images", application.ApplicationConfig().ImageDir)
|
||||||
}
|
}
|
||||||
|
|
||||||
if appConfig.AudioDir != "" {
|
if application.ApplicationConfig().AudioDir != "" {
|
||||||
app.Static("/generated-audio", appConfig.AudioDir)
|
app.Static("/generated-audio", application.ApplicationConfig().AudioDir)
|
||||||
}
|
}
|
||||||
|
|
||||||
// List models
|
// List models
|
||||||
app.Get("/v1/models", openai.ListModelsEndpoint(cl, ml))
|
app.Get("/v1/models", openai.ListModelsEndpoint(application.BackendLoader(), application.ModelLoader()))
|
||||||
app.Get("/models", openai.ListModelsEndpoint(cl, ml))
|
app.Get("/models", openai.ListModelsEndpoint(application.BackendLoader(), application.ModelLoader()))
|
||||||
}
|
}
|
||||||
|
5
go.mod
5
go.mod
@ -76,6 +76,7 @@ require (
|
|||||||
cloud.google.com/go/auth v0.4.1 // indirect
|
cloud.google.com/go/auth v0.4.1 // indirect
|
||||||
cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect
|
cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect
|
||||||
cloud.google.com/go/compute/metadata v0.3.0 // indirect
|
cloud.google.com/go/compute/metadata v0.3.0 // indirect
|
||||||
|
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||||
github.com/envoyproxy/protoc-gen-validate v1.0.4 // indirect
|
github.com/envoyproxy/protoc-gen-validate v1.0.4 // indirect
|
||||||
github.com/fasthttp/websocket v1.5.3 // indirect
|
github.com/fasthttp/websocket v1.5.3 // indirect
|
||||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||||
@ -84,8 +85,12 @@ require (
|
|||||||
github.com/google/s2a-go v0.1.7 // indirect
|
github.com/google/s2a-go v0.1.7 // indirect
|
||||||
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
|
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
|
||||||
github.com/googleapis/gax-go/v2 v2.12.4 // indirect
|
github.com/googleapis/gax-go/v2 v2.12.4 // indirect
|
||||||
|
github.com/json-iterator/go v1.1.12 // indirect
|
||||||
github.com/moby/docker-image-spec v1.3.1 // indirect
|
github.com/moby/docker-image-spec v1.3.1 // indirect
|
||||||
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||||
|
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||||
|
github.com/nikolalohinski/gonja/v2 v2.3.2 // indirect
|
||||||
github.com/pion/datachannel v1.5.8 // indirect
|
github.com/pion/datachannel v1.5.8 // indirect
|
||||||
github.com/pion/dtls/v2 v2.2.12 // indirect
|
github.com/pion/dtls/v2 v2.2.12 // indirect
|
||||||
github.com/pion/ice/v2 v2.3.34 // indirect
|
github.com/pion/ice/v2 v2.3.34 // indirect
|
||||||
|
12
go.sum
12
go.sum
@ -140,6 +140,8 @@ github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5 h1:iFaUwBSo5Svw6L
|
|||||||
github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5/go.mod h1:qssHWj60/X5sZFNxpG4HBPDHVqxNm4DfnCKgrbZOT+s=
|
github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5/go.mod h1:qssHWj60/X5sZFNxpG4HBPDHVqxNm4DfnCKgrbZOT+s=
|
||||||
github.com/dsnet/golib v0.0.0-20171103203638-1ea166775780/go.mod h1:Lj+Z9rebOhdfkVLjJ8T6VcRQv3SXugXy999NBtR9aFY=
|
github.com/dsnet/golib v0.0.0-20171103203638-1ea166775780/go.mod h1:Lj+Z9rebOhdfkVLjJ8T6VcRQv3SXugXy999NBtR9aFY=
|
||||||
github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
|
github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
|
||||||
|
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||||
|
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||||
github.com/elastic/gosigar v0.12.0/go.mod h1:iXRIGg2tLnu7LBdpqzyQfGDEidKCfWcCMS0WKyPWoMs=
|
github.com/elastic/gosigar v0.12.0/go.mod h1:iXRIGg2tLnu7LBdpqzyQfGDEidKCfWcCMS0WKyPWoMs=
|
||||||
github.com/elastic/gosigar v0.14.3 h1:xwkKwPia+hSfg9GqrCUKYdId102m9qTJIIr7egmK/uo=
|
github.com/elastic/gosigar v0.14.3 h1:xwkKwPia+hSfg9GqrCUKYdId102m9qTJIIr7egmK/uo=
|
||||||
github.com/elastic/gosigar v0.14.3/go.mod h1:iXRIGg2tLnu7LBdpqzyQfGDEidKCfWcCMS0WKyPWoMs=
|
github.com/elastic/gosigar v0.14.3/go.mod h1:iXRIGg2tLnu7LBdpqzyQfGDEidKCfWcCMS0WKyPWoMs=
|
||||||
@ -268,6 +270,7 @@ github.com/google/go-containerregistry v0.19.2 h1:TannFKE1QSajsP6hPWb5oJNgKe1IKj
|
|||||||
github.com/google/go-containerregistry v0.19.2/go.mod h1:YCMFNQeeXeLF+dnhhWkqDItx/JSkH01j1Kis4PsjzFI=
|
github.com/google/go-containerregistry v0.19.2/go.mod h1:YCMFNQeeXeLF+dnhhWkqDItx/JSkH01j1Kis4PsjzFI=
|
||||||
github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ=
|
github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ=
|
||||||
github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck=
|
github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck=
|
||||||
|
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||||
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
|
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
|
||||||
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
|
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
|
||||||
github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs=
|
github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs=
|
||||||
@ -353,6 +356,8 @@ github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwA
|
|||||||
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
|
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
|
||||||
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
|
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
|
||||||
github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
|
github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
|
||||||
|
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||||
|
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||||
github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU=
|
github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU=
|
||||||
github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo=
|
github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo=
|
||||||
github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU=
|
github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU=
|
||||||
@ -474,8 +479,12 @@ github.com/moby/sys/sequential v0.5.0 h1:OPvI35Lzn9K04PBbCLW0g4LcFAJgHsvXsRyewg5
|
|||||||
github.com/moby/sys/sequential v0.5.0/go.mod h1:tH2cOOs5V9MlPiXcQzRC+eEyab644PWKGRYaaV5ZZlo=
|
github.com/moby/sys/sequential v0.5.0/go.mod h1:tH2cOOs5V9MlPiXcQzRC+eEyab644PWKGRYaaV5ZZlo=
|
||||||
github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0=
|
github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0=
|
||||||
github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y=
|
github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y=
|
||||||
|
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||||
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||||
github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
|
github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
|
||||||
|
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||||
|
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||||
github.com/mr-tron/base58 v1.1.2/go.mod h1:BinMc/sQntlIE1frQmRFPUoPA1Zkr8VRgBdjWI2mNwc=
|
github.com/mr-tron/base58 v1.1.2/go.mod h1:BinMc/sQntlIE1frQmRFPUoPA1Zkr8VRgBdjWI2mNwc=
|
||||||
github.com/mr-tron/base58 v1.2.0 h1:T/HDJBh4ZCPbU39/+c3rRvE0uKBQlU27+QI8LJ4t64o=
|
github.com/mr-tron/base58 v1.2.0 h1:T/HDJBh4ZCPbU39/+c3rRvE0uKBQlU27+QI8LJ4t64o=
|
||||||
github.com/mr-tron/base58 v1.2.0/go.mod h1:BinMc/sQntlIE1frQmRFPUoPA1Zkr8VRgBdjWI2mNwc=
|
github.com/mr-tron/base58 v1.2.0/go.mod h1:BinMc/sQntlIE1frQmRFPUoPA1Zkr8VRgBdjWI2mNwc=
|
||||||
@ -519,6 +528,9 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq
|
|||||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
|
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
|
||||||
github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo=
|
github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo=
|
||||||
github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM=
|
github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM=
|
||||||
|
github.com/nikolalohinski/gonja v1.5.3 h1:GsA+EEaZDZPGJ8JtpeGN78jidhOlxeJROpqMT9fTj9c=
|
||||||
|
github.com/nikolalohinski/gonja/v2 v2.3.2 h1:UgLFfqi7L9XfX0PEcE4eUpvGojVQL5KhBfJJaBp7ZxY=
|
||||||
|
github.com/nikolalohinski/gonja/v2 v2.3.2/go.mod h1:1Wcc/5huTu6y36e0sOFR1XQoFlylw3c3H3L5WOz0RDg=
|
||||||
github.com/nwaples/rardecode v1.1.0 h1:vSxaY8vQhOcVr4mm5e8XllHWTiM4JF507A0Katqw7MQ=
|
github.com/nwaples/rardecode v1.1.0 h1:vSxaY8vQhOcVr4mm5e8XllHWTiM4JF507A0Katqw7MQ=
|
||||||
github.com/nwaples/rardecode v1.1.0/go.mod h1:5DzqNKiOdpKKBH87u8VlvAnPZMXcGRhxWkRpHbbfGS0=
|
github.com/nwaples/rardecode v1.1.0/go.mod h1:5DzqNKiOdpKKBH87u8VlvAnPZMXcGRhxWkRpHbbfGS0=
|
||||||
github.com/nxadm/tail v1.4.11 h1:8feyoE3OzPrcshW5/MJ4sGESc5cqmGkGCWlco4l0bqY=
|
github.com/nxadm/tail v1.4.11 h1:8feyoE3OzPrcshW5/MJ4sGESc5cqmGkGCWlco4l0bqY=
|
||||||
|
@ -9,8 +9,6 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/pkg/templates"
|
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/pkg/utils"
|
"github.com/mudler/LocalAI/pkg/utils"
|
||||||
|
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
@ -23,7 +21,6 @@ type ModelLoader struct {
|
|||||||
ModelPath string
|
ModelPath string
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
models map[string]*Model
|
models map[string]*Model
|
||||||
templates *templates.TemplateCache
|
|
||||||
wd *WatchDog
|
wd *WatchDog
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -31,7 +28,6 @@ func NewModelLoader(modelPath string) *ModelLoader {
|
|||||||
nml := &ModelLoader{
|
nml := &ModelLoader{
|
||||||
ModelPath: modelPath,
|
ModelPath: modelPath,
|
||||||
models: make(map[string]*Model),
|
models: make(map[string]*Model),
|
||||||
templates: templates.NewTemplateCache(modelPath),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nml
|
return nml
|
||||||
|
@ -1,52 +0,0 @@
|
|||||||
package model
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/pkg/functions"
|
|
||||||
"github.com/mudler/LocalAI/pkg/templates"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Rather than pass an interface{} to the prompt template:
|
|
||||||
// These are the definitions of all possible variables LocalAI will currently populate for use in a prompt template file
|
|
||||||
// Please note: Not all of these are populated on every endpoint - your template should either be tested for each endpoint you map it to, or tolerant of zero values.
|
|
||||||
type PromptTemplateData struct {
|
|
||||||
SystemPrompt string
|
|
||||||
SuppressSystemPrompt bool // used by chat specifically to indicate that SystemPrompt above should be _ignored_
|
|
||||||
Input string
|
|
||||||
Instruction string
|
|
||||||
Functions []functions.Function
|
|
||||||
MessageIndex int
|
|
||||||
}
|
|
||||||
|
|
||||||
type ChatMessageTemplateData struct {
|
|
||||||
SystemPrompt string
|
|
||||||
Role string
|
|
||||||
RoleName string
|
|
||||||
FunctionName string
|
|
||||||
Content string
|
|
||||||
MessageIndex int
|
|
||||||
Function bool
|
|
||||||
FunctionCall interface{}
|
|
||||||
LastMessage bool
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
ChatPromptTemplate templates.TemplateType = iota
|
|
||||||
ChatMessageTemplate
|
|
||||||
CompletionPromptTemplate
|
|
||||||
EditPromptTemplate
|
|
||||||
FunctionsPromptTemplate
|
|
||||||
)
|
|
||||||
|
|
||||||
func (ml *ModelLoader) EvaluateTemplateForPrompt(templateType templates.TemplateType, templateName string, in PromptTemplateData) (string, error) {
|
|
||||||
// TODO: should this check be improved?
|
|
||||||
if templateType == ChatMessageTemplate {
|
|
||||||
return "", fmt.Errorf("invalid templateType: ChatMessage")
|
|
||||||
}
|
|
||||||
return ml.templates.EvaluateTemplate(templateType, templateName, in)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ml *ModelLoader) EvaluateTemplateForChatMessage(templateName string, messageData ChatMessageTemplateData) (string, error) {
|
|
||||||
return ml.templates.EvaluateTemplate(ChatMessageTemplate, templateName, messageData)
|
|
||||||
}
|
|
@ -1,197 +0,0 @@
|
|||||||
package model_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
. "github.com/mudler/LocalAI/pkg/model"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo/v2"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
const chatML = `<|im_start|>{{if eq .RoleName "assistant"}}assistant{{else if eq .RoleName "system"}}system{{else if eq .RoleName "tool"}}tool{{else if eq .RoleName "user"}}user{{end}}
|
|
||||||
{{- if .FunctionCall }}
|
|
||||||
<tool_call>
|
|
||||||
{{- else if eq .RoleName "tool" }}
|
|
||||||
<tool_response>
|
|
||||||
{{- end }}
|
|
||||||
{{- if .Content}}
|
|
||||||
{{.Content }}
|
|
||||||
{{- end }}
|
|
||||||
{{- if .FunctionCall}}
|
|
||||||
{{toJson .FunctionCall}}
|
|
||||||
{{- end }}
|
|
||||||
{{- if .FunctionCall }}
|
|
||||||
</tool_call>
|
|
||||||
{{- else if eq .RoleName "tool" }}
|
|
||||||
</tool_response>
|
|
||||||
{{- end }}<|im_end|>`
|
|
||||||
|
|
||||||
const llama3 = `<|start_header_id|>{{if eq .RoleName "assistant"}}assistant{{else if eq .RoleName "system"}}system{{else if eq .RoleName "tool"}}tool{{else if eq .RoleName "user"}}user{{end}}<|end_header_id|>
|
|
||||||
|
|
||||||
{{ if .FunctionCall -}}
|
|
||||||
Function call:
|
|
||||||
{{ else if eq .RoleName "tool" -}}
|
|
||||||
Function response:
|
|
||||||
{{ end -}}
|
|
||||||
{{ if .Content -}}
|
|
||||||
{{.Content -}}
|
|
||||||
{{ else if .FunctionCall -}}
|
|
||||||
{{ toJson .FunctionCall -}}
|
|
||||||
{{ end -}}
|
|
||||||
<|eot_id|>`
|
|
||||||
|
|
||||||
var llama3TestMatch map[string]map[string]interface{} = map[string]map[string]interface{}{
|
|
||||||
"user": {
|
|
||||||
"template": llama3,
|
|
||||||
"expected": "<|start_header_id|>user<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|>",
|
|
||||||
"data": ChatMessageTemplateData{
|
|
||||||
SystemPrompt: "",
|
|
||||||
Role: "user",
|
|
||||||
RoleName: "user",
|
|
||||||
Content: "A long time ago in a galaxy far, far away...",
|
|
||||||
FunctionCall: nil,
|
|
||||||
FunctionName: "",
|
|
||||||
LastMessage: false,
|
|
||||||
Function: false,
|
|
||||||
MessageIndex: 0,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"assistant": {
|
|
||||||
"template": llama3,
|
|
||||||
"expected": "<|start_header_id|>assistant<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|>",
|
|
||||||
"data": ChatMessageTemplateData{
|
|
||||||
SystemPrompt: "",
|
|
||||||
Role: "assistant",
|
|
||||||
RoleName: "assistant",
|
|
||||||
Content: "A long time ago in a galaxy far, far away...",
|
|
||||||
FunctionCall: nil,
|
|
||||||
FunctionName: "",
|
|
||||||
LastMessage: false,
|
|
||||||
Function: false,
|
|
||||||
MessageIndex: 0,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"function_call": {
|
|
||||||
"template": llama3,
|
|
||||||
"expected": "<|start_header_id|>assistant<|end_header_id|>\n\nFunction call:\n{\"function\":\"test\"}<|eot_id|>",
|
|
||||||
"data": ChatMessageTemplateData{
|
|
||||||
SystemPrompt: "",
|
|
||||||
Role: "assistant",
|
|
||||||
RoleName: "assistant",
|
|
||||||
Content: "",
|
|
||||||
FunctionCall: map[string]string{"function": "test"},
|
|
||||||
FunctionName: "",
|
|
||||||
LastMessage: false,
|
|
||||||
Function: false,
|
|
||||||
MessageIndex: 0,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"function_response": {
|
|
||||||
"template": llama3,
|
|
||||||
"expected": "<|start_header_id|>tool<|end_header_id|>\n\nFunction response:\nResponse from tool<|eot_id|>",
|
|
||||||
"data": ChatMessageTemplateData{
|
|
||||||
SystemPrompt: "",
|
|
||||||
Role: "tool",
|
|
||||||
RoleName: "tool",
|
|
||||||
Content: "Response from tool",
|
|
||||||
FunctionCall: nil,
|
|
||||||
FunctionName: "",
|
|
||||||
LastMessage: false,
|
|
||||||
Function: false,
|
|
||||||
MessageIndex: 0,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
var chatMLTestMatch map[string]map[string]interface{} = map[string]map[string]interface{}{
|
|
||||||
"user": {
|
|
||||||
"template": chatML,
|
|
||||||
"expected": "<|im_start|>user\nA long time ago in a galaxy far, far away...<|im_end|>",
|
|
||||||
"data": ChatMessageTemplateData{
|
|
||||||
SystemPrompt: "",
|
|
||||||
Role: "user",
|
|
||||||
RoleName: "user",
|
|
||||||
Content: "A long time ago in a galaxy far, far away...",
|
|
||||||
FunctionCall: nil,
|
|
||||||
FunctionName: "",
|
|
||||||
LastMessage: false,
|
|
||||||
Function: false,
|
|
||||||
MessageIndex: 0,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"assistant": {
|
|
||||||
"template": chatML,
|
|
||||||
"expected": "<|im_start|>assistant\nA long time ago in a galaxy far, far away...<|im_end|>",
|
|
||||||
"data": ChatMessageTemplateData{
|
|
||||||
SystemPrompt: "",
|
|
||||||
Role: "assistant",
|
|
||||||
RoleName: "assistant",
|
|
||||||
Content: "A long time ago in a galaxy far, far away...",
|
|
||||||
FunctionCall: nil,
|
|
||||||
FunctionName: "",
|
|
||||||
LastMessage: false,
|
|
||||||
Function: false,
|
|
||||||
MessageIndex: 0,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"function_call": {
|
|
||||||
"template": chatML,
|
|
||||||
"expected": "<|im_start|>assistant\n<tool_call>\n{\"function\":\"test\"}\n</tool_call><|im_end|>",
|
|
||||||
"data": ChatMessageTemplateData{
|
|
||||||
SystemPrompt: "",
|
|
||||||
Role: "assistant",
|
|
||||||
RoleName: "assistant",
|
|
||||||
Content: "",
|
|
||||||
FunctionCall: map[string]string{"function": "test"},
|
|
||||||
FunctionName: "",
|
|
||||||
LastMessage: false,
|
|
||||||
Function: false,
|
|
||||||
MessageIndex: 0,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"function_response": {
|
|
||||||
"template": chatML,
|
|
||||||
"expected": "<|im_start|>tool\n<tool_response>\nResponse from tool\n</tool_response><|im_end|>",
|
|
||||||
"data": ChatMessageTemplateData{
|
|
||||||
SystemPrompt: "",
|
|
||||||
Role: "tool",
|
|
||||||
RoleName: "tool",
|
|
||||||
Content: "Response from tool",
|
|
||||||
FunctionCall: nil,
|
|
||||||
FunctionName: "",
|
|
||||||
LastMessage: false,
|
|
||||||
Function: false,
|
|
||||||
MessageIndex: 0,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ = Describe("Templates", func() {
|
|
||||||
Context("chat message ChatML", func() {
|
|
||||||
var modelLoader *ModelLoader
|
|
||||||
BeforeEach(func() {
|
|
||||||
modelLoader = NewModelLoader("")
|
|
||||||
})
|
|
||||||
for key := range chatMLTestMatch {
|
|
||||||
foo := chatMLTestMatch[key]
|
|
||||||
It("renders correctly `"+key+"`", func() {
|
|
||||||
templated, err := modelLoader.EvaluateTemplateForChatMessage(foo["template"].(string), foo["data"].(ChatMessageTemplateData))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(templated).To(Equal(foo["expected"]), templated)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
})
|
|
||||||
Context("chat message llama3", func() {
|
|
||||||
var modelLoader *ModelLoader
|
|
||||||
BeforeEach(func() {
|
|
||||||
modelLoader = NewModelLoader("")
|
|
||||||
})
|
|
||||||
for key := range llama3TestMatch {
|
|
||||||
foo := llama3TestMatch[key]
|
|
||||||
It("renders correctly `"+key+"`", func() {
|
|
||||||
templated, err := modelLoader.EvaluateTemplateForChatMessage(foo["template"].(string), foo["data"].(ChatMessageTemplateData))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(templated).To(Equal(foo["expected"]), templated)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
})
|
|
||||||
})
|
|
@ -11,59 +11,41 @@ import (
|
|||||||
"github.com/mudler/LocalAI/pkg/utils"
|
"github.com/mudler/LocalAI/pkg/utils"
|
||||||
|
|
||||||
"github.com/Masterminds/sprig/v3"
|
"github.com/Masterminds/sprig/v3"
|
||||||
|
|
||||||
|
"github.com/nikolalohinski/gonja/v2"
|
||||||
|
"github.com/nikolalohinski/gonja/v2/exec"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Keep this in sync with config.TemplateConfig. Is there a more idiomatic way to accomplish this in go?
|
// Keep this in sync with config.TemplateConfig. Is there a more idiomatic way to accomplish this in go?
|
||||||
// Technically, order doesn't _really_ matter, but the count must stay in sync, see tests/integration/reflect_test.go
|
// Technically, order doesn't _really_ matter, but the count must stay in sync, see tests/integration/reflect_test.go
|
||||||
type TemplateType int
|
type TemplateType int
|
||||||
|
|
||||||
type TemplateCache struct {
|
type templateCache struct {
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
templatesPath string
|
templatesPath string
|
||||||
templates map[TemplateType]map[string]*template.Template
|
templates map[TemplateType]map[string]*template.Template
|
||||||
|
jinjaTemplates map[TemplateType]map[string]*exec.Template
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTemplateCache(templatesPath string) *TemplateCache {
|
func newTemplateCache(templatesPath string) *templateCache {
|
||||||
tc := &TemplateCache{
|
tc := &templateCache{
|
||||||
templatesPath: templatesPath,
|
templatesPath: templatesPath,
|
||||||
templates: make(map[TemplateType]map[string]*template.Template),
|
templates: make(map[TemplateType]map[string]*template.Template),
|
||||||
|
jinjaTemplates: make(map[TemplateType]map[string]*exec.Template),
|
||||||
}
|
}
|
||||||
return tc
|
return tc
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tc *TemplateCache) initializeTemplateMapKey(tt TemplateType) {
|
func (tc *templateCache) initializeTemplateMapKey(tt TemplateType) {
|
||||||
if _, ok := tc.templates[tt]; !ok {
|
if _, ok := tc.templates[tt]; !ok {
|
||||||
tc.templates[tt] = make(map[string]*template.Template)
|
tc.templates[tt] = make(map[string]*template.Template)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tc *TemplateCache) EvaluateTemplate(templateType TemplateType, templateName string, in interface{}) (string, error) {
|
func (tc *templateCache) existsInModelPath(s string) bool {
|
||||||
tc.mu.Lock()
|
return utils.ExistsInPath(tc.templatesPath, s)
|
||||||
defer tc.mu.Unlock()
|
|
||||||
|
|
||||||
tc.initializeTemplateMapKey(templateType)
|
|
||||||
m, ok := tc.templates[templateType][templateName]
|
|
||||||
if !ok {
|
|
||||||
// return "", fmt.Errorf("template not loaded: %s", templateName)
|
|
||||||
loadErr := tc.loadTemplateIfExists(templateType, templateName)
|
|
||||||
if loadErr != nil {
|
|
||||||
return "", loadErr
|
|
||||||
}
|
|
||||||
m = tc.templates[templateType][templateName] // ok is not important since we check m on the next line, and wealready checked
|
|
||||||
}
|
|
||||||
if m == nil {
|
|
||||||
return "", fmt.Errorf("failed loading a template for %s", templateName)
|
|
||||||
}
|
|
||||||
|
|
||||||
var buf bytes.Buffer
|
|
||||||
|
|
||||||
if err := m.Execute(&buf, in); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return buf.String(), nil
|
|
||||||
}
|
}
|
||||||
|
func (tc *templateCache) loadTemplateIfExists(templateType TemplateType, templateName string) error {
|
||||||
func (tc *TemplateCache) loadTemplateIfExists(templateType TemplateType, templateName string) error {
|
|
||||||
|
|
||||||
// Check if the template was already loaded
|
// Check if the template was already loaded
|
||||||
if _, ok := tc.templates[templateType][templateName]; ok {
|
if _, ok := tc.templates[templateType][templateName]; ok {
|
||||||
@ -82,6 +64,51 @@ func (tc *TemplateCache) loadTemplateIfExists(templateType TemplateType, templat
|
|||||||
return fmt.Errorf("template file outside path: %s", file)
|
return fmt.Errorf("template file outside path: %s", file)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// can either be a file in the system or a string with the template
|
||||||
|
if tc.existsInModelPath(modelTemplateFile) {
|
||||||
|
d, err := os.ReadFile(file)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
dat = string(d)
|
||||||
|
} else {
|
||||||
|
dat = templateName
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the template
|
||||||
|
tmpl, err := template.New("prompt").Funcs(sprig.FuncMap()).Parse(dat)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
tc.templates[templateType][templateName] = tmpl
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tc *templateCache) initializeJinjaTemplateMapKey(tt TemplateType) {
|
||||||
|
if _, ok := tc.jinjaTemplates[tt]; !ok {
|
||||||
|
tc.jinjaTemplates[tt] = make(map[string]*exec.Template)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tc *templateCache) loadJinjaTemplateIfExists(templateType TemplateType, templateName string) error {
|
||||||
|
// Check if the template was already loaded
|
||||||
|
if _, ok := tc.jinjaTemplates[templateType][templateName]; ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the model path exists
|
||||||
|
// skip any error here - we run anyway if a template does not exist
|
||||||
|
modelTemplateFile := fmt.Sprintf("%s.tmpl", templateName)
|
||||||
|
|
||||||
|
dat := ""
|
||||||
|
file := filepath.Join(tc.templatesPath, modelTemplateFile)
|
||||||
|
|
||||||
|
// Security check
|
||||||
|
if err := utils.VerifyPath(modelTemplateFile, tc.templatesPath); err != nil {
|
||||||
|
return fmt.Errorf("template file outside path: %s", file)
|
||||||
|
}
|
||||||
|
|
||||||
// can either be a file in the system or a string with the template
|
// can either be a file in the system or a string with the template
|
||||||
if utils.ExistsInPath(tc.templatesPath, modelTemplateFile) {
|
if utils.ExistsInPath(tc.templatesPath, modelTemplateFile) {
|
||||||
d, err := os.ReadFile(file)
|
d, err := os.ReadFile(file)
|
||||||
@ -93,12 +120,65 @@ func (tc *TemplateCache) loadTemplateIfExists(templateType TemplateType, templat
|
|||||||
dat = templateName
|
dat = templateName
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse the template
|
tmpl, err := gonja.FromString(dat)
|
||||||
tmpl, err := template.New("prompt").Funcs(sprig.FuncMap()).Parse(dat)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
tc.templates[templateType][templateName] = tmpl
|
tc.jinjaTemplates[templateType][templateName] = tmpl
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (tc *templateCache) evaluateJinjaTemplate(templateType TemplateType, templateNameOrContent string, in map[string]interface{}) (string, error) {
|
||||||
|
tc.mu.Lock()
|
||||||
|
defer tc.mu.Unlock()
|
||||||
|
|
||||||
|
tc.initializeJinjaTemplateMapKey(templateType)
|
||||||
|
m, ok := tc.jinjaTemplates[templateType][templateNameOrContent]
|
||||||
|
if !ok {
|
||||||
|
// return "", fmt.Errorf("template not loaded: %s", templateName)
|
||||||
|
loadErr := tc.loadJinjaTemplateIfExists(templateType, templateNameOrContent)
|
||||||
|
if loadErr != nil {
|
||||||
|
return "", loadErr
|
||||||
|
}
|
||||||
|
m = tc.jinjaTemplates[templateType][templateNameOrContent] // ok is not important since we check m on the next line, and wealready checked
|
||||||
|
}
|
||||||
|
if m == nil {
|
||||||
|
return "", fmt.Errorf("failed loading a template for %s", templateNameOrContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
|
||||||
|
data := exec.NewContext(in)
|
||||||
|
|
||||||
|
if err := m.Execute(&buf, data); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return buf.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tc *templateCache) evaluateTemplate(templateType TemplateType, templateNameOrContent string, in interface{}) (string, error) {
|
||||||
|
tc.mu.Lock()
|
||||||
|
defer tc.mu.Unlock()
|
||||||
|
|
||||||
|
tc.initializeTemplateMapKey(templateType)
|
||||||
|
m, ok := tc.templates[templateType][templateNameOrContent]
|
||||||
|
if !ok {
|
||||||
|
// return "", fmt.Errorf("template not loaded: %s", templateName)
|
||||||
|
loadErr := tc.loadTemplateIfExists(templateType, templateNameOrContent)
|
||||||
|
if loadErr != nil {
|
||||||
|
return "", loadErr
|
||||||
|
}
|
||||||
|
m = tc.templates[templateType][templateNameOrContent] // ok is not important since we check m on the next line, and wealready checked
|
||||||
|
}
|
||||||
|
if m == nil {
|
||||||
|
return "", fmt.Errorf("failed loading a template for %s", templateNameOrContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
|
||||||
|
if err := m.Execute(&buf, in); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return buf.String(), nil
|
||||||
|
}
|
||||||
|
@ -1,73 +0,0 @@
|
|||||||
package templates_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/pkg/templates" // Update with your module path
|
|
||||||
. "github.com/onsi/ginkgo/v2"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("TemplateCache", func() {
|
|
||||||
var (
|
|
||||||
templateCache *templates.TemplateCache
|
|
||||||
tempDir string
|
|
||||||
)
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
var err error
|
|
||||||
tempDir, err = os.MkdirTemp("", "templates")
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
|
|
||||||
// Writing example template files
|
|
||||||
err = os.WriteFile(filepath.Join(tempDir, "example.tmpl"), []byte("Hello, {{.Name}}!"), 0600)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
err = os.WriteFile(filepath.Join(tempDir, "empty.tmpl"), []byte(""), 0600)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
|
|
||||||
templateCache = templates.NewTemplateCache(tempDir)
|
|
||||||
})
|
|
||||||
|
|
||||||
AfterEach(func() {
|
|
||||||
os.RemoveAll(tempDir) // Clean up
|
|
||||||
})
|
|
||||||
|
|
||||||
Describe("EvaluateTemplate", func() {
|
|
||||||
Context("when template is loaded successfully", func() {
|
|
||||||
It("should evaluate the template correctly", func() {
|
|
||||||
result, err := templateCache.EvaluateTemplate(1, "example", map[string]string{"Name": "Gopher"})
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
Expect(result).To(Equal("Hello, Gopher!"))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("when template isn't a file", func() {
|
|
||||||
It("should parse from string", func() {
|
|
||||||
result, err := templateCache.EvaluateTemplate(1, "{{.Name}}", map[string]string{"Name": "Gopher"})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(result).To(Equal("Gopher"))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("when template is empty", func() {
|
|
||||||
It("should return an empty string", func() {
|
|
||||||
result, err := templateCache.EvaluateTemplate(1, "empty", nil)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
Expect(result).To(Equal(""))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Describe("concurrency", func() {
|
|
||||||
It("should handle multiple concurrent accesses", func(done Done) {
|
|
||||||
go func() {
|
|
||||||
_, _ = templateCache.EvaluateTemplate(1, "example", map[string]string{"Name": "Gopher"})
|
|
||||||
}()
|
|
||||||
go func() {
|
|
||||||
_, _ = templateCache.EvaluateTemplate(1, "example", map[string]string{"Name": "Gopher"})
|
|
||||||
}()
|
|
||||||
close(done)
|
|
||||||
}, 0.1) // timeout in seconds
|
|
||||||
})
|
|
||||||
})
|
|
295
pkg/templates/evaluator.go
Normal file
295
pkg/templates/evaluator.go
Normal file
@ -0,0 +1,295 @@
|
|||||||
|
package templates
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/mudler/LocalAI/core/config"
|
||||||
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
|
"github.com/mudler/LocalAI/pkg/functions"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Rather than pass an interface{} to the prompt template:
|
||||||
|
// These are the definitions of all possible variables LocalAI will currently populate for use in a prompt template file
|
||||||
|
// Please note: Not all of these are populated on every endpoint - your template should either be tested for each endpoint you map it to, or tolerant of zero values.
|
||||||
|
type PromptTemplateData struct {
|
||||||
|
SystemPrompt string
|
||||||
|
SuppressSystemPrompt bool // used by chat specifically to indicate that SystemPrompt above should be _ignored_
|
||||||
|
Input string
|
||||||
|
Instruction string
|
||||||
|
Functions []functions.Function
|
||||||
|
MessageIndex int
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatMessageTemplateData struct {
|
||||||
|
SystemPrompt string
|
||||||
|
Role string
|
||||||
|
RoleName string
|
||||||
|
FunctionName string
|
||||||
|
Content string
|
||||||
|
MessageIndex int
|
||||||
|
Function bool
|
||||||
|
FunctionCall interface{}
|
||||||
|
LastMessage bool
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
ChatPromptTemplate TemplateType = iota
|
||||||
|
ChatMessageTemplate
|
||||||
|
CompletionPromptTemplate
|
||||||
|
EditPromptTemplate
|
||||||
|
FunctionsPromptTemplate
|
||||||
|
)
|
||||||
|
|
||||||
|
type Evaluator struct {
|
||||||
|
cache *templateCache
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewEvaluator(modelPath string) *Evaluator {
|
||||||
|
return &Evaluator{
|
||||||
|
cache: newTemplateCache(modelPath),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Evaluator) EvaluateTemplateForPrompt(templateType TemplateType, config config.BackendConfig, in PromptTemplateData) (string, error) {
|
||||||
|
template := ""
|
||||||
|
|
||||||
|
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||||
|
if e.cache.existsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
|
||||||
|
template = config.Model
|
||||||
|
}
|
||||||
|
|
||||||
|
switch templateType {
|
||||||
|
case CompletionPromptTemplate:
|
||||||
|
if config.TemplateConfig.Completion != "" {
|
||||||
|
template = config.TemplateConfig.Completion
|
||||||
|
}
|
||||||
|
case EditPromptTemplate:
|
||||||
|
if config.TemplateConfig.Edit != "" {
|
||||||
|
template = config.TemplateConfig.Edit
|
||||||
|
}
|
||||||
|
case ChatPromptTemplate:
|
||||||
|
if config.TemplateConfig.Chat != "" {
|
||||||
|
template = config.TemplateConfig.Chat
|
||||||
|
}
|
||||||
|
case FunctionsPromptTemplate:
|
||||||
|
if config.TemplateConfig.Functions != "" {
|
||||||
|
template = config.TemplateConfig.Functions
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if template == "" {
|
||||||
|
return in.Input, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.TemplateConfig.JinjaTemplate {
|
||||||
|
return e.evaluateJinjaTemplateForPrompt(templateType, template, in)
|
||||||
|
}
|
||||||
|
|
||||||
|
return e.cache.evaluateTemplate(templateType, template, in)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Evaluator) evaluateTemplateForChatMessage(templateName string, messageData ChatMessageTemplateData) (string, error) {
|
||||||
|
return e.cache.evaluateTemplate(ChatMessageTemplate, templateName, messageData)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Evaluator) templateJinjaChat(templateName string, messageData []ChatMessageTemplateData, funcs []functions.Function) (string, error) {
|
||||||
|
|
||||||
|
conversation := make(map[string]interface{})
|
||||||
|
messages := make([]map[string]interface{}, len(messageData))
|
||||||
|
|
||||||
|
// convert from ChatMessageTemplateData to what the jinja template expects
|
||||||
|
|
||||||
|
for _, message := range messageData {
|
||||||
|
// TODO: this seems to cover minimum text templates. Can be expanded to cover more complex interactions
|
||||||
|
var data []byte
|
||||||
|
data, _ = json.Marshal(message.FunctionCall)
|
||||||
|
messages = append(messages, map[string]interface{}{
|
||||||
|
"role": message.RoleName,
|
||||||
|
"content": message.Content,
|
||||||
|
"tool_call": string(data),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
conversation["messages"] = messages
|
||||||
|
|
||||||
|
// if tools are detected, add these
|
||||||
|
if len(funcs) > 0 {
|
||||||
|
conversation["tools"] = funcs
|
||||||
|
}
|
||||||
|
|
||||||
|
return e.cache.evaluateJinjaTemplate(ChatMessageTemplate, templateName, conversation)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Evaluator) evaluateJinjaTemplateForPrompt(templateType TemplateType, templateName string, in PromptTemplateData) (string, error) {
|
||||||
|
|
||||||
|
conversation := make(map[string]interface{})
|
||||||
|
|
||||||
|
conversation["system_prompt"] = in.SystemPrompt
|
||||||
|
conversation["content"] = in.Input
|
||||||
|
|
||||||
|
return e.cache.evaluateJinjaTemplate(templateType, templateName, conversation)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Evaluator) TemplateMessages(messages []schema.Message, config *config.BackendConfig, funcs []functions.Function, shouldUseFn bool) string {
|
||||||
|
|
||||||
|
if config.TemplateConfig.JinjaTemplate {
|
||||||
|
var messageData []ChatMessageTemplateData
|
||||||
|
for messageIndex, i := range messages {
|
||||||
|
fcall := i.FunctionCall
|
||||||
|
if len(i.ToolCalls) > 0 {
|
||||||
|
fcall = i.ToolCalls
|
||||||
|
}
|
||||||
|
messageData = append(messageData, ChatMessageTemplateData{
|
||||||
|
SystemPrompt: config.SystemPrompt,
|
||||||
|
Role: config.Roles[i.Role],
|
||||||
|
RoleName: i.Role,
|
||||||
|
Content: i.StringContent,
|
||||||
|
FunctionCall: fcall,
|
||||||
|
FunctionName: i.Name,
|
||||||
|
LastMessage: messageIndex == (len(messages) - 1),
|
||||||
|
Function: config.Grammar != "" && (messageIndex == (len(messages) - 1)),
|
||||||
|
MessageIndex: messageIndex,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
templatedInput, err := e.templateJinjaChat(config.TemplateConfig.ChatMessage, messageData, funcs)
|
||||||
|
if err == nil {
|
||||||
|
return templatedInput
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var predInput string
|
||||||
|
suppressConfigSystemPrompt := false
|
||||||
|
mess := []string{}
|
||||||
|
for messageIndex, i := range messages {
|
||||||
|
var content string
|
||||||
|
role := i.Role
|
||||||
|
|
||||||
|
// if function call, we might want to customize the role so we can display better that the "assistant called a json action"
|
||||||
|
// if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request
|
||||||
|
if (i.FunctionCall != nil || i.ToolCalls != nil) && i.Role == "assistant" {
|
||||||
|
roleFn := "assistant_function_call"
|
||||||
|
r := config.Roles[roleFn]
|
||||||
|
if r != "" {
|
||||||
|
role = roleFn
|
||||||
|
}
|
||||||
|
}
|
||||||
|
r := config.Roles[role]
|
||||||
|
contentExists := i.Content != nil && i.StringContent != ""
|
||||||
|
|
||||||
|
fcall := i.FunctionCall
|
||||||
|
if len(i.ToolCalls) > 0 {
|
||||||
|
fcall = i.ToolCalls
|
||||||
|
}
|
||||||
|
|
||||||
|
// First attempt to populate content via a chat message specific template
|
||||||
|
if config.TemplateConfig.ChatMessage != "" {
|
||||||
|
chatMessageData := ChatMessageTemplateData{
|
||||||
|
SystemPrompt: config.SystemPrompt,
|
||||||
|
Role: r,
|
||||||
|
RoleName: role,
|
||||||
|
Content: i.StringContent,
|
||||||
|
FunctionCall: fcall,
|
||||||
|
FunctionName: i.Name,
|
||||||
|
LastMessage: messageIndex == (len(messages) - 1),
|
||||||
|
Function: config.Grammar != "" && (messageIndex == (len(messages) - 1)),
|
||||||
|
MessageIndex: messageIndex,
|
||||||
|
}
|
||||||
|
templatedChatMessage, err := e.evaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Interface("message", chatMessageData).Str("template", config.TemplateConfig.ChatMessage).Msg("error processing message with template, skipping")
|
||||||
|
} else {
|
||||||
|
if templatedChatMessage == "" {
|
||||||
|
log.Warn().Msgf("template \"%s\" produced blank output for %+v. Skipping!", config.TemplateConfig.ChatMessage, chatMessageData)
|
||||||
|
continue // TODO: This continue is here intentionally to skip over the line `mess = append(mess, content)` below, and to prevent the sprintf
|
||||||
|
}
|
||||||
|
log.Debug().Msgf("templated message for chat: %s", templatedChatMessage)
|
||||||
|
content = templatedChatMessage
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
marshalAnyRole := func(f any) {
|
||||||
|
j, err := json.Marshal(f)
|
||||||
|
if err == nil {
|
||||||
|
if contentExists {
|
||||||
|
content += "\n" + fmt.Sprint(r, " ", string(j))
|
||||||
|
} else {
|
||||||
|
content = fmt.Sprint(r, " ", string(j))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
marshalAny := func(f any) {
|
||||||
|
j, err := json.Marshal(f)
|
||||||
|
if err == nil {
|
||||||
|
if contentExists {
|
||||||
|
content += "\n" + string(j)
|
||||||
|
} else {
|
||||||
|
content = string(j)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// If this model doesn't have such a template, or if that template fails to return a value, template at the message level.
|
||||||
|
if content == "" {
|
||||||
|
if r != "" {
|
||||||
|
if contentExists {
|
||||||
|
content = fmt.Sprint(r, i.StringContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
if i.FunctionCall != nil {
|
||||||
|
marshalAnyRole(i.FunctionCall)
|
||||||
|
}
|
||||||
|
if i.ToolCalls != nil {
|
||||||
|
marshalAnyRole(i.ToolCalls)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if contentExists {
|
||||||
|
content = fmt.Sprint(i.StringContent)
|
||||||
|
}
|
||||||
|
if i.FunctionCall != nil {
|
||||||
|
marshalAny(i.FunctionCall)
|
||||||
|
}
|
||||||
|
if i.ToolCalls != nil {
|
||||||
|
marshalAny(i.ToolCalls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Special Handling: System. We care if it was printed at all, not the r branch, so check seperately
|
||||||
|
if contentExists && role == "system" {
|
||||||
|
suppressConfigSystemPrompt = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mess = append(mess, content)
|
||||||
|
}
|
||||||
|
|
||||||
|
joinCharacter := "\n"
|
||||||
|
if config.TemplateConfig.JoinChatMessagesByCharacter != nil {
|
||||||
|
joinCharacter = *config.TemplateConfig.JoinChatMessagesByCharacter
|
||||||
|
}
|
||||||
|
|
||||||
|
predInput = strings.Join(mess, joinCharacter)
|
||||||
|
log.Debug().Msgf("Prompt (before templating): %s", predInput)
|
||||||
|
|
||||||
|
promptTemplate := ChatPromptTemplate
|
||||||
|
|
||||||
|
if config.TemplateConfig.Functions != "" && shouldUseFn {
|
||||||
|
promptTemplate = FunctionsPromptTemplate
|
||||||
|
}
|
||||||
|
|
||||||
|
templatedInput, err := e.EvaluateTemplateForPrompt(promptTemplate, *config, PromptTemplateData{
|
||||||
|
SystemPrompt: config.SystemPrompt,
|
||||||
|
SuppressSystemPrompt: suppressConfigSystemPrompt,
|
||||||
|
Input: predInput,
|
||||||
|
Functions: funcs,
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
predInput = templatedInput
|
||||||
|
log.Debug().Msgf("Template found, input modified to: %s", predInput)
|
||||||
|
} else {
|
||||||
|
log.Debug().Msgf("Template failed loading: %s", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
return predInput
|
||||||
|
}
|
253
pkg/templates/evaluator_test.go
Normal file
253
pkg/templates/evaluator_test.go
Normal file
@ -0,0 +1,253 @@
|
|||||||
|
package templates_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/mudler/LocalAI/core/config"
|
||||||
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
|
"github.com/mudler/LocalAI/pkg/functions"
|
||||||
|
. "github.com/mudler/LocalAI/pkg/templates"
|
||||||
|
|
||||||
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
const toolCallJinja = `{{ '<|begin_of_text|>' }}{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}{% if system_message is defined %}{{ '<|start_header_id|>system<|end_header_id|>
|
||||||
|
|
||||||
|
' + system_message + '<|eot_id|>' }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
' + content + '<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
' }}{% elif message['role'] == 'assistant' %}{{ content + '<|eot_id|>' }}{% endif %}{% endfor %}`
|
||||||
|
|
||||||
|
const chatML = `<|im_start|>{{if eq .RoleName "assistant"}}assistant{{else if eq .RoleName "system"}}system{{else if eq .RoleName "tool"}}tool{{else if eq .RoleName "user"}}user{{end}}
|
||||||
|
{{- if .FunctionCall }}
|
||||||
|
<tool_call>
|
||||||
|
{{- else if eq .RoleName "tool" }}
|
||||||
|
<tool_response>
|
||||||
|
{{- end }}
|
||||||
|
{{- if .Content}}
|
||||||
|
{{.Content }}
|
||||||
|
{{- end }}
|
||||||
|
{{- if .FunctionCall}}
|
||||||
|
{{toJson .FunctionCall}}
|
||||||
|
{{- end }}
|
||||||
|
{{- if .FunctionCall }}
|
||||||
|
</tool_call>
|
||||||
|
{{- else if eq .RoleName "tool" }}
|
||||||
|
</tool_response>
|
||||||
|
{{- end }}<|im_end|>`
|
||||||
|
|
||||||
|
const llama3 = `<|start_header_id|>{{if eq .RoleName "assistant"}}assistant{{else if eq .RoleName "system"}}system{{else if eq .RoleName "tool"}}tool{{else if eq .RoleName "user"}}user{{end}}<|end_header_id|>
|
||||||
|
|
||||||
|
{{ if .FunctionCall -}}
|
||||||
|
Function call:
|
||||||
|
{{ else if eq .RoleName "tool" -}}
|
||||||
|
Function response:
|
||||||
|
{{ end -}}
|
||||||
|
{{ if .Content -}}
|
||||||
|
{{.Content -}}
|
||||||
|
{{ else if .FunctionCall -}}
|
||||||
|
{{ toJson .FunctionCall -}}
|
||||||
|
{{ end -}}
|
||||||
|
<|eot_id|>`
|
||||||
|
|
||||||
|
var llama3TestMatch map[string]map[string]interface{} = map[string]map[string]interface{}{
|
||||||
|
"user": {
|
||||||
|
"expected": "<|start_header_id|>user<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|>",
|
||||||
|
"config": &config.BackendConfig{
|
||||||
|
TemplateConfig: config.TemplateConfig{
|
||||||
|
ChatMessage: llama3,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"functions": []functions.Function{},
|
||||||
|
"shouldUseFn": false,
|
||||||
|
"messages": []schema.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
StringContent: "A long time ago in a galaxy far, far away...",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"assistant": {
|
||||||
|
"expected": "<|start_header_id|>assistant<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|>",
|
||||||
|
"config": &config.BackendConfig{
|
||||||
|
TemplateConfig: config.TemplateConfig{
|
||||||
|
ChatMessage: llama3,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"functions": []functions.Function{},
|
||||||
|
"messages": []schema.Message{
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
StringContent: "A long time ago in a galaxy far, far away...",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"shouldUseFn": false,
|
||||||
|
},
|
||||||
|
"function_call": {
|
||||||
|
|
||||||
|
"expected": "<|start_header_id|>assistant<|end_header_id|>\n\nFunction call:\n{\"function\":\"test\"}<|eot_id|>",
|
||||||
|
"config": &config.BackendConfig{
|
||||||
|
TemplateConfig: config.TemplateConfig{
|
||||||
|
ChatMessage: llama3,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"functions": []functions.Function{},
|
||||||
|
"messages": []schema.Message{
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
FunctionCall: map[string]string{"function": "test"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"shouldUseFn": false,
|
||||||
|
},
|
||||||
|
"function_response": {
|
||||||
|
"expected": "<|start_header_id|>tool<|end_header_id|>\n\nFunction response:\nResponse from tool<|eot_id|>",
|
||||||
|
"config": &config.BackendConfig{
|
||||||
|
TemplateConfig: config.TemplateConfig{
|
||||||
|
ChatMessage: llama3,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"functions": []functions.Function{},
|
||||||
|
"messages": []schema.Message{
|
||||||
|
{
|
||||||
|
Role: "tool",
|
||||||
|
StringContent: "Response from tool",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"shouldUseFn": false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var chatMLTestMatch map[string]map[string]interface{} = map[string]map[string]interface{}{
|
||||||
|
"user": {
|
||||||
|
"expected": "<|im_start|>user\nA long time ago in a galaxy far, far away...<|im_end|>",
|
||||||
|
"config": &config.BackendConfig{
|
||||||
|
TemplateConfig: config.TemplateConfig{
|
||||||
|
ChatMessage: chatML,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"functions": []functions.Function{},
|
||||||
|
"shouldUseFn": false,
|
||||||
|
"messages": []schema.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
StringContent: "A long time ago in a galaxy far, far away...",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"assistant": {
|
||||||
|
"expected": "<|im_start|>assistant\nA long time ago in a galaxy far, far away...<|im_end|>",
|
||||||
|
"config": &config.BackendConfig{
|
||||||
|
TemplateConfig: config.TemplateConfig{
|
||||||
|
ChatMessage: chatML,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"functions": []functions.Function{},
|
||||||
|
"messages": []schema.Message{
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
StringContent: "A long time ago in a galaxy far, far away...",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"shouldUseFn": false,
|
||||||
|
},
|
||||||
|
"function_call": {
|
||||||
|
"expected": "<|im_start|>assistant\n<tool_call>\n{\"function\":\"test\"}\n</tool_call><|im_end|>",
|
||||||
|
"config": &config.BackendConfig{
|
||||||
|
TemplateConfig: config.TemplateConfig{
|
||||||
|
ChatMessage: chatML,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"functions": []functions.Function{
|
||||||
|
{
|
||||||
|
Name: "test",
|
||||||
|
Description: "test",
|
||||||
|
Parameters: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"shouldUseFn": true,
|
||||||
|
"messages": []schema.Message{
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
FunctionCall: map[string]string{"function": "test"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"function_response": {
|
||||||
|
"expected": "<|im_start|>tool\n<tool_response>\nResponse from tool\n</tool_response><|im_end|>",
|
||||||
|
"config": &config.BackendConfig{
|
||||||
|
TemplateConfig: config.TemplateConfig{
|
||||||
|
ChatMessage: chatML,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"functions": []functions.Function{},
|
||||||
|
"shouldUseFn": false,
|
||||||
|
"messages": []schema.Message{
|
||||||
|
{
|
||||||
|
Role: "tool",
|
||||||
|
StringContent: "Response from tool",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var jinjaTest map[string]map[string]interface{} = map[string]map[string]interface{}{
|
||||||
|
"user": {
|
||||||
|
"expected": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
|
||||||
|
"config": &config.BackendConfig{
|
||||||
|
TemplateConfig: config.TemplateConfig{
|
||||||
|
ChatMessage: toolCallJinja,
|
||||||
|
JinjaTemplate: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"functions": []functions.Function{},
|
||||||
|
"shouldUseFn": false,
|
||||||
|
"messages": []schema.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
StringContent: "A long time ago in a galaxy far, far away...",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
var _ = Describe("Templates", func() {
|
||||||
|
Context("chat message ChatML", func() {
|
||||||
|
var evaluator *Evaluator
|
||||||
|
BeforeEach(func() {
|
||||||
|
evaluator = NewEvaluator("")
|
||||||
|
})
|
||||||
|
for key := range chatMLTestMatch {
|
||||||
|
foo := chatMLTestMatch[key]
|
||||||
|
It("renders correctly `"+key+"`", func() {
|
||||||
|
templated := evaluator.TemplateMessages(foo["messages"].([]schema.Message), foo["config"].(*config.BackendConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool))
|
||||||
|
Expect(templated).To(Equal(foo["expected"]), templated)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
Context("chat message llama3", func() {
|
||||||
|
var evaluator *Evaluator
|
||||||
|
BeforeEach(func() {
|
||||||
|
evaluator = NewEvaluator("")
|
||||||
|
})
|
||||||
|
for key := range llama3TestMatch {
|
||||||
|
foo := llama3TestMatch[key]
|
||||||
|
It("renders correctly `"+key+"`", func() {
|
||||||
|
templated := evaluator.TemplateMessages(foo["messages"].([]schema.Message), foo["config"].(*config.BackendConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool))
|
||||||
|
Expect(templated).To(Equal(foo["expected"]), templated)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
Context("chat message jinja", func() {
|
||||||
|
var evaluator *Evaluator
|
||||||
|
BeforeEach(func() {
|
||||||
|
evaluator = NewEvaluator("")
|
||||||
|
})
|
||||||
|
for key := range jinjaTest {
|
||||||
|
foo := jinjaTest[key]
|
||||||
|
It("renders correctly `"+key+"`", func() {
|
||||||
|
templated := evaluator.TemplateMessages(foo["messages"].([]schema.Message), foo["config"].(*config.BackendConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool))
|
||||||
|
Expect(templated).To(Equal(foo["expected"]), templated)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
Loading…
Reference in New Issue
Block a user