From 2f0612cb1c168dbecd2d94b9665b11d2f023ffe9 Mon Sep 17 00:00:00 2001 From: Gaurav Garg <52341457+gaugarg-nv@users.noreply.github.com> Date: Thu, 3 Apr 2025 21:50:29 +0530 Subject: [PATCH] CUDA: Prefer vector flash decoding kernel for Gemma models (llama/12738) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Prefer vector flash decoding kernel for Gemma models Vector flash decoding kernel was not being picked for models with head dimension 256. Gemma models are in this category. Removing this limit improves e2e performance by upto 12% in gen phase throughput for Gemm models. * Update ggml/src/ggml-cuda/fattn.cu Co-authored-by: Johannes Gäßler --------- Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/fattn.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 8edc1264..7a2d1e45 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -299,7 +299,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16; const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < GGML_CUDA_CC_ADA_LOVELACE && !mma_needs_data_conversion; - const bool can_use_vector_kernel = (Q->ne[0] % (2*warp_size) == 0) && (prec == GGML_PREC_DEFAULT || Q->ne[0] <= 128); + const bool can_use_vector_kernel = Q->ne[0] % (2*warp_size) == 0; if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) { if (prec == GGML_PREC_DEFAULT) { ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);