vulkan: Handle src1 batch dimension in non-contiguous mat-vec-mul shader (llama/13191)

* vulkan: Handle src1 batch dimension in non-contiguous mat-vec-mul shader
This commit is contained in:
Jeff Bolz 2025-05-01 13:19:31 -05:00 committed by Georgi Gerganov
parent 87b88ed01c
commit df458380d6
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 11 additions and 5 deletions

View File

@ -2419,7 +2419,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true);
} }
} }
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 9 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
@ -4972,6 +4972,8 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
const uint64_t nb01 = src0->nb[1]; const uint64_t nb01 = src0->nb[1];
const uint64_t nb02 = src0->nb[2]; const uint64_t nb02 = src0->nb[2];
const uint64_t nb12 = src1->nb[2];
// const uint64_t ne10 = src1->ne[0]; // const uint64_t ne10 = src1->ne[0];
const uint64_t ne11 = src1->ne[1]; const uint64_t ne11 = src1->ne[1];
const uint64_t ne12 = src1->ne[2]; const uint64_t ne12 = src1->ne[2];
@ -4997,6 +4999,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
const uint32_t row_stride_x = nb01 / sizeof(ggml_fp16_t); const uint32_t row_stride_x = nb01 / sizeof(ggml_fp16_t);
const uint32_t channel_stride_x = nb02 / sizeof(ggml_fp16_t); const uint32_t channel_stride_x = nb02 / sizeof(ggml_fp16_t);
const uint32_t channel_stride_y = nb12 / sizeof(float);
const uint64_t qx_sz = ggml_nbytes(src0); const uint64_t qx_sz = ggml_nbytes(src0);
const uint64_t qy_sz = ggml_nbytes(src1); const uint64_t qy_sz = ggml_nbytes(src1);
@ -5027,7 +5030,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset; const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset;
// compute // compute
const std::array<uint32_t, 7> pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, (uint32_t)(ne12 / ne02), (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) }; const std::array<uint32_t, 9> pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, channel_stride_y, (uint32_t)(ne12 / ne02), (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) };
ggml_vk_sync_buffers(subctx); ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32,
{ vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 7 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 }); { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 7 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });

View File

@ -21,7 +21,9 @@ layout (push_constant) uniform parameter
uint nrows_x; uint nrows_x;
uint row_stride_x; uint row_stride_x;
uint channel_stride_x; uint channel_stride_x;
uint channel_stride_y;
uint channel_x_divisor; uint channel_x_divisor;
uint ne12;
uint b_offset; uint b_offset;
uint d_offset; uint d_offset;
} p; } p;
@ -33,6 +35,7 @@ void main() {
const uint row_x = gl_GlobalInvocationID.y; const uint row_x = gl_GlobalInvocationID.y;
const uint channel = gl_GlobalInvocationID.z; const uint channel = gl_GlobalInvocationID.z;
const uint channel_x = channel / p.channel_x_divisor; const uint channel_x = channel / p.channel_x_divisor;
const uint channel_y = channel % p.ne12;
const uint nrows_y = p.ncols_x; const uint nrows_y = p.ncols_x;
const uint nrows_dst = p.nrows_x; const uint nrows_dst = p.nrows_x;
@ -56,7 +59,7 @@ void main() {
const uint row_y = col_x; const uint row_y = col_x;
const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x; const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
const uint iy = channel*nrows_y + row_y; const uint iy = channel_y*p.channel_stride_y + row_y;
const vec4 av4 = vec4(data_a_v4[ix / 4]); const vec4 av4 = vec4(data_a_v4[ix / 4]);
const vec4 bv4 = vec4(data_b_v4[iy / 4]); const vec4 bv4 = vec4(data_b_v4[iy / 4]);
@ -72,7 +75,7 @@ void main() {
const uint row_y = col_x; const uint row_y = col_x;
const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x; const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
const uint iy = channel*nrows_y + row_y; const uint iy = channel_y*p.channel_stride_y + row_y;
const vec4 av4 = vec4(data_a_v4[ix / 4]); const vec4 av4 = vec4(data_a_v4[ix / 4]);
const vec4 bv4 = vec4(data_b_v4[iy / 4]); const vec4 bv4 = vec4(data_b_v4[iy / 4]);
@ -89,7 +92,7 @@ void main() {
const uint row_y = col_x; const uint row_y = col_x;
const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x; const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
const uint iy = channel*nrows_y + row_y; const uint iy = channel_y*p.channel_stride_y + row_y;
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]); const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);