mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2024-12-22 05:57:48 +00:00
ggml: add support for float16 input tensors in pooling operations (ggml/895)
* Add support for float16 tensors in 1d pooling operations * Add support for float16 input tensors in 2d pooling operations * code cleanup remove unnecessary casting during srow ptr initialization --------- Co-authored-by: vanaka11 <vanaka1189@gmail.com>
This commit is contained in:
parent
8da6fd4dff
commit
b2ead7d6f4
@ -14579,7 +14579,7 @@ static void ggml_compute_forward_pool_1d_sk_p0(
|
|||||||
|
|
||||||
const struct ggml_tensor * src = dst->src[0];
|
const struct ggml_tensor * src = dst->src[0];
|
||||||
|
|
||||||
assert(src->type == GGML_TYPE_F32);
|
assert(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16);
|
||||||
|
|
||||||
if (params->ith != 0) {
|
if (params->ith != 0) {
|
||||||
return;
|
return;
|
||||||
@ -14592,10 +14592,8 @@ static void ggml_compute_forward_pool_1d_sk_p0(
|
|||||||
const int64_t rs = dst->ne[0];
|
const int64_t rs = dst->ne[0];
|
||||||
|
|
||||||
while (cdata < data_end) {
|
while (cdata < data_end) {
|
||||||
const float * const srow = (const float *)cdata;
|
const void * srow = (const void *)cdata;
|
||||||
|
|
||||||
int j = 0;
|
int j = 0;
|
||||||
|
|
||||||
for (int64_t i = 0; i < rs; ++i) {
|
for (int64_t i = 0; i < rs; ++i) {
|
||||||
switch (op) {
|
switch (op) {
|
||||||
case GGML_OP_POOL_AVG: drow[i] = 0; break;
|
case GGML_OP_POOL_AVG: drow[i] = 0; break;
|
||||||
@ -14603,10 +14601,11 @@ static void ggml_compute_forward_pool_1d_sk_p0(
|
|||||||
case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break;
|
case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break;
|
||||||
}
|
}
|
||||||
for (int ki = 0; ki < k; ++ki) {
|
for (int ki = 0; ki < k; ++ki) {
|
||||||
|
const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
|
||||||
switch (op) {
|
switch (op) {
|
||||||
case GGML_OP_POOL_AVG: drow[i] += srow[j]; break;
|
case GGML_OP_POOL_AVG: drow[i] += srow_j; break;
|
||||||
case GGML_OP_POOL_MAX: if (srow[j] > drow[i]) drow[i] = srow[j]; break;
|
case GGML_OP_POOL_MAX: if (srow_j > drow[i]) drow[i] = srow_j; break;
|
||||||
case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break;
|
case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break;
|
||||||
}
|
}
|
||||||
++j;
|
++j;
|
||||||
}
|
}
|
||||||
@ -14647,7 +14646,7 @@ static void ggml_compute_forward_pool_2d(
|
|||||||
|
|
||||||
const struct ggml_tensor * src = dst->src[0];
|
const struct ggml_tensor * src = dst->src[0];
|
||||||
|
|
||||||
GGML_ASSERT(src->type == GGML_TYPE_F32);
|
assert(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16);
|
||||||
|
|
||||||
if (params->ith != 0) {
|
if (params->ith != 0) {
|
||||||
return;
|
return;
|
||||||
@ -14690,14 +14689,15 @@ static void ggml_compute_forward_pool_2d(
|
|||||||
|
|
||||||
for (int ky = 0; ky < k1; ++ky) {
|
for (int ky = 0; ky < k1; ++ky) {
|
||||||
if (iy + ky < 0 || iy + ky >= src->ne[1]) continue;
|
if (iy + ky < 0 || iy + ky >= src->ne[1]) continue;
|
||||||
const float * const srow = (const float *)(cdata + src->nb[1] * (iy + ky));
|
const void * srow = (const void *)(cdata + src->nb[1] * (iy + ky));
|
||||||
for (int kx = 0; kx < k0; ++kx) {
|
for (int kx = 0; kx < k0; ++kx) {
|
||||||
int j = ix + kx;
|
int j = ix + kx;
|
||||||
if (j < 0 || j >= src->ne[0]) continue;
|
if (j < 0 || j >= src->ne[0]) continue;
|
||||||
|
const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
|
||||||
switch (op) {
|
switch (op) {
|
||||||
case GGML_OP_POOL_AVG: *out += srow[j]; break;
|
case GGML_OP_POOL_AVG: *out += srow_j; break;
|
||||||
case GGML_OP_POOL_MAX: if (srow[j] > *out) *out = srow[j]; break;
|
case GGML_OP_POOL_MAX: if (srow_j > *out) *out = srow_j; break;
|
||||||
case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break;
|
case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user