whisper : full batched decoding support

This commit is contained in:
Georgi Gerganov 2023-11-14 16:57:28 +02:00
parent 8b943f9843
commit 91096daa1a
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -688,6 +688,7 @@ struct whisper_decoder {
// grammar parse state of generated sequence of tokens // grammar parse state of generated sequence of tokens
whisper_grammar grammar; whisper_grammar grammar;
int i_batch; // the index of the token in the current batch
int seek_delta; // the window shift found so far based on the decoded timestamp tokens int seek_delta; // the window shift found so far based on the decoded timestamp tokens
bool failed; // has the current segment failed to decode? bool failed; // has the current segment failed to decode?
@ -2228,7 +2229,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
ggml_allocr * alloc = wstate.alloc_decode.alloc; ggml_allocr * alloc = wstate.alloc_decode.alloc;
const int n_ctx = hparams.n_text_ctx; const int n_ctx = kv_self.size;
const int n_state = hparams.n_text_state; const int n_state = hparams.n_text_state;
const int n_head = hparams.n_text_head; const int n_head = hparams.n_text_head;
const int n_layer = hparams.n_text_layer; const int n_layer = hparams.n_text_layer;
@ -2569,7 +2570,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
// compute logits only for the last token // compute logits only for the last token
// comment this line to compute logits for all n_tokens // comment this line to compute logits for all n_tokens
// might be useful in the future // might be useful in the future
cur = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]); //cur = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]);
struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur); struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
@ -2602,22 +2603,26 @@ static bool whisper_decode_internal(
const auto & model = wctx.model; const auto & model = wctx.model;
const auto & hparams = model.hparams; const auto & hparams = model.hparams;
const int n_vocab = hparams.n_vocab; const int n_vocab = hparams.n_vocab;
const int n_tokens = batch.n_tokens;
auto & logits_out = wstate.logits; auto & logits_out = wstate.logits;
struct ggml_tensor * logits; struct ggml_tensor * logits;
auto & kv_self = wstate.kv_self; // find KV slot for the batch
{
auto & kv_self = wstate.kv_self;
if (!whisper_kv_cache_find_slot(kv_self, batch)) { if (!whisper_kv_cache_find_slot(kv_self, batch)) {
return 1; return false;
}
kv_self.n = whisper_kv_cache_cell_max(kv_self);
//kv_self.n = std::min((int32_t) hparams.n_text_ctx, std::max(32, whisper_kv_cache_cell_max(kv_self)));
//printf("n_tokens = %5d, kv_self.head = %5d, kv_self.n = %5d, seq_id = %5d\n", batch.n_tokens, kv_self.head, kv_self.n, batch.seq_id[0][0]);
} }
kv_self.n = whisper_kv_cache_cell_max(kv_self);
//kv_self.n = std::min((int32_t) hparams.n_text_ctx, std::max(32, whisper_kv_cache_cell_max(kv_self)));
//printf("n_tokens = %5d, kv_self.head = %5d, kv_self.n = %5d, seq_id = %5d\n", batch.n_tokens, kv_self.head, kv_self.n, batch.seq_id[0][0]);
// decoder // decoder
{ {
auto & alloc = wstate.alloc_decode.alloc; auto & alloc = wstate.alloc_decode.alloc;
@ -2633,15 +2638,13 @@ static bool whisper_decode_internal(
ggml_graph_compute_helper(wstate.backend, gf, n_threads); ggml_graph_compute_helper(wstate.backend, gf, n_threads);
} }
// extract logits for all N tokens logits_out.resize(n_tokens*n_vocab);
//logits_out.resize(n_tokens*n_vocab); for (int i = 0; i < n_tokens; i++) {
//memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_tokens*n_vocab); if (batch.logits[i] == 0) {
//ggml_backend_tensor_get(logits, logits_out.data(), (n_vocab*(n_tokens - 1))*sizeof(float), sizeof(float)*n_vocab); continue;
}
// extract logits only for the last token ggml_backend_tensor_get(logits, logits_out.data() + (n_vocab*i), sizeof(float)*(n_vocab*i), sizeof(float)*n_vocab);
logits_out.resize(n_vocab); }
//memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab);
ggml_backend_tensor_get(logits, logits_out.data(), 0, sizeof(float)*n_vocab);
if (batch.n_tokens > 1) { if (batch.n_tokens > 1) {
//printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
@ -3074,7 +3077,10 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
state->backend = whisper_backend_init(ctx->params); state->backend = whisper_backend_init(ctx->params);
if (!kv_cache_init(ctx->model.hparams, state->kv_self, ctx->backend, ctx->itype, ctx->model.hparams.n_text_ctx)) { // TODO: determine how large the cache should be
const int factor = 2;
if (!kv_cache_init(ctx->model.hparams, state->kv_self, ctx->backend, ctx->itype, factor*ctx->model.hparams.n_text_ctx)) {
WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__); WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__);
delete state; delete state;
return nullptr; return nullptr;
@ -4566,7 +4572,7 @@ static void whisper_process_logits(
auto & logprobs = decoder.logprobs; auto & logprobs = decoder.logprobs;
{ {
logits.resize(n_logits); logits.resize(n_logits);
memcpy(logits.data(), state.logits.data() + (state.logits.size() - n_logits), n_logits*sizeof(float)); memcpy(logits.data(), state.logits.data() + decoder.i_batch*n_logits, n_logits*sizeof(float));
if (temperature > 0.0f) { if (temperature > 0.0f) {
for (int i = 0; i < n_logits; i++) { for (int i = 0; i < n_logits; i++) {
@ -5317,6 +5323,8 @@ int whisper_full_with_state(
{ {
const int64_t t_start_sample_us = ggml_time_us(); const int64_t t_start_sample_us = ggml_time_us();
state->decoders[0].i_batch = prompt.size() - 1;
whisper_process_logits(*ctx, *state, params, state->decoders[0], t_cur); whisper_process_logits(*ctx, *state, params, state->decoders[0], t_cur);
for (int j = 1; j < n_decoders_cur; ++j) { for (int j = 1; j < n_decoders_cur; ++j) {
@ -5384,7 +5392,6 @@ int whisper_full_with_state(
}); });
uint32_t cur_c = 0; uint32_t cur_c = 0;
std::vector<int> decoder_idx(n_decoders_cur, -1);
for (int j = 0; j < n_decoders_cur; ++j) { for (int j = 0; j < n_decoders_cur; ++j) {
auto & decoder = state->decoders[j]; auto & decoder = state->decoders[j];
@ -5408,8 +5415,6 @@ int whisper_full_with_state(
decoder.sequence = cur.sequence; decoder.sequence = cur.sequence;
decoder.grammar = cur.grammar; decoder.grammar = cur.grammar;
decoder_idx[j] = cur.decoder_idx;
whisper_kv_cache_seq_cp(state->kv_self, cur.decoder_idx, WHISPER_MAX_DECODERS + j, -1, -1); whisper_kv_cache_seq_cp(state->kv_self, cur.decoder_idx, WHISPER_MAX_DECODERS + j, -1, -1);
WHISPER_PRINT_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n", WHISPER_PRINT_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n",
@ -5535,32 +5540,52 @@ int whisper_full_with_state(
state->t_sample_us += ggml_time_us() - t_start_sample_us; state->t_sample_us += ggml_time_us() - t_start_sample_us;
// obtain logits for the next token // obtain logits for the next token
for (int j = 0; j < n_decoders_cur; ++j) { {
auto & decoder = state->decoders[j]; auto & batch = state->batch;
if (decoder.failed || decoder.completed) { batch.n_tokens = 0;
continue;
}
//WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, seek_delta %d\n", __func__, j, decoder.sequence.tokens.back().id, decoder.seek_delta);
// TODO: use batch
const int n_past = prompt.size() + i; const int n_past = prompt.size() + i;
whisper_batch_prep_legacy(state->batch, &decoder.sequence.tokens.back().id, 1, n_past, j); for (int j = 0; j < n_decoders_cur; ++j) {
auto & decoder = state->decoders[j];
if (decoder.failed || decoder.completed) {
continue;
}
//WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, seek_delta %d\n", __func__, j, decoder.sequence.tokens.back().id, decoder.seek_delta);
decoder.i_batch = batch.n_tokens;
batch.token [batch.n_tokens] = decoder.sequence.tokens.back().id;
batch.pos [batch.n_tokens] = n_past;
batch.n_seq_id[batch.n_tokens] = 1;
batch.seq_id [batch.n_tokens][0] = j;
batch.logits [batch.n_tokens] = 1;
batch.n_tokens++;
}
assert(batch.n_tokens > 0);
if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) { if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__); WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
return -8; return -8;
} }
{ const int64_t t_start_sample_us = ggml_time_us();
const int64_t t_start_sample_us = ggml_time_us();
for (int j = 0; j < n_decoders_cur; ++j) {
auto & decoder = state->decoders[j];
if (decoder.failed || decoder.completed) {
continue;
}
whisper_process_logits(*ctx, *state, params, decoder, t_cur); whisper_process_logits(*ctx, *state, params, decoder, t_cur);
state->t_sample_us += ggml_time_us() - t_start_sample_us;
} }
state->t_sample_us += ggml_time_us() - t_start_sample_us;
} }
} }