whisper : full batched decoding support

This commit is contained in:
Georgi Gerganov 2023-11-14 16:57:28 +02:00
parent 8b943f9843
commit 91096daa1a
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -688,6 +688,7 @@ struct whisper_decoder {
// grammar parse state of generated sequence of tokens
whisper_grammar grammar;
int i_batch; // the index of the token in the current batch
int seek_delta; // the window shift found so far based on the decoded timestamp tokens
bool failed; // has the current segment failed to decode?
@ -2228,7 +2229,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
ggml_allocr * alloc = wstate.alloc_decode.alloc;
const int n_ctx = hparams.n_text_ctx;
const int n_ctx = kv_self.size;
const int n_state = hparams.n_text_state;
const int n_head = hparams.n_text_head;
const int n_layer = hparams.n_text_layer;
@ -2569,7 +2570,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
// compute logits only for the last token
// comment this line to compute logits for all n_tokens
// might be useful in the future
cur = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]);
//cur = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]);
struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
@ -2602,22 +2603,26 @@ static bool whisper_decode_internal(
const auto & model = wctx.model;
const auto & hparams = model.hparams;
const int n_vocab = hparams.n_vocab;
const int n_vocab = hparams.n_vocab;
const int n_tokens = batch.n_tokens;
auto & logits_out = wstate.logits;
struct ggml_tensor * logits;
auto & kv_self = wstate.kv_self;
// find KV slot for the batch
{
auto & kv_self = wstate.kv_self;
if (!whisper_kv_cache_find_slot(kv_self, batch)) {
return 1;
if (!whisper_kv_cache_find_slot(kv_self, batch)) {
return false;
}
kv_self.n = whisper_kv_cache_cell_max(kv_self);
//kv_self.n = std::min((int32_t) hparams.n_text_ctx, std::max(32, whisper_kv_cache_cell_max(kv_self)));
//printf("n_tokens = %5d, kv_self.head = %5d, kv_self.n = %5d, seq_id = %5d\n", batch.n_tokens, kv_self.head, kv_self.n, batch.seq_id[0][0]);
}
kv_self.n = whisper_kv_cache_cell_max(kv_self);
//kv_self.n = std::min((int32_t) hparams.n_text_ctx, std::max(32, whisper_kv_cache_cell_max(kv_self)));
//printf("n_tokens = %5d, kv_self.head = %5d, kv_self.n = %5d, seq_id = %5d\n", batch.n_tokens, kv_self.head, kv_self.n, batch.seq_id[0][0]);
// decoder
{
auto & alloc = wstate.alloc_decode.alloc;
@ -2633,15 +2638,13 @@ static bool whisper_decode_internal(
ggml_graph_compute_helper(wstate.backend, gf, n_threads);
}
// extract logits for all N tokens
//logits_out.resize(n_tokens*n_vocab);
//memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_tokens*n_vocab);
//ggml_backend_tensor_get(logits, logits_out.data(), (n_vocab*(n_tokens - 1))*sizeof(float), sizeof(float)*n_vocab);
// extract logits only for the last token
logits_out.resize(n_vocab);
//memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab);
ggml_backend_tensor_get(logits, logits_out.data(), 0, sizeof(float)*n_vocab);
logits_out.resize(n_tokens*n_vocab);
for (int i = 0; i < n_tokens; i++) {
if (batch.logits[i] == 0) {
continue;
}
ggml_backend_tensor_get(logits, logits_out.data() + (n_vocab*i), sizeof(float)*(n_vocab*i), sizeof(float)*n_vocab);
}
if (batch.n_tokens > 1) {
//printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
@ -3074,7 +3077,10 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
state->backend = whisper_backend_init(ctx->params);
if (!kv_cache_init(ctx->model.hparams, state->kv_self, ctx->backend, ctx->itype, ctx->model.hparams.n_text_ctx)) {
// TODO: determine how large the cache should be
const int factor = 2;
if (!kv_cache_init(ctx->model.hparams, state->kv_self, ctx->backend, ctx->itype, factor*ctx->model.hparams.n_text_ctx)) {
WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__);
delete state;
return nullptr;
@ -4566,7 +4572,7 @@ static void whisper_process_logits(
auto & logprobs = decoder.logprobs;
{
logits.resize(n_logits);
memcpy(logits.data(), state.logits.data() + (state.logits.size() - n_logits), n_logits*sizeof(float));
memcpy(logits.data(), state.logits.data() + decoder.i_batch*n_logits, n_logits*sizeof(float));
if (temperature > 0.0f) {
for (int i = 0; i < n_logits; i++) {
@ -5317,6 +5323,8 @@ int whisper_full_with_state(
{
const int64_t t_start_sample_us = ggml_time_us();
state->decoders[0].i_batch = prompt.size() - 1;
whisper_process_logits(*ctx, *state, params, state->decoders[0], t_cur);
for (int j = 1; j < n_decoders_cur; ++j) {
@ -5384,7 +5392,6 @@ int whisper_full_with_state(
});
uint32_t cur_c = 0;
std::vector<int> decoder_idx(n_decoders_cur, -1);
for (int j = 0; j < n_decoders_cur; ++j) {
auto & decoder = state->decoders[j];
@ -5408,8 +5415,6 @@ int whisper_full_with_state(
decoder.sequence = cur.sequence;
decoder.grammar = cur.grammar;
decoder_idx[j] = cur.decoder_idx;
whisper_kv_cache_seq_cp(state->kv_self, cur.decoder_idx, WHISPER_MAX_DECODERS + j, -1, -1);
WHISPER_PRINT_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n",
@ -5535,32 +5540,52 @@ int whisper_full_with_state(
state->t_sample_us += ggml_time_us() - t_start_sample_us;
// obtain logits for the next token
for (int j = 0; j < n_decoders_cur; ++j) {
auto & decoder = state->decoders[j];
{
auto & batch = state->batch;
if (decoder.failed || decoder.completed) {
continue;
}
batch.n_tokens = 0;
//WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, seek_delta %d\n", __func__, j, decoder.sequence.tokens.back().id, decoder.seek_delta);
// TODO: use batch
const int n_past = prompt.size() + i;
whisper_batch_prep_legacy(state->batch, &decoder.sequence.tokens.back().id, 1, n_past, j);
for (int j = 0; j < n_decoders_cur; ++j) {
auto & decoder = state->decoders[j];
if (decoder.failed || decoder.completed) {
continue;
}
//WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, seek_delta %d\n", __func__, j, decoder.sequence.tokens.back().id, decoder.seek_delta);
decoder.i_batch = batch.n_tokens;
batch.token [batch.n_tokens] = decoder.sequence.tokens.back().id;
batch.pos [batch.n_tokens] = n_past;
batch.n_seq_id[batch.n_tokens] = 1;
batch.seq_id [batch.n_tokens][0] = j;
batch.logits [batch.n_tokens] = 1;
batch.n_tokens++;
}
assert(batch.n_tokens > 0);
if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
return -8;
}
{
const int64_t t_start_sample_us = ggml_time_us();
const int64_t t_start_sample_us = ggml_time_us();
for (int j = 0; j < n_decoders_cur; ++j) {
auto & decoder = state->decoders[j];
if (decoder.failed || decoder.completed) {
continue;
}
whisper_process_logits(*ctx, *state, params, decoder, t_cur);
state->t_sample_us += ggml_time_us() - t_start_sample_us;
}
state->t_sample_us += ggml_time_us() - t_start_sample_us;
}
}