From d3b2dd4955e40172c2dec4a39a3f0c2874a905be Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 11 Sep 2023 15:37:24 +0300 Subject: [PATCH] whisper : initial Metal version --- ggml-metal.m | 4 ++ ggml-metal.metal | 16 ++++++ whisper.cpp | 143 +++++++++++++++++++++++++++++++++-------------- 3 files changed, 122 insertions(+), 41 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 7e2355ce..58d62911 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -64,6 +64,7 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(gelu); GGML_METAL_DECL_KERNEL(soft_max); GGML_METAL_DECL_KERNEL(diag_mask_inf); + GGML_METAL_DECL_KERNEL(get_rows_f32); GGML_METAL_DECL_KERNEL(get_rows_f16); GGML_METAL_DECL_KERNEL(get_rows_q4_0); GGML_METAL_DECL_KERNEL(get_rows_q4_1); @@ -208,6 +209,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(gelu); GGML_METAL_ADD_KERNEL(soft_max); GGML_METAL_ADD_KERNEL(diag_mask_inf); + GGML_METAL_ADD_KERNEL(get_rows_f32); GGML_METAL_ADD_KERNEL(get_rows_f16); GGML_METAL_ADD_KERNEL(get_rows_q4_0); GGML_METAL_ADD_KERNEL(get_rows_q4_1); @@ -274,6 +276,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(gelu); GGML_METAL_DEL_KERNEL(soft_max); GGML_METAL_DEL_KERNEL(diag_mask_inf); + GGML_METAL_DEL_KERNEL(get_rows_f32); GGML_METAL_DEL_KERNEL(get_rows_f16); GGML_METAL_DEL_KERNEL(get_rows_q4_0); GGML_METAL_DEL_KERNEL(get_rows_q4_1); @@ -1003,6 +1006,7 @@ void ggml_metal_graph_compute( case GGML_OP_GET_ROWS: { switch (src0->type) { + case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_get_rows_f32]; break; case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break; case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break; case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break; diff --git a/ggml-metal.metal b/ggml-metal.metal index 5070561f..8cf59f4e 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -195,6 +195,22 @@ kernel void kernel_diag_mask_inf( } } +kernel void kernel_get_rows_f32( + device const float * src0, + device const int * src1, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant uint64_t & nb1, + uint tpig[[thread_position_in_grid]]) { + const int i = tpig; + const int r = ((device int32_t *) src1)[i]; + + for (int j = 0; j < ne00; j++) { + dst[i*nb1 + j] = ((device float *) ((device char *) src0 + r*nb01))[j]; + } +} + kernel void kernel_norm( device const void * src0, device float * dst, diff --git a/whisper.cpp b/whisper.cpp index 1fe17245..13937ea8 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -625,9 +625,9 @@ struct whisper_state { // - stores meta info about the intermediate tensors into the `meta_*` buffers // - stores the actual tensor data into the `data_*` buffers - ggml_allocr * alloc_encode = NULL; - ggml_allocr * alloc_cross = NULL; - ggml_allocr * alloc_decode = NULL; + ggml_allocr * alloc_encode = nullptr; + ggml_allocr * alloc_cross = nullptr; + ggml_allocr * alloc_decode = nullptr; // meta data std::vector meta_encode; @@ -640,7 +640,7 @@ struct whisper_state { std::vector data_decode; // result of the encoder - struct ggml_tensor * embd_enc = NULL; + struct ggml_tensor * embd_enc = nullptr; // decode output (2-dimensional array: [n_tokens][n_vocab]) std::vector logits; @@ -1982,9 +1982,9 @@ static struct ggml_cgraph * whisper_build_graph_decoder( // cur = ln_0_w*cur + ln_0_b cur = ggml_add(ctx0, ggml_mul(ctx0, - ggml_repeat(ctx0, layer.attn_ln_0_w, cur), - cur), - ggml_repeat(ctx0, layer.attn_ln_0_b, cur)); + cur, + layer.attn_ln_0_w), + layer.attn_ln_0_b); } // self-attention @@ -1994,10 +1994,8 @@ static struct ggml_cgraph * whisper_build_graph_decoder( cur); Qcur = ggml_add(ctx0, - ggml_repeat(ctx0, - layer.attn_q_b, - Qcur), - Qcur); + Qcur, + layer.attn_q_b); Qcur = ggml_scale(ctx0, Qcur, KQscale); @@ -2015,10 +2013,8 @@ static struct ggml_cgraph * whisper_build_graph_decoder( cur); Vcur = ggml_add(ctx0, - ggml_repeat(ctx0, - layer.attn_v_b, - Vcur), - Vcur); + Vcur, + layer.attn_v_b); Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, N)); @@ -2079,8 +2075,8 @@ static struct ggml_cgraph * whisper_build_graph_decoder( cur); cur = ggml_add(ctx0, - ggml_repeat(ctx0, layer.attn_ln_1_b, cur), - cur); + cur, + layer.attn_ln_1_b); } // add the input @@ -2093,9 +2089,9 @@ static struct ggml_cgraph * whisper_build_graph_decoder( // cur = ln_0_w*cur + ln_0_b cur = ggml_add(ctx0, ggml_mul(ctx0, - ggml_repeat(ctx0, layer.cross_attn_ln_0_w, cur), - cur), - ggml_repeat(ctx0, layer.cross_attn_ln_0_b, cur)); + cur, + layer.cross_attn_ln_0_w), + layer.cross_attn_ln_0_b); } // cross-attention @@ -2105,10 +2101,8 @@ static struct ggml_cgraph * whisper_build_graph_decoder( cur); Qcur = ggml_add(ctx0, - ggml_repeat(ctx0, - layer.cross_attn_q_b, - Qcur), - Qcur); + Qcur, + layer.cross_attn_q_b); Qcur = ggml_scale(ctx0, Qcur, KQscale); @@ -2177,8 +2171,8 @@ static struct ggml_cgraph * whisper_build_graph_decoder( cur); cur = ggml_add(ctx0, - ggml_repeat(ctx0, layer.cross_attn_ln_1_b, cur), - cur); + cur, + layer.cross_attn_ln_1_b); } // add the input @@ -2195,9 +2189,9 @@ static struct ggml_cgraph * whisper_build_graph_decoder( // cur = mlp_ln_w*cur + mlp_ln_b cur = ggml_add(ctx0, ggml_mul(ctx0, - ggml_repeat(ctx0, layer.mlp_ln_w, cur), - cur), - ggml_repeat(ctx0, layer.mlp_ln_b, cur)); + cur, + layer.mlp_ln_w), + layer.mlp_ln_b); } // fully connected @@ -2206,8 +2200,8 @@ static struct ggml_cgraph * whisper_build_graph_decoder( cur); cur = ggml_add(ctx0, - ggml_repeat(ctx0, layer.mlp_0_b, cur), - cur); + cur, + layer.mlp_0_b); // GELU activation cur = ggml_gelu(ctx0, cur); @@ -2218,8 +2212,8 @@ static struct ggml_cgraph * whisper_build_graph_decoder( cur); cur = ggml_add(ctx0, - ggml_repeat(ctx0, layer.mlp_1_b, cur), - cur); + cur, + layer.mlp_1_b); } inpL = ggml_add(ctx0, cur, inpFF); @@ -2233,9 +2227,9 @@ static struct ggml_cgraph * whisper_build_graph_decoder( cur = ggml_add(ctx0, ggml_mul(ctx0, - ggml_repeat(ctx0, model.d_ln_w, cur), - cur), - ggml_repeat(ctx0, model.d_ln_b, cur)); + cur, + model.d_ln_w), + model.d_ln_b); } // compute logits only for the last token @@ -2291,9 +2285,20 @@ static bool whisper_decode_internal( ggml_allocr_alloc_graph(alloc, gf); - ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads); - logits = gf->nodes[gf->n_nodes - 1]; + +#ifdef GGML_USE_METAL + if (wstate.ctx_metal && n_tokens == 1) { + // TODO: fix for non-1 tokens + ggml_metal_set_n_cb (wstate.ctx_metal, n_threads); + ggml_metal_graph_compute(wstate.ctx_metal, gf); + ggml_metal_get_tensor (wstate.ctx_metal, logits); + } else { + ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads); + } +#else + ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads); +#endif } // extract logits for all N tokens @@ -2760,6 +2765,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { if (!state->ctx_coreml) { log("%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str()); #ifndef WHISPER_COREML_ALLOW_FALLBACK + delete state; return nullptr; #endif } else { @@ -2840,7 +2846,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { const int n_tokens = hparams.n_text_ctx; const int n_past = 0; - ggml_cgraph * gf = whisper_build_graph_decoder(*ctx, *state, state->decoders[0], NULL, n_tokens, n_past); + ggml_cgraph * gf = whisper_build_graph_decoder(*ctx, *state, state->decoders[0], nullptr, n_tokens, n_past); const size_t alloc_size = ggml_allocr_alloc_graph(alloc, gf) + tensor_alignment; @@ -2852,6 +2858,61 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { alloc = ggml_allocr_new(data.data(), data.size(), tensor_alignment); } +#ifdef GGML_USE_METAL + state->ctx_metal = ggml_metal_init(1); + if (!state->ctx_metal) { + log("%s: ggml_metal_init() failed\n", __func__); + delete state; + return nullptr; + } + + log("%s: Metal context initialized\n", __func__); + + // this allocates all Metal resources and memory buffers + + void * data_ptr = NULL; + size_t data_size = 0; + + // TODO: add mmap support + //if (params.use_mmap) { + // data_ptr = ctx->model.mapping->addr; + // data_size = ctx->model.mapping->size; + //} else { + // data_ptr = ggml_get_mem_buffer(ctx->model.ctx); + // data_size = ggml_get_mem_size (ctx->model.ctx); + //} + + data_ptr = ggml_get_mem_buffer(ctx->model.ctx); + data_size = ggml_get_mem_size (ctx->model.ctx); + + const size_t max_size = ggml_get_max_tensor_size(ctx->model.ctx); + + log("%s: max tensor size = %8.2f MB\n", __func__, max_size/1024.0/1024.0); + +#define WHISPER_METAL_CHECK_BUF(result) \ + if (!(result)) { \ + log("%s: failed to add buffer\n", __func__); \ + delete state; \ + return nullptr; \ + } + + WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data", data_ptr, data_size, max_size)); + + WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_encode", state->meta_encode.data(), state->meta_encode.size(), 0)); + WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_cross", state->meta_cross.data(), state->meta_cross.size(), 0)); + WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_decode", state->meta_decode.data(), state->meta_decode.size(), 0)); + + WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_encode", state->data_encode.data(), state->data_encode.size(), 0)); + WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_cross", state->data_cross.data(), state->data_cross.size(), 0)); + WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_decode", state->data_decode.data(), state->data_decode.size(), 0)); + + WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "kv_cross", state->kv_cross.buf.data(), state->kv_cross.buf.size(), 0)); + + // TODO: handle multiple decoders + WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "kv_self_0", state->decoders[0].kv_self.buf.data(), state->decoders[0].kv_self.buf.size(), 0)); +#undef WHISPER_METAL_CHECK_BUF +#endif + state->rng = std::mt19937(0); return state; @@ -4561,8 +4622,8 @@ int whisper_full_with_state( decoder.kv_self.n += prompt.size(); - memcpy(decoder.probs.data(), state->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0])); - memcpy(decoder.logits.data(), state->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0])); + memcpy(decoder.probs.data(), state->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0])); + memcpy(decoder.logits.data(), state->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0])); memcpy(decoder.logprobs.data(), state->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0])); }