llama: Add support for RWKV v7 architecture (llama/12412)

* ggml: Add op l2_norm

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* ggml: Add op rwkv_wkv7

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: Add support for RWKV7 and ARWKV7 models

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: fix inference with RWKV6Qwen2

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: add more (a)rwkv7 variants in size

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* Apply code-format changes

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* fix MUSA build

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: fix shape error with rwkv using llama-parallel

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

---------

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
Molly Sophia
2025-03-18 07:27:50 +08:00
committed by Georgi Gerganov
parent cfc2560e41
commit 52c4c03b0a
21 changed files with 1751 additions and 85 deletions

View File

@ -2696,6 +2696,12 @@ static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * ds
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
static void ggml_sycl_l2_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_l2_norm);
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_group_norm);
@ -3410,6 +3416,9 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
case GGML_OP_RMS_NORM:
ggml_sycl_rms_norm(ctx, dst);
break;
case GGML_OP_L2_NORM:
ggml_sycl_l2_norm(ctx, dst);
break;
case GGML_OP_MUL_MAT:
if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
return false;
@ -3487,6 +3496,9 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
case GGML_OP_RWKV_WKV6:
ggml_sycl_op_rwkv_wkv6(ctx, dst);
break;
case GGML_OP_RWKV_WKV7:
ggml_sycl_op_rwkv_wkv7(ctx, dst);
break;
case GGML_OP_GATED_LINEAR_ATTN:
ggml_sycl_op_gated_linear_attn(ctx, dst);
break;
@ -4012,6 +4024,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
return (op->src[0]->type == GGML_TYPE_F32);
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
case GGML_OP_L2_NORM:
case GGML_OP_GROUP_NORM:
return ggml_is_contiguous(op->src[0]);
case GGML_OP_SCALE:
@ -4045,6 +4058,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_LEAKY_RELU:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
case GGML_OP_GATED_LINEAR_ATTN:
return true;
default: