mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-01-26 22:29:58 +00:00
whisper : initial Metal version
This commit is contained in:
parent
4845b9ed09
commit
d3b2dd4955
@ -64,6 +64,7 @@ struct ggml_metal_context {
|
||||
GGML_METAL_DECL_KERNEL(gelu);
|
||||
GGML_METAL_DECL_KERNEL(soft_max);
|
||||
GGML_METAL_DECL_KERNEL(diag_mask_inf);
|
||||
GGML_METAL_DECL_KERNEL(get_rows_f32);
|
||||
GGML_METAL_DECL_KERNEL(get_rows_f16);
|
||||
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
|
||||
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
|
||||
@ -208,6 +209,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||
GGML_METAL_ADD_KERNEL(gelu);
|
||||
GGML_METAL_ADD_KERNEL(soft_max);
|
||||
GGML_METAL_ADD_KERNEL(diag_mask_inf);
|
||||
GGML_METAL_ADD_KERNEL(get_rows_f32);
|
||||
GGML_METAL_ADD_KERNEL(get_rows_f16);
|
||||
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
|
||||
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
|
||||
@ -274,6 +276,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
||||
GGML_METAL_DEL_KERNEL(gelu);
|
||||
GGML_METAL_DEL_KERNEL(soft_max);
|
||||
GGML_METAL_DEL_KERNEL(diag_mask_inf);
|
||||
GGML_METAL_DEL_KERNEL(get_rows_f32);
|
||||
GGML_METAL_DEL_KERNEL(get_rows_f16);
|
||||
GGML_METAL_DEL_KERNEL(get_rows_q4_0);
|
||||
GGML_METAL_DEL_KERNEL(get_rows_q4_1);
|
||||
@ -1003,6 +1006,7 @@ void ggml_metal_graph_compute(
|
||||
case GGML_OP_GET_ROWS:
|
||||
{
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_get_rows_f32]; break;
|
||||
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
|
||||
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
|
||||
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
|
||||
|
@ -195,6 +195,22 @@ kernel void kernel_diag_mask_inf(
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_get_rows_f32(
|
||||
device const float * src0,
|
||||
device const int * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb1,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
const int i = tpig;
|
||||
const int r = ((device int32_t *) src1)[i];
|
||||
|
||||
for (int j = 0; j < ne00; j++) {
|
||||
dst[i*nb1 + j] = ((device float *) ((device char *) src0 + r*nb01))[j];
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_norm(
|
||||
device const void * src0,
|
||||
device float * dst,
|
||||
|
143
whisper.cpp
143
whisper.cpp
@ -625,9 +625,9 @@ struct whisper_state {
|
||||
// - stores meta info about the intermediate tensors into the `meta_*` buffers
|
||||
// - stores the actual tensor data into the `data_*` buffers
|
||||
|
||||
ggml_allocr * alloc_encode = NULL;
|
||||
ggml_allocr * alloc_cross = NULL;
|
||||
ggml_allocr * alloc_decode = NULL;
|
||||
ggml_allocr * alloc_encode = nullptr;
|
||||
ggml_allocr * alloc_cross = nullptr;
|
||||
ggml_allocr * alloc_decode = nullptr;
|
||||
|
||||
// meta data
|
||||
std::vector<uint8_t> meta_encode;
|
||||
@ -640,7 +640,7 @@ struct whisper_state {
|
||||
std::vector<uint8_t> data_decode;
|
||||
|
||||
// result of the encoder
|
||||
struct ggml_tensor * embd_enc = NULL;
|
||||
struct ggml_tensor * embd_enc = nullptr;
|
||||
|
||||
// decode output (2-dimensional array: [n_tokens][n_vocab])
|
||||
std::vector<float> logits;
|
||||
@ -1982,9 +1982,9 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
||||
// cur = ln_0_w*cur + ln_0_b
|
||||
cur = ggml_add(ctx0,
|
||||
ggml_mul(ctx0,
|
||||
ggml_repeat(ctx0, layer.attn_ln_0_w, cur),
|
||||
cur),
|
||||
ggml_repeat(ctx0, layer.attn_ln_0_b, cur));
|
||||
cur,
|
||||
layer.attn_ln_0_w),
|
||||
layer.attn_ln_0_b);
|
||||
}
|
||||
|
||||
// self-attention
|
||||
@ -1994,10 +1994,8 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
||||
cur);
|
||||
|
||||
Qcur = ggml_add(ctx0,
|
||||
ggml_repeat(ctx0,
|
||||
layer.attn_q_b,
|
||||
Qcur),
|
||||
Qcur);
|
||||
Qcur,
|
||||
layer.attn_q_b);
|
||||
|
||||
Qcur = ggml_scale(ctx0, Qcur, KQscale);
|
||||
|
||||
@ -2015,10 +2013,8 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
||||
cur);
|
||||
|
||||
Vcur = ggml_add(ctx0,
|
||||
ggml_repeat(ctx0,
|
||||
layer.attn_v_b,
|
||||
Vcur),
|
||||
Vcur);
|
||||
Vcur,
|
||||
layer.attn_v_b);
|
||||
|
||||
Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, N));
|
||||
|
||||
@ -2079,8 +2075,8 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
||||
cur);
|
||||
|
||||
cur = ggml_add(ctx0,
|
||||
ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
|
||||
cur);
|
||||
cur,
|
||||
layer.attn_ln_1_b);
|
||||
}
|
||||
|
||||
// add the input
|
||||
@ -2093,9 +2089,9 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
||||
// cur = ln_0_w*cur + ln_0_b
|
||||
cur = ggml_add(ctx0,
|
||||
ggml_mul(ctx0,
|
||||
ggml_repeat(ctx0, layer.cross_attn_ln_0_w, cur),
|
||||
cur),
|
||||
ggml_repeat(ctx0, layer.cross_attn_ln_0_b, cur));
|
||||
cur,
|
||||
layer.cross_attn_ln_0_w),
|
||||
layer.cross_attn_ln_0_b);
|
||||
}
|
||||
|
||||
// cross-attention
|
||||
@ -2105,10 +2101,8 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
||||
cur);
|
||||
|
||||
Qcur = ggml_add(ctx0,
|
||||
ggml_repeat(ctx0,
|
||||
layer.cross_attn_q_b,
|
||||
Qcur),
|
||||
Qcur);
|
||||
Qcur,
|
||||
layer.cross_attn_q_b);
|
||||
|
||||
Qcur = ggml_scale(ctx0, Qcur, KQscale);
|
||||
|
||||
@ -2177,8 +2171,8 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
||||
cur);
|
||||
|
||||
cur = ggml_add(ctx0,
|
||||
ggml_repeat(ctx0, layer.cross_attn_ln_1_b, cur),
|
||||
cur);
|
||||
cur,
|
||||
layer.cross_attn_ln_1_b);
|
||||
}
|
||||
|
||||
// add the input
|
||||
@ -2195,9 +2189,9 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
||||
// cur = mlp_ln_w*cur + mlp_ln_b
|
||||
cur = ggml_add(ctx0,
|
||||
ggml_mul(ctx0,
|
||||
ggml_repeat(ctx0, layer.mlp_ln_w, cur),
|
||||
cur),
|
||||
ggml_repeat(ctx0, layer.mlp_ln_b, cur));
|
||||
cur,
|
||||
layer.mlp_ln_w),
|
||||
layer.mlp_ln_b);
|
||||
}
|
||||
|
||||
// fully connected
|
||||
@ -2206,8 +2200,8 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
||||
cur);
|
||||
|
||||
cur = ggml_add(ctx0,
|
||||
ggml_repeat(ctx0, layer.mlp_0_b, cur),
|
||||
cur);
|
||||
cur,
|
||||
layer.mlp_0_b);
|
||||
|
||||
// GELU activation
|
||||
cur = ggml_gelu(ctx0, cur);
|
||||
@ -2218,8 +2212,8 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
||||
cur);
|
||||
|
||||
cur = ggml_add(ctx0,
|
||||
ggml_repeat(ctx0, layer.mlp_1_b, cur),
|
||||
cur);
|
||||
cur,
|
||||
layer.mlp_1_b);
|
||||
}
|
||||
|
||||
inpL = ggml_add(ctx0, cur, inpFF);
|
||||
@ -2233,9 +2227,9 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
||||
|
||||
cur = ggml_add(ctx0,
|
||||
ggml_mul(ctx0,
|
||||
ggml_repeat(ctx0, model.d_ln_w, cur),
|
||||
cur),
|
||||
ggml_repeat(ctx0, model.d_ln_b, cur));
|
||||
cur,
|
||||
model.d_ln_w),
|
||||
model.d_ln_b);
|
||||
}
|
||||
|
||||
// compute logits only for the last token
|
||||
@ -2291,9 +2285,20 @@ static bool whisper_decode_internal(
|
||||
|
||||
ggml_allocr_alloc_graph(alloc, gf);
|
||||
|
||||
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
|
||||
|
||||
logits = gf->nodes[gf->n_nodes - 1];
|
||||
|
||||
#ifdef GGML_USE_METAL
|
||||
if (wstate.ctx_metal && n_tokens == 1) {
|
||||
// TODO: fix for non-1 tokens
|
||||
ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
|
||||
ggml_metal_graph_compute(wstate.ctx_metal, gf);
|
||||
ggml_metal_get_tensor (wstate.ctx_metal, logits);
|
||||
} else {
|
||||
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
|
||||
}
|
||||
#else
|
||||
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
|
||||
#endif
|
||||
}
|
||||
|
||||
// extract logits for all N tokens
|
||||
@ -2760,6 +2765,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
||||
if (!state->ctx_coreml) {
|
||||
log("%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str());
|
||||
#ifndef WHISPER_COREML_ALLOW_FALLBACK
|
||||
delete state;
|
||||
return nullptr;
|
||||
#endif
|
||||
} else {
|
||||
@ -2840,7 +2846,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
||||
const int n_tokens = hparams.n_text_ctx;
|
||||
const int n_past = 0;
|
||||
|
||||
ggml_cgraph * gf = whisper_build_graph_decoder(*ctx, *state, state->decoders[0], NULL, n_tokens, n_past);
|
||||
ggml_cgraph * gf = whisper_build_graph_decoder(*ctx, *state, state->decoders[0], nullptr, n_tokens, n_past);
|
||||
|
||||
const size_t alloc_size = ggml_allocr_alloc_graph(alloc, gf) + tensor_alignment;
|
||||
|
||||
@ -2852,6 +2858,61 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
||||
alloc = ggml_allocr_new(data.data(), data.size(), tensor_alignment);
|
||||
}
|
||||
|
||||
#ifdef GGML_USE_METAL
|
||||
state->ctx_metal = ggml_metal_init(1);
|
||||
if (!state->ctx_metal) {
|
||||
log("%s: ggml_metal_init() failed\n", __func__);
|
||||
delete state;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
log("%s: Metal context initialized\n", __func__);
|
||||
|
||||
// this allocates all Metal resources and memory buffers
|
||||
|
||||
void * data_ptr = NULL;
|
||||
size_t data_size = 0;
|
||||
|
||||
// TODO: add mmap support
|
||||
//if (params.use_mmap) {
|
||||
// data_ptr = ctx->model.mapping->addr;
|
||||
// data_size = ctx->model.mapping->size;
|
||||
//} else {
|
||||
// data_ptr = ggml_get_mem_buffer(ctx->model.ctx);
|
||||
// data_size = ggml_get_mem_size (ctx->model.ctx);
|
||||
//}
|
||||
|
||||
data_ptr = ggml_get_mem_buffer(ctx->model.ctx);
|
||||
data_size = ggml_get_mem_size (ctx->model.ctx);
|
||||
|
||||
const size_t max_size = ggml_get_max_tensor_size(ctx->model.ctx);
|
||||
|
||||
log("%s: max tensor size = %8.2f MB\n", __func__, max_size/1024.0/1024.0);
|
||||
|
||||
#define WHISPER_METAL_CHECK_BUF(result) \
|
||||
if (!(result)) { \
|
||||
log("%s: failed to add buffer\n", __func__); \
|
||||
delete state; \
|
||||
return nullptr; \
|
||||
}
|
||||
|
||||
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data", data_ptr, data_size, max_size));
|
||||
|
||||
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_encode", state->meta_encode.data(), state->meta_encode.size(), 0));
|
||||
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_cross", state->meta_cross.data(), state->meta_cross.size(), 0));
|
||||
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_decode", state->meta_decode.data(), state->meta_decode.size(), 0));
|
||||
|
||||
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_encode", state->data_encode.data(), state->data_encode.size(), 0));
|
||||
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_cross", state->data_cross.data(), state->data_cross.size(), 0));
|
||||
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_decode", state->data_decode.data(), state->data_decode.size(), 0));
|
||||
|
||||
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "kv_cross", state->kv_cross.buf.data(), state->kv_cross.buf.size(), 0));
|
||||
|
||||
// TODO: handle multiple decoders
|
||||
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "kv_self_0", state->decoders[0].kv_self.buf.data(), state->decoders[0].kv_self.buf.size(), 0));
|
||||
#undef WHISPER_METAL_CHECK_BUF
|
||||
#endif
|
||||
|
||||
state->rng = std::mt19937(0);
|
||||
|
||||
return state;
|
||||
@ -4561,8 +4622,8 @@ int whisper_full_with_state(
|
||||
|
||||
decoder.kv_self.n += prompt.size();
|
||||
|
||||
memcpy(decoder.probs.data(), state->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0]));
|
||||
memcpy(decoder.logits.data(), state->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0]));
|
||||
memcpy(decoder.probs.data(), state->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0]));
|
||||
memcpy(decoder.logits.data(), state->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0]));
|
||||
memcpy(decoder.logprobs.data(), state->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0]));
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user