mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2024-12-19 20:57:52 +00:00
ggml : im2col opts
This commit is contained in:
parent
3bfc43e3e3
commit
66bb2e9401
@ -4747,11 +4747,11 @@ static __global__ void im2col_f32_f16(
|
||||
(threadIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z) * CHW +
|
||||
(blockIdx.x * (blockDim.y * blockDim.z) + threadIdx.y * blockDim.z + threadIdx.z);
|
||||
|
||||
if (!(iih < 0 || iih >= IH || iiw < 0 || iiw >= IW)) {
|
||||
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
||||
dst[offset_dst] = __float2half(0.0f);
|
||||
} else {
|
||||
const int offset_src = threadIdx.x * ofs0 + blockIdx.x * ofs1;
|
||||
dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]);
|
||||
} else {
|
||||
dst[offset_dst] = __float2half(0.0f);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1327,11 +1327,11 @@ kernel void kernel_im2col_f16(
|
||||
(tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
|
||||
(tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]);
|
||||
|
||||
if (!(iih < 0 || iih >= IH || iiw < 0 || iiw >= IW)) {
|
||||
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
||||
dst[offset_dst] = 0.0f;
|
||||
} else {
|
||||
const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1;
|
||||
dst[offset_dst] = x[offset_src + iih * IW + iiw];
|
||||
} else {
|
||||
dst[offset_dst] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
|
22
ggml.c
22
ggml.c
@ -1777,7 +1777,6 @@ static void ggml_setup_op_has_task_pass(void) {
|
||||
p[GGML_OP_DIAG_MASK_INF ] = true;
|
||||
p[GGML_OP_DIAG_MASK_ZERO ] = true;
|
||||
p[GGML_OP_CONV_TRANSPOSE_1D ] = true;
|
||||
p[GGML_OP_IM2COL ] = true;
|
||||
p[GGML_OP_CONV_TRANSPOSE_2D ] = true;
|
||||
p[GGML_OP_FLASH_ATTN_BACK ] = true;
|
||||
p[GGML_OP_CROSS_ENTROPY_LOSS ] = true;
|
||||
@ -5122,8 +5121,6 @@ static int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p,
|
||||
return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
|
||||
}
|
||||
|
||||
// ggml_conv_1d
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_conv_1d(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
@ -5263,14 +5260,14 @@ struct ggml_tensor * ggml_conv_2d(
|
||||
int p1,
|
||||
int d0,
|
||||
int d1) {
|
||||
struct ggml_tensor * result = ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true); // [N, OH, OW, IC * KH * KW]
|
||||
struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true); // [N, OH, OW, IC * KH * KW]
|
||||
|
||||
result =
|
||||
ggml_reshape_4d(ctx,
|
||||
ggml_mul_mat(ctx,
|
||||
ggml_reshape_2d(ctx, result, result->ne[0], result->ne[3] * result->ne[2] * result->ne[1]), // [N, OH, OW, IC * KH * KW] => [N*OH*OW, IC * KH * KW]
|
||||
ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1] * a->ne[2]), a->ne[3])), // [OC,IC, KH, KW] => [OC, IC * KH * KW]
|
||||
result->ne[1], result->ne[2], a->ne[3], result->ne[3]); // [N, OC, OH, OW]
|
||||
struct ggml_tensor * result =
|
||||
ggml_mul_mat(ctx,
|
||||
ggml_reshape_2d(ctx, im2col, im2col->ne[0], im2col->ne[3] * im2col->ne[2] * im2col->ne[1]), // [N, OH, OW, IC * KH * KW] => [N*OH*OW, IC * KH * KW]
|
||||
ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1] * a->ne[2]), a->ne[3])); // [OC,IC, KH, KW] => [OC, IC * KH * KW]
|
||||
|
||||
result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], a->ne[3], im2col->ne[3]); // [N, OC, OH, OW]
|
||||
|
||||
return result;
|
||||
}
|
||||
@ -11757,7 +11754,6 @@ static void ggml_compute_forward_im2col_f16(
|
||||
GGML_ASSERT(nb10 == sizeof(float));
|
||||
|
||||
if (params->type == GGML_TASK_INIT) {
|
||||
memset(dst->data, 0, ggml_nbytes(dst));
|
||||
return;
|
||||
}
|
||||
|
||||
@ -11783,7 +11779,9 @@ static void ggml_compute_forward_im2col_f16(
|
||||
const int64_t iiw = iow*s0 + ikw*d0 - p0;
|
||||
const int64_t iih = ioh*s1 + ikh*d1 - p1;
|
||||
|
||||
if (!(iih < 0 || iih >= IH || iiw < 0 || iiw >= IW)) {
|
||||
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
||||
dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
|
||||
} else {
|
||||
dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_FP32_TO_FP16(src_data[iih*IW + iiw]);
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user