CUDA: FA support for Deepseek (Ampere or newer) (llama/13306)

* CUDA: FA support for Deepseek (Ampere or newer)

* do loop unrolling via C++ template
This commit is contained in:
Johannes Gäßler 2025-05-09 13:34:58 +02:00 committed by Georgi Gerganov
parent 4b7cbb62ef
commit 2d436bfbfb
32 changed files with 841 additions and 547 deletions

View File

@ -118,7 +118,7 @@ if (CUDAToolkit_FOUND)
set(CUDA_CXX_FLAGS "")
set(CUDA_FLAGS -use_fast_math)
set(CUDA_FLAGS -use_fast_math -extended-lambda)
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8")
# Options are:

View File

@ -296,6 +296,25 @@ static __device__ void no_device_code(
#define NO_DEVICE_CODE //GGML_ABORT("NO_DEVICE_CODE not valid in host code.")
#endif // __CUDA_ARCH__
// The compiler is always able to unroll loops if they contain continue expressions.
// In such cases loop unrolling can still be achieved via recursion:
template <int n>
struct ggml_cuda_unroll {
template <typename Func, typename... Args>
__device__ void operator()(const Func & f, Args... args) const {
f(n - 1, args...);
ggml_cuda_unroll<n - 1>{}(f, args...);
}
};
template <>
struct ggml_cuda_unroll<1> {
template <typename Func, typename... Args>
__device__ void operator()(const Func & f, Args... args) const {
f(0, args...);
}
};
template<int width = WARP_SIZE>
static __device__ __forceinline__ int warp_reduce_sum(int x) {
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE

View File

@ -2,6 +2,17 @@
#include "common.cuh"
static __device__ __forceinline__ unsigned int ggml_cuda_cvta_generic_to_shared(void * generic_ptr) {
#ifdef CP_ASYNC_AVAILABLE
return __cvta_generic_to_shared(generic_ptr);
#else
GGML_UNUSED(generic_ptr);
NO_DEVICE_CODE;
return 0;
#endif // CP_ASYNC_AVAILABLE
}
// Copies data from global to shared memory, cg == cache global.
// Both the src and dst pointers must be aligned to 16 bit.
// Shared memory uses 32 bit addressing, the pointer is passed as unsigned int.

View File

@ -516,7 +516,7 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
nullptr;
}
template<int D, int ncols1, int ncols2, int KQ_stride> // D == head size
template<int D, int ncols1, int ncols2> // D == head size
__launch_bounds__(D, 1)
static __global__ void flash_attn_stream_k_fixup(
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
@ -665,13 +665,13 @@ static void on_no_fattn_vec_case(const int D) {
fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, q5_0, q5_1, q8_0, and f16.\n");
GGML_ABORT("fatal error");
} else {
fprintf(stderr, "Unsupported KV type combination for head_size 256.\n");
fprintf(stderr, "Unsupported KV type combination for head_size %d.\n", D);
fprintf(stderr, "Only f16 is supported.\n");
GGML_ABORT("fatal error");
}
}
template <int D, int ncols1, int ncols2, int KQ_stride>
template <int DV, int ncols1, int ncols2>
void launch_fattn(
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
@ -691,7 +691,7 @@ void launch_fattn(
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
@ -754,10 +754,13 @@ void launch_fattn(
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
const dim3 block_dim(warp_size, nwarps, 1);
int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
dim3 blocks_num;
if (stream_k) {
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
const int max_blocks = 2*nsm;
const int max_blocks = max_blocks_per_sm*nsm;
const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves);
@ -769,14 +772,11 @@ void launch_fattn(
blocks_num.y = 1;
blocks_num.z = 1;
dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + D) * sizeof(float));
dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float));
} else {
GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0);
const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
// parallel_blocks should be at least large enough to achieve max. occupancy for a single wave:
parallel_blocks = std::max((nsm * max_blocks_per_sm) / ntiles_total, 1);
@ -853,19 +853,19 @@ void launch_fattn(
if (stream_k) {
if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
const dim3 block_dim_combine(D, 1, 1);
const dim3 block_dim_combine(DV, 1, 1);
const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
flash_attn_stream_k_fixup<D, ncols1, ncols2, KQ_stride>
flash_attn_stream_k_fixup<DV, ncols1, ncols2>
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
}
} else if (parallel_blocks > 1) {
const dim3 block_dim_combine(D, 1, 1);
const dim3 block_dim_combine(DV, 1, 1);
const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z);
const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
flash_attn_combine_results<D>
flash_attn_combine_results<DV>
<<<blocks_num_combine, block_dim_combine, nbytes_shared_combine, main_stream>>>
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks);
}

File diff suppressed because it is too large Load Diff

View File

@ -307,7 +307,7 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
constexpr int nwarps = 8;
constexpr size_t nbytes_shared = 0;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, use_logit_softcap>;
launch_fattn<D, cols_per_block, 1, -1>
launch_fattn<D, cols_per_block, 1>
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false);
} break;
case 128: {
@ -315,7 +315,7 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
constexpr int nwarps = 8;
constexpr size_t nbytes_shared = 0;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, use_logit_softcap>;
launch_fattn<D, cols_per_block, 1, -1>
launch_fattn<D, cols_per_block, 1>
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false);
} break;
default: {

View File

@ -318,7 +318,7 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
constexpr int nwarps = 8;
constexpr size_t nbytes_shared = 0;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, use_logit_softcap>;
launch_fattn<D, cols_per_block, 1, -1>
launch_fattn<D, cols_per_block, 1>
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false);
} break;
case 128: {
@ -326,7 +326,7 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
constexpr int nwarps = 8;
constexpr size_t nbytes_shared = 0;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, use_logit_softcap>;
launch_fattn<D, cols_per_block, 1, -1>
launch_fattn<D, cols_per_block, 1>
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false);
} break;
default: {

View File

@ -315,7 +315,7 @@ void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx,
constexpr bool need_f16_K = D != 128;
constexpr bool need_f16_V = D != 128 && D != 64;
constexpr size_t nbytes_shared = 0;
launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
}
template <int D, ggml_type type_K, ggml_type type_V>

View File

@ -310,7 +310,7 @@ void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx,
constexpr bool need_f16_K = D != 128;
constexpr bool need_f16_V = D != 128 && D != 64;
constexpr size_t nbytes_shared = 0;
launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
}
template <int D, ggml_type type_K, ggml_type type_V>

View File

@ -490,7 +490,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
fattn_kernel = flash_attn_ext_f16<
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), KQ_acc_t, use_logit_softcap>;
}
launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, 0, FATTN_KQ_STRIDE, true, true, false, warp_size);
launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, 0, FATTN_KQ_STRIDE, true, true, false, warp_size);
}
void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {

View File

@ -8,58 +8,32 @@
#include "fattn-wmma-f16.cuh"
#include "fattn.cuh"
template <int D, int ncols2>
template <int DKQ, int DV, int ncols2>
static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
if (Q->ne[1] <= 8/ncols2) {
ggml_cuda_flash_attn_ext_mma_f16_case<D, 8/ncols2, ncols2>(ctx, dst);
return;
if constexpr (ncols2 <= 8) {
if (Q->ne[1] <= 8/ncols2) {
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 8/ncols2, ncols2>(ctx, dst);
return;
}
}
if (Q->ne[1] <= 16/ncols2) {
ggml_cuda_flash_attn_ext_mma_f16_case<D, 16/ncols2, ncols2>(ctx, dst);
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 16/ncols2, ncols2>(ctx, dst);
return;
}
if (Q->ne[1] <= 32/ncols2) {
ggml_cuda_flash_attn_ext_mma_f16_case<D, 32/ncols2, ncols2>(ctx, dst);
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 32/ncols2, ncols2>(ctx, dst);
return;
}
ggml_cuda_flash_attn_ext_mma_f16_case<D, 64/ncols2, ncols2>(ctx, dst);
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 64/ncols2, ncols2>(ctx, dst);
}
template <int ncols2>
static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
switch (Q->ne[0]) {
case 64:
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 64, ncols2>(ctx, dst);
break;
case 80:
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 80, ncols2>(ctx, dst);
break;
case 96:
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 96, ncols2>(ctx, dst);
break;
case 112:
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<112, ncols2>(ctx, dst);
break;
case 128:
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<128, ncols2>(ctx, dst);
break;
case 256:
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<256, ncols2>(ctx, dst);
break;
default:
GGML_ABORT("fatal error");
break;
}
}
static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
template <int DKQ, int DV>
static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
@ -68,27 +42,79 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
float max_bias = 0.0f;
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
const float use_gqa_opt = mask && max_bias == 0.0f;
const bool use_gqa_opt = mask && max_bias == 0.0f;
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
const int gqa_ratio = Q->ne[2] / K->ne[2];
if (use_gqa_opt && gqa_ratio % 8 == 0) {
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<8>(ctx, dst);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 8>(ctx, dst);
return;
}
if (use_gqa_opt && gqa_ratio == 4) {
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<4>(ctx, dst);
if (use_gqa_opt && gqa_ratio % 4 == 0) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 4>(ctx, dst);
return;
}
if (use_gqa_opt && gqa_ratio == 2) {
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<2>(ctx, dst);
if (use_gqa_opt && gqa_ratio % 2 == 0) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst);
return;
}
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<1>(ctx, dst);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 1>(ctx, dst);
}
static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
const ggml_tensor * mask = dst->src[3];
switch (Q->ne[0]) {
case 64:
GGML_ASSERT(V->ne[0] == 64);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 64, 64>(ctx, dst);
break;
case 80:
GGML_ASSERT(V->ne[0] == 80);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 80, 80>(ctx, dst);
break;
case 96:
GGML_ASSERT(V->ne[0] == 96);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 96, 96>(ctx, dst);
break;
case 112:
GGML_ASSERT(V->ne[0] == 112);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<112, 112>(ctx, dst);
break;
case 128:
GGML_ASSERT(V->ne[0] == 128);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<128, 128>(ctx, dst);
break;
case 256:
GGML_ASSERT(V->ne[0] == 256);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst);
break;
case 576: {
// For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
GGML_ASSERT(V->ne[0] == 512);
float max_bias = 0.0f;
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
const bool use_gqa_opt = mask && max_bias == 0.0f;
GGML_ASSERT(use_gqa_opt);
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
const int gqa_ratio = Q->ne[2] / K->ne[2];
GGML_ASSERT(gqa_ratio % 16 == 0);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
} break;
default:
GGML_ABORT("fatal error");
break;
}
}
#define FATTN_VEC_F16_CASE(D, type_K, type_V) \
@ -299,7 +325,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < GGML_CUDA_CC_ADA_LOVELACE && !mma_needs_data_conversion;
const bool can_use_vector_kernel = Q->ne[0] % (2*warp_size) == 0;
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0;
if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
if (prec == GGML_PREC_DEFAULT) {
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);

View File

@ -3215,16 +3215,16 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
return false;
#endif // FLASH_ATTN_AVAILABLE
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
// different head sizes of K and V are not supported yet
return false;
const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
if (!new_mma_available(cc) || cc < GGML_CUDA_CC_AMPERE) {
return false;
}
const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2];
return op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512 && op->src[3] && gqa_ratio % 16 == 0;
}
if (op->src[0]->ne[0] == 192) {
return false;
}
if (op->src[0]->ne[0] == 576) {
// DeepSeek MLA
return false;
}
if (op->src[0]->ne[3] != 1) {
return false;
}

View File

@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);

View File

@ -2,9 +2,9 @@
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 1, 8);
DECL_FATTN_MMA_F16_CASE(80, 1, 8);
DECL_FATTN_MMA_F16_CASE(96, 1, 8);
DECL_FATTN_MMA_F16_CASE(112, 1, 8);
DECL_FATTN_MMA_F16_CASE(128, 1, 8);
DECL_FATTN_MMA_F16_CASE(256, 1, 8);
DECL_FATTN_MMA_F16_CASE(64, 64, 1, 8);
DECL_FATTN_MMA_F16_CASE(80, 80, 1, 8);
DECL_FATTN_MMA_F16_CASE(96, 96, 1, 8);
DECL_FATTN_MMA_F16_CASE(112, 112, 1, 8);
DECL_FATTN_MMA_F16_CASE(128, 128, 1, 8);
DECL_FATTN_MMA_F16_CASE(256, 256, 1, 8);

View File

@ -2,9 +2,9 @@
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 16, 1);
DECL_FATTN_MMA_F16_CASE(80, 16, 1);
DECL_FATTN_MMA_F16_CASE(96, 16, 1);
DECL_FATTN_MMA_F16_CASE(112, 16, 1);
DECL_FATTN_MMA_F16_CASE(128, 16, 1);
DECL_FATTN_MMA_F16_CASE(256, 16, 1);
DECL_FATTN_MMA_F16_CASE(64, 64, 16, 1);
DECL_FATTN_MMA_F16_CASE(80, 80, 16, 1);
DECL_FATTN_MMA_F16_CASE(96, 96, 16, 1);
DECL_FATTN_MMA_F16_CASE(112, 112, 16, 1);
DECL_FATTN_MMA_F16_CASE(128, 128, 16, 1);
DECL_FATTN_MMA_F16_CASE(256, 256, 16, 1);

View File

@ -2,9 +2,9 @@
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 16, 2);
DECL_FATTN_MMA_F16_CASE(80, 16, 2);
DECL_FATTN_MMA_F16_CASE(96, 16, 2);
DECL_FATTN_MMA_F16_CASE(112, 16, 2);
DECL_FATTN_MMA_F16_CASE(128, 16, 2);
DECL_FATTN_MMA_F16_CASE(256, 16, 2);
DECL_FATTN_MMA_F16_CASE(64, 64, 16, 2);
DECL_FATTN_MMA_F16_CASE(80, 80, 16, 2);
DECL_FATTN_MMA_F16_CASE(96, 96, 16, 2);
DECL_FATTN_MMA_F16_CASE(112, 112, 16, 2);
DECL_FATTN_MMA_F16_CASE(128, 128, 16, 2);
DECL_FATTN_MMA_F16_CASE(256, 256, 16, 2);

View File

@ -2,9 +2,9 @@
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 16, 4);
DECL_FATTN_MMA_F16_CASE(80, 16, 4);
DECL_FATTN_MMA_F16_CASE(96, 16, 4);
DECL_FATTN_MMA_F16_CASE(112, 16, 4);
DECL_FATTN_MMA_F16_CASE(128, 16, 4);
DECL_FATTN_MMA_F16_CASE(256, 16, 4);
DECL_FATTN_MMA_F16_CASE(64, 64, 16, 4);
DECL_FATTN_MMA_F16_CASE(80, 80, 16, 4);
DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4);
DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4);
DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4);
DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4);

View File

@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);

View File

@ -2,9 +2,9 @@
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 2, 4);
DECL_FATTN_MMA_F16_CASE(80, 2, 4);
DECL_FATTN_MMA_F16_CASE(96, 2, 4);
DECL_FATTN_MMA_F16_CASE(112, 2, 4);
DECL_FATTN_MMA_F16_CASE(128, 2, 4);
DECL_FATTN_MMA_F16_CASE(256, 2, 4);
DECL_FATTN_MMA_F16_CASE(64, 64, 2, 4);
DECL_FATTN_MMA_F16_CASE(80, 80, 2, 4);
DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4);
DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4);
DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4);
DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4);

View File

@ -2,9 +2,9 @@
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 2, 8);
DECL_FATTN_MMA_F16_CASE(80, 2, 8);
DECL_FATTN_MMA_F16_CASE(96, 2, 8);
DECL_FATTN_MMA_F16_CASE(112, 2, 8);
DECL_FATTN_MMA_F16_CASE(128, 2, 8);
DECL_FATTN_MMA_F16_CASE(256, 2, 8);
DECL_FATTN_MMA_F16_CASE(64, 64, 2, 8);
DECL_FATTN_MMA_F16_CASE(80, 80, 2, 8);
DECL_FATTN_MMA_F16_CASE(96, 96, 2, 8);
DECL_FATTN_MMA_F16_CASE(112, 112, 2, 8);
DECL_FATTN_MMA_F16_CASE(128, 128, 2, 8);
DECL_FATTN_MMA_F16_CASE(256, 256, 2, 8);

View File

@ -2,9 +2,9 @@
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 32, 1);
DECL_FATTN_MMA_F16_CASE(80, 32, 1);
DECL_FATTN_MMA_F16_CASE(96, 32, 1);
DECL_FATTN_MMA_F16_CASE(112, 32, 1);
DECL_FATTN_MMA_F16_CASE(128, 32, 1);
DECL_FATTN_MMA_F16_CASE(256, 32, 1);
DECL_FATTN_MMA_F16_CASE(64, 64, 32, 1);
DECL_FATTN_MMA_F16_CASE(80, 80, 32, 1);
DECL_FATTN_MMA_F16_CASE(96, 96, 32, 1);
DECL_FATTN_MMA_F16_CASE(112, 112, 32, 1);
DECL_FATTN_MMA_F16_CASE(128, 128, 32, 1);
DECL_FATTN_MMA_F16_CASE(256, 256, 32, 1);

View File

@ -2,9 +2,9 @@
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 32, 2);
DECL_FATTN_MMA_F16_CASE(80, 32, 2);
DECL_FATTN_MMA_F16_CASE(96, 32, 2);
DECL_FATTN_MMA_F16_CASE(112, 32, 2);
DECL_FATTN_MMA_F16_CASE(128, 32, 2);
DECL_FATTN_MMA_F16_CASE(256, 32, 2);
DECL_FATTN_MMA_F16_CASE(64, 64, 32, 2);
DECL_FATTN_MMA_F16_CASE(80, 80, 32, 2);
DECL_FATTN_MMA_F16_CASE(96, 96, 32, 2);
DECL_FATTN_MMA_F16_CASE(112, 112, 32, 2);
DECL_FATTN_MMA_F16_CASE(128, 128, 32, 2);
DECL_FATTN_MMA_F16_CASE(256, 256, 32, 2);

View File

@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);

View File

@ -2,9 +2,9 @@
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 4, 2);
DECL_FATTN_MMA_F16_CASE(80, 4, 2);
DECL_FATTN_MMA_F16_CASE(96, 4, 2);
DECL_FATTN_MMA_F16_CASE(112, 4, 2);
DECL_FATTN_MMA_F16_CASE(128, 4, 2);
DECL_FATTN_MMA_F16_CASE(256, 4, 2);
DECL_FATTN_MMA_F16_CASE(64, 64, 4, 2);
DECL_FATTN_MMA_F16_CASE(80, 80, 4, 2);
DECL_FATTN_MMA_F16_CASE(96, 96, 4, 2);
DECL_FATTN_MMA_F16_CASE(112, 112, 4, 2);
DECL_FATTN_MMA_F16_CASE(128, 128, 4, 2);
DECL_FATTN_MMA_F16_CASE(256, 256, 4, 2);

View File

@ -2,9 +2,9 @@
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 4, 4);
DECL_FATTN_MMA_F16_CASE(80, 4, 4);
DECL_FATTN_MMA_F16_CASE(96, 4, 4);
DECL_FATTN_MMA_F16_CASE(112, 4, 4);
DECL_FATTN_MMA_F16_CASE(128, 4, 4);
DECL_FATTN_MMA_F16_CASE(256, 4, 4);
DECL_FATTN_MMA_F16_CASE(64, 64, 4, 4);
DECL_FATTN_MMA_F16_CASE(80, 80, 4, 4);
DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4);
DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4);
DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4);
DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4);

View File

@ -2,9 +2,9 @@
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 4, 8);
DECL_FATTN_MMA_F16_CASE(80, 4, 8);
DECL_FATTN_MMA_F16_CASE(96, 4, 8);
DECL_FATTN_MMA_F16_CASE(112, 4, 8);
DECL_FATTN_MMA_F16_CASE(128, 4, 8);
DECL_FATTN_MMA_F16_CASE(256, 4, 8);
DECL_FATTN_MMA_F16_CASE(64, 64, 4, 8);
DECL_FATTN_MMA_F16_CASE(80, 80, 4, 8);
DECL_FATTN_MMA_F16_CASE(96, 96, 4, 8);
DECL_FATTN_MMA_F16_CASE(112, 112, 4, 8);
DECL_FATTN_MMA_F16_CASE(128, 128, 4, 8);
DECL_FATTN_MMA_F16_CASE(256, 256, 4, 8);

View File

@ -2,9 +2,9 @@
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 64, 1);
DECL_FATTN_MMA_F16_CASE(80, 64, 1);
DECL_FATTN_MMA_F16_CASE(96, 64, 1);
DECL_FATTN_MMA_F16_CASE(112, 64, 1);
DECL_FATTN_MMA_F16_CASE(128, 64, 1);
DECL_FATTN_MMA_F16_CASE(256, 64, 1);
DECL_FATTN_MMA_F16_CASE(64, 64, 64, 1);
DECL_FATTN_MMA_F16_CASE(80, 80, 64, 1);
DECL_FATTN_MMA_F16_CASE(96, 96, 64, 1);
DECL_FATTN_MMA_F16_CASE(112, 112, 64, 1);
DECL_FATTN_MMA_F16_CASE(128, 128, 64, 1);
DECL_FATTN_MMA_F16_CASE(256, 256, 64, 1);

View File

@ -2,9 +2,9 @@
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 8, 1);
DECL_FATTN_MMA_F16_CASE(80, 8, 1);
DECL_FATTN_MMA_F16_CASE(96, 8, 1);
DECL_FATTN_MMA_F16_CASE(112, 8, 1);
DECL_FATTN_MMA_F16_CASE(128, 8, 1);
DECL_FATTN_MMA_F16_CASE(256, 8, 1);
DECL_FATTN_MMA_F16_CASE(64, 64, 8, 1);
DECL_FATTN_MMA_F16_CASE(80, 80, 8, 1);
DECL_FATTN_MMA_F16_CASE(96, 96, 8, 1);
DECL_FATTN_MMA_F16_CASE(112, 112, 8, 1);
DECL_FATTN_MMA_F16_CASE(128, 128, 8, 1);
DECL_FATTN_MMA_F16_CASE(256, 256, 8, 1);

View File

@ -2,9 +2,9 @@
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 8, 2);
DECL_FATTN_MMA_F16_CASE(80, 8, 2);
DECL_FATTN_MMA_F16_CASE(96, 8, 2);
DECL_FATTN_MMA_F16_CASE(112, 8, 2);
DECL_FATTN_MMA_F16_CASE(128, 8, 2);
DECL_FATTN_MMA_F16_CASE(256, 8, 2);
DECL_FATTN_MMA_F16_CASE(64, 64, 8, 2);
DECL_FATTN_MMA_F16_CASE(80, 80, 8, 2);
DECL_FATTN_MMA_F16_CASE(96, 96, 8, 2);
DECL_FATTN_MMA_F16_CASE(112, 112, 8, 2);
DECL_FATTN_MMA_F16_CASE(128, 128, 8, 2);
DECL_FATTN_MMA_F16_CASE(256, 256, 8, 2);

View File

@ -2,9 +2,9 @@
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 8, 4);
DECL_FATTN_MMA_F16_CASE(80, 8, 4);
DECL_FATTN_MMA_F16_CASE(96, 8, 4);
DECL_FATTN_MMA_F16_CASE(112, 8, 4);
DECL_FATTN_MMA_F16_CASE(128, 8, 4);
DECL_FATTN_MMA_F16_CASE(256, 8, 4);
DECL_FATTN_MMA_F16_CASE(64, 64, 8, 4);
DECL_FATTN_MMA_F16_CASE(80, 80, 8, 4);
DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4);
DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4);
DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4);
DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4);

View File

@ -2,9 +2,9 @@
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 8, 8);
DECL_FATTN_MMA_F16_CASE(80, 8, 8);
DECL_FATTN_MMA_F16_CASE(96, 8, 8);
DECL_FATTN_MMA_F16_CASE(112, 8, 8);
DECL_FATTN_MMA_F16_CASE(128, 8, 8);
DECL_FATTN_MMA_F16_CASE(256, 8, 8);
DECL_FATTN_MMA_F16_CASE(64, 64, 8, 8);
DECL_FATTN_MMA_F16_CASE(80, 80, 8, 8);
DECL_FATTN_MMA_F16_CASE(96, 96, 8, 8);
DECL_FATTN_MMA_F16_CASE(112, 112, 8, 8);
DECL_FATTN_MMA_F16_CASE(128, 128, 8, 8);
DECL_FATTN_MMA_F16_CASE(256, 256, 8, 8);

View File

@ -18,7 +18,7 @@ SOURCE_FATTN_MMA_START = """// This file has been autogenerated by generate_cu_f
"""
SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size}, {ncols1}, {ncols2});\n"
SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size_kq}, {head_size_v}, {ncols1}, {ncols2});\n"
TYPES_MMQ = [
"GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
@ -57,18 +57,21 @@ for vkq_size in [16, 32]:
with open(f"fattn-vec-f{vkq_size}-instance-hs{head_size}-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f:
f.write(SOURCE_FATTN_VEC.format(vkq_size=vkq_size, head_size=head_size, type_k=type_k, type_v=type_v))
for ncols in [8, 16, 32, 64, 128]:
for ncols2 in [1, 2, 4, 8]:
for ncols in [8, 16, 32, 64]:
for ncols2 in [1, 2, 4, 8, 16]:
if ncols2 > ncols:
continue
ncols1 = ncols // ncols2
if ncols == 128:
continue # Too much register pressure.
with open(f"fattn-mma-f16-instance-ncols1_{ncols1}-ncols2_{ncols2}.cu", "w") as f:
f.write(SOURCE_FATTN_MMA_START)
for head_size in [64, 80, 96, 112, 128, 256]:
if ncols == 128 and head_size == 256:
continue # Needs too much shared memory.
f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size=head_size))
for head_size_kq in [64, 80, 96, 112, 128, 256, 576]:
if head_size_kq != 576 and ncols2 == 16:
continue
if head_size_kq == 576 and ncols2 != 16:
continue
head_size_v = head_size_kq if head_size_kq != 576 else 512
f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size_kq=head_size_kq, head_size_v=head_size_v))
for type in TYPES_MMQ:
with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f: