diff --git a/whisper.cpp b/whisper.cpp index 68626c97..bfb988fb 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -1797,14 +1797,15 @@ static struct ggml_cgraph * whisper_build_graph_cross( cur); Vcross = ggml_add(ctx0, - ggml_repeat(ctx0, - layer.cross_attn_v_b, - Vcross), - Vcross); + Vcross, + layer.cross_attn_v_b); Vcross = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx)); - struct ggml_tensor * k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx)); + struct ggml_tensor * k = ggml_view_1d(ctx0, wstate.kv_cross.k, + n_state*n_ctx, + (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx)); + struct ggml_tensor * v = ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state, ( n_ctx)*ggml_element_size(wstate.kv_cross.v), (il*n_ctx)*ggml_element_size(wstate.kv_cross.v)*n_state); @@ -1851,22 +1852,6 @@ static bool whisper_encode_internal( #else ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads); #endif - - //auto cur = wstate.embd_enc; - ////auto cur = gf->leafs[0]; - //printf("cur name = '%s'\n", cur->name); - - //float * res = (float *) cur->data; - //for (int i = 0; i < 10; ++i) { - // printf("%f ", res[i]); - //} - //printf("\n"); - //double sum = 0.0; - //for (int i = 0; i < ggml_nelements(cur); ++i) { - // sum += res[i]; - //} - //printf("sum: %f\n", sum); - //printf("n: %d\n", ggml_nelements(cur)); } // cross @@ -1879,22 +1864,16 @@ static bool whisper_encode_internal( ggml_allocr_alloc_graph(alloc, gf); +#ifdef GGML_USE_METAL + if (wstate.ctx_metal && false) { + ggml_metal_set_n_cb (wstate.ctx_metal, n_threads); + ggml_metal_graph_compute(wstate.ctx_metal, gf); + } else { + ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads); + } +#else ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads); - - //auto cur = gf->nodes[gf->n_nodes - 1]; - //printf("cur name = '%s'\n", cur->name); - - //ggml_fp16_t * res = (ggml_fp16_t *) cur->data; - //for (int i = 0; i < 10; ++i) { - // printf("%f ", ggml_fp32_to_fp16(res[i])); - //} - //printf("\n"); - //double sum = 0.0; - //for (int i = 0; i < ggml_nelements(cur); ++i) { - // sum += ggml_fp32_to_fp16(res[i]); - //} - //printf("sum: %f\n", sum); - //printf("n: %d\n", ggml_nelements(cur)); +#endif } // ggml_graph_compute_with_ctx(ctx0, &gf, n_threads); @@ -2287,7 +2266,6 @@ static bool whisper_decode_internal( if (wstate.ctx_metal) { ggml_metal_set_n_cb (wstate.ctx_metal, n_threads); ggml_metal_graph_compute(wstate.ctx_metal, gf); - ggml_metal_get_tensor (wstate.ctx_metal, logits); } else { ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads); } @@ -2775,8 +2753,8 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { // TAGS: WHISPER_DECODER_INIT state->decoders[0].sequence.tokens.reserve(ctx->model.hparams.n_text_ctx); - state->decoders[0].probs.reserve(ctx->vocab.n_vocab); - state->decoders[0].logits.reserve(ctx->vocab.n_vocab); + state->decoders[0].probs.reserve (ctx->vocab.n_vocab); + state->decoders[0].logits.reserve (ctx->vocab.n_vocab); state->decoders[0].logprobs.reserve(ctx->vocab.n_vocab); static const size_t tensor_alignment = 32;