mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2024-12-24 06:46:37 +00:00
whisper : fix external encoder (#1860)
This commit is contained in:
parent
b742f13e70
commit
e3c5e2cba8
41
whisper.cpp
41
whisper.cpp
@ -1659,22 +1659,9 @@ static struct ggml_cgraph * whisper_build_graph_conv(
|
|||||||
ggml_set_name(cur, "embd_conv");
|
ggml_set_name(cur, "embd_conv");
|
||||||
wstate.embd_conv = cur;
|
wstate.embd_conv = cur;
|
||||||
} else {
|
} else {
|
||||||
#ifdef WHISPER_USE_COREML
|
ggml_build_forward_expand(gf, mel);
|
||||||
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
|
|
||||||
ggml_allocr_alloc(alloc, cur);
|
|
||||||
|
|
||||||
if (!ggml_allocr_is_measure(alloc)) {
|
|
||||||
whisper_coreml_encode(wstate.ctx_coreml, mel->ne[0], mel->ne[1], (float *) mel->data, (float *) cur->data);
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
#ifdef WHISPER_USE_OPENVINO
|
|
||||||
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
|
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
|
||||||
ggml_allocr_alloc(alloc, cur);
|
|
||||||
|
|
||||||
if (!ggml_allocr_is_measure(alloc)) {
|
|
||||||
whisper_openvino_encode(wstate.ctx_openvino, mel, cur);
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
ggml_set_name(cur, "embd_enc");
|
ggml_set_name(cur, "embd_enc");
|
||||||
wstate.embd_enc = cur;
|
wstate.embd_enc = cur;
|
||||||
@ -1708,14 +1695,6 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|||||||
|
|
||||||
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false);
|
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false);
|
||||||
|
|
||||||
//ggml_allocr * alloc = wstate.alloc_encode.alloc;
|
|
||||||
|
|
||||||
//struct ggml_tensor * cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_ctx, n_state);
|
|
||||||
//ggml_allocr_alloc(alloc, cur);
|
|
||||||
|
|
||||||
//if (!ggml_allocr_is_measure(alloc)) {
|
|
||||||
// ggml_backend_tensor_copy(wstate.embd_conv, cur);
|
|
||||||
//}
|
|
||||||
struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_conv);
|
struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_conv);
|
||||||
|
|
||||||
const float KQscale = 1.0f/sqrtf(float(n_state)/n_head);
|
const float KQscale = 1.0f/sqrtf(float(n_state)/n_head);
|
||||||
@ -1957,14 +1936,6 @@ static struct ggml_cgraph * whisper_build_graph_cross(
|
|||||||
|
|
||||||
ggml_cgraph * gf = ggml_new_graph(ctx0);
|
ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||||
|
|
||||||
//ggml_allocr * alloc = wstate.alloc_cross.alloc;
|
|
||||||
|
|
||||||
//struct ggml_tensor * cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
|
|
||||||
//ggml_allocr_alloc(alloc, cur);
|
|
||||||
|
|
||||||
//if (!ggml_allocr_is_measure(alloc)) {
|
|
||||||
// ggml_backend_tensor_copy(wstate.embd_enc, cur);
|
|
||||||
//}
|
|
||||||
struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc);
|
struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc);
|
||||||
|
|
||||||
const float Kscale = pow(float(n_state) / n_head, -0.25);
|
const float Kscale = pow(float(n_state) / n_head, -0.25);
|
||||||
@ -2037,13 +2008,13 @@ static bool whisper_encode_internal(
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * mel = ggml_graph_get_tensor(gf, "mel");
|
||||||
|
|
||||||
// set the input
|
// set the input
|
||||||
{
|
{
|
||||||
const auto & mel_inp = wstate.mel;
|
const auto & mel_inp = wstate.mel;
|
||||||
const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : wctx.model.hparams.n_audio_ctx;
|
const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : wctx.model.hparams.n_audio_ctx;
|
||||||
|
|
||||||
struct ggml_tensor * mel = ggml_graph_get_tensor(gf, "mel");
|
|
||||||
|
|
||||||
assert(mel->type == GGML_TYPE_F32);
|
assert(mel->type == GGML_TYPE_F32);
|
||||||
assert(mel_inp.n_mel == wctx.model.hparams.n_mels);
|
assert(mel_inp.n_mel == wctx.model.hparams.n_mels);
|
||||||
|
|
||||||
@ -2068,6 +2039,12 @@ static bool whisper_encode_internal(
|
|||||||
if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
|
if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
#if defined(WHISPER_USE_COREML)
|
||||||
|
whisper_coreml_encode(wstate.ctx_coreml, mel->ne[0], mel->ne[1], (float *) mel->data, (float *) wstate.embd_enc->data);
|
||||||
|
#elif defined(WHISPER_USE_OPENVINO)
|
||||||
|
whisper_openvino_encode(wstate.ctx_openvino, mel, wstate.embd_enc);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user