diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index c0bdb9e1..65c8cf8b 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -368,6 +368,8 @@ struct vk_device_struct { vk_pipeline pipeline_rwkv_wkv6_f32; vk_pipeline pipeline_rwkv_wkv7_f32; vk_pipeline pipeline_opt_step_adamw_f32; + vk_pipeline pipeline_conv2d_dw_whcn_f32; + vk_pipeline pipeline_conv2d_dw_cwhn_f32; // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned} vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2]; @@ -680,6 +682,24 @@ struct vk_op_rwkv_wkv7_push_constants { uint32_t H; }; +struct vk_op_conv2d_dw_push_constants { + uint32_t ne; + uint32_t batches; + uint32_t channels; + uint32_t dst_w; + uint32_t dst_h; + uint32_t src_w; + uint32_t src_h; + uint32_t knl_w; + uint32_t knl_h; + int32_t stride_x; + int32_t stride_y; + int32_t pad_x; + int32_t pad_y; + int32_t dilation_x; + int32_t dilation_y; +}; + struct vk_op_upscale_push_constants { uint32_t ne; uint32_t a_offset; uint32_t d_offset; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03; @@ -2529,6 +2549,9 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f32, "conv2d_dw_cwhn_f32", conv2d_dw_cwhn_f32_len, conv2d_dw_cwhn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); + for (auto &c : compiles) { c.wait(); } @@ -5988,6 +6011,15 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_leaky_relu_f32; } return nullptr; + case GGML_OP_CONV_2D_DW: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + if (ggml_is_contiguous(src1)) { + return ctx->device->pipeline_conv2d_dw_whcn_f32; + } else if (ggml_is_contiguous_channels(src1)) { + return ctx->device->pipeline_conv2d_dw_cwhn_f32; + } + } + return nullptr; default: return nullptr; } @@ -6014,6 +6046,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) { case GGML_OP_REPEAT_BACK: case GGML_OP_ROPE: case GGML_OP_RMS_NORM: + case GGML_OP_CONV_2D_DW: return true; default: return false; @@ -6310,6 +6343,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co case GGML_OP_CONCAT: case GGML_OP_UPSCALE: case GGML_OP_UNARY: + case GGML_OP_CONV_2D_DW: { const uint32_t ne = ggml_nelements(dst); if (ne > 262144) { @@ -7096,6 +7130,30 @@ static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, c }, dryrun); } +static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + vk_op_conv2d_dw_push_constants p{}; + p.ne = ggml_nelements(dst); + p.channels = dst->ne[2]; + p.batches = dst->ne[3]; + p.dst_w = dst->ne[0]; + p.dst_h = dst->ne[1]; + p.src_w = src1->ne[0]; + p.src_h = src1->ne[1]; + p.knl_w = src0->ne[0]; + p.knl_h = src0->ne[1]; + p.stride_x = dst->op_params[0]; + p.stride_y = dst->op_params[1]; + p.pad_x = dst->op_params[2]; + p.pad_y = dst->op_params[3]; + p.dilation_x = dst->op_params[4]; + p.dilation_y = dst->op_params[5]; + + GGML_ASSERT(src0->ne[3] == p.channels); + GGML_ASSERT(src1->ne[3] == p.batches); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_2D_DW, std::move(p), dryrun); +} + static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { const float * op_params = (const float *)dst->op_params; ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f }, dryrun); @@ -8116,6 +8174,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_IM2COL: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_POOL_2D: + case GGML_OP_CONV_2D_DW: case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV7: case GGML_OP_LEAKY_RELU: @@ -8179,6 +8238,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_IM2COL: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_POOL_2D: + case GGML_OP_CONV_2D_DW: case GGML_OP_LEAKY_RELU: { // These operations all go through ggml_vk_op_f32, so short-circuit and @@ -8352,6 +8412,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_POOL_2D: ggml_vk_pool_2d(ctx, compute_ctx, src0, node, dryrun); + break; + case GGML_OP_CONV_2D_DW: + ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node, dryrun); + break; case GGML_OP_LEAKY_RELU: ggml_vk_leaky_relu(ctx, compute_ctx, src0, node, dryrun); @@ -8473,6 +8537,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * case GGML_OP_IM2COL: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_POOL_2D: + case GGML_OP_CONV_2D_DW: case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV7: case GGML_OP_LEAKY_RELU: @@ -9442,6 +9507,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_COUNT_EQUAL: case GGML_OP_IM2COL: case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_CONV_2D_DW: case GGML_OP_POOL_2D: case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV7: diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp new file mode 100644 index 00000000..cde0e4b9 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp @@ -0,0 +1,104 @@ +#version 450 + +#include "types.comp" + +layout (push_constant) uniform parameter +{ + uint ne; + uint batches; + uint channels; + uint dst_w; + uint dst_h; + uint src_w; + uint src_h; + uint knl_w; + uint knl_h; + int stride_x; + int stride_y; + int pad_x; + int pad_y; + int dilation_x; + int dilation_y; +} p; + +layout (binding = 0) readonly buffer A {A_TYPE knl_data[];}; +layout (binding = 1) readonly buffer B {B_TYPE src_data[];}; +layout (binding = 2) writeonly buffer D {D_TYPE dst_data[];}; + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE conv_2d_dw_whcn(uint idx) { + uint i0 = idx / p.dst_w; + uint dst_x = idx - i0 * p.dst_w; + uint i1 = i0 / p.dst_h; + uint dst_y = i0 - i1 * p.dst_h; + uint n = i1 / p.channels; + uint c = i1 - n * p.channels; + + uint src_i = n * p.channels * p.src_h * p.src_w + c * p.src_h * p.src_w; + uint knl_i = c * p.knl_h * p.knl_w; + + FLOAT_TYPE sum = 0.0; + for (uint knl_y = 0; knl_y < p.knl_h; ++knl_y) { + uint src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y; + if (src_y >= p.src_h) { // src_y < 0 will wrap to a large unsigned int + continue; + } + for (uint knl_x = 0; knl_x < p.knl_w; ++knl_x) { + uint src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x; + if (src_x >= p.src_w) { // src_x < 0 will wrap to a large unsigned int + continue; + } + FLOAT_TYPE v = FLOAT_TYPE(src_data[src_i + src_y * p.src_w + src_x]); + FLOAT_TYPE k = FLOAT_TYPE(knl_data[knl_i + knl_y * p.knl_w + knl_x]); + sum = fma(v, k, sum); + } + } + return sum; +} + +FLOAT_TYPE conv_2d_dw_cwhn(uint idx) { + uint i0 = idx / p.channels; + uint c = idx - i0 * p.channels; + uint i1 = i0 / p.dst_w; + uint dst_x = i0 - i1 * p.dst_w; + uint n = i1 / p.dst_h; + uint dst_y = i1 - n * p.dst_h; + + uint src_i = n * p.channels * p.src_h * p.src_w; + uint src_row = p.src_w * p.channels; + uint knl_row = p.knl_w * p.channels; + + FLOAT_TYPE sum = 0.0; + for (uint knl_y = 0; knl_y < p.knl_h; ++knl_y) { + uint src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y; + if (src_y >= p.src_h) { // src_y < 0 will wrap to a large unsigned int + continue; + } + for (uint knl_x = 0; knl_x < p.knl_w; ++knl_x) { + uint src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x; + if (src_x >= p.src_w) { // src_x < 0 will wrap to a large unsigned int + continue; + } + FLOAT_TYPE v = FLOAT_TYPE(src_data[src_i + src_y * src_row + src_x * p.channels + c]); + FLOAT_TYPE k = FLOAT_TYPE(knl_data[ knl_y * knl_row + knl_x * p.channels + c]); + sum = fma(v, k, sum); + } + } + return sum; +} + +void main() { + uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + if (idx >= p.ne) { + return; + } + + FLOAT_TYPE result = +#ifdef WHCN + conv_2d_dw_whcn(idx); +#else + conv_2d_dw_cwhn(idx); +#endif + dst_data[idx] = D_TYPE(result); +} \ No newline at end of file diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index cf74625c..daf4e78f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -544,6 +544,9 @@ void process_shaders() { string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}})); + string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}})); + for (auto &c : compiles) { c.wait(); }