From ae6a9bb9a58a0f02e4bb60393a65bde407c9e346 Mon Sep 17 00:00:00 2001 From: Gaurav Garg <52341457+gaugarg-nv@users.noreply.github.com> Date: Thu, 20 Mar 2025 01:22:06 +0530 Subject: [PATCH] CUDA: Improve flash decoding kernel GPU occupancy for BS=1 case (llama/12183) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Find out active blocks per SM using cudaOccupancyMaxActiveBlocksPerMultiprocessor API. Use this value to determine the optimal parallel_blocks value. - Prefer vector flash attention kernels over MMA kernel for BS=1 Fixes Issue: #12182 --------- Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/fattn-common.cuh | 88 +++++++++++++++++++--------- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 3 +- ggml/src/ggml-cuda/fattn-tile-f16.cu | 63 ++++++++------------ ggml/src/ggml-cuda/fattn-tile-f32.cu | 63 ++++++++------------ ggml/src/ggml-cuda/fattn-vec-f16.cuh | 73 +++++++++-------------- ggml/src/ggml-cuda/fattn-vec-f32.cuh | 73 +++++++++-------------- ggml/src/ggml-cuda/fattn-wmma-f16.cu | 65 +++++--------------- ggml/src/ggml-cuda/fattn.cu | 18 +++--- ggml/src/ggml-cuda/ggml-cuda.cu | 3 + ggml/src/ggml-cuda/vendors/hip.h | 1 + ggml/src/ggml-cuda/vendors/musa.h | 1 + 11 files changed, 194 insertions(+), 257 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 4067fd41..1c2a2a13 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -606,48 +606,47 @@ static __global__ void flash_attn_stream_k_fixup( *dst = dst_val / rowsum; } -template // D == head size +template // D == head size #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(D, 1) #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) static __global__ void flash_attn_combine_results( const float * __restrict__ VKQ_parts, const float2 * __restrict__ VKQ_meta, - float * __restrict__ dst) { - VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x; - VKQ_meta += parallel_blocks * gridDim.y*blockIdx.x; - dst += D * gridDim.y*blockIdx.x; + float * __restrict__ dst, + const int parallel_blocks) { + VKQ_parts += parallel_blocks*D * gridDim.z*blockIdx.x; + VKQ_meta += parallel_blocks * gridDim.z*blockIdx.x; + dst += D * gridDim.z*blockIdx.x; const int tid = threadIdx.x; __builtin_assume(tid < D); - __shared__ float2 meta[parallel_blocks]; + extern __shared__ float2 meta[]; if (tid < 2*parallel_blocks) { - ((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.y*(2*parallel_blocks) + tid]; + ((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + tid]; } __syncthreads(); float kqmax = meta[0].x; -#pragma unroll for (int l = 1; l < parallel_blocks; ++l) { kqmax = max(kqmax, meta[l].x); } float VKQ_numerator = 0.0f; float VKQ_denominator = 0.0f; -#pragma unroll for (int l = 0; l < parallel_blocks; ++l) { const float diff = meta[l].x - kqmax; const float KQ_max_scale = expf(diff); const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD); *((uint32_t *) &KQ_max_scale) &= ftz_mask; - VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid]; + VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.z*D + blockIdx.z*D + tid]; VKQ_denominator += KQ_max_scale * meta[l].y; } - dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator; + dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator; } static void on_no_fattn_vec_case(const int D) { @@ -671,12 +670,10 @@ static void on_no_fattn_vec_case(const int D) { } } -// parallel_blocks == 0 is stream-k decomposition -template +template void launch_fattn( - ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, - const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V, - const int warp_size = WARP_SIZE + ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared, + const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE ) { constexpr int ncols = ncols1 * ncols2; @@ -748,12 +745,14 @@ void launch_fattn( nb23 = nb23*bs*sizeof(half)/ts; } + int parallel_blocks = 1; + const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1); const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3]; const dim3 block_dim(warp_size, nwarps, 1); dim3 blocks_num; - if (parallel_blocks == 0) { + if (stream_k) { // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup. const int max_blocks = 2*nsm; const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks; @@ -769,9 +768,43 @@ void launch_fattn( dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + D) * sizeof(float)); } else { - blocks_num.x = parallel_blocks*ntiles_x; - blocks_num.y = Q->ne[2]; - blocks_num.z = Q->ne[3]; + GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0); + const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size. + + int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy. + CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared)); + + // parallel_blocks should be at least large enough to achieve max. occupancy for a single wave: + parallel_blocks = std::max((nsm * max_blocks_per_sm) / ntiles_total, 1); + + // parallel_blocks must not be larger than what the tensor size allows: + parallel_blocks = std::min(parallel_blocks, ntiles_KQ); + + // If ntiles_total % blocks_per_wave != 0 then some efficiency is lost due to tail effects. + // Test whether parallel_blocks can be set to a higher value for better efficiency. + const int blocks_per_wave = nsm * max_blocks_per_sm; + int nwaves_best = 0; + int efficiency_percent_best = 0; + for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KQ; ++parallel_blocks_test) { + const int nblocks_total = ntiles_total * parallel_blocks_test; + const int nwaves = (nblocks_total + blocks_per_wave - 1) / blocks_per_wave; + const int efficiency_percent = 100 * nblocks_total / (nwaves*blocks_per_wave); + + // Stop trying configurations with more waves if we already have good efficiency to avoid excessive overhead. + if (efficiency_percent_best >= 90 && nwaves > nwaves_best) { + break; + } + + if (efficiency_percent > efficiency_percent_best) { + nwaves_best = nwaves; + efficiency_percent_best = efficiency_percent; + parallel_blocks = parallel_blocks_test; + } + } + + blocks_num.x = ntiles_x; + blocks_num.y = parallel_blocks; + blocks_num.z = Q->ne[2]*Q->ne[3]; if (parallel_blocks > 1) { dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); @@ -803,7 +836,7 @@ void launch_fattn( K_data, V_data, mask ? ((const char *) mask->data) : nullptr, - (parallel_blocks) > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr, + !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr, scale, max_bias, m0, m1, n_head_log2, logit_softcap, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], @@ -815,7 +848,7 @@ void launch_fattn( ); CUDA_CHECK(cudaGetLastError()); - if constexpr (parallel_blocks == 0) { + if (stream_k) { if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles. const dim3 block_dim_combine(D, 1, 1); const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2}; @@ -824,13 +857,14 @@ void launch_fattn( <<>> ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]); } - } else if constexpr (parallel_blocks > 1) { + } else if (parallel_blocks > 1) { const dim3 block_dim_combine(D, 1, 1); - const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z); + const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z); + const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2); - flash_attn_combine_results - <<>> - (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + flash_attn_combine_results + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks); } CUDA_CHECK(cudaGetLastError()); } diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 718ee540..024032f6 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -970,7 +970,8 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml fattn_kernel = flash_attn_ext_f16; } - launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, true, true); + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, FATTN_KQ_STRIDE, true, true, true); } diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu index ef3569fa..77455d8e 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f16.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f16.cu @@ -4,7 +4,7 @@ #define FATTN_KQ_STRIDE_TILE_F16 64 -template // D == head size +template // D == head size #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(nwarps*WARP_SIZE, 1) #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) @@ -58,18 +58,17 @@ static __global__ void flash_attn_tile_ext_f16( //In this kernel Q, K, V are matrices while i, j, k are matrix indices. - const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on. - const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. + const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0); - const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape + const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0); + const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio)); + const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape const half * maskh = (const half *) mask + ne11*ic0; const int stride_KV2 = nb11 / sizeof(half2); - const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); + const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1); const half slopeh = __float2half(slopef); static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); @@ -105,8 +104,7 @@ static __global__ void flash_attn_tile_ext_f16( __syncthreads(); - const int k_start = parallel_blocks == 1 ? 0 : ip*FATTN_KQ_STRIDE_TILE_F16; - for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE_TILE_F16) { + for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F16; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F16) { // Calculate KQ tile and keep track of new maximum KQ values: half kqmax_new[ncols/nwarps]; @@ -271,16 +269,16 @@ static __global__ void flash_attn_tile_ext_f16( const int i0 = i00 + 2*threadIdx.x; half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)]; - if (parallel_blocks == 1) { + if (gridDim.y == 1) { dst_val /= __half2half2(kqsum_j); } - const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; - dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] = __low2float(dst_val); - dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = __high2float(dst_val); + const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y; + dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = __low2float(dst_val); + dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = __high2float(dst_val); } - if (parallel_blocks != 1 && threadIdx.x == 0) { - dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j); + if (gridDim.y != 1 && threadIdx.x == 0) { + dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j); } } #else @@ -288,7 +286,7 @@ static __global__ void flash_attn_tile_ext_f16( #endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) } -template +template void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; switch (Q->ne[0]) { @@ -296,15 +294,17 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * constexpr int D = 64; constexpr int nwarps = 8; constexpr size_t nbytes_shared = 0; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; - launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true); + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false); } break; case 128: { constexpr int D = 128; constexpr int nwarps = 8; constexpr size_t nbytes_shared = 0; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; - launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true); + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false); } break; default: { GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128."); @@ -324,37 +324,22 @@ void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_ten if (Q->ne[1] <= 16) { constexpr int cols_per_block = 16; - constexpr int parallel_blocks = 4; if (logit_softcap == 0.0f) { constexpr bool use_logit_softcap = false; - launch_fattn_tile_f16_64_128(ctx, dst); + launch_fattn_tile_f16_64_128(ctx, dst); } else { constexpr bool use_logit_softcap = true; - launch_fattn_tile_f16_64_128(ctx, dst); - } - return; - } - - if (Q->ne[1] <= 32) { - constexpr int cols_per_block = 32; - constexpr int parallel_blocks = 4; - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - launch_fattn_tile_f16_64_128(ctx, dst); - } else { - constexpr bool use_logit_softcap = true; - launch_fattn_tile_f16_64_128(ctx, dst); + launch_fattn_tile_f16_64_128(ctx, dst); } return; } constexpr int cols_per_block = 32; - constexpr int parallel_blocks = 1; if (logit_softcap == 0.0f) { constexpr bool use_logit_softcap = false; - launch_fattn_tile_f16_64_128(ctx, dst); + launch_fattn_tile_f16_64_128(ctx, dst); } else { constexpr bool use_logit_softcap = true; - launch_fattn_tile_f16_64_128(ctx, dst); + launch_fattn_tile_f16_64_128(ctx, dst); } } diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu index 04b69c83..85fea440 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu @@ -4,7 +4,7 @@ #define FATTN_KQ_STRIDE_TILE_F32 32 -template // D == head size +template // D == head size #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(nwarps*WARP_SIZE, 1) #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) @@ -58,18 +58,17 @@ static __global__ void flash_attn_tile_ext_f32( // In this kernel Q, K, V are matrices while i, j, k are matrix indices. - const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on. - const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. + const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0); - const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape + const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0); + const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio)); + const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape const half * maskh = (const half *) mask + ne11*ic0; const int stride_KV2 = nb11 / sizeof(half2); - const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); + const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1); static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); @@ -103,8 +102,7 @@ static __global__ void flash_attn_tile_ext_f32( __syncthreads(); - const int k_start = parallel_blocks == 1 ? 0 : ip*FATTN_KQ_STRIDE_TILE_F32; - for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE_TILE_F32) { + for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F32; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F32) { // Calculate KQ tile and keep track of new maximum KQ values: float kqmax_new[ncols/nwarps]; @@ -269,17 +267,17 @@ static __global__ void flash_attn_tile_ext_f32( const int i0 = i00 + 2*threadIdx.x; float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)]; - if (parallel_blocks == 1) { + if (gridDim.y == 1) { dst_val.x /= kqsum_j; dst_val.y /= kqsum_j; } - const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; - dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] = dst_val.x; - dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = dst_val.y; + const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y; + dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = dst_val.x; + dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = dst_val.y; } - if (parallel_blocks != 1 && threadIdx.x == 0) { - dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j); + if (gridDim.y != 1 && threadIdx.x == 0) { + dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j); } } #else @@ -287,7 +285,7 @@ static __global__ void flash_attn_tile_ext_f32( #endif // FLASH_ATTN_AVAILABLE } -template +template void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; switch (Q->ne[0]) { @@ -295,15 +293,17 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * constexpr int D = 64; constexpr int nwarps = 8; constexpr size_t nbytes_shared = 0; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; - launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true); + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false); } break; case 128: { constexpr int D = 128; constexpr int nwarps = 8; constexpr size_t nbytes_shared = 0; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; - launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true); + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false); } break; default: { GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128."); @@ -320,37 +320,22 @@ void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_ten if (Q->ne[1] <= 16) { constexpr int cols_per_block = 16; - constexpr int parallel_blocks = 4; if (logit_softcap == 0.0f) { constexpr bool use_logit_softcap = false; - launch_fattn_tile_f32_64_128(ctx, dst); + launch_fattn_tile_f32_64_128(ctx, dst); } else { constexpr bool use_logit_softcap = true; - launch_fattn_tile_f32_64_128(ctx, dst); - } - return; - } - - if (Q->ne[1] <= 32) { - constexpr int cols_per_block = 32; - constexpr int parallel_blocks = 4; - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - launch_fattn_tile_f32_64_128(ctx, dst); - } else { - constexpr bool use_logit_softcap = true; - launch_fattn_tile_f32_64_128(ctx, dst); + launch_fattn_tile_f32_64_128(ctx, dst); } return; } constexpr int cols_per_block = 32; - constexpr int parallel_blocks = 1; if (logit_softcap == 0.0f) { constexpr bool use_logit_softcap = false; - launch_fattn_tile_f32_64_128(ctx, dst); + launch_fattn_tile_f32_64_128(ctx, dst); } else { constexpr bool use_logit_softcap = true; - launch_fattn_tile_f32_64_128(ctx, dst); + launch_fattn_tile_f32_64_128(ctx, dst); } } diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index b7686c1e..32c52ebe 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -1,7 +1,7 @@ #include "common.cuh" #include "fattn-common.cuh" -template // D == head size +template // D == head size #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(D, 1) #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) @@ -55,17 +55,16 @@ static __global__ void flash_attn_vec_ext_f16( constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16; constexpr dequantize_1_f16_t dequantize_1_v = get_dequantize_1_f16(type_V); - const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on. - const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. + const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - Q += nb02* blockIdx.y + nb01*ic0; - K += nb12*(blockIdx.y / gqa_ratio); - V += nb22*(blockIdx.y / gqa_ratio); + Q += nb02* blockIdx.z + nb01*ic0; + K += nb12*(blockIdx.z / gqa_ratio); + V += nb22*(blockIdx.z / gqa_ratio); const half * maskh = (const half *) mask + ne11*ic0; - const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); + const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1); const half slopeh = __float2half(slopef); static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); @@ -172,8 +171,7 @@ static __global__ void flash_attn_vec_ext_f16( half2 VKQ[ncols] = {{0.0f, 0.0f}}; - const int k_start = parallel_blocks == 1 ? 0 : ip*D; - for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) { + for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) { // Calculate KQ tile and keep track of new maximum KQ values: // For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression, @@ -283,29 +281,29 @@ static __global__ void flash_attn_vec_ext_f16( kqsum[j_VKQ] = warp_reduce_sum((float)kqsum[j_VKQ]); half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ])); - if (parallel_blocks == 1) { + if (gridDim.y == 1) { dst_val /= kqsum[j_VKQ]; } - const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; - dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val; + const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y; + dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val; } - if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) { - dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]); + if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) { + dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]); } #else NO_DEVICE_CODE; #endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) } -template +template void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { constexpr int nwarps = D/WARP_SIZE; - fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16; + fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16; constexpr bool need_f16_K = D != 128; constexpr bool need_f16_V = D != 128 && D != 64; constexpr size_t nbytes_shared = 0; - launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V); + launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false); } template @@ -325,65 +323,48 @@ void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); if (Q->ne[1] == 1) { - constexpr int cols_per_block = 1; - constexpr int parallel_blocks = 4; + constexpr int cols_per_block = 1; if (logit_softcap == 0.0f) { constexpr bool use_logit_softcap = false; - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); } else { constexpr bool use_logit_softcap = true; - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); } return; } if (Q->ne[1] == 2) { - constexpr int cols_per_block = 2; - constexpr int parallel_blocks = 4; + constexpr int cols_per_block = 2; if (logit_softcap == 0.0f) { constexpr bool use_logit_softcap = false; - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); } else { constexpr bool use_logit_softcap = true; - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); } return; } if (Q->ne[1] <= 4) { - constexpr int cols_per_block = 4; - constexpr int parallel_blocks = 4; + constexpr int cols_per_block = 4; if (logit_softcap == 0.0f) { constexpr bool use_logit_softcap = false; - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); } else { constexpr bool use_logit_softcap = true; - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); } return; } - if (Q->ne[1] <= 8) { - constexpr int cols_per_block = 8; - constexpr int parallel_blocks = 4; - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); - } else { - constexpr bool use_logit_softcap = true; - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); - } - return; - } - - constexpr int cols_per_block = 8; - constexpr int parallel_blocks = 1; + constexpr int cols_per_block = 8; if (logit_softcap == 0.0f) { constexpr bool use_logit_softcap = false; - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); } else { constexpr bool use_logit_softcap = true; - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); } } diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index c1d2dd8d..336c136d 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -1,7 +1,7 @@ #include "common.cuh" #include "fattn-common.cuh" -template // D == head size +template // D == head size #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(D, 1) #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) @@ -55,16 +55,15 @@ static __global__ void flash_attn_vec_ext_f32( constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16; constexpr dequantize_1_f32_t dequantize_1_v = get_dequantize_1_f32(type_V); - const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on. - const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. + const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - Q += nb02* blockIdx.y + nb01*ic0; - K += nb12*(blockIdx.y / gqa_ratio); - V += nb22*(blockIdx.y / gqa_ratio); // K and V have same shape + Q += nb02* blockIdx.z + nb01*ic0; + K += nb12*(blockIdx.z / gqa_ratio); + V += nb22*(blockIdx.z / gqa_ratio); // K and V have same shape const half * maskh = (const half *) mask + ne11*ic0; - const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); + const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1); static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); constexpr int nwarps = D / WARP_SIZE; @@ -167,8 +166,7 @@ static __global__ void flash_attn_vec_ext_f32( float VKQ[ncols] = {0.0f}; - const int k_start = parallel_blocks == 1 ? 0 : ip*D; - for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) { + for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) { // Calculate KQ tile and keep track of new maximum KQ values: float kqmax_new_arr[ncols]; @@ -268,29 +266,29 @@ static __global__ void flash_attn_vec_ext_f32( kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]); float dst_val = VKQ[j_VKQ]; - if (parallel_blocks == 1) { + if (gridDim.y == 1) { dst_val /= kqsum[j_VKQ]; } - const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; - dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val; + const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y; + dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val; } - if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) { - dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]); + if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) { + dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]); } #else NO_DEVICE_CODE; #endif // FLASH_ATTN_AVAILABLE } -template +template void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { constexpr int nwarps = D/WARP_SIZE; - fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32; + fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32; constexpr bool need_f16_K = D != 128; constexpr bool need_f16_V = D != 128 && D != 64; constexpr size_t nbytes_shared = 0; - launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V); + launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false); } template @@ -307,65 +305,48 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); if (Q->ne[1] == 1) { - constexpr int cols_per_block = 1; - constexpr int parallel_blocks = 4; + constexpr int cols_per_block = 1; if (logit_softcap == 0.0f) { constexpr bool use_logit_softcap = false; - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); } else { constexpr bool use_logit_softcap = true; - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); } return; } if (Q->ne[1] == 2) { - constexpr int cols_per_block = 2; - constexpr int parallel_blocks = 4; + constexpr int cols_per_block = 2; if (logit_softcap == 0.0f) { constexpr bool use_logit_softcap = false; - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); } else { constexpr bool use_logit_softcap = true; - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); } return; } if (Q->ne[1] <= 4) { - constexpr int cols_per_block = 4; - constexpr int parallel_blocks = 4; + constexpr int cols_per_block = 4; if (logit_softcap == 0.0f) { constexpr bool use_logit_softcap = false; - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); } else { constexpr bool use_logit_softcap = true; - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); } return; } - if (Q->ne[1] <= 8) { - constexpr int cols_per_block = 8; - constexpr int parallel_blocks = 4; - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); - } else { - constexpr bool use_logit_softcap = true; - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); - } - return; - } - - constexpr int cols_per_block = 8; - constexpr int parallel_blocks = 1; + constexpr int cols_per_block = 8; if (logit_softcap == 0.0f) { constexpr bool use_logit_softcap = false; - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); } else { constexpr bool use_logit_softcap = true; - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); } } diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index dab1d5cb..5c214ea3 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -18,7 +18,7 @@ namespace wmma = rocwmma; #endif // FP16_MMA_AVAILABLE // D == head size, VKQ_stride == num VKQ rows calculated in parallel: -template +template __launch_bounds__(nwarps*ggml_cuda_get_physical_warp_size(), 1) static __global__ void flash_attn_ext_f16( const char * __restrict__ Q, @@ -67,8 +67,7 @@ static __global__ void flash_attn_ext_f16( constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on. - const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. + const int ic0 = ncols*blockIdx.x; // Index of the first Q/QKV column to work on. static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE."); static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16."); @@ -91,16 +90,16 @@ static __global__ void flash_attn_ext_f16( constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half); const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0); - const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio)); - const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape + const float * Q_f = (const float *) (Q + nb02* blockIdx.z + nb01*ic0); + const half * K_h = (const half *) (K + nb12*(blockIdx.z / gqa_ratio)); + const half * V_h = (const half *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0; const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2); const int stride_Q = nb01 / sizeof(float); const int stride_KV = nb11 / sizeof(half); - const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); + const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1); const half slopeh = __float2half(slopef); const half2 slope2 = make_half2(slopef, slopef); @@ -176,7 +175,7 @@ static __global__ void flash_attn_ext_f16( __syncthreads(); // Iterate over ne11 == previous tokens: - for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) { + for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE) { // Calculate tile of KQ: #pragma unroll for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) { @@ -395,7 +394,7 @@ static __global__ void flash_attn_ext_f16( if (ic0 + j_VKQ >= ne01) { return; } - const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; + const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y; float KQ_rowsum_j; if (std::is_same::value) { @@ -411,13 +410,13 @@ static __global__ void flash_attn_ext_f16( break; } float dst_val = VKQ[j_VKQ*D_padded + i]; - if (parallel_blocks == 1) { + if (gridDim.y == 1) { dst_val /= KQ_rowsum_j; } - dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val; + dst[j_dst*gridDim.z*D + blockIdx.z*D + i] = dst_val; } - if (parallel_blocks == 1 || threadIdx.x != 0) { + if (gridDim.y == 1 || threadIdx.x != 0) { continue; } @@ -428,7 +427,7 @@ static __global__ void flash_attn_ext_f16( dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]); } dst_meta_val.y = KQ_rowsum_j; - dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val; + dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = dst_meta_val; } #else NO_DEVICE_CODE; @@ -462,60 +461,26 @@ static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed."); template void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * KQV = dst; - const ggml_tensor * Q = dst->src[0]; constexpr int nwarps = 4; constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16; - const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3]; - const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size; float logit_softcap; memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); - if (4*blocks_num_pb1 < 2*nsm) { - constexpr int parallel_blocks = 4; - fattn_kernel_t fattn_kernel; - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - fattn_kernel = flash_attn_ext_f16< - D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; - } else { - constexpr bool use_logit_softcap = true; - fattn_kernel = flash_attn_ext_f16< - D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; - } - launch_fattn(ctx, dst, fattn_kernel, nwarps, 0, true, true, warp_size); - return; - } - if (2*blocks_num_pb1 < 2*nsm) { - constexpr int parallel_blocks = 2; - fattn_kernel_t fattn_kernel; - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - fattn_kernel = flash_attn_ext_f16< - D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; - } else { - constexpr bool use_logit_softcap = true; - fattn_kernel = flash_attn_ext_f16< - D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; - } - launch_fattn(ctx, dst, fattn_kernel, nwarps, 0, true, true, warp_size); - return; - } - constexpr int parallel_blocks = 1; fattn_kernel_t fattn_kernel; if (logit_softcap == 0.0f) { constexpr bool use_logit_softcap = false; fattn_kernel = flash_attn_ext_f16< - D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; + D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), KQ_acc_t, use_logit_softcap>; } else { constexpr bool use_logit_softcap = true; fattn_kernel = flash_attn_ext_f16< - D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; + D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), KQ_acc_t, use_logit_softcap>; } - launch_fattn(ctx, dst, fattn_kernel, nwarps, 0, true, true, warp_size); + launch_fattn(ctx, dst, fattn_kernel, nwarps, 0, FATTN_KQ_STRIDE, true, true, false, warp_size); } void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 2e72fc8f..97354189 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -281,13 +281,13 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst if (!fp16_mma_available(cc)) { if (prec == GGML_PREC_DEFAULT) { - if (Q->ne[1] <= 8) { + if (Q->ne[1] <= 8 || Q->ne[0] == 256) { ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); } else { ggml_cuda_flash_attn_ext_tile_f16(ctx, dst); } } else { - if (Q->ne[1] <= 8) { + if (Q->ne[1] <= 8 || Q->ne[0] == 256) { ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); } else { ggml_cuda_flash_attn_ext_tile_f32(ctx, dst); @@ -296,17 +296,17 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst return; } - const int gqa_ratio = Q->ne[2] / K->ne[2]; - const bool mma_fast_for_bs1 = fp16_mma_available(cc) && gqa_ratio % 2 == 0 && - K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && mask; - if (Q->ne[1] == 1 && Q->ne[0] % (2*warp_size) == 0 && !mma_fast_for_bs1) { + const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations + const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16; + const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < GGML_CUDA_CC_ADA_LOVELACE && !mma_needs_data_conversion; + const bool can_use_vector_kernel = (Q->ne[0] % (2*warp_size) == 0) && (prec == GGML_PREC_DEFAULT || Q->ne[0] <= 128); + if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) { if (prec == GGML_PREC_DEFAULT) { ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); - return; - } else if(Q->ne[0] <= 128) { + } else { ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); - return; } + return; } // The MMA implementation needs Turing or newer, use the old WMMA code for Volta: diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 5cb56df9..b783310e 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3230,6 +3230,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g #ifndef FLASH_ATTN_AVAILABLE return false; #endif // FLASH_ATTN_AVAILABLE + if (op->src[0]->ne[3] != 1) { + return false; + } if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) { return false; } diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index aace21e3..a4c717a3 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -129,6 +129,7 @@ #define cudaGraph_t hipGraph_t #define cudaStream_t hipStream_t #define cudaSuccess hipSuccess +#define cudaOccupancyMaxActiveBlocksPerMultiprocessor hipOccupancyMaxActiveBlocksPerMultiprocessor #define __trap() do { abort(); __builtin_unreachable(); } while(0) #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS #define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED diff --git a/ggml/src/ggml-cuda/vendors/musa.h b/ggml/src/ggml-cuda/vendors/musa.h index 997f6714..f2d55796 100644 --- a/ggml/src/ggml-cuda/vendors/musa.h +++ b/ggml/src/ggml-cuda/vendors/musa.h @@ -134,5 +134,6 @@ #define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed #define cudaStreamBeginCapture musaStreamBeginCapture #define cudaStreamEndCapture musaStreamEndCapture +#define cudaOccupancyMaxActiveBlocksPerMultiprocessor musaOccupancyMaxActiveBlocksPerMultiprocessor typedef mt_bfloat16 nv_bfloat16;