mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-01-20 03:36:39 +00:00
whisper : do not use F16 tensors when in F32 mode (#369)
This commit is contained in:
parent
b3c865083e
commit
ad2a4ffa03
40
whisper.cpp
40
whisper.cpp
@ -412,6 +412,8 @@ struct whisper_context {
|
||||
std::vector<uint8_t> buf_compute;
|
||||
std::vector<uint8_t> buf_compute_layer;
|
||||
|
||||
ggml_type wtype; // weight type (FP32 or FP16)
|
||||
|
||||
whisper_model model;
|
||||
whisper_vocab vocab;
|
||||
|
||||
@ -629,7 +631,9 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
|
||||
|
||||
// for the big tensors, we have the option to store the data in 16-bit floats
|
||||
// in order to save memory and also to speed up the computation
|
||||
const ggml_type wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
||||
wctx.wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
||||
|
||||
const ggml_type wtype = wctx.wtype;
|
||||
|
||||
size_t ctx_size = 0;
|
||||
|
||||
@ -650,7 +654,6 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
|
||||
|
||||
// encoder
|
||||
{
|
||||
// TODO: F16 .. maybe not?
|
||||
ctx_size += n_audio_ctx*n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_pe;
|
||||
|
||||
ctx_size += 3*n_mels*n_audio_state*ggml_type_size(wtype); // e_conv_1_w
|
||||
@ -665,7 +668,6 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
|
||||
|
||||
// decoder
|
||||
{
|
||||
// TODO: F16 .. maybe not?
|
||||
ctx_size += n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // d_pe;
|
||||
|
||||
ctx_size += n_vocab*n_text_state*ggml_type_size(wtype); // d_te;
|
||||
@ -982,8 +984,8 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
|
||||
const int n_mem = n_text_layer*n_text_ctx;
|
||||
const int n_elements = n_text_state*n_mem;
|
||||
|
||||
model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
|
||||
model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
|
||||
model.memory_k = ggml_new_tensor_1d(ctx, wtype, n_elements);
|
||||
model.memory_v = ggml_new_tensor_1d(ctx, wtype, n_elements);
|
||||
}
|
||||
|
||||
// key/value memory for the cross-attention layer
|
||||
@ -993,8 +995,8 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
|
||||
const int n_mem = n_text_layer*n_audio_ctx;
|
||||
const int n_elements = n_text_state*n_mem;
|
||||
|
||||
model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
|
||||
model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
|
||||
model.memory_cross_k = ggml_new_tensor_1d(ctx, wtype, n_elements);
|
||||
model.memory_cross_v = ggml_new_tensor_1d(ctx, wtype, n_elements);
|
||||
}
|
||||
|
||||
const size_t memory_size =
|
||||
@ -1240,14 +1242,14 @@ static bool whisper_encode(
|
||||
ggml_permute(ctxL,
|
||||
ggml_cpy(ctxL,
|
||||
Qcur,
|
||||
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
|
||||
ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
|
||||
0, 2, 1, 3);
|
||||
|
||||
struct ggml_tensor * K =
|
||||
ggml_permute(ctxL,
|
||||
ggml_cpy(ctxL,
|
||||
Kcur,
|
||||
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
|
||||
ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
|
||||
0, 2, 1, 3);
|
||||
|
||||
struct ggml_tensor * V =
|
||||
@ -1257,7 +1259,7 @@ static bool whisper_encode(
|
||||
Vcur,
|
||||
n_state/n_head, n_head, n_ctx),
|
||||
1, 2, 0, 3),
|
||||
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_ctx, n_state/n_head, n_head)
|
||||
ggml_new_tensor_3d(ctxL, wctx.wtype, n_ctx, n_state/n_head, n_head)
|
||||
);
|
||||
|
||||
struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false);
|
||||
@ -1273,7 +1275,7 @@ static bool whisper_encode(
|
||||
ggml_permute(ctxL,
|
||||
ggml_cpy(ctxL,
|
||||
Kcur,
|
||||
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
|
||||
ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
|
||||
0, 2, 1, 3);
|
||||
|
||||
// K * Q
|
||||
@ -1291,7 +1293,7 @@ static bool whisper_encode(
|
||||
// ggml_permute(ctxL,
|
||||
// ggml_cpy(ctxL,
|
||||
// Vcur,
|
||||
// ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
|
||||
// ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
|
||||
// 1, 2, 0, 3);
|
||||
|
||||
//struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
|
||||
@ -1303,7 +1305,7 @@ static bool whisper_encode(
|
||||
Vcur,
|
||||
n_state/n_head, n_head, n_ctx),
|
||||
0, 2, 1, 3),
|
||||
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_ctx, n_head)
|
||||
ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_ctx, n_head)
|
||||
);
|
||||
|
||||
struct ggml_tensor * KQV = ggml_mul_mat(ctxL, ggml_transpose(ctxL, V), KQ_soft_max);
|
||||
@ -1348,7 +1350,7 @@ static bool whisper_encode(
|
||||
|
||||
#ifdef USE_FLASH_FF
|
||||
cur = ggml_flash_ff(ctxL,
|
||||
ggml_cpy(ctxL, cur, ggml_new_tensor_2d(ctxL, GGML_TYPE_F16, n_state, N)),
|
||||
ggml_cpy(ctxL, cur, ggml_new_tensor_2d(ctxL, wctx.wtype, n_state, N)),
|
||||
layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
|
||||
#else
|
||||
// fully connected
|
||||
@ -3156,7 +3158,7 @@ int whisper_full_parallel(
|
||||
|
||||
// separate key + value memory for each processor
|
||||
{
|
||||
auto & ctx = model.ctx_mem;
|
||||
auto & mctx = model.ctx_mem;
|
||||
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
@ -3169,8 +3171,8 @@ int whisper_full_parallel(
|
||||
const int n_mem = n_text_layer*n_text_ctx;
|
||||
const int n_elements = n_text_state*n_mem;
|
||||
|
||||
model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
|
||||
model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
|
||||
model.memory_k = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
|
||||
model.memory_v = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
|
||||
}
|
||||
|
||||
// key/value memory for the cross-attention layer
|
||||
@ -3180,8 +3182,8 @@ int whisper_full_parallel(
|
||||
const int n_mem = n_text_layer*n_audio_ctx;
|
||||
const int n_elements = n_text_state*n_mem;
|
||||
|
||||
model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
|
||||
model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
|
||||
model.memory_cross_k = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
|
||||
model.memory_cross_v = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user