CUDA: optimize FA for GQA + large batches (llama/12014)

This commit is contained in:
Johannes Gäßler 2025-02-22 12:20:17 +01:00 committed by Georgi Gerganov
parent 98dab49b9a
commit 2d70cd36d7
31 changed files with 920 additions and 396 deletions

View File

@ -24,7 +24,7 @@ static __device__ __forceinline__ void cp_async_cg_16(const unsigned int dst, co
} else
#endif // CUDART_VERSION >= 11040
{
asm volatile("cp.async.cg.shared.global.L2 [%0], [%1], 16;"
asm volatile("cp.async.cg.shared.global [%0], [%1], 16;"
: : "r"(dst), "l"(src));
}
#else

View File

@ -516,27 +516,25 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
nullptr;
}
// The HIP compiler for some reason complains that it can't unroll a loop because of the jt*ncols + j >= ne01 conditional.
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wpass-failed"
#endif // __clang__
template<int D, int ncols, int KQ_stride> // D == head size
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
template<int D, int ncols1, int ncols2, int KQ_stride> // D == head size
__launch_bounds__(D, 1)
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
static __global__ void flash_attn_stream_k_fixup(
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
const int iter_k = ne11 / KQ_stride;
const int iter_j = (ne01 + (ncols - 1)) / ncols;
constexpr int ncols = ncols1*ncols2;
const int bidx0 = blockIdx.x;
const int j = blockIdx.y;
const int c = blockIdx.z;
const int jc = j*ncols2 + c;
const int tid = threadIdx.x;
const int kbc0 = (bidx0 + 0)*iter_k*iter_j*ne02 / gridDim.x;
const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*ne02 / gridDim.x;
const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
const int iter_k = ne11 / FATTN_KQ_STRIDE;
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
const int kbc0 = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
const bool did_not_have_any_data = kbc0 == kbc0_stop;
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
@ -548,22 +546,22 @@ static __global__ void flash_attn_stream_k_fixup(
const int channel = kbc0 / (iter_k*iter_j);
const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k;
dst += jt*ncols*ne02*D + channel*D;
if (jt*ncols1 + j >= ne01) {
return;
}
dst += jt*ne02*(ncols1*D) + channel*(ncols2*D) + (j*ne02 + c)*D + tid;
// Load the partial result that needs a fixup:
float dst_val[ncols] = {0.0f};
float max_val[ncols] = {0.0f};
float rowsum[ncols] = {0.0f};
#pragma unroll
for (int j = 0; j < ncols; ++j) {
if (jt*ncols + j >= ne01) {
break;
}
dst_val[j] = dst[j*ne02*D + threadIdx.x];
float dst_val = 0.0f;
float max_val = 0.0f;
float rowsum = 0.0f;
{
dst_val = *dst;
const float2 tmp = dst_fixup[bidx0*ncols + j];
max_val[j] = tmp.x;
rowsum[j] = tmp.y;
const float2 tmp = dst_fixup[bidx0*ncols + jc];
max_val = tmp.x;
rowsum = tmp.y;
}
// Iterate over previous blocks and compute the combined results.
@ -571,36 +569,30 @@ static __global__ void flash_attn_stream_k_fixup(
int bidx = bidx0 - 1;
int kbc_stop = kbc0;
while(true) {
const int kbc = bidx*iter_k*iter_j*ne02 / gridDim.x;
const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
if (kbc == kbc_stop) { // Did not have any data.
bidx--;
kbc_stop = kbc;
continue;
}
#pragma unroll
for (int j = 0; j < ncols; ++j) {
if (jt*ncols + j >= ne01) {
break;
}
const float dst_add = dst_fixup_data[bidx*ncols*D + j*D + threadIdx.x];
const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid];
const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + j];
const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + jc];
// Scale the current and new value accumulators depending on the max. values.
const float max_val_new = fmaxf(max_val[j], tmp.x);
// Scale the current and new value accumulators depending on the max. values.
const float max_val_new = fmaxf(max_val, tmp.x);
const float diff_val = max_val[j] - max_val_new;
const float diff_add = tmp.x - max_val_new;
const float diff_val = max_val - max_val_new;
const float diff_add = tmp.x - max_val_new;
const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f;
const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f;
const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f;
const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f;
dst_val[j] = scale_val*dst_val[j] + scale_add*dst_add;
rowsum[j] = scale_val*rowsum[j] + scale_add*tmp.y;
dst_val = scale_val*dst_val + scale_add*dst_add;
rowsum = scale_val*rowsum + scale_add*tmp.y;
max_val[j] = max_val_new;
}
max_val = max_val_new;
// If this block started in a previous tile we are done and don't need to combine additional partial results.
if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) {
@ -611,19 +603,9 @@ static __global__ void flash_attn_stream_k_fixup(
}
// Write back final result:
#pragma unroll
for (int j = 0; j < ncols; ++j) {
if (jt*ncols + j >= ne01) {
return;
}
dst[j*ne02*D + threadIdx.x] = dst_val[j] / rowsum[j];
}
*dst = dst_val / rowsum;
}
#ifdef __clang__
#pragma clang diagnostic pop
#endif // __clang__
template<int D, int parallel_blocks> // D == head size
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(D, 1)
@ -690,11 +672,13 @@ static void on_no_fattn_vec_case(const int D) {
}
// parallel_blocks == 0 is stream-k decomposition
template <int D, int cols_per_block, int parallel_blocks, int KQ_stride>
template <int D, int ncols1, int ncols2, int parallel_blocks, int KQ_stride>
void launch_fattn(
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V
) {
constexpr int ncols = ncols1 * ncols2;
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
@ -763,25 +747,26 @@ void launch_fattn(
nb23 = nb23*bs*sizeof(half)/ts;
}
const int ntiles_x = ((Q->ne[1] + cols_per_block - 1) / cols_per_block);
const int ntiles_total = ntiles_x*Q->ne[2]*Q->ne[3];
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
const dim3 block_dim(WARP_SIZE, nwarps, 1);
dim3 blocks_num;
if (parallel_blocks == 0) {
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
const int tiles_nwaves = (ntiles_total + 2*nsm - 1) / (2*nsm);
const int tiles_efficiency_percent = 100 * ntiles_total / (2*nsm*tiles_nwaves);
const int max_blocks = 2*nsm;
const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves);
const int nblocks_stream_k = 2*nsm;
const int nblocks_stream_k = max_blocks;
const bool use_stream_k = tiles_efficiency_percent < 75 || cc >= GGML_CUDA_CC_ADA_LOVELACE;
const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || tiles_efficiency_percent < 75;
blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total;
blocks_num.y = 1;
blocks_num.z = 1;
dst_tmp_meta.alloc(blocks_num.x*cols_per_block * (2*2 + D) * sizeof(float));
dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + D) * sizeof(float));
} else {
blocks_num.x = parallel_blocks*ntiles_x;
blocks_num.y = Q->ne[2];
@ -793,7 +778,6 @@ void launch_fattn(
}
}
float scale = 1.0f;
float max_bias = 0.0f;
float logit_softcap = 0.0f;
@ -832,9 +816,9 @@ void launch_fattn(
if constexpr (parallel_blocks == 0) {
if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
const dim3 block_dim_combine(D, 1, 1);
const dim3 blocks_num_combine = blocks_num;
const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
flash_attn_stream_k_fixup<D, cols_per_block, KQ_stride>
flash_attn_stream_k_fixup<D, ncols1, ncols2, KQ_stride>
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
}

File diff suppressed because it is too large Load Diff

View File

@ -302,14 +302,14 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
constexpr int nwarps = 8;
constexpr size_t nbytes_shared = 0;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
} break;
case 128: {
constexpr int D = 128;
constexpr int nwarps = 8;
constexpr size_t nbytes_shared = 0;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
} break;
default: {
GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");

View File

@ -296,14 +296,14 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
constexpr int nwarps = 8;
constexpr size_t nbytes_shared = 0;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
} break;
case 128: {
constexpr int D = 128;
constexpr int nwarps = 8;
constexpr size_t nbytes_shared = 0;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
} break;
default: {
GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");

View File

@ -310,7 +310,7 @@ void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx,
constexpr bool need_f16_K = D != 128;
constexpr bool need_f16_V = D != 128 && D != 64;
constexpr size_t nbytes_shared = 0;
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
}
template <int D, ggml_type type_K, ggml_type type_V>

View File

@ -290,7 +290,7 @@ void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx,
constexpr bool need_f16_K = D != 128;
constexpr bool need_f16_V = D != 128 && D != 64;
constexpr size_t nbytes_shared = 0;
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
}
template <int D, ggml_type type_K, ggml_type type_V>

View File

@ -478,7 +478,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
fattn_kernel = flash_attn_ext_f16<
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
}
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
return;
}
if (2*blocks_num_pb1 < 2*nsm) {
@ -493,7 +493,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
fattn_kernel = flash_attn_ext_f16<
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
}
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
return;
}
constexpr int parallel_blocks = 1;
@ -507,7 +507,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
fattn_kernel = flash_attn_ext_f16<
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
}
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
}
void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {

View File

@ -8,28 +8,50 @@
#include "fattn-wmma-f16.cuh"
#include "fattn.cuh"
template <int cols_per_block>
template <int D, int ncols2>
static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
if (Q->ne[1] <= 8/ncols2) {
ggml_cuda_flash_attn_ext_mma_f16_case<D, 8/ncols2, ncols2>(ctx, dst);
return;
}
if (Q->ne[1] <= 16/ncols2) {
ggml_cuda_flash_attn_ext_mma_f16_case<D, 16/ncols2, ncols2>(ctx, dst);
return;
}
if (Q->ne[1] <= 32/ncols2) {
ggml_cuda_flash_attn_ext_mma_f16_case<D, 32/ncols2, ncols2>(ctx, dst);
return;
}
ggml_cuda_flash_attn_ext_mma_f16_case<D, 64/ncols2, ncols2>(ctx, dst);
}
template <int ncols2>
static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
switch (Q->ne[0]) {
case 64:
ggml_cuda_flash_attn_ext_mma_f16_case< 64, cols_per_block>(ctx, dst);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 64, ncols2>(ctx, dst);
break;
case 80:
ggml_cuda_flash_attn_ext_mma_f16_case< 80, cols_per_block>(ctx, dst);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 80, ncols2>(ctx, dst);
break;
case 96:
ggml_cuda_flash_attn_ext_mma_f16_case< 96, cols_per_block>(ctx, dst);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 96, ncols2>(ctx, dst);
break;
case 112:
ggml_cuda_flash_attn_ext_mma_f16_case<112, cols_per_block>(ctx, dst);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<112, ncols2>(ctx, dst);
break;
case 128:
ggml_cuda_flash_attn_ext_mma_f16_case<128, cols_per_block>(ctx, dst);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<128, ncols2>(ctx, dst);
break;
case 256:
ggml_cuda_flash_attn_ext_mma_f16_case<256, cols_per_block>(ctx, dst);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<256, ncols2>(ctx, dst);
break;
default:
GGML_ABORT("fatal error");
@ -38,24 +60,35 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context
}
static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * mask = dst->src[3];
if (Q->ne[1] <= 8) {
float max_bias = 0.0f;
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
const float use_gqa_opt = mask && max_bias == 0.0f;
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
const int gqa_ratio = Q->ne[2] / K->ne[2];
if (use_gqa_opt && gqa_ratio % 8 == 0) {
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<8>(ctx, dst);
return;
}
if (Q->ne[1] <= 16) {
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<16>(ctx, dst);
if (use_gqa_opt && gqa_ratio == 4) {
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<4>(ctx, dst);
return;
}
if (Q->ne[1] <= 32) {
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<32>(ctx, dst);
if (use_gqa_opt && gqa_ratio == 2) {
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<2>(ctx, dst);
return;
}
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<64>(ctx, dst);
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<1>(ctx, dst);
}
#define FATTN_VEC_F16_CASE(D, type_K, type_V) \
@ -209,8 +242,11 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
}
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
const ggml_tensor * mask = dst->src[3];
ggml_cuda_set_device(ctx.device);
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
@ -252,7 +288,10 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
return;
}
if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
const int gqa_ratio = Q->ne[2] / K->ne[2];
const bool mma_fast_for_bs1 = fp16_mma_available(cc) && gqa_ratio % 2 == 0 &&
K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && mask;
if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0 && !mma_fast_for_bs1) {
if (prec == GGML_PREC_DEFAULT) {
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
return;

View File

@ -73,6 +73,8 @@ namespace ggml_cuda_mma {
return threadIdx.x / 4;
} else if constexpr (I == 16 && J == 8) {
return (l / 2) * 8 + threadIdx.x / 4;
} else if constexpr (I == 16 && J == 16) {
return ((l / 2) % 2) * 8 + threadIdx.x / 4;
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
}
@ -85,6 +87,8 @@ namespace ggml_cuda_mma {
return 4 * l + threadIdx.x % 4;
} else if constexpr (I == 16 && J == 8) {
return 2 * (threadIdx.x % 4) + l % 2;
} else if constexpr (I == 16 && J == 16) {
return 8 * (l / 4) + 2 * (threadIdx.x % 4) + l % 2;
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
}
@ -289,6 +293,42 @@ namespace ggml_cuda_mma {
#endif // NEW_MMA_AVAILABLE
}
static __device__ __forceinline__ void mma(
tile<16, 8, half2> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
#ifdef NEW_MMA_AVAILABLE
const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x;
int * Dxi = (int *) D.x;
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
: "+r"(Dxi[0]), "+r"(Dxi[1])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[2]));
asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
: "+r"(Dxi[2]), "+r"(Dxi[3])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3]));
#else
// On Turing m16n8k16 mma is not available, use 4x m8n8k8 mma instead:
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
: "+r"(Dxi[0]), "+r"(Dxi[1])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
: "+r"(Dxi[0]), "+r"(Dxi[1])
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]));
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
: "+r"(Dxi[2]), "+r"(Dxi[3])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[1]));
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
: "+r"(Dxi[2]), "+r"(Dxi[3])
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#else
GGML_UNUSED(D);
GGML_UNUSED(A);
GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}
static __device__ __forceinline__ void mma(
tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
#ifdef NEW_MMA_AVAILABLE
@ -316,4 +356,39 @@ namespace ggml_cuda_mma {
#endif // NEW_MMA_AVAILABLE
}
static __device__ __forceinline__ void mma(
tile<16, 16, float> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
#ifdef NEW_MMA_AVAILABLE
const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x;
int * Dxi = (int *) D.x;
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[2]));
asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
: "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3]));
#else
// On Turing m16n8k16 mma is not available, use 4x m8n8k8 mma instead:
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]));
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
: "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[1]));
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
: "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#else
GGML_UNUSED(D);
GGML_UNUSED(A);
GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}
}

View File

@ -1,10 +0,0 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 16);
DECL_FATTN_MMA_F16_CASE(80, 16);
DECL_FATTN_MMA_F16_CASE(96, 16);
DECL_FATTN_MMA_F16_CASE(112, 16);
DECL_FATTN_MMA_F16_CASE(128, 16);
DECL_FATTN_MMA_F16_CASE(256, 16);

View File

@ -1,10 +0,0 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 32);
DECL_FATTN_MMA_F16_CASE(80, 32);
DECL_FATTN_MMA_F16_CASE(96, 32);
DECL_FATTN_MMA_F16_CASE(112, 32);
DECL_FATTN_MMA_F16_CASE(128, 32);
DECL_FATTN_MMA_F16_CASE(256, 32);

View File

@ -1,10 +0,0 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 64);
DECL_FATTN_MMA_F16_CASE(80, 64);
DECL_FATTN_MMA_F16_CASE(96, 64);
DECL_FATTN_MMA_F16_CASE(112, 64);
DECL_FATTN_MMA_F16_CASE(128, 64);
DECL_FATTN_MMA_F16_CASE(256, 64);

View File

@ -1,10 +0,0 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 8);
DECL_FATTN_MMA_F16_CASE(80, 8);
DECL_FATTN_MMA_F16_CASE(96, 8);
DECL_FATTN_MMA_F16_CASE(112, 8);
DECL_FATTN_MMA_F16_CASE(128, 8);
DECL_FATTN_MMA_F16_CASE(256, 8);

View File

@ -0,0 +1,10 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 1, 8);
DECL_FATTN_MMA_F16_CASE(80, 1, 8);
DECL_FATTN_MMA_F16_CASE(96, 1, 8);
DECL_FATTN_MMA_F16_CASE(112, 1, 8);
DECL_FATTN_MMA_F16_CASE(128, 1, 8);
DECL_FATTN_MMA_F16_CASE(256, 1, 8);

View File

@ -0,0 +1,10 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 16, 1);
DECL_FATTN_MMA_F16_CASE(80, 16, 1);
DECL_FATTN_MMA_F16_CASE(96, 16, 1);
DECL_FATTN_MMA_F16_CASE(112, 16, 1);
DECL_FATTN_MMA_F16_CASE(128, 16, 1);
DECL_FATTN_MMA_F16_CASE(256, 16, 1);

View File

@ -0,0 +1,10 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 16, 2);
DECL_FATTN_MMA_F16_CASE(80, 16, 2);
DECL_FATTN_MMA_F16_CASE(96, 16, 2);
DECL_FATTN_MMA_F16_CASE(112, 16, 2);
DECL_FATTN_MMA_F16_CASE(128, 16, 2);
DECL_FATTN_MMA_F16_CASE(256, 16, 2);

View File

@ -0,0 +1,10 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 16, 4);
DECL_FATTN_MMA_F16_CASE(80, 16, 4);
DECL_FATTN_MMA_F16_CASE(96, 16, 4);
DECL_FATTN_MMA_F16_CASE(112, 16, 4);
DECL_FATTN_MMA_F16_CASE(128, 16, 4);
DECL_FATTN_MMA_F16_CASE(256, 16, 4);

View File

@ -0,0 +1,10 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 2, 4);
DECL_FATTN_MMA_F16_CASE(80, 2, 4);
DECL_FATTN_MMA_F16_CASE(96, 2, 4);
DECL_FATTN_MMA_F16_CASE(112, 2, 4);
DECL_FATTN_MMA_F16_CASE(128, 2, 4);
DECL_FATTN_MMA_F16_CASE(256, 2, 4);

View File

@ -0,0 +1,10 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 2, 8);
DECL_FATTN_MMA_F16_CASE(80, 2, 8);
DECL_FATTN_MMA_F16_CASE(96, 2, 8);
DECL_FATTN_MMA_F16_CASE(112, 2, 8);
DECL_FATTN_MMA_F16_CASE(128, 2, 8);
DECL_FATTN_MMA_F16_CASE(256, 2, 8);

View File

@ -0,0 +1,10 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 32, 1);
DECL_FATTN_MMA_F16_CASE(80, 32, 1);
DECL_FATTN_MMA_F16_CASE(96, 32, 1);
DECL_FATTN_MMA_F16_CASE(112, 32, 1);
DECL_FATTN_MMA_F16_CASE(128, 32, 1);
DECL_FATTN_MMA_F16_CASE(256, 32, 1);

View File

@ -0,0 +1,10 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 32, 2);
DECL_FATTN_MMA_F16_CASE(80, 32, 2);
DECL_FATTN_MMA_F16_CASE(96, 32, 2);
DECL_FATTN_MMA_F16_CASE(112, 32, 2);
DECL_FATTN_MMA_F16_CASE(128, 32, 2);
DECL_FATTN_MMA_F16_CASE(256, 32, 2);

View File

@ -0,0 +1,10 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 4, 2);
DECL_FATTN_MMA_F16_CASE(80, 4, 2);
DECL_FATTN_MMA_F16_CASE(96, 4, 2);
DECL_FATTN_MMA_F16_CASE(112, 4, 2);
DECL_FATTN_MMA_F16_CASE(128, 4, 2);
DECL_FATTN_MMA_F16_CASE(256, 4, 2);

View File

@ -0,0 +1,10 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 4, 4);
DECL_FATTN_MMA_F16_CASE(80, 4, 4);
DECL_FATTN_MMA_F16_CASE(96, 4, 4);
DECL_FATTN_MMA_F16_CASE(112, 4, 4);
DECL_FATTN_MMA_F16_CASE(128, 4, 4);
DECL_FATTN_MMA_F16_CASE(256, 4, 4);

View File

@ -0,0 +1,10 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 4, 8);
DECL_FATTN_MMA_F16_CASE(80, 4, 8);
DECL_FATTN_MMA_F16_CASE(96, 4, 8);
DECL_FATTN_MMA_F16_CASE(112, 4, 8);
DECL_FATTN_MMA_F16_CASE(128, 4, 8);
DECL_FATTN_MMA_F16_CASE(256, 4, 8);

View File

@ -0,0 +1,10 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 64, 1);
DECL_FATTN_MMA_F16_CASE(80, 64, 1);
DECL_FATTN_MMA_F16_CASE(96, 64, 1);
DECL_FATTN_MMA_F16_CASE(112, 64, 1);
DECL_FATTN_MMA_F16_CASE(128, 64, 1);
DECL_FATTN_MMA_F16_CASE(256, 64, 1);

View File

@ -0,0 +1,10 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 8, 1);
DECL_FATTN_MMA_F16_CASE(80, 8, 1);
DECL_FATTN_MMA_F16_CASE(96, 8, 1);
DECL_FATTN_MMA_F16_CASE(112, 8, 1);
DECL_FATTN_MMA_F16_CASE(128, 8, 1);
DECL_FATTN_MMA_F16_CASE(256, 8, 1);

View File

@ -0,0 +1,10 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 8, 2);
DECL_FATTN_MMA_F16_CASE(80, 8, 2);
DECL_FATTN_MMA_F16_CASE(96, 8, 2);
DECL_FATTN_MMA_F16_CASE(112, 8, 2);
DECL_FATTN_MMA_F16_CASE(128, 8, 2);
DECL_FATTN_MMA_F16_CASE(256, 8, 2);

View File

@ -0,0 +1,10 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 8, 4);
DECL_FATTN_MMA_F16_CASE(80, 8, 4);
DECL_FATTN_MMA_F16_CASE(96, 8, 4);
DECL_FATTN_MMA_F16_CASE(112, 8, 4);
DECL_FATTN_MMA_F16_CASE(128, 8, 4);
DECL_FATTN_MMA_F16_CASE(256, 8, 4);

View File

@ -0,0 +1,10 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 8, 8);
DECL_FATTN_MMA_F16_CASE(80, 8, 8);
DECL_FATTN_MMA_F16_CASE(96, 8, 8);
DECL_FATTN_MMA_F16_CASE(112, 8, 8);
DECL_FATTN_MMA_F16_CASE(128, 8, 8);
DECL_FATTN_MMA_F16_CASE(256, 8, 8);

View File

@ -18,7 +18,7 @@ SOURCE_FATTN_MMA_START = """// This file has been autogenerated by generate_cu_f
"""
SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size}, {cols_per_block});\n"
SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size}, {ncols1}, {ncols2});\n"
TYPES_MMQ = [
"GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
@ -57,12 +57,18 @@ for vkq_size in [16, 32]:
with open(f"fattn-vec-f{vkq_size}-instance-hs{head_size}-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f:
f.write(SOURCE_FATTN_VEC.format(vkq_size=vkq_size, head_size=head_size, type_k=type_k, type_v=type_v))
for cols_per_block in [8, 16, 32, 64]:
with open(f"fattn-mma-f16-instance-cpb{cols_per_block}.cu", "w") as f:
f.write(SOURCE_FATTN_MMA_START)
for ncols in [8, 16, 32, 64, 128]:
for ncols2 in [1, 2, 4, 8]:
ncols1 = ncols // ncols2
if ncols == 128:
continue # Too much register pressure.
with open(f"fattn-mma-f16-instance-ncols1_{ncols1}-ncols2_{ncols2}.cu", "w") as f:
f.write(SOURCE_FATTN_MMA_START)
for head_size in [64, 80, 96, 112, 128, 256]:
f.write(SOURCE_FATTN_MMA_CASE.format(cols_per_block=cols_per_block, head_size=head_size))
for head_size in [64, 80, 96, 112, 128, 256]:
if ncols == 128 and head_size == 256:
continue # Needs too much shared memory.
f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size=head_size))
for type in TYPES_MMQ:
with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f: