mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-01-27 06:39:26 +00:00
whisper : full batched decoding support
This commit is contained in:
parent
8b943f9843
commit
91096daa1a
99
whisper.cpp
99
whisper.cpp
@ -688,6 +688,7 @@ struct whisper_decoder {
|
||||
// grammar parse state of generated sequence of tokens
|
||||
whisper_grammar grammar;
|
||||
|
||||
int i_batch; // the index of the token in the current batch
|
||||
int seek_delta; // the window shift found so far based on the decoded timestamp tokens
|
||||
|
||||
bool failed; // has the current segment failed to decode?
|
||||
@ -2228,7 +2229,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
||||
|
||||
ggml_allocr * alloc = wstate.alloc_decode.alloc;
|
||||
|
||||
const int n_ctx = hparams.n_text_ctx;
|
||||
const int n_ctx = kv_self.size;
|
||||
const int n_state = hparams.n_text_state;
|
||||
const int n_head = hparams.n_text_head;
|
||||
const int n_layer = hparams.n_text_layer;
|
||||
@ -2569,7 +2570,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
||||
// compute logits only for the last token
|
||||
// comment this line to compute logits for all n_tokens
|
||||
// might be useful in the future
|
||||
cur = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]);
|
||||
//cur = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]);
|
||||
|
||||
struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
|
||||
|
||||
@ -2602,22 +2603,26 @@ static bool whisper_decode_internal(
|
||||
const auto & model = wctx.model;
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
const int n_vocab = hparams.n_vocab;
|
||||
const int n_vocab = hparams.n_vocab;
|
||||
const int n_tokens = batch.n_tokens;
|
||||
|
||||
auto & logits_out = wstate.logits;
|
||||
|
||||
struct ggml_tensor * logits;
|
||||
|
||||
auto & kv_self = wstate.kv_self;
|
||||
// find KV slot for the batch
|
||||
{
|
||||
auto & kv_self = wstate.kv_self;
|
||||
|
||||
if (!whisper_kv_cache_find_slot(kv_self, batch)) {
|
||||
return 1;
|
||||
if (!whisper_kv_cache_find_slot(kv_self, batch)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
kv_self.n = whisper_kv_cache_cell_max(kv_self);
|
||||
//kv_self.n = std::min((int32_t) hparams.n_text_ctx, std::max(32, whisper_kv_cache_cell_max(kv_self)));
|
||||
//printf("n_tokens = %5d, kv_self.head = %5d, kv_self.n = %5d, seq_id = %5d\n", batch.n_tokens, kv_self.head, kv_self.n, batch.seq_id[0][0]);
|
||||
}
|
||||
|
||||
kv_self.n = whisper_kv_cache_cell_max(kv_self);
|
||||
//kv_self.n = std::min((int32_t) hparams.n_text_ctx, std::max(32, whisper_kv_cache_cell_max(kv_self)));
|
||||
//printf("n_tokens = %5d, kv_self.head = %5d, kv_self.n = %5d, seq_id = %5d\n", batch.n_tokens, kv_self.head, kv_self.n, batch.seq_id[0][0]);
|
||||
|
||||
// decoder
|
||||
{
|
||||
auto & alloc = wstate.alloc_decode.alloc;
|
||||
@ -2633,15 +2638,13 @@ static bool whisper_decode_internal(
|
||||
ggml_graph_compute_helper(wstate.backend, gf, n_threads);
|
||||
}
|
||||
|
||||
// extract logits for all N tokens
|
||||
//logits_out.resize(n_tokens*n_vocab);
|
||||
//memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_tokens*n_vocab);
|
||||
//ggml_backend_tensor_get(logits, logits_out.data(), (n_vocab*(n_tokens - 1))*sizeof(float), sizeof(float)*n_vocab);
|
||||
|
||||
// extract logits only for the last token
|
||||
logits_out.resize(n_vocab);
|
||||
//memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab);
|
||||
ggml_backend_tensor_get(logits, logits_out.data(), 0, sizeof(float)*n_vocab);
|
||||
logits_out.resize(n_tokens*n_vocab);
|
||||
for (int i = 0; i < n_tokens; i++) {
|
||||
if (batch.logits[i] == 0) {
|
||||
continue;
|
||||
}
|
||||
ggml_backend_tensor_get(logits, logits_out.data() + (n_vocab*i), sizeof(float)*(n_vocab*i), sizeof(float)*n_vocab);
|
||||
}
|
||||
|
||||
if (batch.n_tokens > 1) {
|
||||
//printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
|
||||
@ -3074,7 +3077,10 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
||||
|
||||
state->backend = whisper_backend_init(ctx->params);
|
||||
|
||||
if (!kv_cache_init(ctx->model.hparams, state->kv_self, ctx->backend, ctx->itype, ctx->model.hparams.n_text_ctx)) {
|
||||
// TODO: determine how large the cache should be
|
||||
const int factor = 2;
|
||||
|
||||
if (!kv_cache_init(ctx->model.hparams, state->kv_self, ctx->backend, ctx->itype, factor*ctx->model.hparams.n_text_ctx)) {
|
||||
WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__);
|
||||
delete state;
|
||||
return nullptr;
|
||||
@ -4566,7 +4572,7 @@ static void whisper_process_logits(
|
||||
auto & logprobs = decoder.logprobs;
|
||||
{
|
||||
logits.resize(n_logits);
|
||||
memcpy(logits.data(), state.logits.data() + (state.logits.size() - n_logits), n_logits*sizeof(float));
|
||||
memcpy(logits.data(), state.logits.data() + decoder.i_batch*n_logits, n_logits*sizeof(float));
|
||||
|
||||
if (temperature > 0.0f) {
|
||||
for (int i = 0; i < n_logits; i++) {
|
||||
@ -5317,6 +5323,8 @@ int whisper_full_with_state(
|
||||
{
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
state->decoders[0].i_batch = prompt.size() - 1;
|
||||
|
||||
whisper_process_logits(*ctx, *state, params, state->decoders[0], t_cur);
|
||||
|
||||
for (int j = 1; j < n_decoders_cur; ++j) {
|
||||
@ -5384,7 +5392,6 @@ int whisper_full_with_state(
|
||||
});
|
||||
|
||||
uint32_t cur_c = 0;
|
||||
std::vector<int> decoder_idx(n_decoders_cur, -1);
|
||||
|
||||
for (int j = 0; j < n_decoders_cur; ++j) {
|
||||
auto & decoder = state->decoders[j];
|
||||
@ -5408,8 +5415,6 @@ int whisper_full_with_state(
|
||||
decoder.sequence = cur.sequence;
|
||||
decoder.grammar = cur.grammar;
|
||||
|
||||
decoder_idx[j] = cur.decoder_idx;
|
||||
|
||||
whisper_kv_cache_seq_cp(state->kv_self, cur.decoder_idx, WHISPER_MAX_DECODERS + j, -1, -1);
|
||||
|
||||
WHISPER_PRINT_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n",
|
||||
@ -5535,32 +5540,52 @@ int whisper_full_with_state(
|
||||
state->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
|
||||
// obtain logits for the next token
|
||||
for (int j = 0; j < n_decoders_cur; ++j) {
|
||||
auto & decoder = state->decoders[j];
|
||||
{
|
||||
auto & batch = state->batch;
|
||||
|
||||
if (decoder.failed || decoder.completed) {
|
||||
continue;
|
||||
}
|
||||
batch.n_tokens = 0;
|
||||
|
||||
//WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, seek_delta %d\n", __func__, j, decoder.sequence.tokens.back().id, decoder.seek_delta);
|
||||
|
||||
// TODO: use batch
|
||||
const int n_past = prompt.size() + i;
|
||||
|
||||
whisper_batch_prep_legacy(state->batch, &decoder.sequence.tokens.back().id, 1, n_past, j);
|
||||
for (int j = 0; j < n_decoders_cur; ++j) {
|
||||
auto & decoder = state->decoders[j];
|
||||
|
||||
if (decoder.failed || decoder.completed) {
|
||||
continue;
|
||||
}
|
||||
|
||||
//WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, seek_delta %d\n", __func__, j, decoder.sequence.tokens.back().id, decoder.seek_delta);
|
||||
|
||||
decoder.i_batch = batch.n_tokens;
|
||||
|
||||
batch.token [batch.n_tokens] = decoder.sequence.tokens.back().id;
|
||||
batch.pos [batch.n_tokens] = n_past;
|
||||
batch.n_seq_id[batch.n_tokens] = 1;
|
||||
batch.seq_id [batch.n_tokens][0] = j;
|
||||
batch.logits [batch.n_tokens] = 1;
|
||||
batch.n_tokens++;
|
||||
}
|
||||
|
||||
assert(batch.n_tokens > 0);
|
||||
|
||||
if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
|
||||
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
|
||||
return -8;
|
||||
}
|
||||
|
||||
{
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
for (int j = 0; j < n_decoders_cur; ++j) {
|
||||
auto & decoder = state->decoders[j];
|
||||
|
||||
if (decoder.failed || decoder.completed) {
|
||||
continue;
|
||||
}
|
||||
|
||||
whisper_process_logits(*ctx, *state, params, decoder, t_cur);
|
||||
|
||||
state->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
|
||||
state->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user