mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-05-09 20:13:14 +00:00
vulkan: Handle src1 batch dimension in non-contiguous mat-vec-mul shader (llama/13191)
* vulkan: Handle src1 batch dimension in non-contiguous mat-vec-mul shader
This commit is contained in:
parent
87b88ed01c
commit
df458380d6
@ -2419,7 +2419,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||||||
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 9 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
||||||
@ -4972,6 +4972,8 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
|
|||||||
const uint64_t nb01 = src0->nb[1];
|
const uint64_t nb01 = src0->nb[1];
|
||||||
const uint64_t nb02 = src0->nb[2];
|
const uint64_t nb02 = src0->nb[2];
|
||||||
|
|
||||||
|
const uint64_t nb12 = src1->nb[2];
|
||||||
|
|
||||||
// const uint64_t ne10 = src1->ne[0];
|
// const uint64_t ne10 = src1->ne[0];
|
||||||
const uint64_t ne11 = src1->ne[1];
|
const uint64_t ne11 = src1->ne[1];
|
||||||
const uint64_t ne12 = src1->ne[2];
|
const uint64_t ne12 = src1->ne[2];
|
||||||
@ -4997,6 +4999,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
|
|||||||
|
|
||||||
const uint32_t row_stride_x = nb01 / sizeof(ggml_fp16_t);
|
const uint32_t row_stride_x = nb01 / sizeof(ggml_fp16_t);
|
||||||
const uint32_t channel_stride_x = nb02 / sizeof(ggml_fp16_t);
|
const uint32_t channel_stride_x = nb02 / sizeof(ggml_fp16_t);
|
||||||
|
const uint32_t channel_stride_y = nb12 / sizeof(float);
|
||||||
|
|
||||||
const uint64_t qx_sz = ggml_nbytes(src0);
|
const uint64_t qx_sz = ggml_nbytes(src0);
|
||||||
const uint64_t qy_sz = ggml_nbytes(src1);
|
const uint64_t qy_sz = ggml_nbytes(src1);
|
||||||
@ -5027,7 +5030,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
|
|||||||
const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset;
|
const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset;
|
||||||
|
|
||||||
// compute
|
// compute
|
||||||
const std::array<uint32_t, 7> pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, (uint32_t)(ne12 / ne02), (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) };
|
const std::array<uint32_t, 9> pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, channel_stride_y, (uint32_t)(ne12 / ne02), (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) };
|
||||||
ggml_vk_sync_buffers(subctx);
|
ggml_vk_sync_buffers(subctx);
|
||||||
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32,
|
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32,
|
||||||
{ vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 7 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
|
{ vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 7 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
|
||||||
|
@ -21,7 +21,9 @@ layout (push_constant) uniform parameter
|
|||||||
uint nrows_x;
|
uint nrows_x;
|
||||||
uint row_stride_x;
|
uint row_stride_x;
|
||||||
uint channel_stride_x;
|
uint channel_stride_x;
|
||||||
|
uint channel_stride_y;
|
||||||
uint channel_x_divisor;
|
uint channel_x_divisor;
|
||||||
|
uint ne12;
|
||||||
uint b_offset;
|
uint b_offset;
|
||||||
uint d_offset;
|
uint d_offset;
|
||||||
} p;
|
} p;
|
||||||
@ -33,6 +35,7 @@ void main() {
|
|||||||
const uint row_x = gl_GlobalInvocationID.y;
|
const uint row_x = gl_GlobalInvocationID.y;
|
||||||
const uint channel = gl_GlobalInvocationID.z;
|
const uint channel = gl_GlobalInvocationID.z;
|
||||||
const uint channel_x = channel / p.channel_x_divisor;
|
const uint channel_x = channel / p.channel_x_divisor;
|
||||||
|
const uint channel_y = channel % p.ne12;
|
||||||
|
|
||||||
const uint nrows_y = p.ncols_x;
|
const uint nrows_y = p.ncols_x;
|
||||||
const uint nrows_dst = p.nrows_x;
|
const uint nrows_dst = p.nrows_x;
|
||||||
@ -56,7 +59,7 @@ void main() {
|
|||||||
const uint row_y = col_x;
|
const uint row_y = col_x;
|
||||||
|
|
||||||
const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
|
const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
|
||||||
const uint iy = channel*nrows_y + row_y;
|
const uint iy = channel_y*p.channel_stride_y + row_y;
|
||||||
|
|
||||||
const vec4 av4 = vec4(data_a_v4[ix / 4]);
|
const vec4 av4 = vec4(data_a_v4[ix / 4]);
|
||||||
const vec4 bv4 = vec4(data_b_v4[iy / 4]);
|
const vec4 bv4 = vec4(data_b_v4[iy / 4]);
|
||||||
@ -72,7 +75,7 @@ void main() {
|
|||||||
const uint row_y = col_x;
|
const uint row_y = col_x;
|
||||||
|
|
||||||
const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
|
const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
|
||||||
const uint iy = channel*nrows_y + row_y;
|
const uint iy = channel_y*p.channel_stride_y + row_y;
|
||||||
|
|
||||||
const vec4 av4 = vec4(data_a_v4[ix / 4]);
|
const vec4 av4 = vec4(data_a_v4[ix / 4]);
|
||||||
const vec4 bv4 = vec4(data_b_v4[iy / 4]);
|
const vec4 bv4 = vec4(data_b_v4[iy / 4]);
|
||||||
@ -89,7 +92,7 @@ void main() {
|
|||||||
const uint row_y = col_x;
|
const uint row_y = col_x;
|
||||||
|
|
||||||
const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
|
const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
|
||||||
const uint iy = channel*nrows_y + row_y;
|
const uint iy = channel_y*p.channel_stride_y + row_y;
|
||||||
|
|
||||||
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
|
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user