diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 783a0ff8..0e9b2e81 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -5531,7 +5531,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx uint32_t workgroups_y = (uint32_t)neq2; uint32_t workgroups_z = (uint32_t)neq3; - if (N == 1 && qk_ratio > 1 && is_pow2(qk_ratio) && gqa_ratio <= flash_attention_num_small_rows && + if (N == 1 && qk_ratio > 1 && gqa_ratio <= flash_attention_num_small_rows && qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) { // grouped query attention - make the N dimension equal to gqa_ratio, reduce // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1 @@ -5544,8 +5544,8 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx uint32_t split_kv = KV; uint32_t split_k = 1; - if (gqa_ratio > 1 && ctx->device->shader_core_count > 0) { - GGML_ASSERT(workgroups_x == 1); + // Try to use split_k when KV is large enough to be worth the overhead + if (workgroups_x == 1 && ctx->device->shader_core_count > 0 && KV >= 512) { // Try to run two workgroups per SM. split_k = ctx->device->shader_core_count * 2 / workgroups_y; if (split_k > 1) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index e1baa85f..b926a578 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -131,7 +131,7 @@ ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in A // Load the slope matrix, indexed by Q's dimension 2. ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2) { - const uint32_t h = iq2 + (r & (p.gqa_ratio - 1)); + const uint32_t h = iq2 + (r % p.gqa_ratio); const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1); const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);