diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 8d6e99e6..9f4147e9 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -3192,7 +3192,7 @@ kernel void kernel_flash_attn_ext( { float S[Q] = { [0 ... Q-1] = 0.0f }; - float M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 }; + float M[Q] = { [0 ... Q-1] = -__FLT_MAX__/2 }; // thread indices inside the simdgroup // TODO: see if we can utilize quad-group functions for better performance @@ -3452,7 +3452,7 @@ kernel void kernel_flash_attn_ext( // reduce the warps sequentially for (ushort sg = 1; sg < nsg; ++sg) { float S = { 0.0f }; - float M = { -__FLT16_MAX__/2 }; + float M = { -__FLT_MAX__/2 }; threadgroup_barrier(mem_flags::mem_threadgroup); @@ -3699,7 +3699,7 @@ kernel void kernel_flash_attn_ext_vec( { float S = 0.0f; - float M = -__FLT16_MAX__/2; + float M = -__FLT_MAX__/2; // thread indices inside the simdgroup const short tx = tiisg%NL;