mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-06-01 06:50:41 +00:00
This assert fired running Qwen_Qwen3-30B-A3B-Q2_K.gguf: GGML_ASSERT(nei0 * nei1 <= 3072); The tensor is 8 x 512. Increase this array size to accommodate.
443 lines
15 KiB
Plaintext
443 lines
15 KiB
Plaintext
#version 450
|
|
|
|
#extension GL_EXT_control_flow_attributes : enable
|
|
#extension GL_EXT_shader_16bit_storage : require
|
|
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
|
|
|
|
#extension GL_EXT_integer_dot_product : require
|
|
|
|
#ifdef FLOAT16
|
|
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
|
#endif
|
|
|
|
#ifdef COOPMAT
|
|
#extension GL_KHR_cooperative_matrix : enable
|
|
#extension GL_KHR_memory_scope_semantics : enable
|
|
#extension GL_KHR_shader_subgroup_basic : enable
|
|
#endif
|
|
|
|
#ifdef MUL_MAT_ID
|
|
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
|
|
#endif
|
|
|
|
#include "types.comp"
|
|
|
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
|
|
|
layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
|
|
#if defined(A_TYPE_PACKED32)
|
|
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
|
|
#endif
|
|
layout (binding = 1) readonly buffer B {block_q8_1_packed32 data_b[];};
|
|
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
|
|
|
#ifdef MUL_MAT_ID
|
|
layout (binding = 3) readonly buffer IDS {int data_ids[];};
|
|
#endif
|
|
|
|
layout (push_constant) uniform parameter
|
|
{
|
|
uint M;
|
|
uint N;
|
|
uint K;
|
|
uint stride_a;
|
|
uint stride_b;
|
|
uint stride_d;
|
|
|
|
uint batch_stride_a;
|
|
uint batch_stride_b;
|
|
uint batch_stride_d;
|
|
|
|
#ifdef MUL_MAT_ID
|
|
uint nei0;
|
|
uint nei1;
|
|
uint nbi1;
|
|
uint ne11;
|
|
#else
|
|
uint k_split;
|
|
uint ne02;
|
|
uint ne12;
|
|
uint broadcast2;
|
|
uint broadcast3;
|
|
#endif
|
|
} p;
|
|
|
|
layout (constant_id = 0) const uint BLOCK_SIZE = 64;
|
|
layout (constant_id = 1) const uint BM = 64;
|
|
layout (constant_id = 2) const uint BN = 64;
|
|
// layout (constant_id = 3) const uint BK = 32;
|
|
layout (constant_id = 4) const uint WM = 32;
|
|
layout (constant_id = 5) const uint WN = 32;
|
|
layout (constant_id = 6) const uint WMITER = 2;
|
|
layout (constant_id = 7) const uint TM = 4;
|
|
layout (constant_id = 8) const uint TN = 2;
|
|
layout (constant_id = 9) const uint TK = 1; // Only needed for coopmat
|
|
layout (constant_id = 10) const uint WARP = 32;
|
|
|
|
#define BK 32
|
|
|
|
#ifdef COOPMAT
|
|
#define SHMEM_STRIDE (BK / 4 + 4)
|
|
#else
|
|
#define SHMEM_STRIDE (BK / 4 + 1)
|
|
#endif
|
|
|
|
shared int32_t buf_a_qs[BM * SHMEM_STRIDE];
|
|
|
|
#ifndef COOPMAT
|
|
#if QUANT_AUXF == 1
|
|
shared FLOAT_TYPE buf_a_dm[BM];
|
|
#else
|
|
shared FLOAT_TYPE_VEC2 buf_a_dm[BM];
|
|
#endif
|
|
#endif
|
|
|
|
shared int32_t buf_b_qs[BN * SHMEM_STRIDE];
|
|
#ifndef COOPMAT
|
|
shared FLOAT_TYPE_VEC2 buf_b_ds[BN];
|
|
#endif
|
|
|
|
#define LOAD_VEC_A (4 * QUANT_R)
|
|
#define LOAD_VEC_B 4
|
|
|
|
#ifdef MUL_MAT_ID
|
|
shared u16vec2 row_ids[4096];
|
|
#endif // MUL_MAT_ID
|
|
|
|
#define NUM_WARPS (BLOCK_SIZE / WARP)
|
|
|
|
#ifdef COOPMAT
|
|
shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
|
|
#endif
|
|
|
|
#include "mul_mmq_funcs.comp"
|
|
|
|
void main() {
|
|
#ifdef NEEDS_INIT_IQ_SHMEM
|
|
init_iq_shmem(gl_WorkGroupSize);
|
|
#endif
|
|
|
|
#ifdef MUL_MAT_ID
|
|
const uint expert_idx = gl_GlobalInvocationID.z;
|
|
#else
|
|
const uint batch_idx = gl_GlobalInvocationID.z;
|
|
|
|
const uint i13 = batch_idx / p.ne12;
|
|
const uint i12 = batch_idx % p.ne12;
|
|
|
|
const uint i03 = i13 / p.broadcast3;
|
|
const uint i02 = i12 / p.broadcast2;
|
|
|
|
const uint batch_idx_a = i03 * p.ne02 + i02;
|
|
#endif
|
|
|
|
const uint blocks_m = (p.M + BM - 1) / BM;
|
|
const uint ir = gl_WorkGroupID.x % blocks_m;
|
|
const uint ik = gl_WorkGroupID.x / blocks_m;
|
|
const uint ic = gl_WorkGroupID.y;
|
|
|
|
const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
|
|
const uint WSUBM = WM / WMITER;
|
|
const uint WSUBN = WN / WNITER;
|
|
|
|
#ifdef COOPMAT
|
|
const uint warp_i = gl_SubgroupID;
|
|
|
|
const uint tiw = gl_SubgroupInvocationID;
|
|
|
|
const uint cms_per_row = WM / TM;
|
|
const uint cms_per_col = WN / TN;
|
|
|
|
const uint storestride = WARP / TM;
|
|
const uint store_r = tiw % TM;
|
|
const uint store_c = tiw / TM;
|
|
#else
|
|
const uint warp_i = gl_LocalInvocationID.x / WARP;
|
|
|
|
const uint tiw = gl_LocalInvocationID.x % WARP;
|
|
|
|
const uint tiwr = tiw % (WSUBM / TM);
|
|
const uint tiwc = tiw / (WSUBM / TM);
|
|
#endif
|
|
|
|
const uint warp_r = warp_i % (BM / WM);
|
|
const uint warp_c = warp_i / (BM / WM);
|
|
|
|
const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A);
|
|
const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A);
|
|
const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B);
|
|
const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B);
|
|
|
|
const uint loadstride_a = BLOCK_SIZE * LOAD_VEC_A / BK;
|
|
const uint loadstride_b = BLOCK_SIZE * LOAD_VEC_B / BK;
|
|
|
|
#ifdef MUL_MAT_ID
|
|
uint _ne1 = 0;
|
|
for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
|
|
for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
|
|
if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
|
|
row_ids[_ne1] = u16vec2(ii0, ii1);
|
|
_ne1++;
|
|
}
|
|
}
|
|
}
|
|
|
|
barrier();
|
|
|
|
// Workgroup has no work
|
|
if (ic * BN >= _ne1) return;
|
|
#endif
|
|
|
|
#ifdef MUL_MAT_ID
|
|
const uint start_k = 0;
|
|
const uint end_k = p.K;
|
|
#else
|
|
const uint start_k = ik * p.k_split;
|
|
const uint end_k = min(p.K, (ik + 1) * p.k_split);
|
|
#endif
|
|
|
|
uint pos_a_ib = (
|
|
#ifdef MUL_MAT_ID
|
|
expert_idx * p.batch_stride_a +
|
|
#else
|
|
batch_idx_a * p.batch_stride_a +
|
|
#endif
|
|
ir * BM * p.stride_a + start_k) / BK;
|
|
#ifdef MUL_MAT_ID
|
|
uint pos_b_ib = 0;
|
|
#else
|
|
uint pos_b_ib = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / BK;
|
|
#endif
|
|
|
|
#ifdef COOPMAT
|
|
coopmat<int8_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
|
|
coopmat<int8_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
|
|
coopmat<int32_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_result;
|
|
|
|
coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> factors[cms_per_row * cms_per_col];
|
|
|
|
coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
|
|
|
|
[[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
|
|
sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
|
|
}
|
|
#else
|
|
int32_t cache_a_qs[WMITER * TM * BK / 4];
|
|
|
|
int32_t cache_b_qs[TN * BK / 4];
|
|
|
|
ACC_TYPE sums[WMITER * TM * WNITER * TN];
|
|
|
|
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
|
|
sums[i] = ACC_TYPE(0.0f);
|
|
}
|
|
#endif
|
|
|
|
#if QUANT_AUXF == 1
|
|
FLOAT_TYPE cache_a_dm[WMITER * TM];
|
|
#else
|
|
FLOAT_TYPE_VEC2 cache_a_dm[WMITER * TM];
|
|
#endif
|
|
|
|
FLOAT_TYPE_VEC2 cache_b_ds[TN];
|
|
|
|
for (uint block = start_k; block < end_k; block += BK) {
|
|
[[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) {
|
|
const uint ib = pos_a_ib + (loadc_a + l) * p.stride_a / BK;
|
|
const uint iqs = loadr_a;
|
|
const uint buf_ib = loadc_a + l;
|
|
|
|
if (iqs == 0) {
|
|
#if QUANT_AUXF == 1
|
|
buf_a_dm[buf_ib] = get_d(ib);
|
|
#else
|
|
buf_a_dm[buf_ib] = get_dm(ib);
|
|
#endif
|
|
}
|
|
#if QUANT_R == 1
|
|
buf_a_qs[buf_ib * SHMEM_STRIDE + iqs] = repack(ib, iqs);
|
|
#else
|
|
const i32vec2 vals = repack(ib, iqs);
|
|
buf_a_qs[buf_ib * SHMEM_STRIDE + iqs ] = vals.x;
|
|
buf_a_qs[buf_ib * SHMEM_STRIDE + iqs + 4] = vals.y;
|
|
#endif
|
|
}
|
|
[[unroll]] for (uint l = 0; loadc_b + l < BN; l += loadstride_b) {
|
|
#ifdef MUL_MAT_ID
|
|
const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
|
|
const uint idx = pos_b_ib + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
|
|
const uint ib = idx / 8;
|
|
const uint iqs = idx & 0x7;
|
|
#else
|
|
const uint ib = pos_b_ib + (loadc_b + l) * p.stride_b / BK;
|
|
const uint iqs = loadr_b;
|
|
#endif
|
|
|
|
const uint buf_ib = loadc_b + l;
|
|
|
|
if (iqs == 0) {
|
|
buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib].ds);
|
|
}
|
|
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs] = data_b[ib].qs[iqs];
|
|
}
|
|
|
|
barrier();
|
|
|
|
pos_a_ib += 1;
|
|
pos_b_ib += 1;
|
|
|
|
#ifdef COOPMAT
|
|
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
|
|
const uint ib_a = warp_r * WM + cm_row * TM;
|
|
// Load from shared into cache
|
|
coopMatLoad(cache_a, buf_a_qs, ib_a * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
|
|
|
|
// TODO: only cache values that are actually needed
|
|
[[unroll]] for (uint t_idx = 0; t_idx < TM; t_idx++) {
|
|
cache_a_dm[t_idx] = buf_a_dm[ib_a + t_idx];
|
|
}
|
|
|
|
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
|
|
const uint ib_b = warp_c * WN + cm_col * TN;
|
|
coopMatLoad(cache_b, buf_b_qs, ib_b * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor);
|
|
|
|
// TODO: only cache values that are actually needed
|
|
[[unroll]] for (uint t_idx = 0; t_idx < TN; t_idx++) {
|
|
cache_b_dm[t_idx] = buf_b_d[ib_b + t_idx];
|
|
}
|
|
|
|
cm_result = coopmat<int32_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0);
|
|
cm_result = coopMatMulAdd(cache_a, cache_b, cm_result);
|
|
|
|
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
|
|
coopmat_stage[warp_i * TM * TN + (store_c + col) * TM + store_r] = ACC_TYPE(float(cache_a_d[store_r]) * float(cache_b_d[store_c + col]));
|
|
}
|
|
|
|
coopMatLoad(factors, coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
|
|
sums[cm_col * cms_per_row + cm_row] += factors * coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(cm_result);
|
|
}
|
|
}
|
|
#else
|
|
// Load from shared into cache
|
|
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
|
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
|
const uint ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr;
|
|
cache_a_dm[wsir * TM + cr] = buf_a_dm[ib];
|
|
[[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
|
|
cache_a_qs[(wsir * TM + cr) * (BK / 4) + idx_k] = buf_a_qs[ib * SHMEM_STRIDE + idx_k];
|
|
}
|
|
}
|
|
}
|
|
|
|
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
|
|
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
|
|
const uint ib = warp_c * WN + wsic * WSUBN + tiwc * TN + cc;
|
|
cache_b_ds[cc] = buf_b_ds[ib];
|
|
[[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
|
|
cache_b_qs[cc * (BK / 4) + idx_k] = buf_b_qs[ib * SHMEM_STRIDE + idx_k];
|
|
}
|
|
}
|
|
|
|
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
|
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
|
|
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
|
const uint cache_a_idx = wsir * TM + cr;
|
|
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
|
|
int32_t q_sum = 0;
|
|
[[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
|
|
q_sum += dotPacked4x8EXT(cache_a_qs[cache_a_idx * (BK / 4) + idx_k],
|
|
cache_b_qs[cc * (BK / 4) + idx_k]);
|
|
}
|
|
|
|
sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc]);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
#endif
|
|
|
|
barrier();
|
|
}
|
|
|
|
const uint dr = ir * BM + warp_r * WM;
|
|
const uint dc = ic * BN + warp_c * WN;
|
|
|
|
#ifndef MUL_MAT_ID
|
|
const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
|
|
#endif
|
|
|
|
#ifdef COOPMAT
|
|
#ifdef MUL_MAT_ID
|
|
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
|
|
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
|
|
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
|
|
|
|
[[unroll]] for (uint col = 0; col < BN; col += storestride) {
|
|
const uint row_i = dc + cm_col * TN + col + store_c;
|
|
if (row_i >= _ne1) break;
|
|
|
|
const u16vec2 row_idx = row_ids[row_i];
|
|
|
|
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
|
|
}
|
|
}
|
|
}
|
|
#else
|
|
const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float
|
|
|
|
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
|
|
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
|
|
const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N;
|
|
|
|
if (is_aligned && is_in_bounds) {
|
|
// Full coopMat is within bounds and stride_d is aligned with 16B
|
|
coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_dtype = coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(sums[cm_col * cms_per_row + cm_row]);
|
|
coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor);
|
|
} else if (is_in_bounds) {
|
|
// Full coopMat is within bounds, but stride_d is not aligned
|
|
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
|
|
|
|
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
|
|
data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
|
|
}
|
|
} else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) {
|
|
// Partial coopMat is within bounds
|
|
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
|
|
|
|
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
|
|
if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) {
|
|
data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
#endif // MUL_MAT_ID
|
|
#else
|
|
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
|
|
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
|
|
|
const uint dr_warp = dr + wsir * WSUBM + tiwr * TM;
|
|
const uint dc_warp = dc + wsic * WSUBN + tiwc * TN;
|
|
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
|
|
#ifdef MUL_MAT_ID
|
|
const uint row_i = dc_warp + cc;
|
|
if (row_i >= _ne1) break;
|
|
|
|
const u16vec2 row_idx = row_ids[row_i];
|
|
#endif // MUL_MAT_ID
|
|
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
|
#ifdef MUL_MAT_ID
|
|
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
|
|
#else
|
|
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
|
|
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
|
|
}
|
|
#endif // MUL_MAT_ID
|
|
}
|
|
}
|
|
}
|
|
}
|
|
#endif // COOPMAT
|
|
}
|