mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-04-25 21:39:44 +00:00
graph : make FA compatible with MLA + add initial Metal kernels (llama/12953)
* graph : make mla compatible with FA * metal : add exp FA kernels for DeepSeek models ggml-ci * llama : minor naming updates ggml-ci * ggml : disable FA for DS head sizes * tests : add FA tests for MLA shapes ggml-ci
This commit is contained in:
parent
4e936e2afa
commit
36019c35a3
@ -3237,6 +3237,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
if (op->src[0]->ne[0] == 192) {
|
||||
return false;
|
||||
}
|
||||
if (op->src[0]->ne[0] == 576) {
|
||||
// DeepSeek MLA
|
||||
return false;
|
||||
}
|
||||
if (op->src[0]->ne[3] != 1) {
|
||||
return false;
|
||||
}
|
||||
|
@ -354,6 +354,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,
|
||||
@ -362,6 +363,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
|
||||
@ -370,6 +372,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,
|
||||
@ -378,6 +381,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,
|
||||
@ -386,6 +390,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,
|
||||
@ -394,6 +399,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,
|
||||
@ -402,6 +408,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96,
|
||||
@ -437,6 +444,13 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_SET_I32,
|
||||
GGML_METAL_KERNEL_TYPE_SET_F32,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
||||
@ -1018,6 +1032,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192, flash_attn_ext_f16_h192, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128, flash_attn_ext_f16_hk192_hv128, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512, flash_attn_ext_f16_hk576_hv512, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm && use_bfloat);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm && use_bfloat);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm && use_bfloat);
|
||||
@ -1026,6 +1041,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192, flash_attn_ext_bf16_h192, has_simdgroup_mm && use_bfloat);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128, flash_attn_ext_bf16_hk192_hv128, has_simdgroup_mm && use_bfloat);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm && use_bfloat);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512, flash_attn_ext_bf16_hk576_hv512, has_simdgroup_mm && use_bfloat);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm);
|
||||
@ -1034,6 +1050,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192, flash_attn_ext_q4_0_h192, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128, flash_attn_ext_q4_0_hk192_hv128, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512, flash_attn_ext_q4_0_hk576_hv512, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, has_simdgroup_mm);
|
||||
@ -1042,6 +1059,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192, flash_attn_ext_q4_1_h192, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128, flash_attn_ext_q4_1_hk192_hv128, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512, flash_attn_ext_q4_1_hk576_hv512, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, has_simdgroup_mm);
|
||||
@ -1050,6 +1068,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192, flash_attn_ext_q5_0_h192, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128, flash_attn_ext_q5_0_hk192_hv128, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512, flash_attn_ext_q5_0_hk576_hv512, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, has_simdgroup_mm);
|
||||
@ -1058,6 +1077,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192, flash_attn_ext_q5_1_h192, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128, flash_attn_ext_q5_1_hk192_hv128, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512, flash_attn_ext_q5_1_hk576_hv512, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, has_simdgroup_mm);
|
||||
@ -1066,6 +1086,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192, flash_attn_ext_q8_0_h192, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, flash_attn_ext_q8_0_hk576_hv512, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, flash_attn_ext_vec_f16_h96, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96, flash_attn_ext_vec_bf16_h96, has_simdgroup_reduction && use_bfloat);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96, flash_attn_ext_vec_q4_0_h96, has_simdgroup_reduction);
|
||||
@ -1101,6 +1122,13 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512, flash_attn_ext_vec_f16_hk576_hv512, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512, flash_attn_ext_vec_bf16_hk576_hv512, has_simdgroup_reduction && use_bfloat);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512, flash_attn_ext_vec_q4_0_hk576_hv512, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512, flash_attn_ext_vec_q4_1_hk576_hv512, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512, flash_attn_ext_vec_q5_0_hk576_hv512, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512, flash_attn_ext_vec_q5_1_hk576_hv512, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512, flash_attn_ext_vec_q8_0_hk576_hv512, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
||||
@ -1365,6 +1393,11 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||
// TODO: not sure if it is worth adding kernels for this size
|
||||
return false;
|
||||
}
|
||||
if (op->src[0]->ne[0] == 576) {
|
||||
// DeepSeek sizes
|
||||
// TODO: disabled for now, until optmized
|
||||
return false;
|
||||
}
|
||||
if (op->src[1]->type != op->src[2]->type) {
|
||||
return false;
|
||||
}
|
||||
@ -3857,12 +3890,14 @@ static void ggml_metal_encode_node(
|
||||
// TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0)
|
||||
// for now avoiding mainly to keep the number of templates/kernels a bit lower
|
||||
// these are now trivial to add after: https://github.com/ggml-org/llama.cpp/pull/12612
|
||||
if (ne01 >= 4 || (ne00%128 != 0 && ne00 != 96 && ne00 != 192)) {
|
||||
if (ne01 >= 4 || (ne00%128 != 0 && ne00 != 96 && ne00 != 192 && ne00 != 576)) {
|
||||
switch (src1->type) {
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
if (ne00 == 192 && ne20 == 128) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128].pipeline;
|
||||
} else if (ne00 == 576 && ne20 == 512) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512].pipeline;
|
||||
} else {
|
||||
switch (ne00) {
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
|
||||
@ -3885,6 +3920,8 @@ static void ggml_metal_encode_node(
|
||||
{
|
||||
if (ne00 == 192 && ne20 == 128) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128].pipeline;
|
||||
} else if (ne00 == 576 && ne20 == 512) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512].pipeline;
|
||||
} else {
|
||||
switch (ne00) {
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break;
|
||||
@ -3907,6 +3944,8 @@ static void ggml_metal_encode_node(
|
||||
{
|
||||
if (ne00 == 192 && ne20 == 128) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128].pipeline;
|
||||
} else if (ne00 == 576 && ne20 == 512) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512].pipeline;
|
||||
} else {
|
||||
switch (ne00) {
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break;
|
||||
@ -3929,6 +3968,8 @@ static void ggml_metal_encode_node(
|
||||
{
|
||||
if (ne00 == 192 && ne20 == 128) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128].pipeline;
|
||||
} else if (ne00 == 576 && ne20 == 512) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512].pipeline;
|
||||
} else {
|
||||
switch (ne00) {
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break;
|
||||
@ -3951,6 +3992,8 @@ static void ggml_metal_encode_node(
|
||||
{
|
||||
if (ne00 == 192 && ne20 == 128) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128].pipeline;
|
||||
} else if (ne00 == 576 && ne20 == 512) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512].pipeline;
|
||||
} else {
|
||||
switch (ne00) {
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break;
|
||||
@ -3973,6 +4016,8 @@ static void ggml_metal_encode_node(
|
||||
{
|
||||
if (ne00 == 192 && ne20 == 128) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128].pipeline;
|
||||
} else if (ne00 == 576 && ne20 == 512) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512].pipeline;
|
||||
} else {
|
||||
switch (ne00) {
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break;
|
||||
@ -3995,6 +4040,8 @@ static void ggml_metal_encode_node(
|
||||
{
|
||||
if (ne00 == 192 && ne20 == 128) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128].pipeline;
|
||||
} else if (ne00 == 576 && ne20 == 512) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512].pipeline;
|
||||
} else {
|
||||
switch (ne00) {
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break;
|
||||
@ -4114,12 +4161,36 @@ static void ggml_metal_encode_node(
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case 576:
|
||||
{
|
||||
if (ne20 == 512) {
|
||||
switch (src1->type) {
|
||||
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512].pipeline; break;
|
||||
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512].pipeline; break;
|
||||
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512].pipeline; break;
|
||||
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512].pipeline; break;
|
||||
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512].pipeline; break;
|
||||
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512].pipeline; break;
|
||||
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512].pipeline; break;
|
||||
default:
|
||||
{
|
||||
GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
|
||||
GGML_LOG_ERROR("add template specialization for this type\n");
|
||||
GGML_ABORT("add template specialization for this type");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
GGML_LOG_ERROR("unsupported size: %lld\n", ne20);
|
||||
GGML_LOG_ERROR("add template specialization for this size\n");
|
||||
GGML_ABORT("add template specialization for this size");
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
||||
GGML_LOG_ERROR("add template specialization for this size\n");
|
||||
GGML_ABORT("add template specialization for this size");
|
||||
}
|
||||
{
|
||||
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
||||
GGML_LOG_ERROR("add template specialization for this size\n");
|
||||
GGML_ABORT("add template specialization for this size");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -3546,6 +3546,7 @@ template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 192, 192>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 192, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 256, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
|
||||
|
||||
#if defined(GGML_METAL_USE_BF16)
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
|
||||
@ -3556,6 +3557,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 192>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
|
||||
#endif
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
|
||||
@ -3566,6 +3568,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 192>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 576, 512>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80, 80>;
|
||||
@ -3575,6 +3578,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 192>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 576, 512>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80, 80>;
|
||||
@ -3584,6 +3588,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 192>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 576, 512>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80, 80>;
|
||||
@ -3593,6 +3598,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 192>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 576, 512>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80, 80>;
|
||||
@ -3602,6 +3608,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 192>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 576, 512>;
|
||||
|
||||
#undef FA_TYPES
|
||||
|
||||
@ -4009,6 +4016,16 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 256, 256, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 256, 256, 4>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 576, 512, 2>;
|
||||
#if defined(GGML_METAL_USE_BF16)
|
||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 576, 512, 2>;
|
||||
#endif
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 576, 512, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_1_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 576, 512, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 576, 512, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 576, 512, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 576, 512, 2>;
|
||||
|
||||
#undef FA_TYPES
|
||||
|
||||
template<typename T>
|
||||
|
@ -9261,6 +9261,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
case 112:
|
||||
case 128:
|
||||
case 256:
|
||||
case 575: // DeepSeek MLA
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
|
Loading…
x
Reference in New Issue
Block a user