mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2024-12-21 13:37:47 +00:00
CUDA: faster q8_0 -> f16 dequantization (llama/4895)
This commit is contained in:
parent
db078a9ba8
commit
12490f4398
57
ggml-cuda.cu
57
ggml-cuda.cu
@ -523,6 +523,8 @@ static_assert(sizeof(block_iq2_xs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16
|
|||||||
#define CUDA_ACC_BLOCK_SIZE 256
|
#define CUDA_ACC_BLOCK_SIZE 256
|
||||||
#define CUDA_IM2COL_BLOCK_SIZE 256
|
#define CUDA_IM2COL_BLOCK_SIZE 256
|
||||||
|
|
||||||
|
#define CUDA_Q8_0_NE_ALIGN 2048
|
||||||
|
|
||||||
// dmmv = dequantize_mul_mat_vec
|
// dmmv = dequantize_mul_mat_vec
|
||||||
#ifndef GGML_CUDA_DMMV_X
|
#ifndef GGML_CUDA_DMMV_X
|
||||||
#define GGML_CUDA_DMMV_X 32
|
#define GGML_CUDA_DMMV_X 32
|
||||||
@ -2327,6 +2329,45 @@ static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __res
|
|||||||
y[i] = x[i];
|
y[i] = x[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <bool need_check>
|
||||||
|
static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, half * __restrict__ y, const int k) {
|
||||||
|
#if __CUDA_ARCH__ >= CC_PASCAL
|
||||||
|
constexpr int nint = CUDA_Q8_0_NE_ALIGN/sizeof(int) + WARP_SIZE;
|
||||||
|
|
||||||
|
const int i0 = CUDA_Q8_0_NE_ALIGN*blockIdx.x;
|
||||||
|
const int * x0 = ((int *) vx) + blockIdx.x * nint;
|
||||||
|
half2 * y2 = (half2 *) (y + i0);
|
||||||
|
|
||||||
|
__shared__ int vals[nint];
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int ix0 = 0; ix0 < nint; ix0 += WARP_SIZE) {
|
||||||
|
if (need_check && i0*sizeof(block_q8_0)/QK8_0 + sizeof(int)*(ix0 + threadIdx.x) >= k*sizeof(block_q8_0)/QK8_0) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int ix = ix0 + threadIdx.x;
|
||||||
|
vals[ix] = x0[ix];
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int iy = 0; iy < CUDA_Q8_0_NE_ALIGN; iy += 2*WARP_SIZE) {
|
||||||
|
if (need_check && i0 + iy + 2*threadIdx.x >= k) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const half * b0 = ((const half *) vals) + (sizeof(block_q8_0)/sizeof(half)) * ((iy + 2*threadIdx.x)/QK8_0);
|
||||||
|
const half d = *b0;
|
||||||
|
const char2 qs = ((const char2 *) (b0 + 1))[threadIdx.x % (QK8_0/2)];
|
||||||
|
|
||||||
|
y2[iy/2 + threadIdx.x] = __hmul2(make_half2(qs.x, qs.y), __half2half2(d));
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
(void) vx; (void) y; (void) k;
|
||||||
|
bad_arch();
|
||||||
|
#endif // __CUDA_ARCH__ >= CC_PASCAL
|
||||||
|
}
|
||||||
|
|
||||||
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
|
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
|
||||||
// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
|
// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
|
||||||
|
|
||||||
@ -6181,6 +6222,17 @@ static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restri
|
|||||||
dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * __restrict__ y, const int k, cudaStream_t stream) {
|
||||||
|
const int num_blocks = (k + CUDA_Q8_0_NE_ALIGN - 1) / CUDA_Q8_0_NE_ALIGN;
|
||||||
|
if (k % CUDA_Q8_0_NE_ALIGN == 0) {
|
||||||
|
const bool need_check = false;
|
||||||
|
dequantize_block_q8_0_f16<need_check><<<num_blocks, WARP_SIZE, 0, stream>>>(vx, y, k);
|
||||||
|
} else {
|
||||||
|
const bool need_check = true;
|
||||||
|
dequantize_block_q8_0_f16<need_check><<<num_blocks, WARP_SIZE, 0, stream>>>(vx, y, k);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template<typename dst_t>
|
template<typename dst_t>
|
||||||
static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
||||||
const int nb = k / QK_K;
|
const int nb = k / QK_K;
|
||||||
@ -6246,6 +6298,7 @@ static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict_
|
|||||||
}
|
}
|
||||||
|
|
||||||
static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
|
static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
|
||||||
|
int id;
|
||||||
switch (type) {
|
switch (type) {
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
|
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
|
||||||
@ -6256,6 +6309,10 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
|
|||||||
case GGML_TYPE_Q5_1:
|
case GGML_TYPE_Q5_1:
|
||||||
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
|
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
|
||||||
case GGML_TYPE_Q8_0:
|
case GGML_TYPE_Q8_0:
|
||||||
|
CUDA_CHECK(cudaGetDevice(&id));
|
||||||
|
if (g_device_caps[id].cc >= CC_PASCAL) {
|
||||||
|
return dequantize_block_q8_0_f16_cuda;
|
||||||
|
}
|
||||||
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
|
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
|
||||||
case GGML_TYPE_Q2_K:
|
case GGML_TYPE_Q2_K:
|
||||||
return dequantize_row_q2_K_cuda;
|
return dequantize_row_q2_K_cuda;
|
||||||
|
Loading…
Reference in New Issue
Block a user