From 1d50c6ac2262ada2d5e75dd1138b8fad3a10db15 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Wed, 9 Apr 2025 00:12:57 -0500 Subject: [PATCH] vulkan: Use fp16 for the flash attention P*V multiplication (llama/12783) This is consistent with the ggml-cuda behavior and the mul_mat fallback. --- ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index 8ddadb8a..a8f4bc41 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -330,9 +330,11 @@ void main() { // resize eM by using smear/reduce coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce); - O = eMdiag * O; + // multiply with fp16 accumulation, then add to O. + coopmat PV = coopmat(0); + PV = coopMatMulAdd(P_A, V, PV); - O = coopMatMulAdd(P_A, V, O); + O = eMdiag * O + coopmat(PV); } // If there is split_k, then the split_k resolve shader does the final