diff --git a/core/backend/llm.go b/core/backend/llm.go index 1cad6db5..14eb8569 100644 --- a/core/backend/llm.go +++ b/core/backend/llm.go @@ -116,6 +116,11 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im } if tokenCallback != nil { + + if c.TemplateConfig.ReplyPrefix != "" { + tokenCallback(c.TemplateConfig.ReplyPrefix, tokenUsage) + } + ss := "" var partialRune []byte @@ -165,8 +170,13 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im tokenUsage.TimingTokenGeneration = reply.TimingTokenGeneration tokenUsage.TimingPromptProcessing = reply.TimingPromptProcessing + response := string(reply.Message) + if c.TemplateConfig.ReplyPrefix != "" { + response = c.TemplateConfig.ReplyPrefix + response + } + return LLMResponse{ - Response: string(reply.Message), + Response: response, Usage: tokenUsage, }, err } diff --git a/core/config/backend_config.go b/core/config/backend_config.go index 1da58c9d..56ffa38c 100644 --- a/core/config/backend_config.go +++ b/core/config/backend_config.go @@ -130,28 +130,28 @@ type LLMConfig struct { TrimSpace []string `yaml:"trimspace"` TrimSuffix []string `yaml:"trimsuffix"` - ContextSize *int `yaml:"context_size"` - NUMA bool `yaml:"numa"` - LoraAdapter string `yaml:"lora_adapter"` - LoraBase string `yaml:"lora_base"` - LoraAdapters []string `yaml:"lora_adapters"` - LoraScales []float32 `yaml:"lora_scales"` - LoraScale float32 `yaml:"lora_scale"` - NoMulMatQ bool `yaml:"no_mulmatq"` - DraftModel string `yaml:"draft_model"` - NDraft int32 `yaml:"n_draft"` - Quantization string `yaml:"quantization"` - LoadFormat string `yaml:"load_format"` - GPUMemoryUtilization float32 `yaml:"gpu_memory_utilization"` // vLLM - TrustRemoteCode bool `yaml:"trust_remote_code"` // vLLM - EnforceEager bool `yaml:"enforce_eager"` // vLLM - SwapSpace int `yaml:"swap_space"` // vLLM - MaxModelLen int `yaml:"max_model_len"` // vLLM - TensorParallelSize int `yaml:"tensor_parallel_size"` // vLLM - DisableLogStatus bool `yaml:"disable_log_stats"` // vLLM - DType string `yaml:"dtype"` // vLLM - LimitMMPerPrompt LimitMMPerPrompt `yaml:"limit_mm_per_prompt"` // vLLM - MMProj string `yaml:"mmproj"` + ContextSize *int `yaml:"context_size"` + NUMA bool `yaml:"numa"` + LoraAdapter string `yaml:"lora_adapter"` + LoraBase string `yaml:"lora_base"` + LoraAdapters []string `yaml:"lora_adapters"` + LoraScales []float32 `yaml:"lora_scales"` + LoraScale float32 `yaml:"lora_scale"` + NoMulMatQ bool `yaml:"no_mulmatq"` + DraftModel string `yaml:"draft_model"` + NDraft int32 `yaml:"n_draft"` + Quantization string `yaml:"quantization"` + LoadFormat string `yaml:"load_format"` + GPUMemoryUtilization float32 `yaml:"gpu_memory_utilization"` // vLLM + TrustRemoteCode bool `yaml:"trust_remote_code"` // vLLM + EnforceEager bool `yaml:"enforce_eager"` // vLLM + SwapSpace int `yaml:"swap_space"` // vLLM + MaxModelLen int `yaml:"max_model_len"` // vLLM + TensorParallelSize int `yaml:"tensor_parallel_size"` // vLLM + DisableLogStatus bool `yaml:"disable_log_stats"` // vLLM + DType string `yaml:"dtype"` // vLLM + LimitMMPerPrompt LimitMMPerPrompt `yaml:"limit_mm_per_prompt"` // vLLM + MMProj string `yaml:"mmproj"` FlashAttention bool `yaml:"flash_attention"` NoKVOffloading bool `yaml:"no_kv_offloading"` @@ -171,9 +171,9 @@ type LLMConfig struct { // LimitMMPerPrompt is a struct that holds the configuration for the limit-mm-per-prompt config in vLLM type LimitMMPerPrompt struct { - LimitImagePerPrompt int `yaml:"image"` - LimitVideoPerPrompt int `yaml:"video"` - LimitAudioPerPrompt int `yaml:"audio"` + LimitImagePerPrompt int `yaml:"image"` + LimitVideoPerPrompt int `yaml:"video"` + LimitAudioPerPrompt int `yaml:"audio"` } // AutoGPTQ is a struct that holds the configuration specific to the AutoGPTQ backend @@ -213,6 +213,8 @@ type TemplateConfig struct { Multimodal string `yaml:"multimodal"` JinjaTemplate bool `yaml:"jinja_template"` + + ReplyPrefix string `yaml:"reply_prefix"` } func (c *BackendConfig) UnmarshalYAML(value *yaml.Node) error {