vulkan: implement several ops relevant for ggml_opt (llama/11769)

* vulkan: support memset_tensor

* vulkan: support GGML_OP_SUM

* vulkan: implement GGML_OP_ARGMAX

* vulkan: implement GGML_OP_SUB

* vulkan: implement GGML_OP_COUNT_EQUAL

* vulkan: implement GGML_OP_OPT_STEP_ADAMW

* vulkan: fix check_results RWKV_WKV6 crash and memory leaks

* vulkan: implement GGML_OP_REPEAT_BACK

* tests: remove invalid test-backend-ops REPEAT_BACK tests

* vulkan: fix COUNT_EQUAL memset using a fillBuffer command
This commit is contained in:
Rémy O 2025-02-17 07:55:57 +01:00 committed by Georgi Gerganov
parent 8a22a8b17f
commit 37a21dd43d
7 changed files with 564 additions and 218 deletions

View File

@ -222,6 +222,7 @@ struct vk_device_struct {
vk_pipeline pipeline_acc_f32;
vk_pipeline pipeline_add_f32, pipeline_add_f32_norepeat;
vk_pipeline pipeline_add_f16_f32_f16, pipeline_add_f16_f32_f16_norepeat;
vk_pipeline pipeline_sub_f32, pipeline_sub_f32_norepeat;
vk_pipeline pipeline_mul_f32, pipeline_mul_f32_norepeat;
vk_pipeline pipeline_div_f32, pipeline_div_f32_norepeat;
vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
@ -232,7 +233,7 @@ struct vk_device_struct {
vk_pipeline pipeline_cos_f32;
vk_pipeline pipeline_clamp_f32;
vk_pipeline pipeline_pad_f32;
vk_pipeline pipeline_repeat_f32;
vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32;
vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16;
vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16;
vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
@ -255,10 +256,13 @@ struct vk_device_struct {
vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
vk_pipeline pipeline_argsort_f32;
vk_pipeline pipeline_sum_rows_f32;
vk_pipeline pipeline_argmax_f32;
vk_pipeline pipeline_count_equal_i32;
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
vk_pipeline pipeline_timestep_embedding_f32;
vk_pipeline pipeline_pool2d_f32;
vk_pipeline pipeline_rwkv_wkv6_f32;
vk_pipeline pipeline_opt_step_adamw_f32;
// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
@ -2147,6 +2151,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_sub_f32, "sub_f32", sub_f32_len, sub_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
ggml_vk_create_pipeline(device, device->pipeline_sub_f32_norepeat, "sub_f32_norepeat", sub_f32_len, sub_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
ggml_vk_create_pipeline(device, device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
ggml_vk_create_pipeline(device, device->pipeline_mul_f32_norepeat, "mul_f32_norepeat", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
ggml_vk_create_pipeline(device, device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
@ -2169,6 +2175,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_gelu_quick_f32, "gelu_quick_f32", gelu_quick_f32_len, gelu_quick_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
@ -2203,8 +2210,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
if (device->float_controls_rte_fp16) {
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
@ -2218,6 +2229,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
for (auto &c : compiles) {
c.wait();
}
@ -3783,6 +3796,12 @@ static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& sr
}
}
static void ggml_vk_buffer_memset_async(vk_context& ctx, vk_buffer& dst, size_t offset, uint32_t c, size_t size) {
VK_LOG_DEBUG("ggml_vk_buffer_memset_async(" << offset << ", " << c << ", " << size << ")");
ctx->s->buffer.fillBuffer(dst->buffer, offset, size, c);
}
static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size) {
VK_LOG_DEBUG("ggml_vk_buffer_memset(" << offset << ", " << c << ", " << size << ")");
@ -5189,6 +5208,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_f16_f32_f16_norepeat : ctx->device->pipeline_add_f16_f32_f16;
}
return nullptr;
case GGML_OP_SUB:
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_sub_f32_norepeat : ctx->device->pipeline_sub_f32;
}
return nullptr;
case GGML_OP_MUL:
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_mul_f32_norepeat : ctx->device->pipeline_mul_f32;
@ -5250,6 +5274,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_repeat_f32;
}
return nullptr;
case GGML_OP_REPEAT_BACK:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_repeat_back_f32;
}
return nullptr;
case GGML_OP_CPY:
case GGML_OP_CONT:
case GGML_OP_DUP:
@ -5358,11 +5387,22 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_argsort_f32;
}
return nullptr;
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_sum_rows_f32;
}
return nullptr;
case GGML_OP_ARGMAX:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
return ctx->device->pipeline_argmax_f32;
}
return nullptr;
case GGML_OP_COUNT_EQUAL:
if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I64) {
return ctx->device->pipeline_count_equal_i32;
}
return nullptr;
case GGML_OP_IM2COL:
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_im2col_f32;
@ -5386,6 +5426,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_rwkv_wkv6_f32;
}
return nullptr;
case GGML_OP_OPT_STEP_ADAMW:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_opt_step_adamw_f32;
}
return nullptr;
case GGML_OP_LEAKY_RELU:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_leaky_relu_f32;
@ -5403,6 +5448,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
case GGML_OP_CPY:
case GGML_OP_GET_ROWS:
case GGML_OP_ADD:
case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_CONCAT:
@ -5413,6 +5459,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
case GGML_OP_CLAMP:
case GGML_OP_PAD:
case GGML_OP_REPEAT:
case GGML_OP_REPEAT_BACK:
case GGML_OP_ROPE:
return true;
default:
@ -5627,6 +5674,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
case GGML_OP_RMS_NORM:
case GGML_OP_SOFT_MAX:
case GGML_OP_SUM_ROWS:
case GGML_OP_ARGMAX:
{
const uint32_t nr = ggml_nrows(src0);
if (nr > 262144) {
@ -5637,6 +5685,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
elements = { nr, 1, 1 };
}
} break;
case GGML_OP_SUM:
// We use GGML_OP_SUM_ROWS with 1 row.
elements = { 1, 1, 1 };
break;
case GGML_OP_GROUP_NORM:
{
const uint32_t num_groups = dst->op_params[0];
@ -5683,6 +5735,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
elements = { N * OC * OH * OW, 1, 1};
} break;
case GGML_OP_ADD:
case GGML_OP_SUB:
case GGML_OP_DIV:
case GGML_OP_MUL:
case GGML_OP_SCALE:
@ -5692,6 +5745,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
case GGML_OP_CLAMP:
case GGML_OP_PAD:
case GGML_OP_REPEAT:
case GGML_OP_REPEAT_BACK:
case GGML_OP_CPY:
case GGML_OP_CONCAT:
case GGML_OP_UPSCALE:
@ -5752,6 +5806,12 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
// im2col uses only src1 and dst buffers
ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
} else if (op == GGML_OP_COUNT_EQUAL) {
ggml_vk_sync_buffers(subctx);
// count_equal assumes that destination buffer is initialized with zeroes
ggml_vk_buffer_memset_async(subctx, d_D, d_buf_offset, 0, d_sz);
ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
} else if (use_src2) {
ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
@ -5814,6 +5874,21 @@ static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const
}, dryrun);
}
static void ggml_vk_sub(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SUB, {
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
0,
0.0f, 0.0f, 0,
}, dryrun);
}
static void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t src1_type_size = ggml_type_size(src1->type);
@ -5972,6 +6047,111 @@ static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx,
);
}
static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_push_constants&& pc, bool dryrun = false) {
const ggml_tensor * x = dst->src[0];
const ggml_tensor * g = dst->src[1];
const ggml_tensor * gm = dst->src[2];
const ggml_tensor * gv = dst->src[3];
const ggml_tensor * p = dst->src[4];
GGML_ASSERT(x->type == GGML_TYPE_F32);
GGML_ASSERT(g->type == GGML_TYPE_F32);
GGML_ASSERT(gm->type == GGML_TYPE_F32);
GGML_ASSERT(gv->type == GGML_TYPE_F32);
GGML_ASSERT(p->type == GGML_TYPE_F32);
GGML_ASSERT(dst->buffer != nullptr);
GGML_ASSERT(ggml_is_contiguous(x));
GGML_ASSERT(ggml_is_contiguous(g));
GGML_ASSERT(ggml_is_contiguous(gm));
GGML_ASSERT(ggml_is_contiguous(gv));
GGML_ASSERT(ggml_is_contiguous(p));
GGML_ASSERT(ggml_are_same_shape(x, g));
GGML_ASSERT(ggml_are_same_shape(x, gm));
GGML_ASSERT(ggml_are_same_shape(x, gv));
GGML_ASSERT(ggml_nelements(p) == 7);
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, g, gm, gv, dst, GGML_OP_OPT_STEP_ADAMW);
GGML_ASSERT(pipeline != nullptr);
if (dryrun) {
ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
return;
}
ggml_backend_vk_buffer_context * x_buf_ctx = (ggml_backend_vk_buffer_context *)x->buffer->context;
ggml_backend_vk_buffer_context * g_buf_ctx = (ggml_backend_vk_buffer_context *)g->buffer->context;
ggml_backend_vk_buffer_context * gm_buf_ctx = (ggml_backend_vk_buffer_context *)gm->buffer->context;
ggml_backend_vk_buffer_context * gv_buf_ctx = (ggml_backend_vk_buffer_context *)gv->buffer->context;
ggml_backend_vk_buffer_context * p_buf_ctx = (ggml_backend_vk_buffer_context *)p->buffer->context;
ggml_vk_sync_buffers(subctx);
vk_buffer d_X = nullptr, d_G = nullptr, d_GM = nullptr, d_GV = nullptr, d_P = nullptr;
size_t x_offset = 0, g_offset = 0, gm_offset = 0, gv_offset = 0, p_offset = 0;
bool X_uma = false, G_uma = false, GM_uma = false, GV_uma = false, P_uma = false;
if (ctx->device->uma) {
ggml_vk_host_get(ctx->device, x->data, d_X, x_offset);
ggml_vk_host_get(ctx->device, g->data, d_G, g_offset);
ggml_vk_host_get(ctx->device, gm->data, d_GM, gm_offset);
ggml_vk_host_get(ctx->device, gv->data, d_GV, gv_offset);
ggml_vk_host_get(ctx->device, p->data, d_P, p_offset);
X_uma = d_X != nullptr;
G_uma = d_G != nullptr;
GM_uma = d_GM != nullptr;
GV_uma = d_GV != nullptr;
P_uma = d_P != nullptr;
}
if (!X_uma) {
d_X = x_buf_ctx->dev_buffer;
x_offset = vk_tensor_offset(x) + x->view_offs;
}
if (!G_uma) {
d_G = g_buf_ctx->dev_buffer;
g_offset = vk_tensor_offset(g) + g->view_offs;
}
if (!GM_uma) {
d_GM = gm_buf_ctx->dev_buffer;
gm_offset = vk_tensor_offset(gm) + gm->view_offs;
}
if (!GV_uma) {
d_GV = gv_buf_ctx->dev_buffer;
gv_offset = vk_tensor_offset(gv) + gv->view_offs;
}
if (!P_uma) {
d_P = p_buf_ctx->dev_buffer;
p_offset = vk_tensor_offset(p) + p->view_offs;
}
const uint64_t x_size = ggml_nbytes(x);
const uint64_t g_size = ggml_nbytes(g);
const uint64_t gm_size = ggml_nbytes(gm);
const uint64_t gv_size = ggml_nbytes(gv);
const uint64_t p_size = ggml_nbytes(p);
std::array<uint32_t, 3> elements = { (uint32_t)ggml_nelements(x), 1, 1 };
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
vk_subbuffer{ d_X, x_offset, x_size },
vk_subbuffer{ d_G, g_offset, g_size },
vk_subbuffer{ d_GM, gm_offset, gm_size },
vk_subbuffer{ d_GV, gv_offset, gv_size },
vk_subbuffer{ d_P, p_offset, p_size },
}, sizeof(vk_op_push_constants), &pc, elements);
}
static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
const size_t n = ggml_nelements(dst->src[0]);
ggml_vk_op_f32_opt_step_adamw(
ctx, subctx, dst,
{ (uint32_t)n, 0, 0.0f, 0.0f },
dryrun
);
}
static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
int * op_params = (int *)dst->op_params;
@ -6105,6 +6285,20 @@ static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, co
}, dryrun);
}
static void ggml_vk_repeat_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, {
(uint32_t)ggml_nelements(dst),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
0,
0.0f, 0.0f,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
}, dryrun);
}
static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
@ -6227,10 +6421,22 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c
}, dryrun);
}
static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
}
static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f }, dryrun);
}
static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f }, dryrun);
}
static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
}
static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
const int32_t s0 = dst->op_params[0];
const int32_t s1 = dst->op_params[1];
@ -7095,9 +7301,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
}
break;
case GGML_OP_REPEAT:
case GGML_OP_REPEAT_BACK:
case GGML_OP_GET_ROWS:
case GGML_OP_ADD:
case GGML_OP_ACC:
case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_CONCAT:
@ -7120,13 +7328,17 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
case GGML_OP_ARGSORT:
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
case GGML_OP_ARGMAX:
case GGML_OP_COUNT_EQUAL:
case GGML_OP_IM2COL:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_POOL_2D:
case GGML_OP_RWKV_WKV6:
case GGML_OP_LEAKY_RELU:
case GGML_OP_FLASH_ATTN_EXT:
case GGML_OP_OPT_STEP_ADAMW:
break;
default:
std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl;
@ -7147,9 +7359,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
} else {
switch (node->op) {
case GGML_OP_REPEAT:
case GGML_OP_REPEAT_BACK:
case GGML_OP_ACC:
case GGML_OP_GET_ROWS:
case GGML_OP_ADD:
case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_CONCAT:
@ -7171,7 +7385,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
case GGML_OP_SOFT_MAX:
case GGML_OP_ROPE:
case GGML_OP_ARGSORT:
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
case GGML_OP_ARGMAX:
case GGML_OP_COUNT_EQUAL:
case GGML_OP_IM2COL:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_POOL_2D:
@ -7192,6 +7409,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
case GGML_OP_REPEAT:
ggml_vk_repeat(ctx, compute_ctx, src0, node, dryrun);
break;
case GGML_OP_REPEAT_BACK:
ggml_vk_repeat_back(ctx, compute_ctx, src0, node, dryrun);
break;
case GGML_OP_ACC:
ggml_vk_acc(ctx, compute_ctx, src0, src1, node, dryrun);
@ -7204,6 +7425,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
case GGML_OP_ADD:
ggml_vk_add(ctx, compute_ctx, src0, src1, node, dryrun);
break;
case GGML_OP_SUB:
ggml_vk_sub(ctx, compute_ctx, src0, src1, node, dryrun);
break;
case GGML_OP_MUL:
ggml_vk_mul(ctx, compute_ctx, src0, src1, node, dryrun);
@ -7291,10 +7516,22 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
case GGML_OP_ARGSORT:
ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun);
break;
case GGML_OP_SUM:
ggml_vk_sum(ctx, compute_ctx, src0, node, dryrun);
break;
case GGML_OP_SUM_ROWS:
ggml_vk_sum_rows(ctx, compute_ctx, src0, node, dryrun);
break;
case GGML_OP_ARGMAX:
ggml_vk_argmax(ctx, compute_ctx, src0, node, dryrun);
break;
case GGML_OP_COUNT_EQUAL:
ggml_vk_count_equal(ctx, compute_ctx, src0, src1, node, dryrun);
break;
case GGML_OP_IM2COL:
ggml_vk_im2col(ctx, compute_ctx, src0, src1, node, dryrun);
@ -7329,6 +7566,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
case GGML_OP_RWKV_WKV6:
ggml_vk_rwkv_wkv6(ctx, compute_ctx, node, dryrun);
break;
case GGML_OP_OPT_STEP_ADAMW:
ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun);
break;
default:
return false;
@ -7380,6 +7622,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
case GGML_OP_ADD:
case GGML_OP_ACC:
case GGML_OP_GET_ROWS:
case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_CONCAT:
@ -7405,13 +7648,18 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
case GGML_OP_TRANSPOSE:
case GGML_OP_NONE:
case GGML_OP_ARGSORT:
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
case GGML_OP_ARGMAX:
case GGML_OP_COUNT_EQUAL:
case GGML_OP_IM2COL:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_POOL_2D:
case GGML_OP_RWKV_WKV6:
case GGML_OP_LEAKY_RELU:
case GGML_OP_REPEAT:
case GGML_OP_REPEAT_BACK:
case GGML_OP_OPT_STEP_ADAMW:
buf = tensor->buffer;
break;
@ -7603,6 +7851,15 @@ static void ggml_backend_vk_buffer_init_tensor(ggml_backend_buffer_t buffer, ggm
}
}
static void ggml_backend_vk_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
VK_LOG_DEBUG("ggml_backend_vk_buffer_memset_tensor(" << buffer << ", " << tensor << ", " << value << ", " << offset << ", " << size << ")");
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
vk_buffer buf = buf_ctx->dev_buffer;
uint32_t val32 = (uint32_t)value * 0x01010101;
ggml_vk_buffer_memset(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, val32, size);
}
static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
VK_LOG_DEBUG("ggml_backend_vk_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")");
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
@ -7647,7 +7904,7 @@ static ggml_backend_buffer_i ggml_backend_vk_buffer_interface = {
/* .free_buffer = */ ggml_backend_vk_buffer_free_buffer,
/* .get_base = */ ggml_backend_vk_buffer_get_base,
/* .init_tensor = */ ggml_backend_vk_buffer_init_tensor,
/* .memset_tensor = */ NULL,
/* .memset_tensor = */ ggml_backend_vk_buffer_memset_tensor,
/* .set_tensor = */ ggml_backend_vk_buffer_set_tensor,
/* .get_tensor = */ ggml_backend_vk_buffer_get_tensor,
/* .cpy_tensor = */ ggml_backend_vk_buffer_cpy_tensor,
@ -8300,6 +8557,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
} break;
case GGML_OP_REPEAT:
return ggml_type_size(op->type) == sizeof(float) && ggml_type_size(op->src[0]->type) == sizeof(float);
case GGML_OP_REPEAT_BACK:
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_ROPE:
case GGML_OP_NONE:
case GGML_OP_RESHAPE:
@ -8313,6 +8572,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
return ggml_is_contiguous(op->src[0]);
case GGML_OP_ADD:
case GGML_OP_ACC:
case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_CONCAT:
@ -8326,12 +8586,16 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
case GGML_OP_ARGSORT:
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
case GGML_OP_ARGMAX:
case GGML_OP_COUNT_EQUAL:
case GGML_OP_IM2COL:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_POOL_2D:
case GGML_OP_RWKV_WKV6:
case GGML_OP_LEAKY_RELU:
case GGML_OP_OPT_STEP_ADAMW:
return true;
default:
return false;
@ -8604,8 +8868,6 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
ggml_tensor * src0 = tensor->src[0];
ggml_tensor * src1 = tensor->src[1];
ggml_tensor * src2 = tensor->src[2];
ggml_tensor * src3 = tensor->src[3];
struct ggml_init_params iparams = {
/*.mem_size =*/ 2ul*1024ul*1024ul*1024ul,
@ -8615,238 +8877,113 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
struct ggml_context * ggml_ctx = ggml_init(iparams);
struct ggml_tensor * src0_clone = nullptr;
struct ggml_tensor * src1_clone = nullptr;
struct ggml_tensor * src2_clone = nullptr;
struct ggml_tensor * src3_clone = nullptr;
std::array<struct ggml_tensor *, 6> src_clone = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
std::array<size_t, 6> src_size = {0, 0, 0, 0, 0, 0};
std::array<void *, 6> src_buffer = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
const char * srci_name[6] = {"src0", "src1", "src2", "src3", "src4", "src5"};
struct ggml_tensor * tensor_clone = nullptr;
size_t src0_size;
size_t src1_size;
size_t src2_size;
size_t src3_size;
for (int i = 0; i < 6; i++) {
ggml_tensor * srci = tensor->src[i];
if (srci == nullptr) {
continue;
}
ggml_tensor * srci_clone = ggml_dup_tensor(ggml_ctx, srci);
size_t srci_size = ggml_nbytes(srci);
void * src0_buffer = nullptr;
void * src1_buffer = nullptr;
void * src2_buffer = nullptr;
void * src3_buffer = nullptr;
src_clone[i] = srci_clone;
src_size[i] = ggml_nbytes(srci);
src_buffer[i] = malloc(srci_size);
if (src0 != nullptr) {
src0_clone = ggml_dup_tensor(ggml_ctx, src0);
src0_size = ggml_nbytes(src0);
src0_buffer = malloc(src0_size);
src0_clone->data = src0_buffer;
if (ggml_backend_buffer_is_host(src0->buffer)) {
memcpy(src0_clone->data, src0->data, src0_size);
memcpy(src0_clone->nb, src0->nb, sizeof(size_t) * GGML_MAX_DIMS);
} else if (ggml_backend_buffer_is_vk(src0->buffer)) {
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
srci_clone->data = src_buffer[i];
if (ggml_backend_buffer_is_host(srci->buffer)) {
memcpy(srci_clone->data, srci->data, srci_size);
memcpy(srci_clone->nb, srci->nb, sizeof(size_t) * GGML_MAX_DIMS);
} else if (ggml_backend_buffer_is_vk(srci->buffer)) {
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)srci->buffer->context;
vk_buffer& buffer_gpu = buf_ctx->dev_buffer;
uint64_t offset = vk_tensor_offset(src0) + src0->view_offs;
if (!ggml_is_contiguous(src0) && ggml_vk_dim01_contiguous(src0)) {
for (int i3 = 0; i3 < src0->ne[3]; i3++) {
for (int i2 = 0; i2 < src0->ne[2]; i2++) {
const int idx = i3*src0->ne[2] + i2;
ggml_vk_buffer_read(buffer_gpu, offset + idx * src0->nb[2], ((char *)src0_clone->data + idx * src0_clone->nb[2]), src0->ne[1] * src0->nb[1]);
uint64_t offset = vk_tensor_offset(srci) + srci->view_offs;
if (!ggml_is_contiguous(srci) && ggml_vk_dim01_contiguous(srci)) {
for (int i3 = 0; i3 < srci->ne[3]; i3++) {
for (int i2 = 0; i2 < srci->ne[2]; i2++) {
const int idx = i3*srci->ne[2] + i2;
ggml_vk_buffer_read(buffer_gpu, offset + idx * srci->nb[2], ((char *)srci_clone->data + idx * srci_clone->nb[2]), srci->ne[1] * srci->nb[1]);
}
}
src0_clone->nb[0] = src0->nb[0];
src0_clone->nb[1] = src0->nb[1];
srci_clone->nb[0] = srci->nb[0];
srci_clone->nb[1] = srci->nb[1];
for (int i = 2; i < GGML_MAX_DIMS; i++) {
src0_clone->nb[i] = src0_clone->nb[i - 1]*src0_clone->ne[i - 1];
srci_clone->nb[i] = srci_clone->nb[i - 1]*srci_clone->ne[i - 1];
}
} else {
if (offset + src0_size >= buffer_gpu->size) {
src0_size = buffer_gpu->size - offset;
if (offset + srci_size >= buffer_gpu->size) {
srci_size = buffer_gpu->size - offset;
}
ggml_vk_buffer_read(buffer_gpu, offset, src0_clone->data, src0_size);
memcpy(src0_clone->nb, src0->nb, sizeof(size_t) * GGML_MAX_DIMS);
ggml_vk_buffer_read(buffer_gpu, offset, srci_clone->data, srci_size);
memcpy(srci_clone->nb, srci->nb, sizeof(size_t) * GGML_MAX_DIMS);
}
} else {
GGML_ABORT("fatal error");
}
if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
ggml_vk_print_tensor(src0, "src0");
}
}
if (src1 != nullptr) {
src1_clone = ggml_dup_tensor(ggml_ctx, src1);
src1_size = ggml_nbytes(src1);
src1_buffer = malloc(src1_size);
src1_clone->data = src1_buffer;
if (ggml_backend_buffer_is_host(src1->buffer)) {
memcpy(src1_clone->data, src1->data, src1_size);
memcpy(src1_clone->nb, src1->nb, sizeof(size_t) * GGML_MAX_DIMS);
} else if (ggml_backend_buffer_is_vk(src1->buffer)) {
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
vk_buffer& buffer_gpu = buf_ctx->dev_buffer;
uint64_t offset = vk_tensor_offset(src1) + src1->view_offs;
if (!ggml_is_contiguous(src1) && ggml_vk_dim01_contiguous(src1)) {
for (int i3 = 0; i3 < src1->ne[3]; i3++) {
for (int i2 = 0; i2 < src1->ne[2]; i2++) {
const int idx = i3*src1->ne[2] + i2;
ggml_vk_buffer_read(buffer_gpu, offset + idx * src1->nb[2], ((char *)src1_clone->data + idx * src1_clone->nb[2]), src1->ne[1] * src1->nb[1]);
}
}
src1_clone->nb[0] = src1->nb[0];
src1_clone->nb[1] = src1->nb[1];
for (int i = 2; i < GGML_MAX_DIMS; i++) {
src1_clone->nb[i] = src1_clone->nb[i - 1]*src1_clone->ne[i - 1];
}
} else {
if (offset + src1_size >= buffer_gpu->size) {
src1_size = buffer_gpu->size - offset;
}
ggml_vk_buffer_read(buffer_gpu, offset, src1_clone->data, src1_size);
memcpy(src1_clone->nb, src1->nb, sizeof(size_t) * GGML_MAX_DIMS);
}
} else {
GGML_ABORT("fatal error");
}
if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
ggml_vk_print_tensor(src1, "src1");
}
}
if (src2 != nullptr) {
src2_clone = ggml_dup_tensor(ggml_ctx, src2);
src2_size = ggml_nbytes(src2);
src2_buffer = malloc(src2_size);
src2_clone->data = src2_buffer;
if (ggml_backend_buffer_is_host(src2->buffer)) {
memcpy(src2_clone->data, src2->data, src2_size);
memcpy(src2_clone->nb, src2->nb, sizeof(size_t) * GGML_MAX_DIMS);
} else if (ggml_backend_buffer_is_vk(src2->buffer)) {
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src2->buffer->context;
vk_buffer& buffer_gpu = buf_ctx->dev_buffer;
uint64_t offset = vk_tensor_offset(src2) + src2->view_offs;
if (!ggml_is_contiguous(src2) && ggml_vk_dim01_contiguous(src2)) {
for (int i3 = 0; i3 < src2->ne[3]; i3++) {
for (int i2 = 0; i2 < src2->ne[2]; i2++) {
const int idx = i3*src2->ne[2] + i2;
ggml_vk_buffer_read(buffer_gpu, offset + idx * src2->nb[2], ((char *)src2_clone->data + idx * src2_clone->nb[2]), src2->ne[1] * src2->nb[1]);
}
}
src2_clone->nb[0] = src2->nb[0];
src2_clone->nb[1] = src2->nb[1];
for (int i = 2; i < GGML_MAX_DIMS; i++) {
src2_clone->nb[i] = src2_clone->nb[i - 1]*src2_clone->ne[i - 1];
}
} else {
if (offset + src2_size >= buffer_gpu->size) {
src2_size = buffer_gpu->size - offset;
}
ggml_vk_buffer_read(buffer_gpu, offset, src2_clone->data, src2_size);
memcpy(src2_clone->nb, src2->nb, sizeof(size_t) * GGML_MAX_DIMS);
}
} else {
GGML_ABORT("fatal error");
}
if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
ggml_vk_print_tensor(src2, "src2");
}
}
if (src3 != nullptr) {
src3_clone = ggml_dup_tensor(ggml_ctx, src3);
src3_size = ggml_nbytes(src3);
src3_buffer = malloc(src3_size);
src3_clone->data = src3_buffer;
if (ggml_backend_buffer_is_host(src3->buffer)) {
memcpy(src3_clone->data, src3->data, src3_size);
memcpy(src3_clone->nb, src3->nb, sizeof(size_t) * GGML_MAX_DIMS);
} else if (ggml_backend_buffer_is_vk(src3->buffer)) {
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src3->buffer->context;
vk_buffer& buffer_gpu = buf_ctx->dev_buffer;
uint64_t offset = vk_tensor_offset(src3) + src3->view_offs;
if (!ggml_is_contiguous(src3) && ggml_vk_dim01_contiguous(src3)) {
for (int i3 = 0; i3 < src3->ne[3]; i3++) {
for (int i2 = 0; i2 < src3->ne[2]; i2++) {
const int idx = i3*src3->ne[2] + i2;
ggml_vk_buffer_read(buffer_gpu, offset + idx * src3->nb[2], ((char *)src3_clone->data + idx * src3_clone->nb[2]), src3->ne[1] * src3->nb[1]);
}
}
src3_clone->nb[0] = src3->nb[0];
src3_clone->nb[1] = src3->nb[1];
for (int i = 2; i < GGML_MAX_DIMS; i++) {
src3_clone->nb[i] = src3_clone->nb[i - 1]*src3_clone->ne[i - 1];
}
} else {
if (offset + src3_size >= buffer_gpu->size) {
src3_size = buffer_gpu->size - offset;
}
ggml_vk_buffer_read(buffer_gpu, offset, src3_clone->data, src3_size);
memcpy(src3_clone->nb, src3->nb, sizeof(size_t) * GGML_MAX_DIMS);
}
} else {
GGML_ABORT("fatal error");
}
if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
ggml_vk_print_tensor(src3, "src3");
ggml_vk_print_tensor(srci, srci_name[i]);
}
}
if (tensor->op == GGML_OP_FLASH_ATTN_EXT) {
const float *params = (const float *)tensor->op_params;
tensor_clone = ggml_flash_attn_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, src3_clone, params[0], params[1], params[2]);
tensor_clone = ggml_flash_attn_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], params[0], params[1], params[2]);
} else if (tensor->op == GGML_OP_MUL_MAT) {
tensor_clone = ggml_mul_mat(ggml_ctx, src0_clone, src1_clone);
tensor_clone = ggml_mul_mat(ggml_ctx, src_clone[0], src_clone[1]);
} else if (tensor->op == GGML_OP_MUL_MAT_ID) {
tensor_clone = ggml_mul_mat_id(ggml_ctx, src0_clone, src1_clone, src2_clone);
tensor_clone = ggml_mul_mat_id(ggml_ctx, src_clone[0], src_clone[1], src_clone[2]);
} else if (tensor->op == GGML_OP_SUB) {
tensor_clone = ggml_sub(ggml_ctx, src_clone[0], src_clone[1]);
} else if (tensor->op == GGML_OP_MUL) {
tensor_clone = ggml_mul(ggml_ctx, src0_clone, src1_clone);
tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]);
} else if (tensor->op == GGML_OP_DIV) {
tensor_clone = ggml_div(ggml_ctx, src0_clone, src1_clone);
tensor_clone = ggml_div(ggml_ctx, src_clone[0], src_clone[1]);
} else if (tensor->op == GGML_OP_CONCAT) {
tensor_clone = ggml_concat(ggml_ctx, src0_clone, src1_clone, *(int *)tensor->op_params);
tensor_clone = ggml_concat(ggml_ctx, src_clone[0], src_clone[1], *(int *)tensor->op_params);
} else if (tensor->op == GGML_OP_UPSCALE) {
tensor_clone = ggml_upscale_ext(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
} else if (tensor->op == GGML_OP_SCALE) {
tensor_clone = ggml_scale(ggml_ctx, src0_clone, ((float *)tensor->op_params)[0]);
tensor_clone = ggml_scale(ggml_ctx, src_clone[0], ((float *)tensor->op_params)[0]);
} else if (tensor->op == GGML_OP_SQR) {
tensor_clone = ggml_sqr(ggml_ctx, src0_clone);
tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_SIN) {
tensor_clone = ggml_sin(ggml_ctx, src0_clone);
tensor_clone = ggml_sin(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_COS) {
tensor_clone = ggml_cos(ggml_ctx, src0_clone);
tensor_clone = ggml_cos(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_CLAMP) {
tensor_clone = ggml_clamp(ggml_ctx, src0_clone, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
} else if (tensor->op == GGML_OP_PAD) {
tensor_clone = ggml_pad(ggml_ctx, src0_clone, tensor->ne[0] - src0_clone->ne[0], tensor->ne[1] - src0_clone->ne[1], tensor->ne[2] - src0_clone->ne[2], tensor->ne[3] - src0_clone->ne[3]);
tensor_clone = ggml_pad(ggml_ctx, src_clone[0], tensor->ne[0] - src_clone[0]->ne[0], tensor->ne[1] - src_clone[0]->ne[1], tensor->ne[2] - src_clone[0]->ne[2], tensor->ne[3] - src_clone[0]->ne[3]);
} else if (tensor->op == GGML_OP_REPEAT) {
tensor_clone = ggml_repeat(ggml_ctx, src0_clone, tensor);
tensor_clone = ggml_repeat(ggml_ctx, src_clone[0], tensor);
} else if (tensor->op == GGML_OP_REPEAT_BACK) {
tensor_clone = ggml_repeat_back(ggml_ctx, src_clone[0], tensor);
} else if (tensor->op == GGML_OP_ADD) {
tensor_clone = ggml_add(ggml_ctx, src0_clone, src1_clone);
tensor_clone = ggml_add(ggml_ctx, src_clone[0], src_clone[1]);
} else if (tensor->op == GGML_OP_ACC) {
tensor_clone = ggml_acc(ggml_ctx, src0_clone, src1_clone, tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]);
tensor_clone = ggml_acc(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]);
} else if (tensor->op == GGML_OP_NORM) {
tensor_clone = ggml_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params);
tensor_clone = ggml_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
} else if (tensor->op == GGML_OP_GROUP_NORM) {
tensor_clone = ggml_group_norm(ggml_ctx, src0_clone, *(int *)tensor->op_params, ((float *)tensor->op_params)[1]);
tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], *(int *)tensor->op_params, ((float *)tensor->op_params)[1]);
} else if (tensor->op == GGML_OP_RMS_NORM) {
tensor_clone = ggml_rms_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params);
tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
} else if (tensor->op == GGML_OP_SOFT_MAX) {
if (src1 != nullptr) {
tensor_clone = ggml_soft_max_ext(ggml_ctx, src0_clone, src1_clone, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
} else {
tensor_clone = ggml_soft_max(ggml_ctx, src0_clone);
tensor_clone = ggml_soft_max(ggml_ctx, src_clone[0]);
}
} else if (tensor->op == GGML_OP_DIAG_MASK_INF) {
tensor_clone = ggml_diag_mask_inf(ggml_ctx, src0_clone, *(int *)tensor->op_params);
tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0], *(int *)tensor->op_params);
} else if (tensor->op == GGML_OP_ROPE) {
const int n_dims = ((int32_t *) tensor->op_params)[1];
const int mode = ((int32_t *) tensor->op_params)[2];
@ -8860,26 +8997,26 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
const float beta_slow = ((float *) tensor->op_params)[10];
if (mode & GGML_ROPE_TYPE_MROPE) {
int32_t *sections = ((int32_t *) tensor->op_params) + 11;
tensor_clone = ggml_rope_multi(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
tensor_clone = ggml_rope_multi(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
} else {
tensor_clone = ggml_rope_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
tensor_clone = ggml_rope_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
}
} else if (tensor->op == GGML_OP_UNARY) {
switch (ggml_get_unary_op(tensor)) {
case GGML_UNARY_OP_SILU:
tensor_clone = ggml_silu(ggml_ctx, src0_clone);
tensor_clone = ggml_silu(ggml_ctx, src_clone[0]);
break;
case GGML_UNARY_OP_GELU:
tensor_clone = ggml_gelu(ggml_ctx, src0_clone);
tensor_clone = ggml_gelu(ggml_ctx, src_clone[0]);
break;
case GGML_UNARY_OP_GELU_QUICK:
tensor_clone = ggml_gelu_quick(ggml_ctx, src0_clone);
tensor_clone = ggml_gelu_quick(ggml_ctx, src_clone[0]);
break;
case GGML_UNARY_OP_RELU:
tensor_clone = ggml_relu(ggml_ctx, src0_clone);
tensor_clone = ggml_relu(ggml_ctx, src_clone[0]);
break;
case GGML_UNARY_OP_TANH:
tensor_clone = ggml_tanh(ggml_ctx, src0_clone);
tensor_clone = ggml_tanh(ggml_ctx, src_clone[0]);
break;
default:
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
@ -8887,28 +9024,34 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
}
} else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) {
if (src1 == nullptr) {
tensor_clone = ggml_dup(ggml_ctx, src0_clone);
tensor_clone = ggml_dup(ggml_ctx, src_clone[0]);
tensor_clone->type = tensor->type;
} else {
tensor_clone = ggml_cpy(ggml_ctx, src0_clone, src1_clone);
tensor_clone = ggml_cpy(ggml_ctx, src_clone[0], src_clone[1]);
}
} else if (tensor->op == GGML_OP_CONT) {
tensor_clone = ggml_cont_4d(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
tensor_clone = ggml_cont_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
} else if (tensor->op == GGML_OP_RESHAPE) {
tensor_clone = ggml_reshape_4d(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
tensor_clone = ggml_reshape_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
} else if (tensor->op == GGML_OP_VIEW) {
tensor_clone = ggml_view_4d(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], tensor->nb[1], tensor->nb[2], tensor->nb[3], ((int32_t *) tensor->op_params)[0]);
tensor_clone = ggml_view_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], tensor->nb[1], tensor->nb[2], tensor->nb[3], ((int32_t *) tensor->op_params)[0]);
} else if (tensor->op == GGML_OP_PERMUTE) {
int32_t * params = (int32_t *)tensor->op_params;
tensor_clone = ggml_permute(ggml_ctx, src0_clone, params[0], params[1], params[2], params[3]);
tensor_clone = ggml_permute(ggml_ctx, src_clone[0], params[0], params[1], params[2], params[3]);
} else if (tensor->op == GGML_OP_TRANSPOSE) {
tensor_clone = ggml_transpose(ggml_ctx, src0_clone);
tensor_clone = ggml_transpose(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_GET_ROWS) {
tensor_clone = ggml_get_rows(ggml_ctx, src0_clone, src1_clone);
tensor_clone = ggml_get_rows(ggml_ctx, src_clone[0], src_clone[1]);
} else if (tensor->op == GGML_OP_ARGSORT) {
tensor_clone = ggml_argsort(ggml_ctx, src0_clone, (ggml_sort_order) *(int *)tensor->op_params);
tensor_clone = ggml_argsort(ggml_ctx, src_clone[0], (ggml_sort_order) *(int *)tensor->op_params);
} else if (tensor->op == GGML_OP_SUM) {
tensor_clone = ggml_sum(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_SUM_ROWS) {
tensor_clone = ggml_sum_rows(ggml_ctx, src0_clone);
tensor_clone = ggml_sum_rows(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_ARGMAX) {
tensor_clone = ggml_argmax(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_COUNT_EQUAL) {
tensor_clone = ggml_count_equal(ggml_ctx, src_clone[0], src_clone[1]);
} else if (tensor->op == GGML_OP_IM2COL) {
const int32_t s0 = tensor->op_params[0];
const int32_t s1 = tensor->op_params[1];
@ -8918,11 +9061,11 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
const int32_t d1 = tensor->op_params[5];
const bool is_2D = tensor->op_params[6] == 1;
tensor_clone = ggml_im2col(ggml_ctx, src0_clone, src1_clone, s0, s1, p0, p1, d0, d1, is_2D, tensor->type);
tensor_clone = ggml_im2col(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1, is_2D, tensor->type);
} else if (tensor->op == GGML_OP_TIMESTEP_EMBEDDING) {
const int32_t dim = tensor->op_params[0];
const int32_t max_period = tensor->op_params[1];
tensor_clone = ggml_timestep_embedding(ggml_ctx, src0_clone, dim, max_period);
tensor_clone = ggml_timestep_embedding(ggml_ctx, src_clone[0], dim, max_period);
} else if (tensor->op == GGML_OP_POOL_2D) {
enum ggml_op_pool op = static_cast<ggml_op_pool>(tensor->op_params[0]);
const int32_t k0 = tensor->op_params[1];
@ -8932,13 +9075,17 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
const int32_t p0 = tensor->op_params[5];
const int32_t p1 = tensor->op_params[6];
tensor_clone = ggml_pool_2d(ggml_ctx, src0_clone, op, k0, k1, s0, s1, p0, p1);
tensor_clone = ggml_pool_2d(ggml_ctx, src_clone[0], op, k0, k1, s0, s1, p0, p1);
} else if (tensor->op == GGML_OP_LEAKY_RELU) {
const float * op_params = (const float *)tensor->op_params;
tensor_clone = ggml_leaky_relu(ggml_ctx, src0_clone, op_params[0], false);
tensor_clone = ggml_leaky_relu(ggml_ctx, src_clone[0], op_params[0], false);
} else if (tensor->op == GGML_OP_RWKV_WKV6) {
tensor_clone = ggml_rwkv_wkv6(ggml_ctx, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3],
tensor->src[4], tensor->src[5]);
tensor_clone = ggml_rwkv_wkv6(ggml_ctx, src_clone[0], src_clone[1],
src_clone[2], src_clone[3], src_clone[4], src_clone[5]);
} else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) {
src_clone[0]->flags = src0->flags;
tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],
src_clone[2], src_clone[3], src_clone[4]);
}
else {
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
@ -8960,11 +9107,10 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
memcpy(comp_result, tensor_clone->data, comp_size);
memcpy(comp_nb, tensor_clone->nb, sizeof(size_t) * GGML_MAX_DIMS);
if (src0 != nullptr) {
free(src0_buffer);
}
if (src1 != nullptr) {
free(src1_buffer);
for (int i = 0; i < 6; i++) {
if (src_buffer[i] != nullptr) {
free(src_buffer[i]);
}
}
ggml_free(ggml_ctx);
@ -9028,6 +9174,9 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
} else if (tensor->type == GGML_TYPE_I32) {
correct = *(int32_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]);
result = *(int32_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]);
} else if (tensor->type == GGML_TYPE_I64) {
correct = *(int64_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]);
result = *(int64_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]);
} else {
std::cerr << "Results check not implemented for type " << ggml_type_name(tensor->type) << std::endl;
}

View File

@ -0,0 +1,51 @@
#version 450
#include "generic_head.comp"
#include "types.comp"
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
shared FLOAT_TYPE tmpmax[BLOCK_SIZE];
shared uint tmp[BLOCK_SIZE];
void main() {
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
const uint col = gl_LocalInvocationID.x;
if (col >= p.KX) {
return;
}
A_TYPE amax = data_a[row*p.KX + col];
tmp[col] = col;
for (uint i = col + BLOCK_SIZE; i < p.KX; i += BLOCK_SIZE) {
A_TYPE val = data_a[row*p.KX + i];
if (val > amax) {
amax = val;
tmp[col] = i;
}
}
tmpmax[col] = amax;
barrier();
[[unroll]] for (int s = int(BLOCK_SIZE) / 2; s > 0; s >>= 1) {
if (col < s && col + s < p.KX) {
if (tmpmax[col] < tmpmax[col + s]) {
tmpmax[col] = tmpmax[col + s];
tmp[col] = tmp[col + s];
}
}
barrier();
}
if (col == 0) {
data_d[row] = D_TYPE(tmp[0]);
}
}

View File

@ -0,0 +1,31 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#include "types.comp"
#include "generic_head.comp"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
layout (binding = 2) buffer D {D_TYPE data_d[];};
const uint CHUNK_SIZE = 512;
void main() {
const uint base = gl_WorkGroupID.x * CHUNK_SIZE;
const uint col = gl_LocalInvocationID.x;
uint count = 0;
[[unroll]]
for (uint i = 0; i < CHUNK_SIZE; i += gl_WorkGroupSize.x) {
const uint idx = base + i + col;
if (idx >= p.KX) {
break;
}
count += uint(data_a[idx] == data_b[idx]);
}
atomicAdd(data_d[0], D_TYPE(count));
}

View File

@ -0,0 +1,42 @@
#version 450
#include "generic_head.comp"
#include "types.comp"
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) buffer X {A_TYPE x[];};
layout (binding = 1) readonly buffer G {A_TYPE grad[];};
layout (binding = 2) buffer GM {A_TYPE gradm[];};
layout (binding = 3) buffer GV {A_TYPE gradv[];};
layout (binding = 4) readonly buffer P {float params[7];};
void main() {
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (i >= p.KX) {
return;
}
const float alpha = params[0];
const float beta1 = params[1];
const float beta2 = params[2];
const float eps = params[3];
const float wd = params[4];
const float beta1h = params[5];
const float beta2h = params[6];
const float gi = grad[i];
const float gmi = gradm[i]*beta1 + gi*(1.0f - beta1);
const float gvi = gradv[i]*beta2 + gi*gi*(1.0f - beta2);
gradm[i] = gmi;
gradv[i] = gvi;
const float mh = gmi*beta1h;
const float vh = sqrt(gvi*beta2h) + eps;
x[i] = x[i]*(1.0f - alpha*wd) - alpha*mh/vh;
}

View File

@ -0,0 +1,37 @@
#version 450
#include "types.comp"
#include "generic_unary_head.comp"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint idx = get_idx();
if (idx >= p.ne) {
return;
}
// Destination multi-index (inlined dst_idx)
const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L);
const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10;
const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L);
const uint i12_offset = i12*p.ne11*p.ne10;
const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L);
const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10;
const uint d_idx = i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + i10*p.nb10;
// Accumulate from sources
A_TYPE acc = A_TYPE(0);
for (uint i3 = i13; i3 < p.ne03; i3 += p.ne13) {
for (uint i2 = i12; i2 < p.ne02; i2 += p.ne12) {
for (uint i1 = i11; i1 < p.ne01; i1 += p.ne11) {
for (uint i0 = i10; i0 < p.ne00; i0 += p.ne10) {
acc += data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00];
}
}
}
}
data_d[get_doffset() + d_idx] = D_TYPE(acc);
}

View File

@ -0,0 +1,29 @@
#version 450
#extension GL_EXT_shader_16bit_storage : require
#include "types.comp"
#include "generic_binary_head.comp"
const uint num_threads = 256;
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
void main() {
uint idx = get_idx();
// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
const uint num_iter = 2;
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
if (idx >= p.ne) {
continue;
}
uint i00, i01, i02, i03;
get_indices(idx, i00, i01, i02, i03);
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) - FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]));
idx += num_threads;
}
}

View File

@ -443,6 +443,8 @@ void process_shaders() {
string_to_spv("add_f32", "add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
string_to_spv("sub_f32", "sub.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
@ -452,6 +454,7 @@ void process_shaders() {
string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("repeat_f32", "repeat.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("repeat_back_f32", "repeat_back.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
@ -501,7 +504,9 @@ void process_shaders() {
string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}}));
string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}));
string_to_spv("im2col_f32", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}));
@ -513,6 +518,8 @@ void process_shaders() {
string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
for (auto &c : compiles) {
c.wait();
}