mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-05-31 22:40:42 +00:00
* rwkv_wkv6 vulkan shader * RWKV_WKV6 Vulkan op tests passed Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Apply code format changes Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * add [[unroll]] and remove unnecessary conditions * add uma support * fix erros in EditorConfig Checker --------- Signed-off-by: Molly Sophia <mollysophia379@gmail.com> Co-authored-by: Molly Sophia <mollysophia379@gmail.com>
88 lines
2.7 KiB
Plaintext
88 lines
2.7 KiB
Plaintext
#version 450
|
|
|
|
#extension GL_EXT_control_flow_attributes : require
|
|
|
|
#define BLOCK_SIZE 64
|
|
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
|
|
|
layout(push_constant) uniform Parameters {
|
|
uint B;
|
|
uint T;
|
|
uint C;
|
|
uint H;
|
|
};
|
|
|
|
layout(binding = 0) readonly buffer KBuf { A_TYPE k[]; };
|
|
layout(binding = 1) readonly buffer VBuf { A_TYPE v[]; };
|
|
layout(binding = 2) readonly buffer RBuf { A_TYPE r[]; };
|
|
layout(binding = 3) readonly buffer TimeFBuf { A_TYPE tf[]; };
|
|
layout(binding = 4) readonly buffer TimeDBuf { A_TYPE td[]; };
|
|
layout(binding = 5) readonly buffer StateBuf { A_TYPE state_in[]; };
|
|
layout(binding = 6) buffer DstBuf { A_TYPE dst[]; };
|
|
|
|
shared A_TYPE _k[BLOCK_SIZE], _r[BLOCK_SIZE], _tf[BLOCK_SIZE], _td[BLOCK_SIZE];
|
|
|
|
void main() {
|
|
const uint head_size = BLOCK_SIZE;
|
|
const uint batch_id = gl_WorkGroupID.x / H;
|
|
const uint head_id = gl_WorkGroupID.x % H;
|
|
const uint tid = gl_LocalInvocationID.x;
|
|
|
|
const uint state_size = C * head_size;
|
|
const uint n_seq_tokens = T / B;
|
|
|
|
if (batch_id >= B || head_id >= H) {
|
|
return;
|
|
}
|
|
|
|
A_TYPE state[BLOCK_SIZE];
|
|
[[unroll]] for (uint i = 0; i < head_size; i++) {
|
|
state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
|
|
+ i * head_size + tid];
|
|
}
|
|
|
|
barrier();
|
|
_tf[tid] = tf[head_id * head_size + tid];
|
|
barrier();
|
|
|
|
const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
|
|
const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
|
|
|
|
for (uint t = start_t; t < end_t; t += C) {
|
|
barrier();
|
|
_k[tid] = k[t];
|
|
_r[tid] = r[t];
|
|
_td[tid] = td[t];
|
|
barrier();
|
|
|
|
const A_TYPE v_val = v[t];
|
|
A_TYPE y = 0.0;
|
|
|
|
[[unroll]] for (uint j = 0; j < head_size; j += 4) {
|
|
vec4 k_vec = vec4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
|
|
vec4 r_vec = vec4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
|
|
vec4 tf_vec = vec4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
|
|
vec4 td_vec = vec4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
|
|
vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]);
|
|
|
|
vec4 kv = k_vec * v_val;
|
|
|
|
vec4 temp = tf_vec * kv + s_vec;
|
|
y += dot(r_vec, temp);
|
|
|
|
s_vec = s_vec * td_vec + kv;
|
|
state[j] = s_vec.x;
|
|
state[j+1] = s_vec.y;
|
|
state[j+2] = s_vec.z;
|
|
state[j+3] = s_vec.w;
|
|
}
|
|
|
|
dst[t] = y;
|
|
}
|
|
|
|
[[unroll]] for (uint i = 0; i < head_size; i++) {
|
|
dst[T * C + batch_id * state_size + head_id * head_size * head_size
|
|
+ i * head_size + tid] = state[i];
|
|
}
|
|
}
|