mirror of
https://github.com/mudler/LocalAI.git
synced 2024-12-29 08:58:52 +00:00
296 lines
8.9 KiB
Go
296 lines
8.9 KiB
Go
|
package templates
|
||
|
|
||
|
import (
|
||
|
"encoding/json"
|
||
|
"fmt"
|
||
|
"strings"
|
||
|
|
||
|
"github.com/mudler/LocalAI/core/config"
|
||
|
"github.com/mudler/LocalAI/core/schema"
|
||
|
"github.com/mudler/LocalAI/pkg/functions"
|
||
|
"github.com/rs/zerolog/log"
|
||
|
)
|
||
|
|
||
|
// Rather than pass an interface{} to the prompt template:
|
||
|
// These are the definitions of all possible variables LocalAI will currently populate for use in a prompt template file
|
||
|
// Please note: Not all of these are populated on every endpoint - your template should either be tested for each endpoint you map it to, or tolerant of zero values.
|
||
|
type PromptTemplateData struct {
|
||
|
SystemPrompt string
|
||
|
SuppressSystemPrompt bool // used by chat specifically to indicate that SystemPrompt above should be _ignored_
|
||
|
Input string
|
||
|
Instruction string
|
||
|
Functions []functions.Function
|
||
|
MessageIndex int
|
||
|
}
|
||
|
|
||
|
type ChatMessageTemplateData struct {
|
||
|
SystemPrompt string
|
||
|
Role string
|
||
|
RoleName string
|
||
|
FunctionName string
|
||
|
Content string
|
||
|
MessageIndex int
|
||
|
Function bool
|
||
|
FunctionCall interface{}
|
||
|
LastMessage bool
|
||
|
}
|
||
|
|
||
|
const (
|
||
|
ChatPromptTemplate TemplateType = iota
|
||
|
ChatMessageTemplate
|
||
|
CompletionPromptTemplate
|
||
|
EditPromptTemplate
|
||
|
FunctionsPromptTemplate
|
||
|
)
|
||
|
|
||
|
type Evaluator struct {
|
||
|
cache *templateCache
|
||
|
}
|
||
|
|
||
|
func NewEvaluator(modelPath string) *Evaluator {
|
||
|
return &Evaluator{
|
||
|
cache: newTemplateCache(modelPath),
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (e *Evaluator) EvaluateTemplateForPrompt(templateType TemplateType, config config.BackendConfig, in PromptTemplateData) (string, error) {
|
||
|
template := ""
|
||
|
|
||
|
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||
|
if e.cache.existsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
|
||
|
template = config.Model
|
||
|
}
|
||
|
|
||
|
switch templateType {
|
||
|
case CompletionPromptTemplate:
|
||
|
if config.TemplateConfig.Completion != "" {
|
||
|
template = config.TemplateConfig.Completion
|
||
|
}
|
||
|
case EditPromptTemplate:
|
||
|
if config.TemplateConfig.Edit != "" {
|
||
|
template = config.TemplateConfig.Edit
|
||
|
}
|
||
|
case ChatPromptTemplate:
|
||
|
if config.TemplateConfig.Chat != "" {
|
||
|
template = config.TemplateConfig.Chat
|
||
|
}
|
||
|
case FunctionsPromptTemplate:
|
||
|
if config.TemplateConfig.Functions != "" {
|
||
|
template = config.TemplateConfig.Functions
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if template == "" {
|
||
|
return in.Input, nil
|
||
|
}
|
||
|
|
||
|
if config.TemplateConfig.JinjaTemplate {
|
||
|
return e.evaluateJinjaTemplateForPrompt(templateType, template, in)
|
||
|
}
|
||
|
|
||
|
return e.cache.evaluateTemplate(templateType, template, in)
|
||
|
}
|
||
|
|
||
|
func (e *Evaluator) evaluateTemplateForChatMessage(templateName string, messageData ChatMessageTemplateData) (string, error) {
|
||
|
return e.cache.evaluateTemplate(ChatMessageTemplate, templateName, messageData)
|
||
|
}
|
||
|
|
||
|
func (e *Evaluator) templateJinjaChat(templateName string, messageData []ChatMessageTemplateData, funcs []functions.Function) (string, error) {
|
||
|
|
||
|
conversation := make(map[string]interface{})
|
||
|
messages := make([]map[string]interface{}, len(messageData))
|
||
|
|
||
|
// convert from ChatMessageTemplateData to what the jinja template expects
|
||
|
|
||
|
for _, message := range messageData {
|
||
|
// TODO: this seems to cover minimum text templates. Can be expanded to cover more complex interactions
|
||
|
var data []byte
|
||
|
data, _ = json.Marshal(message.FunctionCall)
|
||
|
messages = append(messages, map[string]interface{}{
|
||
|
"role": message.RoleName,
|
||
|
"content": message.Content,
|
||
|
"tool_call": string(data),
|
||
|
})
|
||
|
}
|
||
|
|
||
|
conversation["messages"] = messages
|
||
|
|
||
|
// if tools are detected, add these
|
||
|
if len(funcs) > 0 {
|
||
|
conversation["tools"] = funcs
|
||
|
}
|
||
|
|
||
|
return e.cache.evaluateJinjaTemplate(ChatMessageTemplate, templateName, conversation)
|
||
|
}
|
||
|
|
||
|
func (e *Evaluator) evaluateJinjaTemplateForPrompt(templateType TemplateType, templateName string, in PromptTemplateData) (string, error) {
|
||
|
|
||
|
conversation := make(map[string]interface{})
|
||
|
|
||
|
conversation["system_prompt"] = in.SystemPrompt
|
||
|
conversation["content"] = in.Input
|
||
|
|
||
|
return e.cache.evaluateJinjaTemplate(templateType, templateName, conversation)
|
||
|
}
|
||
|
|
||
|
func (e *Evaluator) TemplateMessages(messages []schema.Message, config *config.BackendConfig, funcs []functions.Function, shouldUseFn bool) string {
|
||
|
|
||
|
if config.TemplateConfig.JinjaTemplate {
|
||
|
var messageData []ChatMessageTemplateData
|
||
|
for messageIndex, i := range messages {
|
||
|
fcall := i.FunctionCall
|
||
|
if len(i.ToolCalls) > 0 {
|
||
|
fcall = i.ToolCalls
|
||
|
}
|
||
|
messageData = append(messageData, ChatMessageTemplateData{
|
||
|
SystemPrompt: config.SystemPrompt,
|
||
|
Role: config.Roles[i.Role],
|
||
|
RoleName: i.Role,
|
||
|
Content: i.StringContent,
|
||
|
FunctionCall: fcall,
|
||
|
FunctionName: i.Name,
|
||
|
LastMessage: messageIndex == (len(messages) - 1),
|
||
|
Function: config.Grammar != "" && (messageIndex == (len(messages) - 1)),
|
||
|
MessageIndex: messageIndex,
|
||
|
})
|
||
|
}
|
||
|
|
||
|
templatedInput, err := e.templateJinjaChat(config.TemplateConfig.ChatMessage, messageData, funcs)
|
||
|
if err == nil {
|
||
|
return templatedInput
|
||
|
}
|
||
|
}
|
||
|
|
||
|
var predInput string
|
||
|
suppressConfigSystemPrompt := false
|
||
|
mess := []string{}
|
||
|
for messageIndex, i := range messages {
|
||
|
var content string
|
||
|
role := i.Role
|
||
|
|
||
|
// if function call, we might want to customize the role so we can display better that the "assistant called a json action"
|
||
|
// if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request
|
||
|
if (i.FunctionCall != nil || i.ToolCalls != nil) && i.Role == "assistant" {
|
||
|
roleFn := "assistant_function_call"
|
||
|
r := config.Roles[roleFn]
|
||
|
if r != "" {
|
||
|
role = roleFn
|
||
|
}
|
||
|
}
|
||
|
r := config.Roles[role]
|
||
|
contentExists := i.Content != nil && i.StringContent != ""
|
||
|
|
||
|
fcall := i.FunctionCall
|
||
|
if len(i.ToolCalls) > 0 {
|
||
|
fcall = i.ToolCalls
|
||
|
}
|
||
|
|
||
|
// First attempt to populate content via a chat message specific template
|
||
|
if config.TemplateConfig.ChatMessage != "" {
|
||
|
chatMessageData := ChatMessageTemplateData{
|
||
|
SystemPrompt: config.SystemPrompt,
|
||
|
Role: r,
|
||
|
RoleName: role,
|
||
|
Content: i.StringContent,
|
||
|
FunctionCall: fcall,
|
||
|
FunctionName: i.Name,
|
||
|
LastMessage: messageIndex == (len(messages) - 1),
|
||
|
Function: config.Grammar != "" && (messageIndex == (len(messages) - 1)),
|
||
|
MessageIndex: messageIndex,
|
||
|
}
|
||
|
templatedChatMessage, err := e.evaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData)
|
||
|
if err != nil {
|
||
|
log.Error().Err(err).Interface("message", chatMessageData).Str("template", config.TemplateConfig.ChatMessage).Msg("error processing message with template, skipping")
|
||
|
} else {
|
||
|
if templatedChatMessage == "" {
|
||
|
log.Warn().Msgf("template \"%s\" produced blank output for %+v. Skipping!", config.TemplateConfig.ChatMessage, chatMessageData)
|
||
|
continue // TODO: This continue is here intentionally to skip over the line `mess = append(mess, content)` below, and to prevent the sprintf
|
||
|
}
|
||
|
log.Debug().Msgf("templated message for chat: %s", templatedChatMessage)
|
||
|
content = templatedChatMessage
|
||
|
}
|
||
|
}
|
||
|
|
||
|
marshalAnyRole := func(f any) {
|
||
|
j, err := json.Marshal(f)
|
||
|
if err == nil {
|
||
|
if contentExists {
|
||
|
content += "\n" + fmt.Sprint(r, " ", string(j))
|
||
|
} else {
|
||
|
content = fmt.Sprint(r, " ", string(j))
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
marshalAny := func(f any) {
|
||
|
j, err := json.Marshal(f)
|
||
|
if err == nil {
|
||
|
if contentExists {
|
||
|
content += "\n" + string(j)
|
||
|
} else {
|
||
|
content = string(j)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
// If this model doesn't have such a template, or if that template fails to return a value, template at the message level.
|
||
|
if content == "" {
|
||
|
if r != "" {
|
||
|
if contentExists {
|
||
|
content = fmt.Sprint(r, i.StringContent)
|
||
|
}
|
||
|
|
||
|
if i.FunctionCall != nil {
|
||
|
marshalAnyRole(i.FunctionCall)
|
||
|
}
|
||
|
if i.ToolCalls != nil {
|
||
|
marshalAnyRole(i.ToolCalls)
|
||
|
}
|
||
|
} else {
|
||
|
if contentExists {
|
||
|
content = fmt.Sprint(i.StringContent)
|
||
|
}
|
||
|
if i.FunctionCall != nil {
|
||
|
marshalAny(i.FunctionCall)
|
||
|
}
|
||
|
if i.ToolCalls != nil {
|
||
|
marshalAny(i.ToolCalls)
|
||
|
}
|
||
|
}
|
||
|
// Special Handling: System. We care if it was printed at all, not the r branch, so check seperately
|
||
|
if contentExists && role == "system" {
|
||
|
suppressConfigSystemPrompt = true
|
||
|
}
|
||
|
}
|
||
|
|
||
|
mess = append(mess, content)
|
||
|
}
|
||
|
|
||
|
joinCharacter := "\n"
|
||
|
if config.TemplateConfig.JoinChatMessagesByCharacter != nil {
|
||
|
joinCharacter = *config.TemplateConfig.JoinChatMessagesByCharacter
|
||
|
}
|
||
|
|
||
|
predInput = strings.Join(mess, joinCharacter)
|
||
|
log.Debug().Msgf("Prompt (before templating): %s", predInput)
|
||
|
|
||
|
promptTemplate := ChatPromptTemplate
|
||
|
|
||
|
if config.TemplateConfig.Functions != "" && shouldUseFn {
|
||
|
promptTemplate = FunctionsPromptTemplate
|
||
|
}
|
||
|
|
||
|
templatedInput, err := e.EvaluateTemplateForPrompt(promptTemplate, *config, PromptTemplateData{
|
||
|
SystemPrompt: config.SystemPrompt,
|
||
|
SuppressSystemPrompt: suppressConfigSystemPrompt,
|
||
|
Input: predInput,
|
||
|
Functions: funcs,
|
||
|
})
|
||
|
if err == nil {
|
||
|
predInput = templatedInput
|
||
|
log.Debug().Msgf("Template found, input modified to: %s", predInput)
|
||
|
} else {
|
||
|
log.Debug().Msgf("Template failed loading: %s", err.Error())
|
||
|
}
|
||
|
|
||
|
return predInput
|
||
|
}
|