diff --git a/backend/backend.proto b/backend/backend.proto index 6ef83567..31bd63e5 100644 --- a/backend/backend.proto +++ b/backend/backend.proto @@ -135,6 +135,7 @@ message PredictOptions { bool UseTokenizerTemplate = 43; repeated Message Messages = 44; repeated string Videos = 45; + repeated string Audios = 46; } // The response message containing the result diff --git a/core/backend/llm.go b/core/backend/llm.go index fa4c0709..f74071ba 100644 --- a/core/backend/llm.go +++ b/core/backend/llm.go @@ -31,7 +31,7 @@ type TokenUsage struct { Completion int } -func ModelInference(ctx context.Context, s string, messages []schema.Message, images, videos []string, loader *model.ModelLoader, c config.BackendConfig, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) { +func ModelInference(ctx context.Context, s string, messages []schema.Message, images, videos, audios []string, loader *model.ModelLoader, c config.BackendConfig, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) { modelFile := c.Model threads := c.Threads if *threads == 0 && o.Threads != 0 { @@ -102,6 +102,7 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im opts.UseTokenizerTemplate = c.TemplateConfig.UseTokenizerTemplate opts.Images = images opts.Videos = videos + opts.Audios = audios tokenUsage := TokenUsage{} diff --git a/core/http/endpoints/openai/chat.go b/core/http/endpoints/openai/chat.go index 742a4add..b937120a 100644 --- a/core/http/endpoints/openai/chat.go +++ b/core/http/endpoints/openai/chat.go @@ -644,8 +644,12 @@ func handleQuestion(config *config.BackendConfig, input *schema.OpenAIRequest, m for _, m := range input.Messages { videos = append(videos, m.StringVideos...) } + audios := []string{} + for _, m := range input.Messages { + audios = append(audios, m.StringAudios...) + } - predFunc, err := backend.ModelInference(input.Context, prompt, input.Messages, images, videos, ml, *config, o, nil) + predFunc, err := backend.ModelInference(input.Context, prompt, input.Messages, images, videos, audios, ml, *config, o, nil) if err != nil { log.Error().Err(err).Msg("model inference failed") return "", err diff --git a/core/http/endpoints/openai/inference.go b/core/http/endpoints/openai/inference.go index 4008ba3d..da75d3a1 100644 --- a/core/http/endpoints/openai/inference.go +++ b/core/http/endpoints/openai/inference.go @@ -31,9 +31,13 @@ func ComputeChoices( for _, m := range req.Messages { videos = append(videos, m.StringVideos...) } + audios := []string{} + for _, m := range req.Messages { + audios = append(audios, m.StringAudios...) + } // get the model function to call for the result - predFunc, err := backend.ModelInference(req.Context, predInput, req.Messages, images, videos, loader, *config, o, tokenCallback) + predFunc, err := backend.ModelInference(req.Context, predInput, req.Messages, images, videos, audios, loader, *config, o, tokenCallback) if err != nil { return result, backend.TokenUsage{}, err } diff --git a/core/http/endpoints/openai/request.go b/core/http/endpoints/openai/request.go index 456a1e0c..e24dd28f 100644 --- a/core/http/endpoints/openai/request.go +++ b/core/http/endpoints/openai/request.go @@ -135,7 +135,7 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque } // Decode each request's message content - imgIndex, vidIndex := 0, 0 + imgIndex, vidIndex, audioIndex := 0, 0, 0 for i, m := range input.Messages { switch content := m.Content.(type) { case string: @@ -160,9 +160,19 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque // set a placeholder for each image input.Messages[i].StringContent = fmt.Sprintf("[vid-%d]", vidIndex) + input.Messages[i].StringContent vidIndex++ + case "audio_url", "audio": + // Decode content as base64 either if it's an URL or base64 text + base64, err := utils.GetContentURIAsBase64(pp.AudioURL.URL) + if err != nil { + log.Error().Msgf("Failed encoding image: %s", err) + continue CONTENT + } + input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, base64) // TODO: make sure that we only return base64 stuff + // set a placeholder for each image + input.Messages[i].StringContent = fmt.Sprintf("[audio-%d]", audioIndex) + input.Messages[i].StringContent + audioIndex++ case "image_url", "image": // Decode content as base64 either if it's an URL or base64 text - base64, err := utils.GetContentURIAsBase64(pp.ImageURL.URL) if err != nil { log.Error().Msgf("Failed encoding image: %s", err) diff --git a/core/schema/openai.go b/core/schema/openai.go index 32ed716b..15bcd13d 100644 --- a/core/schema/openai.go +++ b/core/schema/openai.go @@ -58,6 +58,7 @@ type Content struct { Type string `json:"type" yaml:"type"` Text string `json:"text" yaml:"text"` ImageURL ContentURL `json:"image_url" yaml:"image_url"` + AudioURL ContentURL `json:"audio_url" yaml:"audio_url"` VideoURL ContentURL `json:"video_url" yaml:"video_url"` } @@ -78,6 +79,7 @@ type Message struct { StringContent string `json:"string_content,omitempty" yaml:"string_content,omitempty"` StringImages []string `json:"string_images,omitempty" yaml:"string_images,omitempty"` StringVideos []string `json:"string_videos,omitempty" yaml:"string_videos,omitempty"` + StringAudios []string `json:"string_audios,omitempty" yaml:"string_audios,omitempty"` // A result of a function call FunctionCall interface{} `json:"function_call,omitempty" yaml:"function_call,omitempty"`