llama: add support for QRWKV6 model architecture (llama/11001)

llama: add support for QRWKV6 model architecture (llama/11001)

* WIP: Add support for RWKV6Qwen2

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* RWKV: Some graph simplification

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* Add support for RWKV6Qwen2 with cpu and cuda GLA

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* RWKV6[QWEN2]: Concat lerp weights together to reduce cpu overhead

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* Fix some typos

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* code format changes

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* Fix wkv test & add gla test

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* Fix cuda warning

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* Update README.md

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* Update ggml/src/ggml-cuda/gla.cu

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Fix fused lerp weights loading with RWKV6

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* better sanity check skipping for QRWKV6 in llama-quant

thanks @compilade

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
Co-authored-by: compilade <git@compilade.net>

---------

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: compilade <git@compilade.net>
This commit is contained in:
Molly Sophia 2025-01-10 09:58:08 +08:00 committed by Georgi Gerganov
parent c3235bd81e
commit 06209f6683
9 changed files with 367 additions and 17 deletions

View File

@ -501,6 +501,7 @@ extern "C" {
GGML_OP_GET_REL_POS,
GGML_OP_ADD_REL_POS,
GGML_OP_RWKV_WKV6,
GGML_OP_GATED_LINEAR_ATTN,
GGML_OP_UNARY,
@ -1859,6 +1860,15 @@ extern "C" {
struct ggml_tensor * td,
struct ggml_tensor * state);
GGML_API struct ggml_tensor * ggml_gated_linear_attn(
struct ggml_context * ctx,
struct ggml_tensor * k,
struct ggml_tensor * v,
struct ggml_tensor * q,
struct ggml_tensor * g,
struct ggml_tensor * state,
float scale);
// custom operators
typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);

View File

@ -11803,9 +11803,9 @@ static void ggml_compute_forward_add_rel_pos(
static void ggml_compute_forward_rwkv_wkv6_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const int64_t T = dst->src[1]->ne[3];
const int64_t T = dst->src[1]->ne[2];
const int64_t C = dst->ne[0];
const int64_t HEADS = dst->src[1]->ne[2];
const int64_t HEADS = dst->src[1]->ne[1];
const int64_t n_seqs = dst->src[5]->ne[1];
const int64_t head_size = C / HEADS;
@ -12000,6 +12000,197 @@ static void ggml_compute_forward_rwkv_wkv6(
}
}
// ggml_compute_forward_gla
static void ggml_compute_forward_gla_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const int64_t T = dst->src[1]->ne[2];
const int64_t C = dst->ne[0];
const int64_t HEADS = dst->src[1]->ne[1];
const int64_t n_seqs = dst->src[4]->ne[1];
const int64_t head_size = C / HEADS;
const float scale = ggml_get_op_params_f32(dst, 0);
float * dst_data = (float *) dst->data;
float * state = ((float *) dst->data) + C * T;
const int ith = params->ith;
const int nth = params->nth;
if (ith >= HEADS) {
return;
}
const int h_start = (HEADS * ith) / nth;
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
(HEADS * (ith + 1)) / nth : HEADS;
float * k = (float *) dst->src[0]->data;
float * v = (float *) dst->src[1]->data;
float * q = (float *) dst->src[2]->data;
float * g = (float *) dst->src[3]->data;
size_t t_stride = HEADS * head_size; // Same to C
size_t h_stride = C / HEADS;
GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
size_t h_stride_2d = head_size * head_size;
if (ith == 0) {
memset(dst_data, 0, T * C * sizeof(float));
}
ggml_barrier(params->threadpool);
#if defined(__AVX__) && !defined(__AVX512F__)
#define GGML_F32X GGML_F32x8
#define GGML_F32X_SET1 GGML_F32x8_SET1
#define GGML_F32X_LOAD GGML_F32x8_LOAD
#define GGML_F32X_STORE GGML_F32x8_STORE
#define GGML_F32X_MUL GGML_F32x8_MUL
#define GGML_F32X_FMA GGML_F32x8_FMA
#define GLA_VECTOR_SIZE 8
#elif defined(__AVX512F__)
#define GGML_F32X GGML_F32x16
#define GGML_F32X_SET1 GGML_F32x16_SET1
#define GGML_F32X_LOAD GGML_F32x16_LOAD
#define GGML_F32X_STORE GGML_F32x16_STORE
#define GGML_F32X_MUL GGML_F32x16_MUL
#define GGML_F32X_FMA GGML_F32x16_FMA
#define GLA_VECTOR_SIZE 16
#elif defined(__ARM_NEON) && defined(__aarch64__)
#define GGML_F32X GGML_F32x4
#define GGML_F32X_SET1 GGML_F32x4_SET1
#define GGML_F32X_LOAD GGML_F32x4_LOAD
#define GGML_F32X_STORE GGML_F32x4_STORE
#define GGML_F32X_MUL GGML_F32x4_MUL
#define GGML_F32X_FMA GGML_F32x4_FMA
#define GLA_VECTOR_SIZE 4
#endif
#ifdef GLA_VECTOR_SIZE
const int64_t vec_count = head_size / GLA_VECTOR_SIZE;
for (int64_t t = 0; t < T; t++) {
size_t t_offset = t * t_stride;
size_t state_offset = head_size * C * (t / (T / n_seqs));
float * state_cur = state + state_offset;
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;
for (int64_t h = h_start; h < h_end; h++) {
size_t h_offset = h * h_stride;
size_t t_h_offset = t_offset + h_offset;
size_t h_2d_offset = h * h_stride_2d;
for (int64_t i = 0; i < head_size; i++) {
size_t t_h_i_offset = t_h_offset + i;
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
float k_val = k[t_h_i_offset];
float q_val = q[t_h_i_offset] * scale;
float g_val = g[t_h_i_offset];
// Broadcast scalar values to vectors
GGML_F32X k_vec = GGML_F32X_SET1(k_val);
GGML_F32X q_vec = GGML_F32X_SET1(q_val);
GGML_F32X g_vec = GGML_F32X_SET1(g_val);
for (int64_t j = 0; j < vec_count; j++) {
size_t base_j = j * GLA_VECTOR_SIZE;
size_t t_h_j_offset = t_h_offset + base_j;
size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
// Load x elements at once
GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]);
GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]);
GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]);
// Compute kv = v * k
GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec);
// Compute temp = prev_state * g + kv
GGML_F32X temp_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, g_vec);
// Update dst: dst += temp * q
dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, q_vec);
GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec);
// Update state
GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], temp_vec);
}
// Handle remaining elements, this will not be used.
for (int64_t j = vec_count * GLA_VECTOR_SIZE; j < head_size; j++) {
size_t t_h_j_offset = t_h_offset + j;
size_t h_2d_i_j_offset = h_2d_i_offset + j;
float v_val = v[t_h_j_offset];
float kv_val = v_val * k_val;
float prev_state_val = state_prev[h_2d_i_j_offset];
float temp_val = kv_val + prev_state_val * g_val;
dst_data[t_h_j_offset] += temp_val * q_val;
state_cur[h_2d_i_j_offset] = temp_val;
}
}
}
}
#else
for (int64_t t = 0; t < T; t++) {
size_t t_offset = t * t_stride;
size_t state_offset = head_size * C * (t / (T / n_seqs));
float * state_cur = state + state_offset;
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;
for (int64_t h = h_start; h < h_end; h++) {
size_t h_offset = h * h_stride;
size_t t_h_offset = t_offset + h_offset;
size_t h_2d_offset = h * h_stride_2d;
for (int64_t i = 0; i < head_size; i++) {
size_t t_h_i_offset = t_h_offset + i;
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
float k_val = k[t_h_i_offset];
float q_val = q[t_h_i_offset] * scale;
float g_val = g[t_h_i_offset];
for (int64_t j = 0; j < head_size; j++) {
size_t t_h_j_offset = t_h_offset + j;
size_t h_2d_i_j_offset = h_2d_i_offset + j;
float v_val = v[t_h_j_offset];
float kv_val = v_val * k_val;
float prev_state_val = state_prev[h_2d_i_j_offset];
float temp_val = prev_state_val * g_val + kv_val;
dst_data[t_h_j_offset] += temp_val * q_val;
state_cur[h_2d_i_j_offset] = temp_val;
}
}
}
}
#endif
}
static void ggml_compute_forward_gla(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_gla_f32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
// ggml_compute_forward_map_unary
static void ggml_compute_forward_map_unary_f32(
@ -12749,6 +12940,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_rwkv_wkv6(params, tensor);
} break;
case GGML_OP_GATED_LINEAR_ATTN:
{
ggml_compute_forward_gla(params, tensor);
} break;
case GGML_OP_MAP_UNARY:
{
ggml_unary_op_f32_t fun;
@ -13047,6 +13242,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_WIN_UNPART:
case GGML_OP_GET_REL_POS:
case GGML_OP_RWKV_WKV6:
case GGML_OP_GATED_LINEAR_ATTN:
case GGML_OP_MAP_UNARY:
case GGML_OP_MAP_BINARY:
case GGML_OP_MAP_CUSTOM1_F32:

View File

@ -37,6 +37,7 @@
#include "ggml-cuda/unary.cuh"
#include "ggml-cuda/upscale.cuh"
#include "ggml-cuda/wkv6.cuh"
#include "ggml-cuda/gla.cuh"
#include <algorithm>
#include <array>
@ -2167,6 +2168,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_RWKV_WKV6:
ggml_cuda_op_rwkv_wkv6(ctx, dst);
break;
case GGML_OP_GATED_LINEAR_ATTN:
ggml_cuda_op_gated_linear_attn(ctx, dst);
break;
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
ggml_cuda_cross_entropy_loss_back(ctx, dst);
break;
@ -3011,6 +3015,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_LEAKY_RELU:
case GGML_OP_RWKV_WKV6:
case GGML_OP_GATED_LINEAR_ATTN:
return true;
case GGML_OP_FLASH_ATTN_EXT: {
#ifndef FLASH_ATTN_AVAILABLE

93
ggml/src/ggml-cuda/gla.cu Normal file
View File

@ -0,0 +1,93 @@
#include "common.cuh"
#include "gla.cuh"
template<int HEAD_SIZE>
static __global__ void gated_linear_attn_f32(const int B, const int T, const int C, const int H, const float scale,
const float * k, const float * v, const float * r, const float * td, const float * s, float * dst) {
const int tid = threadIdx.x;
const int bid = blockIdx.x;
const int head_size = HEAD_SIZE;
const int batch_i = bid / H;
const int head_i = bid % H;
const int state_size = C * head_size;
const int n_seq_tokens = T / B;
float state[head_size];
__shared__ float _k[head_size], _r[head_size], _td[head_size];
#pragma unroll
for (int i = 0; i < head_size; i++) {
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
}
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
__syncthreads();
_k[tid] = k[t];
_r[tid] = r[t];
_td[tid] = td[t];
__syncthreads();
const float _v = v[t];
float y = 0;
for (int j = 0; j < head_size; j += 4) {
const float4 & k = (float4 &)(_k[j]);
const float4 & r = (float4 &)(_r[j]);
const float4 & td = (float4 &)(_td[j]);
float4 & s = (float4 &)(state[j]);
float4 kv;
kv.x = k.x * _v;
kv.y = k.y * _v;
kv.z = k.z * _v;
kv.w = k.w * _v;
s.x = s.x * td.x + kv.x;
s.y = s.y * td.y + kv.y;
s.z = s.z * td.z + kv.z;
s.w = s.w * td.w + kv.w;
y += r.x * s.x;
y += r.y * s.y;
y += r.z * s.z;
y += r.w * s.w;
}
dst[t] = y * scale;
}
#pragma unroll
for (int i = 0; i < head_size; i++) {
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
}
}
void ggml_cuda_op_gated_linear_attn(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const float * k_d = (const float *)dst->src[0]->data;
const float * v_d = (const float *)dst->src[1]->data;
const float * r_d = (const float *)dst->src[2]->data;
const float * td_d = (const float *)dst->src[3]->data;
const float * s_d = (const float *)dst->src[4]->data;
const int64_t B = dst->src[4]->ne[1];
const int64_t T = dst->src[0]->ne[2];
const int64_t C = dst->ne[0];
const int64_t H = dst->src[0]->ne[1];
float scale;
memcpy(&scale, (float*)dst->op_params, sizeof(float));
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(dst->src[4]->type == GGML_TYPE_F32);
GGML_ASSERT(C % H == 0);
GGML_ASSERT(C / H == 64 || C / H == 128);
if (C / H == 64) {
gated_linear_attn_f32<64><<<B * H, C / H, 0, stream>>>(B, T, C, H, scale, k_d, v_d, r_d, td_d, s_d, dst_d);
} else {
gated_linear_attn_f32<128><<<B * H, C / H, 0, stream>>>(B, T, C, H, scale, k_d, v_d, r_d, td_d, s_d, dst_d);
}
}

View File

@ -0,0 +1,3 @@
#include "common.cuh"
void ggml_cuda_op_gated_linear_attn(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@ -73,9 +73,9 @@ void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
const float * s_d = (const float *)dst->src[5]->data;
const int64_t B = dst->src[5]->ne[1];
const int64_t T = dst->src[0]->ne[3];
const int64_t T = dst->src[0]->ne[2];
const int64_t C = dst->ne[0];
const int64_t H = dst->src[0]->ne[2];
const int64_t H = dst->src[0]->ne[1];
float * dst_d = (float *)dst->data;

View File

@ -109,9 +109,9 @@ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
float* dst_d = (float*)dst->data;
const int64_t B = dst->src[5]->ne[1];
const int64_t T = dst->src[0]->ne[3];
const int64_t T = dst->src[0]->ne[2];
const int64_t C = dst->ne[0];
const int64_t H = dst->src[0]->ne[2];
const int64_t H = dst->src[0]->ne[1];
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
GGML_ASSERT(C % H == 0);

View File

@ -5633,9 +5633,9 @@ static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subc
}
static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
const size_t seq_length = dst->src[0]->ne[3];
const size_t seq_length = dst->src[0]->ne[2];
const size_t n_embed = dst->ne[0];
const size_t n_heads = dst->src[0]->ne[2];
const size_t n_heads = dst->src[0]->ne[1];
const size_t n_seqs = dst->src[5]->ne[1];
ggml_vk_op_f32_rwkv6(

View File

@ -968,6 +968,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"GET_REL_POS",
"ADD_REL_POS",
"RWKV_WKV6",
"GATED_LINEAR_ATTN",
"UNARY",
@ -987,7 +988,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"OPT_STEP_ADAMW",
};
static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@ -1064,6 +1065,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"get_rel_pos(x)",
"add_rel_pos(x)",
"rwkv_wkv6(k, v, r, tf, td, s)",
"gated_linear_attn(k, v, q, gate, s)",
"unary(x)",
@ -1083,7 +1085,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"adamw(x)",
};
static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@ -4629,15 +4631,13 @@ struct ggml_tensor * ggml_rwkv_wkv6(
GGML_ASSERT(ggml_is_contiguous(state));
const int64_t S = k->ne[0];
const int64_t H = k->ne[2];
const int64_t n_tokens = k->ne[3];
const int64_t H = k->ne[1];
const int64_t n_tokens = k->ne[2];
const int64_t n_seqs = state->ne[1];
{
GGML_ASSERT(k->ne[1] == 1);
GGML_ASSERT(v->ne[0] == 1 && v->ne[1] == S && v->ne[2] == H && v->ne[3] == n_tokens);
GGML_ASSERT(r->ne[0] == 1 && r->ne[1] == S && r->ne[2] == H && r->ne[3] == n_tokens);
// TODO: RWKV v4 and v5
GGML_ASSERT(td->ne[0] == 1 && td->ne[1] == S && td->ne[2] == H && td->ne[3] == n_tokens);
GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
GGML_ASSERT(r->ne[0] == S && r->ne[1] == H && r->ne[2] == n_tokens);
GGML_ASSERT(td->ne[0] == S && td->ne[1] == H && td->ne[2] == n_tokens);
GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
}
@ -4656,6 +4656,49 @@ struct ggml_tensor * ggml_rwkv_wkv6(
return result;
}
// ggml_gated_linear_attn
struct ggml_tensor * ggml_gated_linear_attn(
struct ggml_context * ctx,
struct ggml_tensor * k,
struct ggml_tensor * v,
struct ggml_tensor * q,
struct ggml_tensor * g,
struct ggml_tensor * state,
float scale) {
GGML_ASSERT(ggml_is_contiguous(k));
GGML_ASSERT(ggml_is_contiguous(v));
GGML_ASSERT(ggml_is_contiguous(q));
GGML_ASSERT(ggml_is_contiguous(g));
GGML_ASSERT(ggml_is_contiguous(state));
const int64_t S = k->ne[0];
const int64_t H = k->ne[1];
const int64_t n_tokens = k->ne[2];
const int64_t n_seqs = state->ne[1];
{
GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
GGML_ASSERT(q->ne[0] == S && q->ne[1] == H && q->ne[2] == n_tokens);
GGML_ASSERT(g->ne[0] == S && g->ne[1] == H && g->ne[2] == n_tokens);
GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
}
// concat output and new_state
const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
ggml_set_op_params_f32(result, 0, scale);
result->op = GGML_OP_GATED_LINEAR_ATTN;
result->src[0] = k;
result->src[1] = v;
result->src[2] = q;
result->src[3] = g;
result->src[4] = state;
return result;
}
// ggml_unary
static struct ggml_tensor * ggml_unary_impl(