mirror of
https://github.com/mudler/LocalAI.git
synced 2025-03-10 22:43:59 +00:00
feat(multimodal): allow to template placeholders (#3728)
feat(multimodal): allow to template image placeholders Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
04c0841ca9
commit
648ffdf449
@ -196,6 +196,10 @@ type TemplateConfig struct {
|
|||||||
// JoinChatMessagesByCharacter is a string that will be used to join chat messages together.
|
// JoinChatMessagesByCharacter is a string that will be used to join chat messages together.
|
||||||
// It defaults to \n
|
// It defaults to \n
|
||||||
JoinChatMessagesByCharacter *string `yaml:"join_chat_messages_by_character"`
|
JoinChatMessagesByCharacter *string `yaml:"join_chat_messages_by_character"`
|
||||||
|
|
||||||
|
Video string `yaml:"video"`
|
||||||
|
Image string `yaml:"image"`
|
||||||
|
Audio string `yaml:"audio"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *BackendConfig) UnmarshalYAML(value *yaml.Node) error {
|
func (c *BackendConfig) UnmarshalYAML(value *yaml.Node) error {
|
||||||
|
@ -12,6 +12,7 @@ import (
|
|||||||
"github.com/mudler/LocalAI/core/schema"
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
"github.com/mudler/LocalAI/pkg/functions"
|
"github.com/mudler/LocalAI/pkg/functions"
|
||||||
"github.com/mudler/LocalAI/pkg/model"
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
|
"github.com/mudler/LocalAI/pkg/templates"
|
||||||
"github.com/mudler/LocalAI/pkg/utils"
|
"github.com/mudler/LocalAI/pkg/utils"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
@ -168,8 +169,13 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque
|
|||||||
continue CONTENT
|
continue CONTENT
|
||||||
}
|
}
|
||||||
input.Messages[i].StringVideos = append(input.Messages[i].StringVideos, base64) // TODO: make sure that we only return base64 stuff
|
input.Messages[i].StringVideos = append(input.Messages[i].StringVideos, base64) // TODO: make sure that we only return base64 stuff
|
||||||
|
|
||||||
|
t := "[vid-{{.ID}}]{{.Text}}"
|
||||||
|
if config.TemplateConfig.Video != "" {
|
||||||
|
t = config.TemplateConfig.Video
|
||||||
|
}
|
||||||
// set a placeholder for each image
|
// set a placeholder for each image
|
||||||
input.Messages[i].StringContent = fmt.Sprintf("[vid-%d]", vidIndex) + input.Messages[i].StringContent
|
input.Messages[i].StringContent, _ = templates.TemplateMultiModal(t, vidIndex, input.Messages[i].StringContent)
|
||||||
vidIndex++
|
vidIndex++
|
||||||
case "audio_url", "audio":
|
case "audio_url", "audio":
|
||||||
// Decode content as base64 either if it's an URL or base64 text
|
// Decode content as base64 either if it's an URL or base64 text
|
||||||
@ -180,7 +186,11 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque
|
|||||||
}
|
}
|
||||||
input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, base64) // TODO: make sure that we only return base64 stuff
|
input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, base64) // TODO: make sure that we only return base64 stuff
|
||||||
// set a placeholder for each image
|
// set a placeholder for each image
|
||||||
input.Messages[i].StringContent = fmt.Sprintf("[audio-%d]", audioIndex) + input.Messages[i].StringContent
|
t := "[audio-{{.ID}}]{{.Text}}"
|
||||||
|
if config.TemplateConfig.Audio != "" {
|
||||||
|
t = config.TemplateConfig.Audio
|
||||||
|
}
|
||||||
|
input.Messages[i].StringContent, _ = templates.TemplateMultiModal(t, audioIndex, input.Messages[i].StringContent)
|
||||||
audioIndex++
|
audioIndex++
|
||||||
case "image_url", "image":
|
case "image_url", "image":
|
||||||
// Decode content as base64 either if it's an URL or base64 text
|
// Decode content as base64 either if it's an URL or base64 text
|
||||||
@ -189,9 +199,14 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque
|
|||||||
log.Error().Msgf("Failed encoding image: %s", err)
|
log.Error().Msgf("Failed encoding image: %s", err)
|
||||||
continue CONTENT
|
continue CONTENT
|
||||||
}
|
}
|
||||||
|
|
||||||
|
t := "[img-{{.ID}}]{{.Text}}"
|
||||||
|
if config.TemplateConfig.Image != "" {
|
||||||
|
t = config.TemplateConfig.Image
|
||||||
|
}
|
||||||
input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff
|
input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff
|
||||||
// set a placeholder for each image
|
// set a placeholder for each image
|
||||||
input.Messages[i].StringContent = fmt.Sprintf("[img-%d]", imgIndex) + input.Messages[i].StringContent
|
input.Messages[i].StringContent, _ = templates.TemplateMultiModal(t, imgIndex, input.Messages[i].StringContent)
|
||||||
imgIndex++
|
imgIndex++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -314,7 +314,7 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string
|
|||||||
|
|
||||||
client = NewModel(modelID, serverAddress, process)
|
client = NewModel(modelID, serverAddress, process)
|
||||||
} else {
|
} else {
|
||||||
log.Debug().Msg("external backend is uri")
|
log.Debug().Msg("external backend is a uri")
|
||||||
// address
|
// address
|
||||||
client = NewModel(modelID, uri, nil)
|
client = NewModel(modelID, uri, nil)
|
||||||
}
|
}
|
||||||
|
24
pkg/templates/multimodal.go
Normal file
24
pkg/templates/multimodal.go
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
package templates
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"text/template"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TemplateMultiModal(templateString string, templateID int, text string) (string, error) {
|
||||||
|
// compile the template
|
||||||
|
tmpl, err := template.New("template").Parse(templateString)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
result := bytes.NewBuffer(nil)
|
||||||
|
// execute the template
|
||||||
|
err = tmpl.Execute(result, struct {
|
||||||
|
ID int
|
||||||
|
Text string
|
||||||
|
}{
|
||||||
|
ID: templateID,
|
||||||
|
Text: text,
|
||||||
|
})
|
||||||
|
return result.String(), err
|
||||||
|
}
|
19
pkg/templates/multimodal_test.go
Normal file
19
pkg/templates/multimodal_test.go
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
package templates_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
. "github.com/mudler/LocalAI/pkg/templates" // Update with your module path
|
||||||
|
|
||||||
|
// Update with your module path
|
||||||
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = Describe("EvaluateTemplate", func() {
|
||||||
|
Context("templating simple strings for multimodal chat", func() {
|
||||||
|
It("should template messages correctly", func() {
|
||||||
|
result, err := TemplateMultiModal("[img-{{.ID}}]{{.Text}}", 1, "bar")
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
Expect(result).To(Equal("[img-1]bar"))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
Loading…
x
Reference in New Issue
Block a user