ggml : im2col opts

This commit is contained in:
Georgi Gerganov 2023-11-11 10:41:00 +02:00
parent 3bfc43e3e3
commit 66bb2e9401
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
3 changed files with 16 additions and 18 deletions

View File

@ -4747,11 +4747,11 @@ static __global__ void im2col_f32_f16(
(threadIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z) * CHW +
(blockIdx.x * (blockDim.y * blockDim.z) + threadIdx.y * blockDim.z + threadIdx.z);
if (!(iih < 0 || iih >= IH || iiw < 0 || iiw >= IW)) {
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
dst[offset_dst] = __float2half(0.0f);
} else {
const int offset_src = threadIdx.x * ofs0 + blockIdx.x * ofs1;
dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]);
} else {
dst[offset_dst] = __float2half(0.0f);
}
}

View File

@ -1327,11 +1327,11 @@ kernel void kernel_im2col_f16(
(tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
(tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]);
if (!(iih < 0 || iih >= IH || iiw < 0 || iiw >= IW)) {
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
dst[offset_dst] = 0.0f;
} else {
const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1;
dst[offset_dst] = x[offset_src + iih * IW + iiw];
} else {
dst[offset_dst] = 0.0f;
}
}

22
ggml.c
View File

@ -1777,7 +1777,6 @@ static void ggml_setup_op_has_task_pass(void) {
p[GGML_OP_DIAG_MASK_INF ] = true;
p[GGML_OP_DIAG_MASK_ZERO ] = true;
p[GGML_OP_CONV_TRANSPOSE_1D ] = true;
p[GGML_OP_IM2COL ] = true;
p[GGML_OP_CONV_TRANSPOSE_2D ] = true;
p[GGML_OP_FLASH_ATTN_BACK ] = true;
p[GGML_OP_CROSS_ENTROPY_LOSS ] = true;
@ -5122,8 +5121,6 @@ static int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p,
return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
}
// ggml_conv_1d
GGML_API struct ggml_tensor * ggml_conv_1d(
struct ggml_context * ctx,
struct ggml_tensor * a,
@ -5263,14 +5260,14 @@ struct ggml_tensor * ggml_conv_2d(
int p1,
int d0,
int d1) {
struct ggml_tensor * result = ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true); // [N, OH, OW, IC * KH * KW]
struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true); // [N, OH, OW, IC * KH * KW]
result =
ggml_reshape_4d(ctx,
ggml_mul_mat(ctx,
ggml_reshape_2d(ctx, result, result->ne[0], result->ne[3] * result->ne[2] * result->ne[1]), // [N, OH, OW, IC * KH * KW] => [N*OH*OW, IC * KH * KW]
ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1] * a->ne[2]), a->ne[3])), // [OCIC, KH, KW] => [OC, IC * KH * KW]
result->ne[1], result->ne[2], a->ne[3], result->ne[3]); // [N, OC, OH, OW]
struct ggml_tensor * result =
ggml_mul_mat(ctx,
ggml_reshape_2d(ctx, im2col, im2col->ne[0], im2col->ne[3] * im2col->ne[2] * im2col->ne[1]), // [N, OH, OW, IC * KH * KW] => [N*OH*OW, IC * KH * KW]
ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1] * a->ne[2]), a->ne[3])); // [OCIC, KH, KW] => [OC, IC * KH * KW]
result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], a->ne[3], im2col->ne[3]); // [N, OC, OH, OW]
return result;
}
@ -11757,7 +11754,6 @@ static void ggml_compute_forward_im2col_f16(
GGML_ASSERT(nb10 == sizeof(float));
if (params->type == GGML_TASK_INIT) {
memset(dst->data, 0, ggml_nbytes(dst));
return;
}
@ -11783,7 +11779,9 @@ static void ggml_compute_forward_im2col_f16(
const int64_t iiw = iow*s0 + ikw*d0 - p0;
const int64_t iih = ioh*s1 + ikh*d1 - p1;
if (!(iih < 0 || iih >= IH || iiw < 0 || iiw >= IW)) {
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
} else {
dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_FP32_TO_FP16(src_data[iih*IW + iiw]);
}
}