diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 1c0ca5ad..80d0765b 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -3128,14 +3128,15 @@ kernel void kernel_flash_attn_ext( const int iq2 = tgpig[1]; const int iq1 = tgpig[0]*Q; - const short DK4 = DK/4; - const short DK8 = DK/8; - const short DK16 = DK/16; - const short DV4 = DV/4; - const short DV8 = DV/8; - const short DV16 = DV/16; - const short NW = N_SIMDWIDTH; - const short SH = (2*C + Q); // shared memory per simdgroup (s_t == float) + constexpr short DK4 = DK/4; + constexpr short DK8 = DK/8; + constexpr short DK16 = DK/16; + constexpr short DV4 = DV/4; + constexpr short DV8 = DV/8; + constexpr short DV16 = DV/16; + + constexpr short NW = N_SIMDWIDTH; + constexpr short SH = (2*C + Q); // shared memory per simdgroup (s_t == float) const short TS = nsg*SH; // shared memory size per query in (s_t == float) const short T = DK + 2*TS; // shared memory size per query in (half) @@ -3641,11 +3642,11 @@ kernel void kernel_flash_attn_ext_vec( const int iq2 = tgpig[1]; const int iq1 = tgpig[0]; - const short DK4 = DK/4; - const short DV4 = DV/4; - const short NW = N_SIMDWIDTH; - const short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads - const short SH = 2*C; // shared memory per simdgroup + constexpr short DK4 = DK/4; + constexpr short DV4 = DV/4; + constexpr short NW = N_SIMDWIDTH; + constexpr short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads + constexpr short SH = 2*C; // shared memory per simdgroup const short T = DK + nsg*SH; // shared memory size per query in (half) @@ -3956,7 +3957,7 @@ kernel void kernel_flash_attn_ext_vec( half, half4, \ half4 -typedef decltype(kernel_flash_attn_ext_vec) flash_attn_ext_vec_t; +typedef decltype(kernel_flash_attn_ext_vec) flash_attn_ext_vec_t; template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_USE_BF16)