CUDA: non-contiguous (RMS) norm support (llama/11659)

* CUDA: non-contiguous (RMS) norm support

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Johannes Gäßler 2025-02-04 22:21:42 +01:00 committed by Georgi Gerganov
parent c310272fa0
commit bae6bbf487
4 changed files with 66 additions and 34 deletions

View File

@ -38,6 +38,7 @@
#include "ggml-cuda/upscale.cuh" #include "ggml-cuda/upscale.cuh"
#include "ggml-cuda/wkv6.cuh" #include "ggml-cuda/wkv6.cuh"
#include "ggml-cuda/gla.cuh" #include "ggml-cuda/gla.cuh"
#include "ggml.h"
#include <algorithm> #include <algorithm>
#include <array> #include <array>
@ -3139,6 +3140,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
break; break;
case GGML_OP_NORM: case GGML_OP_NORM:
case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM:
return true;
case GGML_OP_RMS_NORM_BACK: case GGML_OP_RMS_NORM_BACK:
return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0; return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
break; break;
@ -3181,7 +3183,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_SUM_ROWS: case GGML_OP_SUM_ROWS:
case GGML_OP_ARGSORT: case GGML_OP_ARGSORT:
case GGML_OP_ACC: case GGML_OP_ACC:
return true;
case GGML_OP_GROUP_NORM: case GGML_OP_GROUP_NORM:
return ggml_is_contiguous(op->src[0]);
case GGML_OP_UPSCALE: case GGML_OP_UPSCALE:
case GGML_OP_PAD: case GGML_OP_PAD:
case GGML_OP_ARANGE: case GGML_OP_ARANGE:

View File

@ -1,12 +1,20 @@
#include "norm.cuh" #include "norm.cuh"
#include <cstdint>
template <int block_size> template <int block_size>
static __global__ void norm_f32(const float * x, float * dst, const int ncols, const float eps) { static __global__ void norm_f32(
const int row = blockIdx.x*blockDim.y + threadIdx.y; const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
const int64_t stride_sample, const float eps) {
const int nrows = gridDim.x;
const int nchannels = gridDim.y;
const int row = blockIdx.x;
const int channel = blockIdx.y;
const int sample = blockIdx.z;
const int tid = threadIdx.x; const int tid = threadIdx.x;
x += int64_t(row)*ncols; x += sample*stride_sample + channel*stride_channel + row*stride_row;
dst += int64_t(row)*ncols; dst += ((sample*nchannels + channel)*nrows + row)*ncols;
float2 mean_var = make_float2(0.0f, 0.0f); float2 mean_var = make_float2(0.0f, 0.0f);
@ -97,12 +105,19 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
} }
template <int block_size> template <int block_size>
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) { static __global__ void rms_norm_f32(
const int row = blockIdx.x*blockDim.y + threadIdx.y; const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
const int64_t stride_sample, const float eps) {
const int nrows = gridDim.x;
const int nchannels = gridDim.y;
const int row = blockIdx.x;
const int channel = blockIdx.y;
const int sample = blockIdx.z;
const int tid = threadIdx.x; const int tid = threadIdx.x;
x += int64_t(row)*ncols; x += sample*stride_sample + channel*stride_channel + row*stride_row;
dst += int64_t(row)*ncols; dst += ((sample*nchannels + channel)*nrows + row)*ncols;
float tmp = 0.0f; // partial sum for thread in warp float tmp = 0.0f; // partial sum for thread in warp
@ -186,13 +201,16 @@ static __global__ void rms_norm_back_f32(
} }
} }
static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { static void norm_f32_cuda(
const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
const dim3 blocks_num(nrows, nchannels, nsamples);
if (ncols < 1024) { if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1); const dim3 block_dims(WARP_SIZE, 1, 1);
norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps); norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
} else { } else {
const dim3 block_dims(1024, 1, 1); const dim3 block_dims(1024, 1, 1);
norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps); norm_f32<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
} }
} }
@ -207,13 +225,16 @@ static void group_norm_f32_cuda(
} }
} }
static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { static void rms_norm_f32_cuda(
const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
const dim3 blocks_num(nrows, nchannels, nsamples);
if (ncols < 1024) { if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1); const dim3 block_dims(WARP_SIZE, 1, 1);
rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps); rms_norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
} else { } else {
const dim3 block_dims(1024, 1, 1); const dim3 block_dims(1024, 1, 1);
rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps); rms_norm_f32<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
} }
} }
@ -233,19 +254,22 @@ void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
float * dst_d = (float *) dst->data; float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream(); cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32);
const int64_t ne00 = src0->ne[0]; GGML_TENSOR_UNARY_OP_LOCALS;
const int64_t nrows = ggml_nrows(src0);
float eps; float eps;
memcpy(&eps, dst->op_params, sizeof(float)); memcpy(&eps, dst->op_params, sizeof(float));
GGML_ASSERT(eps >= 0.0f); GGML_ASSERT(eps >= 0.0f);
norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream); const size_t ts0 = ggml_type_size(src0->type);
GGML_ASSERT(nb00 == ts0);
const int64_t s01 = nb01 / ts0;
const int64_t s02 = nb02 / ts0;
const int64_t s03 = nb03 / ts0;
norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
} }
void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@ -254,8 +278,6 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
float * dst_d = (float *)dst->data; float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream(); cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32);
@ -275,19 +297,22 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
float * dst_d = (float *) dst->data; float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream(); cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32);
const int64_t ne00 = src0->ne[0]; GGML_TENSOR_UNARY_OP_LOCALS;
const int64_t nrows = ggml_nrows(src0);
float eps; float eps;
memcpy(&eps, dst->op_params, sizeof(float)); memcpy(&eps, dst->op_params, sizeof(float));
GGML_ASSERT(eps >= 0.0f); GGML_ASSERT(eps >= 0.0f);
rms_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream); const size_t ts0 = ggml_type_size(src0->type);
GGML_ASSERT(nb00 == ts0);
const int64_t s01 = nb01 / ts0;
const int64_t s02 = nb02 / ts0;
const int64_t s03 = nb03 / ts0;
rms_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
} }
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {

View File

@ -1206,10 +1206,11 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_OP_GROUP_NORM: case GGML_OP_GROUP_NORM:
return has_simdgroup_reduction; return has_simdgroup_reduction;
case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM:
return has_simdgroup_reduction && (op->ne[0] % 4 == 0); return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
case GGML_OP_ARGMAX: case GGML_OP_ARGMAX:
case GGML_OP_NORM:
return true; return true;
case GGML_OP_NORM:
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
case GGML_OP_ROPE: case GGML_OP_ROPE:
{ {
const int mode = ((const int32_t *) op->op_params)[2]; const int mode = ((const int32_t *) op->op_params)[2];

View File

@ -8182,9 +8182,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_VIEW: case GGML_OP_VIEW:
case GGML_OP_PERMUTE: case GGML_OP_PERMUTE:
case GGML_OP_TRANSPOSE: case GGML_OP_TRANSPOSE:
return true;
case GGML_OP_NORM: case GGML_OP_NORM:
case GGML_OP_GROUP_NORM: case GGML_OP_GROUP_NORM:
case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM:
return ggml_is_contiguous(op->src[0]);
case GGML_OP_ADD: case GGML_OP_ADD:
case GGML_OP_ACC: case GGML_OP_ACC:
case GGML_OP_MUL: case GGML_OP_MUL: