CUDA: fix non-cont. inputs for batched mat mul (llama/13155)

This commit is contained in:
Johannes Gäßler 2025-04-29 16:00:27 +02:00 committed by Georgi Gerganov
parent 4872355f6e
commit 1543a3600c
3 changed files with 94 additions and 41 deletions

View File

@ -1,6 +1,8 @@
#include "convert.cuh"
#include "dequantize.cuh"
#include <cstdint>
#define CUDA_Q8_0_NE_ALIGN 2048
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
@ -570,30 +572,46 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t
}
template <typename src_t, typename dst_t>
static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
static __global__ void convert_unary(
const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02,
const int64_t s01, const int64_t s02, const int64_t s03) {
const int64_t i00 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
if (i00 >= ne00) {
return;
}
const int64_t i01 = blockIdx.y;
const int64_t i02 = blockIdx.z % ne02;
const int64_t i03 = blockIdx.z / ne02;
const src_t * x = (const src_t *) vx;
y[i] = float(x[i]);
const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00;
const int64_t iy = ((i03*ne02 + i02)*ne01 + i01)*ne00 + i00;
y[iy] = float(x[ix]);
}
template <typename src_t, typename dst_t>
static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
convert_unary<src_t><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
static void convert_unary_cuda(const void * vx, dst_t * y,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) {
const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, ne01, ne02*ne03);
convert_unary<src_t><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>
(vx, y, ne00, ne01, ne02, s01, s02, s03);
}
template <typename src_t, typename dst_t>
static void convert_unary_cont_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
convert_unary_cuda<src_t>(vx, y, k, 1, 1, 1, k, k, k, stream);
}
to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type) {
switch (type) {
case GGML_TYPE_F32:
return convert_unary_cuda<float>;
return convert_unary_cont_cuda<float>;
case GGML_TYPE_F16:
return convert_unary_cuda<half>;
return convert_unary_cont_cuda<half>;
default:
return nullptr;
}
@ -643,9 +661,9 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
case GGML_TYPE_IQ3_S:
return dequantize_row_iq3_s_cuda;
case GGML_TYPE_F32:
return convert_unary_cuda<float>;
return convert_unary_cont_cuda<float>;
case GGML_TYPE_BF16:
return convert_unary_cuda<nv_bfloat16>;
return convert_unary_cont_cuda<nv_bfloat16>;
default:
return nullptr;
}
@ -692,7 +710,18 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
case GGML_TYPE_IQ3_S:
return dequantize_row_iq3_s_cuda;
case GGML_TYPE_F16:
return convert_unary_cuda<half>;
return convert_unary_cont_cuda<half>;
case GGML_TYPE_BF16:
return convert_unary_cont_cuda<nv_bfloat16>;
default:
return nullptr;
}
}
to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) {
switch (type) {
case GGML_TYPE_F32:
return convert_unary_cuda<float>;
case GGML_TYPE_BF16:
return convert_unary_cuda<nv_bfloat16>;
default:

View File

@ -3,7 +3,7 @@
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
template<typename T>
using to_t_cuda_t = void (*)(const void * __restrict__ x, T * __restrict__ y, int64_t k, cudaStream_t stream);
using to_t_cuda_t = void (*)(const void * x, T * y, int64_t k, cudaStream_t stream);
typedef to_t_cuda_t<float> to_fp32_cuda_t;
typedef to_t_cuda_t<half> to_fp16_cuda_t;
@ -14,3 +14,13 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type);
to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type);
to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type);
// TODO more general support for non-contiguous inputs
template<typename T>
using to_t_nc_cuda_t = void (*)(const void * x, T * y,
int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03,
int64_t s01, int64_t s02, int64_t s03, cudaStream_t stream);
typedef to_t_nc_cuda_t<half> to_fp16_nc_cuda_t;
to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type);

View File

@ -1720,15 +1720,15 @@ static __global__ void k_compute_batched_ptrs(
size_t nb12, size_t nb13,
size_t nbd2, size_t nbd3,
int64_t r2, int64_t r3) {
int64_t i13 = blockIdx.x * blockDim.x + threadIdx.x;
int64_t i12 = blockIdx.y * blockDim.y + threadIdx.y;
const int64_t i13 = blockIdx.x * blockDim.x + threadIdx.x;
const int64_t i12 = blockIdx.y * blockDim.y + threadIdx.y;
if (i13 >= ne13 || i12 >= ne12) {
return;
}
int64_t i03 = i13 / r3;
int64_t i02 = i12 / r2;
const int64_t i03 = i13 / r3;
const int64_t i02 = i12 / r2;
ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12 + i13*nb13;
@ -1742,6 +1742,10 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer));
GGML_ASSERT(src0->type == GGML_TYPE_F16);
// Byte offsets and tensor dimensions are currently used in an inconsistent way for dst.
// As long as dst is contiguous this does not matter though.
GGML_ASSERT(ggml_is_contiguous(dst));
GGML_TENSOR_BINARY_OP_LOCALS
const int64_t ne_dst = ggml_nelements(dst);
@ -1750,21 +1754,31 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(), main_stream));
void * src0_ddq = src0->data;
half * src0_f16 = (half *) src0_ddq;
float * src1_ddf = (float *) src1->data;
float * dst_ddf = (float *) dst->data;
const half * src0_f16 = (const half *) src0->data;
float * dst_ddf = (float *) dst->data;
const half * src1_f16 = (const half *) src1->data;
const size_t ts_src1 = ggml_type_size(src1->type);
GGML_ASSERT(nb10 == ts_src1);
int64_t s11 = nb11 / ts_src1;
int64_t s12 = nb12 / ts_src1;
int64_t s13 = nb13 / ts_src1;
ggml_cuda_pool_alloc<half> src1_f16_alloc(ctx.pool());
// convert src1 to fp16
ggml_cuda_pool_alloc<half> src1_f16_alloc(ctx.pool());
if (src1->type != GGML_TYPE_F16) {
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
const to_fp16_nc_cuda_t to_fp16_cuda = ggml_get_to_fp16_nc_cuda(src1->type);
const int64_t ne_src1 = ggml_nelements(src1);
src1_f16_alloc.alloc(ne_src1);
GGML_ASSERT(to_fp16_cuda != nullptr);
to_fp16_cuda(src1_ddf, src1_f16_alloc.get(), ne_src1, main_stream);
to_fp16_cuda(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
src1_f16 = src1_f16_alloc.get();
s11 = ne10;
s12 = ne11*s11;
s13 = ne12*s12;
}
half * src1_f16 = src1->type == GGML_TYPE_F16 ? (half *) src1_ddf : src1_f16_alloc.get();
ggml_cuda_pool_alloc<half> dst_f16(ctx.pool());
char * dst_t;
@ -1824,13 +1838,13 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
int i02 = i12 / r2;
CUBLAS_CHECK(
cublasGemmEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
alpha, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half),
(const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float),
beta, ( char *) dst_t + i12*nbd2 + i13*nbd3, cu_data_type, ne01,
cu_compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
cublasGemmEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
alpha, (const char *) src0_f16 + i03*nb03 + i02*nb02, CUDA_R_16F, nb01/sizeof(half),
src1_f16 + i13*s13 + i12*s12, CUDA_R_16F, s11,
beta, ( char *) dst_t + i13*nbd3 + i12*nbd2, cu_data_type, ne0,
cu_compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
}
}
@ -1841,15 +1855,15 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
CUBLAS_CHECK(
cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
alpha, (const char *) src0_f16, CUDA_R_16F, nb01/nb00, nb02/nb00, // strideA
(const char *) src1_f16, CUDA_R_16F, nb11/nb10, nb12/nb10, // strideB
beta, ( char *) dst_t, cu_data_type, ne01, nb2/nb0, // strideC
alpha, src0_f16, CUDA_R_16F, nb01/nb00, nb02/nb00, // strideA
src1_f16, CUDA_R_16F, s11, s12, // strideB
beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
ne12*ne13,
cu_compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
} else {
// use cublasGemmBatchedEx
const int ne23 = ne12*ne13;
const int64_t ne23 = ne12*ne13;
ggml_cuda_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
ggml_cuda_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
@ -1861,8 +1875,8 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
ne12, ne13,
ne23,
nb02, nb03,
src1->type == GGML_TYPE_F16 ? nb12 : nb12/2,
src1->type == GGML_TYPE_F16 ? nb13 : nb13/2,
src1->type == GGML_TYPE_F16 ? nb12 : s12*sizeof(half),
src1->type == GGML_TYPE_F16 ? nb13 : s13*sizeof(half),
nbd2, nbd3,
r2, r3);
CUDA_CHECK(cudaGetLastError());
@ -1871,8 +1885,8 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F, nb01/nb00,
(const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F, nb11/nb10,
beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne01,
(const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F, s11,
beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0,
ne23,
cu_compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
@ -1936,7 +1950,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
} else if (!split && use_mul_mat_vec_q) {
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
} else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) &&
dst->op_params[0] == GGML_PREC_DEFAULT && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
!ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
// general KQ + KQV multi-batch without FlashAttention
ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
} else if (use_mul_mat_vec) {