vulkan: Use fp16 for the flash attention P*V multiplication (llama/12783)

This is consistent with the ggml-cuda behavior and the mul_mat fallback.
This commit is contained in:
Jeff Bolz 2025-04-09 00:12:57 -05:00 committed by Georgi Gerganov
parent 79f23d9132
commit 1d50c6ac22

View File

@ -330,9 +330,11 @@ void main() {
// resize eM by using smear/reduce
coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);
O = eMdiag * O;
// multiply with fp16 accumulation, then add to O.
coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(0);
PV = coopMatMulAdd(P_A, V, PV);
O = coopMatMulAdd(P_A, V, O);
O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(PV);
}
// If there is split_k, then the split_k resolve shader does the final