mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-03-01 03:26:12 +00:00
CUDA: use async data loading for FlashAttention (llama/11894)
* CUDA: use async data loading for FlashAttention --------- Co-authored-by: Diego Devesa <slarengh@gmail.com>
This commit is contained in:
parent
37a21dd43d
commit
51a3580c79
@ -41,12 +41,13 @@
|
|||||||
#define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
|
#define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
|
||||||
#define CUDART_HMASK 12000 // CUDA 12.0, min. ver. for half2 -> uint mask comparisons
|
#define CUDART_HMASK 12000 // CUDA 12.0, min. ver. for half2 -> uint mask comparisons
|
||||||
|
|
||||||
#define GGML_CUDA_CC_PASCAL 600
|
#define GGML_CUDA_CC_PASCAL 600
|
||||||
#define GGML_CUDA_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
|
#define GGML_CUDA_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
|
||||||
#define GGML_CUDA_CC_VOLTA 700
|
#define GGML_CUDA_CC_VOLTA 700
|
||||||
#define GGML_CUDA_CC_TURING 750
|
#define GGML_CUDA_CC_TURING 750
|
||||||
#define GGML_CUDA_CC_AMPERE 800
|
#define GGML_CUDA_CC_AMPERE 800
|
||||||
#define GGML_CUDA_CC_OFFSET_AMD 0x1000000
|
#define GGML_CUDA_CC_ADA_LOVELACE 890
|
||||||
|
#define GGML_CUDA_CC_OFFSET_AMD 0x1000000
|
||||||
|
|
||||||
// GCN/CNDA, wave size is 64
|
// GCN/CNDA, wave size is 64
|
||||||
#define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16
|
#define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16
|
||||||
@ -199,6 +200,10 @@ typedef float2 dfloat2;
|
|||||||
#define NEW_MMA_AVAILABLE
|
#define NEW_MMA_AVAILABLE
|
||||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
|
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
|
||||||
|
|
||||||
|
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||||
|
#define CP_ASYNC_AVAILABLE
|
||||||
|
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||||
|
|
||||||
#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
|
#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
|
||||||
#define FLASH_ATTN_AVAILABLE
|
#define FLASH_ATTN_AVAILABLE
|
||||||
#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
|
#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
|
||||||
@ -231,6 +236,10 @@ static bool new_mma_available(const int cc) {
|
|||||||
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
|
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool cp_async_available(const int cc) {
|
||||||
|
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
|
||||||
|
}
|
||||||
|
|
||||||
static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
|
static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
|
||||||
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||||
return __AMDGCN_WAVEFRONT_SIZE;
|
return __AMDGCN_WAVEFRONT_SIZE;
|
||||||
|
46
ggml/src/ggml-cuda/cp-async.cuh
Normal file
46
ggml/src/ggml-cuda/cp-async.cuh
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
// Simplified API for asynchronous data loading.
|
||||||
|
|
||||||
|
#include "common.cuh"
|
||||||
|
|
||||||
|
// Copies data from global to shared memory, cg == cache global.
|
||||||
|
// Both the src and dst pointers must be aligned to 16 bit.
|
||||||
|
// Shared memory uses 32 bit addressing, the pointer is passed as unsigned int.
|
||||||
|
// Generic pointers can be converted to 32 bit shared memory pointers using __cvta_generic_to_shared.
|
||||||
|
// Only the 16 bit copy is exposed because 4 and 8 bit copies did not yield performance improvements.
|
||||||
|
template <int preload>
|
||||||
|
static __device__ __forceinline__ void cp_async_cg_16(const unsigned int dst, const void * src) {
|
||||||
|
static_assert(preload == 0 || preload == 64 || preload == 128 || preload == 256, "bad preload");
|
||||||
|
#ifdef CP_ASYNC_AVAILABLE
|
||||||
|
#if CUDART_VERSION >= 11040
|
||||||
|
if (preload == 256) {
|
||||||
|
asm volatile("cp.async.cg.shared.global.L2::256B [%0], [%1], 16;"
|
||||||
|
: : "r"(dst), "l"(src));
|
||||||
|
} else if (preload == 128) {
|
||||||
|
asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], 16;"
|
||||||
|
: : "r"(dst), "l"(src));
|
||||||
|
} else if (preload == 64) {
|
||||||
|
asm volatile("cp.async.cg.shared.global.L2::64B [%0], [%1], 16;"
|
||||||
|
: : "r"(dst), "l"(src));
|
||||||
|
} else
|
||||||
|
#endif // CUDART_VERSION >= 11040
|
||||||
|
{
|
||||||
|
asm volatile("cp.async.cg.shared.global.L2 [%0], [%1], 16;"
|
||||||
|
: : "r"(dst), "l"(src));
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
GGML_UNUSED(dst);
|
||||||
|
GGML_UNUSED(src);
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
#endif // CP_ASYNC_AVAILABLE
|
||||||
|
}
|
||||||
|
|
||||||
|
// Makes each thread wait until its asynchronous data copies are done.
|
||||||
|
// This does NOT provide any additional synchronization.
|
||||||
|
// In particular, when copying data with multiple warps a call to __syncthreads will be needed.
|
||||||
|
static __device__ __forceinline__ void cp_async_wait_all() {
|
||||||
|
#ifdef CP_ASYNC_AVAILABLE
|
||||||
|
asm volatile("cp.async.wait_all;");
|
||||||
|
#else
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
#endif // CP_ASYNC_AVAILABLE
|
||||||
|
}
|
@ -716,7 +716,9 @@ void launch_fattn(
|
|||||||
|
|
||||||
ggml_cuda_pool & pool = ctx.pool();
|
ggml_cuda_pool & pool = ctx.pool();
|
||||||
cudaStream_t main_stream = ctx.stream();
|
cudaStream_t main_stream = ctx.stream();
|
||||||
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
|
const int id = ggml_cuda_get_device();
|
||||||
|
const int cc = ggml_cuda_info().devices[id].cc;
|
||||||
|
const int nsm = ggml_cuda_info().devices[id].nsm;
|
||||||
|
|
||||||
ggml_cuda_pool_alloc<half> K_f16(pool);
|
ggml_cuda_pool_alloc<half> K_f16(pool);
|
||||||
ggml_cuda_pool_alloc<half> V_f16(pool);
|
ggml_cuda_pool_alloc<half> V_f16(pool);
|
||||||
@ -768,13 +770,14 @@ void launch_fattn(
|
|||||||
dim3 blocks_num;
|
dim3 blocks_num;
|
||||||
if (parallel_blocks == 0) {
|
if (parallel_blocks == 0) {
|
||||||
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
|
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
|
||||||
const int tiles_nwaves = (ntiles_total - nsm - 1) / nsm;
|
const int tiles_nwaves = (ntiles_total + 2*nsm - 1) / (2*nsm);
|
||||||
const bool tiles_inefficient = 3*nsm < 2*tiles_nwaves*ntiles_total;
|
const int tiles_efficiency_percent = 100 * ntiles_total / (2*nsm*tiles_nwaves);
|
||||||
const bool short_context = K->ne[1] < 4096;
|
|
||||||
|
|
||||||
const int nblocks_stream_k = 2*nsm;
|
const int nblocks_stream_k = 2*nsm;
|
||||||
|
|
||||||
blocks_num.x = short_context && !tiles_inefficient ? ntiles_total : nblocks_stream_k;
|
const bool use_stream_k = tiles_efficiency_percent < 75 || cc >= GGML_CUDA_CC_ADA_LOVELACE;
|
||||||
|
|
||||||
|
blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total;
|
||||||
blocks_num.y = 1;
|
blocks_num.y = 1;
|
||||||
blocks_num.z = 1;
|
blocks_num.z = 1;
|
||||||
|
|
||||||
@ -827,7 +830,7 @@ void launch_fattn(
|
|||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
if constexpr (parallel_blocks == 0) {
|
if constexpr (parallel_blocks == 0) {
|
||||||
if (blocks_num.x % ntiles_total != 0) { // Fixup is only needed if the SMs work on fractional tiles.
|
if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
|
||||||
const dim3 block_dim_combine(D, 1, 1);
|
const dim3 block_dim_combine(D, 1, 1);
|
||||||
const dim3 blocks_num_combine = blocks_num;
|
const dim3 blocks_num_combine = blocks_num;
|
||||||
|
|
||||||
|
@ -1,7 +1,252 @@
|
|||||||
#include "common.cuh"
|
#include "common.cuh"
|
||||||
|
#include "cp-async.cuh"
|
||||||
#include "mma.cuh"
|
#include "mma.cuh"
|
||||||
#include "fattn-common.cuh"
|
#include "fattn-common.cuh"
|
||||||
|
|
||||||
|
using namespace ggml_cuda_mma;
|
||||||
|
|
||||||
|
typedef tile<16, 8, half2> tile_A;
|
||||||
|
typedef tile< 8, 8, half2> tile_B;
|
||||||
|
typedef tile<16, 8, float> tile_C_KQ;
|
||||||
|
typedef tile<16, 4, half2> tile_C_VKQ;
|
||||||
|
|
||||||
|
template<int D, int nwarps, int KQ_stride>
|
||||||
|
static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
||||||
|
const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV) {
|
||||||
|
constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
|
||||||
|
|
||||||
|
// If cp.async is available, load up to the highest power of 2 in D asynchronously:
|
||||||
|
#ifdef CP_ASYNC_AVAILABLE
|
||||||
|
static_assert(D >= 64 && D < 512, "bad D");
|
||||||
|
constexpr int k0_sync_start = D/2 < 64 ? 32 : (D/2 < 128 ? 64 : 128);
|
||||||
|
|
||||||
|
const unsigned int tile_KV_32 = __cvta_generic_to_shared(tile_KV);
|
||||||
|
|
||||||
|
constexpr int preload = 64;
|
||||||
|
constexpr int h2_per_chunk = 16/sizeof(half2);
|
||||||
|
constexpr int chunks_per_row = k0_sync_start / h2_per_chunk;
|
||||||
|
constexpr int stride_i = WARP_SIZE / chunks_per_row;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < KQ_stride; i0 += nwarps*stride_i) {
|
||||||
|
const int i = i0 + threadIdx.y*stride_i + (chunks_per_row == WARP_SIZE ? 0 : threadIdx.x / chunks_per_row);
|
||||||
|
const int k = (chunks_per_row == WARP_SIZE ? threadIdx.x : threadIdx.x % chunks_per_row)*h2_per_chunk;
|
||||||
|
|
||||||
|
cp_async_cg_16<preload>(tile_KV_32 + (i*D2_padded + k)*sizeof(half2), KV + i*stride_KV + k);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
constexpr int k0_sync_start = 0;
|
||||||
|
#endif // CP_ASYNC_AVAILABLE
|
||||||
|
static_assert(k0_sync_start % WARP_SIZE == 0, "bad k0_sync_start");
|
||||||
|
|
||||||
|
// If D is not a power of 2, the rest is loaded synchronously.
|
||||||
|
// K/V data is loaded with decreasing granularity for D for better memory bandwidth.
|
||||||
|
static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds");
|
||||||
|
#pragma unroll
|
||||||
|
for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
|
||||||
|
const int k0_start = stride_k == WARP_SIZE ? k0_sync_start : D/2 - (D/2) % (2*stride_k);
|
||||||
|
const int k0_stop = D/2 - (D/2) % (1*stride_k);
|
||||||
|
const int stride_i = WARP_SIZE / stride_k;
|
||||||
|
|
||||||
|
if (k0_start == k0_stop || k0_stop <= k0_sync_start) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < KQ_stride; i0 += nwarps*stride_i) {
|
||||||
|
const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
||||||
|
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
||||||
|
|
||||||
|
tile_KV[i*D2_padded + k] = KV[i*stride_KV + k];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int D, int ncols, int nwarps, int KQ_stride, bool use_logit_softcap, bool needs_fixup, bool is_fixup, bool last_iter>
|
||||||
|
static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||||
|
const float2 * const __restrict__ Q_f2,
|
||||||
|
const half2 * const __restrict__ K_h2,
|
||||||
|
const half2 * const __restrict__ V_h2,
|
||||||
|
const half * const __restrict__ maskh,
|
||||||
|
float2 * const __restrict__ dstk,
|
||||||
|
float2 * const __restrict__ dstk_fixup,
|
||||||
|
const float scale,
|
||||||
|
const float slope,
|
||||||
|
const float logit_softcap,
|
||||||
|
const int ne01,
|
||||||
|
const int ne02,
|
||||||
|
const int stride_Q,
|
||||||
|
const int stride_KV,
|
||||||
|
const int stride_mask,
|
||||||
|
const int jt,
|
||||||
|
half2 * const __restrict__ tile_K,
|
||||||
|
half2 * const __restrict__ tile_V,
|
||||||
|
const tile_B * const __restrict__ Q_B,
|
||||||
|
tile_C_VKQ * const __restrict__ VKQ_C,
|
||||||
|
float2 & KQ_max,
|
||||||
|
float2 & KQ_rowsum,
|
||||||
|
const int kb0) {
|
||||||
|
#ifdef NEW_MMA_AVAILABLE
|
||||||
|
constexpr int np = nwarps*tile_B::I / ncols; // Number of parallel CUDA warps per Q column.
|
||||||
|
constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
|
||||||
|
|
||||||
|
const int k_VKQ_0 = kb0*KQ_stride;
|
||||||
|
tile_C_KQ KQ_C[KQ_stride/(np*tile_C_KQ::I)];
|
||||||
|
|
||||||
|
#ifdef CP_ASYNC_AVAILABLE
|
||||||
|
cp_async_wait_all();
|
||||||
|
__syncthreads();
|
||||||
|
flash_attn_ext_f16_load_tile<D, nwarps, KQ_stride>(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV);
|
||||||
|
#else
|
||||||
|
flash_attn_ext_f16_load_tile<D, nwarps, KQ_stride>(K_h2 + k_VKQ_0*stride_KV, tile_K, stride_KV);
|
||||||
|
__syncthreads();
|
||||||
|
#endif // CP_ASYNC_AVAILABLE
|
||||||
|
|
||||||
|
// Calculate tile of KQ:
|
||||||
|
#pragma unroll
|
||||||
|
for (int i_KQ_00 = 0; i_KQ_00 < KQ_stride; i_KQ_00 += np*tile_A::I) {
|
||||||
|
const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I;
|
||||||
|
#pragma unroll
|
||||||
|
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += tile_A::J) {
|
||||||
|
tile_A K_A;
|
||||||
|
load_ldmatrix(K_A, tile_K + i_KQ_0*D2_padded + k_KQ_0, D2_padded);
|
||||||
|
mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, ((tile_B *) Q_B)[k_KQ_0/tile_A::J]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifndef CP_ASYNC_AVAILABLE
|
||||||
|
__syncthreads(); // Only needed if tile_K == tile_V.
|
||||||
|
#endif // CP_ASYNC_AVAILABLE
|
||||||
|
|
||||||
|
if (use_logit_softcap) {
|
||||||
|
static_assert(KQ_stride % (np*tile_C_KQ::I) == 0, "bad loop size");
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < KQ_stride/(np*tile_C_KQ::I); ++i) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int l = 0; l < tile_C_KQ::ne; ++l) {
|
||||||
|
KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (maskh) {
|
||||||
|
static_assert(KQ_stride % (np *tile_C_KQ::I) == 0, "bad loop size");
|
||||||
|
static_assert(ncols % (nwarps/np*tile_C_KQ::J) == 0, "bad loop size");
|
||||||
|
#pragma unroll
|
||||||
|
for (int i00 = 0; i00 < KQ_stride; i00 += np*tile_C_KQ::I) {
|
||||||
|
const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I;
|
||||||
|
#pragma unroll
|
||||||
|
for (int l = 0; l < tile_C_KQ::ne; ++l) {
|
||||||
|
const int i = i0 + tile_C_KQ::get_i(l);
|
||||||
|
const int j = (threadIdx.y / np)*tile_C_KQ::J + tile_C_KQ::get_j(l);
|
||||||
|
|
||||||
|
KQ_C[i00/(np*tile_C_KQ::I)].x[l] += slope*__half2float(maskh[j*stride_mask + k_VKQ_0 + i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate softmax for each KQ column using the current max. value.
|
||||||
|
// The divisor is stored in KQ_rowsum and will be applied at the end.
|
||||||
|
float2 KQ_max_new = KQ_max;
|
||||||
|
static_assert(KQ_stride % (np*tile_C_KQ::I) == 0, "bad loop size");
|
||||||
|
#pragma unroll
|
||||||
|
for (int k = 0; k < KQ_stride/(np*tile_C_KQ::I); ++k) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int l0 = 0; l0 < tile_C_KQ::ne; l0 += 2) {
|
||||||
|
KQ_max_new.x = fmaxf(KQ_max_new.x, KQ_C[k].x[l0 + 0]);
|
||||||
|
KQ_max_new.y = fmaxf(KQ_max_new.y, KQ_C[k].x[l0 + 1]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Values per KQ column are spread across 8 threads, does not need full warp reduce:
|
||||||
|
#pragma unroll
|
||||||
|
for (int offset = 16; offset > 2; offset >>= 1) {
|
||||||
|
KQ_max_new.x = fmaxf(KQ_max_new.x, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.x, offset, WARP_SIZE));
|
||||||
|
KQ_max_new.y = fmaxf(KQ_max_new.y, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.y, offset, WARP_SIZE));
|
||||||
|
}
|
||||||
|
|
||||||
|
float2 KQ_rowsum_add = make_float2(0.0f, 0.0f);
|
||||||
|
static_assert(KQ_stride % (np*tile_C_KQ::I) == 0, "bad loop size");
|
||||||
|
#pragma unroll
|
||||||
|
for (int k = 0; k < KQ_stride/(np*tile_C_KQ::I); ++k) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int l = 0; l < tile_C_KQ::ne; ++l) {
|
||||||
|
const float KQ_max_l = l % 2 == 0 ? KQ_max_new.x : KQ_max_new.y;
|
||||||
|
const float diff = KQ_C[k].x[l] - KQ_max_l;
|
||||||
|
KQ_C[k].x[l] = expf(diff);
|
||||||
|
|
||||||
|
if (l % 2 == 0) {
|
||||||
|
KQ_rowsum_add.x += KQ_C[k].x[l];
|
||||||
|
} else {
|
||||||
|
KQ_rowsum_add.y += KQ_C[k].x[l];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
const float2 diff = make_float2(KQ_max.x - KQ_max_new.x, KQ_max.y - KQ_max_new.y);
|
||||||
|
const float2 KQ_max_scale = make_float2(expf(diff.x), expf(diff.y));
|
||||||
|
KQ_max = KQ_max_new;
|
||||||
|
|
||||||
|
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
|
||||||
|
KQ_rowsum.x = KQ_max_scale.x*KQ_rowsum.x + KQ_rowsum_add.x;
|
||||||
|
KQ_rowsum.y = KQ_max_scale.y*KQ_rowsum.y + KQ_rowsum_add.y;
|
||||||
|
|
||||||
|
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale.x, KQ_max_scale.y);
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < D/tile_C_VKQ::I; ++i) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int l = 0; l < tile_C_VKQ::ne; ++l) {
|
||||||
|
VKQ_C[i].x[l] *= KQ_max_scale_h2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert KQ C tiles into B tiles for VKQ calculation:
|
||||||
|
tile_B B[KQ_stride/(np*2*tile_B::J)];
|
||||||
|
static_assert(KQ_stride % (np*2*tile_B::J) == 0, "bad loop size");
|
||||||
|
#pragma unroll
|
||||||
|
for (int k = 0; k < KQ_stride/(np*2*tile_B::J); ++k) {
|
||||||
|
B[k] = get_transposed(get_half2(KQ_C[k]));
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef CP_ASYNC_AVAILABLE
|
||||||
|
cp_async_wait_all();
|
||||||
|
__syncthreads();
|
||||||
|
if (!last_iter) {
|
||||||
|
flash_attn_ext_f16_load_tile<D, nwarps, KQ_stride>(K_h2 + (k_VKQ_0 + KQ_stride)*stride_KV, tile_K, stride_KV);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
flash_attn_ext_f16_load_tile<D, nwarps, KQ_stride>(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV);
|
||||||
|
__syncthreads();
|
||||||
|
#endif // CP_ASYNC_AVAILABLE
|
||||||
|
|
||||||
|
// Calculate VKQ tile:
|
||||||
|
#pragma unroll
|
||||||
|
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += tile_C_VKQ::I) {
|
||||||
|
static_assert((KQ_stride/2) % (np*tile_A::J) == 0, "bad loop size");
|
||||||
|
#pragma unroll
|
||||||
|
for (int k00 = 0; k00 < KQ_stride/2; k00 += np*tile_A::J) {
|
||||||
|
const int k0 = k00 + (threadIdx.y % np)*tile_A::J;
|
||||||
|
|
||||||
|
tile_A A;
|
||||||
|
load_ldmatrix_trans(A, tile_V + 2*k0*D2_padded + i_VKQ_0/2, D2_padded);
|
||||||
|
mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifndef CP_ASYNC_AVAILABLE
|
||||||
|
__syncthreads(); // Only needed if tile_K == tile_V.
|
||||||
|
#endif // CP_ASYNC_AVAILABLE
|
||||||
|
|
||||||
|
#else
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
#endif // NEW_MMA_AVAILABLE
|
||||||
|
}
|
||||||
|
|
||||||
template<int D, int ncols, int nwarps, int KQ_stride, bool use_logit_softcap, bool needs_fixup, bool is_fixup>
|
template<int D, int ncols, int nwarps, int KQ_stride, bool use_logit_softcap, bool needs_fixup, bool is_fixup>
|
||||||
static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||||
const float2 * const __restrict__ Q_f2,
|
const float2 * const __restrict__ Q_f2,
|
||||||
@ -13,61 +258,40 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|||||||
const float scale,
|
const float scale,
|
||||||
const float slope,
|
const float slope,
|
||||||
const float logit_softcap,
|
const float logit_softcap,
|
||||||
const int ne00,
|
|
||||||
const int ne01,
|
const int ne01,
|
||||||
const int ne02,
|
const int ne02,
|
||||||
const int ne03,
|
const int stride_Q,
|
||||||
const int ne10,
|
const int stride_KV,
|
||||||
const int ne11,
|
const int stride_mask,
|
||||||
const int ne12,
|
|
||||||
const int ne13,
|
|
||||||
const int ne31,
|
|
||||||
const int nb31,
|
|
||||||
const int nb01,
|
|
||||||
const int nb02,
|
|
||||||
const int nb03,
|
|
||||||
const int nb11,
|
|
||||||
const int nb12,
|
|
||||||
const int nb13,
|
|
||||||
const int nb21,
|
|
||||||
const int nb22,
|
|
||||||
const int nb23,
|
|
||||||
const int ne0,
|
|
||||||
const int ne1,
|
|
||||||
const int ne2,
|
|
||||||
const int ne3,
|
|
||||||
const int jt,
|
const int jt,
|
||||||
const int kb0_start,
|
const int kb0_start,
|
||||||
const int kb0_stop) {
|
const int kb0_stop) {
|
||||||
#ifdef NEW_MMA_AVAILABLE
|
#ifdef NEW_MMA_AVAILABLE
|
||||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||||
|
|
||||||
typedef mma_A_I16K8<half2> mma_A;
|
static_assert(nwarps*tile_B::I % ncols == 0, "bad nwarps");
|
||||||
typedef mma_B_J8K8<half2> mma_B;
|
constexpr int np = nwarps*tile_B::I / ncols; // Number of parallel CUDA warps per Q column.
|
||||||
typedef mma_C_I16J8<float> mma_C_KQ;
|
|
||||||
typedef mma_C_I16J8<half2> mma_C_VKQ;
|
|
||||||
|
|
||||||
static_assert(nwarps*mma_B::J % ncols == 0, "bad nwarps");
|
|
||||||
constexpr int np = nwarps*mma_B::J / ncols; // Number of parallel CUDA warps per Q column.
|
|
||||||
|
|
||||||
static_assert(D % nwarps == 0, "bad D");
|
static_assert(D % nwarps == 0, "bad D");
|
||||||
static_assert(KQ_stride % nwarps == 0, "bad KQ_stride");
|
static_assert(KQ_stride % nwarps == 0, "bad KQ_stride");
|
||||||
|
|
||||||
constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
|
constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
|
||||||
extern __shared__ half2 tile_KV[]; // Temporary shared buffer for loading K/V data with KQ_stride*D logical elements.
|
|
||||||
|
|
||||||
const int stride_Q = nb01 / sizeof(float2);
|
// Temporary shared buffer for loading K/V data with KQ_stride*D logical elements:
|
||||||
const int stride_KV = nb11 / sizeof(half2);
|
extern __shared__ half2 tile_K[];
|
||||||
const int stride_mask = nb31 / sizeof(half);
|
#ifdef CP_ASYNC_AVAILABLE
|
||||||
|
half2 * tile_V = tile_K + KQ_stride*D2_padded;
|
||||||
|
#else
|
||||||
|
half2 * tile_V = tile_K;
|
||||||
|
#endif // CP_ASYNC_AVAILABLE
|
||||||
|
|
||||||
mma_B Q_B[D/(2*mma_B::K)];
|
tile_B Q_B[D/(2*tile_B::J)];
|
||||||
mma_C_VKQ VKQ_C[D/mma_C_VKQ::I];
|
tile_C_VKQ VKQ_C[D/tile_C_VKQ::I];
|
||||||
|
|
||||||
float2 KQ_rowsum = {0.0f, 0.0f};
|
float2 KQ_rowsum = {0.0f, 0.0f};
|
||||||
float2 KQ_max = {-FLT_MAX/2.0f, -FLT_MAX/2.0f};
|
float2 KQ_max = {-FLT_MAX/2.0f, -FLT_MAX/2.0f};
|
||||||
float2 KQ_max_scale = {0.0f, 0.0f};
|
|
||||||
|
|
||||||
// Temporarily load Q data into tile_KV, will be loaded into registers afterwards.
|
// Temporarily load Q data into tile_K, will be loaded into registers afterwards.
|
||||||
// The loading is done with decreasing granularity for D for better memory bandwidth.
|
// The loading is done with decreasing granularity for D for better memory bandwidth.
|
||||||
const half2 scale_h2 = make_half2(scale, scale);
|
const half2 scale_h2 = make_half2(scale, scale);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
@ -76,6 +300,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|||||||
const int k0_stop = D/2 - (D/2) % (1*stride_k);
|
const int k0_stop = D/2 - (D/2) % (1*stride_k);
|
||||||
const int stride_j = WARP_SIZE / stride_k;
|
const int stride_j = WARP_SIZE / stride_k;
|
||||||
|
|
||||||
|
if (k0_start == k0_stop) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) {
|
if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -90,14 +318,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|||||||
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
||||||
|
|
||||||
const float2 tmp = Q_f2[(jt*ncols + j)*stride_Q + k];
|
const float2 tmp = Q_f2[(jt*ncols + j)*stride_Q + k];
|
||||||
tile_KV[j*D2_padded + k] = scale_h2 * make_half2(tmp.x, tmp.y);
|
tile_K[j*D2_padded + k] = scale_h2 * make_half2(tmp.x, tmp.y);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
||||||
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
||||||
|
|
||||||
tile_KV[j*D2_padded + k] = make_half2(0.0f, 0.0f);
|
tile_K[j*D2_padded + k] = make_half2(0.0f, 0.0f);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -106,198 +334,42 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
{
|
{
|
||||||
const int j0 = (threadIdx.y / np) * mma_B::J;
|
const int j0 = (threadIdx.y / np) * tile_B::I;
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k0 = 0; k0 < D/2; k0 += mma_B::K) {
|
for (int k0 = 0; k0 < D/2; k0 += tile_B::J) {
|
||||||
Q_B[k0/mma_B::K].load_ldmatrix(tile_KV + j0*D2_padded + k0, D2_padded);
|
load_ldmatrix(Q_B[k0/tile_B::J], tile_K + j0*D2_padded + k0, D2_padded);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
|
// Preload K data for first iteration when using cp_async:
|
||||||
|
#ifdef CP_ASYNC_AVAILABLE
|
||||||
|
flash_attn_ext_f16_load_tile<D, nwarps, KQ_stride>(K_h2 + kb0_start*KQ_stride*stride_KV, tile_K, stride_KV);
|
||||||
|
#endif // CP_ASYNC_AVAILABLE
|
||||||
|
|
||||||
// Iterate over ne11 == previous tokens:
|
// Iterate over ne11 == previous tokens:
|
||||||
for (int kb0 = kb0_start; kb0 < kb0_stop; ++kb0) {
|
for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) {
|
||||||
const int k_VKQ_0 = kb0*KQ_stride;
|
constexpr bool last_iter = false;
|
||||||
mma_C_KQ KQ_C[KQ_stride/(np*mma_C_KQ::I)];
|
flash_attn_ext_f16_iter<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup, last_iter>
|
||||||
|
(Q_f2, K_h2, V_h2, maskh, dstk, dstk_fixup, scale, slope, logit_softcap,
|
||||||
// Load K data into tile with decreasing granularity for D for better memory bandwidth:
|
ne01, ne02, stride_Q, stride_KV, stride_mask, jt, tile_K, tile_V, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
|
||||||
static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds");
|
}
|
||||||
#pragma unroll
|
{ // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
|
||||||
for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
|
constexpr bool last_iter = true;
|
||||||
const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k);
|
flash_attn_ext_f16_iter<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup, last_iter>
|
||||||
const int k0_stop = D/2 - (D/2) % (1*stride_k);
|
(Q_f2, K_h2, V_h2, maskh, dstk, dstk_fixup, scale, slope, logit_softcap,
|
||||||
const int stride_i = WARP_SIZE / stride_k;
|
ne01, ne02, stride_Q, stride_KV, stride_mask, jt, tile_K, tile_V, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
|
||||||
|
}
|
||||||
#pragma unroll
|
|
||||||
for (int i_KQ_0 = 0; i_KQ_0 < KQ_stride; i_KQ_0 += nwarps*stride_i) {
|
|
||||||
const int i_KQ = i_KQ_0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += stride_k) {
|
|
||||||
const int k_KQ = k_KQ_0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
|
||||||
|
|
||||||
tile_KV[i_KQ*D2_padded + k_KQ] = K_h2[(k_VKQ_0 + i_KQ)*stride_KV + k_KQ];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
// Calculate tile of KQ:
|
|
||||||
#pragma unroll
|
|
||||||
for (int i_KQ_00 = 0; i_KQ_00 < KQ_stride; i_KQ_00 += np*mma_A::I) {
|
|
||||||
const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*mma_A::I;
|
|
||||||
#pragma unroll
|
|
||||||
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += mma_A::K) {
|
|
||||||
mma_A K_A;
|
|
||||||
K_A.load_ldmatrix(tile_KV + i_KQ_0*D2_padded + k_KQ_0, D2_padded);
|
|
||||||
KQ_C[i_KQ_00/(np*mma_A::I)].mma(K_A, Q_B[k_KQ_0/mma_A::K]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
if (use_logit_softcap) {
|
|
||||||
static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size");
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < KQ_stride/(np*mma_C_KQ::I); ++i) {
|
|
||||||
#pragma unroll
|
|
||||||
for (int l = 0; l < mma_C_KQ::ne; ++l) {
|
|
||||||
KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (maskh) {
|
|
||||||
static_assert(KQ_stride % (np *mma_C_KQ::I) == 0, "bad loop size");
|
|
||||||
static_assert(ncols % (nwarps/np*mma_C_KQ::J) == 0, "bad loop size");
|
|
||||||
#pragma unroll
|
|
||||||
for (int i00 = 0; i00 < KQ_stride; i00 += np*mma_C_KQ::I) {
|
|
||||||
const int i0 = i00 + (threadIdx.y % np)*mma_C_KQ::I;
|
|
||||||
#pragma unroll
|
|
||||||
for (int l = 0; l < mma_C_KQ::ne; ++l) {
|
|
||||||
const int i = i0 + mma_C_KQ::get_i(l);
|
|
||||||
const int j = (threadIdx.y / np)*mma_C_KQ::J + mma_C_KQ::get_j(l);
|
|
||||||
|
|
||||||
KQ_C[i00/(np*mma_C_KQ::I)].x[l] += slope*__half2float(maskh[j*stride_mask + k_VKQ_0 + i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Calculate softmax for each KQ column using the current max. value.
|
|
||||||
// The divisor is stored in KQ_rowsum and will be applied at the end.
|
|
||||||
float2 KQ_max_new = KQ_max;
|
|
||||||
static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size");
|
|
||||||
#pragma unroll
|
|
||||||
for (int k = 0; k < KQ_stride/(np*mma_C_KQ::I); ++k) {
|
|
||||||
#pragma unroll
|
|
||||||
for (int l0 = 0; l0 < mma_C_KQ::ne; l0 += 2) {
|
|
||||||
KQ_max_new.x = fmaxf(KQ_max_new.x, KQ_C[k].x[l0 + 0]);
|
|
||||||
KQ_max_new.y = fmaxf(KQ_max_new.y, KQ_C[k].x[l0 + 1]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Values per KQ column are spread across 8 threads, does not need full warp reduce:
|
|
||||||
#pragma unroll
|
|
||||||
for (int offset = 16; offset > 2; offset >>= 1) {
|
|
||||||
KQ_max_new.x = fmaxf(KQ_max_new.x, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.x, offset, WARP_SIZE));
|
|
||||||
KQ_max_new.y = fmaxf(KQ_max_new.y, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.y, offset, WARP_SIZE));
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
const float2 diff = make_float2(KQ_max.x - KQ_max_new.x, KQ_max.y - KQ_max_new.y);
|
|
||||||
KQ_max_scale = make_float2(expf(diff.x), expf(diff.y));
|
|
||||||
if (diff.x <= SOFTMAX_FTZ_THRESHOLD) {
|
|
||||||
KQ_max_scale.x = 0.0f;
|
|
||||||
}
|
|
||||||
if (diff.y <= SOFTMAX_FTZ_THRESHOLD) {
|
|
||||||
KQ_max_scale.y = 0.0f;
|
|
||||||
}
|
|
||||||
KQ_max = KQ_max_new;
|
|
||||||
}
|
|
||||||
|
|
||||||
float2 KQ_rowsum_add = make_float2(0.0f, 0.0f);
|
|
||||||
static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size");
|
|
||||||
#pragma unroll
|
|
||||||
for (int k = 0; k < KQ_stride/(np*mma_C_KQ::I); ++k) {
|
|
||||||
#pragma unroll
|
|
||||||
for (int l = 0; l < mma_C_KQ::ne; ++l) {
|
|
||||||
const float KQ_max_l = l % 2 == 0 ? KQ_max.x : KQ_max.y;
|
|
||||||
const float diff = KQ_C[k].x[l] - KQ_max_l;
|
|
||||||
KQ_C[k].x[l] = expf(diff);
|
|
||||||
if (diff <= SOFTMAX_FTZ_THRESHOLD) {
|
|
||||||
KQ_C[k].x[l] = 0.0f;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (l % 2 == 0) {
|
|
||||||
KQ_rowsum_add.x += KQ_C[k].x[l];
|
|
||||||
} else {
|
|
||||||
KQ_rowsum_add.y += KQ_C[k].x[l];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
|
|
||||||
KQ_rowsum.x = KQ_max_scale.x*KQ_rowsum.x + KQ_rowsum_add.x;
|
|
||||||
KQ_rowsum.y = KQ_max_scale.y*KQ_rowsum.y + KQ_rowsum_add.y;
|
|
||||||
|
|
||||||
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale.x, KQ_max_scale.y);
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < D/mma_C_VKQ::I; ++i) {
|
|
||||||
#pragma unroll
|
|
||||||
for (int l = 0; l < mma_C_VKQ::ne; ++l) {
|
|
||||||
VKQ_C[i].x[l] *= KQ_max_scale_h2;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert KQ C tiles into B tiles for VKQ calculation:
|
|
||||||
mma_B B[KQ_stride/(np*2*mma_B::K)];
|
|
||||||
static_assert(KQ_stride % (np*2*mma_B::K) == 0, "bad loop size");
|
|
||||||
#pragma unroll
|
|
||||||
for (int k = 0; k < KQ_stride/(np*2*mma_B::K); ++k) {
|
|
||||||
B[k] = KQ_C[k].to_mma_B();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load V data into tile with decreasing granularity for D for better memory bandwidth:
|
|
||||||
static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds");
|
|
||||||
#pragma unroll
|
|
||||||
for (int stride_i : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
|
|
||||||
const int i0_start = stride_i == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_i);
|
|
||||||
const int i0_stop = D/2 - (D/2) % (1*stride_i);
|
|
||||||
const int stride_k = WARP_SIZE / stride_i;
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int k_V_0 = 0; k_V_0 < KQ_stride; k_V_0 += nwarps*stride_k) {
|
|
||||||
const int k_V = k_V_0 + threadIdx.y*stride_k + (stride_i == WARP_SIZE ? 0 : threadIdx.x / stride_i);
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int i_V_0 = i0_start; i_V_0 < i0_stop; i_V_0 += stride_i) {
|
|
||||||
const int i_V = i_V_0 + (stride_i == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_i);
|
|
||||||
|
|
||||||
tile_KV[k_V*D2_padded + i_V] = V_h2[(k_VKQ_0 + k_V)*stride_KV + i_V];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
// Calculate VKQ tile:
|
|
||||||
#pragma unroll
|
|
||||||
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += mma_C_VKQ::I) {
|
|
||||||
static_assert((KQ_stride/2) % (np*mma_A::K) == 0, "bad loop size");
|
|
||||||
#pragma unroll
|
|
||||||
for (int k00 = 0; k00 < KQ_stride/2; k00 += np*mma_A::K) {
|
|
||||||
const int k0 = k00 + (threadIdx.y % np)*mma_A::K;
|
|
||||||
|
|
||||||
mma_A A;
|
|
||||||
A.load_ldmatrix_trans(tile_KV + 2*k0*D2_padded + i_VKQ_0/2, D2_padded);
|
|
||||||
VKQ_C[i_VKQ_0/mma_C_VKQ::I].mma(A, B[k00/(np*mma_A::K)]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
// With cp_async there is no __syncthreads at the end of the iter,
|
||||||
|
// there can be a race condition on shared memory access for combining/writing back results.
|
||||||
|
#ifdef CP_ASYNC_AVAILABLE
|
||||||
|
if (nwarps*tile_B::I > KQ_stride) {
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
}
|
}
|
||||||
|
#endif // CP_ASYNC_AVAILABLE
|
||||||
|
|
||||||
// Finally, sum up partial KQ rowsums.
|
// Finally, sum up partial KQ rowsums.
|
||||||
// The partial sums are spread across 8 threads each, does not need full reduce.
|
// The partial sums are spread across 8 threads each, does not need full reduce.
|
||||||
@ -310,26 +382,26 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|||||||
// Write VKQ accumulators to shared memory in column-major format.
|
// Write VKQ accumulators to shared memory in column-major format.
|
||||||
// It's faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
|
// It's faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
|
||||||
// Also for np > 1 the combination is done via these values in shared memory.
|
// Also for np > 1 the combination is done via these values in shared memory.
|
||||||
const int j_cwd = threadIdx.y*mma_B::J + mma_B::get_j(-1); // j combine write data
|
const int j_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // j combine write data
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k0 = 0; k0 < D/2; k0 += mma_B::K) {
|
for (int k0 = 0; k0 < D/2; k0 += tile_B::J) {
|
||||||
const mma_B B = VKQ_C[k0/mma_B::K].to_mma_B(); // Conversion of C to B matrix puts it in column-major format.
|
const tile_B B = get_transposed(VKQ_C[k0/tile_B::J]); // Conversion of C to B matrix puts it in column-major format.
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int l = 0; l < mma_B::ne; ++l) {
|
for (int l = 0; l < tile_B::ne; ++l) {
|
||||||
const int k = k0 + mma_B::get_k(l);
|
const int k = k0 + tile_B::get_j(l);
|
||||||
|
|
||||||
tile_KV[j_cwd*D2_padded + k] = B.x[l];
|
tile_K[j_cwd*D2_padded + k] = B.x[l];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const int j_cwmo = (threadIdx.x % (2*mma_C_VKQ::J)) / mma_C_VKQ::J; // j combine write meta offset
|
const int j_cwmo = (threadIdx.x % (2*tile_C_VKQ::J)) / tile_C_VKQ::J; // j combine write meta offset
|
||||||
const int j_cwm = threadIdx.y*(2*mma_C_VKQ::J) + 2*mma_C_VKQ::get_j(-1) + j_cwmo; // j combine write meta
|
const int j_cwm = threadIdx.y*(2*tile_C_VKQ::J) + 2*tile_C_VKQ::get_j(-1) + j_cwmo; // j combine write meta
|
||||||
const float2 KQ_cmr = make_float2(((const float *) &KQ_max)[j_cwmo], ((const float *) &KQ_rowsum)[j_cwmo]); // KQ combine max rowsum
|
const float2 KQ_cmr = make_float2(((const float *) &KQ_max)[j_cwmo], ((const float *) &KQ_rowsum)[j_cwmo]); // KQ combine max rowsum
|
||||||
|
|
||||||
if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*mma_C_VKQ::J) {
|
if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*tile_C_VKQ::J) {
|
||||||
// Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
|
// Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
|
||||||
((float2 *) tile_KV)[j_cwm*(D2_padded/2) + D/4] = KQ_cmr;
|
((float2 *) tile_K)[j_cwm*(D2_padded/2) + D/4] = KQ_cmr;
|
||||||
}
|
}
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
@ -337,11 +409,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|||||||
static_assert(np == 1 || np == 2 || np == 4, "bad np");
|
static_assert(np == 1 || np == 2 || np == 4, "bad np");
|
||||||
if (np == 1) {
|
if (np == 1) {
|
||||||
// No combination is needed, the meta data can be directly written from registers to VRAM.
|
// No combination is needed, the meta data can be directly written from registers to VRAM.
|
||||||
if (needs_fixup && threadIdx.x < mma_B::J) {
|
if (needs_fixup && threadIdx.x < tile_B::I) {
|
||||||
float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
|
float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
|
||||||
dstk_fixup_meta[j_cwm] = KQ_cmr;
|
dstk_fixup_meta[j_cwm] = KQ_cmr;
|
||||||
}
|
}
|
||||||
if (is_fixup && threadIdx.x < mma_B::J) {
|
if (is_fixup && threadIdx.x < tile_B::I) {
|
||||||
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
|
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
|
||||||
dstk_fixup_meta[j_cwm] = KQ_cmr;
|
dstk_fixup_meta[j_cwm] = KQ_cmr;
|
||||||
}
|
}
|
||||||
@ -350,42 +422,40 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|||||||
// Warps with threadIdx.y % np != 0 must NOT return early.
|
// Warps with threadIdx.y % np != 0 must NOT return early.
|
||||||
// All threads must return simultaneously to avoid race conditions with work on the next tile.
|
// All threads must return simultaneously to avoid race conditions with work on the next tile.
|
||||||
|
|
||||||
float * meta_j = (float *) tile_KV + (threadIdx.y*mma_B::J + threadIdx.x)*D2_padded + D/2;
|
float * meta_j = (float *) tile_K + (threadIdx.y*tile_B::I + threadIdx.x)*D2_padded + D/2;
|
||||||
|
|
||||||
float KQ_cm = -FLT_MAX/2; // KQ combine max per parallel warp.
|
float KQ_cm = -FLT_MAX/2; // KQ combine max per parallel warp.
|
||||||
if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) {
|
if (np*tile_B::I == WARP_SIZE || threadIdx.x < np*tile_B::I) {
|
||||||
KQ_cm = meta_j[0];
|
KQ_cm = meta_j[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
float KQ_cmn = KQ_cm; // KQ combine max new, max between all parallel warps.
|
float KQ_cmn = KQ_cm; // KQ combine max new, max between all parallel warps.
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int offset = np*mma_B::J/2; offset >= mma_B::J; offset >>= 1) {
|
for (int offset = np*tile_B::I/2; offset >= tile_B::I; offset >>= 1) {
|
||||||
KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE));
|
KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE));
|
||||||
}
|
}
|
||||||
|
|
||||||
const float KQ_cms = expf(KQ_cm - KQ_cmn); // KQ combine max scale per warp.
|
const float KQ_cms = expf(KQ_cm - KQ_cmn); // KQ combine max scale per warp.
|
||||||
float KQ_crs = 0.0f; // KQ combine rowsum, scaled sum of all parallel warps.
|
float KQ_crs = 0.0f; // KQ combine rowsum, scaled sum of all parallel warps.
|
||||||
if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) {
|
if (np*tile_B::I == WARP_SIZE || threadIdx.x < np*tile_B::I) {
|
||||||
KQ_crs = KQ_cms*meta_j[1];
|
KQ_crs = KQ_cms*meta_j[1];
|
||||||
}
|
}
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int offset = np*mma_B::J/2; offset >= mma_B::J; offset >>= 1) {
|
for (int offset = np*tile_B::I/2; offset >= tile_B::I; offset >>= 1) {
|
||||||
KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE);
|
KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write back combined meta data:
|
// Write back combined meta data:
|
||||||
if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) {
|
if (np*tile_B::I == WARP_SIZE || threadIdx.x < np*tile_B::I) {
|
||||||
meta_j[0] = KQ_cmn; // Combined max. KQ values.
|
*((float2 *) meta_j) = make_float2(KQ_cms, KQ_crs); // Combined KQ max scale + rowsum.
|
||||||
meta_j[1] = KQ_crs; // Combined KQ rowsums.
|
|
||||||
meta_j[2] = KQ_cms; // KQ max scales per parallel warp.
|
|
||||||
}
|
}
|
||||||
if (needs_fixup && threadIdx.x < mma_B::J) {
|
if (needs_fixup && threadIdx.x < tile_B::I) {
|
||||||
float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
|
float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
|
||||||
dstk_fixup_meta[(threadIdx.y/np)*mma_B::J + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
|
dstk_fixup_meta[(threadIdx.y/np)*tile_B::I + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
|
||||||
}
|
}
|
||||||
if (is_fixup && threadIdx.x < mma_B::J) {
|
if (is_fixup && threadIdx.x < tile_B::I) {
|
||||||
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
|
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
|
||||||
dstk_fixup_meta[(threadIdx.y/np)*mma_B::J + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
|
dstk_fixup_meta[(threadIdx.y/np)*tile_B::I + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -404,6 +474,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|||||||
const int k0_stop = D/2 - (D/2) % (1*stride_k);
|
const int k0_stop = D/2 - (D/2) % (1*stride_k);
|
||||||
const int stride_j = WARP_SIZE / stride_k;
|
const int stride_j = WARP_SIZE / stride_k;
|
||||||
|
|
||||||
|
if (k0_start == k0_stop) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) {
|
if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -411,12 +485,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j0_dst = 0; j0_dst < ncols; j0_dst += (nwarps/np)*stride_j) {
|
for (int j0_dst = 0; j0_dst < ncols; j0_dst += (nwarps/np)*stride_j) {
|
||||||
const int j_dst = j0_dst + (threadIdx.y/np)*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
|
const int j_dst = j0_dst + (threadIdx.y/np)*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
|
||||||
const int j_tile_KV = (j_dst/mma_B::J)*(np*mma_B::J) + j_dst % mma_B::J;
|
const int j_tile_K = (j_dst/tile_B::I)*(np*tile_B::I) + j_dst % tile_B::I;
|
||||||
|
|
||||||
if (!is_fixup && jt*ncols + j_dst >= ne01) {
|
if (!is_fixup && jt*ncols + j_dst >= ne01) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
const float * meta_j = (const float *) tile_KV + j_tile_KV*D2_padded + D/2;
|
const float * meta_j = (const float *) tile_K + j_tile_K*D2_padded + D/2;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
||||||
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
||||||
@ -424,8 +498,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|||||||
float2 dstk_val = make_float2(0.0f, 0.0f);
|
float2 dstk_val = make_float2(0.0f, 0.0f);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int ip = 0; ip < np; ++ip) {
|
for (int ip = 0; ip < np; ++ip) {
|
||||||
const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*mma_B::J*D2_padded + 2];
|
const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*tile_B::I*D2_padded + 0];
|
||||||
const float2 dstk_val_add = __half22float2(tile_KV[(j_tile_KV + ip*mma_B::J)*D2_padded + k]);
|
const float2 dstk_val_add = __half22float2(tile_K[(j_tile_K + ip*tile_B::I)*D2_padded + k]);
|
||||||
dstk_val.x += dstk_val_add.x*KQ_crs;
|
dstk_val.x += dstk_val_add.x*KQ_crs;
|
||||||
dstk_val.y += dstk_val_add.y*KQ_crs;
|
dstk_val.y += dstk_val_add.y*KQ_crs;
|
||||||
}
|
}
|
||||||
@ -450,7 +524,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|||||||
__syncthreads();
|
__syncthreads();
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
#endif // NEW_MMA_AVAILABLE
|
#endif // NEW_MMA_AVAILABLE
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -494,6 +568,11 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
const int ne1,
|
const int ne1,
|
||||||
const int ne2,
|
const int ne2,
|
||||||
const int ne3) {
|
const int ne3) {
|
||||||
|
#ifndef NEW_MMA_AVAILABLE
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
return;
|
||||||
|
#endif // NEW_MMA_AVAILABLE
|
||||||
|
|
||||||
// Skip unused kernel variants for faster compilation:
|
// Skip unused kernel variants for faster compilation:
|
||||||
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
@ -504,6 +583,10 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
|
|
||||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||||
|
|
||||||
|
const int stride_Q = nb01 / sizeof(float2);
|
||||||
|
const int stride_KV = nb11 / sizeof(half2);
|
||||||
|
const int stride_mask = nb31 / sizeof(half);
|
||||||
|
|
||||||
const int iter_k = ne11 / KQ_stride;
|
const int iter_k = ne11 / KQ_stride;
|
||||||
const int iter_j = (ne01 + (ncols - 1)) / ncols;
|
const int iter_j = (ne01 + (ncols - 1)) / ncols;
|
||||||
|
|
||||||
@ -535,14 +618,12 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
|
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
|
||||||
flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup>
|
flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup>
|
||||||
(Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap,
|
(Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap,
|
||||||
ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3,
|
ne01, ne02, stride_Q, stride_KV, stride_mask, jt, kb0_start, kb0_stop);
|
||||||
jt, kb0_start, kb0_stop);
|
|
||||||
} else {
|
} else {
|
||||||
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
|
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
|
||||||
flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup>
|
flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup>
|
||||||
(Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap,
|
(Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap,
|
||||||
ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3,
|
ne01, ne02, stride_Q, stride_KV, stride_mask, jt, kb0_start, kb0_stop);
|
||||||
jt, kb0_start, kb0_stop);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
kbc += iter_k;
|
kbc += iter_k;
|
||||||
@ -571,24 +652,27 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
constexpr bool needs_fixup = false;
|
constexpr bool needs_fixup = false;
|
||||||
flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup>
|
flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup>
|
||||||
(Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap,
|
(Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap,
|
||||||
ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3,
|
ne01, ne02, stride_Q, stride_KV, stride_mask, jt, kb0_start, kb0_stop);
|
||||||
jt, kb0_start, kb0_stop);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int D, int cols_per_block>
|
template <int D, int cols_per_block>
|
||||||
void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
typedef mma_A_I16K8<half2> mma_A;
|
typedef tile<16, 8, half2> tile_A;
|
||||||
typedef mma_B_J8K8<half2> mma_B;
|
typedef tile< 8, 8, half2> tile_B;
|
||||||
|
|
||||||
static_assert(D % mma_B::K == 0, "bad D");
|
static_assert(D % tile_B::J == 0, "bad D");
|
||||||
static_assert(cols_per_block % mma_B::J == 0, "bad cols_per_block");
|
static_assert(cols_per_block % tile_B::I == 0, "bad cols_per_block");
|
||||||
|
|
||||||
const ggml_tensor * KQV = dst;
|
const ggml_tensor * KQV = dst;
|
||||||
|
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||||
|
|
||||||
constexpr int KQ_stride = D <= 128 ? 64 : 32;
|
constexpr int KQ_stride = D <= 128 ? 64 : 32;
|
||||||
constexpr int nwarps = (KQ_stride == 32 && cols_per_block <= 16) ?
|
constexpr int nwarps = (KQ_stride == 32 && cols_per_block <= 16) ?
|
||||||
cols_per_block/mma_B::J * KQ_stride/mma_A::I : (cols_per_block <= 8 ? 4 : 8);
|
cols_per_block/tile_B::J * KQ_stride/tile_A::I : (cols_per_block <= 8 ? 4 : 8);
|
||||||
constexpr size_t nbytes_shared = std::max(KQ_stride, nwarps*mma_B::J) * (D + 8) * sizeof(half);
|
|
||||||
|
const int nrows_KQ = cp_async_available(cc) ? 2*KQ_stride : KQ_stride;
|
||||||
|
const int nrows_combine = nwarps*tile_B::J;
|
||||||
|
const size_t nbytes_shared = std::max(nrows_KQ, nrows_combine) * (D + 8) * sizeof(half);
|
||||||
|
|
||||||
float logit_softcap;
|
float logit_softcap;
|
||||||
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||||
|
@ -4,11 +4,12 @@
|
|||||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-multiply-accumulate-operation-using-mma-instruction
|
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-multiply-accumulate-operation-using-mma-instruction
|
||||||
//
|
//
|
||||||
// Like with nvcuda::wmma there are three types of matrix tiles: A, B, and C with A @ B = C.
|
// Like with nvcuda::wmma there are three types of matrix tiles: A, B, and C with A @ B = C.
|
||||||
// A is a row-major matrix with shape I x K.
|
// A is a row-major matrix with shape M x K.
|
||||||
// B is a column-major matrix with shape K x J.
|
// B is a column-major matrix with shape K x N.
|
||||||
// C is a column-major matrix with shape I x J.
|
// C is a column-major matrix with shape M x N.
|
||||||
// Note that along their lowest dimension I, J, and K are measured in physical 32 bit elements instead of logical elements.
|
// A, B, and C are represented using the same fundamental data type: a row-major matrix with I rows and J columns.
|
||||||
// The functions get_i, get_j, and get_k can be used to get the physical 32 bit index of the lth element of a thread within a tile.
|
// Note that J is measured in physical 32 bit elements instead of logical elements.
|
||||||
|
// The methods get_i and get_j can be used to get the physical 32 bit index of the lth element of a thread within a tile.
|
||||||
// All matrix tiles have ne physical 32 bit elements per warp.
|
// All matrix tiles have ne physical 32 bit elements per warp.
|
||||||
//
|
//
|
||||||
// As described in the documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes.
|
// As described in the documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes.
|
||||||
@ -23,7 +24,7 @@ static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
|
|||||||
|
|
||||||
#ifdef NEW_MMA_AVAILABLE
|
#ifdef NEW_MMA_AVAILABLE
|
||||||
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
|
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
|
||||||
: "+r"(ret) : "r"(x));
|
: "=r"(ret) : "r"(x));
|
||||||
#else
|
#else
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
#endif // defined(NEW_MMA_AVAILABLE)
|
#endif // defined(NEW_MMA_AVAILABLE)
|
||||||
@ -52,407 +53,267 @@ static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
|
|||||||
|
|
||||||
#endif // CUDART_VERSION >= 11080
|
#endif // CUDART_VERSION >= 11080
|
||||||
|
|
||||||
|
static __device__ __forceinline__ half2 ggml_cuda_movmatrix(const half2 x) {
|
||||||
|
half2 ret;
|
||||||
|
*((int *) &ret) = ggml_cuda_movmatrix(*((const int *) &x));
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
namespace ggml_cuda_mma {
|
||||||
struct mma_A_I16K4 {
|
|
||||||
static_assert(sizeof(T) == 4, "bad type size");
|
|
||||||
|
|
||||||
static constexpr int I = 16;
|
template <int I_, int J_, typename T>
|
||||||
static constexpr int K = 4;
|
struct tile {
|
||||||
static constexpr int ne = 2;
|
static constexpr int I = I_;
|
||||||
|
static constexpr int J = J_;
|
||||||
|
static constexpr int ne = I * J / WARP_SIZE;
|
||||||
|
T x[ne] = {0};
|
||||||
|
|
||||||
T x[ne];
|
static __device__ __forceinline__ int get_i(const int l) {
|
||||||
|
if constexpr (I == 8 && (J == 4 || J == 8)) {
|
||||||
|
return threadIdx.x / 4;
|
||||||
|
} else if constexpr (I == 16 && J == 8) {
|
||||||
|
return (l / 2) * 8 + threadIdx.x / 4;
|
||||||
|
} else {
|
||||||
|
static_assert(I == -1 && J == -1, "template specialization not implemented");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ int get_i(const int l) {
|
static __device__ __forceinline__ int get_j(const int l) {
|
||||||
const int ret = (l%2) * (I/2) + threadIdx.x / K;
|
if constexpr (I == 8 && J == 4) {
|
||||||
GGML_CUDA_ASSUME(ret >= 0);
|
return threadIdx.x % 4;
|
||||||
GGML_CUDA_ASSUME(ret < I);
|
} else if constexpr (I == 8 && J == 8) {
|
||||||
return ret;
|
return 4 * l + threadIdx.x % 4;
|
||||||
}
|
} else if constexpr (I == 16 && J == 8) {
|
||||||
|
return 2 * (threadIdx.x % 4) + l % 2;
|
||||||
|
} else {
|
||||||
|
static_assert(I == -1 && J == -1, "template specialization not implemented");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
static __device__ __forceinline__ int get_k(const int /* l */) {
|
template <int I_, int J_>
|
||||||
const int ret = threadIdx.x % K;
|
struct tile<I_, J_, half2> {
|
||||||
GGML_CUDA_ASSUME(ret >= 0);
|
static constexpr int I = I_;
|
||||||
GGML_CUDA_ASSUME(ret < K);
|
static constexpr int J = J_;
|
||||||
return ret;
|
static constexpr int ne = I * J / WARP_SIZE;
|
||||||
}
|
half2 x[ne] = {{0.0f, 0.0f}};
|
||||||
|
|
||||||
__device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) {
|
static __device__ __forceinline__ int get_i(const int l) {
|
||||||
|
if constexpr (I == 8 && J == 8) {
|
||||||
|
return threadIdx.x / 4;
|
||||||
|
} else if constexpr (I == 16 && J == 4) {
|
||||||
|
return l * 8 + threadIdx.x / 4;
|
||||||
|
} else if constexpr (I == 16 && J == 8) {
|
||||||
|
return (l % 2) * 8 + threadIdx.x / 4;
|
||||||
|
} else {
|
||||||
|
static_assert(I == -1 && J == -1, "template specialization not implemented");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ int get_j(const int l) {
|
||||||
|
if constexpr (I == 8 && J == 8) {
|
||||||
|
return l * 4 + threadIdx.x % 4;
|
||||||
|
} else if constexpr (I == 16 && J == 4) {
|
||||||
|
return threadIdx.x % 4;
|
||||||
|
} else if constexpr (I == 16 && J == 8) {
|
||||||
|
return (l / 2) * 4 + threadIdx.x % 4;
|
||||||
|
} else {
|
||||||
|
static_assert(I == -1 && J == -1, "template specialization not implemented");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <int I, int J>
|
||||||
|
static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
|
||||||
|
tile<I, J/2, half2> ret;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int l = 0; l < ne; ++l) {
|
for (int l0 = 0; l0 < tile_float.ne; l0 += 2) {
|
||||||
x[l] = xs0[get_i(l)*stride + get_k(l)];
|
ret.x[l0/2] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]);
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ tile<8, 8, half2> get_transposed(const tile<16, 4, half2> & t) {
|
||||||
|
tile<8, 8, half2> ret;
|
||||||
|
ret.x[0] = ggml_cuda_movmatrix(t.x[0]);
|
||||||
|
ret.x[1] = ggml_cuda_movmatrix(t.x[1]);
|
||||||
|
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int I, int J, typename T>
|
||||||
|
static __device__ __forceinline__ void load_generic(tile<I, J, T> & t, const T * __restrict__ xs0, const int stride) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int l = 0; l < t.ne; ++l) {
|
||||||
|
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
|
template <typename T>
|
||||||
|
static __device__ __forceinline__ void load_ldmatrix(
|
||||||
|
tile<8, 8, T> & t, const T * __restrict__ xs0, const int stride) {
|
||||||
#ifdef NEW_MMA_AVAILABLE
|
#ifdef NEW_MMA_AVAILABLE
|
||||||
int * xi = (int *) x;
|
int * xi = (int *) t.x;
|
||||||
const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride;
|
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + ((threadIdx.x / t.I) * (t.J / 2)) % t.J;
|
||||||
asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
|
asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
|
||||||
: "+r"(xi[0]), "+r"(xi[1])
|
: "=r"(xi[0]), "=r"(xi[1])
|
||||||
|
: "l"(xs));
|
||||||
|
#else
|
||||||
|
load_generic(t, xs0, stride);
|
||||||
|
#endif // NEW_MMA_AVAILABLE
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static __device__ __forceinline__ void load_ldmatrix(
|
||||||
|
tile<16, 4, T> & t, const T * __restrict__ xs0, const int stride) {
|
||||||
|
#ifdef NEW_MMA_AVAILABLE
|
||||||
|
int * xi = (int *) t.x;
|
||||||
|
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride;
|
||||||
|
asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
|
||||||
|
: "=r"(xi[0]), "=r"(xi[1])
|
||||||
: "l"(xs));
|
: "l"(xs));
|
||||||
#else
|
#else
|
||||||
load_generic(xs0, stride);
|
load_generic(xs0, stride);
|
||||||
#endif // NEW_MMA_AVAILABLE
|
#endif // NEW_MMA_AVAILABLE
|
||||||
}
|
}
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct mma_A_I16K8 {
|
static __device__ __forceinline__ void load_ldmatrix(
|
||||||
static_assert(sizeof(T) == 4, "bad type size");
|
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
|
||||||
|
|
||||||
static constexpr int I = 16;
|
|
||||||
static constexpr int K = 8;
|
|
||||||
static constexpr int ne = 4;
|
|
||||||
|
|
||||||
T x[ne];
|
|
||||||
|
|
||||||
static __device__ __forceinline__ int get_i(const int l) {
|
|
||||||
const int ret = (l%2) * (I/2) + threadIdx.x / (K/2);
|
|
||||||
GGML_CUDA_ASSUME(ret >= 0);
|
|
||||||
GGML_CUDA_ASSUME(ret < I);
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
static __device__ __forceinline__ int get_k(const int l) {
|
|
||||||
const int ret = (l/2) * (K/2) + threadIdx.x % (K/2);
|
|
||||||
GGML_CUDA_ASSUME(ret >= 0);
|
|
||||||
GGML_CUDA_ASSUME(ret < K);
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) {
|
|
||||||
#pragma unroll
|
|
||||||
for (int l = 0; l < ne; ++l) {
|
|
||||||
x[l] = xs0[get_i(l)*stride + get_k(l)];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
|
|
||||||
#ifdef NEW_MMA_AVAILABLE
|
#ifdef NEW_MMA_AVAILABLE
|
||||||
int * xi = (int * ) x;
|
int * xi = (int * ) t.x;
|
||||||
const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
|
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
|
||||||
asm("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
|
asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
|
||||||
: "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3])
|
: "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3])
|
||||||
: "l"(xs));
|
: "l"(xs));
|
||||||
#else
|
#else
|
||||||
|
load_generic(t, xs0, stride);
|
||||||
|
#endif // NEW_MMA_AVAILABLE
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static __device__ __forceinline__ void load_ldmatrix_trans(
|
||||||
|
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
|
||||||
|
#ifdef NEW_MMA_AVAILABLE
|
||||||
|
int * xi = (int * ) t.x;
|
||||||
|
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
|
||||||
|
asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];"
|
||||||
|
: "=r"(xi[0]), "=r"(xi[2]), "=r"(xi[1]), "=r"(xi[3])
|
||||||
|
: "l"(xs));
|
||||||
|
#else
|
||||||
|
GGML_UNUSED(t);
|
||||||
GGML_UNUSED(xs0);
|
GGML_UNUSED(xs0);
|
||||||
GGML_UNUSED(stride);
|
GGML_UNUSED(stride);
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
#endif // NEW_MMA_AVAILABLE
|
#endif // NEW_MMA_AVAILABLE
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ __forceinline__ void load_ldmatrix_trans(const T * __restrict__ xs0, const int & stride) {
|
static __device__ __forceinline__ void mma(
|
||||||
#ifdef NEW_MMA_AVAILABLE
|
tile<16, 8, int> & D, const tile<16, 4, int> & A, const tile<8, 4, int> & B) {
|
||||||
int * xi = (int * ) x;
|
|
||||||
const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
|
|
||||||
asm("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];"
|
|
||||||
: "+r"(xi[0]), "+r"(xi[2]), "+r"(xi[1]), "+r"(xi[3])
|
|
||||||
: "l"(xs));
|
|
||||||
#else
|
|
||||||
GGML_UNUSED(xs0);
|
|
||||||
GGML_UNUSED(stride);
|
|
||||||
NO_DEVICE_CODE;
|
|
||||||
#endif // NEW_MMA_AVAILABLE
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ __forceinline__ void transpose() {
|
|
||||||
int * xi = (int *) x;
|
|
||||||
xi[0] = ggml_cuda_movmatrix(xi[0]);
|
|
||||||
|
|
||||||
const int tmp = ggml_cuda_movmatrix(xi[1]);
|
|
||||||
xi[1] = ggml_cuda_movmatrix(xi[2]);
|
|
||||||
xi[2] = tmp;
|
|
||||||
|
|
||||||
xi[3] = ggml_cuda_movmatrix(xi[3]);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
struct mma_B_J8K4 {
|
|
||||||
static_assert(sizeof(T) == 4, "bad type size");
|
|
||||||
|
|
||||||
static constexpr int J = 8;
|
|
||||||
static constexpr int K = 4;
|
|
||||||
static constexpr int ne = 1;
|
|
||||||
|
|
||||||
T x[ne];
|
|
||||||
|
|
||||||
static __device__ __forceinline__ int get_j(const int /* l */) {
|
|
||||||
const int ret = threadIdx.x / K;
|
|
||||||
GGML_CUDA_ASSUME(ret >= 0);
|
|
||||||
GGML_CUDA_ASSUME(ret < J);
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
static __device__ __forceinline__ int get_k(const int /* l */) {
|
|
||||||
const int ret = threadIdx.x % K;
|
|
||||||
GGML_CUDA_ASSUME(ret >= 0);
|
|
||||||
GGML_CUDA_ASSUME(ret < K);
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) {
|
|
||||||
#pragma unroll
|
|
||||||
for (int l = 0; l < ne; ++l) {
|
|
||||||
x[l] = xs0[get_j(l)*stride + get_k(l)];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
|
|
||||||
#ifdef NEW_MMA_AVAILABLE
|
|
||||||
int * xi = (int *) x;
|
|
||||||
const int * xs = (const int *) xs0 + (threadIdx.x%J)*stride;
|
|
||||||
asm("ldmatrix.sync.aligned.m8n8.x1.b16 {%0}, [%1];"
|
|
||||||
: "+r"(xi[0]) : "l"(xs));
|
|
||||||
#else
|
|
||||||
load_generic(xs0, stride);
|
|
||||||
#endif // NEW_MMA_AVAILABLE
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
struct mma_B_J8K8 {
|
|
||||||
static_assert(sizeof(T) == 4, "bad type size");
|
|
||||||
|
|
||||||
static constexpr int J = 8;
|
|
||||||
static constexpr int K = 8;
|
|
||||||
static constexpr int ne = 2;
|
|
||||||
|
|
||||||
T x[ne];
|
|
||||||
|
|
||||||
static __device__ __forceinline__ int get_j(const int /* l */) {
|
|
||||||
const int ret = threadIdx.x / (K/2);
|
|
||||||
GGML_CUDA_ASSUME(ret >= 0);
|
|
||||||
GGML_CUDA_ASSUME(ret < J);
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
static __device__ __forceinline__ int get_k(const int l) {
|
|
||||||
const int ret = l * (K/2) + threadIdx.x % (K/2);
|
|
||||||
GGML_CUDA_ASSUME(ret >= 0);
|
|
||||||
GGML_CUDA_ASSUME(ret < K);
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) {
|
|
||||||
#pragma unroll
|
|
||||||
for (int l = 0; l < ne; ++l) {
|
|
||||||
x[l] = xs0[get_j(l)*stride + get_k(l)];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
|
|
||||||
#ifdef NEW_MMA_AVAILABLE
|
|
||||||
int * xi = (int *) x;
|
|
||||||
const int * xs = (const int *) xs0 + (threadIdx.x%J)*stride + ((threadIdx.x/J)*(K/2)) % K;
|
|
||||||
asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
|
|
||||||
: "+r"(xi[0]), "+r"(xi[1])
|
|
||||||
: "l"(xs));
|
|
||||||
#else
|
|
||||||
load_generic(xs0, stride);
|
|
||||||
#endif // NEW_MMA_AVAILABLE
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
struct mma_C_I16J8 {};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct mma_C_I16J8<int> {
|
|
||||||
static constexpr int I = 16;
|
|
||||||
static constexpr int J = 8;
|
|
||||||
static constexpr int ne = 4;
|
|
||||||
|
|
||||||
int x[ne] = {0};
|
|
||||||
|
|
||||||
static __device__ __forceinline__ int get_i(const int l) {
|
|
||||||
const int ret = (l/2) * (I/2) + threadIdx.x / (J/2);
|
|
||||||
GGML_CUDA_ASSUME(ret >= 0);
|
|
||||||
GGML_CUDA_ASSUME(ret < I);
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
static __device__ __forceinline__ int get_j(const int l) {
|
|
||||||
const int ret = 2 * (threadIdx.x % (J/2)) + l%2;
|
|
||||||
GGML_CUDA_ASSUME(ret >= 0);
|
|
||||||
GGML_CUDA_ASSUME(ret < J);
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ __forceinline__ void mma(const mma_A_I16K4<int> & mma_A, const mma_B_J8K4<int> & mma_B) {
|
|
||||||
#ifdef NEW_MMA_AVAILABLE
|
#ifdef NEW_MMA_AVAILABLE
|
||||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||||
asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
|
asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
|
||||||
: "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
|
: "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
|
||||||
: "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_B.x[0]));
|
: "r"(A.x[0]), "r"(A.x[1]), "r"(B.x[0]));
|
||||||
#else
|
#else
|
||||||
// On Turing m16n8k16 mma is not available, use 2x m8n8k16 mma instead:
|
// On Turing m16n8k16 mma is not available, use 2x m8n8k16 mma instead:
|
||||||
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
|
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
|
||||||
: "+r"(x[0]), "+r"(x[1])
|
: "+r"(D.x[0]), "+r"(D.x[1])
|
||||||
: "r"(mma_A.x[0]), "r"(mma_B.x[0]));
|
: "r"(A.x[0]), "r"(B.x[0]));
|
||||||
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
|
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
|
||||||
: "+r"(x[2]), "+r"(x[3])
|
: "+r"(D.x[2]), "+r"(D.x[3])
|
||||||
: "r"(mma_A.x[1]), "r"(mma_B.x[0]));
|
: "r"(A.x[1]), "r"(B.x[0]));
|
||||||
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||||
#else
|
#else
|
||||||
GGML_UNUSED(mma_A);
|
GGML_UNUSED(D);
|
||||||
GGML_UNUSED(mma_B);
|
GGML_UNUSED(A);
|
||||||
|
GGML_UNUSED(B);
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
#endif // NEW_MMA_AVAILABLE
|
#endif // NEW_MMA_AVAILABLE
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ __forceinline__ void mma(const mma_A_I16K8<int> & mma_A, const mma_B_J8K8<int> & mma_B) {
|
static __device__ __forceinline__ void mma(
|
||||||
|
tile<16, 8, int> & D, const tile<16, 8, int> & A, const tile<8, 8, int> & B) {
|
||||||
#ifdef NEW_MMA_AVAILABLE
|
#ifdef NEW_MMA_AVAILABLE
|
||||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||||
asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
|
asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
|
||||||
: "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
|
: "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
|
||||||
: "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_A.x[2]), "r"(mma_A.x[3]), "r"(mma_B.x[0]), "r"(mma_B.x[1]));
|
: "r"(A.x[0]), "r"(A.x[1]), "r"(A.x[2]), "r"(A.x[3]), "r"(B.x[0]), "r"(B.x[1]));
|
||||||
#else
|
#else
|
||||||
// On Turing m16n8k32 mma is not available, use 4x m8n8k16 mma instead:
|
// On Turing m16n8k32 mma is not available, use 4x m8n8k16 mma instead:
|
||||||
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
|
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
|
||||||
: "+r"(x[0]), "+r"(x[1])
|
: "+r"(D.x[0]), "+r"(D.x[1])
|
||||||
: "r"(mma_A.x[0]), "r"(mma_B.x[0]));
|
: "r"(A.x[0]), "r"(B.x[0]));
|
||||||
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
|
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
|
||||||
: "+r"(x[2]), "+r"(x[3])
|
: "+r"(D.x[2]), "+r"(D.x[3])
|
||||||
: "r"(mma_A.x[1]), "r"(mma_B.x[0]));
|
: "r"(A.x[1]), "r"(B.x[0]));
|
||||||
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
|
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
|
||||||
: "+r"(x[0]), "+r"(x[1])
|
: "+r"(D.x[0]), "+r"(D.x[1])
|
||||||
: "r"(mma_A.x[2]), "r"(mma_B.x[1]));
|
: "r"(A.x[2]), "r"(B.x[1]));
|
||||||
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
|
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
|
||||||
: "+r"(x[2]), "+r"(x[3])
|
: "+r"(D.x[2]), "+r"(D.x[3])
|
||||||
: "r"(mma_A.x[3]), "r"(mma_B.x[1]));
|
: "r"(A.x[3]), "r"(B.x[1]));
|
||||||
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||||
#else
|
#else
|
||||||
GGML_UNUSED(mma_A);
|
GGML_UNUSED(D);
|
||||||
GGML_UNUSED(mma_B);
|
GGML_UNUSED(A);
|
||||||
|
GGML_UNUSED(B);
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
#endif // NEW_MMA_AVAILABLE
|
#endif // NEW_MMA_AVAILABLE
|
||||||
}
|
}
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
static __device__ __forceinline__ void mma(
|
||||||
struct mma_C_I16J8<half2> {
|
tile<16, 4, half2> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
|
||||||
static constexpr int I = 16;
|
|
||||||
static constexpr int J = 4;
|
|
||||||
static constexpr int ne = 2;
|
|
||||||
|
|
||||||
half2 x[ne] = {{0.0f, 0.0f}, {0.0f, 0.0f}};
|
|
||||||
|
|
||||||
static __device__ __forceinline__ int get_i(const int l) {
|
|
||||||
const int ret = l * (I/2) + threadIdx.x / J;
|
|
||||||
GGML_CUDA_ASSUME(ret >= 0);
|
|
||||||
GGML_CUDA_ASSUME(ret < I);
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
static __device__ __forceinline__ int get_j(const int /* l */) {
|
|
||||||
const int ret = threadIdx.x % J;
|
|
||||||
GGML_CUDA_ASSUME(ret >= 0);
|
|
||||||
GGML_CUDA_ASSUME(ret < J);
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ __forceinline__ void mma(const mma_A_I16K8<half2> & mma_A, const mma_B_J8K8<half2> & mma_B) {
|
|
||||||
#ifdef NEW_MMA_AVAILABLE
|
#ifdef NEW_MMA_AVAILABLE
|
||||||
int * Axi = (int *) mma_A.x;
|
const int * Axi = (const int *) A.x;
|
||||||
int * Bxi = (int *) mma_B.x;
|
const int * Bxi = (const int *) B.x;
|
||||||
int * xi = (int *) x;
|
int * Dxi = (int *) D.x;
|
||||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||||
asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
|
asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
|
||||||
: "+r"(xi[0]), "+r"(xi[1])
|
: "+r"(Dxi[0]), "+r"(Dxi[1])
|
||||||
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
|
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
|
||||||
#else
|
#else
|
||||||
// On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead:
|
// On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead:
|
||||||
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
|
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
|
||||||
: "+r"(xi[0]), "+r"(xi[1])
|
: "+r"(Dxi[0]), "+r"(Dxi[1])
|
||||||
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
|
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
|
||||||
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
|
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
|
||||||
: "+r"(xi[0]), "+r"(xi[1])
|
: "+r"(Dxi[0]), "+r"(Dxi[1])
|
||||||
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
|
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
|
||||||
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||||
#else
|
#else
|
||||||
GGML_UNUSED(mma_A);
|
GGML_UNUSED(D);
|
||||||
GGML_UNUSED(mma_B);
|
GGML_UNUSED(A);
|
||||||
|
GGML_UNUSED(B);
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
#endif // NEW_MMA_AVAILABLE
|
#endif // NEW_MMA_AVAILABLE
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ __forceinline__ mma_B_J8K8<half2> to_mma_B() {
|
static __device__ __forceinline__ void mma(
|
||||||
mma_B_J8K8<half2> mma_B;
|
tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
|
||||||
|
|
||||||
int * xi = (int *) x;
|
|
||||||
int * Bxi = (int *) mma_B.x;
|
|
||||||
Bxi[0] = ggml_cuda_movmatrix(xi[0]);
|
|
||||||
Bxi[1] = ggml_cuda_movmatrix(xi[1]);
|
|
||||||
|
|
||||||
return mma_B;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct mma_C_I16J8<float> {
|
|
||||||
static constexpr int I = 16;
|
|
||||||
static constexpr int J = 8;
|
|
||||||
static constexpr int ne = 4;
|
|
||||||
|
|
||||||
float x[ne] = {0.0f, 0.0f, 0.0f, 0.0f};
|
|
||||||
|
|
||||||
static __device__ __forceinline__ int get_i(const int l) {
|
|
||||||
const int ret = (l/2) * (I/2) + threadIdx.x / (J/2);
|
|
||||||
GGML_CUDA_ASSUME(ret >= 0);
|
|
||||||
GGML_CUDA_ASSUME(ret < I);
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
static __device__ __forceinline__ int get_j(const int l) {
|
|
||||||
const int ret = 2 * (threadIdx.x % (J/2)) + l%2;
|
|
||||||
GGML_CUDA_ASSUME(ret >= 0);
|
|
||||||
GGML_CUDA_ASSUME(ret < J);
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ __forceinline__ void mma(const mma_A_I16K8<half2> & mma_A, const mma_B_J8K8<half2> & mma_B) {
|
|
||||||
#ifdef NEW_MMA_AVAILABLE
|
#ifdef NEW_MMA_AVAILABLE
|
||||||
int * Axi = (int *) mma_A.x;
|
const int * Axi = (const int *) A.x;
|
||||||
int * Bxi = (int *) mma_B.x;
|
const int * Bxi = (const int *) B.x;
|
||||||
int * xi = (int *) x;
|
int * Dxi = (int *) D.x;
|
||||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||||
asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
|
asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
|
||||||
: "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3])
|
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
|
||||||
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
|
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
|
||||||
#else
|
#else
|
||||||
// On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead:
|
// On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead:
|
||||||
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
|
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
|
||||||
: "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3])
|
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
|
||||||
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
|
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
|
||||||
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
|
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
|
||||||
: "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3])
|
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
|
||||||
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
|
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
|
||||||
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||||
#else
|
#else
|
||||||
GGML_UNUSED(mma_A);
|
GGML_UNUSED(D);
|
||||||
GGML_UNUSED(mma_B);
|
GGML_UNUSED(A);
|
||||||
|
GGML_UNUSED(B);
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
#endif // NEW_MMA_AVAILABLE
|
#endif // NEW_MMA_AVAILABLE
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ __forceinline__ mma_B_J8K8<half2> to_mma_B() {
|
}
|
||||||
mma_B_J8K8<half2> mma_B;
|
|
||||||
mma_B.x[0] = make_half2(x[0], x[1]);
|
|
||||||
mma_B.x[1] = make_half2(x[2], x[3]);
|
|
||||||
|
|
||||||
int * Bxi = (int *) mma_B.x;
|
|
||||||
Bxi[0] = ggml_cuda_movmatrix(Bxi[0]);
|
|
||||||
Bxi[1] = ggml_cuda_movmatrix(Bxi[1]);
|
|
||||||
|
|
||||||
return mma_B;
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ __forceinline__ void load_generic(const float * __restrict__ xs0, const int & stride) {
|
|
||||||
#pragma unroll
|
|
||||||
for (int l = 0; l < ne; ++l) {
|
|
||||||
x[l] = xs0[get_j(l)*stride + get_i(l)];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
@ -7,6 +7,8 @@
|
|||||||
#include <climits>
|
#include <climits>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
|
||||||
|
using namespace ggml_cuda_mma;
|
||||||
|
|
||||||
#define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
|
#define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
|
||||||
#define MMQ_ITER_K 256
|
#define MMQ_ITER_K 256
|
||||||
#define MMQ_NWARPS 8
|
#define MMQ_NWARPS 8
|
||||||
@ -647,15 +649,15 @@ template <int mmq_x, int mmq_y, int nwarps, mmq_q8_1_ds_layout ds_layout>
|
|||||||
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
||||||
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
||||||
|
|
||||||
typedef mma_A_I16K8<int> mma_A;
|
typedef tile<16, 8, int> tile_A;
|
||||||
typedef mma_B_J8K8<int> mma_B;
|
typedef tile< 8, 8, int> tile_B;
|
||||||
typedef mma_C_I16J8<int> mma_C;
|
typedef tile<16, 8, int> tile_C;
|
||||||
|
|
||||||
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
||||||
constexpr int rows_per_warp = 2 * granularity;
|
constexpr int rows_per_warp = 2 * granularity;
|
||||||
constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
|
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
||||||
|
|
||||||
y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
|
y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
|
||||||
|
|
||||||
const int * x_qs = (const int *) x;
|
const int * x_qs = (const int *) x;
|
||||||
const float * x_df = (const float *) x_qs + 2*WARP_SIZE;
|
const float * x_df = (const float *) x_qs + 2*WARP_SIZE;
|
||||||
@ -663,8 +665,8 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
|||||||
const float * y_df = (const float *) y;
|
const float * y_df = (const float *) y;
|
||||||
const half2 * y_ds = (const half2 *) y;
|
const half2 * y_ds = (const half2 *) y;
|
||||||
|
|
||||||
mma_A A[ntx][WARP_SIZE/QI8_0];
|
tile_A A[ntx][WARP_SIZE/QI8_0];
|
||||||
float dA[ntx][mma_C::ne/2][WARP_SIZE/QI8_0];
|
float dA[ntx][tile_C::ne/2][WARP_SIZE/QI8_0];
|
||||||
|
|
||||||
const int i0 = (threadIdx.y/ntx)*rows_per_warp;
|
const int i0 = (threadIdx.y/ntx)*rows_per_warp;
|
||||||
|
|
||||||
@ -674,12 +676,12 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
|||||||
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
|
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
|
||||||
const int k0 = k00 + k01;
|
const int k0 = k00 + k01;
|
||||||
|
|
||||||
A[n][k01/QI8_0].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
|
load_ldmatrix(A[n][k01/QI8_0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
for (int l = 0; l < tile_C::ne/2; ++l) {
|
||||||
const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
|
const int i = i0 + n*tile_A::I + tile_C::get_i(2*l);
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
|
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
|
||||||
@ -691,17 +693,17 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
|||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
|
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
|
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
|
||||||
mma_B B;
|
tile_B B;
|
||||||
float dB[mma_C::ne/2];
|
float dB[tile_C::ne/2];
|
||||||
|
|
||||||
B.load_generic(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
|
load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
for (int l = 0; l < tile_C::ne/2; ++l) {
|
||||||
const int j = j0 + mma_C::get_j(l);
|
const int j = j0 + tile_C::get_j(l);
|
||||||
|
|
||||||
if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
|
if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
|
||||||
dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
|
dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
|
||||||
@ -712,12 +714,12 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
|||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int n = 0; n < ntx; ++n) {
|
for (int n = 0; n < ntx; ++n) {
|
||||||
mma_C C;
|
tile_C C;
|
||||||
C.mma(A[n][k01/QI8_0], B);
|
mma(C, A[n][k01/QI8_0], B);
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int l = 0; l < mma_C::ne; ++l) {
|
for (int l = 0; l < tile_C::ne; ++l) {
|
||||||
sum[(j0/mma_C::J + n)*mma_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2];
|
sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -758,23 +760,23 @@ template <int mmq_x, int mmq_y, int nwarps>
|
|||||||
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
|
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
|
||||||
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
||||||
|
|
||||||
typedef mma_A_I16K8<int> mma_A;
|
typedef tile<16, 8, int> tile_A;
|
||||||
typedef mma_B_J8K8<int> mma_B;
|
typedef tile< 8, 8, int> tile_B;
|
||||||
typedef mma_C_I16J8<int> mma_C;
|
typedef tile<16, 8, int> tile_C;
|
||||||
|
|
||||||
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
||||||
constexpr int rows_per_warp = 2 * granularity;
|
constexpr int rows_per_warp = 2 * granularity;
|
||||||
constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
|
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
||||||
|
|
||||||
y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
|
y += (threadIdx.y % ntx) * (tile_B::J*MMQ_TILE_Y_K);
|
||||||
|
|
||||||
const int * x_qs = (const int *) x;
|
const int * x_qs = (const int *) x;
|
||||||
const half2 * x_dm = (const half2 *) x_qs + 2*WARP_SIZE;
|
const half2 * x_dm = (const half2 *) x_qs + 2*WARP_SIZE;
|
||||||
const int * y_qs = (const int *) y + 4;
|
const int * y_qs = (const int *) y + 4;
|
||||||
const half2 * y_dm = (const half2 *) y;
|
const half2 * y_dm = (const half2 *) y;
|
||||||
|
|
||||||
mma_A A[ntx][WARP_SIZE/QI8_1];
|
tile_A A[ntx][WARP_SIZE/QI8_1];
|
||||||
float2 dmA[ntx][mma_C::ne/2][WARP_SIZE/QI8_1];
|
float2 dmA[ntx][tile_C::ne/2][WARP_SIZE/QI8_1];
|
||||||
|
|
||||||
const int i0 = (threadIdx.y/ntx)*rows_per_warp;
|
const int i0 = (threadIdx.y/ntx)*rows_per_warp;
|
||||||
|
|
||||||
@ -784,12 +786,12 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
|
|||||||
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
|
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
|
||||||
const int k0 = k00 + k01;
|
const int k0 = k00 + k01;
|
||||||
|
|
||||||
A[n][k01/QI8_1].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
|
load_ldmatrix(A[n][k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
for (int l = 0; l < tile_C::ne/2; ++l) {
|
||||||
const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
|
const int i = i0 + n*tile_A::I + tile_C::get_i(2*l);
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
|
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
|
||||||
@ -801,30 +803,30 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
|
|||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
|
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
|
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
|
||||||
mma_B B;
|
tile_B B;
|
||||||
float2 dsB[mma_C::ne/2];
|
float2 dsB[tile_C::ne/2];
|
||||||
|
|
||||||
B.load_generic(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
|
load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
for (int l = 0; l < tile_C::ne/2; ++l) {
|
||||||
const int j = j0 + mma_C::get_j(l);
|
const int j = j0 + tile_C::get_j(l);
|
||||||
|
|
||||||
dsB[l] = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
dsB[l] = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int n = 0; n < ntx; ++n) {
|
for (int n = 0; n < ntx; ++n) {
|
||||||
mma_C C;
|
tile_C C;
|
||||||
C.mma(A[n][k01/QI8_1], B);
|
mma(C, A[n][k01/QI8_1], B);
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int l = 0; l < mma_C::ne; ++l) {
|
for (int l = 0; l < tile_C::ne; ++l) {
|
||||||
sum[(j0/mma_C::J + n)*mma_C::ne + l] += dmA[n][l/2][k01/QI8_1].x*dsB[l%2].x*C.x[l];
|
sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2][k01/QI8_1].x*dsB[l%2].x*C.x[l];
|
||||||
sum[(j0/mma_C::J + n)*mma_C::ne + l] += dmA[n][l/2][k01/QI8_1].y*dsB[l%2].y;
|
sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2][k01/QI8_1].y*dsB[l%2].y;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -868,26 +870,26 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
|||||||
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
||||||
#ifdef NEW_MMA_AVAILABLE
|
#ifdef NEW_MMA_AVAILABLE
|
||||||
|
|
||||||
typedef mma_A_I16K4<int> mma_A;
|
typedef tile<16, 4, int> tile_A;
|
||||||
typedef mma_A_I16K8<int> mma_A_K8;
|
typedef tile<16, 8, int> tile_A_8;
|
||||||
typedef mma_B_J8K4<int> mma_B;
|
typedef tile< 8, 4, int> tile_B;
|
||||||
typedef mma_C_I16J8<int> mma_C;
|
typedef tile<16, 8, int> tile_C;
|
||||||
|
|
||||||
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
||||||
constexpr int rows_per_warp = 2 * granularity;
|
constexpr int rows_per_warp = 2 * granularity;
|
||||||
constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
|
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
||||||
|
|
||||||
y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
|
y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
|
||||||
|
|
||||||
const int * x_qs = (const int *) x;
|
const int * x_qs = (const int *) x;
|
||||||
const float * x_df = (const float *) x_qs + WARP_SIZE*2;
|
const float * x_df = (const float *) x_qs + WARP_SIZE*2;
|
||||||
const int * y_qs = (const int *) y + 4;
|
const int * y_qs = (const int *) y + 4;
|
||||||
const float * y_df = (const float *) y;
|
const float * y_df = (const float *) y;
|
||||||
|
|
||||||
const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
|
const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I);
|
||||||
|
|
||||||
mma_A A[ntx][8];
|
tile_A A[ntx][8];
|
||||||
float dA[ntx][mma_C::ne/2][8];
|
float dA[ntx][tile_C::ne/2][8];
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int n = 0; n < ntx; ++n) {
|
for (int n = 0; n < ntx; ++n) {
|
||||||
@ -895,12 +897,12 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
|||||||
for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
|
for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
|
||||||
const int k0 = k00 + k01;
|
const int k0 = k00 + k01;
|
||||||
|
|
||||||
((mma_A_K8 *) A[n])[k01/8].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
|
load_ldmatrix(((tile_A_8 *) A[n])[k01/8], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
for (int l = 0; l < tile_C::ne/2; ++l) {
|
||||||
const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
|
const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k01 = 0; k01 < WARP_SIZE; k01 += 4) {
|
for (int k01 = 0; k01 < WARP_SIZE; k01 += 4) {
|
||||||
@ -912,32 +914,32 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
|||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
|
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
|
for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
|
||||||
mma_B B[2];
|
tile_B B[2];
|
||||||
float dB[mma_C::ne/2];
|
float dB[tile_C::ne/2];
|
||||||
|
|
||||||
// Here load_generic is faster than load_ldmatrix.
|
// Here load_generic is faster than load_ldmatrix.
|
||||||
B[0].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
|
load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
|
||||||
B[1].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K);
|
load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K);
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
for (int l = 0; l < tile_C::ne/2; ++l) {
|
||||||
const int j = j0 + mma_C::get_j(l);
|
const int j = j0 + tile_C::get_j(l);
|
||||||
|
|
||||||
dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
|
dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int n = 0; n < ntx; ++n) {
|
for (int n = 0; n < ntx; ++n) {
|
||||||
mma_C C[2];
|
tile_C C[2];
|
||||||
C[0].mma(A[n][k01/4 + 0], B[0]);
|
mma(C[0], A[n][k01/4 + 0], B[0]);
|
||||||
C[1].mma(A[n][k01/4 + 1], B[1]);
|
mma(C[1], A[n][k01/4 + 1], B[1]);
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int l = 0; l < mma_C::ne; ++l) {
|
for (int l = 0; l < tile_C::ne; ++l) {
|
||||||
sum[(j0/mma_C::J + n)*mma_C::ne + l] += dB[l%2]*(C[0].x[l]*dA[n][l/2][k01/4 + 0] + C[1].x[l]*dA[n][l/2][k01/4 + 1]);
|
sum[(j0/tile_C::J + n)*tile_C::ne + l] += dB[l%2]*(C[0].x[l]*dA[n][l/2][k01/4 + 0] + C[1].x[l]*dA[n][l/2][k01/4 + 1]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1056,27 +1058,27 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
|||||||
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
||||||
#ifdef NEW_MMA_AVAILABLE
|
#ifdef NEW_MMA_AVAILABLE
|
||||||
|
|
||||||
typedef mma_A_I16K4<int> mma_A;
|
typedef tile<16, 4, int> tile_A;
|
||||||
typedef mma_A_I16K8<int> mma_A_K8;
|
typedef tile<16, 8, int> tile_A_8;
|
||||||
typedef mma_B_J8K4<int> mma_B;
|
typedef tile< 8, 4, int> tile_B;
|
||||||
typedef mma_C_I16J8<int> mma_C;
|
typedef tile<16, 8, int> tile_C;
|
||||||
|
|
||||||
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
||||||
constexpr int rows_per_warp = 2 * granularity;
|
constexpr int rows_per_warp = 2 * granularity;
|
||||||
constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
|
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
||||||
|
|
||||||
y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
|
y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
|
||||||
|
|
||||||
const int * x_qs = (const int *) x;
|
const int * x_qs = (const int *) x;
|
||||||
const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE*2;
|
const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE*2;
|
||||||
const int * y_qs = (const int *) y + 4;
|
const int * y_qs = (const int *) y + 4;
|
||||||
const half2 * y_ds = (const half2 *) y;
|
const half2 * y_ds = (const half2 *) y;
|
||||||
|
|
||||||
const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
|
const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I);
|
||||||
|
|
||||||
mma_A A[ntx][8];
|
tile_A A[ntx][8];
|
||||||
float dA[ntx][mma_C::ne/2][8];
|
float dA[ntx][tile_C::ne/2][8];
|
||||||
float mA[ntx][mma_C::ne/2][8];
|
float mA[ntx][tile_C::ne/2][8];
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int n = 0; n < ntx; ++n) {
|
for (int n = 0; n < ntx; ++n) {
|
||||||
@ -1084,15 +1086,15 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
|||||||
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
|
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
|
||||||
const int k0 = k00 + k01;
|
const int k0 = k00 + k01;
|
||||||
|
|
||||||
((mma_A_K8 *) A[n])[k01/QI8_1].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
|
load_ldmatrix(((tile_A_8 *) A[n])[k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int n = 0; n < ntx; ++n) {
|
for (int n = 0; n < ntx; ++n) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
for (int l = 0; l < tile_C::ne/2; ++l) {
|
||||||
const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
|
const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1/2) {
|
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1/2) {
|
||||||
@ -1107,58 +1109,58 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
|||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
|
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
||||||
float2 dB[mma_C::ne/2];
|
float2 dB[tile_C::ne/2];
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
for (int l = 0; l < tile_C::ne/2; ++l) {
|
||||||
const int j = j0 + mma_C::get_j(l);
|
const int j = j0 + tile_C::get_j(l);
|
||||||
|
|
||||||
dB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K]);
|
dB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K]);
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
|
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
|
||||||
mma_B B[2];
|
tile_B B[2];
|
||||||
|
|
||||||
// Here load_generic is faster than load_ldmatrix.
|
// Here load_generic is faster than load_ldmatrix.
|
||||||
B[0].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
|
load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
|
||||||
B[1].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K);
|
load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K);
|
||||||
|
|
||||||
mma_C Cm[2];
|
tile_C Cm[2];
|
||||||
if (k01 >= WARP_SIZE * 3/4) {
|
if (k01 >= WARP_SIZE * 3/4) {
|
||||||
mma_A A1;
|
tile_A A1;
|
||||||
A1.x[0] = 0x01010101;
|
A1.x[0] = 0x01010101;
|
||||||
A1.x[1] = 0x01010101;
|
A1.x[1] = 0x01010101;
|
||||||
Cm[0].mma(A1, B[0]);
|
mma(Cm[0], A1, B[0]);
|
||||||
Cm[1].mma(A1, B[1]);
|
mma(Cm[1], A1, B[1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int n = 0; n < ntx; ++n) {
|
for (int n = 0; n < ntx; ++n) {
|
||||||
mma_C Cd[2];
|
tile_C Cd[2];
|
||||||
|
|
||||||
Cd[0].mma(A[n][k01/4 + 0], B[0]);
|
mma(Cd[0], A[n][k01/4 + 0], B[0]);
|
||||||
Cd[1].mma(A[n][k01/4 + 1], B[1]);
|
mma(Cd[1], A[n][k01/4 + 1], B[1]);
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int l = 0; l < mma_C::ne; ++l) {
|
for (int l = 0; l < tile_C::ne; ++l) {
|
||||||
float tmp = Cd[0].x[l]*dA[n][l/2][k01/4 + 0] + Cd[1].x[l]*dA[n][l/2][k01/4 + 1];
|
float tmp = Cd[0].x[l]*dA[n][l/2][k01/4 + 0] + Cd[1].x[l]*dA[n][l/2][k01/4 + 1];
|
||||||
if (k01 >= WARP_SIZE * 3/4) {
|
if (k01 >= WARP_SIZE * 3/4) {
|
||||||
tmp -= Cm[0].x[l]*mA[n][l/2][k01/4 + 0] + Cm[1].x[l]*mA[n][l/2][k01/4 + 1];
|
tmp -= Cm[0].x[l]*mA[n][l/2][k01/4 + 0] + Cm[1].x[l]*mA[n][l/2][k01/4 + 1];
|
||||||
}
|
}
|
||||||
sum[(j0/mma_C::J + n)*mma_C::ne + l] += tmp*(k01 < WARP_SIZE/2 ? dB[l%2].x : dB[l%2].y);
|
sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*(k01 < WARP_SIZE/2 ? dB[l%2].x : dB[l%2].y);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k01 = 0; k01 < WARP_SIZE * 3/4; k01 += QI8_1) {
|
for (int k01 = 0; k01 < WARP_SIZE * 3/4; k01 += QI8_1) {
|
||||||
float2 sB[mma_C::ne/2];
|
float2 sB[tile_C::ne/2];
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
for (int l = 0; l < tile_C::ne/2; ++l) {
|
||||||
const int j = j0 + mma_C::get_j(l);
|
const int j = j0 + tile_C::get_j(l);
|
||||||
|
|
||||||
sB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
|
sB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
|
||||||
}
|
}
|
||||||
@ -1166,9 +1168,9 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int n = 0; n < ntx; ++n) {
|
for (int n = 0; n < ntx; ++n) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int l = 0; l < mma_C::ne; ++l) {
|
for (int l = 0; l < tile_C::ne; ++l) {
|
||||||
sum[(j0/mma_C::J + n)*mma_C::ne + l] -= mA[n][l/2][k01/4 + 0]*sB[l%2].x;
|
sum[(j0/tile_C::J + n)*tile_C::ne + l] -= mA[n][l/2][k01/4 + 0]*sB[l%2].x;
|
||||||
sum[(j0/mma_C::J + n)*mma_C::ne + l] -= mA[n][l/2][k01/4 + 1]*sB[l%2].y;
|
sum[(j0/tile_C::J + n)*tile_C::ne + l] -= mA[n][l/2][k01/4 + 1]*sB[l%2].y;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1708,15 +1710,15 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
|||||||
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
||||||
#ifdef NEW_MMA_AVAILABLE
|
#ifdef NEW_MMA_AVAILABLE
|
||||||
|
|
||||||
typedef mma_A_I16K4<int> mma_A;
|
typedef tile<16, 4, int> tile_A;
|
||||||
typedef mma_B_J8K4<int> mma_B;
|
typedef tile< 8, 4, int> tile_B;
|
||||||
typedef mma_C_I16J8<int> mma_C;
|
typedef tile<16, 8, int> tile_C;
|
||||||
|
|
||||||
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
||||||
constexpr int rows_per_warp = 2 * granularity;
|
constexpr int rows_per_warp = 2 * granularity;
|
||||||
constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
|
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
||||||
|
|
||||||
y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
|
y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
|
||||||
|
|
||||||
const int * x_qs = (const int *) x;
|
const int * x_qs = (const int *) x;
|
||||||
const float * x_df = (const float *) x_qs + WARP_SIZE*2;
|
const float * x_df = (const float *) x_qs + WARP_SIZE*2;
|
||||||
@ -1724,11 +1726,11 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
|||||||
const int * y_qs = (const int *) y + 4;
|
const int * y_qs = (const int *) y + 4;
|
||||||
const float * y_df = (const float *) y;
|
const float * y_df = (const float *) y;
|
||||||
|
|
||||||
const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
|
const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I);
|
||||||
|
|
||||||
mma_A A[ntx][8];
|
tile_A A[ntx][8];
|
||||||
int scA[ntx][mma_C::ne/2][8];
|
int scA[ntx][tile_C::ne/2][8];
|
||||||
float dA[ntx][mma_C::ne/2];
|
float dA[ntx][tile_C::ne/2];
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int n = 0; n < ntx; ++n) {
|
for (int n = 0; n < ntx; ++n) {
|
||||||
@ -1736,8 +1738,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
|||||||
for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
|
for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
|
||||||
const int k0 = k00 + k01;
|
const int k0 = k00 + k01;
|
||||||
|
|
||||||
A[n][k01/4 + 0].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K);
|
load_ldmatrix(A[n][k01/4 + 0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K);
|
||||||
A[n][k01/4 + 1].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + mma_A::K), MMQ_MMA_TILE_X_K_Q6_K);
|
load_ldmatrix(A[n][k01/4 + 1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + tile_A::J), MMQ_MMA_TILE_X_K_Q6_K);
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
@ -1745,8 +1747,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
|||||||
const int k0 = k00 + k01;
|
const int k0 = k00 + k01;
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
for (int l = 0; l < tile_C::ne/2; ++l) {
|
||||||
const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
|
const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
|
||||||
|
|
||||||
const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k0/16];
|
const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k0/16];
|
||||||
const int8_t * sc = (const int8_t *) &sc_packed;
|
const int8_t * sc = (const int8_t *) &sc_packed;
|
||||||
@ -1759,41 +1761,41 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
|||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
for (int l = 0; l < tile_C::ne/2; ++l) {
|
||||||
const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
|
const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
|
||||||
|
|
||||||
dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K];
|
dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
|
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
||||||
float tmp[ntx][mma_C::ne] = {{0.0f}};
|
float tmp[ntx][tile_C::ne] = {{0.0f}};
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
|
for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
|
||||||
mma_B B[2];
|
tile_B B[2];
|
||||||
float dB[mma_C::ne/2];
|
float dB[tile_C::ne/2];
|
||||||
|
|
||||||
// Here load_generic is faster than load_ldmatrix.
|
// Here load_generic is faster than load_ldmatrix.
|
||||||
B[0].load_generic(y_qs + j0*MMQ_TILE_Y_K + 0 + k01, MMQ_TILE_Y_K);
|
load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + 0 + k01, MMQ_TILE_Y_K);
|
||||||
B[1].load_generic(y_qs + j0*MMQ_TILE_Y_K + mma_B::K + k01, MMQ_TILE_Y_K);
|
load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + tile_B::J + k01, MMQ_TILE_Y_K);
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
for (int l = 0; l < tile_C::ne/2; ++l) {
|
||||||
const int j = j0 + mma_C::get_j(l);
|
const int j = j0 + tile_C::get_j(l);
|
||||||
|
|
||||||
dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
|
dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int n = 0; n < ntx; ++n) {
|
for (int n = 0; n < ntx; ++n) {
|
||||||
mma_C C[2];
|
tile_C C[2];
|
||||||
C[0].mma(A[n][k01/4 + 0], B[0]);
|
mma(C[0], A[n][k01/4 + 0], B[0]);
|
||||||
C[1].mma(A[n][k01/4 + 1], B[1]);
|
mma(C[1], A[n][k01/4 + 1], B[1]);
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int l = 0; l < mma_C::ne; ++l) {
|
for (int l = 0; l < tile_C::ne; ++l) {
|
||||||
tmp[n][l] += (C[0].x[l]*scA[n][l/2][k01/4 + 0] + C[1].x[l]*scA[n][l/2][k01/4 + 1])*dB[l%2];
|
tmp[n][l] += (C[0].x[l]*scA[n][l/2][k01/4 + 0] + C[1].x[l]*scA[n][l/2][k01/4 + 1])*dB[l%2];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1802,8 +1804,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int n = 0; n < ntx; ++n) {
|
for (int n = 0; n < ntx; ++n) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int l = 0; l < mma_C::ne; ++l) {
|
for (int l = 0; l < tile_C::ne; ++l) {
|
||||||
sum[(j0/mma_C::J + n)*mma_C::ne + l] += tmp[n][l]*dA[n][l/2];
|
sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp[n][l]*dA[n][l/2];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -2312,36 +2314,36 @@ template<int mmq_x, int mmq_y, int nwarps, bool need_check>
|
|||||||
static __device__ __forceinline__ void mmq_write_back_mma(
|
static __device__ __forceinline__ void mmq_write_back_mma(
|
||||||
const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
|
const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
|
||||||
|
|
||||||
typedef mma_C_I16J8<int> mma_C;
|
typedef tile<16, 8, int> tile_C;
|
||||||
|
|
||||||
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
||||||
constexpr int rows_per_warp = 2 * granularity;
|
constexpr int rows_per_warp = 2 * granularity;
|
||||||
constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
|
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
||||||
|
|
||||||
const int i0 = (threadIdx.y / ntx) * (ntx*mma_C::I);
|
const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I);
|
||||||
#ifdef NEW_MMA_AVAILABLE
|
#ifdef NEW_MMA_AVAILABLE
|
||||||
static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y");
|
static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y");
|
||||||
#endif // NEW_MMA_AVAILABLE
|
#endif // NEW_MMA_AVAILABLE
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
|
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int n = 0; n < ntx; ++n) {
|
for (int n = 0; n < ntx; ++n) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int l = 0; l < mma_C::ne; ++l) {
|
for (int l = 0; l < tile_C::ne; ++l) {
|
||||||
const int j = j0 + (threadIdx.y % ntx) * mma_C::J + mma_C::get_j(l);
|
const int j = j0 + (threadIdx.y % ntx) * tile_C::J + tile_C::get_j(l);
|
||||||
|
|
||||||
if (j > j_max) {
|
if (j > j_max) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int i = i0 + n*mma_C::I + mma_C::get_i(l);
|
const int i = i0 + n*tile_C::I + tile_C::get_i(l);
|
||||||
|
|
||||||
if (need_check && i > i_max) {
|
if (need_check && i > i_max) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
dst[j*stride + i] = sum[(j0/mma_C::J + n)*mma_C::ne + l];
|
dst[j*stride + i] = sum[(j0/tile_C::J + n)*tile_C::ne + l];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user