From ccc7cb0287eaba505d033b3511bf9469b4dde4e7 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Tue, 22 Oct 2024 09:34:05 +0200 Subject: [PATCH] feat(templates): use a single template for multimodals messages (#3892) Signed-off-by: Ettore Di Giacinto --- core/config/backend_config.go | 4 +- core/http/endpoints/openai/request.go | 43 ++++++++-------- pkg/templates/multimodal.go | 50 +++++++++++++++++-- pkg/templates/multimodal_test.go | 72 ++++++++++++++++++++++++++- 4 files changed, 140 insertions(+), 29 deletions(-) diff --git a/core/config/backend_config.go b/core/config/backend_config.go index 79e134d8..b386d096 100644 --- a/core/config/backend_config.go +++ b/core/config/backend_config.go @@ -197,9 +197,7 @@ type TemplateConfig struct { // It defaults to \n JoinChatMessagesByCharacter *string `yaml:"join_chat_messages_by_character"` - Video string `yaml:"video"` - Image string `yaml:"image"` - Audio string `yaml:"audio"` + Multimodal string `yaml:"multimodal"` } func (c *BackendConfig) UnmarshalYAML(value *yaml.Node) error { diff --git a/core/http/endpoints/openai/request.go b/core/http/endpoints/openai/request.go index a418433e..1309fa82 100644 --- a/core/http/endpoints/openai/request.go +++ b/core/http/endpoints/openai/request.go @@ -149,6 +149,10 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque // Decode each request's message content imgIndex, vidIndex, audioIndex := 0, 0, 0 for i, m := range input.Messages { + nrOfImgsInMessage := 0 + nrOfVideosInMessage := 0 + nrOfAudiosInMessage := 0 + switch content := m.Content.(type) { case string: input.Messages[i].StringContent = content @@ -156,11 +160,16 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque dat, _ := json.Marshal(content) c := []schema.Content{} json.Unmarshal(dat, &c) + + textContent := "" + // we will template this at the end + CONTENT: for _, pp := range c { switch pp.Type { case "text": - input.Messages[i].StringContent = pp.Text + textContent += pp.Text + //input.Messages[i].StringContent = pp.Text case "video", "video_url": // Decode content as base64 either if it's an URL or base64 text base64, err := utils.GetContentURIAsBase64(pp.VideoURL.URL) @@ -169,14 +178,8 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque continue CONTENT } input.Messages[i].StringVideos = append(input.Messages[i].StringVideos, base64) // TODO: make sure that we only return base64 stuff - - t := "[vid-{{.ID}}]{{.Text}}" - if config.TemplateConfig.Video != "" { - t = config.TemplateConfig.Video - } - // set a placeholder for each image - input.Messages[i].StringContent, _ = templates.TemplateMultiModal(t, vidIndex, input.Messages[i].StringContent) vidIndex++ + nrOfVideosInMessage++ case "audio_url", "audio": // Decode content as base64 either if it's an URL or base64 text base64, err := utils.GetContentURIAsBase64(pp.AudioURL.URL) @@ -185,13 +188,8 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque continue CONTENT } input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, base64) // TODO: make sure that we only return base64 stuff - // set a placeholder for each image - t := "[audio-{{.ID}}]{{.Text}}" - if config.TemplateConfig.Audio != "" { - t = config.TemplateConfig.Audio - } - input.Messages[i].StringContent, _ = templates.TemplateMultiModal(t, audioIndex, input.Messages[i].StringContent) audioIndex++ + nrOfAudiosInMessage++ case "image_url", "image": // Decode content as base64 either if it's an URL or base64 text base64, err := utils.GetContentURIAsBase64(pp.ImageURL.URL) @@ -200,16 +198,21 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque continue CONTENT } - t := "[img-{{.ID}}]{{.Text}}" - if config.TemplateConfig.Image != "" { - t = config.TemplateConfig.Image - } input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff - // set a placeholder for each image - input.Messages[i].StringContent, _ = templates.TemplateMultiModal(t, imgIndex, input.Messages[i].StringContent) + imgIndex++ + nrOfImgsInMessage++ } } + + input.Messages[i].StringContent, _ = templates.TemplateMultiModal(config.TemplateConfig.Multimodal, templates.MultiModalOptions{ + TotalImages: imgIndex, + TotalVideos: vidIndex, + TotalAudios: audioIndex, + ImagesInMessage: nrOfImgsInMessage, + VideosInMessage: nrOfVideosInMessage, + AudiosInMessage: nrOfAudiosInMessage, + }, textContent) } } diff --git a/pkg/templates/multimodal.go b/pkg/templates/multimodal.go index a2056640..3a19b07a 100644 --- a/pkg/templates/multimodal.go +++ b/pkg/templates/multimodal.go @@ -7,20 +7,60 @@ import ( "github.com/Masterminds/sprig/v3" ) -func TemplateMultiModal(templateString string, templateID int, text string) (string, error) { +type MultiModalOptions struct { + TotalImages int + TotalAudios int + TotalVideos int + + ImagesInMessage int + AudiosInMessage int + VideosInMessage int +} + +type MultimodalContent struct { + ID int +} + +const DefaultMultiModalTemplate = "{{ range .Audio }}[audio-{{.ID}}]{{end}}{{ range .Images }}[img-{{.ID}}]{{end}}{{ range .Video }}[vid-{{.ID}}]{{end}}{{.Text}}" + +func TemplateMultiModal(templateString string, opts MultiModalOptions, text string) (string, error) { + if templateString == "" { + templateString = DefaultMultiModalTemplate + } + // compile the template tmpl, err := template.New("template").Funcs(sprig.FuncMap()).Parse(templateString) if err != nil { return "", err } + + videos := []MultimodalContent{} + for i := 0; i < opts.VideosInMessage; i++ { + videos = append(videos, MultimodalContent{ID: i + (opts.TotalVideos - opts.VideosInMessage)}) + } + + audios := []MultimodalContent{} + for i := 0; i < opts.AudiosInMessage; i++ { + audios = append(audios, MultimodalContent{ID: i + (opts.TotalAudios - opts.AudiosInMessage)}) + } + + images := []MultimodalContent{} + for i := 0; i < opts.ImagesInMessage; i++ { + images = append(images, MultimodalContent{ID: i + (opts.TotalImages - opts.ImagesInMessage)}) + } + result := bytes.NewBuffer(nil) // execute the template err = tmpl.Execute(result, struct { - ID int - Text string + Audio []MultimodalContent + Images []MultimodalContent + Video []MultimodalContent + Text string }{ - ID: templateID, - Text: text, + Audio: audios, + Images: images, + Video: videos, + Text: text, }) return result.String(), err } diff --git a/pkg/templates/multimodal_test.go b/pkg/templates/multimodal_test.go index d1a8bd5b..ef8607a7 100644 --- a/pkg/templates/multimodal_test.go +++ b/pkg/templates/multimodal_test.go @@ -11,7 +11,77 @@ import ( var _ = Describe("EvaluateTemplate", func() { Context("templating simple strings for multimodal chat", func() { It("should template messages correctly", func() { - result, err := TemplateMultiModal("[img-{{.ID}}]{{.Text}}", 1, "bar") + result, err := TemplateMultiModal("", MultiModalOptions{ + TotalImages: 1, + TotalAudios: 0, + TotalVideos: 0, + ImagesInMessage: 1, + AudiosInMessage: 0, + VideosInMessage: 0, + }, "bar") + Expect(err).NotTo(HaveOccurred()) + Expect(result).To(Equal("[img-0]bar")) + }) + + It("should handle messages with more images correctly", func() { + result, err := TemplateMultiModal("", MultiModalOptions{ + TotalImages: 2, + TotalAudios: 0, + TotalVideos: 0, + ImagesInMessage: 2, + AudiosInMessage: 0, + VideosInMessage: 0, + }, "bar") + Expect(err).NotTo(HaveOccurred()) + Expect(result).To(Equal("[img-0][img-1]bar")) + }) + It("should handle messages with more images correctly", func() { + result, err := TemplateMultiModal("", MultiModalOptions{ + TotalImages: 4, + TotalAudios: 1, + TotalVideos: 0, + ImagesInMessage: 2, + AudiosInMessage: 1, + VideosInMessage: 0, + }, "bar") + Expect(err).NotTo(HaveOccurred()) + Expect(result).To(Equal("[audio-0][img-2][img-3]bar")) + }) + It("should handle messages with more images correctly", func() { + result, err := TemplateMultiModal("", MultiModalOptions{ + TotalImages: 3, + TotalAudios: 1, + TotalVideos: 0, + ImagesInMessage: 1, + AudiosInMessage: 1, + VideosInMessage: 0, + }, "bar") + Expect(err).NotTo(HaveOccurred()) + Expect(result).To(Equal("[audio-0][img-2]bar")) + }) + It("should handle messages with more images correctly", func() { + result, err := TemplateMultiModal("", MultiModalOptions{ + TotalImages: 0, + TotalAudios: 0, + TotalVideos: 0, + ImagesInMessage: 0, + AudiosInMessage: 0, + VideosInMessage: 0, + }, "bar") + Expect(err).NotTo(HaveOccurred()) + Expect(result).To(Equal("bar")) + }) + }) + Context("templating with custom defaults", func() { + It("should handle messages with more images correctly", func() { + result, err := TemplateMultiModal("{{ range .Audio }}[audio-{{ add1 .ID}}]{{end}}{{ range .Images }}[img-{{ add1 .ID}}]{{end}}{{ range .Video }}[vid-{{ add1 .ID}}]{{end}}{{.Text}}", MultiModalOptions{ + TotalImages: 1, + TotalAudios: 0, + TotalVideos: 0, + ImagesInMessage: 1, + AudiosInMessage: 0, + VideosInMessage: 0, + }, "bar") Expect(err).NotTo(HaveOccurred()) Expect(result).To(Equal("[img-1]bar")) })