mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-01-27 14:49:55 +00:00
metal : run "cross" step on the GPU
This commit is contained in:
parent
9fdd415367
commit
cd476375b4
56
whisper.cpp
56
whisper.cpp
@ -1797,14 +1797,15 @@ static struct ggml_cgraph * whisper_build_graph_cross(
|
||||
cur);
|
||||
|
||||
Vcross = ggml_add(ctx0,
|
||||
ggml_repeat(ctx0,
|
||||
layer.cross_attn_v_b,
|
||||
Vcross),
|
||||
Vcross);
|
||||
Vcross,
|
||||
layer.cross_attn_v_b);
|
||||
|
||||
Vcross = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx));
|
||||
|
||||
struct ggml_tensor * k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
|
||||
struct ggml_tensor * k = ggml_view_1d(ctx0, wstate.kv_cross.k,
|
||||
n_state*n_ctx,
|
||||
(ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
|
||||
|
||||
struct ggml_tensor * v = ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state,
|
||||
( n_ctx)*ggml_element_size(wstate.kv_cross.v),
|
||||
(il*n_ctx)*ggml_element_size(wstate.kv_cross.v)*n_state);
|
||||
@ -1851,22 +1852,6 @@ static bool whisper_encode_internal(
|
||||
#else
|
||||
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
|
||||
#endif
|
||||
|
||||
//auto cur = wstate.embd_enc;
|
||||
////auto cur = gf->leafs[0];
|
||||
//printf("cur name = '%s'\n", cur->name);
|
||||
|
||||
//float * res = (float *) cur->data;
|
||||
//for (int i = 0; i < 10; ++i) {
|
||||
// printf("%f ", res[i]);
|
||||
//}
|
||||
//printf("\n");
|
||||
//double sum = 0.0;
|
||||
//for (int i = 0; i < ggml_nelements(cur); ++i) {
|
||||
// sum += res[i];
|
||||
//}
|
||||
//printf("sum: %f\n", sum);
|
||||
//printf("n: %d\n", ggml_nelements(cur));
|
||||
}
|
||||
|
||||
// cross
|
||||
@ -1879,22 +1864,16 @@ static bool whisper_encode_internal(
|
||||
|
||||
ggml_allocr_alloc_graph(alloc, gf);
|
||||
|
||||
#ifdef GGML_USE_METAL
|
||||
if (wstate.ctx_metal && false) {
|
||||
ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
|
||||
ggml_metal_graph_compute(wstate.ctx_metal, gf);
|
||||
} else {
|
||||
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
|
||||
}
|
||||
#else
|
||||
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
|
||||
|
||||
//auto cur = gf->nodes[gf->n_nodes - 1];
|
||||
//printf("cur name = '%s'\n", cur->name);
|
||||
|
||||
//ggml_fp16_t * res = (ggml_fp16_t *) cur->data;
|
||||
//for (int i = 0; i < 10; ++i) {
|
||||
// printf("%f ", ggml_fp32_to_fp16(res[i]));
|
||||
//}
|
||||
//printf("\n");
|
||||
//double sum = 0.0;
|
||||
//for (int i = 0; i < ggml_nelements(cur); ++i) {
|
||||
// sum += ggml_fp32_to_fp16(res[i]);
|
||||
//}
|
||||
//printf("sum: %f\n", sum);
|
||||
//printf("n: %d\n", ggml_nelements(cur));
|
||||
#endif
|
||||
}
|
||||
|
||||
// ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
||||
@ -2287,7 +2266,6 @@ static bool whisper_decode_internal(
|
||||
if (wstate.ctx_metal) {
|
||||
ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
|
||||
ggml_metal_graph_compute(wstate.ctx_metal, gf);
|
||||
ggml_metal_get_tensor (wstate.ctx_metal, logits);
|
||||
} else {
|
||||
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
|
||||
}
|
||||
@ -2775,8 +2753,8 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
||||
// TAGS: WHISPER_DECODER_INIT
|
||||
state->decoders[0].sequence.tokens.reserve(ctx->model.hparams.n_text_ctx);
|
||||
|
||||
state->decoders[0].probs.reserve(ctx->vocab.n_vocab);
|
||||
state->decoders[0].logits.reserve(ctx->vocab.n_vocab);
|
||||
state->decoders[0].probs.reserve (ctx->vocab.n_vocab);
|
||||
state->decoders[0].logits.reserve (ctx->vocab.n_vocab);
|
||||
state->decoders[0].logprobs.reserve(ctx->vocab.n_vocab);
|
||||
|
||||
static const size_t tensor_alignment = 32;
|
||||
|
Loading…
x
Reference in New Issue
Block a user