metal : run "cross" step on the GPU

This commit is contained in:
Georgi Gerganov 2023-09-12 20:11:13 +03:00
parent 9fdd415367
commit cd476375b4
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -1797,14 +1797,15 @@ static struct ggml_cgraph * whisper_build_graph_cross(
cur);
Vcross = ggml_add(ctx0,
ggml_repeat(ctx0,
layer.cross_attn_v_b,
Vcross),
Vcross);
Vcross,
layer.cross_attn_v_b);
Vcross = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx));
struct ggml_tensor * k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
struct ggml_tensor * k = ggml_view_1d(ctx0, wstate.kv_cross.k,
n_state*n_ctx,
(ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
struct ggml_tensor * v = ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state,
( n_ctx)*ggml_element_size(wstate.kv_cross.v),
(il*n_ctx)*ggml_element_size(wstate.kv_cross.v)*n_state);
@ -1851,22 +1852,6 @@ static bool whisper_encode_internal(
#else
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
#endif
//auto cur = wstate.embd_enc;
////auto cur = gf->leafs[0];
//printf("cur name = '%s'\n", cur->name);
//float * res = (float *) cur->data;
//for (int i = 0; i < 10; ++i) {
// printf("%f ", res[i]);
//}
//printf("\n");
//double sum = 0.0;
//for (int i = 0; i < ggml_nelements(cur); ++i) {
// sum += res[i];
//}
//printf("sum: %f\n", sum);
//printf("n: %d\n", ggml_nelements(cur));
}
// cross
@ -1879,22 +1864,16 @@ static bool whisper_encode_internal(
ggml_allocr_alloc_graph(alloc, gf);
#ifdef GGML_USE_METAL
if (wstate.ctx_metal && false) {
ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
ggml_metal_graph_compute(wstate.ctx_metal, gf);
} else {
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
}
#else
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
//auto cur = gf->nodes[gf->n_nodes - 1];
//printf("cur name = '%s'\n", cur->name);
//ggml_fp16_t * res = (ggml_fp16_t *) cur->data;
//for (int i = 0; i < 10; ++i) {
// printf("%f ", ggml_fp32_to_fp16(res[i]));
//}
//printf("\n");
//double sum = 0.0;
//for (int i = 0; i < ggml_nelements(cur); ++i) {
// sum += ggml_fp32_to_fp16(res[i]);
//}
//printf("sum: %f\n", sum);
//printf("n: %d\n", ggml_nelements(cur));
#endif
}
// ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
@ -2287,7 +2266,6 @@ static bool whisper_decode_internal(
if (wstate.ctx_metal) {
ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
ggml_metal_graph_compute(wstate.ctx_metal, gf);
ggml_metal_get_tensor (wstate.ctx_metal, logits);
} else {
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
}
@ -2775,8 +2753,8 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
// TAGS: WHISPER_DECODER_INIT
state->decoders[0].sequence.tokens.reserve(ctx->model.hparams.n_text_ctx);
state->decoders[0].probs.reserve(ctx->vocab.n_vocab);
state->decoders[0].logits.reserve(ctx->vocab.n_vocab);
state->decoders[0].probs.reserve (ctx->vocab.n_vocab);
state->decoders[0].logits.reserve (ctx->vocab.n_vocab);
state->decoders[0].logprobs.reserve(ctx->vocab.n_vocab);
static const size_t tensor_alignment = 32;