CUDA: faster q8_0 -> f16 dequantization (llama/4895)

This commit is contained in:
Johannes Gäßler 2024-01-12 20:38:54 +01:00 committed by Georgi Gerganov
parent db078a9ba8
commit 12490f4398
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -523,6 +523,8 @@ static_assert(sizeof(block_iq2_xs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16
#define CUDA_ACC_BLOCK_SIZE 256 #define CUDA_ACC_BLOCK_SIZE 256
#define CUDA_IM2COL_BLOCK_SIZE 256 #define CUDA_IM2COL_BLOCK_SIZE 256
#define CUDA_Q8_0_NE_ALIGN 2048
// dmmv = dequantize_mul_mat_vec // dmmv = dequantize_mul_mat_vec
#ifndef GGML_CUDA_DMMV_X #ifndef GGML_CUDA_DMMV_X
#define GGML_CUDA_DMMV_X 32 #define GGML_CUDA_DMMV_X 32
@ -2327,6 +2329,45 @@ static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __res
y[i] = x[i]; y[i] = x[i];
} }
template <bool need_check>
static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, half * __restrict__ y, const int k) {
#if __CUDA_ARCH__ >= CC_PASCAL
constexpr int nint = CUDA_Q8_0_NE_ALIGN/sizeof(int) + WARP_SIZE;
const int i0 = CUDA_Q8_0_NE_ALIGN*blockIdx.x;
const int * x0 = ((int *) vx) + blockIdx.x * nint;
half2 * y2 = (half2 *) (y + i0);
__shared__ int vals[nint];
#pragma unroll
for (int ix0 = 0; ix0 < nint; ix0 += WARP_SIZE) {
if (need_check && i0*sizeof(block_q8_0)/QK8_0 + sizeof(int)*(ix0 + threadIdx.x) >= k*sizeof(block_q8_0)/QK8_0) {
break;
}
const int ix = ix0 + threadIdx.x;
vals[ix] = x0[ix];
}
#pragma unroll
for (int iy = 0; iy < CUDA_Q8_0_NE_ALIGN; iy += 2*WARP_SIZE) {
if (need_check && i0 + iy + 2*threadIdx.x >= k) {
return;
}
const half * b0 = ((const half *) vals) + (sizeof(block_q8_0)/sizeof(half)) * ((iy + 2*threadIdx.x)/QK8_0);
const half d = *b0;
const char2 qs = ((const char2 *) (b0 + 1))[threadIdx.x % (QK8_0/2)];
y2[iy/2 + threadIdx.x] = __hmul2(make_half2(qs.x, qs.y), __half2half2(d));
}
#else
(void) vx; (void) y; (void) k;
bad_arch();
#endif // __CUDA_ARCH__ >= CC_PASCAL
}
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called // VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q // MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
@ -6181,6 +6222,17 @@ static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restri
dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k); dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
} }
static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * __restrict__ y, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_Q8_0_NE_ALIGN - 1) / CUDA_Q8_0_NE_ALIGN;
if (k % CUDA_Q8_0_NE_ALIGN == 0) {
const bool need_check = false;
dequantize_block_q8_0_f16<need_check><<<num_blocks, WARP_SIZE, 0, stream>>>(vx, y, k);
} else {
const bool need_check = true;
dequantize_block_q8_0_f16<need_check><<<num_blocks, WARP_SIZE, 0, stream>>>(vx, y, k);
}
}
template<typename dst_t> template<typename dst_t>
static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
const int nb = k / QK_K; const int nb = k / QK_K;
@ -6246,6 +6298,7 @@ static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict_
} }
static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
int id;
switch (type) { switch (type) {
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>; return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
@ -6256,6 +6309,10 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
case GGML_TYPE_Q5_1: case GGML_TYPE_Q5_1:
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>; return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_0:
CUDA_CHECK(cudaGetDevice(&id));
if (g_device_caps[id].cc >= CC_PASCAL) {
return dequantize_block_q8_0_f16_cuda;
}
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>; return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
case GGML_TYPE_Q2_K: case GGML_TYPE_Q2_K:
return dequantize_row_q2_K_cuda; return dequantize_row_q2_K_cuda;