mirror of
https://github.com/mudler/LocalAI.git
synced 2024-12-24 06:46:39 +00:00
feat(llama2): add template for chat messages (#782)
Co-authored-by: Aman Karmani <aman@tmm1.net> Lays some of the groundwork for LLAMA2 compatibility as well as other future models with complex prompting schemes. Started small refactoring in pkg/model/loader.go regarding template loading. Currently still a part of ModelLoader, but should be easy to add template loading for situations other than overall prompt templates and the new chat-specific per-message templates Adds support for new chat-endpoint-specific, per-message templates as an alternative to the existing Role: XYZ sprintf method. Includes a temporary prompt template as an example, since I have a few questions before we merge in the model-gallery side changes (see ) Minor debug logging changes.
This commit is contained in:
parent
5ee186b8e5
commit
c6bf67f446
@ -49,6 +49,8 @@ type Config struct {
|
|||||||
functionCallString, functionCallNameString string
|
functionCallString, functionCallNameString string
|
||||||
|
|
||||||
FunctionsConfig Functions `yaml:"function"`
|
FunctionsConfig Functions `yaml:"function"`
|
||||||
|
|
||||||
|
SystemPrompt string `yaml:"system_prompt"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Functions struct {
|
type Functions struct {
|
||||||
@ -58,10 +60,11 @@ type Functions struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type TemplateConfig struct {
|
type TemplateConfig struct {
|
||||||
Completion string `yaml:"completion"`
|
|
||||||
Functions string `yaml:"function"`
|
|
||||||
Chat string `yaml:"chat"`
|
Chat string `yaml:"chat"`
|
||||||
|
ChatMessage string `yaml:"chat_message"`
|
||||||
|
Completion string `yaml:"completion"`
|
||||||
Edit string `yaml:"edit"`
|
Edit string `yaml:"edit"`
|
||||||
|
Functions string `yaml:"function"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ConfigLoader struct {
|
type ConfigLoader struct {
|
||||||
|
@ -43,12 +43,12 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
|
|||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
processFunctions := false
|
processFunctions := false
|
||||||
funcs := grammar.Functions{}
|
funcs := grammar.Functions{}
|
||||||
model, input, err := readInput(c, o.Loader, true)
|
modelFile, input, err := readInput(c, o.Loader, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
config, input, err := readConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
|
config, input, err := readConfig(modelFile, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
}
|
}
|
||||||
@ -110,9 +110,10 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
|
|||||||
var predInput string
|
var predInput string
|
||||||
|
|
||||||
mess := []string{}
|
mess := []string{}
|
||||||
for _, i := range input.Messages {
|
for messageIndex, i := range input.Messages {
|
||||||
var content string
|
var content string
|
||||||
role := i.Role
|
role := i.Role
|
||||||
|
|
||||||
// if function call, we might want to customize the role so we can display better that the "assistant called a json action"
|
// if function call, we might want to customize the role so we can display better that the "assistant called a json action"
|
||||||
// if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request
|
// if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request
|
||||||
if i.FunctionCall != nil && i.Role == "assistant" {
|
if i.FunctionCall != nil && i.Role == "assistant" {
|
||||||
@ -124,6 +125,29 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
|
|||||||
}
|
}
|
||||||
r := config.Roles[role]
|
r := config.Roles[role]
|
||||||
contentExists := i.Content != nil && *i.Content != ""
|
contentExists := i.Content != nil && *i.Content != ""
|
||||||
|
// First attempt to populate content via a chat message specific template
|
||||||
|
if config.TemplateConfig.ChatMessage != "" {
|
||||||
|
chatMessageData := model.ChatMessageTemplateData{
|
||||||
|
SystemPrompt: config.SystemPrompt,
|
||||||
|
Role: r,
|
||||||
|
RoleName: role,
|
||||||
|
Content: *i.Content,
|
||||||
|
MessageIndex: messageIndex,
|
||||||
|
}
|
||||||
|
templatedChatMessage, err := o.Loader.EvaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Msgf("error processing message %+v using template \"%s\": %v. Skipping!", chatMessageData, config.TemplateConfig.ChatMessage, err)
|
||||||
|
} else {
|
||||||
|
if templatedChatMessage == "" {
|
||||||
|
log.Warn().Msgf("template \"%s\" produced blank output for %+v. Skipping!", config.TemplateConfig.ChatMessage, chatMessageData)
|
||||||
|
continue // TODO: This continue is here intentionally to skip over the line `mess = append(mess, content)` below, and to prevent the sprintf
|
||||||
|
}
|
||||||
|
log.Debug().Msgf("templated message for chat: %s", templatedChatMessage)
|
||||||
|
content = templatedChatMessage
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// If this model doesn't have such a template, or if
|
||||||
|
if content == "" {
|
||||||
if r != "" {
|
if r != "" {
|
||||||
if contentExists {
|
if contentExists {
|
||||||
content = fmt.Sprint(r, " ", *i.Content)
|
content = fmt.Sprint(r, " ", *i.Content)
|
||||||
@ -153,6 +177,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
mess = append(mess, content)
|
mess = append(mess, content)
|
||||||
}
|
}
|
||||||
@ -181,10 +206,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||||
templatedInput, err := o.Loader.TemplatePrefix(templateFile, struct {
|
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.ChatPromptTemplate, templateFile, model.PromptTemplateData{
|
||||||
Input string
|
|
||||||
Functions []grammar.Function
|
|
||||||
}{
|
|
||||||
Input: predInput,
|
Input: predInput,
|
||||||
Functions: funcs,
|
Functions: funcs,
|
||||||
})
|
})
|
||||||
|
@ -38,14 +38,14 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
|
|||||||
}
|
}
|
||||||
|
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
model, input, err := readInput(c, o.Loader, true)
|
modelFile, input, err := readInput(c, o.Loader, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug().Msgf("`input`: %+v", input)
|
log.Debug().Msgf("`input`: %+v", input)
|
||||||
|
|
||||||
config, input, err := readConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
|
config, input, err := readConfig(modelFile, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
}
|
}
|
||||||
@ -76,9 +76,7 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
|
|||||||
predInput := config.PromptStrings[0]
|
predInput := config.PromptStrings[0]
|
||||||
|
|
||||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||||
templatedInput, err := o.Loader.TemplatePrefix(templateFile, struct {
|
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
|
||||||
Input string
|
|
||||||
}{
|
|
||||||
Input: predInput,
|
Input: predInput,
|
||||||
})
|
})
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@ -124,9 +122,7 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
|
|||||||
var result []Choice
|
var result []Choice
|
||||||
for k, i := range config.PromptStrings {
|
for k, i := range config.PromptStrings {
|
||||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||||
templatedInput, err := o.Loader.TemplatePrefix(templateFile, struct {
|
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
|
||||||
Input string
|
|
||||||
}{
|
|
||||||
Input: i,
|
Input: i,
|
||||||
})
|
})
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
@ -6,18 +6,19 @@ import (
|
|||||||
|
|
||||||
config "github.com/go-skynet/LocalAI/api/config"
|
config "github.com/go-skynet/LocalAI/api/config"
|
||||||
"github.com/go-skynet/LocalAI/api/options"
|
"github.com/go-skynet/LocalAI/api/options"
|
||||||
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
model, input, err := readInput(c, o.Loader, true)
|
modelFile, input, err := readInput(c, o.Loader, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
config, input, err := readConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
|
config, input, err := readConfig(modelFile, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
}
|
}
|
||||||
@ -33,10 +34,10 @@ func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
|
|||||||
var result []Choice
|
var result []Choice
|
||||||
for _, i := range config.InputStrings {
|
for _, i := range config.InputStrings {
|
||||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||||
templatedInput, err := o.Loader.TemplatePrefix(templateFile, struct {
|
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.EditPromptTemplate, templateFile, model.PromptTemplateData{
|
||||||
Input string
|
Input: i,
|
||||||
Instruction string
|
Instruction: input.Instruction,
|
||||||
}{Input: i})
|
})
|
||||||
if err == nil {
|
if err == nil {
|
||||||
i = templatedInput
|
i = templatedInput
|
||||||
log.Debug().Msgf("Template found, input modified to: %s", i)
|
log.Debug().Msgf("Template found, input modified to: %s", i)
|
||||||
|
@ -128,7 +128,7 @@ func (ml *ModelLoader) startProcess(grpcProcess, id string, serverAddress string
|
|||||||
// It also loads the model
|
// It also loads the model
|
||||||
func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string) (*grpc.Client, error) {
|
func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string) (*grpc.Client, error) {
|
||||||
return func(s string) (*grpc.Client, error) {
|
return func(s string) (*grpc.Client, error) {
|
||||||
log.Debug().Msgf("Loading GRPC Model", backend, *o)
|
log.Debug().Msgf("Loading GRPC Model %s: %+v", backend, *o)
|
||||||
|
|
||||||
var client *grpc.Client
|
var client *grpc.Client
|
||||||
|
|
||||||
|
@ -4,43 +4,81 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"text/template"
|
"text/template"
|
||||||
|
|
||||||
|
grammar "github.com/go-skynet/LocalAI/pkg/grammar"
|
||||||
"github.com/go-skynet/LocalAI/pkg/grpc"
|
"github.com/go-skynet/LocalAI/pkg/grpc"
|
||||||
process "github.com/mudler/go-processmanager"
|
process "github.com/mudler/go-processmanager"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Rather than pass an interface{} to the prompt template:
|
||||||
|
// These are the definitions of all possible variables LocalAI will currently populate for use in a prompt template file
|
||||||
|
// Please note: Not all of these are populated on every endpoint - your template should either be tested for each endpoint you map it to, or tolerant of zero values.
|
||||||
|
type PromptTemplateData struct {
|
||||||
|
Input string
|
||||||
|
Instruction string
|
||||||
|
Functions []grammar.Function
|
||||||
|
MessageIndex int
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Ask mudler about FunctionCall stuff being useful at the message level?
|
||||||
|
type ChatMessageTemplateData struct {
|
||||||
|
SystemPrompt string
|
||||||
|
Role string
|
||||||
|
RoleName string
|
||||||
|
Content string
|
||||||
|
MessageIndex int
|
||||||
|
}
|
||||||
|
|
||||||
|
// Keep this in sync with config.TemplateConfig. Is there a more idiomatic way to accomplish this in go?
|
||||||
|
// Technically, order doesn't _really_ matter, but the count must stay in sync, see tests/integration/reflect_test.go
|
||||||
|
type TemplateType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
ChatPromptTemplate TemplateType = iota
|
||||||
|
ChatMessageTemplate
|
||||||
|
CompletionPromptTemplate
|
||||||
|
EditPromptTemplate
|
||||||
|
FunctionsPromptTemplate
|
||||||
|
|
||||||
|
// The following TemplateType is **NOT** a valid value and MUST be last. It exists to make the sanity integration tests simpler!
|
||||||
|
IntegrationTestTemplate
|
||||||
|
)
|
||||||
|
|
||||||
|
// new idea: what if we declare a struct of these here, and use a loop to check?
|
||||||
|
|
||||||
|
// TODO: Split ModelLoader and TemplateLoader? Just to keep things more organized. Left together to share a mutex until I look into that. Would split if we seperate directories for .bin/.yaml and .tmpl
|
||||||
type ModelLoader struct {
|
type ModelLoader struct {
|
||||||
ModelPath string
|
ModelPath string
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
// TODO: this needs generics
|
// TODO: this needs generics
|
||||||
models map[string]*grpc.Client
|
models map[string]*grpc.Client
|
||||||
grpcProcesses map[string]*process.Process
|
grpcProcesses map[string]*process.Process
|
||||||
promptsTemplates map[string]*template.Template
|
templates map[TemplateType]map[string]*template.Template
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewModelLoader(modelPath string) *ModelLoader {
|
func NewModelLoader(modelPath string) *ModelLoader {
|
||||||
return &ModelLoader{
|
nml := &ModelLoader{
|
||||||
ModelPath: modelPath,
|
ModelPath: modelPath,
|
||||||
models: make(map[string]*grpc.Client),
|
models: make(map[string]*grpc.Client),
|
||||||
promptsTemplates: make(map[string]*template.Template),
|
templates: make(map[TemplateType]map[string]*template.Template),
|
||||||
grpcProcesses: make(map[string]*process.Process),
|
grpcProcesses: make(map[string]*process.Process),
|
||||||
}
|
}
|
||||||
|
nml.initializeTemplateMap()
|
||||||
|
return nml
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ml *ModelLoader) ExistsInModelPath(s string) bool {
|
func (ml *ModelLoader) ExistsInModelPath(s string) bool {
|
||||||
_, err := os.Stat(filepath.Join(ml.ModelPath, s))
|
return existsInPath(ml.ModelPath, s)
|
||||||
return err == nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ml *ModelLoader) ListModels() ([]string, error) {
|
func (ml *ModelLoader) ListModels() ([]string, error) {
|
||||||
files, err := ioutil.ReadDir(ml.ModelPath)
|
files, err := os.ReadDir(ml.ModelPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return []string{}, err
|
return []string{}, err
|
||||||
}
|
}
|
||||||
@ -58,63 +96,6 @@ func (ml *ModelLoader) ListModels() ([]string, error) {
|
|||||||
return models, nil
|
return models, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ml *ModelLoader) TemplatePrefix(modelName string, in interface{}) (string, error) {
|
|
||||||
ml.mu.Lock()
|
|
||||||
defer ml.mu.Unlock()
|
|
||||||
|
|
||||||
m, ok := ml.promptsTemplates[modelName]
|
|
||||||
if !ok {
|
|
||||||
modelFile := filepath.Join(ml.ModelPath, modelName)
|
|
||||||
if err := ml.loadTemplateIfExists(modelName, modelFile); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
t, exists := ml.promptsTemplates[modelName]
|
|
||||||
if exists {
|
|
||||||
m = t
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if m == nil {
|
|
||||||
return "", fmt.Errorf("failed loading any template")
|
|
||||||
}
|
|
||||||
|
|
||||||
var buf bytes.Buffer
|
|
||||||
|
|
||||||
if err := m.Execute(&buf, in); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return buf.String(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ml *ModelLoader) loadTemplateIfExists(modelName, modelFile string) error {
|
|
||||||
// Check if the template was already loaded
|
|
||||||
if _, ok := ml.promptsTemplates[modelName]; ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if the model path exists
|
|
||||||
// skip any error here - we run anyway if a template does not exist
|
|
||||||
modelTemplateFile := fmt.Sprintf("%s.tmpl", modelName)
|
|
||||||
|
|
||||||
if !ml.ExistsInModelPath(modelTemplateFile) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
dat, err := os.ReadFile(filepath.Join(ml.ModelPath, modelTemplateFile))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse the template
|
|
||||||
tmpl, err := template.New("prompt").Parse(string(dat))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
ml.promptsTemplates[modelName] = tmpl
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ml *ModelLoader) LoadModel(modelName string, loader func(string) (*grpc.Client, error)) (*grpc.Client, error) {
|
func (ml *ModelLoader) LoadModel(modelName string, loader func(string) (*grpc.Client, error)) (*grpc.Client, error) {
|
||||||
ml.mu.Lock()
|
ml.mu.Lock()
|
||||||
defer ml.mu.Unlock()
|
defer ml.mu.Unlock()
|
||||||
@ -134,10 +115,13 @@ func (ml *ModelLoader) LoadModel(modelName string, loader func(string) (*grpc.Cl
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// If there is a prompt template, load it
|
// TODO: Add a helper method to iterate all prompt templates associated with a config if and only if it's YAML?
|
||||||
if err := ml.loadTemplateIfExists(modelName, modelFile); err != nil {
|
// Minor perf loss here until this is fixed, but we initialize on first request
|
||||||
return nil, err
|
|
||||||
}
|
// // If there is a prompt template, load it
|
||||||
|
// if err := ml.loadTemplateIfExists(modelName); err != nil {
|
||||||
|
// return nil, err
|
||||||
|
// }
|
||||||
|
|
||||||
ml.models[modelName] = model
|
ml.models[modelName] = model
|
||||||
return model, nil
|
return model, nil
|
||||||
@ -148,9 +132,9 @@ func (ml *ModelLoader) checkIsLoaded(s string) *grpc.Client {
|
|||||||
log.Debug().Msgf("Model already loaded in memory: %s", s)
|
log.Debug().Msgf("Model already loaded in memory: %s", s)
|
||||||
|
|
||||||
if !m.HealthCheck(context.Background()) {
|
if !m.HealthCheck(context.Background()) {
|
||||||
log.Debug().Msgf("GRPC Model not responding", s)
|
log.Debug().Msgf("GRPC Model not responding: %s", s)
|
||||||
if !ml.grpcProcesses[s].IsAlive() {
|
if !ml.grpcProcesses[s].IsAlive() {
|
||||||
log.Debug().Msgf("GRPC Process is not responding", s)
|
log.Debug().Msgf("GRPC Process is not responding: %s", s)
|
||||||
// stop and delete the process, this forces to re-load the model and re-create again the service
|
// stop and delete the process, this forces to re-load the model and re-create again the service
|
||||||
ml.grpcProcesses[s].Stop()
|
ml.grpcProcesses[s].Stop()
|
||||||
delete(ml.grpcProcesses, s)
|
delete(ml.grpcProcesses, s)
|
||||||
@ -164,3 +148,81 @@ func (ml *ModelLoader) checkIsLoaded(s string) *grpc.Client {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ml *ModelLoader) EvaluateTemplateForPrompt(templateType TemplateType, templateName string, in PromptTemplateData) (string, error) {
|
||||||
|
// TODO: should this check be improved?
|
||||||
|
if templateType == ChatMessageTemplate {
|
||||||
|
return "", fmt.Errorf("invalid templateType: ChatMessage")
|
||||||
|
}
|
||||||
|
return ml.evaluateTemplate(templateType, templateName, in)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ml *ModelLoader) EvaluateTemplateForChatMessage(templateName string, messageData ChatMessageTemplateData) (string, error) {
|
||||||
|
return ml.evaluateTemplate(ChatMessageTemplate, templateName, messageData)
|
||||||
|
}
|
||||||
|
|
||||||
|
func existsInPath(path string, s string) bool {
|
||||||
|
_, err := os.Stat(filepath.Join(path, s))
|
||||||
|
return err == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ml *ModelLoader) initializeTemplateMap() {
|
||||||
|
// This also seems somewhat clunky as we reference the Test / End of valid data value slug, but it works?
|
||||||
|
for tt := TemplateType(0); tt < IntegrationTestTemplate; tt++ {
|
||||||
|
ml.templates[tt] = make(map[string]*template.Template)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ml *ModelLoader) evaluateTemplate(templateType TemplateType, templateName string, in interface{}) (string, error) {
|
||||||
|
ml.mu.Lock()
|
||||||
|
defer ml.mu.Unlock()
|
||||||
|
|
||||||
|
m, ok := ml.templates[templateType][templateName]
|
||||||
|
if !ok {
|
||||||
|
// return "", fmt.Errorf("template not loaded: %s", templateName)
|
||||||
|
loadErr := ml.loadTemplateIfExists(templateType, templateName)
|
||||||
|
if loadErr != nil {
|
||||||
|
return "", loadErr
|
||||||
|
}
|
||||||
|
m = ml.templates[templateType][templateName] // ok is not important since we check m on the next line, and wealready checked
|
||||||
|
}
|
||||||
|
if m == nil {
|
||||||
|
return "", fmt.Errorf("failed loading a template for %s", templateName)
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
|
||||||
|
if err := m.Execute(&buf, in); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return buf.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ml *ModelLoader) loadTemplateIfExists(templateType TemplateType, templateName string) error {
|
||||||
|
// Check if the template was already loaded
|
||||||
|
if _, ok := ml.templates[templateType][templateName]; ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the model path exists
|
||||||
|
// skip any error here - we run anyway if a template does not exist
|
||||||
|
modelTemplateFile := fmt.Sprintf("%s.tmpl", templateName)
|
||||||
|
|
||||||
|
if !ml.ExistsInModelPath(modelTemplateFile) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
dat, err := os.ReadFile(filepath.Join(ml.ModelPath, modelTemplateFile))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the template
|
||||||
|
tmpl, err := template.New("prompt").Parse(string(dat))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
ml.templates[templateType][templateName] = tmpl
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
7
prompt-templates/llama2-chat-message.tmpl
Normal file
7
prompt-templates/llama2-chat-message.tmpl
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
{{if eq .RoleName "assistant"}}{{.Content}}{{else}}
|
||||||
|
[INST]
|
||||||
|
{{if .SystemPrompt}}{{.SystemPrompt}}{{else if eq .RoleName "system"}}<<SYS>>{{.Content}}<</SYS>>
|
||||||
|
|
||||||
|
{{else if .Content}}{{.Content}}{{end}}
|
||||||
|
[/INST]
|
||||||
|
{{end}}
|
23
tests/integration/reflect_test.go
Normal file
23
tests/integration/reflect_test.go
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
package integration_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
|
||||||
|
config "github.com/go-skynet/LocalAI/api/config"
|
||||||
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = Describe("Integration Tests involving reflection in liue of code generation", func() {
|
||||||
|
Context("config.TemplateConfig and model.TemplateType must stay in sync", func() {
|
||||||
|
|
||||||
|
ttc := reflect.TypeOf(config.TemplateConfig{})
|
||||||
|
|
||||||
|
It("TemplateConfig and TemplateType should have the same number of valid values", func() {
|
||||||
|
const lastValidTemplateType = model.IntegrationTestTemplate - 1
|
||||||
|
Expect(lastValidTemplateType).To(Equal(ttc.NumField()))
|
||||||
|
})
|
||||||
|
|
||||||
|
})
|
||||||
|
})
|
Loading…
Reference in New Issue
Block a user