diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index c02bfb59..b0a3cc88 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -128,14 +128,9 @@ option(GGML_LLAMAFILE "ggml: use LLAMAFILE" option(GGML_CUDA "ggml: use CUDA" OFF) option(GGML_MUSA "ggml: use MUSA" OFF) -option(GGML_CUDA_FORCE_DMMV "ggml: use dmmv instead of mmvq CUDA kernels" OFF) option(GGML_CUDA_FORCE_MMQ "ggml: use mmq kernels instead of cuBLAS" OFF) option(GGML_CUDA_FORCE_CUBLAS "ggml: always use cuBLAS instead of mmq kernels" OFF) -set (GGML_CUDA_DMMV_X "32" CACHE STRING "ggml: x stride for dmmv CUDA kernels") -set (GGML_CUDA_MMV_Y "1" CACHE STRING "ggml: y block size for mmv CUDA kernels") option(GGML_CUDA_F16 "ggml: use 16 bit floats for some calculations" OFF) -set (GGML_CUDA_KQUANTS_ITER "2" CACHE STRING - "ggml: iters./thread per block for Q2_K/Q6_K") set (GGML_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING "ggml: max. batch size for using peer access") option(GGML_CUDA_NO_PEER_COPY "ggml: do not use peer to peer copies" OFF) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 07f04332..ef56e944 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -16,11 +16,11 @@ #include "ggml-cuda/cpy.cuh" #include "ggml-cuda/cross-entropy-loss.cuh" #include "ggml-cuda/diagmask.cuh" -#include "ggml-cuda/dmmv.cuh" #include "ggml-cuda/fattn.cuh" #include "ggml-cuda/getrows.cuh" #include "ggml-cuda/im2col.cuh" #include "ggml-cuda/mmq.cuh" +#include "ggml-cuda/mmv.cuh" #include "ggml-cuda/mmvq.cuh" #include "ggml-cuda/norm.cuh" #include "ggml-cuda/opt-step-adamw.cuh" @@ -1020,114 +1020,6 @@ typedef void (*ggml_cuda_op_mul_mat_t)( #define MUL_MAT_SRC1_COL_STRIDE 128 -static __global__ void mul_mat_p021_f16_f32( - const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y) { - - const half * x = (const half *) vx; - - const int row_x = blockDim.y*blockIdx.y + threadIdx.y; - const int channel = blockDim.z*blockIdx.z + threadIdx.z; - const int channel_x = channel / (nchannels_y / nchannels_x); - - const int nrows_y = ncols_x; - const int nrows_dst = nrows_x; - const int row_dst = row_x; - - float tmp = 0.0f; - - for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) { - const int col_x = col_x0 + threadIdx.x; - - if (col_x >= ncols_x) { - break; - } - - // x is transposed and permuted - const int ix = row_x*nchannels_x*ncols_x + channel_x*ncols_x + col_x; - const float xi = __half2float(x[ix]); - - const int row_y = col_x; - - // y is not transposed but permuted - const int iy = channel*nrows_y + row_y; - - tmp += xi * y[iy]; - } - - // dst is not transposed and not permuted - const int idst = channel*nrows_dst + row_dst; - - // sum up partial sums and write back result - tmp = warp_reduce_sum(tmp); - - if (threadIdx.x == 0) { - dst[idst] = tmp; - } -} - -static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous - const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x, - const int row_stride_x, const int channel_stride_x, const int channel_x_divisor) { - - const half * x = (const half *) vx; - - const int row_x = blockDim.y*blockIdx.y + threadIdx.y; - const int channel = blockDim.z*blockIdx.z + threadIdx.z; - const int channel_x = channel / channel_x_divisor; - - const int nrows_y = ncols_x; - const int nrows_dst = nrows_x; - const int row_dst = row_x; - - const int idst = channel*nrows_dst + row_dst; - - float tmp = 0.0f; - - for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) { - const int col_x = col_x0 + threadIdx.x; - - if (col_x >= ncols_x) { - break; - } - - const int row_y = col_x; - - const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x; - const int iy = channel*nrows_y + row_y; - - const float xi = __half2float(x[ix]); - - tmp += xi * y[iy]; - } - - // sum up partial sums and write back result - tmp = warp_reduce_sum(tmp); - - if (threadIdx.x == 0) { - dst[idst] = tmp; - } -} - -static void ggml_mul_mat_p021_f16_f32_cuda( - const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, - const int nchannels_x, const int nchannels_y, cudaStream_t stream) { - - const dim3 block_nums(1, nrows_x, nchannels_y); - const dim3 block_dims(WARP_SIZE, 1, 1); - mul_mat_p021_f16_f32<<>>(vx, y, dst, ncols_x, nrows_x, nchannels_x, nchannels_y); -} - -static void ggml_mul_mat_vec_nc_f16_f32_cuda( - const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int row_stride_x, - const int nchannels_x, const int nchannels_y, const int channel_stride_x, cudaStream_t stream) { - - const dim3 block_nums(1, nrows_x, nchannels_y); - const dim3 block_dims(WARP_SIZE, 1, 1); - mul_mat_vec_nc_f16_f32<<>> - (vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x, nchannels_y/nchannels_x); -} - static cudaError_t ggml_cuda_cpy_tensor_2d( void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) { @@ -1654,58 +1546,6 @@ static void ggml_cuda_op_mul_mat( } } -static void ggml_cuda_mul_mat_vec_p021(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1)); - GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer)); - GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation - GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // 0213 permutation - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; - - const int64_t ne12 = src1->ne[2]; - - cudaStream_t main_stream = ctx.stream(); - - void * src0_ddq = src0->data; - float * src1_ddf = (float *) src1->data; - float * dst_ddf = (float *) dst->data; - - ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream); -} - -static void ggml_cuda_mul_mat_vec_nc(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_ASSERT(!ggml_is_transposed(src0)); - GGML_ASSERT(!ggml_is_transposed(src1)); - GGML_ASSERT(!ggml_is_permuted(src0)); - GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer)); - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; - - const int64_t nb01 = src0->nb[1]; - const int64_t nb02 = src0->nb[2]; - - const int64_t ne12 = src1->ne[2]; - - cudaStream_t main_stream = ctx.stream(); - - void * src0_ddq = src0->data; - float * src1_ddf = (float *) src1->data; - float * dst_ddf = (float *) dst->data; - - const int64_t row_stride_x = nb01 / sizeof(half); - const int64_t channel_stride_x = nb02 / sizeof(half); - - ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream); -} - static __global__ void k_compute_batched_ptrs( const half * src0_as_f16, const half * src1_as_f16, char * dst, const void ** ptrs_src, void ** ptrs_dst, @@ -1879,21 +1719,17 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft); - bool use_dequantize_mul_mat_vec = ggml_cuda_dmmv_type_supported(src0->type) + bool use_mul_mat_vec = src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 - && src0->ne[0] % (GGML_CUDA_DMMV_X*2) == 0 && src1->ne[1] == 1; - bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) + && src0->ne[0] % 2 == 0 && src1->ne[1] == 1; + bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; - bool use_mul_mat_q = ggml_is_quantized(src0->type) + bool use_mul_mat_q = ggml_is_quantized(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; - // if mmvq is available it's a better choice than dmmv: -#ifndef GGML_CUDA_FORCE_DMMV - use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q; -#endif // GGML_CUDA_FORCE_DMMV - - bool any_gpus_with_slow_fp16 = false; + bool any_gpus_with_slow_fp16 = false; + bool any_gpus_without_fp16_mma = false; if (split) { ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context; @@ -1904,14 +1740,16 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor continue; } - const int cc = ggml_cuda_info().devices[id].cc; - use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); - any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc); + const int cc = ggml_cuda_info().devices[id].cc; + use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); + any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc); + any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_available(cc); } } else { - const int cc = ggml_cuda_info().devices[ctx.device].cc; - use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); - any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc); + const int cc = ggml_cuda_info().devices[ctx.device].cc; + use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); + any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc); + any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_available(cc); } // debug helpers @@ -1922,18 +1760,14 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name); //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name); - if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { - // FP32 precision KQ single-batch for batch size 1 without FlashAttention - ggml_cuda_mul_mat_vec_p021(ctx, src0, src1, dst); - } else if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { - // FP32 precision KQV single-batch for batch size 1 without FlashAttention - ggml_cuda_mul_mat_vec_nc(ctx, src0, src1, dst); + if (!split && src0->type == GGML_TYPE_F16 && src1->ne[1] == 1 && dst->ne[3] == 1 && (src0->ne[1] < MMV_MAX_ROWS || any_gpus_without_fp16_mma)) { + ggml_cuda_mul_mat_vec(ctx, src0, src1, dst); } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { // KQ + KQV multi-batch without FlashAttention ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst); - } else if (use_dequantize_mul_mat_vec) { - ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, nullptr); + } else if (use_mul_mat_vec) { + ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec, nullptr); } else if (use_mul_mat_vec_q) { ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda); } else if (use_mul_mat_q) { diff --git a/ggml/src/ggml-cuda/ggml/CMakeLists.txt b/ggml/src/ggml-cuda/ggml/CMakeLists.txt index 860552f3..3dde0f36 100644 --- a/ggml/src/ggml-cuda/ggml/CMakeLists.txt +++ b/ggml/src/ggml-cuda/ggml/CMakeLists.txt @@ -54,21 +54,12 @@ if (CUDAToolkit_FOUND) target_link_libraries(ggml-cuda PRIVATE ggml-base) target_include_directories(ggml-cuda PRIVATE . ..) - # TODO: change the definitions to this target only - - add_compile_definitions(GGML_CUDA_DMMV_X=${GGML_CUDA_DMMV_X}) - add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_MMV_Y}) - add_compile_definitions(K_QUANTS_PER_ITERATION=${GGML_CUDA_KQUANTS_ITER}) add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE}) if (GGML_CUDA_GRAPHS) add_compile_definitions(GGML_CUDA_USE_GRAPHS) endif() - if (GGML_CUDA_FORCE_DMMV) - add_compile_definitions(GGML_CUDA_FORCE_DMMV) - endif() - if (GGML_CUDA_FORCE_MMQ) add_compile_definitions(GGML_CUDA_FORCE_MMQ) endif() @@ -81,10 +72,6 @@ if (CUDAToolkit_FOUND) add_compile_definitions(GGML_CUDA_NO_VMM) endif() - if (DEFINED GGML_CUDA_DMMV_Y) - add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_DMMV_Y}) # for backwards compatibility - endif() - if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16) add_compile_definitions(GGML_CUDA_F16) endif() diff --git a/ggml/src/ggml-cuda/mmv.cu b/ggml/src/ggml-cuda/mmv.cu new file mode 100644 index 00000000..cfe91f42 --- /dev/null +++ b/ggml/src/ggml-cuda/mmv.cu @@ -0,0 +1,223 @@ +#include "common.cuh" +#include "mmv.cuh" + +template +static __global__ void mul_mat_vec( + const half * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row, + const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst) { + const int64_t row = blockIdx.x; + const int64_t channel = blockIdx.z; + const int tid = threadIdx.x; + + x += (channel/channel_ratio)*stride_channel_x + row*stride_row; + y += channel *stride_channel_y; + dst += channel *stride_channel_dst; + + const half2 * x2 = (const half2 *) x; + const float2 * y2 = (const float2 *) y; + + extern __shared__ char data_mmv[]; + float * buf_iw = (float *) data_mmv; + + if (block_size > WARP_SIZE) { + if (tid < WARP_SIZE) { + buf_iw[tid] = 0.0f; + } + __syncthreads(); + } + + float sumf; + + if (std::is_same::value) { + sumf = 0.0f; + + for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) { + const float2 tmpx = __half22float2(x2[col2]); + const float2 tmpy = y2[col2]; + sumf += tmpx.x * tmpy.x; + sumf += tmpx.y * tmpy.y; + } + } else { +#ifdef FP16_AVAILABLE + half2 sumh2 = make_half2(0.0f, 0.0f); + + for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) { + const float2 tmp = y2[col2]; + sumh2 += x2[col2] * make_half2(tmp.x, tmp.y); + } + + sumf = __low2float(sumh2) + __high2float(sumh2); +#else + NO_DEVICE_CODE; +#endif // FP16_AVAILABLE + } + + sumf = warp_reduce_sum(sumf); + + if (block_size > WARP_SIZE) { + buf_iw[tid/WARP_SIZE] = sumf; + __syncthreads(); + if (tid > WARP_SIZE) { + return; + } + sumf = buf_iw[tid]; + sumf = warp_reduce_sum(sumf); + } + + if (tid != 0) { + return; + } + + dst[row] = sumf; +} + +template +static void launch_mul_mat_vec_cuda( + const half * x, const float * y, float * dst, + const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, + cudaStream_t stream) { + GGML_ASSERT(ncols % 2 == 0); + GGML_ASSERT(stride_row % 2 == 0); + GGML_ASSERT(nchannels_y % nchannels_x == 0); + const int64_t channel_ratio = nchannels_y / nchannels_x; + + int64_t block_size_best = WARP_SIZE; + int64_t niter_best = (ncols + 2*WARP_SIZE - 1) / (2*WARP_SIZE); + for (int64_t block_size = 2*WARP_SIZE; block_size <= 256; block_size += WARP_SIZE) { + const int64_t niter = (ncols + 2*block_size - 1) / (2*block_size); + if (niter < niter_best) { + niter_best = niter; + block_size_best = block_size; + } + } + + const int smem = WARP_SIZE*sizeof(float); + const dim3 block_nums(nrows, 1, nchannels_y); + const dim3 block_dims(block_size_best, 1, 1); + switch (block_size_best) { + case 32: { + mul_mat_vec<<>> + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + } break; + case 64: { + mul_mat_vec<<>> + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + } break; + case 96: { + mul_mat_vec<<>> + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + } break; + case 128: { + mul_mat_vec<<>> + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + } break; + case 160: { + mul_mat_vec<<>> + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + } break; + case 192: { + mul_mat_vec<<>> + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + } break; + case 224: { + mul_mat_vec<<>> + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + } break; + case 256: { + mul_mat_vec<<>> + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + } break; + default: { + GGML_ABORT("fatal error"); + } break; + } +} + +static void mul_mat_vec_cuda( + const half * x, const float * y, float * dst, + const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, + enum ggml_prec prec, cudaStream_t stream) { + switch (prec) { + case GGML_PREC_DEFAULT: { + launch_mul_mat_vec_cuda(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, + stride_channel_x, stride_channel_y, stride_channel_dst, stream); + } break; + case GGML_PREC_F32: { + launch_mul_mat_vec_cuda(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, + stride_channel_x, stride_channel_y, stride_channel_dst, stream); + } break; + } +} + +void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + + GGML_ASSERT(src1->ne[1] == 1); + + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32; + + const half * src0_d = (const half *) src0->data; + const float * src1_d = (const float *) src1->data; + float * dst_d = (float *) dst->data; + + const int64_t ne02 = src0->ne[2]; + const int64_t ne12 = src1->ne[2]; + GGML_ASSERT(dst->ne[2] == ne12); + + GGML_ASSERT(src0->ne[3] == 1); + GGML_ASSERT(src1->ne[3] == 1); + GGML_ASSERT( dst->ne[3] == 1); + + const int64_t stride_row = src0->nb[1] / ggml_type_size(src0->type); + const int64_t channel_stride_x = src0->nb[2] / ggml_type_size(src0->type); + const int64_t channel_stride_y = src1->nb[2] / ggml_type_size(src1->type); + const int64_t channel_stride_dst = dst->nb[2] / ggml_type_size( dst->type); + + mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12, channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream()); +} + +void ggml_cuda_op_mul_mat_vec( + ggml_backend_cuda_context & ctx, + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, cudaStream_t stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + const int64_t ne00 = src0->ne[0]; + const int64_t row_diff = row_high - row_low; + + GGML_ASSERT(src1_ncols == 1); + + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32; + + + // ggml_cuda_op provides single, contiguous matrices + const int64_t stride_row = ne00; + const int64_t nchannels_x = 1; + const int64_t nchannels_y = 1; + const int64_t channel_stride_x = 0; + const int64_t channel_stride_y = 0; + const int64_t channel_stride_dst = 0; + + mul_mat_vec_cuda((const half *) src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row, + nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream); + + GGML_UNUSED(ctx); + GGML_UNUSED(src1); + GGML_UNUSED(dst); + GGML_UNUSED(src1_ddq_i); + GGML_UNUSED(src1_ncols); + GGML_UNUSED(src1_padded_row_size); +} diff --git a/ggml/src/ggml-cuda/mmv.cuh b/ggml/src/ggml-cuda/mmv.cuh new file mode 100644 index 00000000..78a1cd4a --- /dev/null +++ b/ggml/src/ggml-cuda/mmv.cuh @@ -0,0 +1,12 @@ +#include "common.cuh" + +// maximum number of src0 rows with which to use mul_mat_vec over cuBLAS if FP16 tensor cores are available +#define MMV_MAX_ROWS 512 + +void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); + +void ggml_cuda_op_mul_mat_vec( + ggml_backend_cuda_context & ctx, + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, cudaStream_t stream); diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index 5ed186de..fccf8eb8 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -75,18 +75,11 @@ target_include_directories(ggml-hip PRIVATE . ..) target_compile_definitions(ggml PUBLIC GGML_USE_CUDA) add_compile_definitions(GGML_USE_HIP) -add_compile_definitions(GGML_CUDA_DMMV_X=${GGML_CUDA_DMMV_X}) -add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_MMV_Y}) -add_compile_definitions(K_QUANTS_PER_ITERATION=${GGML_CUDA_KQUANTS_ITER}) if (GGML_HIP_UMA) add_compile_definitions(GGML_HIP_UMA) endif() -if (GGML_CUDA_FORCE_DMMV) - add_compile_definitions(GGML_CUDA_FORCE_DMMV) -endif() - if (GGML_CUDA_FORCE_MMQ) add_compile_definitions(GGML_CUDA_FORCE_MMQ) endif() diff --git a/ggml/src/ggml-musa/ggml/CMakeLists.txt b/ggml/src/ggml-musa/ggml/CMakeLists.txt index 8edc75cc..f3c01369 100644 --- a/ggml/src/ggml-musa/ggml/CMakeLists.txt +++ b/ggml/src/ggml-musa/ggml/CMakeLists.txt @@ -58,19 +58,12 @@ if (MUSAToolkit_FOUND) target_compile_definitions(ggml PUBLIC GGML_USE_CUDA) add_compile_definitions(GGML_USE_MUSA) - add_compile_definitions(GGML_CUDA_DMMV_X=${GGML_CUDA_DMMV_X}) - add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_MMV_Y}) - add_compile_definitions(K_QUANTS_PER_ITERATION=${GGML_CUDA_KQUANTS_ITER}) add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE}) if (GGML_CUDA_GRAPHS) add_compile_definitions(GGML_CUDA_USE_GRAPHS) endif() - if (GGML_CUDA_FORCE_DMMV) - add_compile_definitions(GGML_CUDA_FORCE_DMMV) - endif() - if (GGML_CUDA_FORCE_MMQ) add_compile_definitions(GGML_CUDA_FORCE_MMQ) endif() @@ -83,10 +76,6 @@ if (MUSAToolkit_FOUND) add_compile_definitions(GGML_CUDA_NO_VMM) endif() - if (DEFINED GGML_CUDA_DMMV_Y) - add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_DMMV_Y}) # for backwards compatibility - endif() - if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16) add_compile_definitions(GGML_CUDA_F16) endif()