metal : use constexpr in FA kernels + fix typedef (llama/12659)

* metal : use constexpr in FA kernels

ggml-ci

* cont

ggml-ci

* cont : fix typedef

ggml-ci
This commit is contained in:
Georgi Gerganov 2025-03-30 22:04:04 +03:00
parent f9015b585b
commit 93631b2be6

View File

@ -3128,14 +3128,15 @@ kernel void kernel_flash_attn_ext(
const int iq2 = tgpig[1]; const int iq2 = tgpig[1];
const int iq1 = tgpig[0]*Q; const int iq1 = tgpig[0]*Q;
const short DK4 = DK/4; constexpr short DK4 = DK/4;
const short DK8 = DK/8; constexpr short DK8 = DK/8;
const short DK16 = DK/16; constexpr short DK16 = DK/16;
const short DV4 = DV/4; constexpr short DV4 = DV/4;
const short DV8 = DV/8; constexpr short DV8 = DV/8;
const short DV16 = DV/16; constexpr short DV16 = DV/16;
const short NW = N_SIMDWIDTH;
const short SH = (2*C + Q); // shared memory per simdgroup (s_t == float) constexpr short NW = N_SIMDWIDTH;
constexpr short SH = (2*C + Q); // shared memory per simdgroup (s_t == float)
const short TS = nsg*SH; // shared memory size per query in (s_t == float) const short TS = nsg*SH; // shared memory size per query in (s_t == float)
const short T = DK + 2*TS; // shared memory size per query in (half) const short T = DK + 2*TS; // shared memory size per query in (half)
@ -3641,11 +3642,11 @@ kernel void kernel_flash_attn_ext_vec(
const int iq2 = tgpig[1]; const int iq2 = tgpig[1];
const int iq1 = tgpig[0]; const int iq1 = tgpig[0];
const short DK4 = DK/4; constexpr short DK4 = DK/4;
const short DV4 = DV/4; constexpr short DV4 = DV/4;
const short NW = N_SIMDWIDTH; constexpr short NW = N_SIMDWIDTH;
const short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads constexpr short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads
const short SH = 2*C; // shared memory per simdgroup constexpr short SH = 2*C; // shared memory per simdgroup
const short T = DK + nsg*SH; // shared memory size per query in (half) const short T = DK + nsg*SH; // shared memory size per query in (half)
@ -3956,7 +3957,7 @@ kernel void kernel_flash_attn_ext_vec(
half, half4, \ half, half4, \
half4 half4
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 128>) flash_attn_ext_vec_t; typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>; template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>;
#if defined(GGML_METAL_USE_BF16) #if defined(GGML_METAL_USE_BF16)