vulkan: fix NaN issue in flash attention shader (llama/12776)

Use -FLT_MAX/2 rather than -inf as the initial value for computing the maximum.
This commit is contained in:
Jeff Bolz 2025-04-06 04:03:47 -05:00 committed by Georgi Gerganov
parent d792d2a2dc
commit 3c26dd3353

View File

@ -227,8 +227,11 @@ void main() {
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M; coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M;
// Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M.
const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
L = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0); L = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(-1.0/0.0); M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(NEG_FLT_MAX_OVER_2);
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> slopeMat = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(1.0); coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> slopeMat = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(1.0);
@ -278,7 +281,7 @@ void main() {
uint R = ((i + 1) * Br > N) ? (N % Br) : Br; uint R = ((i + 1) * Br > N) ? (N % Br) : Br;
uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc; uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc;
coopMatPerElementNV(S, S, replacePadding, ACC_TYPE(-1.0/0.0), R, C); coopMatPerElementNV(S, S, replacePadding, ACC_TYPE(NEG_FLT_MAX_OVER_2), R, C);
} }
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> rowmax, P, rowsum, eM; coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> rowmax, P, rowsum, eM;