From 1543a3600cdabdd7cbcc0aaab8b29e86b3f44f8d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 29 Apr 2025 16:00:27 +0200 Subject: [PATCH] CUDA: fix non-cont. inputs for batched mat mul (llama/13155) --- ggml/src/ggml-cuda/convert.cu | 53 +++++++++++++++++++------ ggml/src/ggml-cuda/convert.cuh | 12 +++++- ggml/src/ggml-cuda/ggml-cuda.cu | 70 ++++++++++++++++++++------------- 3 files changed, 94 insertions(+), 41 deletions(-) diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index a224ec0e..c6dec427 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -1,6 +1,8 @@ #include "convert.cuh" #include "dequantize.cuh" +#include + #define CUDA_Q8_0_NE_ALIGN 2048 template @@ -570,30 +572,46 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t } template -static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) { - const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; +static __global__ void convert_unary( + const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02, + const int64_t s01, const int64_t s02, const int64_t s03) { + const int64_t i00 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; - if (i >= k) { + if (i00 >= ne00) { return; } + const int64_t i01 = blockIdx.y; + const int64_t i02 = blockIdx.z % ne02; + const int64_t i03 = blockIdx.z / ne02; + const src_t * x = (const src_t *) vx; - y[i] = float(x[i]); + const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00; + const int64_t iy = ((i03*ne02 + i02)*ne01 + i01)*ne00 + i00; + y[iy] = float(x[ix]); } template -static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; - convert_unary<<>>(vx, y, k); +static void convert_unary_cuda(const void * vx, dst_t * y, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) { + const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, ne01, ne02*ne03); + convert_unary<<>> + (vx, y, ne00, ne01, ne02, s01, s02, s03); +} + +template +static void convert_unary_cont_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { + convert_unary_cuda(vx, y, k, 1, 1, 1, k, k, k, stream); } to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type) { switch (type) { case GGML_TYPE_F32: - return convert_unary_cuda; + return convert_unary_cont_cuda; case GGML_TYPE_F16: - return convert_unary_cuda; + return convert_unary_cont_cuda; default: return nullptr; } @@ -643,9 +661,9 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { case GGML_TYPE_IQ3_S: return dequantize_row_iq3_s_cuda; case GGML_TYPE_F32: - return convert_unary_cuda; + return convert_unary_cont_cuda; case GGML_TYPE_BF16: - return convert_unary_cuda; + return convert_unary_cont_cuda; default: return nullptr; } @@ -692,7 +710,18 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { case GGML_TYPE_IQ3_S: return dequantize_row_iq3_s_cuda; case GGML_TYPE_F16: - return convert_unary_cuda; + return convert_unary_cont_cuda; + case GGML_TYPE_BF16: + return convert_unary_cont_cuda; + default: + return nullptr; + } +} + +to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) { + switch (type) { + case GGML_TYPE_F32: + return convert_unary_cuda; case GGML_TYPE_BF16: return convert_unary_cuda; default: diff --git a/ggml/src/ggml-cuda/convert.cuh b/ggml/src/ggml-cuda/convert.cuh index 411a13cf..b65b98e0 100644 --- a/ggml/src/ggml-cuda/convert.cuh +++ b/ggml/src/ggml-cuda/convert.cuh @@ -3,7 +3,7 @@ #define CUDA_DEQUANTIZE_BLOCK_SIZE 256 template -using to_t_cuda_t = void (*)(const void * __restrict__ x, T * __restrict__ y, int64_t k, cudaStream_t stream); +using to_t_cuda_t = void (*)(const void * x, T * y, int64_t k, cudaStream_t stream); typedef to_t_cuda_t to_fp32_cuda_t; typedef to_t_cuda_t to_fp16_cuda_t; @@ -14,3 +14,13 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type); to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type); to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type); + +// TODO more general support for non-contiguous inputs + +template +using to_t_nc_cuda_t = void (*)(const void * x, T * y, + int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03, + int64_t s01, int64_t s02, int64_t s03, cudaStream_t stream); + +typedef to_t_nc_cuda_t to_fp16_nc_cuda_t; +to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 19b9ce72..fba8cb65 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -1720,15 +1720,15 @@ static __global__ void k_compute_batched_ptrs( size_t nb12, size_t nb13, size_t nbd2, size_t nbd3, int64_t r2, int64_t r3) { - int64_t i13 = blockIdx.x * blockDim.x + threadIdx.x; - int64_t i12 = blockIdx.y * blockDim.y + threadIdx.y; + const int64_t i13 = blockIdx.x * blockDim.x + threadIdx.x; + const int64_t i12 = blockIdx.y * blockDim.y + threadIdx.y; if (i13 >= ne13 || i12 >= ne12) { return; } - int64_t i03 = i13 / r3; - int64_t i02 = i12 / r2; + const int64_t i03 = i13 / r3; + const int64_t i02 = i12 / r2; ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03; ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12 + i13*nb13; @@ -1742,6 +1742,10 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer)); GGML_ASSERT(src0->type == GGML_TYPE_F16); + // Byte offsets and tensor dimensions are currently used in an inconsistent way for dst. + // As long as dst is contiguous this does not matter though. + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_TENSOR_BINARY_OP_LOCALS const int64_t ne_dst = ggml_nelements(dst); @@ -1750,21 +1754,31 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(), main_stream)); - void * src0_ddq = src0->data; - half * src0_f16 = (half *) src0_ddq; - float * src1_ddf = (float *) src1->data; - float * dst_ddf = (float *) dst->data; + const half * src0_f16 = (const half *) src0->data; + float * dst_ddf = (float *) dst->data; + + const half * src1_f16 = (const half *) src1->data; + const size_t ts_src1 = ggml_type_size(src1->type); + GGML_ASSERT(nb10 == ts_src1); + int64_t s11 = nb11 / ts_src1; + int64_t s12 = nb12 / ts_src1; + int64_t s13 = nb13 / ts_src1; + ggml_cuda_pool_alloc src1_f16_alloc(ctx.pool()); // convert src1 to fp16 - ggml_cuda_pool_alloc src1_f16_alloc(ctx.pool()); if (src1->type != GGML_TYPE_F16) { - const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type); + const to_fp16_nc_cuda_t to_fp16_cuda = ggml_get_to_fp16_nc_cuda(src1->type); const int64_t ne_src1 = ggml_nelements(src1); src1_f16_alloc.alloc(ne_src1); GGML_ASSERT(to_fp16_cuda != nullptr); - to_fp16_cuda(src1_ddf, src1_f16_alloc.get(), ne_src1, main_stream); + + to_fp16_cuda(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream); + + src1_f16 = src1_f16_alloc.get(); + s11 = ne10; + s12 = ne11*s11; + s13 = ne12*s12; } - half * src1_f16 = src1->type == GGML_TYPE_F16 ? (half *) src1_ddf : src1_f16_alloc.get(); ggml_cuda_pool_alloc dst_f16(ctx.pool()); char * dst_t; @@ -1824,13 +1838,13 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co int i02 = i12 / r2; CUBLAS_CHECK( - cublasGemmEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N, - ne01, ne11, ne10, - alpha, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half), - (const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float), - beta, ( char *) dst_t + i12*nbd2 + i13*nbd3, cu_data_type, ne01, - cu_compute_type, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + cublasGemmEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N, + ne01, ne11, ne10, + alpha, (const char *) src0_f16 + i03*nb03 + i02*nb02, CUDA_R_16F, nb01/sizeof(half), + src1_f16 + i13*s13 + i12*s12, CUDA_R_16F, s11, + beta, ( char *) dst_t + i13*nbd3 + i12*nbd2, cu_data_type, ne0, + cu_compute_type, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); } } } @@ -1841,15 +1855,15 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co CUBLAS_CHECK( cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N, ne01, ne11, ne10, - alpha, (const char *) src0_f16, CUDA_R_16F, nb01/nb00, nb02/nb00, // strideA - (const char *) src1_f16, CUDA_R_16F, nb11/nb10, nb12/nb10, // strideB - beta, ( char *) dst_t, cu_data_type, ne01, nb2/nb0, // strideC + alpha, src0_f16, CUDA_R_16F, nb01/nb00, nb02/nb00, // strideA + src1_f16, CUDA_R_16F, s11, s12, // strideB + beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC ne12*ne13, cu_compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); } else { // use cublasGemmBatchedEx - const int ne23 = ne12*ne13; + const int64_t ne23 = ne12*ne13; ggml_cuda_pool_alloc ptrs_src(ctx.pool(), 2*ne23); ggml_cuda_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23); @@ -1861,8 +1875,8 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co ne12, ne13, ne23, nb02, nb03, - src1->type == GGML_TYPE_F16 ? nb12 : nb12/2, - src1->type == GGML_TYPE_F16 ? nb13 : nb13/2, + src1->type == GGML_TYPE_F16 ? nb12 : s12*sizeof(half), + src1->type == GGML_TYPE_F16 ? nb13 : s13*sizeof(half), nbd2, nbd3, r2, r3); CUDA_CHECK(cudaGetLastError()); @@ -1871,8 +1885,8 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N, ne01, ne11, ne10, alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F, nb01/nb00, - (const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F, nb11/nb10, - beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne01, + (const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F, s11, + beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0, ne23, cu_compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); @@ -1936,7 +1950,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor } else if (!split && use_mul_mat_vec_q) { ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst); } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) && - dst->op_params[0] == GGML_PREC_DEFAULT && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { + !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { // general KQ + KQV multi-batch without FlashAttention ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst); } else if (use_mul_mat_vec) {