mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-01-28 23:24:23 +00:00
whisper : initial Metal version
This commit is contained in:
parent
4845b9ed09
commit
d3b2dd4955
@ -64,6 +64,7 @@ struct ggml_metal_context {
|
|||||||
GGML_METAL_DECL_KERNEL(gelu);
|
GGML_METAL_DECL_KERNEL(gelu);
|
||||||
GGML_METAL_DECL_KERNEL(soft_max);
|
GGML_METAL_DECL_KERNEL(soft_max);
|
||||||
GGML_METAL_DECL_KERNEL(diag_mask_inf);
|
GGML_METAL_DECL_KERNEL(diag_mask_inf);
|
||||||
|
GGML_METAL_DECL_KERNEL(get_rows_f32);
|
||||||
GGML_METAL_DECL_KERNEL(get_rows_f16);
|
GGML_METAL_DECL_KERNEL(get_rows_f16);
|
||||||
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
|
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
|
||||||
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
|
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
|
||||||
@ -208,6 +209,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||||||
GGML_METAL_ADD_KERNEL(gelu);
|
GGML_METAL_ADD_KERNEL(gelu);
|
||||||
GGML_METAL_ADD_KERNEL(soft_max);
|
GGML_METAL_ADD_KERNEL(soft_max);
|
||||||
GGML_METAL_ADD_KERNEL(diag_mask_inf);
|
GGML_METAL_ADD_KERNEL(diag_mask_inf);
|
||||||
|
GGML_METAL_ADD_KERNEL(get_rows_f32);
|
||||||
GGML_METAL_ADD_KERNEL(get_rows_f16);
|
GGML_METAL_ADD_KERNEL(get_rows_f16);
|
||||||
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
|
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
|
||||||
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
|
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
|
||||||
@ -274,6 +276,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|||||||
GGML_METAL_DEL_KERNEL(gelu);
|
GGML_METAL_DEL_KERNEL(gelu);
|
||||||
GGML_METAL_DEL_KERNEL(soft_max);
|
GGML_METAL_DEL_KERNEL(soft_max);
|
||||||
GGML_METAL_DEL_KERNEL(diag_mask_inf);
|
GGML_METAL_DEL_KERNEL(diag_mask_inf);
|
||||||
|
GGML_METAL_DEL_KERNEL(get_rows_f32);
|
||||||
GGML_METAL_DEL_KERNEL(get_rows_f16);
|
GGML_METAL_DEL_KERNEL(get_rows_f16);
|
||||||
GGML_METAL_DEL_KERNEL(get_rows_q4_0);
|
GGML_METAL_DEL_KERNEL(get_rows_q4_0);
|
||||||
GGML_METAL_DEL_KERNEL(get_rows_q4_1);
|
GGML_METAL_DEL_KERNEL(get_rows_q4_1);
|
||||||
@ -1003,6 +1006,7 @@ void ggml_metal_graph_compute(
|
|||||||
case GGML_OP_GET_ROWS:
|
case GGML_OP_GET_ROWS:
|
||||||
{
|
{
|
||||||
switch (src0->type) {
|
switch (src0->type) {
|
||||||
|
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_get_rows_f32]; break;
|
||||||
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
|
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
|
||||||
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
|
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
|
||||||
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
|
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
|
||||||
|
@ -195,6 +195,22 @@ kernel void kernel_diag_mask_inf(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
kernel void kernel_get_rows_f32(
|
||||||
|
device const float * src0,
|
||||||
|
device const int * src1,
|
||||||
|
device float * dst,
|
||||||
|
constant int64_t & ne00,
|
||||||
|
constant uint64_t & nb01,
|
||||||
|
constant uint64_t & nb1,
|
||||||
|
uint tpig[[thread_position_in_grid]]) {
|
||||||
|
const int i = tpig;
|
||||||
|
const int r = ((device int32_t *) src1)[i];
|
||||||
|
|
||||||
|
for (int j = 0; j < ne00; j++) {
|
||||||
|
dst[i*nb1 + j] = ((device float *) ((device char *) src0 + r*nb01))[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
kernel void kernel_norm(
|
kernel void kernel_norm(
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
|
139
whisper.cpp
139
whisper.cpp
@ -625,9 +625,9 @@ struct whisper_state {
|
|||||||
// - stores meta info about the intermediate tensors into the `meta_*` buffers
|
// - stores meta info about the intermediate tensors into the `meta_*` buffers
|
||||||
// - stores the actual tensor data into the `data_*` buffers
|
// - stores the actual tensor data into the `data_*` buffers
|
||||||
|
|
||||||
ggml_allocr * alloc_encode = NULL;
|
ggml_allocr * alloc_encode = nullptr;
|
||||||
ggml_allocr * alloc_cross = NULL;
|
ggml_allocr * alloc_cross = nullptr;
|
||||||
ggml_allocr * alloc_decode = NULL;
|
ggml_allocr * alloc_decode = nullptr;
|
||||||
|
|
||||||
// meta data
|
// meta data
|
||||||
std::vector<uint8_t> meta_encode;
|
std::vector<uint8_t> meta_encode;
|
||||||
@ -640,7 +640,7 @@ struct whisper_state {
|
|||||||
std::vector<uint8_t> data_decode;
|
std::vector<uint8_t> data_decode;
|
||||||
|
|
||||||
// result of the encoder
|
// result of the encoder
|
||||||
struct ggml_tensor * embd_enc = NULL;
|
struct ggml_tensor * embd_enc = nullptr;
|
||||||
|
|
||||||
// decode output (2-dimensional array: [n_tokens][n_vocab])
|
// decode output (2-dimensional array: [n_tokens][n_vocab])
|
||||||
std::vector<float> logits;
|
std::vector<float> logits;
|
||||||
@ -1982,9 +1982,9 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|||||||
// cur = ln_0_w*cur + ln_0_b
|
// cur = ln_0_w*cur + ln_0_b
|
||||||
cur = ggml_add(ctx0,
|
cur = ggml_add(ctx0,
|
||||||
ggml_mul(ctx0,
|
ggml_mul(ctx0,
|
||||||
ggml_repeat(ctx0, layer.attn_ln_0_w, cur),
|
cur,
|
||||||
cur),
|
layer.attn_ln_0_w),
|
||||||
ggml_repeat(ctx0, layer.attn_ln_0_b, cur));
|
layer.attn_ln_0_b);
|
||||||
}
|
}
|
||||||
|
|
||||||
// self-attention
|
// self-attention
|
||||||
@ -1994,10 +1994,8 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|||||||
cur);
|
cur);
|
||||||
|
|
||||||
Qcur = ggml_add(ctx0,
|
Qcur = ggml_add(ctx0,
|
||||||
ggml_repeat(ctx0,
|
Qcur,
|
||||||
layer.attn_q_b,
|
layer.attn_q_b);
|
||||||
Qcur),
|
|
||||||
Qcur);
|
|
||||||
|
|
||||||
Qcur = ggml_scale(ctx0, Qcur, KQscale);
|
Qcur = ggml_scale(ctx0, Qcur, KQscale);
|
||||||
|
|
||||||
@ -2015,10 +2013,8 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|||||||
cur);
|
cur);
|
||||||
|
|
||||||
Vcur = ggml_add(ctx0,
|
Vcur = ggml_add(ctx0,
|
||||||
ggml_repeat(ctx0,
|
Vcur,
|
||||||
layer.attn_v_b,
|
layer.attn_v_b);
|
||||||
Vcur),
|
|
||||||
Vcur);
|
|
||||||
|
|
||||||
Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, N));
|
Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, N));
|
||||||
|
|
||||||
@ -2079,8 +2075,8 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|||||||
cur);
|
cur);
|
||||||
|
|
||||||
cur = ggml_add(ctx0,
|
cur = ggml_add(ctx0,
|
||||||
ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
|
cur,
|
||||||
cur);
|
layer.attn_ln_1_b);
|
||||||
}
|
}
|
||||||
|
|
||||||
// add the input
|
// add the input
|
||||||
@ -2093,9 +2089,9 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|||||||
// cur = ln_0_w*cur + ln_0_b
|
// cur = ln_0_w*cur + ln_0_b
|
||||||
cur = ggml_add(ctx0,
|
cur = ggml_add(ctx0,
|
||||||
ggml_mul(ctx0,
|
ggml_mul(ctx0,
|
||||||
ggml_repeat(ctx0, layer.cross_attn_ln_0_w, cur),
|
cur,
|
||||||
cur),
|
layer.cross_attn_ln_0_w),
|
||||||
ggml_repeat(ctx0, layer.cross_attn_ln_0_b, cur));
|
layer.cross_attn_ln_0_b);
|
||||||
}
|
}
|
||||||
|
|
||||||
// cross-attention
|
// cross-attention
|
||||||
@ -2105,10 +2101,8 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|||||||
cur);
|
cur);
|
||||||
|
|
||||||
Qcur = ggml_add(ctx0,
|
Qcur = ggml_add(ctx0,
|
||||||
ggml_repeat(ctx0,
|
Qcur,
|
||||||
layer.cross_attn_q_b,
|
layer.cross_attn_q_b);
|
||||||
Qcur),
|
|
||||||
Qcur);
|
|
||||||
|
|
||||||
Qcur = ggml_scale(ctx0, Qcur, KQscale);
|
Qcur = ggml_scale(ctx0, Qcur, KQscale);
|
||||||
|
|
||||||
@ -2177,8 +2171,8 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|||||||
cur);
|
cur);
|
||||||
|
|
||||||
cur = ggml_add(ctx0,
|
cur = ggml_add(ctx0,
|
||||||
ggml_repeat(ctx0, layer.cross_attn_ln_1_b, cur),
|
cur,
|
||||||
cur);
|
layer.cross_attn_ln_1_b);
|
||||||
}
|
}
|
||||||
|
|
||||||
// add the input
|
// add the input
|
||||||
@ -2195,9 +2189,9 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|||||||
// cur = mlp_ln_w*cur + mlp_ln_b
|
// cur = mlp_ln_w*cur + mlp_ln_b
|
||||||
cur = ggml_add(ctx0,
|
cur = ggml_add(ctx0,
|
||||||
ggml_mul(ctx0,
|
ggml_mul(ctx0,
|
||||||
ggml_repeat(ctx0, layer.mlp_ln_w, cur),
|
cur,
|
||||||
cur),
|
layer.mlp_ln_w),
|
||||||
ggml_repeat(ctx0, layer.mlp_ln_b, cur));
|
layer.mlp_ln_b);
|
||||||
}
|
}
|
||||||
|
|
||||||
// fully connected
|
// fully connected
|
||||||
@ -2206,8 +2200,8 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|||||||
cur);
|
cur);
|
||||||
|
|
||||||
cur = ggml_add(ctx0,
|
cur = ggml_add(ctx0,
|
||||||
ggml_repeat(ctx0, layer.mlp_0_b, cur),
|
cur,
|
||||||
cur);
|
layer.mlp_0_b);
|
||||||
|
|
||||||
// GELU activation
|
// GELU activation
|
||||||
cur = ggml_gelu(ctx0, cur);
|
cur = ggml_gelu(ctx0, cur);
|
||||||
@ -2218,8 +2212,8 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|||||||
cur);
|
cur);
|
||||||
|
|
||||||
cur = ggml_add(ctx0,
|
cur = ggml_add(ctx0,
|
||||||
ggml_repeat(ctx0, layer.mlp_1_b, cur),
|
cur,
|
||||||
cur);
|
layer.mlp_1_b);
|
||||||
}
|
}
|
||||||
|
|
||||||
inpL = ggml_add(ctx0, cur, inpFF);
|
inpL = ggml_add(ctx0, cur, inpFF);
|
||||||
@ -2233,9 +2227,9 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|||||||
|
|
||||||
cur = ggml_add(ctx0,
|
cur = ggml_add(ctx0,
|
||||||
ggml_mul(ctx0,
|
ggml_mul(ctx0,
|
||||||
ggml_repeat(ctx0, model.d_ln_w, cur),
|
cur,
|
||||||
cur),
|
model.d_ln_w),
|
||||||
ggml_repeat(ctx0, model.d_ln_b, cur));
|
model.d_ln_b);
|
||||||
}
|
}
|
||||||
|
|
||||||
// compute logits only for the last token
|
// compute logits only for the last token
|
||||||
@ -2291,9 +2285,20 @@ static bool whisper_decode_internal(
|
|||||||
|
|
||||||
ggml_allocr_alloc_graph(alloc, gf);
|
ggml_allocr_alloc_graph(alloc, gf);
|
||||||
|
|
||||||
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
|
|
||||||
|
|
||||||
logits = gf->nodes[gf->n_nodes - 1];
|
logits = gf->nodes[gf->n_nodes - 1];
|
||||||
|
|
||||||
|
#ifdef GGML_USE_METAL
|
||||||
|
if (wstate.ctx_metal && n_tokens == 1) {
|
||||||
|
// TODO: fix for non-1 tokens
|
||||||
|
ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
|
||||||
|
ggml_metal_graph_compute(wstate.ctx_metal, gf);
|
||||||
|
ggml_metal_get_tensor (wstate.ctx_metal, logits);
|
||||||
|
} else {
|
||||||
|
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
// extract logits for all N tokens
|
// extract logits for all N tokens
|
||||||
@ -2760,6 +2765,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|||||||
if (!state->ctx_coreml) {
|
if (!state->ctx_coreml) {
|
||||||
log("%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str());
|
log("%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str());
|
||||||
#ifndef WHISPER_COREML_ALLOW_FALLBACK
|
#ifndef WHISPER_COREML_ALLOW_FALLBACK
|
||||||
|
delete state;
|
||||||
return nullptr;
|
return nullptr;
|
||||||
#endif
|
#endif
|
||||||
} else {
|
} else {
|
||||||
@ -2840,7 +2846,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|||||||
const int n_tokens = hparams.n_text_ctx;
|
const int n_tokens = hparams.n_text_ctx;
|
||||||
const int n_past = 0;
|
const int n_past = 0;
|
||||||
|
|
||||||
ggml_cgraph * gf = whisper_build_graph_decoder(*ctx, *state, state->decoders[0], NULL, n_tokens, n_past);
|
ggml_cgraph * gf = whisper_build_graph_decoder(*ctx, *state, state->decoders[0], nullptr, n_tokens, n_past);
|
||||||
|
|
||||||
const size_t alloc_size = ggml_allocr_alloc_graph(alloc, gf) + tensor_alignment;
|
const size_t alloc_size = ggml_allocr_alloc_graph(alloc, gf) + tensor_alignment;
|
||||||
|
|
||||||
@ -2852,6 +2858,61 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|||||||
alloc = ggml_allocr_new(data.data(), data.size(), tensor_alignment);
|
alloc = ggml_allocr_new(data.data(), data.size(), tensor_alignment);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef GGML_USE_METAL
|
||||||
|
state->ctx_metal = ggml_metal_init(1);
|
||||||
|
if (!state->ctx_metal) {
|
||||||
|
log("%s: ggml_metal_init() failed\n", __func__);
|
||||||
|
delete state;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
log("%s: Metal context initialized\n", __func__);
|
||||||
|
|
||||||
|
// this allocates all Metal resources and memory buffers
|
||||||
|
|
||||||
|
void * data_ptr = NULL;
|
||||||
|
size_t data_size = 0;
|
||||||
|
|
||||||
|
// TODO: add mmap support
|
||||||
|
//if (params.use_mmap) {
|
||||||
|
// data_ptr = ctx->model.mapping->addr;
|
||||||
|
// data_size = ctx->model.mapping->size;
|
||||||
|
//} else {
|
||||||
|
// data_ptr = ggml_get_mem_buffer(ctx->model.ctx);
|
||||||
|
// data_size = ggml_get_mem_size (ctx->model.ctx);
|
||||||
|
//}
|
||||||
|
|
||||||
|
data_ptr = ggml_get_mem_buffer(ctx->model.ctx);
|
||||||
|
data_size = ggml_get_mem_size (ctx->model.ctx);
|
||||||
|
|
||||||
|
const size_t max_size = ggml_get_max_tensor_size(ctx->model.ctx);
|
||||||
|
|
||||||
|
log("%s: max tensor size = %8.2f MB\n", __func__, max_size/1024.0/1024.0);
|
||||||
|
|
||||||
|
#define WHISPER_METAL_CHECK_BUF(result) \
|
||||||
|
if (!(result)) { \
|
||||||
|
log("%s: failed to add buffer\n", __func__); \
|
||||||
|
delete state; \
|
||||||
|
return nullptr; \
|
||||||
|
}
|
||||||
|
|
||||||
|
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data", data_ptr, data_size, max_size));
|
||||||
|
|
||||||
|
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_encode", state->meta_encode.data(), state->meta_encode.size(), 0));
|
||||||
|
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_cross", state->meta_cross.data(), state->meta_cross.size(), 0));
|
||||||
|
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_decode", state->meta_decode.data(), state->meta_decode.size(), 0));
|
||||||
|
|
||||||
|
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_encode", state->data_encode.data(), state->data_encode.size(), 0));
|
||||||
|
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_cross", state->data_cross.data(), state->data_cross.size(), 0));
|
||||||
|
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_decode", state->data_decode.data(), state->data_decode.size(), 0));
|
||||||
|
|
||||||
|
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "kv_cross", state->kv_cross.buf.data(), state->kv_cross.buf.size(), 0));
|
||||||
|
|
||||||
|
// TODO: handle multiple decoders
|
||||||
|
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "kv_self_0", state->decoders[0].kv_self.buf.data(), state->decoders[0].kv_self.buf.size(), 0));
|
||||||
|
#undef WHISPER_METAL_CHECK_BUF
|
||||||
|
#endif
|
||||||
|
|
||||||
state->rng = std::mt19937(0);
|
state->rng = std::mt19937(0);
|
||||||
|
|
||||||
return state;
|
return state;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user