From b3f3779c1b071e3cd0f0aa8bb2ea22f69ce62558 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Mon, 17 Mar 2025 09:26:18 -0500 Subject: [PATCH] vulkan: Add N/2 and N/4 optimized paths in coopmat2 shader (llama/12312) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 24 +++--- .../vulkan-shaders/mul_mm_cm2.comp | 79 ++++++++++++++----- 2 files changed, 72 insertions(+), 31 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index aa7281ac..97398f07 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1597,33 +1597,33 @@ static void ggml_vk_load_shaders(vk_device& device) { uint32_t l_align, m_align, s_align; if (device->coopmat2) { // spec constants and tile sizes for non-quant matmul/matmul_id - l_warptile = { 256, 128, 256, 64 }; - m_warptile = { 256, 128, 128, 64 }; - s_warptile = { 128, 64, 64, 64 }; + l_warptile = { 256, 128, 256, 64, 1 }; + m_warptile = { 256, 128, 128, 64, 0 }; + s_warptile = { 128, 64, 64, 64, 0 }; l_wg_denoms = {128, 256, 1 }; m_wg_denoms = {128, 128, 1 }; s_wg_denoms = { 64, 64, 1 }; // spec constants and tile sizes for quant matmul (non-Qi_K) - l_warptile_mmq = { 256, 128, 256, 64 }; - m_warptile_mmq = { 256, 128, 128, 64 }; - s_warptile_mmq = { 256, 32, 64, 128 }; + l_warptile_mmq = { 256, 128, 256, 64, 1 }; + m_warptile_mmq = { 256, 128, 128, 64, 1 }; + s_warptile_mmq = { 256, 32, 64, 128, 0 }; l_mmq_wg_denoms = { 128, 256, 1 }; m_mmq_wg_denoms = { 128, 128, 1 }; s_mmq_wg_denoms = { 32, 64, 1 }; // spec constants and tile sizes for quant matmul (Qi_K) - l_warptile_mmq_k = { 256, 64, 128, 64 }; - m_warptile_mmq_k = { 256, 32, 64, 64 }; - s_warptile_mmq_k = { 256, 32, 32, 128 }; + l_warptile_mmq_k = { 256, 64, 128, 64, 1 }; + m_warptile_mmq_k = { 256, 32, 64, 64, 0 }; + s_warptile_mmq_k = { 256, 32, 32, 128, 0 }; l_mmq_wg_denoms_k = { 64, 128, 1 }; m_mmq_wg_denoms_k = { 32, 64, 1 }; s_mmq_wg_denoms_k = { 32, 32, 1 }; // spec constants and tile sizes for quant matmul_id - l_warptile_mmqid = { 256, 128, 64, 16 }; - m_warptile_mmqid = { 256, 128, 64, 16 }; - s_warptile_mmqid = { 256, 128, 64, 16 }; + l_warptile_mmqid = { 256, 128, 64, 16, 0 }; + m_warptile_mmqid = { 256, 128, 64, 16, 0 }; + s_warptile_mmqid = { 256, 128, 64, 16, 0 }; l_mmqid_wg_denoms = { 128, 64, 1 }; m_mmqid_wg_denoms = { 128, 64, 1 }; s_mmqid_wg_denoms = { 128, 64, 1 }; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp index 5b7a4efe..7649febb 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp @@ -23,6 +23,10 @@ layout (constant_id = 1) const uint BM = 64; layout (constant_id = 2) const uint BN = 64; layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant +layout (constant_id = 4) const bool enable_smaller_matrices = false; +const uint BNover2 = enable_smaller_matrices ? (BN / 2) : BN; +const uint BNover4 = enable_smaller_matrices ? (BN / 4) : BN; + layout (push_constant) uniform parameter { uint M; @@ -168,15 +172,13 @@ void main() { const uint end_k = min(p.K, (ik + 1) * p.k_split); #endif - coopmat sum; - sum = coopmat(0.0); - #ifdef MUL_MAT_ID uint pos_a = (expert_idx * p.batch_stride_a) / QUANT_K; uint pos_b = 0; #else uint pos_a = (batch_idx_a * p.batch_stride_a) / QUANT_K; uint pos_b = batch_idx * p.batch_stride_b; + uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; #endif uint stride_a = p.stride_a / QUANT_K; @@ -197,6 +199,7 @@ void main() { tensorLayoutNV<2> tensorLayoutB = createTensorLayoutNV(2); tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutBClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); + tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1); #if QUANT_K > 1 tensorLayoutA = setTensorLayoutBlockSizeNV(tensorLayoutA, 1, QUANT_K); @@ -232,16 +235,54 @@ void main() { tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1); uint k_iters = (end_k - start_k + BK - 1) / BK; + if (enable_smaller_matrices && ic * BN + BNover4 >= p.N) { + coopmat sum = coopmat(0.0); + for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) { - for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) { + coopmat mat_a; + coopmat mat_b; - coopmat mat_a; - coopmat mat_b; + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose); - coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose); + sum = coopMatMulAdd(mat_a, mat_b, sum); + } + coopmat mat_d = coopmat(sum); - sum = coopMatMulAdd(mat_a, mat_b, sum); + coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover4, ir * BM, BM), tensorViewTranspose); + return; + } else if (enable_smaller_matrices && ic * BN + BNover2 >= p.N) { + coopmat sum = coopmat(0.0); + for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) { + + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + } + coopmat mat_d = coopmat(sum); + + coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover2, ir * BM, BM), tensorViewTranspose); + return; + } else { + coopmat sum = coopmat(0.0); + for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) { + + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + } + coopmat mat_d = coopmat(sum); + + coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose); + return; } } else #endif // !defined(MUL_MAT_ID) @@ -254,6 +295,9 @@ void main() { tensorLayoutBClamp = setTensorLayoutStrideNV(tensorLayoutBClamp, stride_b, 1); + coopmat sum; + sum = coopmat(0.0); + [[dont_unroll]] for (uint block_k = start_k; block_k < end_k; block_k += BK) { @@ -296,19 +340,16 @@ void main() { sum = coopMatMulAdd(mat_a, mat_b, sum); } } - } - // Convert from ACC_TYPE to D_TYPE - coopmat mat_d; - mat_d = coopmat(sum); + // Convert from ACC_TYPE to D_TYPE + coopmat mat_d; + mat_d = coopmat(sum); #ifdef MUL_MAT_ID - // Call callback to store each element, remapping row through shared memory - coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic); + // Call callback to store each element, remapping row through shared memory + coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic); #else - tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1); - - uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; - coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose); + coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose); #endif + } }