diff --git a/ggml/src/ggml-cuda/cp-async.cuh b/ggml/src/ggml-cuda/cp-async.cuh index 51aa41e7..ecb65999 100644 --- a/ggml/src/ggml-cuda/cp-async.cuh +++ b/ggml/src/ggml-cuda/cp-async.cuh @@ -24,7 +24,7 @@ static __device__ __forceinline__ void cp_async_cg_16(const unsigned int dst, co } else #endif // CUDART_VERSION >= 11040 { - asm volatile("cp.async.cg.shared.global.L2 [%0], [%1], 16;" + asm volatile("cp.async.cg.shared.global [%0], [%1], 16;" : : "r"(dst), "l"(src)); } #else diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index fefbd319..7b9566fb 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -516,27 +516,25 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) { nullptr; } -// The HIP compiler for some reason complains that it can't unroll a loop because of the jt*ncols + j >= ne01 conditional. -#ifdef __clang__ -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wpass-failed" -#endif // __clang__ - -template // D == head size -#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +template // D == head size __launch_bounds__(D, 1) -#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) static __global__ void flash_attn_stream_k_fixup( float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) { - const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols); - - const int iter_k = ne11 / KQ_stride; - const int iter_j = (ne01 + (ncols - 1)) / ncols; + constexpr int ncols = ncols1*ncols2; const int bidx0 = blockIdx.x; + const int j = blockIdx.y; + const int c = blockIdx.z; + const int jc = j*ncols2 + c; + const int tid = threadIdx.x; - const int kbc0 = (bidx0 + 0)*iter_k*iter_j*ne02 / gridDim.x; - const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*ne02 / gridDim.x; + const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols); + + const int iter_k = ne11 / FATTN_KQ_STRIDE; + const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; + + const int kbc0 = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; + const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; const bool did_not_have_any_data = kbc0 == kbc0_stop; const bool wrote_beginning_of_tile = kbc0 % iter_k == 0; @@ -548,22 +546,22 @@ static __global__ void flash_attn_stream_k_fixup( const int channel = kbc0 / (iter_k*iter_j); const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k; - dst += jt*ncols*ne02*D + channel*D; + if (jt*ncols1 + j >= ne01) { + return; + } + + dst += jt*ne02*(ncols1*D) + channel*(ncols2*D) + (j*ne02 + c)*D + tid; // Load the partial result that needs a fixup: - float dst_val[ncols] = {0.0f}; - float max_val[ncols] = {0.0f}; - float rowsum[ncols] = {0.0f}; -#pragma unroll - for (int j = 0; j < ncols; ++j) { - if (jt*ncols + j >= ne01) { - break; - } - dst_val[j] = dst[j*ne02*D + threadIdx.x]; + float dst_val = 0.0f; + float max_val = 0.0f; + float rowsum = 0.0f; + { + dst_val = *dst; - const float2 tmp = dst_fixup[bidx0*ncols + j]; - max_val[j] = tmp.x; - rowsum[j] = tmp.y; + const float2 tmp = dst_fixup[bidx0*ncols + jc]; + max_val = tmp.x; + rowsum = tmp.y; } // Iterate over previous blocks and compute the combined results. @@ -571,36 +569,30 @@ static __global__ void flash_attn_stream_k_fixup( int bidx = bidx0 - 1; int kbc_stop = kbc0; while(true) { - const int kbc = bidx*iter_k*iter_j*ne02 / gridDim.x; + const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x; if (kbc == kbc_stop) { // Did not have any data. bidx--; kbc_stop = kbc; continue; } -#pragma unroll - for (int j = 0; j < ncols; ++j) { - if (jt*ncols + j >= ne01) { - break; - } - const float dst_add = dst_fixup_data[bidx*ncols*D + j*D + threadIdx.x]; + const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid]; - const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + j]; + const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + jc]; - // Scale the current and new value accumulators depending on the max. values. - const float max_val_new = fmaxf(max_val[j], tmp.x); + // Scale the current and new value accumulators depending on the max. values. + const float max_val_new = fmaxf(max_val, tmp.x); - const float diff_val = max_val[j] - max_val_new; - const float diff_add = tmp.x - max_val_new; + const float diff_val = max_val - max_val_new; + const float diff_add = tmp.x - max_val_new; - const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f; - const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f; + const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f; + const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f; - dst_val[j] = scale_val*dst_val[j] + scale_add*dst_add; - rowsum[j] = scale_val*rowsum[j] + scale_add*tmp.y; + dst_val = scale_val*dst_val + scale_add*dst_add; + rowsum = scale_val*rowsum + scale_add*tmp.y; - max_val[j] = max_val_new; - } + max_val = max_val_new; // If this block started in a previous tile we are done and don't need to combine additional partial results. if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) { @@ -611,19 +603,9 @@ static __global__ void flash_attn_stream_k_fixup( } // Write back final result: -#pragma unroll - for (int j = 0; j < ncols; ++j) { - if (jt*ncols + j >= ne01) { - return; - } - dst[j*ne02*D + threadIdx.x] = dst_val[j] / rowsum[j]; - } + *dst = dst_val / rowsum; } -#ifdef __clang__ -#pragma clang diagnostic pop -#endif // __clang__ - template // D == head size #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(D, 1) @@ -690,11 +672,13 @@ static void on_no_fattn_vec_case(const int D) { } // parallel_blocks == 0 is stream-k decomposition -template +template void launch_fattn( ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V ) { + constexpr int ncols = ncols1 * ncols2; + const ggml_tensor * Q = dst->src[0]; const ggml_tensor * K = dst->src[1]; const ggml_tensor * V = dst->src[2]; @@ -763,25 +747,26 @@ void launch_fattn( nb23 = nb23*bs*sizeof(half)/ts; } - const int ntiles_x = ((Q->ne[1] + cols_per_block - 1) / cols_per_block); - const int ntiles_total = ntiles_x*Q->ne[2]*Q->ne[3]; + const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1); + const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3]; const dim3 block_dim(WARP_SIZE, nwarps, 1); dim3 blocks_num; if (parallel_blocks == 0) { // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup. - const int tiles_nwaves = (ntiles_total + 2*nsm - 1) / (2*nsm); - const int tiles_efficiency_percent = 100 * ntiles_total / (2*nsm*tiles_nwaves); + const int max_blocks = 2*nsm; + const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks; + const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves); - const int nblocks_stream_k = 2*nsm; + const int nblocks_stream_k = max_blocks; - const bool use_stream_k = tiles_efficiency_percent < 75 || cc >= GGML_CUDA_CC_ADA_LOVELACE; + const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || tiles_efficiency_percent < 75; blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total; blocks_num.y = 1; blocks_num.z = 1; - dst_tmp_meta.alloc(blocks_num.x*cols_per_block * (2*2 + D) * sizeof(float)); + dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + D) * sizeof(float)); } else { blocks_num.x = parallel_blocks*ntiles_x; blocks_num.y = Q->ne[2]; @@ -793,7 +778,6 @@ void launch_fattn( } } - float scale = 1.0f; float max_bias = 0.0f; float logit_softcap = 0.0f; @@ -832,9 +816,9 @@ void launch_fattn( if constexpr (parallel_blocks == 0) { if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles. const dim3 block_dim_combine(D, 1, 1); - const dim3 blocks_num_combine = blocks_num; + const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2}; - flash_attn_stream_k_fixup + flash_attn_stream_k_fixup <<>> ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]); } diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index d777f541..b2e0db9a 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -5,12 +5,15 @@ using namespace ggml_cuda_mma; -typedef tile<16, 8, half2> tile_A; -typedef tile< 8, 8, half2> tile_B; -typedef tile<16, 8, float> tile_C_KQ; -typedef tile<16, 4, half2> tile_C_VKQ; +typedef tile<16, 8, half2> tile_A; +typedef tile< 8, 8, half2> tile_B; +typedef tile<16, 8, half2> tile_B_16; +typedef tile<16, 8, float> tile_C_KQ; +typedef tile<16, 16, float> tile_C_KQ_16; +typedef tile<16, 4, half2> tile_C_VKQ; +typedef tile<16, 8, half2> tile_C_VKQ_16; -template +template static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV) { constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts. @@ -27,7 +30,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( constexpr int chunks_per_row = k0_sync_start / h2_per_chunk; constexpr int stride_i = WARP_SIZE / chunks_per_row; #pragma unroll - for (int i0 = 0; i0 < KQ_stride; i0 += nwarps*stride_i) { + for (int i0 = 0; i0 < KQ_per_iter; i0 += nwarps*stride_i) { const int i = i0 + threadIdx.y*stride_i + (chunks_per_row == WARP_SIZE ? 0 : threadIdx.x / chunks_per_row); const int k = (chunks_per_row == WARP_SIZE ? threadIdx.x : threadIdx.x % chunks_per_row)*h2_per_chunk; @@ -40,7 +43,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( // If D is not a power of 2, the rest is loaded synchronously. // K/V data is loaded with decreasing granularity for D for better memory bandwidth. - static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds"); + static_assert(KQ_per_iter % (4*nwarps) == 0, "out of bounds"); #pragma unroll for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { const int k0_start = stride_k == WARP_SIZE ? k0_sync_start : D/2 - (D/2) % (2*stride_k); @@ -52,7 +55,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( } #pragma unroll - for (int i0 = 0; i0 < KQ_stride; i0 += nwarps*stride_i) { + for (int i0 = 0; i0 < KQ_per_iter; i0 += nwarps*stride_i) { const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); #pragma unroll @@ -65,12 +68,54 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( } } -template +template +static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( + const half2 * const __restrict__ mask_h2, half2 * const __restrict__ tile_mask, const int stride_mask) { + static_assert(KQ_per_iter == 2*WARP_SIZE || KQ_per_iter == WARP_SIZE, "bad KQ_per_iter"); +#ifdef CP_ASYNC_AVAILABLE + constexpr int preload = KQ_per_iter * sizeof(half); + constexpr int cols_per_warp = 8*WARP_SIZE/KQ_per_iter; + constexpr int stride_j = nwarps * cols_per_warp; + + const unsigned int tile_mask_32 = __cvta_generic_to_shared(tile_mask); + +#pragma unroll + for (int j0 = 0; j0 < ncols1; j0 += stride_j) { + const int j = j0 + threadIdx.y*cols_per_warp + + (KQ_per_iter == 2*WARP_SIZE ? threadIdx.x / (WARP_SIZE/4) : threadIdx.x / (WARP_SIZE/8)); + + if (j0 + stride_j > ncols1 && j >= ncols1) { + break; + } + + const int i = 4 * (KQ_per_iter == 2*WARP_SIZE ? threadIdx.x % (WARP_SIZE/4) : threadIdx.x % (WARP_SIZE/8)); + + cp_async_cg_16(tile_mask_32 + j*(KQ_per_iter*sizeof(half) + 16) + i*sizeof(half2), mask_h2 + j*stride_mask + i); + } +#else + constexpr int cols_per_warp = 2*WARP_SIZE/KQ_per_iter; + constexpr int stride_j = nwarps * cols_per_warp; +#pragma unroll + for (int j0 = 0; j0 < ncols1; j0 += stride_j) { + const int j = j0 + threadIdx.y*cols_per_warp + (KQ_per_iter == 2*WARP_SIZE ? 0 : threadIdx.x / (WARP_SIZE/2)); + + if (j0 + stride_j > ncols1 && j >= ncols1) { + break; + } + + const int i = KQ_per_iter == 2*WARP_SIZE ? threadIdx.x : threadIdx.x % (WARP_SIZE/2); + + tile_mask[j*(KQ_per_iter/2 + 4) + i] = mask_h2[j*stride_mask + i]; + } +#endif // CP_ASYNC_AVAILABLE +} + +template static __device__ __forceinline__ void flash_attn_ext_f16_iter( const float2 * const __restrict__ Q_f2, const half2 * const __restrict__ K_h2, const half2 * const __restrict__ V_h2, - const half * const __restrict__ maskh, + const half2 * const __restrict__ mask_h2, float2 * const __restrict__ dstk, float2 * const __restrict__ dstk_fixup, const float scale, @@ -78,42 +123,60 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const float logit_softcap, const int ne01, const int ne02, - const int stride_Q, const int stride_KV, const int stride_mask, const int jt, half2 * const __restrict__ tile_K, half2 * const __restrict__ tile_V, + half2 * const __restrict__ tile_mask, const tile_B * const __restrict__ Q_B, tile_C_VKQ * const __restrict__ VKQ_C, - float2 & KQ_max, - float2 & KQ_rowsum, + float * const __restrict__ KQ_max, + float * const __restrict__ KQ_rowsum, const int kb0) { #ifdef NEW_MMA_AVAILABLE - constexpr int np = nwarps*tile_B::I / ncols; // Number of parallel CUDA warps per Q column. - constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts. + constexpr int cols_per_warp = ntiles * tile_B::I; + constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles; + constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. + constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts. - const int k_VKQ_0 = kb0*KQ_stride; - tile_C_KQ KQ_C[KQ_stride/(np*tile_C_KQ::I)]; + const int k_VKQ_0 = kb0 * KQ_per_iter; + tile_C_KQ KQ_C[KQ_per_iter/(np*tile_C_KQ::I) * ntiles]; + + // Use wide variants of tiles if ntiles >= 2. + tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B; + tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C; + tile_C_KQ_16 * KQ_C_16 = (tile_C_KQ_16 *) KQ_C; #ifdef CP_ASYNC_AVAILABLE cp_async_wait_all(); __syncthreads(); - flash_attn_ext_f16_load_tile(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV); + flash_attn_ext_f16_load_tile(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV); #else - flash_attn_ext_f16_load_tile(K_h2 + k_VKQ_0*stride_KV, tile_K, stride_KV); + if (ncols2 > 1 || mask_h2) { + flash_attn_ext_f16_load_mask(mask_h2 + k_VKQ_0/2, tile_mask, stride_mask); + } + flash_attn_ext_f16_load_tile(K_h2 + k_VKQ_0*stride_KV, tile_K, stride_KV); __syncthreads(); #endif // CP_ASYNC_AVAILABLE // Calculate tile of KQ: #pragma unroll - for (int i_KQ_00 = 0; i_KQ_00 < KQ_stride; i_KQ_00 += np*tile_A::I) { + for (int i_KQ_00 = 0; i_KQ_00 < KQ_per_iter; i_KQ_00 += np*tile_A::I) { const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I; #pragma unroll for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += tile_A::J) { tile_A K_A; load_ldmatrix(K_A, tile_K + i_KQ_0*D2_padded + k_KQ_0, D2_padded); - mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, ((tile_B *) Q_B)[k_KQ_0/tile_A::J]); + if (ntiles == 1) { + mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, Q_B[k_KQ_0/tile_A::J]); + } else { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { + // Wide version of KQ_C is column-major => swap A and B. + mma(KQ_C_16[i_KQ_00/(np*tile_A::I) * ntiles/2 + t], Q_B_16[k_KQ_0/tile_A::J * ntiles/2 + t], K_A); + } + } } } @@ -122,9 +185,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #endif // CP_ASYNC_AVAILABLE if (use_logit_softcap) { - static_assert(KQ_stride % (np*tile_C_KQ::I) == 0, "bad loop size"); + static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size"); #pragma unroll - for (int i = 0; i < KQ_stride/(np*tile_C_KQ::I); ++i) { + for (int i = 0; i < KQ_per_iter/(np*tile_C_KQ::I) * ntiles; ++i) { #pragma unroll for (int l = 0; l < tile_C_KQ::ne; ++l) { KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]); @@ -132,109 +195,209 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } } - if (maskh) { - static_assert(KQ_stride % (np *tile_C_KQ::I) == 0, "bad loop size"); - static_assert(ncols % (nwarps/np*tile_C_KQ::J) == 0, "bad loop size"); + float KQ_max_new[cols_per_thread]; #pragma unroll - for (int i00 = 0; i00 < KQ_stride; i00 += np*tile_C_KQ::I) { - const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I; -#pragma unroll - for (int l = 0; l < tile_C_KQ::ne; ++l) { - const int i = i0 + tile_C_KQ::get_i(l); - const int j = (threadIdx.y / np)*tile_C_KQ::J + tile_C_KQ::get_j(l); + for (int col = 0; col < cols_per_thread; ++col) { + KQ_max_new[col] = KQ_max[col]; + } + float KQ_rowsum_add[cols_per_thread] = {0.0f}; - KQ_C[i00/(np*tile_C_KQ::I)].x[l] += slope*__half2float(maskh[j*stride_mask + k_VKQ_0 + i]); + if (ntiles == 1) { + if (ncols2 > 1 || mask_h2) { +#pragma unroll + for (int i00 = 0; i00 < KQ_per_iter; i00 += np*tile_C_KQ::I) { + const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I; +#pragma unroll + for (int l = 0; l < tile_C_KQ::ne; ++l) { + const int i = i0 + tile_C_KQ::get_i(l); + const int j = ((threadIdx.y / np)*tile_C_KQ::J + tile_C_KQ::get_j(l)) / ncols2; + + KQ_C[i00/(np*tile_C_KQ::I)].x[l] += slope * + __half2float(((const half *) tile_mask)[j*(KQ_per_iter + 8) + i]); + } } } - } - // Calculate softmax for each KQ column using the current max. value. - // The divisor is stored in KQ_rowsum and will be applied at the end. - float2 KQ_max_new = KQ_max; - static_assert(KQ_stride % (np*tile_C_KQ::I) == 0, "bad loop size"); + // Calculate softmax for each KQ column using the current max. value. + // The divisor is stored in KQ_rowsum and will be applied at the end. + static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size"); #pragma unroll - for (int k = 0; k < KQ_stride/(np*tile_C_KQ::I); ++k) { + for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ::I); ++k) { #pragma unroll - for (int l0 = 0; l0 < tile_C_KQ::ne; l0 += 2) { - KQ_max_new.x = fmaxf(KQ_max_new.x, KQ_C[k].x[l0 + 0]); - KQ_max_new.y = fmaxf(KQ_max_new.y, KQ_C[k].x[l0 + 1]); + for (int l = 0; l < tile_C_KQ::ne; ++l) { + KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l]); + } } - } - // Values per KQ column are spread across 8 threads, does not need full warp reduce: + // Values per KQ column are spread across 8 threads, does not need full warp reduce: #pragma unroll - for (int offset = 16; offset > 2; offset >>= 1) { - KQ_max_new.x = fmaxf(KQ_max_new.x, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.x, offset, WARP_SIZE)); - KQ_max_new.y = fmaxf(KQ_max_new.y, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.y, offset, WARP_SIZE)); - } + for (int col = 0; col < cols_per_thread; ++col) { +#pragma unroll + for (int offset = 16; offset >= 4; offset >>= 1) { + KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE)); + } + } - float2 KQ_rowsum_add = make_float2(0.0f, 0.0f); - static_assert(KQ_stride % (np*tile_C_KQ::I) == 0, "bad loop size"); -#pragma unroll - for (int k = 0; k < KQ_stride/(np*tile_C_KQ::I); ++k) { -#pragma unroll - for (int l = 0; l < tile_C_KQ::ne; ++l) { - const float KQ_max_l = l % 2 == 0 ? KQ_max_new.x : KQ_max_new.y; - const float diff = KQ_C[k].x[l] - KQ_max_l; - KQ_C[k].x[l] = expf(diff); + static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size"); - if (l % 2 == 0) { - KQ_rowsum_add.x += KQ_C[k].x[l]; - } else { - KQ_rowsum_add.y += KQ_C[k].x[l]; +#pragma unroll + for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ::I); ++k) { +#pragma unroll + for (int l = 0; l < tile_C_KQ::ne; ++l) { + KQ_C[k].x[l] = expf(KQ_C[k].x[l] - KQ_max_new[l % 2]); + + KQ_rowsum_add[l % 2] += KQ_C[k].x[l]; + } + } + } else { // ntiles > 1 + if (ncols2 > 1 || mask_h2) { +#pragma unroll + for (int i00 = 0; i00 < KQ_per_iter; i00 += np*tile_C_KQ_16::J) { + const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ_16::J; +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { +#pragma unroll + for (int l0 = 0; l0 < tile_C_KQ_16::ne; l0 += 2) { + const int i = (i0 + tile_C_KQ_16::get_j(l0)) / 2; + const int j = ((threadIdx.y / np)*cols_per_warp + t*tile_C_KQ_16::I + tile_C_KQ_16::get_i(l0)) / ncols2; + + const float2 tmp = __half22float2(tile_mask[j*(KQ_per_iter/2 + 4) + i]); + const int KQ_index = i00/(np*tile_C_KQ_16::J) * ntiles/2 + t; + KQ_C_16[KQ_index].x[l0 + 0] += slope*tmp.x; + KQ_C_16[KQ_index].x[l0 + 1] += slope*tmp.y; + } + } + } + } + + // Calculate softmax for each KQ column using the current max. value. + // The divisor is stored in KQ_rowsum and will be applied at the end. + static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size"); +#pragma unroll + for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ_16::J); ++k) { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { +#pragma unroll + for (int l = 0; l < tile_C_KQ_16::ne; ++l) { + const int KQ_index = 2*t + (l/2) % 2; + KQ_max_new[KQ_index] = fmaxf(KQ_max_new[KQ_index], KQ_C_16[k*ntiles/2 + t].x[l]); + } + } + } + + // Values per KQ column are spread across 4 threads, does not need full warp reduce: +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { +#pragma unroll + for (int offset = 2; offset >= 1; offset >>= 1) { + KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE)); + } + } + + static_assert(KQ_per_iter % (np*tile_C_KQ_16::J) == 0, "bad loop size"); +#pragma unroll + for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ_16::J); ++k) { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { +#pragma unroll + for (int l = 0; l < tile_C_KQ_16::ne; ++l) { + const int KQ_index = 2*t + (l/2) % 2; + + KQ_C_16[k*ntiles/2 + t].x[l] = expf(KQ_C_16[k*ntiles/2 + t].x[l] - KQ_max_new[KQ_index]); + + KQ_rowsum_add[KQ_index] += KQ_C_16[k*ntiles/2 + t].x[l]; + } } } } { - const float2 diff = make_float2(KQ_max.x - KQ_max_new.x, KQ_max.y - KQ_max_new.y); - const float2 KQ_max_scale = make_float2(expf(diff.x), expf(diff.y)); - KQ_max = KQ_max_new; - - // Scale previous KQ_rowsum to account for a potential increase in KQ_max: - KQ_rowsum.x = KQ_max_scale.x*KQ_rowsum.x + KQ_rowsum_add.x; - KQ_rowsum.y = KQ_max_scale.y*KQ_rowsum.y + KQ_rowsum_add.y; - - const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale.x, KQ_max_scale.y); + float KQ_max_scale[cols_per_thread]; #pragma unroll - for (int i = 0; i < D/tile_C_VKQ::I; ++i) { + for (int col = 0; col < cols_per_thread; ++col) { + KQ_max_scale[col] = expf(KQ_max[col] - KQ_max_new[col]); + KQ_max[col] = KQ_max_new[col]; + + // Scale previous KQ_rowsum to account for a potential increase in KQ_max: + KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_rowsum_add[col]; + } + + if (ntiles == 1) { + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]); #pragma unroll - for (int l = 0; l < tile_C_VKQ::ne; ++l) { - VKQ_C[i].x[l] *= KQ_max_scale_h2; + for (int i = 0; i < D/tile_C_VKQ::I; ++i) { +#pragma unroll + for (int l = 0; l < tile_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale_h2; + } + } + } else { +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]); +#pragma unroll + for (int i = 0; i < D/tile_C_VKQ_16::J; ++i) { +#pragma unroll + for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) { + VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2; + } + } } } } // Convert KQ C tiles into B tiles for VKQ calculation: - tile_B B[KQ_stride/(np*2*tile_B::J)]; - static_assert(KQ_stride % (np*2*tile_B::J) == 0, "bad loop size"); + tile_B B[KQ_per_iter/(np*2*tile_B::J) * ntiles]; + tile_B_16 * B_16 = (tile_B_16 *) B; + static_assert(KQ_per_iter % (np*2*tile_B::J) == 0, "bad loop size"); + if (ntiles == 1) { #pragma unroll - for (int k = 0; k < KQ_stride/(np*2*tile_B::J); ++k) { - B[k] = get_transposed(get_half2(KQ_C[k])); + for (int k = 0; k < KQ_per_iter/(np*2*tile_B::J); ++k) { + B[k] = get_transposed(get_half2(KQ_C[k])); + } + } else { + for (int k = 0; k < KQ_per_iter/(np*2*tile_B_16::J); ++k) { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { + B_16[k*ntiles/2 + t] = get_half2(KQ_C_16[k*ntiles/2 + t]); + } + } } #ifdef CP_ASYNC_AVAILABLE + // Preload K tile for next iteration: cp_async_wait_all(); __syncthreads(); if (!last_iter) { - flash_attn_ext_f16_load_tile(K_h2 + (k_VKQ_0 + KQ_stride)*stride_KV, tile_K, stride_KV); + if (ncols2 > 1 || mask_h2) { + flash_attn_ext_f16_load_mask(mask_h2 + (k_VKQ_0 + KQ_per_iter)/2, tile_mask, stride_mask); + } + flash_attn_ext_f16_load_tile(K_h2 + (k_VKQ_0 + KQ_per_iter)*stride_KV, tile_K, stride_KV); } #else - flash_attn_ext_f16_load_tile(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV); + flash_attn_ext_f16_load_tile(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV); __syncthreads(); #endif // CP_ASYNC_AVAILABLE // Calculate VKQ tile: #pragma unroll for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += tile_C_VKQ::I) { - static_assert((KQ_stride/2) % (np*tile_A::J) == 0, "bad loop size"); + static_assert((KQ_per_iter/2) % (np*tile_A::J) == 0, "bad loop size"); #pragma unroll - for (int k00 = 0; k00 < KQ_stride/2; k00 += np*tile_A::J) { + for (int k00 = 0; k00 < KQ_per_iter/2; k00 += np*tile_A::J) { const int k0 = k00 + (threadIdx.y % np)*tile_A::J; tile_A A; load_ldmatrix_trans(A, tile_V + 2*k0*D2_padded + i_VKQ_0/2, D2_padded); - mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]); + if (ntiles == 1) { + mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]); + } else { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { + // Wide version of VKQ_C is column-major => swap A and B. + mma(VKQ_C_16[i_VKQ_0/tile_C_VKQ::I * ntiles/2 + t], B_16[k00/(np*tile_A::J) * ntiles/2 + t], A); + } + } } } @@ -247,12 +410,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #endif // NEW_MMA_AVAILABLE } -template +template static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const float2 * const __restrict__ Q_f2, const half2 * const __restrict__ K_h2, const half2 * const __restrict__ V_h2, - const half * const __restrict__ maskh, + const half2 * const __restrict__ mask_h2, float2 * const __restrict__ dstk, float2 * const __restrict__ dstk_fixup, const float scale, @@ -260,7 +423,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const float logit_softcap, const int ne01, const int ne02, - const int stride_Q, + const int stride_Q1, + const int stride_Q2, const int stride_KV, const int stride_mask, const int jt, @@ -269,63 +433,78 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( #ifdef NEW_MMA_AVAILABLE //In this kernel Q, K, V are matrices while i, j, k are matrix indices. - static_assert(nwarps*tile_B::I % ncols == 0, "bad nwarps"); - constexpr int np = nwarps*tile_B::I / ncols; // Number of parallel CUDA warps per Q column. + constexpr int ncols = ncols1 * ncols2; + constexpr int cols_per_warp = ntiles * tile_B::I; + constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles; + constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. - static_assert(D % nwarps == 0, "bad D"); - static_assert(KQ_stride % nwarps == 0, "bad KQ_stride"); + static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps"); + + static_assert(D % nwarps == 0, "bad D"); + static_assert(KQ_per_iter % nwarps == 0, "bad KQ_per_iter"); constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts. - // Temporary shared buffer for loading K/V data with KQ_stride*D logical elements: + // Temporary shared buffer for loading K/V data with KQ_per_iter*D logical elements: extern __shared__ half2 tile_K[]; #ifdef CP_ASYNC_AVAILABLE - half2 * tile_V = tile_K + KQ_stride*D2_padded; + half2 * tile_V = tile_K + KQ_per_iter*D2_padded; #else - half2 * tile_V = tile_K; + half2 * tile_V = tile_K; #endif // CP_ASYNC_AVAILABLE + half2 * tile_mask = tile_V + KQ_per_iter*D2_padded; - tile_B Q_B[D/(2*tile_B::J)]; - tile_C_VKQ VKQ_C[D/tile_C_VKQ::I]; + tile_B Q_B[D/(2*tile_B::J) * ntiles]; + tile_C_VKQ VKQ_C[D/tile_C_VKQ::I * ntiles]; - float2 KQ_rowsum = {0.0f, 0.0f}; - float2 KQ_max = {-FLT_MAX/2.0f, -FLT_MAX/2.0f}; + tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B; + tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C; + + float KQ_rowsum[cols_per_thread] = {0.0f}; + float KQ_max[cols_per_thread]; +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { + KQ_max[col] = -FLT_MAX/2.0f; + } // Temporarily load Q data into tile_K, will be loaded into registers afterwards. // The loading is done with decreasing granularity for D for better memory bandwidth. const half2 scale_h2 = make_half2(scale, scale); #pragma unroll for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { - const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k); - const int k0_stop = D/2 - (D/2) % (1*stride_k); - const int stride_j = WARP_SIZE / stride_k; + const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k); + const int k0_stop = D/2 - (D/2) % (1*stride_k); + const int stride_jc = WARP_SIZE / stride_k; if (k0_start == k0_stop) { continue; } - if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) { - break; - } - #pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps*stride_j) { - const int j = j0 + threadIdx.y*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + for (int jc0 = 0; jc0 < ncols; jc0 += nwarps*stride_jc) { + const int jc = jc0 + threadIdx.y*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); - if (jt*ncols + j < ne01) { + if (jc0 + nwarps*stride_jc > ncols && jc >= ncols) { + break; + } + + const int j = jc / ncols2; + const int c = jc % ncols2; + + if (jt*ncols1 + j < ne01) { #pragma unroll for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); - const float2 tmp = Q_f2[(jt*ncols + j)*stride_Q + k]; - tile_K[j*D2_padded + k] = scale_h2 * make_half2(tmp.x, tmp.y); + const float2 tmp = Q_f2[(jt*ncols1 + j)*stride_Q1 + c*stride_Q2 + k]; + tile_K[jc*D2_padded + k] = scale_h2 * make_half2(tmp.x, tmp.y); } } else { #pragma unroll for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); - tile_K[j*D2_padded + k] = make_half2(0.0f, 0.0f); + tile_K[jc*D2_padded + k] = make_half2(0.0f, 0.0f); } } } @@ -334,128 +513,217 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( __syncthreads(); { - const int j0 = (threadIdx.y / np) * tile_B::I; + const int j0 = (threadIdx.y / np) * cols_per_warp; #pragma unroll for (int k0 = 0; k0 < D/2; k0 += tile_B::J) { - load_ldmatrix(Q_B[k0/tile_B::J], tile_K + j0*D2_padded + k0, D2_padded); + if (ntiles == 1) { + load_ldmatrix(Q_B[k0/tile_B::J], tile_K + j0*D2_padded + k0, D2_padded); + } else { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { + load_ldmatrix(Q_B_16[k0/tile_B_16::J * ntiles/2 + t], + tile_K + (j0 + t*tile_B_16::I)*D2_padded + k0, D2_padded); + } + } } } __syncthreads(); - // Preload K data for first iteration when using cp_async: + // Preload mask and K data for first iteration when using cp_async: #ifdef CP_ASYNC_AVAILABLE - flash_attn_ext_f16_load_tile(K_h2 + kb0_start*KQ_stride*stride_KV, tile_K, stride_KV); + if (ncols2 > 1 || mask_h2) { + flash_attn_ext_f16_load_mask(mask_h2 + kb0_start*KQ_per_iter/2, tile_mask, stride_mask); + } + flash_attn_ext_f16_load_tile(K_h2 + kb0_start*KQ_per_iter*stride_KV, tile_K, stride_KV); #endif // CP_ASYNC_AVAILABLE // Iterate over ne11 == previous tokens: for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) { constexpr bool last_iter = false; - flash_attn_ext_f16_iter - (Q_f2, K_h2, V_h2, maskh, dstk, dstk_fixup, scale, slope, logit_softcap, - ne01, ne02, stride_Q, stride_KV, stride_mask, jt, tile_K, tile_V, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); + flash_attn_ext_f16_iter + (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, + ne01, ne02, stride_KV, stride_mask, jt, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); } { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally. constexpr bool last_iter = true; - flash_attn_ext_f16_iter - (Q_f2, K_h2, V_h2, maskh, dstk, dstk_fixup, scale, slope, logit_softcap, - ne01, ne02, stride_Q, stride_KV, stride_mask, jt, tile_K, tile_V, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1); + flash_attn_ext_f16_iter + (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, + ne01, ne02, stride_KV, stride_mask, jt, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1); } // With cp_async there is no __syncthreads at the end of the iter, // there can be a race condition on shared memory access for combining/writing back results. #ifdef CP_ASYNC_AVAILABLE - if (nwarps*tile_B::I > KQ_stride) { + if (nwarps*cols_per_warp > KQ_per_iter) { __syncthreads(); } #endif // CP_ASYNC_AVAILABLE // Finally, sum up partial KQ rowsums. - // The partial sums are spread across 8 threads each, does not need full reduce. + // The partial sums are spread across 8/4 threads each, does not need full reduce. + { + constexpr int offset_first = ntiles == 1 ? 16 : 2; + constexpr int offset_last = ntiles == 1 ? 4 : 1; #pragma unroll - for (int offset = 16; offset > 2; offset >>= 1) { - KQ_rowsum.x += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum.x, offset, WARP_SIZE); - KQ_rowsum.y += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum.y, offset, WARP_SIZE); + for (int col = 0; col < cols_per_thread; ++col) { +#pragma unroll + for (int offset = offset_first; offset >= offset_last; offset >>= 1) { + KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, WARP_SIZE); + } + } } // Write VKQ accumulators to shared memory in column-major format. // It's faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM. // Also for np > 1 the combination is done via these values in shared memory. - const int j_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // j combine write data + if (ntiles == 1) { + const int jc_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // jc combine write data #pragma unroll - for (int k0 = 0; k0 < D/2; k0 += tile_B::J) { - const tile_B B = get_transposed(VKQ_C[k0/tile_B::J]); // Conversion of C to B matrix puts it in column-major format. + for (int k0 = 0; k0 < D/2; k0 += tile_B::J) { + const tile_B B = get_transposed(VKQ_C[k0/tile_B::J]); // Conversion of C to B matrix puts it in column-major format. #pragma unroll - for (int l = 0; l < tile_B::ne; ++l) { - const int k = k0 + tile_B::get_j(l); + for (int l = 0; l < tile_B::ne; ++l) { + const int k = k0 + tile_B::get_j(l); - tile_K[j_cwd*D2_padded + k] = B.x[l]; + tile_K[jc_cwd*D2_padded + k] = B.x[l]; + } + } + } else { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { + const int j0 = threadIdx.y*cols_per_warp + t*tile_C_VKQ_16::I; +#pragma unroll + for (int k0 = 0; k0 < D/2; k0 += tile_C_VKQ_16::J) { +#pragma unroll + for (int l = 0; l < tile_C_VKQ_16::ne; ++l) { + const int j = j0 + tile_C_VKQ_16::get_i(l); + const int k = k0 + tile_C_VKQ_16::get_j(l); + + tile_K[j*D2_padded + k] = VKQ_C_16[k0/tile_C_VKQ_16::J * ntiles/2 + t].x[l]; + } + } } } - const int j_cwmo = (threadIdx.x % (2*tile_C_VKQ::J)) / tile_C_VKQ::J; // j combine write meta offset - const int j_cwm = threadIdx.y*(2*tile_C_VKQ::J) + 2*tile_C_VKQ::get_j(-1) + j_cwmo; // j combine write meta - const float2 KQ_cmr = make_float2(((const float *) &KQ_max)[j_cwmo], ((const float *) &KQ_rowsum)[j_cwmo]); // KQ combine max rowsum + if constexpr (ntiles == 1) { + const int jc_cwmo = (threadIdx.x % (2*tile_C_VKQ::J)) / tile_C_VKQ::J; // jc combine write meta offset + const int jc_cwm = threadIdx.y*(2*tile_C_VKQ::J) + 2*tile_C_VKQ::get_j(-1) + jc_cwmo; // jc combine write meta + const float2 KQ_cmr = make_float2(KQ_max[jc_cwmo], KQ_rowsum[jc_cwmo]); // KQ combine max rowsum - if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*tile_C_VKQ::J) { - // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale. - ((float2 *) tile_K)[j_cwm*(D2_padded/2) + D/4] = KQ_cmr; + if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*tile_C_VKQ::J) { + // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale. + ((float2 *) tile_K)[jc_cwm*(D2_padded/2) + D/4] = KQ_cmr; + } + + __syncthreads(); + + if (np == 1) { + // No combination is needed, the meta data can be directly written from registers to VRAM. + if (needs_fixup && threadIdx.x < tile_B::I) { + float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; + dstk_fixup_meta[jc_cwm] = KQ_cmr; + } + if (is_fixup && threadIdx.x < tile_B::I) { + float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; + dstk_fixup_meta[jc_cwm] = KQ_cmr; + } + } + } else { + static_assert(ntiles == 2 || ntiles == 4, "bad ntiles"); + const int jc_cwm = threadIdx.y*cols_per_warp // jc combine write meta + + (ntiles == 4 ? ((threadIdx.x % 4) / 2) * tile_C_VKQ_16::I : 0) + + tile_C_VKQ_16::get_i(threadIdx.x % 4); + const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]); // KQ combine max rowsum + + if (((!needs_fixup && !is_fixup) || np > 1) && (ntiles == 4 || threadIdx.x % 4 < cols_per_thread)) { + // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale. + ((float2 *) tile_K)[jc_cwm*(D2_padded/2) + D/4] = KQ_cmr; + } + + __syncthreads(); + + if (np == 1) { + // No combination is needed, the meta data can be directly written from registers to VRAM. + if (needs_fixup && (ntiles == 4 || threadIdx.x % 4 < ntiles)) { + float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; + dstk_fixup_meta[jc_cwm] = KQ_cmr; + } + if (is_fixup && (ntiles == 4 || threadIdx.x % 4 < ntiles)) { + float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; + dstk_fixup_meta[jc_cwm] = KQ_cmr; + } + } } - __syncthreads(); - - static_assert(np == 1 || np == 2 || np == 4, "bad np"); - if (np == 1) { - // No combination is needed, the meta data can be directly written from registers to VRAM. - if (needs_fixup && threadIdx.x < tile_B::I) { - float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; - dstk_fixup_meta[j_cwm] = KQ_cmr; - } - if (is_fixup && threadIdx.x < tile_B::I) { - float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; - dstk_fixup_meta[j_cwm] = KQ_cmr; - } - } else if (threadIdx.y % np == 0) { + static_assert(np == 1 || ntiles == 1 || ntiles == 2, "bad ntiles"); + if (np > 1 && threadIdx.y % np == 0) { // Combine the meta data for parallel warps via shared memory. // Warps with threadIdx.y % np != 0 must NOT return early. // All threads must return simultaneously to avoid race conditions with work on the next tile. - float * meta_j = (float *) tile_K + (threadIdx.y*tile_B::I + threadIdx.x)*D2_padded + D/2; + constexpr int nmeta = np*cols_per_warp >= WARP_SIZE ? np*cols_per_warp/WARP_SIZE : 1; - float KQ_cm = -FLT_MAX/2; // KQ combine max per parallel warp. - if (np*tile_B::I == WARP_SIZE || threadIdx.x < np*tile_B::I) { - KQ_cm = meta_j[0]; + const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < WARP_SIZE ? threadIdx.x % (np*cols_per_warp) : threadIdx.x); + float2 * const meta_ptr = ((float2 *) tile_K) + jc_meta*(D2_padded/2) + D/4; + float2 meta[nmeta]; +#pragma unroll + for (int imeta = 0; imeta < nmeta; ++imeta) { + meta[imeta] = meta_ptr[imeta * WARP_SIZE * D2_padded/2]; } - float KQ_cmn = KQ_cm; // KQ combine max new, max between all parallel warps. + float KQ_cmn = meta[0].x; // KQ combine max new, max between all parallel warps. #pragma unroll - for (int offset = np*tile_B::I/2; offset >= tile_B::I; offset >>= 1) { + for (int imeta = 1; imeta < nmeta; ++imeta) { + KQ_cmn = fmaxf(KQ_cmn, meta[imeta].x); + } +#pragma unroll + for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) { + if (offset >= WARP_SIZE) { + continue; + } KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE)); } - const float KQ_cms = expf(KQ_cm - KQ_cmn); // KQ combine max scale per warp. - float KQ_crs = 0.0f; // KQ combine rowsum, scaled sum of all parallel warps. - if (np*tile_B::I == WARP_SIZE || threadIdx.x < np*tile_B::I) { - KQ_crs = KQ_cms*meta_j[1]; + float KQ_cms[nmeta]; // KQ combine max scale per warp. +#pragma unroll + for (int imeta = 0; imeta < nmeta; ++imeta) { + KQ_cms[imeta] = expf(meta[imeta].x - KQ_cmn); + } + + float KQ_crs = KQ_cms[0]*meta[0].y; // KQ combine rowsum, scaled sum of all parallel warps. +#pragma unroll + for (int imeta = 1; imeta < nmeta; ++imeta) { + KQ_crs += KQ_cms[imeta]*meta[imeta].y; } #pragma unroll - for (int offset = np*tile_B::I/2; offset >= tile_B::I; offset >>= 1) { + for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) { + if (offset >= WARP_SIZE) { + continue; + } KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE); } // Write back combined meta data: - if (np*tile_B::I == WARP_SIZE || threadIdx.x < np*tile_B::I) { - *((float2 *) meta_j) = make_float2(KQ_cms, KQ_crs); // Combined KQ max scale + rowsum. +#pragma unroll + for (int imeta = 0; imeta < nmeta; ++imeta) { + if (np*cols_per_warp >= WARP_SIZE || threadIdx.x < np*cols_per_warp) { + // Combined KQ max scale + rowsum. + meta_ptr[imeta * WARP_SIZE * D2_padded/2] = make_float2(KQ_cms[imeta], KQ_crs); + } } - if (needs_fixup && threadIdx.x < tile_B::I) { + + // Combined KQ max + rowsum. + static_assert(cols_per_warp <= WARP_SIZE); + if (needs_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) { float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; - dstk_fixup_meta[(threadIdx.y/np)*tile_B::I + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); + dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); } - if (is_fixup && threadIdx.x < tile_B::I) { + if (is_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) { float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; - dstk_fixup_meta[(threadIdx.y/np)*tile_B::I + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); + dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); } } @@ -470,27 +738,32 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( #pragma unroll for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { - const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k); - const int k0_stop = D/2 - (D/2) % (1*stride_k); - const int stride_j = WARP_SIZE / stride_k; + const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k); + const int k0_stop = D/2 - (D/2) % (1*stride_k); + const int stride_jc = WARP_SIZE / stride_k; if (k0_start == k0_stop) { continue; } - if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) { - break; - } - #pragma unroll - for (int j0_dst = 0; j0_dst < ncols; j0_dst += (nwarps/np)*stride_j) { - const int j_dst = j0_dst + (threadIdx.y/np)*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); - const int j_tile_K = (j_dst/tile_B::I)*(np*tile_B::I) + j_dst % tile_B::I; + for (int jc0_dst = 0; jc0_dst < ncols; jc0_dst += (nwarps/np)*stride_jc) { + const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); - if (!is_fixup && jt*ncols + j_dst >= ne01) { + if (jc0_dst + (nwarps/np)*stride_jc > ncols && jc_dst >= ncols) { + break; + } + + const int jc_tile_K = (jc_dst/cols_per_warp)*(np*cols_per_warp) + jc_dst % cols_per_warp; + + const int j_dst = jc_dst / ncols2; + const int c_dst = jc_dst % ncols2; + + if (!is_fixup && jt*ncols1 + j_dst >= ne01) { continue; } - const float * meta_j = (const float *) tile_K + j_tile_K*D2_padded + D/2; + + const float * meta_j = (const float *) tile_K + jc_tile_K*D2_padded + D/2; #pragma unroll for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); @@ -498,8 +771,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( float2 dstk_val = make_float2(0.0f, 0.0f); #pragma unroll for (int ip = 0; ip < np; ++ip) { - const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*tile_B::I*D2_padded + 0]; - const float2 dstk_val_add = __half22float2(tile_K[(j_tile_K + ip*tile_B::I)*D2_padded + k]); + const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*cols_per_warp * D2_padded + 0]; + const float2 dstk_val_add = __half22float2(tile_K[(jc_tile_K + ip*cols_per_warp) * D2_padded + k]); dstk_val.x += dstk_val_add.x*KQ_crs; dstk_val.y += dstk_val_add.y*KQ_crs; } @@ -511,9 +784,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } if (is_fixup) { - dstk_fixup_data[j_dst*(D/2) + k] = dstk_val; + dstk_fixup_data[jc_dst*(D/2) + k] = dstk_val; } else { - dstk[(jt*ncols + j_dst)*ne02*(D/2) + k] = dstk_val; + dstk[((jt*ncols1 + j_dst)*ne02 + c_dst)*(D/2) + k] = dstk_val; } } } @@ -528,10 +801,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( #endif // NEW_MMA_AVAILABLE } -template -#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +template __launch_bounds__(nwarps*WARP_SIZE, 2) -#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) static __global__ void flash_attn_ext_f16( const char * __restrict__ Q, const char * __restrict__ K, @@ -579,20 +850,23 @@ static __global__ void flash_attn_ext_f16( return; } - static_assert(FATTN_KQ_STRIDE % KQ_stride == 0, "bad KQ_stride"); + static_assert(FATTN_KQ_STRIDE % KQ_per_iter == 0, "bad KQ_per_iter"); const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const int stride_Q = nb01 / sizeof(float2); + const int stride_Q1 = nb01 / sizeof(float2); + const int stride_Q2 = nb02 / sizeof(float2); const int stride_KV = nb11 / sizeof(half2); - const int stride_mask = nb31 / sizeof(half); + const int stride_mask = nb31 / sizeof(half2); - const int iter_k = ne11 / KQ_stride; - const int iter_j = (ne01 + (ncols - 1)) / ncols; + const int iter_k = ne11 / FATTN_KQ_STRIDE; + const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; + + constexpr int kb_niter = FATTN_KQ_STRIDE / KQ_per_iter; // Number of kernel iterations per assigned KQ slice. // kbc == k block continuous, current index in continuous ijk space. - int kbc = (blockIdx.x + 0)*iter_k*iter_j*ne02 / gridDim.x; - const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*ne02 / gridDim.x; + int kbc = (blockIdx.x + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; + const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined. // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup). @@ -605,25 +879,28 @@ static __global__ void flash_attn_ext_f16( const int channel = kbc / (iter_k*iter_j); const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile. - const float2 * Q_f2 = (const float2 *) (Q + nb02* channel); - const half2 * K_h2 = (const half2 *) (K + nb12*(channel / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb12*(channel / gqa_ratio)); // K and V have same shape - const half * maskh = mask ? (const half *) mask + (nb31/sizeof(half))*jt*ncols : nullptr; - float2 * dstk = ((float2 *) dst) + channel*(D/2); + const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); + const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); + const half2 * V_h2 = (const half2 *) (V + nb12*(channel*ncols2 / gqa_ratio)); // K and V have same shape + const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; + float2 * dstk = ((float2 *) dst) + channel*(ncols2 * D/2); - const float slope = get_alibi_slope(max_bias, channel, n_head_log2, m0, m1); + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; + + const int kb0_start_kernel = kb0_start * kb_niter; + const int kb0_stop_kernel = kb0_stop * kb_niter; constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. if (kb0_start == 0) { constexpr bool needs_fixup = false; // CUDA block is working on an entire tile. - flash_attn_ext_f16_process_tile - (Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap, - ne01, ne02, stride_Q, stride_KV, stride_mask, jt, kb0_start, kb0_stop); + flash_attn_ext_f16_process_tile + (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, + ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); } else { constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile. - flash_attn_ext_f16_process_tile - (Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap, - ne01, ne02, stride_Q, stride_KV, stride_mask, jt, kb0_start, kb0_stop); + flash_attn_ext_f16_process_tile + (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, + ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); } kbc += iter_k; @@ -640,39 +917,46 @@ static __global__ void flash_attn_ext_f16( const int channel = kbc / (iter_k*iter_j); const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile. - const float2 * Q_f2 = (const float2 *) (Q + nb02* channel); - const half2 * K_h2 = (const half2 *) (K + nb12*(channel / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb12*(channel / gqa_ratio)); // K and V have same shape - const half * maskh = mask ? (const half *) mask + (nb31/sizeof(half))*jt*ncols : nullptr; - float2 * dstk = ((float2 *) dst) + channel*(D/2); + const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); + const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); + const half2 * V_h2 = (const half2 *) (V + nb12*(channel*ncols2 / gqa_ratio)); // K and V have same shape + const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; + float2 * dstk = ((float2 *) dst) + channel*(ncols2 * D/2); - const float slope = get_alibi_slope(max_bias, channel, n_head_log2, m0, m1); + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; + + const int kb0_start_kernel = kb0_start * kb_niter; + const int kb0_stop_kernel = kb0_stop * kb_niter; constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. constexpr bool needs_fixup = false; - flash_attn_ext_f16_process_tile - (Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap, - ne01, ne02, stride_Q, stride_KV, stride_mask, jt, kb0_start, kb0_stop); + flash_attn_ext_f16_process_tile + (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, + ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); } -template +template void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - typedef tile<16, 8, half2> tile_A; - typedef tile< 8, 8, half2> tile_B; + constexpr int ncols = ncols1 * ncols2; + constexpr int KQ_per_iter = D <= 128 && ncols1 <= 64 ? 64 : 32; + constexpr int nwarps = (KQ_per_iter == 32 && ncols <= 16) ? 2 : 4; + constexpr int ntiles = ncols <= 8 ? 1 : (ncols <= 64 ? 2 : 4); + constexpr int cols_per_warp = ntiles * tile_B::I; - static_assert(D % tile_B::J == 0, "bad D"); - static_assert(cols_per_block % tile_B::I == 0, "bad cols_per_block"); + static_assert(D % tile_B::J == 0, "bad D"); + static_assert(ncols % cols_per_warp == 0, "bad ncols"); const ggml_tensor * KQV = dst; - const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; - constexpr int KQ_stride = D <= 128 ? 64 : 32; - constexpr int nwarps = (KQ_stride == 32 && cols_per_block <= 16) ? - cols_per_block/tile_B::J * KQ_stride/tile_A::I : (cols_per_block <= 8 ? 4 : 8); + const int KQ_shared_rows = cp_async_available(cc) ? 2*KQ_per_iter : KQ_per_iter; - const int nrows_KQ = cp_async_available(cc) ? 2*KQ_stride : KQ_stride; - const int nrows_combine = nwarps*tile_B::J; - const size_t nbytes_shared = std::max(nrows_KQ, nrows_combine) * (D + 8) * sizeof(half); + const size_t nbytes_shared_KV = KQ_shared_rows * (D + 8) * sizeof(half); + const size_t nbytes_shared_mask = ncols1 * (KQ_per_iter + 8) * sizeof(half); + const size_t nbytes_shared_combine = nwarps*cols_per_warp * (D + 8) * sizeof(half); + + const size_t nbytes_shared_total = std::max(nbytes_shared_KV + nbytes_shared_mask, nbytes_shared_combine); float logit_softcap; memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); @@ -680,42 +964,58 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml fattn_kernel_t fattn_kernel; if (logit_softcap == 0.0f) { constexpr bool use_logit_softcap = false; - fattn_kernel = flash_attn_ext_f16; + fattn_kernel = flash_attn_ext_f16; } else { constexpr bool use_logit_softcap = true; - fattn_kernel = flash_attn_ext_f16; + fattn_kernel = flash_attn_ext_f16; } - launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true); + + launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, true, true); } -#define DECL_FATTN_MMA_F16_CASE(D, cols_per_block) \ + +#define DECL_FATTN_MMA_F16_CASE(D, ncols1, ncols2) \ template void ggml_cuda_flash_attn_ext_mma_f16_case \ - (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ + (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ -extern DECL_FATTN_MMA_F16_CASE( 64, 8); -extern DECL_FATTN_MMA_F16_CASE( 80, 8); -extern DECL_FATTN_MMA_F16_CASE( 96, 8); -extern DECL_FATTN_MMA_F16_CASE(112, 8); -extern DECL_FATTN_MMA_F16_CASE(128, 8); -extern DECL_FATTN_MMA_F16_CASE(256, 8); +#define DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(D, ncols) \ + extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/1, 1); \ + extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/2, 2); \ + extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/4, 4); \ + extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/8, 8); \ -extern DECL_FATTN_MMA_F16_CASE( 64, 16); -extern DECL_FATTN_MMA_F16_CASE( 80, 16); -extern DECL_FATTN_MMA_F16_CASE( 96, 16); -extern DECL_FATTN_MMA_F16_CASE(112, 16); -extern DECL_FATTN_MMA_F16_CASE(128, 16); -extern DECL_FATTN_MMA_F16_CASE(256, 16); +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 8); +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 8); +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 8); +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 8); +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 8); +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 8); -extern DECL_FATTN_MMA_F16_CASE( 64, 32); -extern DECL_FATTN_MMA_F16_CASE( 80, 32); -extern DECL_FATTN_MMA_F16_CASE( 96, 32); -extern DECL_FATTN_MMA_F16_CASE(112, 32); -extern DECL_FATTN_MMA_F16_CASE(128, 32); -extern DECL_FATTN_MMA_F16_CASE(256, 32); +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 16); +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 16); +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 16); +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 16); +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 16); +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 16); -extern DECL_FATTN_MMA_F16_CASE( 64, 64); -extern DECL_FATTN_MMA_F16_CASE( 80, 64); -extern DECL_FATTN_MMA_F16_CASE( 96, 64); -extern DECL_FATTN_MMA_F16_CASE(112, 64); -extern DECL_FATTN_MMA_F16_CASE(128, 64); -extern DECL_FATTN_MMA_F16_CASE(256, 64); +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 32); +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 32); +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 32); +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 32); +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 32); +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 32); + +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64); +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 64); +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 64); +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 64); +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 64); +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 64); + +// Kernels with ncols == 128 are only 4% faster due to register pressure. +// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 128); +// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 128); +// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 128); +// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 128); +// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128); +// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 128); // Needs too much shared memory. diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu index d4edbad0..b8b415ef 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f16.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f16.cu @@ -302,14 +302,14 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * constexpr int nwarps = 8; constexpr size_t nbytes_shared = 0; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; - launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true); + launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true); } break; case 128: { constexpr int D = 128; constexpr int nwarps = 8; constexpr size_t nbytes_shared = 0; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; - launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true); + launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true); } break; default: { GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128."); diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu index 0d274f33..4352a284 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu @@ -296,14 +296,14 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * constexpr int nwarps = 8; constexpr size_t nbytes_shared = 0; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; - launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true); + launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true); } break; case 128: { constexpr int D = 128; constexpr int nwarps = 8; constexpr size_t nbytes_shared = 0; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; - launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true); + launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true); } break; default: { GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128."); diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index d9ac4424..e758a0f6 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -310,7 +310,7 @@ void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, constexpr bool need_f16_K = D != 128; constexpr bool need_f16_V = D != 128 && D != 64; constexpr size_t nbytes_shared = 0; - launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V); + launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V); } template diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index 6ef8f9dc..134144a3 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -290,7 +290,7 @@ void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, constexpr bool need_f16_K = D != 128; constexpr bool need_f16_V = D != 128 && D != 64; constexpr size_t nbytes_shared = 0; - launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V); + launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V); } template diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index 45702ad6..de38470a 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -478,7 +478,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm fattn_kernel = flash_attn_ext_f16< D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; } - launch_fattn(ctx, dst, fattn_kernel, nwarps, 0, true, true); + launch_fattn(ctx, dst, fattn_kernel, nwarps, 0, true, true); return; } if (2*blocks_num_pb1 < 2*nsm) { @@ -493,7 +493,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm fattn_kernel = flash_attn_ext_f16< D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; } - launch_fattn(ctx, dst, fattn_kernel, nwarps, 0, true, true); + launch_fattn(ctx, dst, fattn_kernel, nwarps, 0, true, true); return; } constexpr int parallel_blocks = 1; @@ -507,7 +507,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm fattn_kernel = flash_attn_ext_f16< D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; } - launch_fattn(ctx, dst, fattn_kernel, nwarps, 0, true, true); + launch_fattn(ctx, dst, fattn_kernel, nwarps, 0, true, true); } void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index b0cf152f..b1becccb 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -8,28 +8,50 @@ #include "fattn-wmma-f16.cuh" #include "fattn.cuh" -template +template +static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + + if (Q->ne[1] <= 8/ncols2) { + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); + return; + } + + if (Q->ne[1] <= 16/ncols2) { + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); + return; + } + + if (Q->ne[1] <= 32/ncols2) { + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); + return; + } + + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); +} + +template static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; switch (Q->ne[0]) { case 64: - ggml_cuda_flash_attn_ext_mma_f16_case< 64, cols_per_block>(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 64, ncols2>(ctx, dst); break; case 80: - ggml_cuda_flash_attn_ext_mma_f16_case< 80, cols_per_block>(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 80, ncols2>(ctx, dst); break; case 96: - ggml_cuda_flash_attn_ext_mma_f16_case< 96, cols_per_block>(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 96, ncols2>(ctx, dst); break; case 112: - ggml_cuda_flash_attn_ext_mma_f16_case<112, cols_per_block>(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<112, ncols2>(ctx, dst); break; case 128: - ggml_cuda_flash_attn_ext_mma_f16_case<128, cols_per_block>(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<128, ncols2>(ctx, dst); break; case 256: - ggml_cuda_flash_attn_ext_mma_f16_case<256, cols_per_block>(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<256, ncols2>(ctx, dst); break; default: GGML_ABORT("fatal error"); @@ -38,24 +60,35 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context } static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * mask = dst->src[3]; - if (Q->ne[1] <= 8) { + float max_bias = 0.0f; + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); + + const float use_gqa_opt = mask && max_bias == 0.0f; + + GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); + const int gqa_ratio = Q->ne[2] / K->ne[2]; + + if (use_gqa_opt && gqa_ratio % 8 == 0) { ggml_cuda_flash_attn_ext_mma_f16_switch_hs<8>(ctx, dst); return; } - if (Q->ne[1] <= 16) { - ggml_cuda_flash_attn_ext_mma_f16_switch_hs<16>(ctx, dst); + if (use_gqa_opt && gqa_ratio == 4) { + ggml_cuda_flash_attn_ext_mma_f16_switch_hs<4>(ctx, dst); return; } - if (Q->ne[1] <= 32) { - ggml_cuda_flash_attn_ext_mma_f16_switch_hs<32>(ctx, dst); + if (use_gqa_opt && gqa_ratio == 2) { + ggml_cuda_flash_attn_ext_mma_f16_switch_hs<2>(ctx, dst); return; } - ggml_cuda_flash_attn_ext_mma_f16_switch_hs<64>(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_switch_hs<1>(ctx, dst); } #define FATTN_VEC_F16_CASE(D, type_K, type_V) \ @@ -209,8 +242,11 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg } void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * KQV = dst; - const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + const ggml_tensor * mask = dst->src[3]; ggml_cuda_set_device(ctx.device); const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; @@ -252,7 +288,10 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst return; } - if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) { + const int gqa_ratio = Q->ne[2] / K->ne[2]; + const bool mma_fast_for_bs1 = fp16_mma_available(cc) && gqa_ratio % 2 == 0 && + K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && mask; + if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0 && !mma_fast_for_bs1) { if (prec == GGML_PREC_DEFAULT) { ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); return; diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index 0a5656e4..9206bfeb 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -73,6 +73,8 @@ namespace ggml_cuda_mma { return threadIdx.x / 4; } else if constexpr (I == 16 && J == 8) { return (l / 2) * 8 + threadIdx.x / 4; + } else if constexpr (I == 16 && J == 16) { + return ((l / 2) % 2) * 8 + threadIdx.x / 4; } else { static_assert(I == -1 && J == -1, "template specialization not implemented"); } @@ -85,6 +87,8 @@ namespace ggml_cuda_mma { return 4 * l + threadIdx.x % 4; } else if constexpr (I == 16 && J == 8) { return 2 * (threadIdx.x % 4) + l % 2; + } else if constexpr (I == 16 && J == 16) { + return 8 * (l / 4) + 2 * (threadIdx.x % 4) + l % 2; } else { static_assert(I == -1 && J == -1, "template specialization not implemented"); } @@ -289,6 +293,42 @@ namespace ggml_cuda_mma { #endif // NEW_MMA_AVAILABLE } + static __device__ __forceinline__ void mma( + tile<16, 8, half2> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) { +#ifdef NEW_MMA_AVAILABLE + const int * Axi = (const int *) A.x; + const int * Bxi = (const int *) B.x; + int * Dxi = (int *) D.x; +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};" + : "+r"(Dxi[0]), "+r"(Dxi[1]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[2])); + asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};" + : "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3])); +#else + // On Turing m16n8k16 mma is not available, use 4x m8n8k8 mma instead: + asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};" + : "+r"(Dxi[0]), "+r"(Dxi[1]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0])); + asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};" + : "+r"(Dxi[0]), "+r"(Dxi[1]) + : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2])); + asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};" + : "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[1])); + asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};" + : "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3])); +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#else + GGML_UNUSED(D); + GGML_UNUSED(A); + GGML_UNUSED(B); + NO_DEVICE_CODE; +#endif // NEW_MMA_AVAILABLE + } + static __device__ __forceinline__ void mma( tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) { #ifdef NEW_MMA_AVAILABLE @@ -316,4 +356,39 @@ namespace ggml_cuda_mma { #endif // NEW_MMA_AVAILABLE } + static __device__ __forceinline__ void mma( + tile<16, 16, float> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) { +#ifdef NEW_MMA_AVAILABLE + const int * Axi = (const int *) A.x; + const int * Bxi = (const int *) B.x; + int * Dxi = (int *) D.x; +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[2])); + asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" + : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3])); +#else + // On Turing m16n8k16 mma is not available, use 4x m8n8k8 mma instead: + asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0])); + asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2])); + asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" + : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[1])); + asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" + : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) + : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3])); +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#else + GGML_UNUSED(D); + GGML_UNUSED(A); + GGML_UNUSED(B); + NO_DEVICE_CODE; +#endif // NEW_MMA_AVAILABLE + } } diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb16.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb16.cu deleted file mode 100644 index f09bdeff..00000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb16.cu +++ /dev/null @@ -1,10 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-mma-f16.cuh" - -DECL_FATTN_MMA_F16_CASE(64, 16); -DECL_FATTN_MMA_F16_CASE(80, 16); -DECL_FATTN_MMA_F16_CASE(96, 16); -DECL_FATTN_MMA_F16_CASE(112, 16); -DECL_FATTN_MMA_F16_CASE(128, 16); -DECL_FATTN_MMA_F16_CASE(256, 16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb32.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb32.cu deleted file mode 100644 index 22110887..00000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb32.cu +++ /dev/null @@ -1,10 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-mma-f16.cuh" - -DECL_FATTN_MMA_F16_CASE(64, 32); -DECL_FATTN_MMA_F16_CASE(80, 32); -DECL_FATTN_MMA_F16_CASE(96, 32); -DECL_FATTN_MMA_F16_CASE(112, 32); -DECL_FATTN_MMA_F16_CASE(128, 32); -DECL_FATTN_MMA_F16_CASE(256, 32); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb64.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb64.cu deleted file mode 100644 index d24b0857..00000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb64.cu +++ /dev/null @@ -1,10 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-mma-f16.cuh" - -DECL_FATTN_MMA_F16_CASE(64, 64); -DECL_FATTN_MMA_F16_CASE(80, 64); -DECL_FATTN_MMA_F16_CASE(96, 64); -DECL_FATTN_MMA_F16_CASE(112, 64); -DECL_FATTN_MMA_F16_CASE(128, 64); -DECL_FATTN_MMA_F16_CASE(256, 64); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb8.cu deleted file mode 100644 index bdf86c0e..00000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb8.cu +++ /dev/null @@ -1,10 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-mma-f16.cuh" - -DECL_FATTN_MMA_F16_CASE(64, 8); -DECL_FATTN_MMA_F16_CASE(80, 8); -DECL_FATTN_MMA_F16_CASE(96, 8); -DECL_FATTN_MMA_F16_CASE(112, 8); -DECL_FATTN_MMA_F16_CASE(128, 8); -DECL_FATTN_MMA_F16_CASE(256, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu new file mode 100644 index 00000000..80108615 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 1, 8); +DECL_FATTN_MMA_F16_CASE(80, 1, 8); +DECL_FATTN_MMA_F16_CASE(96, 1, 8); +DECL_FATTN_MMA_F16_CASE(112, 1, 8); +DECL_FATTN_MMA_F16_CASE(128, 1, 8); +DECL_FATTN_MMA_F16_CASE(256, 1, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu new file mode 100644 index 00000000..66161c0a --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 16, 1); +DECL_FATTN_MMA_F16_CASE(80, 16, 1); +DECL_FATTN_MMA_F16_CASE(96, 16, 1); +DECL_FATTN_MMA_F16_CASE(112, 16, 1); +DECL_FATTN_MMA_F16_CASE(128, 16, 1); +DECL_FATTN_MMA_F16_CASE(256, 16, 1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu new file mode 100644 index 00000000..ee88c72a --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 16, 2); +DECL_FATTN_MMA_F16_CASE(80, 16, 2); +DECL_FATTN_MMA_F16_CASE(96, 16, 2); +DECL_FATTN_MMA_F16_CASE(112, 16, 2); +DECL_FATTN_MMA_F16_CASE(128, 16, 2); +DECL_FATTN_MMA_F16_CASE(256, 16, 2); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu new file mode 100644 index 00000000..d888a5a4 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 16, 4); +DECL_FATTN_MMA_F16_CASE(80, 16, 4); +DECL_FATTN_MMA_F16_CASE(96, 16, 4); +DECL_FATTN_MMA_F16_CASE(112, 16, 4); +DECL_FATTN_MMA_F16_CASE(128, 16, 4); +DECL_FATTN_MMA_F16_CASE(256, 16, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu new file mode 100644 index 00000000..d93a2d08 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 2, 4); +DECL_FATTN_MMA_F16_CASE(80, 2, 4); +DECL_FATTN_MMA_F16_CASE(96, 2, 4); +DECL_FATTN_MMA_F16_CASE(112, 2, 4); +DECL_FATTN_MMA_F16_CASE(128, 2, 4); +DECL_FATTN_MMA_F16_CASE(256, 2, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu new file mode 100644 index 00000000..617464c9 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 2, 8); +DECL_FATTN_MMA_F16_CASE(80, 2, 8); +DECL_FATTN_MMA_F16_CASE(96, 2, 8); +DECL_FATTN_MMA_F16_CASE(112, 2, 8); +DECL_FATTN_MMA_F16_CASE(128, 2, 8); +DECL_FATTN_MMA_F16_CASE(256, 2, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu new file mode 100644 index 00000000..970d2b68 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 32, 1); +DECL_FATTN_MMA_F16_CASE(80, 32, 1); +DECL_FATTN_MMA_F16_CASE(96, 32, 1); +DECL_FATTN_MMA_F16_CASE(112, 32, 1); +DECL_FATTN_MMA_F16_CASE(128, 32, 1); +DECL_FATTN_MMA_F16_CASE(256, 32, 1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu new file mode 100644 index 00000000..65cd377c --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 32, 2); +DECL_FATTN_MMA_F16_CASE(80, 32, 2); +DECL_FATTN_MMA_F16_CASE(96, 32, 2); +DECL_FATTN_MMA_F16_CASE(112, 32, 2); +DECL_FATTN_MMA_F16_CASE(128, 32, 2); +DECL_FATTN_MMA_F16_CASE(256, 32, 2); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu new file mode 100644 index 00000000..f4a8bf34 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 4, 2); +DECL_FATTN_MMA_F16_CASE(80, 4, 2); +DECL_FATTN_MMA_F16_CASE(96, 4, 2); +DECL_FATTN_MMA_F16_CASE(112, 4, 2); +DECL_FATTN_MMA_F16_CASE(128, 4, 2); +DECL_FATTN_MMA_F16_CASE(256, 4, 2); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu new file mode 100644 index 00000000..de191a8a --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 4, 4); +DECL_FATTN_MMA_F16_CASE(80, 4, 4); +DECL_FATTN_MMA_F16_CASE(96, 4, 4); +DECL_FATTN_MMA_F16_CASE(112, 4, 4); +DECL_FATTN_MMA_F16_CASE(128, 4, 4); +DECL_FATTN_MMA_F16_CASE(256, 4, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu new file mode 100644 index 00000000..e8cb0e1b --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 4, 8); +DECL_FATTN_MMA_F16_CASE(80, 4, 8); +DECL_FATTN_MMA_F16_CASE(96, 4, 8); +DECL_FATTN_MMA_F16_CASE(112, 4, 8); +DECL_FATTN_MMA_F16_CASE(128, 4, 8); +DECL_FATTN_MMA_F16_CASE(256, 4, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu new file mode 100644 index 00000000..a532e962 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 64, 1); +DECL_FATTN_MMA_F16_CASE(80, 64, 1); +DECL_FATTN_MMA_F16_CASE(96, 64, 1); +DECL_FATTN_MMA_F16_CASE(112, 64, 1); +DECL_FATTN_MMA_F16_CASE(128, 64, 1); +DECL_FATTN_MMA_F16_CASE(256, 64, 1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu new file mode 100644 index 00000000..bf25181a --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 8, 1); +DECL_FATTN_MMA_F16_CASE(80, 8, 1); +DECL_FATTN_MMA_F16_CASE(96, 8, 1); +DECL_FATTN_MMA_F16_CASE(112, 8, 1); +DECL_FATTN_MMA_F16_CASE(128, 8, 1); +DECL_FATTN_MMA_F16_CASE(256, 8, 1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu new file mode 100644 index 00000000..378c132e --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 8, 2); +DECL_FATTN_MMA_F16_CASE(80, 8, 2); +DECL_FATTN_MMA_F16_CASE(96, 8, 2); +DECL_FATTN_MMA_F16_CASE(112, 8, 2); +DECL_FATTN_MMA_F16_CASE(128, 8, 2); +DECL_FATTN_MMA_F16_CASE(256, 8, 2); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu new file mode 100644 index 00000000..372641be --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 8, 4); +DECL_FATTN_MMA_F16_CASE(80, 8, 4); +DECL_FATTN_MMA_F16_CASE(96, 8, 4); +DECL_FATTN_MMA_F16_CASE(112, 8, 4); +DECL_FATTN_MMA_F16_CASE(128, 8, 4); +DECL_FATTN_MMA_F16_CASE(256, 8, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu new file mode 100644 index 00000000..9ff5968b --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 8, 8); +DECL_FATTN_MMA_F16_CASE(80, 8, 8); +DECL_FATTN_MMA_F16_CASE(96, 8, 8); +DECL_FATTN_MMA_F16_CASE(112, 8, 8); +DECL_FATTN_MMA_F16_CASE(128, 8, 8); +DECL_FATTN_MMA_F16_CASE(256, 8, 8); diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index a2628f16..dd373a09 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -18,7 +18,7 @@ SOURCE_FATTN_MMA_START = """// This file has been autogenerated by generate_cu_f """ -SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size}, {cols_per_block});\n" +SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size}, {ncols1}, {ncols2});\n" TYPES_MMQ = [ "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", @@ -57,12 +57,18 @@ for vkq_size in [16, 32]: with open(f"fattn-vec-f{vkq_size}-instance-hs{head_size}-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f: f.write(SOURCE_FATTN_VEC.format(vkq_size=vkq_size, head_size=head_size, type_k=type_k, type_v=type_v)) -for cols_per_block in [8, 16, 32, 64]: - with open(f"fattn-mma-f16-instance-cpb{cols_per_block}.cu", "w") as f: - f.write(SOURCE_FATTN_MMA_START) +for ncols in [8, 16, 32, 64, 128]: + for ncols2 in [1, 2, 4, 8]: + ncols1 = ncols // ncols2 + if ncols == 128: + continue # Too much register pressure. + with open(f"fattn-mma-f16-instance-ncols1_{ncols1}-ncols2_{ncols2}.cu", "w") as f: + f.write(SOURCE_FATTN_MMA_START) - for head_size in [64, 80, 96, 112, 128, 256]: - f.write(SOURCE_FATTN_MMA_CASE.format(cols_per_block=cols_per_block, head_size=head_size)) + for head_size in [64, 80, 96, 112, 128, 256]: + if ncols == 128 and head_size == 256: + continue # Needs too much shared memory. + f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size=head_size)) for type in TYPES_MMQ: with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f: