From 12c462d6560a7da92ce8ca346445f27c154937d5 Mon Sep 17 00:00:00 2001
From: JidongZhang-THU <1119708529@qq.com>
Date: Wed, 31 Jan 2024 21:10:15 +0800
Subject: [PATCH] llava : add MobileVLM support (llama/5132)

* New Feature:
    1. Sum_Rows:
        fix cuda kernel overflow
        fix block shape error when nrows too big
    2. Im2Col:
        Support Batch in cuda
        Support f32 to f32 both in cpu && cuda
    3. DepthWiseConv:
        Support by Im2Col && MulMat
    4. Pool_2d:
        Supoort avg pooling in cuda
    5. HardSigmoid:
        Imp in cuda
    6. HardSwish:
        Imp in cuda

* fix tabs instead of spaces

* code clean

* CUDA POOL2D

* ADD POOL2D test case in test-backend-ops.cpp

* code clean

* fix pool2d_kernel

nits

* fix bug in pool2d kernel

* fix avg pooling, count_include_pad

nits

* test-backend-ops : add more pool_2d tests

* cuda : fix warnings and formatting

* ggml : check types in release builds too in pool_2d

* test-backend-ops : remove f16 pool_2d tests

* cuda : more style fixes

* Add assert in ggml_cuda_op_pool2d

* pool2d float padding fallback

* test-backend-ops : add dst_type to im2col

---------

Co-authored-by: slaren <slarengh@gmail.com>
---
 ggml-cuda.cu | 209 ++++++++++++++++++++++++++++++++++++++++++++++-----
 ggml.c       | 118 +++++++++++++++++++++++++----
 ggml.h       |   3 +-
 3 files changed, 296 insertions(+), 34 deletions(-)

diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 949bc8a1..e5659574 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -524,6 +524,8 @@ static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong
 #define CUDA_SILU_BLOCK_SIZE 256
 #define CUDA_TANH_BLOCK_SIZE 256
 #define CUDA_RELU_BLOCK_SIZE 256
+#define CUDA_HARDSIGMOID_BLOCK_SIZE 256
+#define CUDA_HARDSWISH_BLOCK_SIZE 256
 #define CUDA_SQR_BLOCK_SIZE 256
 #define CUDA_CPY_BLOCK_SIZE 32
 #define CUDA_SCALE_BLOCK_SIZE 256
@@ -540,6 +542,7 @@ static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong
 #define CUDA_PAD_BLOCK_SIZE 256
 #define CUDA_ACC_BLOCK_SIZE 256
 #define CUDA_IM2COL_BLOCK_SIZE 256
+#define CUDA_POOL2D_BLOCK_SIZE 256
 
 #define CUDA_Q8_0_NE_ALIGN 2048
 
@@ -823,6 +826,24 @@ static __global__ void relu_f32(const float * x, float * dst, const int k) {
     dst[i] = fmaxf(x[i], 0);
 }
 
+static __global__ void hardsigmoid_f32(const float * x, float * dst, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+    dst[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
+}
+
+static __global__ void hardswish_f32(const float * x, float * dst, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+    dst[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
+}
+
 static __global__ void leaky_relu_f32(const float * x, float * dst, const int k, const float negative_slope) {
     const int i  = blockDim.x*blockIdx.x + threadIdx.x;
     if (i >= k) {
@@ -5823,7 +5844,7 @@ static __global__ void alibi_f32(const float * x, float * dst, const int ncols,
 }
 
 static __global__ void k_sum_rows_f32(const float * x, float * dst, const int ncols) {
-    const int row = blockIdx.y;
+    const int row = blockIdx.x;
     const int col = threadIdx.x;
 
     float sum = 0.0f;
@@ -6145,9 +6166,10 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min,
     dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
 }
 
-static  __global__ void im2col_f32_f16(
-        const float * x, half * dst,
-        int offset_delta, int IW, int IH, int OW, int KW, int KH, int pelements, int CHW,
+template <typename T>
+static  __global__ void im2col_kernel(
+        const float * x, T * dst, int batch_offset,
+        int offset_delta, int IC, int IW, int IH, int OH, int OW, int KW, int KH, int pelements, int CHW,
         int s0, int s1, int p0, int p1, int d0, int d1) {
     const int i = threadIdx.x + blockIdx.x * blockDim.x;
     if (i >= pelements) {
@@ -6160,21 +6182,73 @@ static  __global__ void im2col_f32_f16(
     const int ky = (i - kd) / OW;
     const int ix = i % OW;
 
+    const int oh = blockIdx.y;
+    const int batch = blockIdx.z / IC;
+    const int ic = blockIdx.z % IC;
+
     const int64_t iiw = ix * s0 + kx * d0 - p0;
-    const int64_t iih = blockIdx.y * s1 + ky * d1 - p1;
+    const int64_t iih = oh * s1 + ky * d1 - p1;
 
     const int64_t offset_dst =
-        (blockIdx.y * OW + ix) * CHW +
-        (blockIdx.z * (KW * KH) + ky * KW + kx);
+        ((batch * OH + oh) * OW + ix) * CHW +
+        (ic * (KW * KH) + ky * KW + kx);
 
     if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
-        dst[offset_dst] = __float2half(0.0f);
+        dst[offset_dst] = 0.0f;
     } else {
-        const int64_t offset_src = blockIdx.z * offset_delta;
-        dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]);
+        const int64_t offset_src = ic * offset_delta + batch * batch_offset;
+        dst[offset_dst] = x[offset_src + iih * IW + iiw];
     }
 }
 
+template <typename Ti, typename To>
+static  __global__ void pool2d_nchw_kernel(
+        const int ih, const int iw, const int oh, const int ow,
+        const int kh, const int kw, const int sh, const int sw,
+        const int ph, const int pw, const int parallel_elements,
+        const Ti* src, To* dst, const enum ggml_op_pool op) {
+        int idx = threadIdx.x + blockIdx.x * blockDim.x;
+        if (idx >= parallel_elements) {
+            return;
+        }
+
+        const int I_HW = ih * iw;
+        const int O_HW = oh * ow;
+        const int nc = idx / O_HW;
+        const int cur_oh = idx % O_HW / ow;
+        const int cur_ow = idx % O_HW % ow;
+        const Ti* i_ptr = src + nc * I_HW;
+        To* o_ptr = dst + nc * O_HW;
+        const int start_h = cur_oh * sh - ph;
+        const int bh = max(0, start_h);
+        const int eh = min(ih, start_h + kh);
+        const int start_w = cur_ow * sw - pw;
+        const int bw = max(0, start_w);
+        const int ew = min(iw, start_w + kw);
+        const To scale = 1. / (kh * kw);
+        To res = 0;
+
+        switch (op) {
+            case GGML_OP_POOL_AVG: res = 0; break;
+            case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
+        }
+
+        for (int i = bh; i < eh; i += 1) {
+            for (int j = bw; j < ew; j += 1) {
+    #if __CUDA_ARCH__ >= 350
+                Ti cur = __ldg(i_ptr + i * iw + j);
+    #else
+                Ti cur = i_ptr[i * iw + j];
+    #endif
+                switch (op) {
+                    case GGML_OP_POOL_AVG: res += cur * scale; break;
+                    case GGML_OP_POOL_MAX: res = max(res, (To)cur); break;
+                }
+            }
+        }
+        o_ptr[cur_oh * ow + cur_ow] = res;
+}
+
 template<int qk, int qr, dequantize_kernel_t dq>
 static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
                             const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
@@ -6388,6 +6462,16 @@ static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
     relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
 }
 
+static void hardsigmoid_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_HARDSIGMOID_BLOCK_SIZE - 1) / CUDA_HARDSIGMOID_BLOCK_SIZE;
+    hardsigmoid_f32<<<num_blocks, CUDA_HARDSIGMOID_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+}
+
+static void hardswish_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_HARDSWISH_BLOCK_SIZE - 1) / CUDA_HARDSWISH_BLOCK_SIZE;
+    hardswish_f32<<<num_blocks, CUDA_HARDSWISH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+}
+
 static void leaky_relu_f32_cuda(const float * x, float * dst, const int k, const float negative_slope, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
     leaky_relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, negative_slope);
@@ -7475,7 +7559,7 @@ static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const
 
 static void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     const dim3 block_dims(WARP_SIZE, 1, 1);
-    const dim3 block_nums(1, nrows, 1);
+    const dim3 block_nums(nrows, 1, 1);
     k_sum_rows_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
 }
 
@@ -7587,14 +7671,15 @@ static void soft_max_f32_cuda(const float * x, const float * y, float * dst, con
     }
 }
 
-static void im2col_f32_f16_cuda(const float* x, half* dst,
+template <typename T>
+static void im2col_cuda(const float* x, T* dst,
     int IW, int IH, int OW, int OH, int KW, int KH, int IC,
-    int offset_delta,
+    int batch, int batch_offset, int offset_delta,
     int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
     const int parallel_elements = OW * KW * KH;
     const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
-    dim3 block_nums(num_blocks, OH, IC);
-    im2col_f32_f16<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, stream>>>(x, dst, offset_delta, IW, IH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
+    dim3 block_nums(num_blocks, OH, batch * IC);
+    im2col_kernel<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, stream>>>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
 }
 
 // buffer pool for cuda
@@ -8179,6 +8264,34 @@ static void ggml_cuda_op_relu(
     (void) src1_dd;
 }
 
+static void ggml_cuda_op_hardsigmoid(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    hardsigmoid_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
+
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
+
+static void ggml_cuda_op_hardswish(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    hardswish_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
+
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
+
 static void ggml_cuda_op_leaky_relu(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
     const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
@@ -8810,13 +8923,46 @@ static void ggml_cuda_op_alibi(
     (void) src1_dd;
 }
 
+static void ggml_cuda_op_pool2d(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    const int32_t * opts = (const int32_t *)dst->op_params;
+    enum ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
+    const int k0 = opts[1];
+    const int k1 = opts[2];
+    const int s0 = opts[3];
+    const int s1 = opts[4];
+    const int p0 = opts[5];
+    const int p1 = opts[6];
+
+    const int64_t IH = src0->ne[1];
+    const int64_t IW = src0->ne[0];
+
+    const int64_t N = dst->ne[3];
+    const int64_t OC = dst->ne[2];
+    const int64_t OH = dst->ne[1];
+    const int64_t OW = dst->ne[0];
+
+    const int parallel_elements = N * OC * OH * OW;
+    const int num_blocks = (parallel_elements + CUDA_POOL2D_BLOCK_SIZE - 1) / CUDA_POOL2D_BLOCK_SIZE;
+    dim3 block_nums(num_blocks);
+    pool2d_nchw_kernel<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, main_stream>>>(IH, IW, OH, OW, k1, k0, s1, s0, p1, p0, parallel_elements, src0_dd, dst_dd, op);
+
+    (void) src1;
+    (void) src1_dd;
+}
+
 static void ggml_cuda_op_im2col(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
     const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
 
     GGML_ASSERT(src0->type == GGML_TYPE_F16);
     GGML_ASSERT(src1->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F16);
+    GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
 
     const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
     const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
@@ -8838,8 +8984,14 @@ static void ggml_cuda_op_im2col(
     const int64_t OW =         dst->ne[1];
 
     const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
+    const int64_t batch = src1->ne[3];
+    const size_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32
 
-    im2col_f32_f16_cuda(src1_dd, (half*) dst_dd, IW, IH, OW, OH, KW, KH, IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
+    if(dst->type == GGML_TYPE_F16) {
+        im2col_cuda(src1_dd, (half*) dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
+    } else {
+        im2col_cuda(src1_dd, (float*) dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
+    }
 
     (void) src0;
     (void) src0_dd;
@@ -9435,6 +9587,13 @@ static void ggml_cuda_relu(const ggml_tensor * src0, const ggml_tensor * src1, g
     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_relu);
 }
 
+static void ggml_cuda_hardsigmoid(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_hardsigmoid);
+}
+
+static void ggml_cuda_hardswish(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_hardswish);
+}
 static void ggml_cuda_leaky_relu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_leaky_relu);
 }
@@ -10220,6 +10379,10 @@ static void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1,
     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi);
 }
 
+static void ggml_cuda_pool2d(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_pool2d);
+}
+
 static void ggml_cuda_im2col(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_im2col);
 }
@@ -10321,6 +10484,12 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st
                 case GGML_UNARY_OP_RELU:
                     func = ggml_cuda_relu;
                     break;
+                case GGML_UNARY_OP_HARDSIGMOID:
+                    func = ggml_cuda_hardsigmoid;
+                    break;
+                case GGML_UNARY_OP_HARDSWISH:
+                    func = ggml_cuda_hardswish;
+                    break;
                 default:
                     return false;
             }
@@ -10395,6 +10564,9 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st
         case GGML_OP_IM2COL:
             func = ggml_cuda_im2col;
             break;
+        case GGML_OP_POOL_2D:
+            func = ggml_cuda_pool2d;
+            break;
         case GGML_OP_SUM_ROWS:
             func = ggml_cuda_sum_rows;
             break;
@@ -11123,6 +11295,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
                 case GGML_UNARY_OP_GELU:
                 case GGML_UNARY_OP_SILU:
                 case GGML_UNARY_OP_RELU:
+                case GGML_UNARY_OP_HARDSIGMOID:
+                case GGML_UNARY_OP_HARDSWISH:
                 case GGML_UNARY_OP_GELU_QUICK:
                 case GGML_UNARY_OP_TANH:
                     return true;
@@ -11221,6 +11395,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
         case GGML_OP_ROPE:
         case GGML_OP_ALIBI:
         case GGML_OP_IM2COL:
+        case GGML_OP_POOL_2D:
         case GGML_OP_SUM_ROWS:
         case GGML_OP_ARGSORT:
         case GGML_OP_ACC:
diff --git a/ggml.c b/ggml.c
index 1286ea8e..6dc9a525 100644
--- a/ggml.c
+++ b/ggml.c
@@ -5349,7 +5349,7 @@ GGML_API struct ggml_tensor * ggml_conv_1d(
         int                   s0,
         int                   p0,
         int                   d0) {
-    struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false); // [N, OL, IC * K]
+    struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false, GGML_TYPE_F16); // [N, OL, IC * K]
 
     struct ggml_tensor * result =
         ggml_mul_mat(ctx,
@@ -5427,16 +5427,15 @@ struct ggml_tensor * ggml_conv_depthwise_2d(
     int                  p1,
     int                  d0,
     int                  d1) {
+
     struct ggml_tensor * new_a = ggml_reshape_4d(ctx, a, a->ne[0], a->ne[1], 1, a->ne[2] * a->ne[3]);
     struct ggml_tensor * im2col = ggml_im2col(ctx, new_a,
                                         ggml_reshape_4d(ctx, b, b->ne[0], b->ne[1], 1, b->ne[2] * b->ne[3]),
-                                        s0, s1, p0, p1, d0, d1, true); // [N * IC, OH, OW, KH * KW]
-
-    struct ggml_tensor * result =
-        ggml_mul_mat(ctx,
-                ggml_reshape_4d(ctx, new_a, (new_a->ne[0] * new_a->ne[1]), new_a->ne[2],  new_a->ne[3], 1),                       // [OC,1, KH, KW] => [1, OC, 1, KH * KW]
-                ggml_reshape_4d(ctx, im2col, im2col->ne[0], im2col->ne[2] * im2col->ne[1], b->ne[2], b->ne[3])); // [N * IC, OH, OW, KH * KW] => [N, IC, OH * OW, KH * KW]
+                                        s0, s1, p0, p1, d0, d1, true, GGML_TYPE_F16); // [N * IC, OH, OW, KH * KW]
+    struct ggml_tensor * new_b = ggml_reshape_4d(ctx, im2col, im2col->ne[0], im2col->ne[2] * im2col->ne[1], b->ne[2], b->ne[3]); // [N * IC, OH, OW, KH * KW] => [N, IC, OH * OW, KH * KW]
 
+    new_a = ggml_reshape_4d(ctx, new_a, (new_a->ne[0] * new_a->ne[1]), new_a->ne[2],  new_a->ne[3], 1);                       // [OC,1, KH, KW] => [1, OC, 1, KH * KW]
+    struct ggml_tensor * result = ggml_mul_mat(ctx, new_a, new_b);
     result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], b->ne[2], b->ne[3]); // [N, OC, OH, OW]
 
     return result;
@@ -5457,7 +5456,8 @@ struct ggml_tensor * ggml_im2col(
     int                  p1,
     int                  d0,
     int                  d1,
-    bool                 is_2D) {
+    bool                 is_2D,
+    enum ggml_type       dst_type) {
 
     if(is_2D) {
         GGML_ASSERT(a->ne[2] == b->ne[2]);
@@ -5481,7 +5481,7 @@ struct ggml_tensor * ggml_im2col(
         is_2D ?      b->ne[3] : 1,
     };
 
-    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne);
+    struct ggml_tensor * result = ggml_new_tensor(ctx, dst_type, 4, ne);
     int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) };
     ggml_set_op_params(result, params, sizeof(params));
 
@@ -5506,7 +5506,7 @@ struct ggml_tensor * ggml_conv_2d(
         int                  p1,
         int                  d0,
         int                  d1) {
-    struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true); // [N, OH, OW, IC * KH * KW]
+    struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true, GGML_TYPE_F16); // [N, OH, OW, IC * KH * KW]
 
     struct ggml_tensor * result =
         ggml_mul_mat(ctx,
@@ -5632,12 +5632,13 @@ struct ggml_tensor * ggml_pool_2d(
         is_node = true;
     }
 
+    struct ggml_tensor * result;
     const int64_t ne[3] = {
         ggml_calc_pool_output_size(a->ne[0], k0, s0, p0),
         ggml_calc_pool_output_size(a->ne[1], k1, s1, p1),
         a->ne[2],
     };
-    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne);
+    result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne);
 
     int32_t params[] = { op, k0, k1, s0, s1, p0, p1 };
     ggml_set_op_params(result, params, sizeof(params));
@@ -5645,7 +5646,6 @@ struct ggml_tensor * ggml_pool_2d(
     result->op = GGML_OP_POOL_2D;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-
     return result;
 }
 
@@ -12493,6 +12493,92 @@ static void ggml_compute_forward_conv_transpose_1d(
     }
 }
 
+// src0: kernel [OC, IC, KH, KW]
+// src1: image [N, IC, IH, IW]
+// dst:  result [N, OH, OW, IC*KH*KW]
+static void ggml_compute_forward_im2col_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+              struct ggml_tensor * dst) {
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    int64_t t0 = ggml_perf_time_us();
+    UNUSED(t0);
+
+    GGML_TENSOR_BINARY_OP_LOCALS;
+
+    const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
+    const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
+    const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
+    const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
+    const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
+    const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
+    const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int64_t N  = is_2D ? ne13 : ne12;
+    const int64_t IC = is_2D ? ne12 : ne11;
+    const int64_t IH = is_2D ? ne11 : 1;
+    const int64_t IW = ne10;
+
+    const int64_t KH = is_2D ? ne01 : 1;
+    const int64_t KW = ne00;
+
+    const int64_t OH = is_2D ? ne2 : 1;
+    const int64_t OW = ne1;
+
+    int ofs0 = is_2D ? nb13 : nb12;
+    int ofs1 = is_2D ? nb12 : nb11;
+
+    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nb10 == sizeof(float));
+
+    if (params->type == GGML_TASK_INIT) {
+        return;
+    }
+
+    if (params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
+    {
+        float * const wdata = (float *) dst->data;
+
+        for (int64_t in = 0; in < N; in++) {
+            for (int64_t ioh = 0; ioh < OH; ioh++) { // 1
+                for (int64_t iow = 0; iow < OW; iow++) {
+                    for (int64_t iic = ith; iic < IC; iic += nth) {
+
+                        // micro kernel
+                        float * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
+                        const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
+
+                        for (int64_t ikh = 0; ikh < KH; ikh++) {  // 1
+                            for (int64_t ikw = 0; ikw < KW; ikw++) {
+                                const int64_t iiw = iow*s0 + ikw*d0 - p0;
+                                const int64_t iih = ioh*s1 + ikh*d1 - p1;
+
+                                if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
+                                    dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
+                                } else {
+                                    dst_data[iic*(KH*KW) + ikh*KW + ikw] = (src_data[iih*IW + iiw]);
+                                }
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    }
+}
+
+
 // src0: kernel [OC, IC, KH, KW]
 // src1: image [N, IC, IH, IW]
 // dst:  result [N, OH, OW, IC*KH*KW]
@@ -12583,14 +12669,14 @@ static void ggml_compute_forward_im2col(
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
               struct ggml_tensor * dst) {
-    switch (src0->type) {
+    switch (dst->type) {
         case GGML_TYPE_F16:
             {
                 ggml_compute_forward_im2col_f16(params, src0, src1, dst);
             } break;
         case GGML_TYPE_F32:
             {
-                GGML_ASSERT(false);
+                ggml_compute_forward_im2col_f32(params, src0, src1, dst);
             } break;
         default:
             {
@@ -12781,8 +12867,8 @@ static void ggml_compute_forward_pool_2d(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src,
         struct ggml_tensor * dst) {
-    assert(src->type == GGML_TYPE_F32);
-    assert(params->ith == 0);
+    GGML_ASSERT(src->type == GGML_TYPE_F32);
+    GGML_ASSERT(params->ith == 0);
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
diff --git a/ggml.h b/ggml.h
index 3b264bfd..1360cd8e 100644
--- a/ggml.h
+++ b/ggml.h
@@ -1500,7 +1500,8 @@ extern "C" {
             int                  p1,
             int                  d0,
             int                  d1,
-            bool                 is_2D);
+            bool                 is_2D,
+            enum ggml_type       dst_type);
 
     GGML_API struct ggml_tensor * ggml_conv_depthwise_2d(
             struct ggml_context * ctx,