metal : add quantized FA support (llama/10149)

* metal : add quantized FA (vec) support

ggml-ci

* metal : add quantized FA (non-vec) support

* metal : fix support check

ggml-ci

* metal : clean-up

* metal : clean-up (cont)

* metal : fix shared memory calc + reduce smem + comments

* metal : float-correctness

* metal : minor [no ci]
This commit is contained in:
Georgi Gerganov 2024-11-06 10:24:23 +02:00
parent f69c8b6f1b
commit 915bcd2c63
2 changed files with 567 additions and 191 deletions

View File

@ -255,9 +255,49 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
@ -710,9 +750,49 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, support_simdgroup_mm);
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112, flash_attn_ext_q4_0_h112, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128, flash_attn_ext_q4_0_h128, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112, flash_attn_ext_q4_1_h112, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128, flash_attn_ext_q4_1_h128, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112, flash_attn_ext_q5_0_h112, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128, flash_attn_ext_q5_0_h128, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112, flash_attn_ext_q5_1_h112, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128, flash_attn_ext_q5_1_h128, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, flash_attn_ext_q8_0_h112, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, support_simdgroup_reduction);
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256, support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
@ -869,13 +949,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_OP_LEAKY_RELU:
return true;
case GGML_OP_FLASH_ATTN_EXT:
if (op->src[1]->type != GGML_TYPE_F16) {
return false;
}
if (op->src[2]->type != GGML_TYPE_F16) {
return false;
}
if (op->src[0]->ne[0] == 256) {
if (op->src[1]->type != op->src[2]->type) {
return false;
}
return support_simdgroup_mm; // TODO: over-restricted for vec-kernels
@ -2822,6 +2896,7 @@ static void ggml_metal_encode_node(
GGML_ASSERT(ne11 % 32 == 0);
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == src2->type);
GGML_ASSERT(ggml_are_same_shape (src1, src2));
@ -2869,13 +2944,16 @@ static void ggml_metal_encode_node(
bool use_vec_kernel = false;
if (ne01 >= 4 || (ne00%128 != 0)) {
switch (src1->type) {
case GGML_TYPE_F16:
{
switch (ne00) {
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
//case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
default:
{
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
@ -2883,12 +2961,137 @@ static void ggml_metal_encode_node(
GGML_ABORT("add template specialization for this size");
}
}
} break;
case GGML_TYPE_Q4_0:
{
switch (ne00) {
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break;
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80 ].pipeline; break;
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96 ].pipeline; break;
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112].pipeline; break;
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128].pipeline; break;
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256].pipeline; break;
default:
{
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
GGML_LOG_ERROR("add template specialization for this size\n");
GGML_ABORT("add template specialization for this size");
}
}
} break;
case GGML_TYPE_Q4_1:
{
switch (ne00) {
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break;
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80 ].pipeline; break;
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96 ].pipeline; break;
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112].pipeline; break;
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128].pipeline; break;
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256].pipeline; break;
default:
{
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
GGML_LOG_ERROR("add template specialization for this size\n");
GGML_ABORT("add template specialization for this size");
}
}
} break;
case GGML_TYPE_Q5_0:
{
switch (ne00) {
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break;
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80 ].pipeline; break;
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96 ].pipeline; break;
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112].pipeline; break;
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128].pipeline; break;
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256].pipeline; break;
default:
{
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
GGML_LOG_ERROR("add template specialization for this size\n");
GGML_ABORT("add template specialization for this size");
}
}
} break;
case GGML_TYPE_Q5_1:
{
switch (ne00) {
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break;
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80 ].pipeline; break;
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96 ].pipeline; break;
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112].pipeline; break;
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128].pipeline; break;
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256].pipeline; break;
default:
{
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
GGML_LOG_ERROR("add template specialization for this size\n");
GGML_ABORT("add template specialization for this size");
}
}
} break;
case GGML_TYPE_Q8_0:
{
switch (ne00) {
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break;
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80 ].pipeline; break;
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96 ].pipeline; break;
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112].pipeline; break;
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128].pipeline; break;
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256].pipeline; break;
default:
{
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
GGML_LOG_ERROR("add template specialization for this size\n");
GGML_ABORT("add template specialization for this size");
}
}
} break;
default:
{
GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
GGML_LOG_ERROR("add template specialization for this type\n");
GGML_ABORT("add template specialization for this type");
}
}
} else {
use_vec_kernel = true;
switch (ne00) {
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
//case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
case 128:
{
switch (src1->type) {
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128].pipeline; break;
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128].pipeline; break;
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128].pipeline; break;
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128].pipeline; break;
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128].pipeline; break;
default:
{
GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
GGML_LOG_ERROR("add template specialization for this type\n");
GGML_ABORT("add template specialization for this type");
}
}
} break;
case 256:
{
switch (src1->type) {
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256].pipeline; break;
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256].pipeline; break;
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256].pipeline; break;
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256].pipeline; break;
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256].pipeline; break;
default:
{
GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
GGML_LOG_ERROR("add template specialization for this type\n");
GGML_ABORT("add template specialization for this type");
}
}
} break;
default:
{
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
@ -2942,10 +3145,16 @@ static void ggml_metal_encode_node(
GGML_ASSERT(nqptg % 8 == 0);
GGML_ASSERT(ncpsg % 32 == 0);
// 16*32*(nsg)
// the shared memory needed for the simdgroups to load the KV cache
// each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
//
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
int64_t nsgmax = 2;
while (true) {
const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2);
const size_t smem = FATTN_SMEM(nsgmax);
if (smem > device.maxThreadgroupMemoryLength) {
break;
}
@ -2956,16 +3165,15 @@ static void ggml_metal_encode_node(
// simdgroups per threadgroup (a.k.a. warps)
const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2);
const size_t smem = FATTN_SMEM(nsg);
//printf("smem: %zu, max: %zu\n", smem, device.maxThreadgroupMemoryLength);
//printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
[encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
[encoder setThreadgroupMemoryLength:smem atIndex:0];
#undef FATTN_SMEM
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
} else {
// half1x4 kernel
// half4x4 kernel
const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
@ -2973,8 +3181,28 @@ static void ggml_metal_encode_node(
GGML_ASSERT(nqptg % 1 == 0);
GGML_ASSERT(ncpsg % 32 == 0);
// ne00 + 2*ncpsg*(nsg)
// for each query, we load it as f16 in shared memory (ne00)
// and store the attention scores (nqptg x ncpsg) as f32
//
// 2*ne00*(nsg)
// each simdgroup has a full f32 head vector in shared mem to accumulate results
//
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) + 2*ne00*(nsg))*(sizeof(float)/2), 16))
int64_t nsgmax = 2;
while (true) {
const size_t smem = FATTN_SMEM(nsgmax);
if (smem > device.maxThreadgroupMemoryLength) {
break;
}
nsgmax *= 2;
}
nsgmax /= 2;
// simdgroups per threadgroup (a.k.a. warps)
const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
const int64_t nsgt = MAX(2, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
int64_t nsg = 1;
while (nsg <= nsgt) {
@ -2982,12 +3210,12 @@ static void ggml_metal_encode_node(
}
nsg /= 2;
const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2);
const size_t smem = FATTN_SMEM(nsg);
//printf("smem: %zu, max: %zu\n", smem, device.maxThreadgroupMemoryLength);
//printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
[encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
[encoder setThreadgroupMemoryLength:smem atIndex:0];
#undef FATTN_SMEM
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
}
} break;

View File

@ -2723,46 +2723,10 @@ kernel void kernel_leaky_relu_f32(
dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
}
typedef void (flash_attn_ext_f16_t)(
device const char * q,
device const char * k,
device const char * v,
device const char * mask,
device float * dst,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne11,
constant int64_t & ne12,
constant int64_t & ne13,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant uint64_t & nb13,
constant uint64_t & nb21,
constant uint64_t & nb22,
constant uint64_t & nb23,
constant uint64_t & nb31,
constant int64_t & ne1,
constant int64_t & ne2,
constant float & scale,
constant float & max_bias,
constant float & m0,
constant float & m1,
constant uint32_t & n_head_log2,
constant float & logit_softcap,
threadgroup half * shared,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]);
// ref: https://arxiv.org/pdf/2307.08691.pdf
template<int64_t D, int64_t Q = 8, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
kernel void kernel_flash_attn_ext_f16(
// D - head size, Q - queries per threadgroup, KV - key/value processed per each simdgroup, C - cache items per threadgroup
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &), short D, short Q = 8, short KV = 8, short C = 32>
kernel void kernel_flash_attn_ext(
device const char * q,
device const char * k,
device const char * v,
@ -2800,13 +2764,13 @@ kernel void kernel_flash_attn_ext_f16(
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
const short nsg = ntg.y; // number of simdgroups
const short iq3 = tgpig[2];
const short iq2 = tgpig[1];
const short iq1 = tgpig[0]*Q;
const int iq3 = tgpig[2];
const int iq2 = tgpig[1];
const int iq1 = tgpig[0]*Q;
const short D4 = D/4;
const short D8 = D/8;
//const short Q8 = Q/8;
const short D16 = D/16;
const short NW = N_SIMDWIDTH;
const short SH = (C + Q); // shared memory per simdgroup in (half)
@ -2818,6 +2782,9 @@ kernel void kernel_flash_attn_ext_f16(
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
threadgroup half * skv = (threadgroup half *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K and V in shared memory
threadgroup half4x4 * skv4 = (threadgroup half4x4 *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in half4x4
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
simdgroup_half8x8 lo[D8];
@ -2849,25 +2816,28 @@ kernel void kernel_flash_attn_ext_f16(
threadgroup_barrier(mem_flags::mem_threadgroup);
{
float S[Q] = { [0 ... Q-1] = 0.0h };
float S[Q] = { [0 ... Q-1] = 0.0f };
float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 };
// thread indices inside the simdgroup
const short tx = tiisg%4;
const short ty = tiisg/4;
// assume K and V are same shape
const short ne22 = ne12;
const short ne23 = ne13;
// broadcast
// broadcast k
const short rk2 = ne02/ne12;
const short rk3 = ne03/ne13;
const short rv2 = ne02/ne22;
const short rv3 = ne03/ne23;
// k indices
const short ik2 = iq2/rk2;
const short ik3 = iq3/rk3;
// v indices
// broadcast v
const short rv2 = ne02/ne22;
const short rv3 = ne03/ne23;
const short iv2 = iq2/rv2;
const short iv3 = iq3/rv3;
@ -2906,6 +2876,9 @@ kernel void kernel_flash_attn_ext_f16(
for (short cc = 0; cc < C/8; ++cc) {
simdgroup_float8x8 mqk = make_filled_simdgroup_matrix<float, 8>(0.h);
// this is compile-time check, so it does not have runtime overhead
if (is_same<block_q, half4x4>::value) {
// we can read directly from global memory
device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
for (short i = 0; i < D8; ++i) {
@ -2914,6 +2887,49 @@ kernel void kernel_flash_attn_ext_f16(
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
}
} else {
for (short ii = 0; ii < D16; ii += 4) {
device const block_q * pk4 = (device const block_q *) ((device const char *) k + ((ic + 8*cc + ty)*nb11 + ik2*nb12 + ik3*nb13));
if (D16%4 == 0) {
// the head is evenly divisible by 4*16 = 64, so no need for bound checks
half4x4 tmp;
dequantize_func(pk4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
skv4[4*ty + tx] = tmp;
simdgroup_barrier(mem_flags::mem_threadgroup);
#pragma unroll
for (short k = 0; k < 4; ++k) {
simdgroup_half8x8 mk;
simdgroup_load(mk, skv + 16*k + 0*8, 4*16, 0, true); // transpose
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
simdgroup_load(mk, skv + 16*k + 1*8, 4*16, 0, true); // transpose
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
}
} else {
if (ii + tx < D16) {
half4x4 tmp;
dequantize_func(pk4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
skv4[4*ty + tx] = tmp;
}
simdgroup_barrier(mem_flags::mem_threadgroup);
for (short k = 0; k < 4 && ii + k < D16; ++k) {
simdgroup_half8x8 mk;
simdgroup_load(mk, skv + 16*k + 0*8, 4*16, 0, true); // transpose
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
simdgroup_load(mk, skv + 16*k + 1*8, 4*16, 0, true); // transpose
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
}
}
}
}
simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
}
@ -2977,16 +2993,61 @@ kernel void kernel_flash_attn_ext_f16(
// O = O + (Q*K^T)*V
{
for (short cc = 0; cc < C/8; ++cc) {
simdgroup_float8x8 ms;
simdgroup_load(ms, ss + 8*cc, TF, 0, false);
if (is_same<block_q, half4x4>::value) {
// we can read directly from global memory
device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
#pragma unroll
for (short i = 0; i < D8; ++i) {
simdgroup_half8x8 mk;
simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false);
simdgroup_half8x8 mv;
simdgroup_load(mv, pv + i*8, nb21/sizeof(half), 0, false);
simdgroup_float8x8 mv;
simdgroup_load(mv, ss + 8*cc, TF, 0, false);
simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]);
}
} else {
for (short ii = 0; ii < D16; ii += 4) {
device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 8*cc + ty)*nb21 + iv2*nb22 + iv3*nb23));
simdgroup_multiply_accumulate(lo[i], mv, mk, lo[i]);
if (D16%4 == 0) {
// no need for bound checks
half4x4 tmp;
dequantize_func(pv4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
skv4[4*ty + tx] = tmp;
simdgroup_barrier(mem_flags::mem_threadgroup);
#pragma unroll
for (short k = 0; k < 4; ++k) {
simdgroup_half8x8 mv;
simdgroup_load(mv, skv + 16*k + 0*8, 4*16, 0, false);
simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
simdgroup_load(mv, skv + 16*k + 1*8, 4*16, 0, false);
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
}
} else {
if (ii + tx < D16) {
half4x4 tmp;
dequantize_func(pv4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
skv4[4*ty + tx] = tmp;
}
simdgroup_barrier(mem_flags::mem_threadgroup);
for (short k = 0; k < 4 && ii + k < D16; ++k) {
simdgroup_half8x8 mv;
simdgroup_load(mv, skv + 16*k + 0*8, 4*16, 0, false);
simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
simdgroup_load(mv, skv + 16*k + 1*8, 4*16, 0, false);
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
}
}
}
}
}
}
@ -3003,7 +3064,7 @@ kernel void kernel_flash_attn_ext_f16(
// reduce the warps sequentially
for (short sg = 1; sg < nsg; ++sg) {
float S = { 0.0h };
float S = { 0.0f };
float M = { -FLT_MAX/2 };
threadgroup_barrier(mem_flags::mem_threadgroup);
@ -3082,15 +3143,54 @@ kernel void kernel_flash_attn_ext_f16(
}
}
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64>;
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80>;
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>;
template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>;
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>;
//template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>;
typedef decltype(kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 64>) flash_attn_ext_t;
template<int64_t D, int64_t Q = 1, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
kernel void kernel_flash_attn_ext_vec_f16(
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 64>;
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 80>;
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 96>;
template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 112>;
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 128>;
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 256>;
template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 64>;
template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 80>;
template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 96>;
template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 112>;
template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 128>;
template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 256>;
template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 64>;
template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 80>;
template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 96>;
template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 112>;
template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 128>;
template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 256>;
template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 64>;
template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 80>;
template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 96>;
template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 112>;
template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 128>;
template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 256>;
template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 64>;
template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 80>;
template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 96>;
template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 112>;
template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 128>;
template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 256>;
template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 64>;
template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 80>;
template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 96>;
template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 112>;
template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 128>;
template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 256>;
// NOTE: can use half instead of float precision for some extra perf
// D - head size, Q - queries per threadgroup, C - cache items per threadgroup
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &), short D, short Q = 1, short C = 32>
kernel void kernel_flash_attn_ext_vec(
device const char * q,
device const char * k,
device const char * v,
@ -3128,36 +3228,27 @@ kernel void kernel_flash_attn_ext_vec_f16(
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
const short nsg = ntg.y; // number of simdgroups
const short iq3 = tgpig[2];
const short iq2 = tgpig[1];
const short iq1 = tgpig[0];
const int iq3 = tgpig[2];
const int iq2 = tgpig[1];
const int iq1 = tgpig[0];
const short D4 = D/4;
const short D16 = D/16;
const short NW = N_SIMDWIDTH;
const short SH = (C + Q); // shared memory per simdgroup in (half)
const short NW4 = NW/4;
const short SH = C; // shared memory per simdgroup in (half)
const short T = D + 2*nsg*SH; // shared memory size per query in (half)
float slope = 1.0f;
// ALiBi
if (max_bias > 0.0f) {
const uint32_t h = iq2;
const float base = h < n_head_log2 ? m0 : m1;
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
slope = pow(base, exp);
}
//threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
threadgroup half4x4 * sq44 = (threadgroup half4x4 *) (shared + 0*D); // same as above but in half4x4
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention
threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4
threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + 1*T); // scratch buffer for the results
threadgroup float4x4 * sr44 = (threadgroup float4x4 *) (shared + 2*sgitg*D + Q*T); // scratch buffer for the results
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
half4 lo[D4/NW];
float4x4 lo[D16/NW4];
// load heads from Q to shared memory
device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
@ -3171,8 +3262,8 @@ kernel void kernel_flash_attn_ext_vec_f16(
}
// zero out lo
for (short i = tiisg; i < D4; i += NW) {
lo[i/NW] = 0.0h;
for (short i = 0; i < D16/NW4; i += NW4) {
lo[i] = float4x4(0.0f);
}
// zero out shared memory SH
@ -3183,38 +3274,52 @@ kernel void kernel_flash_attn_ext_vec_f16(
threadgroup_barrier(mem_flags::mem_threadgroup);
{
float S = { 0.0h };
float M = { -FLT_MAX/2 };
float S = 0.0f;
float M = -FLT_MAX/2;
// thread indices inside the simdgroup
const short tx = tiisg%8;
const short ty = tiisg/8;
// assume K and V are same shape
const short ne22 = ne12;
const short ne23 = ne13;
// broadcast
// broadcast k
const short rk2 = ne02/ne12;
const short rk3 = ne03/ne13;
const short ik2 = iq2/rk2;
const short ik3 = iq3/rk3;
// broadcast v
const short rv2 = ne02/ne22;
const short rv3 = ne03/ne23;
// k indices
const short ik2 = iq2 / rk2;
const short ik3 = iq3 / rk3;
// v indices
const short iv2 = iq2 / rv2;
const short iv3 = iq3 / rv3;
const short iv2 = iq2/rv2;
const short iv3 = iq3/rv3;
// load the queries from shared memory into local memory
float4 mq[D4/NW];
float4x4 mq[D16/NW4];
for (short ii = 0; ii < D4; ii += NW) {
short i = ii + tiisg;
mq[ii/NW] = (float4) sq4[i];
for (short ii = 0; ii < D16; ii += NW4) {
mq[ii/NW4] = (float4x4) sq44[ii + tx];
}
// pointer to the mask
device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31);
device const half * mp = (device const half *) (mask + iq1*nb31);
float slope = 1.0f;
// ALiBi
if (max_bias > 0.0f) {
const uint32_t h = iq2;
const float base = h < n_head_log2 ? m0 : m1;
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
slope = pow(base, exp);
}
// loop over the KV cache
// each simdgroup handles blocks of Q rows and C columns
@ -3226,47 +3331,54 @@ kernel void kernel_flash_attn_ext_vec_f16(
// Q*K^T
{
#pragma unroll
// each simdgroup processes 1 query and 4 keys
for (short cc = 0; cc < C/4; ++cc) {
float4 mqk = { 0.0h };
float mqk = 0.0;
device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13));
device const block_q * pk = (device const block_q *) ((device const char *) k + ((ic + 4*cc + ty)*nb11 + ik2*nb12 + ik3*nb13));
#pragma unroll
for (short ii = 0; ii < D4; ii += NW) {
const short i = ii + tiisg;
for (short ii = 0; ii < D16; ii += NW4) {
const short i = ii + tx;
float4x4 mk;
mk[0] = (float4) pk4[i + 0*(nb11/8)];
mk[1] = (float4) pk4[i + 1*(nb11/8)];
mk[2] = (float4) pk4[i + 2*(nb11/8)];
mk[3] = (float4) pk4[i + 3*(nb11/8)];
dequantize_func(pk + i/nl, i%nl, mk);
mqk += (float4) (mq[ii/NW] * mk);
mqk +=
dot(mq[ii/NW4][0], mk[0]) +
dot(mq[ii/NW4][1], mk[1]) +
dot(mq[ii/NW4][2], mk[2]) +
dot(mq[ii/NW4][3], mk[3]);
}
// reduce the results from the threads in the simdgroup
mqk += simd_shuffle_down(mqk, 16);
mqk += simd_shuffle_down(mqk, 8);
// simdgroup reduce
// [ 0 .. 7] -> [ 0]
// [ 8 .. 15] -> [ 8]
// [16 .. 23] -> [16]
// [24 .. 31] -> [24]
//mqk += simd_shuffle_down(mqk, 16);
//mqk += simd_shuffle_down(mqk, 8);
mqk += simd_shuffle_down(mqk, 4);
mqk += simd_shuffle_down(mqk, 2);
mqk += simd_shuffle_down(mqk, 1);
// mqk = mqk*scale + mask*slope
if (tiisg == 0) {
if (tx == 0) {
mqk *= scale;
if (logit_softcap != 0.0f) {
mqk = logit_softcap*precise::tanh(mqk);
}
mqk += (mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f;
mqk += (mask != q) ? ((float) mp[ic + 4*cc + ty])*slope : (float) 0.0f;
ss4[cc] = mqk;
ss[4*cc + ty] = mqk;
}
}
}
simdgroup_barrier(mem_flags::mem_threadgroup);
// online softmax
{
const short p = tiisg;
@ -3286,29 +3398,32 @@ kernel void kernel_flash_attn_ext_vec_f16(
// O = diag(ms)*O
#pragma unroll
for (short ii = 0; ii < D4; ii += NW) {
lo[ii/NW] *= ms;
for (short ii = 0; ii < D16; ii += NW4) {
lo[ii/NW4] *= ms;
}
}
simdgroup_barrier(mem_flags::mem_threadgroup);
// O = O + (Q*K^T)*V
{
#pragma unroll
for (short cc = 0; cc < C/4; ++cc) {
device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + 4*cc)*nb21 + iv2*nb22 + iv3*nb23));
device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 4*cc + ty)*nb21 + iv2*nb22 + iv3*nb23));
const float4x4 lss(ss[4*cc + ty]);
#pragma unroll
for (short ii = 0; ii < D4; ii += NW) {
const short i = ii + tiisg;
for (short ii = 0; ii < D16; ii += NW4) {
const short i = ii + tx;
lo[ii/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0];
lo[ii/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1];
lo[ii/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2];
lo[ii/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3];
}
}
}
float4x4 mv;
dequantize_func(pv4 + i/nl, i%nl, mv);
lo[ii/NW4] += mv*lss;
}
}
}
}
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
@ -3318,10 +3433,32 @@ kernel void kernel_flash_attn_ext_vec_f16(
}
}
// simdgroup reduce
// [ 0, 8, 16, 24] -> [ 0]
// [ 1, 9, 17, 25] -> [ 1]
// [ 2, 10, 18, 26] -> [ 2]
// [ 3, 11, 19, 27] -> [ 3]
// [ 4, 12, 20, 28] -> [ 4]
// [ 5, 13, 21, 29] -> [ 5]
// [ 6, 14, 22, 30] -> [ 6]
// [ 7, 15, 23, 31] -> [ 7]
for (short ii = 0; ii < D16; ii += NW4) {
lo[ii/NW4][0] += simd_shuffle_down(lo[ii/NW4][0], 16);
lo[ii/NW4][0] += simd_shuffle_down(lo[ii/NW4][0], 8);
lo[ii/NW4][1] += simd_shuffle_down(lo[ii/NW4][1], 16);
lo[ii/NW4][1] += simd_shuffle_down(lo[ii/NW4][1], 8);
lo[ii/NW4][2] += simd_shuffle_down(lo[ii/NW4][2], 16);
lo[ii/NW4][2] += simd_shuffle_down(lo[ii/NW4][2], 8);
lo[ii/NW4][3] += simd_shuffle_down(lo[ii/NW4][3], 16);
lo[ii/NW4][3] += simd_shuffle_down(lo[ii/NW4][3], 8);
}
// store results to shared memory
for (short ii = 0; ii < D4; ii += NW) {
short i = ii + tiisg;
sr4[i] = lo[ii/NW];
for (short i = tiisg; i < D16; i += NW4) {
sr44[i] = lo[i/NW4];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
@ -3348,30 +3485,41 @@ kernel void kernel_flash_attn_ext_vec_f16(
}
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
for (short ii = 0; ii < D4; ii += NW) {
short i = ii + tiisg;
sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1;
for (short i = tiisg; i < D16; i += NW) {
sr44[i] = sr44[i]*ms0 + sr44[i + r*D16]*ms1;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
device float4 * dst4 = (device float4 *) dst;
device float4x4 * dst44 = (device float4x4 *) dst;
// final rescale with 1/S and store to global memory
if (sgitg == 0) {
const float S = ss[0];
for (short ii = 0; ii < D4; ii += NW) {
short i = ii + tiisg;
dst4[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D4 + i] = (float4) sr4[i]/S;
for (short i = tiisg; i < D16; i += NW) {
dst44[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D16 + i] = sr44[i]/S;
}
}
}
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
//template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
typedef decltype(kernel_flash_attn_ext_vec<half4x4, 1, dequantize_f16, 64>) flash_attn_ext_vec_t;
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<half4x4, 1, dequantize_f16, 128>;
template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q4_0, 2, dequantize_q4_0, 128>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q4_1, 2, dequantize_q4_1, 128>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q5_0, 2, dequantize_q5_0, 128>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q5_1, 2, dequantize_q5_1, 128>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q8_0, 2, dequantize_q8_0, 128>;
template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<half4x4, 1, dequantize_f16, 256>;
template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q4_0, 2, dequantize_q4_0, 256>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q4_1, 2, dequantize_q4_1, 256>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q5_0, 2, dequantize_q5_0, 256>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q5_1, 2, dequantize_q5_1, 256>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q8_0, 2, dequantize_q8_0, 256>;
template<typename T0, typename T1>
kernel void kernel_cpy(