metal : improve FA + improve MoE (llama/12612)

* ggml : FA with different K, V head sizes (CPU)

ggml-ci

* metal : add FA with HS=192

* metal : extend FA to support different K and V head sizes

ggml-ci

* metal : add FA vector kernels for heads K 192 and V 128

ggml-ci

* ggml : restrict op on other backends to equal head sizes

ggml-ci

* metal : optimize FA-vec kernel

ggml-ci

* metal : FA remove mq registers

* metal : improve MoE mul_mat_id condition

ggml-ci

* metal : fix comments + remove unnecessary addition

ggml-ci

* metal : avoid too much shared memory usage with mul_mat_id

ggml-ci
This commit is contained in:
Georgi Gerganov 2025-03-28 20:21:59 +02:00
parent 1b81415963
commit 27533e7f63
8 changed files with 875 additions and 670 deletions

View File

@ -1791,11 +1791,11 @@ extern "C" {
#define GGML_KQ_MASK_PAD 64 #define GGML_KQ_MASK_PAD 64
// q: [n_embd, n_batch, n_head, 1] // q: [n_embd_k, n_batch, n_head, 1]
// k: [n_embd, n_kv, n_head_kv, 1] // k: [n_embd_k, n_kv, n_head_kv, 1]
// v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !! // v: [n_embd_v, n_kv, n_head_kv, 1] !! not transposed !!
// mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !! // mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
// res: [n_embd, n_head, n_batch, 1] !! permuted !! // res: [n_embd_v, n_head, n_batch, 1] !! permuted !!
GGML_API struct ggml_tensor * ggml_flash_attn_ext( GGML_API struct ggml_tensor * ggml_flash_attn_ext(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * q, struct ggml_tensor * q,

View File

@ -12238,10 +12238,11 @@ static void ggml_compute_forward_flash_attn_ext_f16(
const int ith = params->ith; const int ith = params->ith;
const int nth = params->nth; const int nth = params->nth;
const int64_t D = neq0; const int64_t DK = nek0;
const int64_t N = neq1; const int64_t DV = nev0;
const int64_t N = neq1;
GGML_ASSERT(ne0 == D); GGML_ASSERT(ne0 == DV);
GGML_ASSERT(ne2 == N); GGML_ASSERT(ne2 == N);
// input tensor rows must be contiguous // input tensor rows must be contiguous
@ -12249,12 +12250,11 @@ static void ggml_compute_forward_flash_attn_ext_f16(
GGML_ASSERT(nbk0 == ggml_type_size(k->type)); GGML_ASSERT(nbk0 == ggml_type_size(k->type));
GGML_ASSERT(nbv0 == ggml_type_size(v->type)); GGML_ASSERT(nbv0 == ggml_type_size(v->type));
GGML_ASSERT(neq0 == D); GGML_ASSERT(neq0 == DK);
GGML_ASSERT(nek0 == D); GGML_ASSERT(nek0 == DK);
GGML_ASSERT(nev0 == D); GGML_ASSERT(nev0 == DV);
GGML_ASSERT(neq1 == N); GGML_ASSERT(neq1 == N);
GGML_ASSERT(nev0 == D);
// dst cannot be transposed or permuted // dst cannot be transposed or permuted
GGML_ASSERT(nb0 == sizeof(float)); GGML_ASSERT(nb0 == sizeof(float));
@ -12320,15 +12320,15 @@ static void ggml_compute_forward_flash_attn_ext_f16(
float S = 0.0f; // sum float S = 0.0f; // sum
float M = -INFINITY; // maximum KQ value float M = -INFINITY; // maximum KQ value
float * VKQ32 = (float *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator float * VKQ32 = (float *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
float * V32 = (VKQ32 + 1*D); // (temporary) FP32 V buffer float * V32 = (VKQ32 + 1*DV); // (temporary) FP32 V buffer
ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*DV); // (temporary) FP16 VKQ accumulator
ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16 ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*DV); // (temporary) buffer for Q converted to quantized/FP16
if (v->type == GGML_TYPE_F16) { if (v->type == GGML_TYPE_F16) {
memset(VKQ16, 0, D*sizeof(ggml_fp16_t)); memset(VKQ16, 0, DV*sizeof(ggml_fp16_t));
} else { } else {
memset(VKQ32, 0, D*sizeof(float)); memset(VKQ32, 0, DV*sizeof(float));
} }
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL; const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
@ -12342,7 +12342,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
const int iv2 = iq2 / rv2; const int iv2 = iq2 / rv2;
const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)); const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
q_to_vec_dot(pq, Q_q, D); q_to_vec_dot(pq, Q_q, DK);
// online softmax / attention // online softmax / attention
// loop over n_kv and n_head_kv // loop over n_kv and n_head_kv
@ -12356,7 +12356,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
float s; // KQ value float s; // KQ value
const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3); const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1); kq_vec_dot(DK, &s, 0, k_data, 0, Q_q, 0, 1);
s = s*scale; // scale KQ value s = s*scale; // scale KQ value
@ -12380,14 +12380,14 @@ static void ggml_compute_forward_flash_attn_ext_f16(
ms = expf(Mold - M); ms = expf(Mold - M);
// V = V*expf(Mold - M) // V = V*expf(Mold - M)
ggml_vec_scale_f16(D, VKQ16, ms); ggml_vec_scale_f16(DV, VKQ16, ms);
} else { } else {
// no new maximum, ms == 1.0f, vs != 1.0f // no new maximum, ms == 1.0f, vs != 1.0f
vs = expf(s - M); vs = expf(s - M);
} }
// V += v*expf(s - M) // V += v*expf(s - M)
ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs); ggml_vec_mad_f16(DV, VKQ16, (const ggml_fp16_t *) v_data, vs);
} else { } else {
if (s > M) { if (s > M) {
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
@ -12395,30 +12395,30 @@ static void ggml_compute_forward_flash_attn_ext_f16(
ms = expf(Mold - M); ms = expf(Mold - M);
// V = V*expf(Mold - M) // V = V*expf(Mold - M)
ggml_vec_scale_f32(D, VKQ32, ms); ggml_vec_scale_f32(DV, VKQ32, ms);
} else { } else {
// no new maximum, ms == 1.0f, vs != 1.0f // no new maximum, ms == 1.0f, vs != 1.0f
vs = expf(s - M); vs = expf(s - M);
} }
v_to_float(v_data, V32, D); v_to_float(v_data, V32, DV);
// V += v*expf(s - M) // V += v*expf(s - M)
ggml_vec_mad_f32(D, VKQ32, V32, vs); ggml_vec_mad_f32(DV, VKQ32, V32, vs);
} }
S = S*ms + vs; // scale and increment sum with partial sum S = S*ms + vs; // scale and increment sum with partial sum
} }
if (v->type == GGML_TYPE_F16) { if (v->type == GGML_TYPE_F16) {
for (int64_t d = 0; d < D; ++d) { for (int64_t d = 0; d < DV; ++d) {
VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]); VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
} }
} }
// V /= S // V /= S
const float S_inv = 1.0f/S; const float S_inv = 1.0f/S;
ggml_vec_scale_f32(D, VKQ32, S_inv); ggml_vec_scale_f32(DV, VKQ32, S_inv);
// dst indices // dst indices
const int i1 = iq1; const int i1 = iq1;
@ -15277,7 +15277,6 @@ struct ggml_cplan ggml_graph_plan(
size_t cur = 0; size_t cur = 0;
if (!ggml_cpu_extra_work_size(n_threads, node, &cur)) { if (!ggml_cpu_extra_work_size(n_threads, node, &cur)) {
switch (node->op) { switch (node->op) {
case GGML_OP_CPY: case GGML_OP_CPY:
case GGML_OP_DUP: case GGML_OP_DUP:
@ -15386,9 +15385,10 @@ struct ggml_cplan ggml_graph_plan(
} break; } break;
case GGML_OP_FLASH_ATTN_EXT: case GGML_OP_FLASH_ATTN_EXT:
{ {
const int64_t ne00 = node->src[0]->ne[0]; // D const int64_t ne10 = node->src[1]->ne[0]; // DK
const int64_t ne20 = node->src[2]->ne[0]; // DV
cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread cur = sizeof(float)*(1*ne10 + 2*ne20)*n_tasks; // 1x head size K + 2x head size V (per thread)
} break; } break;
case GGML_OP_FLASH_ATTN_BACK: case GGML_OP_FLASH_ATTN_BACK:
{ {

View File

@ -3232,6 +3232,13 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
#ifndef FLASH_ATTN_AVAILABLE #ifndef FLASH_ATTN_AVAILABLE
return false; return false;
#endif // FLASH_ATTN_AVAILABLE #endif // FLASH_ATTN_AVAILABLE
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
// different head sizes of K and V are not supported yet
return false;
}
if (op->src[0]->ne[0] == 192) {
return false;
}
if (op->src[0]->ne[3] != 1) { if (op->src[0]->ne[3] != 1) {
return false; return false;
} }

View File

@ -219,9 +219,12 @@ typedef struct {
int32_t ne11; int32_t ne11;
int32_t ne_12_2; // assume K and V are same shape int32_t ne_12_2; // assume K and V are same shape
int32_t ne_12_3; int32_t ne_12_3;
uint64_t nb_12_1; uint64_t nb11;
uint64_t nb_12_2; uint64_t nb12;
uint64_t nb_12_3; uint64_t nb13;
uint64_t nb21;
uint64_t nb22;
uint64_t nb23;
uint64_t nb31; uint64_t nb31;
int32_t ne1; int32_t ne1;
int32_t ne2; int32_t ne2;

File diff suppressed because it is too large Load Diff

View File

@ -48,7 +48,7 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
template <typename type4> template <typename type4>
void dequantize_f16_t4(device const half4 * src, short il, thread type4 & reg) { void dequantize_f16_t4(device const half4 * src, short il, thread type4 & reg) {
reg = (type4)(*(src + il)); reg = (type4)(*(src));
} }
#if defined(GGML_METAL_USE_BF16) #if defined(GGML_METAL_USE_BF16)
@ -56,6 +56,11 @@ template <typename type4x4>
void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) { void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) {
reg = (type4x4)(*src); reg = (type4x4)(*src);
} }
template <typename type4>
void dequantize_bf16_t4(device const bfloat4 * src, short il, thread type4 & reg) {
reg = (type4)(*(src));
}
#endif #endif
template <typename type4x4> template <typename type4x4>
@ -3100,7 +3105,8 @@ template<
typename vd4x4_t, // key type in device memory typename vd4x4_t, // key type in device memory
short nl_v, short nl_v,
void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &), void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
short D, // head size short DK, // K head size
short DV, // V head size
short Q = 8, // queries per threadgroup short Q = 8, // queries per threadgroup
short KV = 8, // key/value processed per each simdgroup short KV = 8, // key/value processed per each simdgroup
short C = 32> // cache items per threadgroup short C = 32> // cache items per threadgroup
@ -3122,20 +3128,23 @@ kernel void kernel_flash_attn_ext(
const int iq2 = tgpig[1]; const int iq2 = tgpig[1];
const int iq1 = tgpig[0]*Q; const int iq1 = tgpig[0]*Q;
const short D4 = D/4; const short DK4 = DK/4;
const short D8 = D/8; const short DK8 = DK/8;
const short D16 = D/16; const short DK16 = DK/16;
const short DV4 = DV/4;
const short DV8 = DV/8;
const short DV16 = DV/16;
const short NW = N_SIMDWIDTH; const short NW = N_SIMDWIDTH;
const short SH = (2*C + Q); // shared memory per simdgroup (s_t == float) const short SH = (2*C + Q); // shared memory per simdgroup (s_t == float)
const short TS = nsg*SH; // shared memory size per query in (s_t == float) const short TS = nsg*SH; // shared memory size per query in (s_t == float)
const short T = D + 2*TS; // shared memory size per query in (half) const short T = DK + 2*TS; // shared memory size per query in (half)
threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*D); // holds the query data threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*D); // same as above but in q4_t threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*D); // reuse query data for accumulation threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*DK); // reuse query data for accumulation
threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*D); // same as above but in o4_t threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*DK); // same as above but in o4_t
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + Q*D); // scratch buffer for attention, mask and diagonal matrix threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + Q*DK); // scratch buffer for attention, mask and diagonal matrix
threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
@ -3144,23 +3153,23 @@ kernel void kernel_flash_attn_ext(
threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
o8x8_t lo[D8]; o8x8_t lo[DV8];
// load heads from Q to shared memory // load heads from Q to shared memory
for (short j = sgitg; j < Q; j += nsg) { for (short j = sgitg; j < Q; j += nsg) {
device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*args.nb01 + iq2*args.nb02 + iq3*args.nb03)); device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*args.nb01 + iq2*args.nb02 + iq3*args.nb03));
for (short i = tiisg; i < D4; i += NW) { for (short i = tiisg; i < DK4; i += NW) {
if (iq1 + j < args.ne01) { if (iq1 + j < args.ne01) {
sq4[j*D4 + i] = (q4_t) q4[i]; sq4[j*DK4 + i] = (q4_t) q4[i];
} else { } else {
sq4[j*D4 + i] = (q4_t) 0.0f; sq4[j*DK4 + i] = (q4_t) 0.0f;
} }
} }
} }
// zero out lo // zero out lo
for (short i = 0; i < D8; ++i) { for (short i = 0; i < DV8; ++i) {
lo[i] = make_filled_simdgroup_matrix<o_t, 8>((o_t) 0.0f); lo[i] = make_filled_simdgroup_matrix<o_t, 8>((o_t) 0.0f);
} }
@ -3190,13 +3199,6 @@ kernel void kernel_flash_attn_ext(
const short ikv2 = iq2/(args.ne02/args.ne_12_2); const short ikv2 = iq2/(args.ne02/args.ne_12_2);
const short ikv3 = iq3/(args.ne03/args.ne_12_3); const short ikv3 = iq3/(args.ne03/args.ne_12_3);
// load the queries from shared memory into local memory
q8x8_t mq[D8];
for (short i = 0; i < D8; ++i) {
simdgroup_load(mq[i], sq + i*8, D);
}
const bool has_mask = mask != q; const bool has_mask = mask != q;
half slope = 1.0f; half slope = 1.0f;
@ -3249,20 +3251,22 @@ kernel void kernel_flash_attn_ext(
// this is compile-time check, so it does not have runtime overhead // this is compile-time check, so it does not have runtime overhead
if (is_same<kd4x4_t, k4x4_t>::value) { if (is_same<kd4x4_t, k4x4_t>::value) {
// we can read directly from global memory // we can read directly from global memory
device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3)); device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13));
#pragma unroll(D8) #pragma unroll(DK8)
for (short i = 0; i < D8; ++i) { for (short i = 0; i < DK8; ++i) {
k8x8_t mk; k8x8_t mk;
simdgroup_load(mk, pk + i*8, args.nb_12_1/sizeof(k_t), 0, true); // transpose // TODO: use ne10 simdgroup_load(mk, pk + i*8, args.nb11/sizeof(k_t), 0, true); // transpose // TODO: use ne10
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); q8x8_t mq;
simdgroup_load(mq, sq + i*8, DK);
simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
} }
} else { } else {
for (short ii = 0; ii < D16; ii += 4) { for (short ii = 0; ii < DK16; ii += 4) {
device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3)); device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13));
if (D16%4 == 0) { if (DK16%4 == 0) {
// the head is evenly divisible by 4*16 = 64, so no need for bound checks // the head is evenly divisible by 4*16 = 64, so no need for bound checks
{ {
k4x4_t tmp; k4x4_t tmp;
@ -3275,15 +3279,18 @@ kernel void kernel_flash_attn_ext(
#pragma unroll(4) #pragma unroll(4)
for (short k = 0; k < 4; ++k) { for (short k = 0; k < 4; ++k) {
k8x8_t mk; k8x8_t mk;
q8x8_t mq;
simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk); simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK);
simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk); simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK);
simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
} }
} else { } else {
if (ii + tx < D16) { if (ii + tx < DK16) {
k4x4_t tmp; k4x4_t tmp;
deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp); deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
sk4x4[4*ty + tx] = tmp; sk4x4[4*ty + tx] = tmp;
@ -3291,14 +3298,17 @@ kernel void kernel_flash_attn_ext(
simdgroup_barrier(mem_flags::mem_threadgroup); simdgroup_barrier(mem_flags::mem_threadgroup);
for (short k = 0; k < 4 && ii + k < D16; ++k) { for (short k = 0; k < 4 && ii + k < DK16; ++k) {
k8x8_t mk; k8x8_t mk;
q8x8_t mq;
simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk); simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK);
simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk); simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK);
simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
} }
} }
} }
@ -3350,8 +3360,8 @@ kernel void kernel_flash_attn_ext(
s8x8_t mm; s8x8_t mm;
simdgroup_load(mm, ss + 2*C, TS, 0, false); simdgroup_load(mm, ss + 2*C, TS, 0, false);
#pragma unroll(D8) #pragma unroll(DV8)
for (short i = 0; i < D8; ++i) { for (short i = 0; i < DV8; ++i) {
simdgroup_multiply(lo[i], mm, lo[i]); simdgroup_multiply(lo[i], mm, lo[i]);
} }
} }
@ -3364,20 +3374,20 @@ kernel void kernel_flash_attn_ext(
if (is_same<vd4x4_t, v4x4_t>::value) { if (is_same<vd4x4_t, v4x4_t>::value) {
// we can read directly from global memory // we can read directly from global memory
device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3)); device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23));
#pragma unroll(D8) #pragma unroll(DV8)
for (short i = 0; i < D8; ++i) { for (short i = 0; i < DV8; ++i) {
v8x8_t mv; v8x8_t mv;
simdgroup_load(mv, pv + i*8, args.nb_12_1/sizeof(v_t), 0, false); // TODO: use ne20 simdgroup_load(mv, pv + i*8, args.nb21/sizeof(v_t), 0, false); // TODO: use ne20
simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]); simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]);
} }
} else { } else {
for (short ii = 0; ii < D16; ii += 4) { for (short ii = 0; ii < DV16; ii += 4) {
device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3)); device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23));
if (D16%4 == 0) { if (DV16%4 == 0) {
// no need for bound checks // no need for bound checks
{ {
v4x4_t tmp; v4x4_t tmp;
@ -3398,7 +3408,7 @@ kernel void kernel_flash_attn_ext(
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]); simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
} }
} else { } else {
if (ii + tx < D16) { if (ii + tx < DV16) {
v4x4_t tmp; v4x4_t tmp;
deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp); deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp);
sv4x4[4*ty + tx] = tmp; sv4x4[4*ty + tx] = tmp;
@ -3406,7 +3416,7 @@ kernel void kernel_flash_attn_ext(
simdgroup_barrier(mem_flags::mem_threadgroup); simdgroup_barrier(mem_flags::mem_threadgroup);
for (short k = 0; k < 4 && ii + k < D16; ++k) { for (short k = 0; k < 4 && ii + k < DV16; ++k) {
v8x8_t mv; v8x8_t mv;
simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false); simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
@ -3440,8 +3450,8 @@ kernel void kernel_flash_attn_ext(
// each simdgroup stores its output to shared memory, reusing sq // each simdgroup stores its output to shared memory, reusing sq
if (sgitg == sg) { if (sgitg == sg) {
for (short i = 0; i < D8; ++i) { for (short i = 0; i < DV8; ++i) {
simdgroup_store(lo[i], so + i*8, D, 0, false); simdgroup_store(lo[i], so + i*8, DV, 0, false);
} }
} }
@ -3480,11 +3490,11 @@ kernel void kernel_flash_attn_ext(
simdgroup_load(ms0, ss + 2*C, TS, 0, false); simdgroup_load(ms0, ss + 2*C, TS, 0, false);
simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false); simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false);
#pragma unroll(D8) #pragma unroll(DV8)
for (short i = 0; i < D8; ++i) { for (short i = 0; i < DV8; ++i) {
o8x8_t t; o8x8_t t;
simdgroup_load (t, so + i*8, D, 0, false); simdgroup_load (t, so + i*8, DV, 0, false);
simdgroup_multiply(t, ms1, t); simdgroup_multiply(t, ms1, t);
simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t); simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
@ -3495,8 +3505,8 @@ kernel void kernel_flash_attn_ext(
// store result to shared memory (reuse sq) // store result to shared memory (reuse sq)
if (sgitg == 0) { if (sgitg == 0) {
for (short i = 0; i < D8; ++i) { for (short i = 0; i < DV8; ++i) {
simdgroup_store(lo[i], so + i*8, D, 0, false); simdgroup_store(lo[i], so + i*8, DV, 0, false);
} }
} }
@ -3507,8 +3517,8 @@ kernel void kernel_flash_attn_ext(
for (short j = 0; j < Q && iq1 + j < args.ne01; ++j) { for (short j = 0; j < Q && iq1 + j < args.ne01; ++j) {
const float S = ss[j*TS + 0]; const float S = ss[j*TS + 0];
for (short i = tiisg; i < D4; i += NW) { for (short i = tiisg; i < DV4; i += NW) {
dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*D4 + i] = (float4) so4[j*D4 + i]/S; dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4 + i] = (float4) so4[j*DV4 + i]/S;
} }
} }
} }
@ -3525,80 +3535,94 @@ kernel void kernel_flash_attn_ext(
float, simdgroup_float8x8, \ float, simdgroup_float8x8, \
half, half4, simdgroup_half8x8 half, half4, simdgroup_half8x8
typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64>) flash_attn_ext_t; typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64>; template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>;
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80>; template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80, 80>;
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 96>; template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 96, 96>;
template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 112>; template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 112, 112>;
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128>; template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128, 128>;
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 256>; template [[host_name("kernel_flash_attn_ext_f16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 192, 192>;
template [[host_name("kernel_flash_attn_ext_f16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 192, 128>;
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 256, 256>;
#if defined(GGML_METAL_USE_BF16) #if defined(GGML_METAL_USE_BF16)
template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64>; template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80>; template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96>; template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96, 96>;
template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 112>; template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 112, 112>;
template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 128>; template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 128, 128>;
template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256>; template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 192>;
template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 128>;
template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256, 256>;
#endif #endif
template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64>; template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80>; template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80, 80>;
template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 96>; template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 96, 96>;
template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 112>; template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 112, 112>;
template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 128>; template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 128, 128>;
template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256>; template [[host_name("kernel_flash_attn_ext_q4_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 192>;
template [[host_name("kernel_flash_attn_ext_q4_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 128>;
template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64>; template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80>; template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80, 80>;
template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 96>; template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 96, 96>;
template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 112>; template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 112, 112>;
template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 128>; template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 128, 128>;
template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256>; template [[host_name("kernel_flash_attn_ext_q4_1_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 192>;
template [[host_name("kernel_flash_attn_ext_q4_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 128>;
template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64>; template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80>; template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80, 80>;
template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 96>; template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 96, 96>;
template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 112>; template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 112, 112>;
template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 128>; template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 128, 128>;
template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256>; template [[host_name("kernel_flash_attn_ext_q5_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 192>;
template [[host_name("kernel_flash_attn_ext_q5_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 128>;
template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64>; template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80>; template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80, 80>;
template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 96>; template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 96, 96>;
template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 112>; template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 112, 112>;
template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 128>; template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 128, 128>;
template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256>; template [[host_name("kernel_flash_attn_ext_q5_1_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 192>;
template [[host_name("kernel_flash_attn_ext_q5_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 128>;
template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64>; template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80>; template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80, 80>;
template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 96>; template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 96, 96>;
template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 112>; template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 112, 112>;
template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 128>; template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 128, 128>;
template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256>; template [[host_name("kernel_flash_attn_ext_q8_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 192>;
template [[host_name("kernel_flash_attn_ext_q8_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 128>;
template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256, 256>;
#undef FA_TYPES #undef FA_TYPES
template< template<
typename q4_t, // query types in shared memory typename q4_t, // query types in shared memory
typename q4x4_t, typename k4_t, // key types in shared memory
typename k4x4_t, // key types in shared memory typename v4_t, // value types in shared memory
typename v4x4_t, // value types in shared memory typename qk_t, // Q*K types
typename qk_t, // Q*K types typename s_t, // soft-max types
typename s_t, // soft-max types
typename s4_t, typename s4_t,
typename s4x4_t, typename o4_t, // attention accumulation types
typename o4x4_t, // attention accumulation types typename kd4_t, // key type in device memory
typename kd4x4_t, // key type in device memory
short nl_k, short nl_k,
void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &), void (*deq_k_t4)(device const kd4_t *, short, thread k4_t &),
typename vd4x4_t, // key type in device memory typename vd4_t, // key type in device memory
short nl_v, short nl_v,
void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &), void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &),
short D, // head size short DK, // K head size
short Q = 1, // queries per threadgroup short DV, // V head size
short C = 32> // cache items per threadgroup short NE = 4, // head elements per thread
short Q = 1, // queries per threadgroup
short C = 32> // cache items per threadgroup
kernel void kernel_flash_attn_ext_vec( kernel void kernel_flash_attn_ext_vec(
constant ggml_metal_kargs_flash_attn_ext & args, constant ggml_metal_kargs_flash_attn_ext & args,
device const char * q, device const char * q,
@ -3617,29 +3641,28 @@ kernel void kernel_flash_attn_ext_vec(
const int iq2 = tgpig[1]; const int iq2 = tgpig[1];
const int iq1 = tgpig[0]; const int iq1 = tgpig[0];
const short D4 = D/4; const short DK4 = DK/4;
const short D16 = D/16; const short DV4 = DV/4;
const short NW = N_SIMDWIDTH; const short NW = N_SIMDWIDTH;
const short NL = NW/4; // note: this can be adjusted to support D%64 == 0 and D%32 == 0 const short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads
const short SH = 2*C; // shared memory per simdgroup const short SH = 2*C; // shared memory per simdgroup
const short T = D + nsg*SH; // shared memory size per query in (half) const short T = DK + nsg*SH; // shared memory size per query in (half)
//threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*D); // holds the query data //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*D); // same as above but in q4_t threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shmem_f16 + 0*D); // same as above but in q4x4_t threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*D); // scratch buffer for attention threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*D); // same as above but in s4_t threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + C + Q*DK); // scratch buffer for mask
threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + C + Q*D); // scratch buffer for mask threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results
threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shmem_f16 + sgitg*D + Q*T); // scratch buffer for the results
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) // store the result for all queries in local memory (the O matrix from the paper)
o4x4_t lo[D16/NL]; o4_t lo[DV4/NL];
// load heads from Q to shared memory // load heads from Q to shared memory
device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03)); device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03));
for (short i = tiisg; i < D4; i += NW) { for (short i = tiisg; i < DK4; i += NW) {
if (iq1 < args.ne01) { if (iq1 < args.ne01) {
sq4[i] = (q4_t) q4[i]; sq4[i] = (q4_t) q4[i];
} else { } else {
@ -3648,8 +3671,8 @@ kernel void kernel_flash_attn_ext_vec(
} }
// zero out lo // zero out lo
for (short i = 0; i < D16/NL; ++i) { for (short i = 0; i < DV4/NL; ++i) {
lo[i] = (o4x4_t) 0.0f; lo[i] = (o4_t) 0.0f;
} }
// zero out shared memory SH // zero out shared memory SH
@ -3674,14 +3697,6 @@ kernel void kernel_flash_attn_ext_vec(
const short ikv2 = iq2/(args.ne02/args.ne_12_2); const short ikv2 = iq2/(args.ne02/args.ne_12_2);
const short ikv3 = iq3/(args.ne03/args.ne_12_3); const short ikv3 = iq3/(args.ne03/args.ne_12_3);
// load the queries from shared memory into local memory
q4x4_t mq[D16/NL];
#pragma unroll(D16/NL)
for (short ii = 0; ii < D16; ii += NL) {
mq[ii/NL] = sq4x4[ii + tx];
}
const bool has_mask = mask != q; const bool has_mask = mask != q;
// pointer to the mask // pointer to the mask
@ -3713,43 +3728,56 @@ kernel void kernel_flash_attn_ext_vec(
// Q*K^T // Q*K^T
{ {
// each simdgroup processes 1 query and 4 (NW/NL) keys // each simdgroup processes 1 query and NE (NW/NL) head elements
for (short cc = 0; cc < C/4; ++cc) { for (short cc = 0; cc < C/NE; ++cc) {
qk_t mqka[4] = { 0.0, 0.0, 0.0, 0.0 }; qk_t mqk = 0.0f;
device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3)); device const kd4_t * pk = (device const kd4_t *) ((device const char *) k + ((ic + NE*cc + ty)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13));
#pragma unroll(D16/NL) #pragma unroll(DK4/NL)
for (short ii = 0; ii < D16; ii += NL) { for (short ii = 0; ii < DK4; ii += NL) {
const short i = ii + tx; const short i = ii + tx;
k4x4_t mk; k4_t mk;
deq_k(pk + i/nl_k, i%nl_k, mk); deq_k_t4(pk + i/nl_k, i%nl_k, mk);
// note: this is less precise than the version below // note: this is less precise than the version below
//mqka[0] += dot(mq[ii/NL][0], mk[0]); //mqka[0] += dot(mq[0], mk[0]);
//mqka[1] += dot(mq[ii/NL][1], mk[1]); //mqka[1] += dot(mq[1], mk[1]);
//mqka[2] += dot(mq[ii/NL][2], mk[2]); //mqka[2] += dot(mq[2], mk[2]);
//mqka[3] += dot(mq[ii/NL][3], mk[3]); //mqka[3] += dot(mq[3], mk[3]);
mqka[0] += dot((float4) mq[ii/NL][0], (float4) mk[0]); //q4x4_t mq = sq4x4[i];
mqka[1] += dot((float4) mq[ii/NL][1], (float4) mk[1]); //mqka[0] += dot((float4) mq[0], (float4) mk[0]);
mqka[2] += dot((float4) mq[ii/NL][2], (float4) mk[2]); //mqka[1] += dot((float4) mq[1], (float4) mk[1]);
mqka[3] += dot((float4) mq[ii/NL][3], (float4) mk[3]); //mqka[2] += dot((float4) mq[2], (float4) mk[2]);
//mqka[3] += dot((float4) mq[3], (float4) mk[3]);
mqk += dot((float4) mk, (float4) sq4[i]);
} }
qk_t mqk = mqka[0] + mqka[1] + mqka[2] + mqka[3]; static_assert(NE > 1, "NE must be > 1"); // note: not sure why NE == 1 fails
// simdgroup reduce // simdgroup reduce (NE = 4)
// [ 0 .. 7] -> [ 0] // [ 0 .. 7] -> [ 0]
// [ 8 .. 15] -> [ 8] // [ 8 .. 15] -> [ 8]
// [16 .. 23] -> [16] // [16 .. 23] -> [16]
// [24 .. 31] -> [24] // [24 .. 31] -> [24]
//mqk += simd_shuffle_down(mqk, 16); if (NE <= 1) {
//mqk += simd_shuffle_down(mqk, 8); mqk += simd_shuffle_down(mqk, 16);
mqk += simd_shuffle_down(mqk, 4); }
mqk += simd_shuffle_down(mqk, 2); if (NE <= 2) {
mqk += simd_shuffle_down(mqk, 1); mqk += simd_shuffle_down(mqk, 8);
}
if (NE <= 4) {
mqk += simd_shuffle_down(mqk, 4);
}
if (NE <= 8) {
mqk += simd_shuffle_down(mqk, 2);
}
if (NE <= 16) {
mqk += simd_shuffle_down(mqk, 1);
}
// mqk = mqk*scale + mask*slope // mqk = mqk*scale + mask*slope
if (tx == 0) { if (tx == 0) {
@ -3759,9 +3787,9 @@ kernel void kernel_flash_attn_ext_vec(
mqk = args.logit_softcap*precise::tanh(mqk); mqk = args.logit_softcap*precise::tanh(mqk);
} }
mqk += sm[4*cc + ty]*slope; mqk += sm[NE*cc + ty]*slope;
ss[4*cc + ty] = mqk; ss[NE*cc + ty] = mqk;
} }
} }
} }
@ -3784,8 +3812,8 @@ kernel void kernel_flash_attn_ext_vec(
ss[tiisg] = vs; ss[tiisg] = vs;
// O = diag(ms)*O // O = diag(ms)*O
#pragma unroll(D16/NL) #pragma unroll(DV4/NL)
for (short ii = 0; ii < D16; ii += NL) { for (short ii = 0; ii < DV4; ii += NL) {
lo[ii/NL] *= ms; lo[ii/NL] *= ms;
} }
} }
@ -3794,17 +3822,18 @@ kernel void kernel_flash_attn_ext_vec(
// O = O + (Q*K^T)*V // O = O + (Q*K^T)*V
{ {
for (short cc = 0; cc < C/4; ++cc) { //#pragma unroll(C/NE)
device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3)); for (short cc = 0; cc < C/NE; ++cc) {
device const vd4_t * pv4 = (device const vd4_t *) ((device const char *) v + ((ic + NE*cc + ty)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23));
const s4x4_t ms(ss[4*cc + ty]); const s4_t ms(ss[NE*cc + ty]);
#pragma unroll(D16/NL) #pragma unroll(DV4/NL)
for (short ii = 0; ii < D16; ii += NL) { for (short ii = 0; ii < DV4; ii += NL) {
const short i = ii + tx; const short i = ii + tx;
v4x4_t mv; v4_t mv;
deq_v(pv4 + i/nl_v, i%nl_v, mv); deq_v_t4(pv4 + i/nl_v, i%nl_v, mv);
lo[ii/NL] += mv*ms; lo[ii/NL] += mv*ms;
} }
@ -3819,7 +3848,7 @@ kernel void kernel_flash_attn_ext_vec(
} }
} }
// simdgroup reduce // simdgroup reduce (NE = 4)
// [ 0, 8, 16, 24] -> [ 0] // [ 0, 8, 16, 24] -> [ 0]
// [ 1, 9, 17, 25] -> [ 1] // [ 1, 9, 17, 25] -> [ 1]
// [ 2, 10, 18, 26] -> [ 2] // [ 2, 10, 18, 26] -> [ 2]
@ -3828,37 +3857,48 @@ kernel void kernel_flash_attn_ext_vec(
// [ 5, 13, 21, 29] -> [ 5] // [ 5, 13, 21, 29] -> [ 5]
// [ 6, 14, 22, 30] -> [ 6] // [ 6, 14, 22, 30] -> [ 6]
// [ 7, 15, 23, 31] -> [ 7] // [ 7, 15, 23, 31] -> [ 7]
for (short ii = 0; ii < D16; ii += NL) { for (short ii = 0; ii < DV4; ii += NL) {
lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 16); if (NE > 1) {
lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 8); lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 16);
//lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 4); lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 16);
//lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 2); lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 16);
//lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 1); lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 16);
}
lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 16); if (NE > 2) {
lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 8); lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 8);
//lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 4); lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 8);
//lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 2); lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 8);
//lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 1); lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 8);
}
lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 16); if (NE > 4) {
lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 8); lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 4);
//lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 4); lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 4);
//lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 2); lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 4);
//lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 1); lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 4);
}
lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 16); if (NE > 8) {
lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 8); lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 2);
//lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 4); lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 2);
//lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 2); lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 2);
//lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 1); lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 2);
}
if (NE > 16) {
lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 1);
lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 1);
lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 1);
lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 1);
}
} }
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
// store results to shared memory // store results to shared memory
for (short i = tiisg; i < D16; i += NL) { for (short i = tiisg; i < DV4; i += NL) {
sr4x4[i] = lo[i/NL]; sr4[i] = lo[i/NL];
} }
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
@ -3885,22 +3925,22 @@ kernel void kernel_flash_attn_ext_vec(
} }
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
for (short i = tiisg; i < D16; i += NW) { for (short i = tiisg; i < DV4; i += NW) {
sr4x4[i] = sr4x4[i]*ms0 + sr4x4[i + r*D16]*ms1; sr4[i] = sr4[i]*ms0 + sr4[i + r*DV4]*ms1;
} }
} }
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
} }
device float4x4 * dst44 = (device float4x4 *) dst; device float4 * dst4 = (device float4 *) dst;
// final rescale with 1/S and store to global memory // final rescale with 1/S and store to global memory
if (sgitg == 0) { if (sgitg == 0) {
const float S = ss[0]; const float S = ss[0];
for (short i = tiisg; i < D16; i += NW) { for (short i = tiisg; i < DV4; i += NW) {
dst44[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)iq1*args.ne1)*D16 + i] = (float4x4) sr4x4[i]/S; dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)iq1*args.ne1)*DV4 + i] = (float4) sr4[i]/S;
} }
} }
} }
@ -3909,34 +3949,54 @@ kernel void kernel_flash_attn_ext_vec(
// in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max // in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
// //
#define FA_TYPES \ #define FA_TYPES \
half4, half4x4, \ half4, \
half4x4, \ half4, \
half4x4, \ half4, \
float, \ float, \
half, half4, half4x4, \ half, half4, \
half4x4 half4
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128>) flash_attn_ext_vec_t; typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 128>) flash_attn_ext_vec_t;
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128>; template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>;
#if defined(GGML_METAL_USE_BF16) #if defined(GGML_METAL_USE_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 128>; template [[host_name("kernel_flash_attn_ext_vec_bf16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 128, 128, 4>;
#endif #endif
template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 128>; template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 128, 128, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 128>; template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 128, 128, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 128>; template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 128, 128, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 128>; template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 128, 128, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 128>; template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 128, 128, 4>;
template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 256>; template [[host_name("kernel_flash_attn_ext_vec_f16_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 192, 4>;
#if defined(GGML_METAL_USE_BF16) #if defined(GGML_METAL_USE_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256>; template [[host_name("kernel_flash_attn_ext_vec_bf16_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 192, 4>;
#endif #endif
template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256>; template [[host_name("kernel_flash_attn_ext_vec_q4_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 192, 192, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256>; template [[host_name("kernel_flash_attn_ext_vec_q4_1_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 192, 192, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256>; template [[host_name("kernel_flash_attn_ext_vec_q5_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 192, 192, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256>; template [[host_name("kernel_flash_attn_ext_vec_q5_1_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 192, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256>; template [[host_name("kernel_flash_attn_ext_vec_q8_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 192, 4>;
template [[host_name("kernel_flash_attn_ext_vec_f16_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 128, 4>;
#if defined(GGML_METAL_USE_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 128, 4>;
#endif
template [[host_name("kernel_flash_attn_ext_vec_q4_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 192, 128, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 192, 128, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 192, 128, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 128, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 128, 4>;
template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 256, 256, 4>;
#if defined(GGML_METAL_USE_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 256, 256, 4>;
#endif
template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 256, 256, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 256, 256, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 256, 256, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 256, 256, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 256, 256, 4>;
#undef FA_TYPES #undef FA_TYPES

View File

@ -8764,6 +8764,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
default: default:
return false; return false;
} }
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
// different head sizes of K and V are not supported yet
return false;
}
if (op->src[0]->type != GGML_TYPE_F32) { if (op->src[0]->type != GGML_TYPE_F32) {
return false; return false;
} }

View File

@ -4369,7 +4369,7 @@ struct ggml_tensor * ggml_flash_attn_ext(
} }
// permute(0, 2, 1, 3) // permute(0, 2, 1, 3)
int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] }; int64_t ne[4] = { v->ne[0], q->ne[2], q->ne[1], q->ne[3] };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
float params[] = { scale, max_bias, logit_softcap }; float params[] = { scale, max_bias, logit_softcap };