CUDA: refactor mmq, dmmv, mmvq (llama/7716)

* CUDA: refactor mmq, dmmv, mmvq

* fix out-of-bounds write

* struct for qk, qr, qi

* fix cmake build

* mmq_type_traits
This commit is contained in:
Johannes Gäßler 2024-06-05 16:53:00 +02:00 committed by Georgi Gerganov
parent abab4500fa
commit e08c62149b
110 changed files with 1778 additions and 1767 deletions

View File

@ -123,12 +123,18 @@ typedef sycl::half2 ggml_half2;
#define QI1_S (QK_K / (4*QR1_S)) #define QI1_S (QK_K / (4*QR1_S))
#define QR1_S 8 #define QR1_S 8
#define QI1_M (QK_K / (4*QR1_M))
#define QR1_M 8
#define QI4_NL (QK4_NL / (4*QR4_NL)) #define QI4_NL (QK4_NL / (4*QR4_NL))
#define QR4_NL 2 #define QR4_NL 2
#define QI4_XS (QK_K / (4*QR4_XS)) #define QI4_XS (QK_K / (4*QR4_XS))
#define QR4_XS 8 #define QR4_XS 8
#define QI3_S (QK_K / (4*QR3_S))
#define QR3_S 8
#endif // GGML_COMMON_DECL_CUDA || GGML_COMMON_DECL_HIP #endif // GGML_COMMON_DECL_CUDA || GGML_COMMON_DECL_HIP
#define QK4_0 32 #define QK4_0 32

View File

@ -633,88 +633,22 @@ GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device) {
// cuda split buffer // cuda split buffer
static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUDA_MAX_DEVICES> & tensor_split) { static int64_t get_row_rounding(const std::array<float, GGML_CUDA_MAX_DEVICES> & tensor_split) {
int64_t min_compute_capability = INT_MAX; int64_t row_rounding = 0;
int64_t max_compute_capability = INT_MIN;
for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) { for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
if (tensor_split[id] < (id + 1 < ggml_backend_cuda_get_device_count() ? tensor_split[id + 1] : 1.0f)) { if (tensor_split[id] >= (id + 1 < ggml_backend_cuda_get_device_count() ? tensor_split[id + 1] : 1.0f)) {
if (min_compute_capability > ggml_cuda_info().devices[id].cc) { continue;
min_compute_capability = ggml_cuda_info().devices[id].cc;
}
if (max_compute_capability < ggml_cuda_info().devices[id].cc) {
max_compute_capability = ggml_cuda_info().devices[id].cc;
}
} }
}
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) const int cc = ggml_cuda_info().devices[id].cc;
switch(type) { row_rounding = std::max(row_rounding, (int64_t)get_mmq_y_host(cc, get_mmq_x_max_host(cc)));
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
return max_compute_capability >= CC_RDNA2 ? 128 : 64;
case GGML_TYPE_F16:
case GGML_TYPE_F32:
return 1;
case GGML_TYPE_Q2_K:
return max_compute_capability >= CC_RDNA2 ? 128 : 32;
case GGML_TYPE_Q3_K:
return min_compute_capability < CC_RDNA2 ? 128 : 64;
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ1_M:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ3_S:
return max_compute_capability >= CC_RDNA2 ? 128 : 64;
default:
GGML_ASSERT(false);
} }
#else return row_rounding;
switch(type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
return max_compute_capability >= CC_VOLTA ? 128 : 64;
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
return 64;
case GGML_TYPE_F16:
case GGML_TYPE_F32:
return 1;
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ1_M:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ3_S:
return max_compute_capability >= CC_VOLTA ? 128 : 64;
case GGML_TYPE_Q6_K:
return 64;
default:
GGML_ASSERT(false);
}
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
} }
static void get_row_split(int64_t * row_low, int64_t * row_high, const ggml_tensor * tensor, const std::array<float, GGML_CUDA_MAX_DEVICES> & tensor_split, int id) { static void get_row_split(int64_t * row_low, int64_t * row_high, const ggml_tensor * tensor, const std::array<float, GGML_CUDA_MAX_DEVICES> & tensor_split, int id) {
const int64_t nrows = ggml_nrows(tensor); const int64_t nrows = ggml_nrows(tensor);
const int64_t rounding = get_row_rounding(tensor->type, tensor_split); const int64_t rounding = get_row_rounding(tensor_split);
*row_low = id == 0 ? 0 : nrows*tensor_split[id]; *row_low = id == 0 ? 0 : nrows*tensor_split[id];
*row_low -= *row_low % rounding; *row_low -= *row_low % rounding;
@ -1499,7 +1433,7 @@ static void ggml_cuda_op_mul_mat(
// for multi GPU, get the row boundaries from tensor split // for multi GPU, get the row boundaries from tensor split
// and round to mul_mat_q tile sizes // and round to mul_mat_q tile sizes
if (split) { if (split) {
const int64_t rounding = get_row_rounding(src0->type, tensor_split); const int64_t rounding = get_row_rounding(tensor_split);
if (id != 0) { if (id != 0) {
dev[id].row_low = ne01*tensor_split[id]; dev[id].row_low = ne01*tensor_split[id];

View File

@ -160,7 +160,7 @@
#endif #endif
#define MMVQ_MAX_BATCH_SIZE 8 // max batch size to use MMVQ kernels #define MMVQ_MAX_BATCH_SIZE 8 // max batch size to use MMVQ kernels
#define MMQ_MAX_BATCH_SIZE 32 // max batch size to use MMQ kernels when tensor cores are available #define MMQ_MAX_BATCH_SIZE 64 // max batch size to use MMQ kernels when tensor cores are available
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
@ -484,6 +484,161 @@ static __device__ __forceinline__ float get_alibi_slope(
return powf(base, exph); return powf(base, exph);
} }
template <ggml_type type>
struct ggml_cuda_type_traits;
template<>
struct ggml_cuda_type_traits<GGML_TYPE_F16> {
static constexpr int qk = 1;
static constexpr int qr = 1;
};
template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q4_0> {
static constexpr int qk = QK4_0;
static constexpr int qr = QR4_0;
static constexpr int qi = QI4_0;
};
template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q4_1> {
static constexpr int qk = QK4_1;
static constexpr int qr = QR4_1;
static constexpr int qi = QI4_1;
};
template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q5_0> {
static constexpr int qk = QK5_0;
static constexpr int qr = QR5_0;
static constexpr int qi = QI5_0;
};
template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q5_1> {
static constexpr int qk = QK5_1;
static constexpr int qr = QR5_1;
static constexpr int qi = QI5_1;
};
template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q8_0> {
static constexpr int qk = QK8_0;
static constexpr int qr = QR8_0;
static constexpr int qi = QI8_0;
};
template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q2_K> {
static constexpr int qk = QK_K;
static constexpr int qr = QR2_K;
static constexpr int qi = QI2_K;
};
template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q3_K> {
static constexpr int qk = QK_K;
static constexpr int qr = QR3_K;
static constexpr int qi = QI3_K;
};
template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q4_K> {
static constexpr int qk = QK_K;
static constexpr int qr = QR4_K;
static constexpr int qi = QI4_K;
};
template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q5_K> {
static constexpr int qk = QK_K;
static constexpr int qr = QR5_K;
static constexpr int qi = QI5_K;
};
template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q6_K> {
static constexpr int qk = QK_K;
static constexpr int qr = QR6_K;
static constexpr int qi = QI6_K;
};
template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ2_XXS> {
static constexpr int qk = QK_K;
static constexpr int qr = QR2_XXS;
static constexpr int qi = QI2_XXS;
};
template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ2_XS> {
static constexpr int qk = QK_K;
static constexpr int qr = QR2_XS;
static constexpr int qi = QI2_XS;
};
template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ2_S> {
static constexpr int qk = QK_K;
static constexpr int qr = QR2_S;
static constexpr int qi = QI2_S;
};
template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ3_XXS> {
static constexpr int qk = QK_K;
static constexpr int qr = QR3_XXS;
static constexpr int qi = QI3_XXS;
};
template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ1_S> {
static constexpr int qk = QK_K;
static constexpr int qr = QR1_S;
static constexpr int qi = QI1_S;
};
template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ1_M> {
static constexpr int qk = QK_K;
static constexpr int qr = QR1_M;
static constexpr int qi = QI1_M;
};
template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ4_NL> {
static constexpr int qk = QK4_NL;
static constexpr int qr = QR4_NL;
static constexpr int qi = QI4_NL;
};
template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ4_XS> {
static constexpr int qk = QK_K;
static constexpr int qr = QR4_XS;
static constexpr int qi = QI4_XS;
};
template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ3_S> {
static constexpr int qk = QK_K;
static constexpr int qr = QR3_S;
static constexpr int qi = QI3_S;
};
static int get_mmq_x_max_host(const int cc) {
#ifdef CUDA_USE_TENSOR_CORES
return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? MMQ_MAX_BATCH_SIZE : 64;
#else
return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? 128 : 64;
#endif // CUDA_USE_TENSOR_CORES
}
// Round rows to this value for --split-mode row:
static int get_mmq_y_host(const int cc, const int mmq_x) {
return cc >= CC_VOLTA && mmq_x >= 32 ? 128 : 64;
}
////////////////////// //////////////////////
struct ggml_cuda_device_info { struct ggml_cuda_device_info {

View File

@ -422,10 +422,22 @@ static __device__ void convert_f16(const void * vx, const int64_t ib, const int
v.y = x[ib + iqs + 1]; v.y = x[ib + iqs + 1];
} }
template <int qk, int qr, dequantize_kernel_t dequantize_kernel> static constexpr __device__ dequantize_kernel_t get_dequantize_kernel(ggml_type type) {
return type == GGML_TYPE_Q4_0 ? dequantize_q4_0 :
type == GGML_TYPE_Q4_1 ? dequantize_q4_1 :
type == GGML_TYPE_Q5_0 ? dequantize_q5_0 :
type == GGML_TYPE_Q5_1 ? dequantize_q5_1 :
type == GGML_TYPE_Q8_0 ? dequantize_q8_0 :
type == GGML_TYPE_F16 ? convert_f16 :
nullptr;
}
template <ggml_type type>
static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) { static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
// qk = quantized weights per x block constexpr int qk = ggml_cuda_type_traits<type>::qk; // quantized weights per x block
// qr = number of quantized weights per data value in x block constexpr int qr = ggml_cuda_type_traits<type>::qr; // number of quantized weights per data value in x block
constexpr dequantize_kernel_t dequantize_kernel = get_dequantize_kernel(type);
const int64_t row = (int64_t)blockIdx.x*blockDim.y + threadIdx.y; const int64_t row = (int64_t)blockIdx.x*blockDim.y + threadIdx.y;
if (row >= nrows) { if (row >= nrows) {
@ -493,7 +505,7 @@ static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y,
// the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead // the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead
const dim3 block_nums(block_num_y, 1, 1); const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0> dequantize_mul_mat_vec<GGML_TYPE_Q4_0>
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows); <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
} }
@ -502,7 +514,7 @@ static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y,
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1); const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1> dequantize_mul_mat_vec<GGML_TYPE_Q4_1>
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows); <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
} }
@ -511,7 +523,7 @@ static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y,
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1); const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0> dequantize_mul_mat_vec<GGML_TYPE_Q5_0>
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows); <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
} }
@ -520,7 +532,7 @@ static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y,
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1); const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1> dequantize_mul_mat_vec<GGML_TYPE_Q5_1>
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows); <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
} }
@ -529,7 +541,7 @@ static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y,
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1); const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0> dequantize_mul_mat_vec<GGML_TYPE_Q8_0>
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows); <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
} }
@ -580,7 +592,7 @@ static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, floa
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1); const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
dequantize_mul_mat_vec<1, 1, convert_f16> dequantize_mul_mat_vec<GGML_TYPE_F16>
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows); <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
} }

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,9 +1,47 @@
#include "mmvq.cuh" #include "mmvq.cuh"
#include "vecdotq.cuh" #include "vecdotq.cuh"
typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs); typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs);
template <int ncols_y, int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda> static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) {
return type == GGML_TYPE_Q4_0 ? vec_dot_q4_0_q8_1 :
type == GGML_TYPE_Q4_1 ? vec_dot_q4_1_q8_1 :
type == GGML_TYPE_Q5_0 ? vec_dot_q5_0_q8_1 :
type == GGML_TYPE_Q5_1 ? vec_dot_q5_1_q8_1 :
type == GGML_TYPE_Q8_0 ? vec_dot_q8_0_q8_1 :
type == GGML_TYPE_Q2_K ? vec_dot_q2_K_q8_1 :
type == GGML_TYPE_Q3_K ? vec_dot_q3_K_q8_1 :
type == GGML_TYPE_Q4_K ? vec_dot_q4_K_q8_1 :
type == GGML_TYPE_Q5_K ? vec_dot_q5_K_q8_1 :
type == GGML_TYPE_Q6_K ? vec_dot_q6_K_q8_1 :
type == GGML_TYPE_IQ2_XXS ? vec_dot_iq2_xxs_q8_1 :
type == GGML_TYPE_IQ2_XS ? vec_dot_iq2_xs_q8_1 :
type == GGML_TYPE_IQ2_S ? vec_dot_iq2_s_q8_1 :
type == GGML_TYPE_IQ3_XXS ? vec_dot_iq3_xxs_q8_1 :
type == GGML_TYPE_IQ1_S ? vec_dot_iq1_s_q8_1 :
type == GGML_TYPE_IQ1_M ? vec_dot_iq1_m_q8_1 :
type == GGML_TYPE_IQ4_NL ? vec_dot_iq4_nl_q8_1 :
type == GGML_TYPE_IQ4_XS ? vec_dot_iq4_xs_q8_1 :
type == GGML_TYPE_IQ3_S ? vec_dot_iq3_s_q8_1 :
nullptr;
}
static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
return type == GGML_TYPE_Q4_0 ? VDR_Q4_0_Q8_1_MMVQ :
type == GGML_TYPE_Q4_1 ? VDR_Q4_1_Q8_1_MMVQ :
type == GGML_TYPE_Q5_0 ? VDR_Q5_0_Q8_1_MMVQ :
type == GGML_TYPE_Q5_1 ? VDR_Q5_1_Q8_1_MMVQ :
type == GGML_TYPE_Q8_0 ? VDR_Q8_0_Q8_1_MMVQ :
type == GGML_TYPE_Q2_K ? VDR_Q2_K_Q8_1_MMVQ :
type == GGML_TYPE_Q3_K ? VDR_Q3_K_Q8_1_MMVQ :
type == GGML_TYPE_Q4_K ? VDR_Q4_K_Q8_1_MMVQ :
type == GGML_TYPE_Q5_K ? VDR_Q5_K_Q8_1_MMVQ :
type == GGML_TYPE_Q6_K ? VDR_Q6_K_Q8_1_MMVQ :
type == GGML_TYPE_IQ4_NL ? VDR_Q4_K_Q8_1_MMVQ :
1;
}
template <ggml_type type, int ncols_y>
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
// tell the compiler to use as many registers as it wants, see nwarps definition below // tell the compiler to use as many registers as it wants, see nwarps definition below
__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1) __launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
@ -12,6 +50,12 @@ static __global__ void mul_mat_vec_q(
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int qi = ggml_cuda_type_traits<type>::qi;
constexpr int vdr = get_vdr_mmvq(type);
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3)) #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
constexpr int nwarps = 1; constexpr int nwarps = 1;
constexpr int rows_per_cuda_block = 1; constexpr int rows_per_cuda_block = 1;
@ -29,7 +73,6 @@ static __global__ void mul_mat_vec_q(
// partial sum for each thread // partial sum for each thread
float tmp[ncols_y][rows_per_cuda_block] = {0.0f}; float tmp[ncols_y][rows_per_cuda_block] = {0.0f};
const block_q_t * x = (const block_q_t *) vx;
const block_q8_1 * y = (const block_q8_1 *) vy; const block_q8_1 * y = (const block_q8_1 *) vy;
for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) { for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
@ -42,8 +85,7 @@ static __global__ void mul_mat_vec_q(
for (int j = 0; j < ncols_y; ++j) { for (int j = 0; j < ncols_y; ++j) {
#pragma unroll #pragma unroll
for (int i = 0; i < rows_per_cuda_block; ++i) { for (int i = 0; i < rows_per_cuda_block; ++i) {
tmp[j][i] += vec_dot_q_cuda( tmp[j][i] += vec_dot_q_cuda(vx, &y[j*blocks_per_col_y + kby], (row0 + i)*blocks_per_row_x + kbx, kqs);
&x[kbx + (row0 + i)*blocks_per_row_x], &y[j*blocks_per_col_y + kby], kqs);
} }
} }
} }
@ -81,12 +123,12 @@ static __global__ void mul_mat_vec_q(
} }
} }
template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot> template <ggml_type type>
static void mul_mat_vec_q_cuda( static void mul_mat_vec_q_cuda(
const void * vx, const void * vy, float * dst, const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
GGML_ASSERT(ncols_x % qk == 0); GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE); GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE);
int id = ggml_cuda_get_device(); int id = ggml_cuda_get_device();
@ -124,36 +166,28 @@ static void mul_mat_vec_q_cuda(
switch (ncols_y) { switch (ncols_y) {
case 1: case 1:
mul_mat_vec_q<1, qk, qi, block_q_t, vdr, vec_dot> mul_mat_vec_q<type, 1><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
break; break;
case 2: case 2:
mul_mat_vec_q<2, qk, qi, block_q_t, vdr, vec_dot> mul_mat_vec_q<type, 2><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
break; break;
case 3: case 3:
mul_mat_vec_q<3, qk, qi, block_q_t, vdr, vec_dot> mul_mat_vec_q<type, 3><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
break; break;
case 4: case 4:
mul_mat_vec_q<4, qk, qi, block_q_t, vdr, vec_dot> mul_mat_vec_q<type, 4><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
break; break;
case 5: case 5:
mul_mat_vec_q<5, qk, qi, block_q_t, vdr, vec_dot> mul_mat_vec_q<type, 5><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
break; break;
case 6: case 6:
mul_mat_vec_q<6, qk, qi, block_q_t, vdr, vec_dot> mul_mat_vec_q<type, 6><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
break; break;
case 7: case 7:
mul_mat_vec_q<7, qk, qi, block_q_t, vdr, vec_dot> mul_mat_vec_q<type, 7><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
break; break;
case 8: case 8:
mul_mat_vec_q<8, qk, qi, block_q_t, vdr, vec_dot> mul_mat_vec_q<type, 8><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
break; break;
default: default:
GGML_ASSERT(false); GGML_ASSERT(false);
@ -165,152 +199,133 @@ static void mul_mat_vec_q4_0_q8_1_cuda(
const void * vx, const void * vy, float * dst, const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1> mul_mat_vec_q_cuda<GGML_TYPE_Q4_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
} }
static void mul_mat_vec_q4_1_q8_1_cuda( static void mul_mat_vec_q4_1_q8_1_cuda(
const void * vx, const void * vy, float * dst, const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1> mul_mat_vec_q_cuda<GGML_TYPE_Q4_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
} }
static void mul_mat_vec_q5_0_q8_1_cuda( static void mul_mat_vec_q5_0_q8_1_cuda(
const void * vx, const void * vy, float * dst, const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1> mul_mat_vec_q_cuda<GGML_TYPE_Q5_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
} }
static void mul_mat_vec_q5_1_q8_1_cuda( static void mul_mat_vec_q5_1_q8_1_cuda(
const void * vx, const void * vy, float * dst, const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1> mul_mat_vec_q_cuda<GGML_TYPE_Q5_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
} }
static void mul_mat_vec_q8_0_q8_1_cuda( static void mul_mat_vec_q8_0_q8_1_cuda(
const void * vx, const void * vy, float * dst, const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1> mul_mat_vec_q_cuda<GGML_TYPE_Q8_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
} }
static void mul_mat_vec_q2_K_q8_1_cuda( static void mul_mat_vec_q2_K_q8_1_cuda(
const void * vx, const void * vy, float * dst, const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1> mul_mat_vec_q_cuda<GGML_TYPE_Q2_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
} }
static void mul_mat_vec_q3_K_q8_1_cuda( static void mul_mat_vec_q3_K_q8_1_cuda(
const void * vx, const void * vy, float * dst, const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1> mul_mat_vec_q_cuda<GGML_TYPE_Q3_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
} }
static void mul_mat_vec_q4_K_q8_1_cuda( static void mul_mat_vec_q4_K_q8_1_cuda(
const void * vx, const void * vy, float * dst, const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1> mul_mat_vec_q_cuda<GGML_TYPE_Q4_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
} }
static void mul_mat_vec_q5_K_q8_1_cuda( static void mul_mat_vec_q5_K_q8_1_cuda(
const void * vx, const void * vy, float * dst, const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1> mul_mat_vec_q_cuda<GGML_TYPE_Q5_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
} }
static void mul_mat_vec_q6_K_q8_1_cuda( static void mul_mat_vec_q6_K_q8_1_cuda(
const void * vx, const void * vy, float * dst, const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1> mul_mat_vec_q_cuda<GGML_TYPE_Q6_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
} }
static void mul_mat_vec_iq2_xxs_q8_1_cuda( static void mul_mat_vec_iq2_xxs_q8_1_cuda(
const void * vx, const void * vy, float * dst, const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<QK_K, QI2_XXS, block_iq2_xxs, 1, vec_dot_iq2_xxs_q8_1> mul_mat_vec_q_cuda<GGML_TYPE_IQ2_XXS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
} }
static void mul_mat_vec_iq2_xs_q8_1_cuda( static void mul_mat_vec_iq2_xs_q8_1_cuda(
const void * vx, const void * vy, float * dst, const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<QK_K, QI2_XS, block_iq2_xs, 1, vec_dot_iq2_xs_q8_1> mul_mat_vec_q_cuda<GGML_TYPE_IQ2_XS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
} }
static void mul_mat_vec_iq2_s_q8_1_cuda( static void mul_mat_vec_iq2_s_q8_1_cuda(
const void * vx, const void * vy, float * dst, const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<QK_K, QI2_S, block_iq2_s, 1, vec_dot_iq2_s_q8_1> mul_mat_vec_q_cuda<GGML_TYPE_IQ2_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
} }
static void mul_mat_vec_iq3_xxs_q8_1_cuda( static void mul_mat_vec_iq3_xxs_q8_1_cuda(
const void * vx, const void * vy, float * dst, const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<QK_K, QI3_XXS, block_iq3_xxs, 1, vec_dot_iq3_xxs_q8_1> mul_mat_vec_q_cuda<GGML_TYPE_IQ3_XXS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
} }
static void mul_mat_vec_iq1_s_q8_1_cuda( static void mul_mat_vec_iq1_s_q8_1_cuda(
const void * vx, const void * vy, float * dst, const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<QK_K, QI1_S, block_iq1_s, 1, vec_dot_iq1_s_q8_1> mul_mat_vec_q_cuda<GGML_TYPE_IQ1_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
} }
static void mul_mat_vec_iq1_m_q8_1_cuda( static void mul_mat_vec_iq1_m_q8_1_cuda(
const void * vx, const void * vy, float * dst, const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<QK_K, QI1_S, block_iq1_m, 1, vec_dot_iq1_m_q8_1> mul_mat_vec_q_cuda<GGML_TYPE_IQ1_M>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
} }
static void mul_mat_vec_iq4_nl_q8_1_cuda( static void mul_mat_vec_iq4_nl_q8_1_cuda(
const void * vx, const void * vy, float * dst, const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<QK4_NL, QI4_NL, block_iq4_nl, VDR_Q4_0_Q8_1_MMVQ, vec_dot_iq4_nl_q8_1> mul_mat_vec_q_cuda<GGML_TYPE_IQ4_NL>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
} }
static void mul_mat_vec_iq4_xs_q8_1_cuda( static void mul_mat_vec_iq4_xs_q8_1_cuda(
const void * vx, const void * vy, float * dst, const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<QK_K, QI4_XS, block_iq4_xs, 1, vec_dot_iq4_xs_q8_1> mul_mat_vec_q_cuda<GGML_TYPE_IQ4_XS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
} }
static void mul_mat_vec_iq3_s_q8_1_cuda( static void mul_mat_vec_iq3_s_q8_1_cuda(
const void * vx, const void * vy, float * dst, const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<QK_K, QI3_XS, block_iq3_s, 1, vec_dot_iq3_s_q8_1> mul_mat_vec_q_cuda<GGML_TYPE_IQ3_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
} }
void ggml_cuda_op_mul_mat_vec_q( void ggml_cuda_op_mul_mat_vec_q(

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f16.cuh" #include "../fattn-vec-f16.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f16.cuh" #include "../fattn-vec-f16.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f16.cuh" #include "../fattn-vec-f16.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f16.cuh" #include "../fattn-vec-f16.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f16.cuh" #include "../fattn-vec-f16.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f16.cuh" #include "../fattn-vec-f16.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f16.cuh" #include "../fattn-vec-f16.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f16.cuh" #include "../fattn-vec-f16.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f16.cuh" #include "../fattn-vec-f16.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f16.cuh" #include "../fattn-vec-f16.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f16.cuh" #include "../fattn-vec-f16.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f16.cuh" #include "../fattn-vec-f16.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f16.cuh" #include "../fattn-vec-f16.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f16.cuh" #include "../fattn-vec-f16.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f16.cuh" #include "../fattn-vec-f16.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f16.cuh" #include "../fattn-vec-f16.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f16.cuh" #include "../fattn-vec-f16.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f16.cuh" #include "../fattn-vec-f16.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f16.cuh" #include "../fattn-vec-f16.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f16.cuh" #include "../fattn-vec-f16.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f16.cuh" #include "../fattn-vec-f16.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f16.cuh" #include "../fattn-vec-f16.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f16.cuh" #include "../fattn-vec-f16.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f16.cuh" #include "../fattn-vec-f16.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f16.cuh" #include "../fattn-vec-f16.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f16.cuh" #include "../fattn-vec-f16.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f16.cuh" #include "../fattn-vec-f16.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f16.cuh" #include "../fattn-vec-f16.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f16.cuh" #include "../fattn-vec-f16.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f16.cuh" #include "../fattn-vec-f16.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f16.cuh" #include "../fattn-vec-f16.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f16.cuh" #include "../fattn-vec-f16.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f16.cuh" #include "../fattn-vec-f16.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f16.cuh" #include "../fattn-vec-f16.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f16.cuh" #include "../fattn-vec-f16.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f16.cuh" #include "../fattn-vec-f16.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f16.cuh" #include "../fattn-vec-f16.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f16.cuh" #include "../fattn-vec-f16.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f16.cuh" #include "../fattn-vec-f16.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f16.cuh" #include "../fattn-vec-f16.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f16.cuh" #include "../fattn-vec-f16.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f16.cuh" #include "../fattn-vec-f16.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f16.cuh" #include "../fattn-vec-f16.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f32.cuh" #include "../fattn-vec-f32.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f32.cuh" #include "../fattn-vec-f32.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f32.cuh" #include "../fattn-vec-f32.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f32.cuh" #include "../fattn-vec-f32.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f32.cuh" #include "../fattn-vec-f32.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f32.cuh" #include "../fattn-vec-f32.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f32.cuh" #include "../fattn-vec-f32.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f32.cuh" #include "../fattn-vec-f32.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f32.cuh" #include "../fattn-vec-f32.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f32.cuh" #include "../fattn-vec-f32.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f32.cuh" #include "../fattn-vec-f32.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f32.cuh" #include "../fattn-vec-f32.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f32.cuh" #include "../fattn-vec-f32.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f32.cuh" #include "../fattn-vec-f32.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f32.cuh" #include "../fattn-vec-f32.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f32.cuh" #include "../fattn-vec-f32.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f32.cuh" #include "../fattn-vec-f32.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f32.cuh" #include "../fattn-vec-f32.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f32.cuh" #include "../fattn-vec-f32.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f32.cuh" #include "../fattn-vec-f32.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f32.cuh" #include "../fattn-vec-f32.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f32.cuh" #include "../fattn-vec-f32.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f32.cuh" #include "../fattn-vec-f32.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f32.cuh" #include "../fattn-vec-f32.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f32.cuh" #include "../fattn-vec-f32.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f32.cuh" #include "../fattn-vec-f32.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f32.cuh" #include "../fattn-vec-f32.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f32.cuh" #include "../fattn-vec-f32.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f32.cuh" #include "../fattn-vec-f32.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f32.cuh" #include "../fattn-vec-f32.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f32.cuh" #include "../fattn-vec-f32.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f32.cuh" #include "../fattn-vec-f32.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f32.cuh" #include "../fattn-vec-f32.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f32.cuh" #include "../fattn-vec-f32.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f32.cuh" #include "../fattn-vec-f32.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f32.cuh" #include "../fattn-vec-f32.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f32.cuh" #include "../fattn-vec-f32.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f32.cuh" #include "../fattn-vec-f32.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f32.cuh" #include "../fattn-vec-f32.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f32.cuh" #include "../fattn-vec-f32.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f32.cuh" #include "../fattn-vec-f32.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f32.cuh" #include "../fattn-vec-f32.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f32.cuh" #include "../fattn-vec-f32.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-wmma-f16.cuh" #include "../fattn-wmma-f16.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-wmma-f16.cuh" #include "../fattn-wmma-f16.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-wmma-f16.cuh" #include "../fattn-wmma-f16.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-wmma-f16.cuh" #include "../fattn-wmma-f16.cuh"

View File

@ -1,4 +1,4 @@
// This file has been autogenerated by generate-variants.py, do not edit manually. // This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-wmma-f16.cuh" #include "../fattn-wmma-f16.cuh"

View File

@ -20,6 +20,18 @@ SOURCE_FATTN_WMMA_START = """// This file has been autogenerated by generate_cu_
SOURCE_FATTN_WMMA_CASE = "DECL_FATTN_WMMA_F16_CASE({head_size}, {cols_per_block}, {kq_acc_t});\n" SOURCE_FATTN_WMMA_CASE = "DECL_FATTN_WMMA_F16_CASE({head_size}, {cols_per_block}, {kq_acc_t});\n"
TYPES_MMQ = [
"GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
"GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K"
]
SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmq.cuh"
DECL_MMQ_CASE({type});
"""
def get_short_name(long_quant_name): def get_short_name(long_quant_name):
return long_quant_name.replace("GGML_TYPE_", "").lower() return long_quant_name.replace("GGML_TYPE_", "").lower()
@ -57,3 +69,7 @@ for kq_acc_t in ["half", "float"]:
if kq_acc_t == "float" and cols_per_block == 32 and head_size == 256: # register spilling, bad performance if kq_acc_t == "float" and cols_per_block == 32 and head_size == 256: # register spilling, bad performance
continue continue
f.write(SOURCE_FATTN_WMMA_CASE.format(kq_acc_t=kq_acc_t, cols_per_block=cols_per_block, head_size=head_size)) f.write(SOURCE_FATTN_WMMA_CASE.format(kq_acc_t=kq_acc_t, cols_per_block=cols_per_block, head_size=head_size))
for type in TYPES_MMQ:
with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f:
f.write(SOURCE_MMQ.format(type=type))

View File

@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmq.cuh"
DECL_MMQ_CASE(GGML_TYPE_Q2_K);

Some files were not shown because too many files have changed in this diff Show More