From 52c4c03b0ab23be660b578dd9f1fd7ae399d852c Mon Sep 17 00:00:00 2001 From: Molly Sophia Date: Tue, 18 Mar 2025 07:27:50 +0800 Subject: [PATCH] llama: Add support for RWKV v7 architecture (llama/12412) * ggml: Add op l2_norm Signed-off-by: Molly Sophia * ggml: Add op rwkv_wkv7 Signed-off-by: Molly Sophia * llama: Add support for RWKV7 and ARWKV7 models Signed-off-by: Molly Sophia * llama: fix inference with RWKV6Qwen2 Signed-off-by: Molly Sophia * llama: add more (a)rwkv7 variants in size Signed-off-by: Molly Sophia * Apply code-format changes Signed-off-by: Molly Sophia * fix MUSA build Signed-off-by: Molly Sophia * llama: fix shape error with rwkv using llama-parallel Signed-off-by: Molly Sophia --------- Signed-off-by: Molly Sophia --- ggml/include/ggml.h | 24 ++ ggml/src/ggml-cpu/ggml-cpu.c | 255 ++++++++++++++- ggml/src/ggml-cuda/ggml-cuda.cu | 10 +- ggml/src/ggml-cuda/norm.cu | 116 +++++++ ggml/src/ggml-cuda/norm.cuh | 2 + ggml/src/ggml-cuda/wkv.cu | 199 ++++++++++++ ggml/src/ggml-cuda/wkv.cuh | 7 + ggml/src/ggml-metal/ggml-metal-impl.h | 7 + ggml/src/ggml-metal/ggml-metal.m | 122 +++++++ ggml/src/ggml-metal/ggml-metal.metal | 221 +++++++++++++ ggml/src/ggml-sycl/backend.hpp | 2 +- ggml/src/ggml-sycl/ggml-sycl.cpp | 14 + ggml/src/ggml-sycl/norm.cpp | 108 +++++++ ggml/src/ggml-sycl/norm.hpp | 6 + ggml/src/ggml-sycl/wkv.cpp | 305 ++++++++++++++++++ ggml/src/ggml-sycl/wkv.hpp | 10 + ggml/src/ggml-vulkan/ggml-vulkan.cpp | 206 +++++++----- .../ggml-vulkan/vulkan-shaders/l2_norm.comp | 41 +++ .../vulkan-shaders/vulkan-shaders-gen.cpp | 3 + ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp | 91 ++++++ ggml/src/ggml.c | 87 ++++- 21 files changed, 1751 insertions(+), 85 deletions(-) create mode 100644 ggml/src/ggml-cuda/wkv.cu create mode 100644 ggml/src/ggml-cuda/wkv.cuh create mode 100644 ggml/src/ggml-sycl/wkv.cpp create mode 100644 ggml/src/ggml-sycl/wkv.hpp create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 2e5076d3..cb3edb10 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -454,6 +454,7 @@ extern "C" { GGML_OP_RMS_NORM, GGML_OP_RMS_NORM_BACK, GGML_OP_GROUP_NORM, + GGML_OP_L2_NORM, GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID, @@ -502,6 +503,7 @@ extern "C" { GGML_OP_ADD_REL_POS, GGML_OP_RWKV_WKV6, GGML_OP_GATED_LINEAR_ATTN, + GGML_OP_RWKV_WKV7, GGML_OP_UNARY, @@ -1095,6 +1097,18 @@ extern "C" { int n_groups, float eps); + // l2 normalize along rows + // used in rwkv v7 + GGML_API struct ggml_tensor * ggml_l2_norm( + struct ggml_context * ctx, + struct ggml_tensor * a, + float eps); + + GGML_API struct ggml_tensor * ggml_l2_norm_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + float eps); + // a - x // b - dy GGML_API struct ggml_tensor * ggml_rms_norm_back( @@ -1890,6 +1904,16 @@ extern "C" { struct ggml_tensor * state, float scale); + GGML_API struct ggml_tensor * ggml_rwkv_wkv7( + struct ggml_context * ctx, + struct ggml_tensor * r, + struct ggml_tensor * w, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * state); + // custom operators typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *); diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 6fc5d42f..a2f8d91d 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -8548,6 +8548,69 @@ static void ggml_compute_forward_group_norm( } } +// ggml_compute_forward_l2_norm + +static void ggml_compute_forward_l2_norm_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + GGML_ASSERT(eps >= 0.0f); + + // TODO: optimize + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = ith; i01 < ne01; i01 += nth) { + const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + + ggml_float sum = 0.0; + for (int64_t i00 = 0; i00 < ne00; i00++) { + sum += (ggml_float)(x[i00] * x[i00]); + } + + float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); + + memcpy(y, x, ne00 * sizeof(float)); + + const float scale = 1.0f/fmaxf(sqrtf(sum), eps); + + ggml_vec_scale_f32(ne00, y, scale); + } + } + } +} + +static void ggml_compute_forward_l2_norm( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_l2_norm_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_mul_mat static void ggml_compute_forward_mul_mat_one_chunk( @@ -13604,6 +13667,184 @@ static void ggml_compute_forward_gla( } } +// ggml_compute_forward_rwkv_wkv7 + +static void ggml_compute_forward_rwkv_wkv7_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + const int64_t T = dst->src[1]->ne[2]; + const int64_t C = dst->ne[0]; + const int64_t HEADS = dst->src[1]->ne[1]; + const int64_t n_seqs = dst->src[6]->ne[1]; + const int64_t head_size = C / HEADS; + + float * dst_data = (float *) dst->data; + float * state = ((float *) dst->data) + C * T; + + const int ith = params->ith; + const int nth = params->nth; + + if (ith >= HEADS) { + return; + } + + const int h_start = (HEADS * ith) / nth; + const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ? + (HEADS * (ith + 1)) / nth : HEADS; + + float * r = (float *) dst->src[0]->data; + float * w = (float *) dst->src[1]->data; + float * k = (float *) dst->src[2]->data; + float * v = (float *) dst->src[3]->data; + float * a = (float *) dst->src[4]->data; + float * b = (float *) dst->src[5]->data; + + int64_t t_stride = HEADS * head_size; // Same to C + + int64_t h_stride = C / HEADS; + GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS + int64_t h_stride_2d = head_size * head_size; + + #if defined(GGML_SIMD) + for (int64_t t = 0; t < T; t++) { + int64_t t_offset = t * t_stride; + int64_t state_offset = head_size * C * (t / (T / n_seqs)); + float * state_cur = state + state_offset; + float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset; + + for (int64_t h = h_start; h < h_end; h++) { + int64_t h_offset = h * h_stride; + int64_t t_h_offset = t_offset + h_offset; + int64_t h_2d_offset = h * h_stride_2d; + + for (int64_t ii = 0; ii < head_size; ii++) { + int64_t t_h_i_offset = t_h_offset + ii; + int64_t h_2d_i_offset = h_2d_offset + ii * h_stride; + + GGML_F32_VEC v_vec = GGML_F32_VEC_SET1(v[t_h_i_offset]); + + float sa = 0; + { + GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; + GGML_F32_VEC ax[GGML_F32_ARR]; + GGML_F32_VEC ay[GGML_F32_ARR]; + for (int64_t j = 0; j < head_size; j += GGML_F32_STEP) { + for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) { + ax[kk] = GGML_F32_VEC_LOAD(&a[t_h_offset + j + kk * GGML_F32_EPR]); + ay[kk] = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]); + sum[kk] = GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]); + } + } + GGML_F32_VEC_REDUCE(sa, sum); + } + + GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1(sa); + + int64_t j = 0; + GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; + for (; j < head_size; j += GGML_F32_STEP) { + for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) { + int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR; + int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR; + + GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD(&r[t_h_j_offset]); + GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD(&w[t_h_j_offset]); + GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]); + GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]); + + k_vec = GGML_F32_VEC_MUL(v_vec, k_vec); + + GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]); + // kv + s * decay + sa * b + state_vec = GGML_F32_VEC_FMA(k_vec, state_vec, w_vec); + state_vec = GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec); + GGML_F32_VEC_STORE(&state_cur[h_2d_i_j_offset], state_vec); + + result_vec[kk] = GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec); + } + } + GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec); + + // There shouldn't be left-overs though. + for (; j < head_size; j++) { + int64_t t_h_j_offset = t_h_offset + j; + int64_t h_2d_i_j_offset = h_2d_i_offset + j; + + float r_val = r[t_h_j_offset]; + float w_val = w[t_h_j_offset]; + float k_val = k[t_h_j_offset]; + float b_val = b[t_h_j_offset]; + float kv_val = v[t_h_i_offset] * k_val; + + float prev_state_val = state_prev[h_2d_i_j_offset]; + state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val; + dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val; + } + } + } + } + #else + for (int64_t t = 0; t < T; t++) { + int64_t t_offset = t * t_stride; + int64_t state_offset = head_size * C * (t / (T / n_seqs)); + float * state_cur = state + state_offset; + float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset; + + for (int64_t h = h_start; h < h_end; h++) { + int64_t h_offset = h * h_stride; + int64_t t_h_offset = t_offset + h_offset; + int64_t h_2d_offset = h * h_stride_2d; + + for (int64_t i = 0; i < head_size; i++) { + int64_t t_h_i_offset = t_h_offset + i; + int64_t h_2d_i_offset = h_2d_offset + i * h_stride; + + float v_val = v[t_h_i_offset]; + + float sa = 0, result = 0; + for (int64_t j = 0; j < head_size; j++) { + sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j]; + } + + for (int64_t j = 0; j < head_size; j++) { + int64_t t_h_j_offset = t_h_offset + j; + int64_t h_2d_i_j_offset = h_2d_i_offset + j; + + float r_val = r[t_h_j_offset]; + float w_val = w[t_h_j_offset]; + float k_val = k[t_h_j_offset]; + float b_val = b[t_h_j_offset]; + float kv_val = v_val * k_val; + float prev_state_val = state_prev[h_2d_i_j_offset]; + state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val; + result += state_cur[h_2d_i_j_offset] * r_val; + } + dst_data[t_h_i_offset] = result; + } + } + } + #endif +} + + +static void ggml_compute_forward_rwkv_wkv7( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_rwkv_wkv7_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_map_unary static void ggml_compute_forward_map_unary_f32( @@ -14170,6 +14411,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_group_norm(params, tensor); } break; + case GGML_OP_L2_NORM: + { + ggml_compute_forward_l2_norm(params, tensor); + } break; case GGML_OP_MUL_MAT: { ggml_compute_forward_mul_mat(params, tensor); @@ -14357,6 +14602,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_gla(params, tensor); } break; + case GGML_OP_RWKV_WKV7: + { + ggml_compute_forward_rwkv_wkv7(params, tensor); + } break; case GGML_OP_MAP_UNARY: { ggml_unary_op_f32_t fun; @@ -14582,6 +14831,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_NORM: case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM_BACK: + case GGML_OP_L2_NORM: case GGML_OP_GROUP_NORM: case GGML_OP_CONCAT: case GGML_OP_MUL_MAT: @@ -14648,14 +14898,15 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_FLASH_ATTN_BACK: case GGML_OP_SSM_CONV: case GGML_OP_SSM_SCAN: + case GGML_OP_RWKV_WKV6: + case GGML_OP_GATED_LINEAR_ATTN: + case GGML_OP_RWKV_WKV7: { n_tasks = n_threads; } break; case GGML_OP_WIN_PART: case GGML_OP_WIN_UNPART: case GGML_OP_GET_REL_POS: - case GGML_OP_RWKV_WKV6: - case GGML_OP_GATED_LINEAR_ATTN: case GGML_OP_MAP_UNARY: case GGML_OP_MAP_BINARY: case GGML_OP_MAP_CUSTOM1_F32: diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 9bba398c..8fb06382 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -36,7 +36,7 @@ #include "ggml-cuda/tsembd.cuh" #include "ggml-cuda/unary.cuh" #include "ggml-cuda/upscale.cuh" -#include "ggml-cuda/wkv6.cuh" +#include "ggml-cuda/wkv.cuh" #include "ggml-cuda/gla.cuh" #include "ggml.h" @@ -2196,6 +2196,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_GROUP_NORM: ggml_cuda_op_group_norm(ctx, dst); break; + case GGML_OP_L2_NORM: + ggml_cuda_op_l2_norm(ctx, dst); + break; case GGML_OP_CONCAT: ggml_cuda_op_concat(ctx, dst); break; @@ -2304,6 +2307,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_GATED_LINEAR_ATTN: ggml_cuda_op_gated_linear_attn(ctx, dst); break; + case GGML_OP_RWKV_WKV7: + ggml_cuda_op_rwkv_wkv7(ctx, dst); + break; case GGML_OP_CROSS_ENTROPY_LOSS_BACK: ggml_cuda_cross_entropy_loss_back(ctx, dst); break; @@ -3161,6 +3167,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g break; case GGML_OP_NORM: case GGML_OP_RMS_NORM: + case GGML_OP_L2_NORM: return true; case GGML_OP_RMS_NORM_BACK: return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0; @@ -3215,6 +3222,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_LEAKY_RELU: case GGML_OP_RWKV_WKV6: case GGML_OP_GATED_LINEAR_ATTN: + case GGML_OP_RWKV_WKV7: return true; case GGML_OP_FLASH_ATTN_EXT: { #ifndef FLASH_ATTN_AVAILABLE diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu index f127616e..0020dbce 100644 --- a/ggml/src/ggml-cuda/norm.cu +++ b/ggml/src/ggml-cuda/norm.cu @@ -201,6 +201,85 @@ static __global__ void rms_norm_back_f32( } } +// template +// static __global__ void l2_norm_f32(const float * x, float * dst, const int ncols, const float eps) { +// const int row = blockIdx.x*blockDim.y + threadIdx.y; +// const int tid = threadIdx.x; + +// float tmp = 0.0f; // partial sum for thread in warp + +// for (int col = tid; col < ncols; col += block_size) { +// const float xi = x[row*ncols + col]; +// tmp += xi * xi; +// } + +// // sum up partial sums +// tmp = warp_reduce_sum(tmp); +// if (block_size > WARP_SIZE) { +// __shared__ float s_sum[32]; +// int warp_id = threadIdx.x / WARP_SIZE; +// int lane_id = threadIdx.x % WARP_SIZE; +// if (lane_id == 0) { +// s_sum[warp_id] = tmp; +// } +// __syncthreads(); +// tmp = s_sum[lane_id]; +// tmp = warp_reduce_sum(tmp); +// } + +// // from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html +// const float scale = rsqrtf(fmaxf(tmp, eps * eps)); + +// for (int col = tid; col < ncols; col += block_size) { +// dst[row*ncols + col] = scale * x[row*ncols + col]; +// } +// } + +template +static __global__ void l2_norm_f32( + const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel, + const int64_t stride_sample, const float eps) { + const int nrows = gridDim.x; + const int nchannels = gridDim.y; + + const int row = blockIdx.x; + const int channel = blockIdx.y; + const int sample = blockIdx.z; + const int tid = threadIdx.x; + + x += sample*stride_sample + channel*stride_channel + row*stride_row; + dst += ((sample*nchannels + channel)*nrows + row)*ncols; + + float tmp = 0.0f; // partial sum for thread in warp + + for (int col = tid; col < ncols; col += block_size) { + const float xi = x[col]; + tmp += xi * xi; + } + + // sum up partial sums + tmp = warp_reduce_sum(tmp); + if constexpr (block_size > WARP_SIZE) { + static_assert(block_size == 1024, "unexpected block_size"); + __shared__ float s_sum[32]; + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + s_sum[warp_id] = tmp; + } + __syncthreads(); + tmp = s_sum[lane_id]; + tmp = warp_reduce_sum(tmp); + } + + // from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html + const float scale = rsqrtf(fmaxf(tmp, eps * eps)); + + for (int col = tid; col < ncols; col += block_size) { + dst[col] = scale * x[col]; + } +} + static void norm_f32_cuda( const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples, const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) { @@ -248,6 +327,19 @@ static void rms_norm_back_f32_cuda(const float * grad, const float * xf, float * } } +static void l2_norm_f32_cuda( + const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples, + const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) { + const dim3 blocks_num(nrows, nchannels, nsamples); + if (ncols < 1024) { + const dim3 block_dims(WARP_SIZE, 1, 1); + l2_norm_f32<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + } else { + const dim3 block_dims(1024, 1, 1); + l2_norm_f32<1024><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + } +} + void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *) src0->data; @@ -340,3 +432,27 @@ void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * d rms_norm_back_f32_cuda(grad_d, src0f_d, dst_d, ne00, nrows, eps, stream); } + +void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *) src0->data; + float * dst_d = (float *) dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + GGML_TENSOR_UNARY_OP_LOCALS; + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + GGML_ASSERT(eps >= 0.0f); + + const size_t ts0 = ggml_type_size(src0->type); + GGML_ASSERT(nb00 == ts0); + const int64_t s01 = nb01 / ts0; + const int64_t s02 = nb02 / ts0; + const int64_t s03 = nb03 / ts0; + + l2_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream); +} diff --git a/ggml/src/ggml-cuda/norm.cuh b/ggml/src/ggml-cuda/norm.cuh index d63d3438..706a5660 100644 --- a/ggml/src/ggml-cuda/norm.cuh +++ b/ggml/src/ggml-cuda/norm.cuh @@ -7,3 +7,5 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/wkv.cu b/ggml/src/ggml-cuda/wkv.cu new file mode 100644 index 00000000..d2fced70 --- /dev/null +++ b/ggml/src/ggml-cuda/wkv.cu @@ -0,0 +1,199 @@ +#include "common.cuh" +#include "wkv.cuh" + +template +static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) { + const int tid = threadIdx.x; + const int bid = blockIdx.x; + + const int head_size = block_size; + const int batch_i = bid / H; + const int head_i = bid % H; + const int state_size = C * head_size; + const int n_seq_tokens = T / B; + + float state[head_size]; + __shared__ float _k[head_size], _r[head_size], _tf[head_size], _td[head_size]; + + #pragma unroll + for (int i = 0; i < head_size; i++) { + state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid]; + } + + __syncthreads(); + _tf[tid] = tf[head_i * head_size + tid]; + __syncthreads(); + + for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) { + __syncthreads(); + _k[tid] = k[t]; + _r[tid] = r[t]; + _td[tid] = td[t]; + __syncthreads(); + + const float _v = v[t]; + float y = 0; + for (int j = 0; j < head_size; j += 4) { + const float4& k = (float4&)(_k[j]); + const float4& r = (float4&)(_r[j]); + const float4& tf = (float4&)(_tf[j]); + const float4& td = (float4&)(_td[j]); + float4& s = (float4&)(state[j]); + float4 kv; + + kv.x = k.x * _v; + kv.y = k.y * _v; + kv.z = k.z * _v; + kv.w = k.w * _v; + + y += r.x * (tf.x * kv.x + s.x); + y += r.y * (tf.y * kv.y + s.y); + y += r.z * (tf.z * kv.z + s.z); + y += r.w * (tf.w * kv.w + s.w); + + s.x = s.x * td.x + kv.x; + s.y = s.y * td.y + kv.y; + s.z = s.z * td.z + kv.z; + s.w = s.w * td.w + kv.w; + } + dst[t] = y; + } + + #pragma unroll + for (int i = 0; i < head_size; i++) { + dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i]; + } +} + +template +static __global__ void rwkv_wkv7_f32(const int B, const int T, const int C, const int H, const float * r, const float * w, const float * k, const float * v, const float * a, const float * b, const float * s, float * dst) { + const int tid = threadIdx.x; + const int bid = blockIdx.x; + + const int head_size = block_size; + const int batch_i = bid / H; + const int head_i = bid % H; + const int state_size = C * head_size; + const int n_seq_tokens = T / B; + + float state[head_size]; + __shared__ float _r[head_size], _w[head_size], _k[head_size], _a[head_size], _b[head_size]; + +#ifndef GGML_USE_MUSA + #pragma unroll +#endif + for (int i = 0; i < head_size; i++) { + state[i] = s[batch_i * state_size + head_i * head_size * head_size + tid * head_size + i]; + } + + for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) { + __syncthreads(); + _r[tid] = r[t]; + _w[tid] = w[t]; + _k[tid] = k[t]; + _a[tid] = a[t]; + _b[tid] = b[t]; + __syncthreads(); + + float sa = 0; + #pragma unroll + for (int j = 0; j < head_size; j += 4) + { + const float4& a = (float4&)(_a[j]); + const float4& s = (float4&)(state[j]); + sa += a.x * s.x; + sa += a.y * s.y; + sa += a.z * s.z; + sa += a.w * s.w; + } + + const float _v = v[t]; + float y = 0; + for (int j = 0; j < head_size; j += 4) { + const float4& r = (float4&)(_r[j]); + const float4& w = (float4&)(_w[j]); + const float4& k = (float4&)(_k[j]); + const float4& b = (float4&)(_b[j]); + float4& s = (float4&)(state[j]); + float4 kv; + + kv.x = k.x * _v; + kv.y = k.y * _v; + kv.z = k.z * _v; + kv.w = k.w * _v; + + s.x = s.x * w.x + kv.x + sa * b.x; + s.y = s.y * w.y + kv.y + sa * b.y; + s.z = s.z * w.z + kv.z + sa * b.z; + s.w = s.w * w.w + kv.w + sa * b.w; + + y += s.x * r.x; + y += s.y * r.y; + y += s.z * r.z; + y += s.w * r.w; + } + dst[t] = y; + } + + #pragma unroll + for (int i = 0; i < head_size; i++) { + dst[T * C + batch_i * state_size + head_i * head_size * head_size + tid * head_size + i] = state[i]; + } +} + +void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const float * k_d = (const float *)dst->src[0]->data; + const float * v_d = (const float *)dst->src[1]->data; + const float * r_d = (const float *)dst->src[2]->data; + const float * tf_d = (const float *)dst->src[3]->data; + const float * td_d = (const float *)dst->src[4]->data; + const float * s_d = (const float *)dst->src[5]->data; + + const int64_t B = dst->src[5]->ne[1]; + const int64_t T = dst->src[0]->ne[2]; + const int64_t C = dst->ne[0]; + const int64_t H = dst->src[0]->ne[1]; + + float * dst_d = (float *)dst->data; + + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32); + GGML_ASSERT(C % H == 0); + GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE || C / H == CUDA_WKV_BLOCK_SIZE * 2); + + if (C / H == CUDA_WKV_BLOCK_SIZE) { + rwkv_wkv_f32<<>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d); + } else { + rwkv_wkv_f32<<>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d); + } +} + +void ggml_cuda_op_rwkv_wkv7(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const float * r_d = (const float *)dst->src[0]->data; + const float * w_d = (const float *)dst->src[1]->data; + const float * k_d = (const float *)dst->src[2]->data; + const float * v_d = (const float *)dst->src[3]->data; + const float * a_d = (const float *)dst->src[4]->data; + const float * b_d = (const float *)dst->src[5]->data; + const float * s_d = (const float *)dst->src[6]->data; + + const int64_t B = dst->src[6]->ne[1]; + const int64_t T = dst->src[0]->ne[2]; + const int64_t C = dst->ne[0]; + const int64_t H = dst->src[0]->ne[1]; + + float * dst_d = (float *)dst->data; + + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32); + GGML_ASSERT(C % H == 0); + GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE || C / H == CUDA_WKV_BLOCK_SIZE * 2); + + if (C / H == CUDA_WKV_BLOCK_SIZE) { + rwkv_wkv7_f32<<>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d); + } else { + rwkv_wkv7_f32<<>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d); + } +} diff --git a/ggml/src/ggml-cuda/wkv.cuh b/ggml/src/ggml-cuda/wkv.cuh new file mode 100644 index 00000000..9623dd7f --- /dev/null +++ b/ggml/src/ggml-cuda/wkv.cuh @@ -0,0 +1,7 @@ +#include "common.cuh" + +#define CUDA_WKV_BLOCK_SIZE 64 + +void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_rwkv_wkv7(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index a58c474e..1e954b4c 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -285,6 +285,13 @@ typedef struct { float eps; } ggml_metal_kargs_rms_norm; +typedef struct { + int32_t ne00; + int32_t ne00_4; + uint64_t nb01; + float eps; +} ggml_metal_kargs_l2_norm; + typedef struct { int64_t ne00; int64_t ne01; diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index e51a4169..af65e7d9 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -184,10 +184,13 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, GGML_METAL_KERNEL_TYPE_RMS_NORM, + GGML_METAL_KERNEL_TYPE_L2_NORM, GGML_METAL_KERNEL_TYPE_GROUP_NORM, GGML_METAL_KERNEL_TYPE_NORM, GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, + GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, + GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, @@ -810,10 +813,13 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat); @@ -1251,6 +1257,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex case GGML_OP_GROUP_NORM: return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]); case GGML_OP_RMS_NORM: + case GGML_OP_L2_NORM: return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0])); case GGML_OP_ARGMAX: return true; @@ -1288,6 +1295,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex return has_simdgroup_mm; // TODO: over-restricted for vec-kernels case GGML_OP_SSM_CONV: case GGML_OP_SSM_SCAN: + case GGML_OP_RWKV_WKV6: + case GGML_OP_RWKV_WKV7: return true; case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: @@ -2216,6 +2225,83 @@ static void ggml_metal_encode_node( [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; + case GGML_OP_RWKV_WKV6: + { + const int64_t B = dst->src[5]->ne[1]; + const int64_t T = dst->src[0]->ne[2]; + const int64_t C = dst->ne[0]; + const int64_t H = dst->src[0]->ne[1]; + + GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32); + GGML_ASSERT(C % H == 0); + GGML_ASSERT(C / H == 64); + + size_t offs_src3 = 0; + size_t offs_src4 = 0; + size_t offs_src5 = 0; + + id id_src3 = dst->src[3] ? ggml_metal_get_buffer(dst->src[3], &offs_src3) : nil; + id id_src4 = dst->src[4] ? ggml_metal_get_buffer(dst->src[4], &offs_src4) : nil; + id id_src5 = dst->src[5] ? ggml_metal_get_buffer(dst->src[5], &offs_src5) : nil; + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; + [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4]; + [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:6]; + + [encoder setBytes:&B length:sizeof(B) atIndex:7]; + [encoder setBytes:&T length:sizeof(T) atIndex:8]; + [encoder setBytes:&C length:sizeof(C) atIndex:9]; + [encoder setBytes:&H length:sizeof(H) atIndex:10]; + + [encoder dispatchThreadgroups:MTLSizeMake(B * H, 1, 1) threadsPerThreadgroup:MTLSizeMake(C/ H, 1, 1)]; + } break; + case GGML_OP_RWKV_WKV7: + { + const int64_t B = dst->src[6]->ne[1]; + const int64_t T = dst->src[0]->ne[2]; + const int64_t C = dst->ne[0]; + const int64_t H = dst->src[0]->ne[1]; + + GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32); + GGML_ASSERT(C % H == 0); + GGML_ASSERT(C / H == 64); + + size_t offs_src3 = 0; + size_t offs_src4 = 0; + size_t offs_src5 = 0; + size_t offs_src6 = 0; + + id id_src3 = dst->src[3] ? ggml_metal_get_buffer(dst->src[3], &offs_src3) : nil; + id id_src4 = dst->src[4] ? ggml_metal_get_buffer(dst->src[4], &offs_src4) : nil; + id id_src5 = dst->src[5] ? ggml_metal_get_buffer(dst->src[5], &offs_src5) : nil; + id id_src6 = dst->src[6] ? ggml_metal_get_buffer(dst->src[6], &offs_src6) : nil; + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; + [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4]; + [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5]; + [encoder setBuffer:id_src6 offset:offs_src6 atIndex:6]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:7]; + + [encoder setBytes:&B length:sizeof(B) atIndex:8]; + [encoder setBytes:&T length:sizeof(T) atIndex:9]; + [encoder setBytes:&C length:sizeof(C) atIndex:10]; + [encoder setBytes:&H length:sizeof(H) atIndex:11]; + + [encoder dispatchThreadgroups:MTLSizeMake(B * H, 1, 1) threadsPerThreadgroup:MTLSizeMake(C/ H, 1, 1)]; + } break; case GGML_OP_MUL_MAT: { GGML_ASSERT(ne00 == ne10); @@ -3122,6 +3208,42 @@ static void ggml_metal_encode_node( const int64_t nrows = ggml_nrows(src0); + [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_L2_NORM: + { + GGML_ASSERT(ne00 % 4 == 0); + GGML_ASSERT(ggml_is_contiguous_1(src0)); + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_L2_NORM].pipeline; + + int nth = 32; // SIMD width + + while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) { + nth *= 2; + } + + nth = MIN(nth, ne00/4); + + ggml_metal_kargs_l2_norm args = { + /*.ne00 =*/ ne00, + /*.ne00_4 =*/ ne00/4, + /*.nb01 =*/ nb01, + /*.eps =*/ eps, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; + + const int64_t nrows = ggml_nrows(src0); + [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; case GGML_OP_GROUP_NORM: diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index ad9d42a3..3cef81b7 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1295,6 +1295,184 @@ kernel void kernel_ssm_scan_f32( } } +kernel void kernel_rwkv_wkv6_f32( + device const float * k, + device const float * v, + device const float * r, + device const float * tf, + device const float * td, + device const float * state_in, + device float * dst, + constant uint & B, + constant uint & T, + constant uint & C, + constant uint & H, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + const uint head_size = 64; // TODO: support head_size = 128 + const uint batch_id = tgpig.x / H; + const uint head_id = tgpig.x % H; + const uint tid = tpitg.x; + + if (batch_id >= B || head_id >= H) { + return; + } + + const uint state_size = C * head_size; + const uint n_seq_tokens = T / B; + + threadgroup float _k[head_size]; + threadgroup float _r[head_size]; + threadgroup float _tf[head_size]; + threadgroup float _td[head_size]; + + float state[head_size]; + + for (uint i = 0; i < head_size; i++) { + state[i] = state_in[batch_id * state_size + head_id * head_size * head_size + + i * head_size + tid]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + _tf[tid] = tf[head_id * head_size + tid]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid; + const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid; + + for (uint t = start_t; t < end_t; t += C) { + threadgroup_barrier(mem_flags::mem_threadgroup); + _k[tid] = k[t]; + _r[tid] = r[t]; + _td[tid] = td[t]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + const float v_val = v[t]; + float y = 0.0; + + for (uint j = 0; j < head_size; j += 4) { + float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]); + float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]); + float4 tf_vec = float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]); + float4 td_vec = float4(_td[j], _td[j+1], _td[j+2], _td[j+3]); + float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]); + + float4 kv = k_vec * v_val; + + float4 temp = tf_vec * kv + s_vec; + y += dot(r_vec, temp); + + s_vec = s_vec * td_vec + kv; + state[j] = s_vec[0]; + state[j+1] = s_vec[1]; + state[j+2] = s_vec[2]; + state[j+3] = s_vec[3]; + } + + dst[t] = y; + } + + for (uint i = 0; i < head_size; i++) { + dst[T * C + batch_id * state_size + head_id * head_size * head_size + + i * head_size + tid] = state[i]; + } +} + +kernel void kernel_rwkv_wkv7_f32( + device const float * r, + device const float * w, + device const float * k, + device const float * v, + device const float * a, + device const float * b, + device const float * state_in, + device float * dst, + constant uint & B, + constant uint & T, + constant uint & C, + constant uint & H, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + const uint head_size = 64; // TODO: support head_size = 128 + const uint batch_id = tgpig.x / H; + const uint head_id = tgpig.x % H; + const uint tid = tpitg.x; + + if (batch_id >= B || head_id >= H) { + return; + } + + const uint state_size = C * head_size; + const uint n_seq_tokens = T / B; + + threadgroup float _r[head_size]; + threadgroup float _w[head_size]; + threadgroup float _k[head_size]; + threadgroup float _a[head_size]; + threadgroup float _b[head_size]; + + float state[head_size]; + + for (uint i = 0; i < head_size; i++) { + state[i] = state_in[batch_id * state_size + head_id * head_size * head_size + + tid * head_size + i]; + } + + const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid; + const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid; + + for (uint t = start_t; t < end_t; t += C) { + threadgroup_barrier(mem_flags::mem_threadgroup); + _r[tid] = r[t]; + _w[tid] = w[t]; + _k[tid] = k[t]; + _a[tid] = a[t]; + _b[tid] = b[t]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + const float v_val = v[t]; + float y = 0.0, sa = 0.0; + + float4 sa_vec(0.0); + + for (int j = 0; j < head_size; j += 4) { + float4 a_vec = float4(_a[j], _a[j+1], _a[j+2], _a[j+3]); + float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]); + sa_vec += a_vec * s_vec; + } + sa = sa_vec[0] + sa_vec[1] + sa_vec[2] + sa_vec[3]; + + for (uint j = 0; j < head_size; j += 4) { + float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]); + float4 w_vec = float4(_w[j], _w[j+1], _w[j+2], _w[j+3]); + float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]); + float4 b_vec = float4(_b[j], _b[j+1], _b[j+2], _b[j+3]); + float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]); + + float4 kv = k_vec * v_val; + + s_vec = s_vec * w_vec + kv + sa * b_vec; + y += dot(s_vec, r_vec); + + state[j] = s_vec[0]; + state[j+1] = s_vec[1]; + state[j+2] = s_vec[2]; + state[j+3] = s_vec[3]; + } + + dst[t] = y; + } + + for (uint i = 0; i < head_size; i++) { + dst[T * C + batch_id * state_size + head_id * head_size * head_size + + tid * head_size + i] = state[i]; + } +} + kernel void kernel_argmax( device const void * x, device int32_t * dst, @@ -1463,6 +1641,49 @@ kernel void kernel_rms_norm( } } +kernel void kernel_l2_norm( + constant ggml_metal_kargs_l2_norm & args, + device const char * src0, + device char * dst, + threadgroup float * shmem_f32 [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + ushort tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort ntg[[threads_per_threadgroup]]) { + if (sgitg == 0) { + shmem_f32[tiisg] = 0.0f; + } + + device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01); + + float sumf = 0.0f; + + // parallel sum + for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { + sumf += dot(x[i00], x[i00]); + } + sumf = simd_sum(sumf); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + shmem_f32[sgitg] = sumf; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + sumf = shmem_f32[tiisg]; + sumf = simd_sum(sumf); + + const float scale = 1.0f/sqrt(max(sumf, args.eps)); + + device float4 * y = (device float4 *) dst + tgpig*args.ne00_4; + for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { + y[i00] = x[i00] * scale; + } +} + kernel void kernel_group_norm( device const float * src0, device float * dst, diff --git a/ggml/src/ggml-sycl/backend.hpp b/ggml/src/ggml-sycl/backend.hpp index 577ff51f..73d807ca 100644 --- a/ggml/src/ggml-sycl/backend.hpp +++ b/ggml/src/ggml-sycl/backend.hpp @@ -26,7 +26,7 @@ #include "softmax.hpp" #include "tsembd.hpp" #include "im2col.hpp" -#include "wkv6.hpp" +#include "wkv.hpp" #include "outprod.hpp" #include "element_wise.hpp" #include "cpy.hpp" diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 05984d8c..477652ab 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -2696,6 +2696,12 @@ static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * ds GGML_SYCL_DEBUG("call %s done\n", __func__); } +static void ggml_sycl_l2_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_l2_norm); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_group_norm); @@ -3410,6 +3416,9 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens case GGML_OP_RMS_NORM: ggml_sycl_rms_norm(ctx, dst); break; + case GGML_OP_L2_NORM: + ggml_sycl_l2_norm(ctx, dst); + break; case GGML_OP_MUL_MAT: if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) { return false; @@ -3487,6 +3496,9 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens case GGML_OP_RWKV_WKV6: ggml_sycl_op_rwkv_wkv6(ctx, dst); break; + case GGML_OP_RWKV_WKV7: + ggml_sycl_op_rwkv_wkv7(ctx, dst); + break; case GGML_OP_GATED_LINEAR_ATTN: ggml_sycl_op_gated_linear_attn(ctx, dst); break; @@ -4012,6 +4024,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g return (op->src[0]->type == GGML_TYPE_F32); case GGML_OP_NORM: case GGML_OP_RMS_NORM: + case GGML_OP_L2_NORM: case GGML_OP_GROUP_NORM: return ggml_is_contiguous(op->src[0]); case GGML_OP_SCALE: @@ -4045,6 +4058,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_LEAKY_RELU: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_RWKV_WKV6: + case GGML_OP_RWKV_WKV7: case GGML_OP_GATED_LINEAR_ATTN: return true; default: diff --git a/ggml/src/ggml-sycl/norm.cpp b/ggml/src/ggml-sycl/norm.cpp index 9cf2be15..6439db21 100644 --- a/ggml/src/ggml-sycl/norm.cpp +++ b/ggml/src/ggml-sycl/norm.cpp @@ -180,6 +180,50 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const floa } } +static void l2_norm_f32(const float* x, float* dst, const int ncols, const float eps, + const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) { + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + + item_ct1.get_local_id(1); + const int tid = item_ct1.get_local_id(2); + const int nthreads = item_ct1.get_local_range(2); + const int nwarps = nthreads / WARP_SIZE; + float tmp = 0.0f; // partial sum for thread in warp + + for (int col = tid; col < ncols; col += block_size) { + const float xi = x[row * ncols + col]; + tmp += xi * xi; + } + + // sum up partial sums + tmp = warp_reduce_sum(tmp, item_ct1); + if (block_size > WARP_SIZE) { + + int warp_id = item_ct1.get_local_id(2) / WARP_SIZE; + int lane_id = item_ct1.get_local_id(2) % WARP_SIZE; + if (lane_id == 0) { + s_sum[warp_id] = tmp; + } + /* + DPCT1118:3: SYCL group functions and algorithms must be encountered in + converged control flow. You may need to adjust the code. + */ + item_ct1.barrier(sycl::access::fence_space::local_space); + size_t nreduce = nwarps / WARP_SIZE; + tmp = 0.f; + for (size_t i = 0; i < nreduce; i += 1) + { + tmp += s_sum[lane_id + i * WARP_SIZE]; + } + tmp = warp_reduce_sum(tmp, item_ct1); + } + + const float scale = sycl::rsqrt(sycl::max(tmp, eps * eps)); + + for (int col = tid; col < ncols; col += block_size) { + dst[row * ncols + col] = scale * x[row * ncols + col]; + } +} + static void norm_f32_sycl(const float* x, float* dst, const int ncols, const int nrows, const float eps, queue_ptr stream, int device) { @@ -311,6 +355,48 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, } } +static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols, + const int nrows, const float eps, + queue_ptr stream, int device) { + GGML_ASSERT(ncols % WARP_SIZE == 0); + // printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE); + if (ncols < 1024) { + const sycl::range<3> block_dims(1, 1, WARP_SIZE); + stream->submit([&](sycl::handler& cgh) { + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, + block_dims), + [=](sycl::nd_item<3> item_ct1) + [[intel::reqd_sub_group_size(WARP_SIZE)]] { + l2_norm_f32(x, dst, ncols, eps, item_ct1, + nullptr, WARP_SIZE); + }); + }); + } + else { + const int work_group_size = ggml_sycl_info().max_work_group_sizes[device]; + assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0); + const sycl::range<3> block_dims(1, 1, work_group_size); + /* + DPCT1049:19: The work-group size passed to the SYCL kernel may exceed + the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if needed. + */ + stream->submit([&](sycl::handler& cgh) { + sycl::local_accessor s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE), + cgh); + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, + block_dims), + [=](sycl::nd_item<3> item_ct1) + [[intel::reqd_sub_group_size(WARP_SIZE)]] { + l2_norm_f32(x, dst, ncols, eps, item_ct1, + get_pointer(s_sum_acc_ct1), work_group_size); + }); + }); + } +} + void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, const ggml_tensor* src1, ggml_tensor* dst, const float* src0_dd, const float* src1_dd, float* dst_dd, @@ -376,3 +462,25 @@ void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* sr (void)dst; (void)src1_dd; } + +void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, + const ggml_tensor* src1, ggml_tensor* dst, + const float* src0_dd, const float* src1_dd, + float* dst_dd, + const queue_ptr& main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + const int64_t ne00 = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + l2_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device); + + (void)src1; + (void)dst; + (void)src1_dd; +} diff --git a/ggml/src/ggml-sycl/norm.hpp b/ggml/src/ggml-sycl/norm.hpp index a9ad9156..11e91680 100644 --- a/ggml/src/ggml-sycl/norm.hpp +++ b/ggml/src/ggml-sycl/norm.hpp @@ -32,4 +32,10 @@ void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* float* dst_dd, const queue_ptr& main_stream); +void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, + const ggml_tensor* src1, ggml_tensor* dst, + const float* src0_dd, const float* src1_dd, + float* dst_dd, + const queue_ptr& main_stream); + #endif // GGML_SYCL_NORM_HPP diff --git a/ggml/src/ggml-sycl/wkv.cpp b/ggml/src/ggml-sycl/wkv.cpp new file mode 100644 index 00000000..540f6fbf --- /dev/null +++ b/ggml/src/ggml-sycl/wkv.cpp @@ -0,0 +1,305 @@ +#include +#include "wkv.hpp" + +constexpr int WKV_BLOCK_SIZE = 64; // Matching CUDA_WKV_BLOCK_SIZE + +// Helper function for the main kernel +template +static void rwkv_wkv6_f32_kernel( + const int B, const int T, const int C, const int H, + const float* k, const float* v, const float* r, + const float* tf, const float* td, const float* s, + float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) { + + const int tid = item_ct1.get_local_id(2); + const int bid = item_ct1.get_group(2); + + const int head_size = block_size; + const int batch_i = bid / H; + const int head_i = bid % H; + const int state_size = C * head_size; + const int n_seq_tokens = T / B; + + // Set up shared memory pointers + float* _k = shared_mem; + float* _r = _k + head_size; + float* _tf = _r + head_size; + float* _td = _tf + head_size; + + // Local state array + float state[block_size]; + + // Load initial state + #pragma unroll + for (int i = 0; i < head_size; i++) { + state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid]; + } + + // Sync threads before shared memory operations + item_ct1.barrier(sycl::access::fence_space::local_space); + + // Load time-mixing parameters + _tf[tid] = tf[head_i * head_size + tid]; + item_ct1.barrier(sycl::access::fence_space::local_space); + + // Main sequence processing loop + for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; + t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; + t += C) { + + item_ct1.barrier(sycl::access::fence_space::local_space); + + // Load current timestep data to shared memory + _k[tid] = k[t]; + _r[tid] = r[t]; + _td[tid] = td[t]; + + item_ct1.barrier(sycl::access::fence_space::local_space); + + const float _v = v[t]; + float y = 0; + + // Process in chunks of 4 for better vectorization + sycl::float4 k4, r4, tf4, td4, s4; + #pragma unroll + for (int j = 0; j < head_size; j += 4) { + // Load data in vec4 chunks + k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]); + r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]); + tf4 = sycl::float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]); + td4 = sycl::float4(_td[j], _td[j+1], _td[j+2], _td[j+3]); + s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]); + + // Compute key-value product + sycl::float4 kv4 = k4 * _v; + + // Accumulate weighted sum + y += sycl::dot(r4, tf4 * kv4 + s4); + + // Update state + s4 = s4 * td4 + kv4; + + // Store updated state + state[j] = s4.x(); + state[j+1] = s4.y(); + state[j+2] = s4.z(); + state[j+3] = s4.w(); + } + + dst[t] = y; + } + + // Save final state + #pragma unroll + for (int i = 0; i < head_size; i++) { + dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i]; + } +} + +template +static void rwkv_wkv7_f32_kernel( + const int B, const int T, const int C, const int H, + const float* r, const float* w, const float* k, const float* v, + const float* a, const float* b, const float* s, + float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) { + + const int tid = item_ct1.get_local_id(2); + const int bid = item_ct1.get_group(2); + + const int head_size = block_size; + const int batch_i = bid / H; + const int head_i = bid % H; + const int state_size = C * head_size; + const int n_seq_tokens = T / B; + + float* _r = shared_mem; + float* _w = _r + head_size; + float* _k = _w + head_size; + float* _a = _k + head_size; + float* _b = _a + head_size; + + float state[block_size]; + + #pragma unroll + for (int i = 0; i < head_size; i++) { + state[i] = s[batch_i * state_size + head_i * head_size * head_size + tid * head_size + i]; + } + + for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; + t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; + t += C) { + + item_ct1.barrier(sycl::access::fence_space::local_space); + + _r[tid] = r[t]; + _w[tid] = w[t]; + _k[tid] = k[t]; + _a[tid] = a[t]; + _b[tid] = b[t]; + + item_ct1.barrier(sycl::access::fence_space::local_space); + + const float _v = v[t]; + float y = 0, sa = 0; + sycl::float4 a4, s4; + + #pragma unroll + for (int j = 0; j < head_size; j += 4) { + a4 = sycl::float4(_a[j], _a[j+1], _a[j+2], _a[j+3]); + s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]); + sa += sycl::dot(a4, s4); + } + + sycl::float4 r4, w4, k4, b4; + #pragma unroll + for (int j = 0; j < head_size; j += 4) { + r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]); + w4 = sycl::float4(_w[j], _w[j+1], _w[j+2], _w[j+3]); + k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]); + b4 = sycl::float4(_b[j], _b[j+1], _b[j+2], _b[j+3]); + s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]); + + sycl::float4 kv4 = k4 * _v; + + s4 = s4 * w4 + kv4 + sa * b4; + y += sycl::dot(r4, s4); + + state[j] = s4.x(); + state[j+1] = s4.y(); + state[j+2] = s4.z(); + state[j+3] = s4.w(); + } + + dst[t] = y; + } + + #pragma unroll + for (int i = 0; i < head_size; i++) { + dst[T * C + batch_i * state_size + head_i * head_size * head_size + tid * head_size + i] = state[i]; + } +} + +void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { + + const ggml_tensor *src0 = dst->src[0]; + const ggml_tensor *src1 = dst->src[1]; + + const float* k_d = (const float*)dst->src[0]->data; + const float* v_d = (const float*)dst->src[1]->data; + const float* r_d = (const float*)dst->src[2]->data; + const float* tf_d = (const float*)dst->src[3]->data; + const float* td_d = (const float*)dst->src[4]->data; + const float* s_d = (const float*)dst->src[5]->data; + float* dst_d = (float*)dst->data; + + const int64_t B = dst->src[5]->ne[1]; + const int64_t T = dst->src[0]->ne[2]; + const int64_t C = dst->ne[0]; + const int64_t H = dst->src[0]->ne[1]; + + GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32); + GGML_ASSERT(C % H == 0); + GGML_ASSERT(C / H == WKV_BLOCK_SIZE || C / H == WKV_BLOCK_SIZE * 2); // The current sycl kernel is designed for RWKV6, HEAD_SIZE == 64 + + dpct::queue_ptr stream = ctx.stream(); + + // Calculate execution configuration + const size_t shared_mem_size = C / H * 4 * sizeof(float); // For k, r, tf, td + sycl::range<3> block_dims(1, 1, C / H); + sycl::range<3> grid_dims(1, 1, B * H); + + // Submit kernel + if (C / H == WKV_BLOCK_SIZE) { + stream->submit([&](sycl::handler& cgh) { + sycl::local_accessor shared_mem_acc(shared_mem_size, cgh); + + cgh.parallel_for( + sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + rwkv_wkv6_f32_kernel( + B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d, + item_ct1, (float*)shared_mem_acc.get_multi_ptr().get() + ); + }); + }); + } else { + stream->submit([&](sycl::handler& cgh) { + sycl::local_accessor shared_mem_acc(shared_mem_size, cgh); + + cgh.parallel_for( + sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + rwkv_wkv6_f32_kernel( + B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d, + item_ct1, (float*)shared_mem_acc.get_multi_ptr().get() + ); + }); + }); + } + + GGML_UNUSED(src0); + GGML_UNUSED(src1); +} + +void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { + + const ggml_tensor *src0 = dst->src[0]; + const ggml_tensor *src1 = dst->src[1]; + + const float* r_d = (const float*)dst->src[0]->data; + const float* w_d = (const float*)dst->src[1]->data; + const float* k_d = (const float*)dst->src[2]->data; + const float* v_d = (const float*)dst->src[3]->data; + const float* a_d = (const float*)dst->src[4]->data; + const float* b_d = (const float*)dst->src[5]->data; + const float* s_d = (const float*)dst->src[6]->data; + float* dst_d = (float*)dst->data; + + const int64_t B = dst->src[6]->ne[1]; + const int64_t T = dst->src[0]->ne[2]; + const int64_t C = dst->ne[0]; + const int64_t H = dst->src[0]->ne[1]; + + GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32); + GGML_ASSERT(C % H == 0); + GGML_ASSERT(C / H == WKV_BLOCK_SIZE || C / H == WKV_BLOCK_SIZE * 2); + + dpct::queue_ptr stream = ctx.stream(); + + // Calculate execution configuration + const size_t shared_mem_size = C / H * 5 * sizeof(float); // For r, w, k, a, b + sycl::range<3> block_dims(1, 1, C / H); + sycl::range<3> grid_dims(1, 1, B * H); + + // Submit kernel + if (C / H == WKV_BLOCK_SIZE) { + stream->submit([&](sycl::handler& cgh) { + sycl::local_accessor shared_mem_acc(shared_mem_size, cgh); + + cgh.parallel_for( + sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + rwkv_wkv7_f32_kernel( + B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d, + item_ct1, (float*)shared_mem_acc.get_multi_ptr().get() + ); + }); + }); + } else { + stream->submit([&](sycl::handler& cgh) { + sycl::local_accessor shared_mem_acc(shared_mem_size, cgh); + + cgh.parallel_for( + sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + rwkv_wkv7_f32_kernel( + B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d, + item_ct1, (float*)shared_mem_acc.get_multi_ptr().get() + ); + }); + }); + } + + GGML_UNUSED(src0); + GGML_UNUSED(src1); +} diff --git a/ggml/src/ggml-sycl/wkv.hpp b/ggml/src/ggml-sycl/wkv.hpp new file mode 100644 index 00000000..9f34a100 --- /dev/null +++ b/ggml/src/ggml-sycl/wkv.hpp @@ -0,0 +1,10 @@ +#ifndef GGML_SYCL_WKV_HPP +#define GGML_SYCL_WKV_HPP + +#include "common.hpp" + +void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +#endif // GGML_SYCL_WKV_HPP diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 97398f07..c0ee5dad 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -304,6 +304,7 @@ struct vk_device_struct { vk_pipeline pipeline_group_norm_f32; vk_pipeline pipeline_rms_norm_f32; vk_pipeline pipeline_rms_norm_back_f32; + vk_pipeline pipeline_l2_norm_f32; vk_pipeline pipeline_gelu_f32; vk_pipeline pipeline_gelu_quick_f32; vk_pipeline pipeline_silu_f32; @@ -328,6 +329,7 @@ struct vk_device_struct { vk_pipeline pipeline_timestep_embedding_f32; vk_pipeline pipeline_pool2d_f32; vk_pipeline pipeline_rwkv_wkv6_f32; + vk_pipeline pipeline_rwkv_wkv7_f32; vk_pipeline pipeline_opt_step_adamw_f32; // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned} @@ -629,6 +631,13 @@ struct vk_op_rwkv_wkv6_push_constants { uint32_t H; }; +struct vk_op_rwkv_wkv7_push_constants { + uint32_t B; + uint32_t T; + uint32_t C; + uint32_t H; +}; + // Allow pre-recording command buffers struct vk_staging_memcpy { vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {} @@ -2263,6 +2272,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); @@ -2374,6 +2384,8 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); + ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); for (auto &c : compiles) { @@ -5473,6 +5485,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_rms_norm_back_f32; } return nullptr; + case GGML_OP_L2_NORM: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_l2_norm_f32; + } + return nullptr; case GGML_OP_UNARY: switch (ggml_get_unary_op(dst)) { case GGML_UNARY_OP_SILU: @@ -5612,6 +5629,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_rwkv_wkv6_f32; } return nullptr; + case GGML_OP_RWKV_WKV7: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_rwkv_wkv7_f32; + } + return nullptr; case GGML_OP_OPT_STEP_ADAMW: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_opt_step_adamw_f32; @@ -5859,6 +5881,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co case GGML_OP_NORM: case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM_BACK: + case GGML_OP_L2_NORM: case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX_BACK: case GGML_OP_SUM_ROWS: @@ -6108,23 +6131,17 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const }, dryrun); } -static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, bool dryrun = false) { - const ggml_tensor * k = dst->src[0]; - const ggml_tensor * v = dst->src[1]; - const ggml_tensor * r = dst->src[2]; - const ggml_tensor * tf = dst->src[3]; - const ggml_tensor * td = dst->src[4]; - const ggml_tensor * state = dst->src[5]; +static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, int version, bool dryrun = false) { + GGML_ASSERT(version == 6 || version == 7); + int num_srcs = version == 6 ? 6 : 7; + + for (int i = 0; i < num_srcs; i++) { + GGML_ASSERT(!ggml_is_quantized(dst->src[i]->type)); + } - GGML_ASSERT(!ggml_is_quantized(k->type)); - GGML_ASSERT(!ggml_is_quantized(v->type)); - GGML_ASSERT(!ggml_is_quantized(r->type)); - GGML_ASSERT(!ggml_is_quantized(tf->type)); - GGML_ASSERT(!ggml_is_quantized(td->type)); - GGML_ASSERT(!ggml_is_quantized(state->type)); GGML_ASSERT(dst->buffer != nullptr); - vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, k, v, r, dst, GGML_OP_RWKV_WKV6); + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op); GGML_ASSERT(pipeline != nullptr); if (dryrun) { @@ -6133,89 +6150,73 @@ static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subc } ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; - ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context; - ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context; - ggml_backend_vk_buffer_context * r_buf_ctx = (ggml_backend_vk_buffer_context *)r->buffer->context; - ggml_backend_vk_buffer_context * tf_buf_ctx = (ggml_backend_vk_buffer_context *)tf->buffer->context; - ggml_backend_vk_buffer_context * td_buf_ctx = (ggml_backend_vk_buffer_context *)td->buffer->context; - ggml_backend_vk_buffer_context * state_buf_ctx = (ggml_backend_vk_buffer_context *)state->buffer->context; + ggml_backend_vk_buffer_context * src_buf_ctxs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr }; + for (int i = 0; i < num_srcs; i++) { + src_buf_ctxs[i] = (ggml_backend_vk_buffer_context *)dst->src[i]->buffer->context; + } ggml_vk_sync_buffers(subctx); - vk_buffer d_D = nullptr, d_K = nullptr, d_V = nullptr, d_R = nullptr, d_TF = nullptr, d_TD = nullptr, d_State = nullptr; - size_t k_offset = 0, v_offset = 0, r_offset = 0, tf_offset = 0, td_offset = 0, state_offset = 0, dst_offset = 0; - bool K_uma = false, V_uma = false, R_uma = false, TF_uma = false, TD_uma = false, STATE_uma = false, DST_uma = false; + vk_buffer d_D = nullptr, d_srcs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr }; + size_t dst_offset = 0, src_offsets[7] = { 0, 0, 0, 0, 0, 0, 0 }; + bool dst_uma = false, srcs_uma[7] = { false, false, false, false, false, false, false }; if (ctx->device->uma) { - ggml_vk_host_get(ctx->device, k->data, d_K, k_offset); - ggml_vk_host_get(ctx->device, v->data, d_V, v_offset); - ggml_vk_host_get(ctx->device, r->data, d_R, r_offset); - ggml_vk_host_get(ctx->device, tf->data, d_TF, tf_offset); - ggml_vk_host_get(ctx->device, td->data, d_TD, td_offset); - ggml_vk_host_get(ctx->device, state->data, d_State, state_offset); + for (int i = 0; i < num_srcs; i++) { + ggml_vk_host_get(ctx->device, dst->src[i]->data, d_srcs[i], src_offsets[i]); + srcs_uma[i] = d_srcs[i] != nullptr; + } + ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset); - - K_uma = d_K != nullptr; - V_uma = d_V != nullptr; - R_uma = d_R != nullptr; - TF_uma = d_TF != nullptr; - TD_uma = d_TD != nullptr; - STATE_uma = d_State != nullptr; - DST_uma = d_D != nullptr; + dst_uma = d_D != nullptr; } - if (!K_uma) { - d_K = k_buf_ctx->dev_buffer; - k_offset = vk_tensor_offset(k) + k->view_offs; + uint64_t src_sizes[7] = { 0, 0, 0, 0, 0, 0, 0 }; + for (int i = 0; i < num_srcs; i++) { + src_sizes[i] = ggml_nbytes(dst->src[i]); + if (!srcs_uma[i]) { + d_srcs[i] = src_buf_ctxs[i]->dev_buffer; + src_offsets[i] = vk_tensor_offset(dst->src[i]) + dst->src[i]->view_offs; + } } - if (!V_uma) { - d_V = v_buf_ctx->dev_buffer; - v_offset = vk_tensor_offset(v) + v->view_offs; - } - if (!R_uma) { - d_R = r_buf_ctx->dev_buffer; - r_offset = vk_tensor_offset(r) + r->view_offs; - } - if (!TF_uma) { - d_TF = tf_buf_ctx->dev_buffer; - tf_offset = vk_tensor_offset(tf) + tf->view_offs; - } - if (!TD_uma) { - d_TD = td_buf_ctx->dev_buffer; - td_offset = vk_tensor_offset(td) + td->view_offs; - } - if (!STATE_uma) { - d_State = state_buf_ctx->dev_buffer; - state_offset = vk_tensor_offset(state) + state->view_offs; - } - if (!DST_uma) { + + const uint64_t dst_size = ggml_nbytes(dst); + if (!dst_uma) { d_D = dst_buf_ctx->dev_buffer; dst_offset = vk_tensor_offset(dst) + dst->view_offs; } - const uint64_t k_size = ggml_nbytes(k); - const uint64_t v_size = ggml_nbytes(v); - const uint64_t r_size = ggml_nbytes(r); - const uint64_t tf_size = ggml_nbytes(tf); - const uint64_t td_size = ggml_nbytes(td); - const uint64_t state_size = ggml_nbytes(state); - const uint64_t dst_size = ggml_nbytes(dst); - std::array elements = { (uint32_t)(pc.B * pc.H), 1, 1 }; - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { - vk_subbuffer{ d_K, k_offset, k_size }, - vk_subbuffer{ d_V, v_offset, v_size }, - vk_subbuffer{ d_R, r_offset, r_size }, - vk_subbuffer{ d_TF, tf_offset, tf_size }, - vk_subbuffer{ d_TD, td_offset, td_size }, - vk_subbuffer{ d_State, state_offset, state_size }, - vk_subbuffer{ d_D, dst_offset, dst_size } - }, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements); + if (version == 6) { + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { + vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] }, + vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] }, + vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] }, + vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] }, + vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] }, + vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] }, + vk_subbuffer{ d_D, dst_offset, dst_size } + }, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements); + } else if (version == 7) { + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { + vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] }, + vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] }, + vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] }, + vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] }, + vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] }, + vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] }, + vk_subbuffer{ d_srcs[6], src_offsets[6], src_sizes[6] }, + vk_subbuffer{ d_D, dst_offset, dst_size } + }, sizeof(vk_op_rwkv_wkv7_push_constants), &pc, elements); + } else { + // shouldn't happen + GGML_ASSERT(false); + } } static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) { @@ -6224,7 +6225,7 @@ static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, const size_t n_heads = dst->src[0]->ne[1]; const size_t n_seqs = dst->src[5]->ne[1]; - ggml_vk_op_f32_rwkv6( + ggml_vk_op_f32_wkv( ctx, subctx, dst, { (uint32_t)n_seqs, @@ -6232,6 +6233,26 @@ static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, (uint32_t)n_embed, (uint32_t)n_heads, }, + 6, + dryrun + ); +} + +static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) { + const size_t seq_length = dst->src[0]->ne[2]; + const size_t n_embed = dst->ne[0]; + const size_t n_heads = dst->src[0]->ne[1]; + const size_t n_seqs = dst->src[6]->ne[1]; + + ggml_vk_op_f32_wkv( + ctx, subctx, dst, + { + (uint32_t)n_seqs, + (uint32_t)seq_length, + (uint32_t)n_embed, + (uint32_t)n_heads, + }, + 7, dryrun ); } @@ -6533,6 +6554,11 @@ static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& sub ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun); } +static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + float * op_params = (float *)dst->op_params; + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun); +} + static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); } @@ -7528,6 +7554,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_GROUP_NORM: case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM_BACK: + case GGML_OP_L2_NORM: case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX_BACK: @@ -7544,6 +7571,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_POOL_2D: case GGML_OP_RWKV_WKV6: + case GGML_OP_RWKV_WKV7: case GGML_OP_LEAKY_RELU: case GGML_OP_FLASH_ATTN_EXT: case GGML_OP_OPT_STEP_ADAMW: @@ -7590,6 +7618,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_GROUP_NORM: case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM_BACK: + case GGML_OP_L2_NORM: case GGML_OP_UNARY: case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: @@ -7707,6 +7736,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_RMS_NORM_BACK: ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun); + break; + case GGML_OP_L2_NORM: + ggml_vk_l2_norm(ctx, compute_ctx, src0, node, dryrun); + break; case GGML_OP_UNARY: switch (ggml_get_unary_op(node)) { @@ -7797,6 +7830,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod break; + case GGML_OP_RWKV_WKV7: + ggml_vk_rwkv_wkv7(ctx, compute_ctx, node, dryrun); + + break; + case GGML_OP_OPT_STEP_ADAMW: ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun); @@ -7870,6 +7908,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * case GGML_OP_GROUP_NORM: case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM_BACK: + case GGML_OP_L2_NORM: case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX_BACK: @@ -7889,6 +7928,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_POOL_2D: case GGML_OP_RWKV_WKV6: + case GGML_OP_RWKV_WKV7: case GGML_OP_LEAKY_RELU: case GGML_OP_REPEAT: case GGML_OP_REPEAT_BACK: @@ -8806,6 +8846,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_NORM: case GGML_OP_GROUP_NORM: case GGML_OP_RMS_NORM: + case GGML_OP_L2_NORM: return ggml_is_contiguous(op->src[0]); case GGML_OP_ADD: case GGML_OP_SUB: @@ -8835,6 +8876,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_POOL_2D: case GGML_OP_RWKV_WKV6: + case GGML_OP_RWKV_WKV7: case GGML_OP_LEAKY_RELU: case GGML_OP_OPT_STEP_ADAMW: return true; @@ -9219,6 +9261,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { tensor_clone = ggml_rms_norm_back(ggml_ctx, src_clone[0], src_clone[1], eps); } else if (tensor->op == GGML_OP_SILU_BACK) { tensor_clone = ggml_silu_back(ggml_ctx, src_clone[0], src_clone[1]); + } else if (tensor->op == GGML_OP_L2_NORM) { + const float eps = ((float *) tensor->op_params)[0]; + tensor_clone = ggml_l2_norm(ggml_ctx, src_clone[0], eps); } else if (tensor->op == GGML_OP_SOFT_MAX) { if (src1 != nullptr) { tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]); @@ -9338,6 +9383,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { } else if (tensor->op == GGML_OP_RWKV_WKV6) { tensor_clone = ggml_rwkv_wkv6(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], src_clone[4], src_clone[5]); + } else if (tensor->op == GGML_OP_RWKV_WKV7) { + tensor_clone = ggml_rwkv_wkv7(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], + src_clone[4], src_clone[5], src_clone[6]); } else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) { src_clone[0]->flags = src0->flags; tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1], diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp new file mode 100644 index 00000000..deba8c39 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp @@ -0,0 +1,41 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable +#define BLOCK_SIZE 512 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +shared FLOAT_TYPE sum[BLOCK_SIZE]; + +void main() { + const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint tid = gl_LocalInvocationID.x; + + sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp + + [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { + const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]); + sum[tid] += xi * xi; + } + + // sum up partial sums and write back result + barrier(); + [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + sum[tid] += sum[tid + s]; + } + barrier(); + } + + const FLOAT_TYPE scale = inversesqrt(max(sum[0], FLOAT_TYPE(p.param1))); + + [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { + data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col])); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index ee1fec4e..eb2ad63f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -434,6 +434,7 @@ void process_shaders() { string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}); @@ -528,6 +529,8 @@ void process_shaders() { string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); for (auto &c : compiles) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp b/ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp new file mode 100644 index 00000000..88c1c02b --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp @@ -0,0 +1,91 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : require + +#define BLOCK_SIZE 64 +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout(push_constant) uniform Parameters { + uint B; + uint T; + uint C; + uint H; +}; + +layout(binding = 0) readonly buffer RBuf { A_TYPE r[]; }; +layout(binding = 1) readonly buffer WBuf { A_TYPE w[]; }; +layout(binding = 2) readonly buffer KBuf { A_TYPE k[]; }; +layout(binding = 3) readonly buffer VBuf { A_TYPE v[]; }; +layout(binding = 4) readonly buffer ABuf { A_TYPE a[]; }; +layout(binding = 5) readonly buffer BBuf { A_TYPE b[]; }; +layout(binding = 6) readonly buffer StateBuf { A_TYPE state_in[]; }; +layout(binding = 7) buffer DstBuf { A_TYPE dst[]; }; + +shared A_TYPE _r[BLOCK_SIZE], _w[BLOCK_SIZE], _k[BLOCK_SIZE], _a[BLOCK_SIZE], _b[BLOCK_SIZE]; + +void main() { + const uint head_size = BLOCK_SIZE; + const uint batch_id = gl_WorkGroupID.x / H; + const uint head_id = gl_WorkGroupID.x % H; + const uint tid = gl_LocalInvocationID.x; + + const uint state_size = C * head_size; + const uint n_seq_tokens = T / B; + + if (batch_id >= B || head_id >= H) { + return; + } + + A_TYPE state[BLOCK_SIZE]; + [[unroll]] for (uint i = 0; i < head_size; i++) { + state[i] = state_in[batch_id * state_size + head_id * head_size * head_size + + tid * head_size + i]; + } + + const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid; + const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid; + + for (uint t = start_t; t < end_t; t += C) { + barrier(); + _r[tid] = r[t]; + _w[tid] = w[t]; + _k[tid] = k[t]; + _a[tid] = a[t]; + _b[tid] = b[t]; + barrier(); + + A_TYPE sa = 0.0; + [[unroll]] for (uint j = 0; j < head_size; j += 4) { + vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]); + vec4 a_vec = vec4(_a[j], _a[j+1], _a[j+2], _a[j+3]); + sa += dot(s_vec, a_vec); + } + + const A_TYPE v_val = v[t]; + A_TYPE y = 0.0; + + [[unroll]] for (uint j = 0; j < head_size; j += 4) { + vec4 r_vec = vec4(_r[j], _r[j+1], _r[j+2], _r[j+3]); + vec4 w_vec = vec4(_w[j], _w[j+1], _w[j+2], _w[j+3]); + vec4 k_vec = vec4(_k[j], _k[j+1], _k[j+2], _k[j+3]); + vec4 b_vec = vec4(_b[j], _b[j+1], _b[j+2], _b[j+3]); + vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]); + + vec4 kv = k_vec * v_val; + s_vec = s_vec * w_vec + kv + sa * b_vec; + y += dot(r_vec, s_vec); + + state[j] = s_vec.x; + state[j+1] = s_vec.y; + state[j+2] = s_vec.z; + state[j+3] = s_vec.w; + } + + dst[t] = y; + } + + [[unroll]] for (uint i = 0; i < head_size; i++) { + dst[T * C + batch_id * state_size + head_id * head_size * head_size + + tid * head_size + i] = state[i]; + } +} diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 89409bb0..2e081d59 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -929,6 +929,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "RMS_NORM", "RMS_NORM_BACK", "GROUP_NORM", + "L2_NORM", "MUL_MAT", "MUL_MAT_ID", @@ -977,6 +978,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "ADD_REL_POS", "RWKV_WKV6", "GATED_LINEAR_ATTN", + "RWKV_WKV7", "UNARY", @@ -996,7 +998,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "OPT_STEP_ADAMW", }; -static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83"); +static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1026,6 +1028,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "rms_norm(x)", "rms_norm_back(x)", "group_norm(x)", + "l2_norm(x)", "X*Y", "X[i]*Y", @@ -1074,6 +1077,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "add_rel_pos(x)", "rwkv_wkv6(k, v, r, tf, td, s)", "gated_linear_attn(k, v, q, gate, s)", + "rwkv_wkv7(r, w, k, v, a, b, s)", "unary(x)", @@ -1093,7 +1097,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "adamw(x)", }; -static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83"); +static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -2686,6 +2690,37 @@ struct ggml_tensor * ggml_group_norm_inplace( return ggml_group_norm_impl(ctx, a, n_groups, eps, true); } +// ggml_l2_norm + +static struct ggml_tensor * ggml_l2_norm_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + float eps, + bool inplace) { + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + ggml_set_op_params_f32(result, 0, eps); + + result->op = GGML_OP_L2_NORM; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_l2_norm( + struct ggml_context * ctx, + struct ggml_tensor * a, + float eps) { + return ggml_l2_norm_impl(ctx, a, eps, false); +} + +struct ggml_tensor * ggml_l2_norm_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + float eps) { + return ggml_l2_norm_impl(ctx, a, eps, true); +} + // ggml_mul_mat static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { @@ -4720,6 +4755,54 @@ struct ggml_tensor * ggml_gated_linear_attn( return result; } +// ggml_rwkv_wkv7 + +struct ggml_tensor * ggml_rwkv_wkv7( + struct ggml_context * ctx, + struct ggml_tensor * r, + struct ggml_tensor * w, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * state) { + GGML_ASSERT(ggml_is_contiguous(r)); + GGML_ASSERT(ggml_is_contiguous(w)); + GGML_ASSERT(ggml_is_contiguous(k)); + GGML_ASSERT(ggml_is_contiguous(v)); + GGML_ASSERT(ggml_is_contiguous(a)); + GGML_ASSERT(ggml_is_contiguous(b)); + GGML_ASSERT(ggml_is_contiguous(state)); + + const int64_t S = k->ne[0]; + const int64_t H = k->ne[1]; + const int64_t n_tokens = k->ne[2]; + const int64_t n_seqs = state->ne[1]; + { + GGML_ASSERT(w->ne[0] == S && w->ne[1] == H && w->ne[2] == n_tokens); + GGML_ASSERT(k->ne[0] == S && k->ne[1] == H && k->ne[2] == n_tokens); + GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens); + GGML_ASSERT(a->ne[0] == S && a->ne[1] == H && a->ne[2] == n_tokens); + GGML_ASSERT(b->ne[0] == S && b->ne[1] == H && b->ne[2] == n_tokens); + GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs); + } + + // concat output and new_state + const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + result->op = GGML_OP_RWKV_WKV7; + result->src[0] = r; + result->src[1] = w; + result->src[2] = k; + result->src[3] = v; + result->src[4] = a; + result->src[5] = b; + result->src[6] = state; + + return result; +} + // ggml_unary static struct ggml_tensor * ggml_unary_impl(