whisper : initial Metal version

This commit is contained in:
Georgi Gerganov 2023-09-11 15:37:24 +03:00
parent 4845b9ed09
commit d3b2dd4955
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
3 changed files with 122 additions and 41 deletions

View File

@ -64,6 +64,7 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(gelu);
GGML_METAL_DECL_KERNEL(soft_max);
GGML_METAL_DECL_KERNEL(diag_mask_inf);
GGML_METAL_DECL_KERNEL(get_rows_f32);
GGML_METAL_DECL_KERNEL(get_rows_f16);
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
@ -208,6 +209,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(gelu);
GGML_METAL_ADD_KERNEL(soft_max);
GGML_METAL_ADD_KERNEL(diag_mask_inf);
GGML_METAL_ADD_KERNEL(get_rows_f32);
GGML_METAL_ADD_KERNEL(get_rows_f16);
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
@ -274,6 +276,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(gelu);
GGML_METAL_DEL_KERNEL(soft_max);
GGML_METAL_DEL_KERNEL(diag_mask_inf);
GGML_METAL_DEL_KERNEL(get_rows_f32);
GGML_METAL_DEL_KERNEL(get_rows_f16);
GGML_METAL_DEL_KERNEL(get_rows_q4_0);
GGML_METAL_DEL_KERNEL(get_rows_q4_1);
@ -1003,6 +1006,7 @@ void ggml_metal_graph_compute(
case GGML_OP_GET_ROWS:
{
switch (src0->type) {
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_get_rows_f32]; break;
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;

View File

@ -195,6 +195,22 @@ kernel void kernel_diag_mask_inf(
}
}
kernel void kernel_get_rows_f32(
device const float * src0,
device const int * src1,
device float * dst,
constant int64_t & ne00,
constant uint64_t & nb01,
constant uint64_t & nb1,
uint tpig[[thread_position_in_grid]]) {
const int i = tpig;
const int r = ((device int32_t *) src1)[i];
for (int j = 0; j < ne00; j++) {
dst[i*nb1 + j] = ((device float *) ((device char *) src0 + r*nb01))[j];
}
}
kernel void kernel_norm(
device const void * src0,
device float * dst,

View File

@ -625,9 +625,9 @@ struct whisper_state {
// - stores meta info about the intermediate tensors into the `meta_*` buffers
// - stores the actual tensor data into the `data_*` buffers
ggml_allocr * alloc_encode = NULL;
ggml_allocr * alloc_cross = NULL;
ggml_allocr * alloc_decode = NULL;
ggml_allocr * alloc_encode = nullptr;
ggml_allocr * alloc_cross = nullptr;
ggml_allocr * alloc_decode = nullptr;
// meta data
std::vector<uint8_t> meta_encode;
@ -640,7 +640,7 @@ struct whisper_state {
std::vector<uint8_t> data_decode;
// result of the encoder
struct ggml_tensor * embd_enc = NULL;
struct ggml_tensor * embd_enc = nullptr;
// decode output (2-dimensional array: [n_tokens][n_vocab])
std::vector<float> logits;
@ -1982,9 +1982,9 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
// cur = ln_0_w*cur + ln_0_b
cur = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, layer.attn_ln_0_w, cur),
cur),
ggml_repeat(ctx0, layer.attn_ln_0_b, cur));
cur,
layer.attn_ln_0_w),
layer.attn_ln_0_b);
}
// self-attention
@ -1994,10 +1994,8 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
cur);
Qcur = ggml_add(ctx0,
ggml_repeat(ctx0,
layer.attn_q_b,
Qcur),
Qcur);
Qcur,
layer.attn_q_b);
Qcur = ggml_scale(ctx0, Qcur, KQscale);
@ -2015,10 +2013,8 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
cur);
Vcur = ggml_add(ctx0,
ggml_repeat(ctx0,
layer.attn_v_b,
Vcur),
Vcur);
Vcur,
layer.attn_v_b);
Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, N));
@ -2079,8 +2075,8 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
cur);
cur = ggml_add(ctx0,
ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
cur);
cur,
layer.attn_ln_1_b);
}
// add the input
@ -2093,9 +2089,9 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
// cur = ln_0_w*cur + ln_0_b
cur = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, layer.cross_attn_ln_0_w, cur),
cur),
ggml_repeat(ctx0, layer.cross_attn_ln_0_b, cur));
cur,
layer.cross_attn_ln_0_w),
layer.cross_attn_ln_0_b);
}
// cross-attention
@ -2105,10 +2101,8 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
cur);
Qcur = ggml_add(ctx0,
ggml_repeat(ctx0,
layer.cross_attn_q_b,
Qcur),
Qcur);
Qcur,
layer.cross_attn_q_b);
Qcur = ggml_scale(ctx0, Qcur, KQscale);
@ -2177,8 +2171,8 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
cur);
cur = ggml_add(ctx0,
ggml_repeat(ctx0, layer.cross_attn_ln_1_b, cur),
cur);
cur,
layer.cross_attn_ln_1_b);
}
// add the input
@ -2195,9 +2189,9 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
// cur = mlp_ln_w*cur + mlp_ln_b
cur = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, layer.mlp_ln_w, cur),
cur),
ggml_repeat(ctx0, layer.mlp_ln_b, cur));
cur,
layer.mlp_ln_w),
layer.mlp_ln_b);
}
// fully connected
@ -2206,8 +2200,8 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
cur);
cur = ggml_add(ctx0,
ggml_repeat(ctx0, layer.mlp_0_b, cur),
cur);
cur,
layer.mlp_0_b);
// GELU activation
cur = ggml_gelu(ctx0, cur);
@ -2218,8 +2212,8 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
cur);
cur = ggml_add(ctx0,
ggml_repeat(ctx0, layer.mlp_1_b, cur),
cur);
cur,
layer.mlp_1_b);
}
inpL = ggml_add(ctx0, cur, inpFF);
@ -2233,9 +2227,9 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
cur = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, model.d_ln_w, cur),
cur),
ggml_repeat(ctx0, model.d_ln_b, cur));
cur,
model.d_ln_w),
model.d_ln_b);
}
// compute logits only for the last token
@ -2291,9 +2285,20 @@ static bool whisper_decode_internal(
ggml_allocr_alloc_graph(alloc, gf);
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
logits = gf->nodes[gf->n_nodes - 1];
#ifdef GGML_USE_METAL
if (wstate.ctx_metal && n_tokens == 1) {
// TODO: fix for non-1 tokens
ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
ggml_metal_graph_compute(wstate.ctx_metal, gf);
ggml_metal_get_tensor (wstate.ctx_metal, logits);
} else {
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
}
#else
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
#endif
}
// extract logits for all N tokens
@ -2760,6 +2765,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
if (!state->ctx_coreml) {
log("%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str());
#ifndef WHISPER_COREML_ALLOW_FALLBACK
delete state;
return nullptr;
#endif
} else {
@ -2840,7 +2846,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
const int n_tokens = hparams.n_text_ctx;
const int n_past = 0;
ggml_cgraph * gf = whisper_build_graph_decoder(*ctx, *state, state->decoders[0], NULL, n_tokens, n_past);
ggml_cgraph * gf = whisper_build_graph_decoder(*ctx, *state, state->decoders[0], nullptr, n_tokens, n_past);
const size_t alloc_size = ggml_allocr_alloc_graph(alloc, gf) + tensor_alignment;
@ -2852,6 +2858,61 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
alloc = ggml_allocr_new(data.data(), data.size(), tensor_alignment);
}
#ifdef GGML_USE_METAL
state->ctx_metal = ggml_metal_init(1);
if (!state->ctx_metal) {
log("%s: ggml_metal_init() failed\n", __func__);
delete state;
return nullptr;
}
log("%s: Metal context initialized\n", __func__);
// this allocates all Metal resources and memory buffers
void * data_ptr = NULL;
size_t data_size = 0;
// TODO: add mmap support
//if (params.use_mmap) {
// data_ptr = ctx->model.mapping->addr;
// data_size = ctx->model.mapping->size;
//} else {
// data_ptr = ggml_get_mem_buffer(ctx->model.ctx);
// data_size = ggml_get_mem_size (ctx->model.ctx);
//}
data_ptr = ggml_get_mem_buffer(ctx->model.ctx);
data_size = ggml_get_mem_size (ctx->model.ctx);
const size_t max_size = ggml_get_max_tensor_size(ctx->model.ctx);
log("%s: max tensor size = %8.2f MB\n", __func__, max_size/1024.0/1024.0);
#define WHISPER_METAL_CHECK_BUF(result) \
if (!(result)) { \
log("%s: failed to add buffer\n", __func__); \
delete state; \
return nullptr; \
}
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data", data_ptr, data_size, max_size));
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_encode", state->meta_encode.data(), state->meta_encode.size(), 0));
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_cross", state->meta_cross.data(), state->meta_cross.size(), 0));
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_decode", state->meta_decode.data(), state->meta_decode.size(), 0));
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_encode", state->data_encode.data(), state->data_encode.size(), 0));
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_cross", state->data_cross.data(), state->data_cross.size(), 0));
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_decode", state->data_decode.data(), state->data_decode.size(), 0));
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "kv_cross", state->kv_cross.buf.data(), state->kv_cross.buf.size(), 0));
// TODO: handle multiple decoders
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "kv_self_0", state->decoders[0].kv_self.buf.data(), state->decoders[0].kv_self.buf.size(), 0));
#undef WHISPER_METAL_CHECK_BUF
#endif
state->rng = std::mt19937(0);
return state;
@ -4561,8 +4622,8 @@ int whisper_full_with_state(
decoder.kv_self.n += prompt.size();
memcpy(decoder.probs.data(), state->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0]));
memcpy(decoder.logits.data(), state->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0]));
memcpy(decoder.probs.data(), state->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0]));
memcpy(decoder.logits.data(), state->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0]));
memcpy(decoder.logprobs.data(), state->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0]));
}