mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-05 18:18:20 +00:00
feat(templates): use a single template for multimodals messages (#3892)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
a1d6cc93a8
commit
ccc7cb0287
@ -197,9 +197,7 @@ type TemplateConfig struct {
|
|||||||
// It defaults to \n
|
// It defaults to \n
|
||||||
JoinChatMessagesByCharacter *string `yaml:"join_chat_messages_by_character"`
|
JoinChatMessagesByCharacter *string `yaml:"join_chat_messages_by_character"`
|
||||||
|
|
||||||
Video string `yaml:"video"`
|
Multimodal string `yaml:"multimodal"`
|
||||||
Image string `yaml:"image"`
|
|
||||||
Audio string `yaml:"audio"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *BackendConfig) UnmarshalYAML(value *yaml.Node) error {
|
func (c *BackendConfig) UnmarshalYAML(value *yaml.Node) error {
|
||||||
|
@ -149,6 +149,10 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque
|
|||||||
// Decode each request's message content
|
// Decode each request's message content
|
||||||
imgIndex, vidIndex, audioIndex := 0, 0, 0
|
imgIndex, vidIndex, audioIndex := 0, 0, 0
|
||||||
for i, m := range input.Messages {
|
for i, m := range input.Messages {
|
||||||
|
nrOfImgsInMessage := 0
|
||||||
|
nrOfVideosInMessage := 0
|
||||||
|
nrOfAudiosInMessage := 0
|
||||||
|
|
||||||
switch content := m.Content.(type) {
|
switch content := m.Content.(type) {
|
||||||
case string:
|
case string:
|
||||||
input.Messages[i].StringContent = content
|
input.Messages[i].StringContent = content
|
||||||
@ -156,11 +160,16 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque
|
|||||||
dat, _ := json.Marshal(content)
|
dat, _ := json.Marshal(content)
|
||||||
c := []schema.Content{}
|
c := []schema.Content{}
|
||||||
json.Unmarshal(dat, &c)
|
json.Unmarshal(dat, &c)
|
||||||
|
|
||||||
|
textContent := ""
|
||||||
|
// we will template this at the end
|
||||||
|
|
||||||
CONTENT:
|
CONTENT:
|
||||||
for _, pp := range c {
|
for _, pp := range c {
|
||||||
switch pp.Type {
|
switch pp.Type {
|
||||||
case "text":
|
case "text":
|
||||||
input.Messages[i].StringContent = pp.Text
|
textContent += pp.Text
|
||||||
|
//input.Messages[i].StringContent = pp.Text
|
||||||
case "video", "video_url":
|
case "video", "video_url":
|
||||||
// Decode content as base64 either if it's an URL or base64 text
|
// Decode content as base64 either if it's an URL or base64 text
|
||||||
base64, err := utils.GetContentURIAsBase64(pp.VideoURL.URL)
|
base64, err := utils.GetContentURIAsBase64(pp.VideoURL.URL)
|
||||||
@ -169,14 +178,8 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque
|
|||||||
continue CONTENT
|
continue CONTENT
|
||||||
}
|
}
|
||||||
input.Messages[i].StringVideos = append(input.Messages[i].StringVideos, base64) // TODO: make sure that we only return base64 stuff
|
input.Messages[i].StringVideos = append(input.Messages[i].StringVideos, base64) // TODO: make sure that we only return base64 stuff
|
||||||
|
|
||||||
t := "[vid-{{.ID}}]{{.Text}}"
|
|
||||||
if config.TemplateConfig.Video != "" {
|
|
||||||
t = config.TemplateConfig.Video
|
|
||||||
}
|
|
||||||
// set a placeholder for each image
|
|
||||||
input.Messages[i].StringContent, _ = templates.TemplateMultiModal(t, vidIndex, input.Messages[i].StringContent)
|
|
||||||
vidIndex++
|
vidIndex++
|
||||||
|
nrOfVideosInMessage++
|
||||||
case "audio_url", "audio":
|
case "audio_url", "audio":
|
||||||
// Decode content as base64 either if it's an URL or base64 text
|
// Decode content as base64 either if it's an URL or base64 text
|
||||||
base64, err := utils.GetContentURIAsBase64(pp.AudioURL.URL)
|
base64, err := utils.GetContentURIAsBase64(pp.AudioURL.URL)
|
||||||
@ -185,13 +188,8 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque
|
|||||||
continue CONTENT
|
continue CONTENT
|
||||||
}
|
}
|
||||||
input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, base64) // TODO: make sure that we only return base64 stuff
|
input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, base64) // TODO: make sure that we only return base64 stuff
|
||||||
// set a placeholder for each image
|
|
||||||
t := "[audio-{{.ID}}]{{.Text}}"
|
|
||||||
if config.TemplateConfig.Audio != "" {
|
|
||||||
t = config.TemplateConfig.Audio
|
|
||||||
}
|
|
||||||
input.Messages[i].StringContent, _ = templates.TemplateMultiModal(t, audioIndex, input.Messages[i].StringContent)
|
|
||||||
audioIndex++
|
audioIndex++
|
||||||
|
nrOfAudiosInMessage++
|
||||||
case "image_url", "image":
|
case "image_url", "image":
|
||||||
// Decode content as base64 either if it's an URL or base64 text
|
// Decode content as base64 either if it's an URL or base64 text
|
||||||
base64, err := utils.GetContentURIAsBase64(pp.ImageURL.URL)
|
base64, err := utils.GetContentURIAsBase64(pp.ImageURL.URL)
|
||||||
@ -200,16 +198,21 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque
|
|||||||
continue CONTENT
|
continue CONTENT
|
||||||
}
|
}
|
||||||
|
|
||||||
t := "[img-{{.ID}}]{{.Text}}"
|
|
||||||
if config.TemplateConfig.Image != "" {
|
|
||||||
t = config.TemplateConfig.Image
|
|
||||||
}
|
|
||||||
input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff
|
input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff
|
||||||
// set a placeholder for each image
|
|
||||||
input.Messages[i].StringContent, _ = templates.TemplateMultiModal(t, imgIndex, input.Messages[i].StringContent)
|
|
||||||
imgIndex++
|
imgIndex++
|
||||||
|
nrOfImgsInMessage++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
input.Messages[i].StringContent, _ = templates.TemplateMultiModal(config.TemplateConfig.Multimodal, templates.MultiModalOptions{
|
||||||
|
TotalImages: imgIndex,
|
||||||
|
TotalVideos: vidIndex,
|
||||||
|
TotalAudios: audioIndex,
|
||||||
|
ImagesInMessage: nrOfImgsInMessage,
|
||||||
|
VideosInMessage: nrOfVideosInMessage,
|
||||||
|
AudiosInMessage: nrOfAudiosInMessage,
|
||||||
|
}, textContent)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -7,19 +7,59 @@ import (
|
|||||||
"github.com/Masterminds/sprig/v3"
|
"github.com/Masterminds/sprig/v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TemplateMultiModal(templateString string, templateID int, text string) (string, error) {
|
type MultiModalOptions struct {
|
||||||
|
TotalImages int
|
||||||
|
TotalAudios int
|
||||||
|
TotalVideos int
|
||||||
|
|
||||||
|
ImagesInMessage int
|
||||||
|
AudiosInMessage int
|
||||||
|
VideosInMessage int
|
||||||
|
}
|
||||||
|
|
||||||
|
type MultimodalContent struct {
|
||||||
|
ID int
|
||||||
|
}
|
||||||
|
|
||||||
|
const DefaultMultiModalTemplate = "{{ range .Audio }}[audio-{{.ID}}]{{end}}{{ range .Images }}[img-{{.ID}}]{{end}}{{ range .Video }}[vid-{{.ID}}]{{end}}{{.Text}}"
|
||||||
|
|
||||||
|
func TemplateMultiModal(templateString string, opts MultiModalOptions, text string) (string, error) {
|
||||||
|
if templateString == "" {
|
||||||
|
templateString = DefaultMultiModalTemplate
|
||||||
|
}
|
||||||
|
|
||||||
// compile the template
|
// compile the template
|
||||||
tmpl, err := template.New("template").Funcs(sprig.FuncMap()).Parse(templateString)
|
tmpl, err := template.New("template").Funcs(sprig.FuncMap()).Parse(templateString)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
videos := []MultimodalContent{}
|
||||||
|
for i := 0; i < opts.VideosInMessage; i++ {
|
||||||
|
videos = append(videos, MultimodalContent{ID: i + (opts.TotalVideos - opts.VideosInMessage)})
|
||||||
|
}
|
||||||
|
|
||||||
|
audios := []MultimodalContent{}
|
||||||
|
for i := 0; i < opts.AudiosInMessage; i++ {
|
||||||
|
audios = append(audios, MultimodalContent{ID: i + (opts.TotalAudios - opts.AudiosInMessage)})
|
||||||
|
}
|
||||||
|
|
||||||
|
images := []MultimodalContent{}
|
||||||
|
for i := 0; i < opts.ImagesInMessage; i++ {
|
||||||
|
images = append(images, MultimodalContent{ID: i + (opts.TotalImages - opts.ImagesInMessage)})
|
||||||
|
}
|
||||||
|
|
||||||
result := bytes.NewBuffer(nil)
|
result := bytes.NewBuffer(nil)
|
||||||
// execute the template
|
// execute the template
|
||||||
err = tmpl.Execute(result, struct {
|
err = tmpl.Execute(result, struct {
|
||||||
ID int
|
Audio []MultimodalContent
|
||||||
|
Images []MultimodalContent
|
||||||
|
Video []MultimodalContent
|
||||||
Text string
|
Text string
|
||||||
}{
|
}{
|
||||||
ID: templateID,
|
Audio: audios,
|
||||||
|
Images: images,
|
||||||
|
Video: videos,
|
||||||
Text: text,
|
Text: text,
|
||||||
})
|
})
|
||||||
return result.String(), err
|
return result.String(), err
|
||||||
|
@ -11,7 +11,77 @@ import (
|
|||||||
var _ = Describe("EvaluateTemplate", func() {
|
var _ = Describe("EvaluateTemplate", func() {
|
||||||
Context("templating simple strings for multimodal chat", func() {
|
Context("templating simple strings for multimodal chat", func() {
|
||||||
It("should template messages correctly", func() {
|
It("should template messages correctly", func() {
|
||||||
result, err := TemplateMultiModal("[img-{{.ID}}]{{.Text}}", 1, "bar")
|
result, err := TemplateMultiModal("", MultiModalOptions{
|
||||||
|
TotalImages: 1,
|
||||||
|
TotalAudios: 0,
|
||||||
|
TotalVideos: 0,
|
||||||
|
ImagesInMessage: 1,
|
||||||
|
AudiosInMessage: 0,
|
||||||
|
VideosInMessage: 0,
|
||||||
|
}, "bar")
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
Expect(result).To(Equal("[img-0]bar"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should handle messages with more images correctly", func() {
|
||||||
|
result, err := TemplateMultiModal("", MultiModalOptions{
|
||||||
|
TotalImages: 2,
|
||||||
|
TotalAudios: 0,
|
||||||
|
TotalVideos: 0,
|
||||||
|
ImagesInMessage: 2,
|
||||||
|
AudiosInMessage: 0,
|
||||||
|
VideosInMessage: 0,
|
||||||
|
}, "bar")
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
Expect(result).To(Equal("[img-0][img-1]bar"))
|
||||||
|
})
|
||||||
|
It("should handle messages with more images correctly", func() {
|
||||||
|
result, err := TemplateMultiModal("", MultiModalOptions{
|
||||||
|
TotalImages: 4,
|
||||||
|
TotalAudios: 1,
|
||||||
|
TotalVideos: 0,
|
||||||
|
ImagesInMessage: 2,
|
||||||
|
AudiosInMessage: 1,
|
||||||
|
VideosInMessage: 0,
|
||||||
|
}, "bar")
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
Expect(result).To(Equal("[audio-0][img-2][img-3]bar"))
|
||||||
|
})
|
||||||
|
It("should handle messages with more images correctly", func() {
|
||||||
|
result, err := TemplateMultiModal("", MultiModalOptions{
|
||||||
|
TotalImages: 3,
|
||||||
|
TotalAudios: 1,
|
||||||
|
TotalVideos: 0,
|
||||||
|
ImagesInMessage: 1,
|
||||||
|
AudiosInMessage: 1,
|
||||||
|
VideosInMessage: 0,
|
||||||
|
}, "bar")
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
Expect(result).To(Equal("[audio-0][img-2]bar"))
|
||||||
|
})
|
||||||
|
It("should handle messages with more images correctly", func() {
|
||||||
|
result, err := TemplateMultiModal("", MultiModalOptions{
|
||||||
|
TotalImages: 0,
|
||||||
|
TotalAudios: 0,
|
||||||
|
TotalVideos: 0,
|
||||||
|
ImagesInMessage: 0,
|
||||||
|
AudiosInMessage: 0,
|
||||||
|
VideosInMessage: 0,
|
||||||
|
}, "bar")
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
Expect(result).To(Equal("bar"))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
Context("templating with custom defaults", func() {
|
||||||
|
It("should handle messages with more images correctly", func() {
|
||||||
|
result, err := TemplateMultiModal("{{ range .Audio }}[audio-{{ add1 .ID}}]{{end}}{{ range .Images }}[img-{{ add1 .ID}}]{{end}}{{ range .Video }}[vid-{{ add1 .ID}}]{{end}}{{.Text}}", MultiModalOptions{
|
||||||
|
TotalImages: 1,
|
||||||
|
TotalAudios: 0,
|
||||||
|
TotalVideos: 0,
|
||||||
|
ImagesInMessage: 1,
|
||||||
|
AudiosInMessage: 0,
|
||||||
|
VideosInMessage: 0,
|
||||||
|
}, "bar")
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
Expect(result).To(Equal("[img-1]bar"))
|
Expect(result).To(Equal("[img-1]bar"))
|
||||||
})
|
})
|
||||||
|
Loading…
x
Reference in New Issue
Block a user