whisper : move kv_self to whisper_state

This commit is contained in:
Georgi Gerganov 2023-11-14 11:04:15 +02:00
parent 3cbaaed060
commit 8b943f9843
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -445,7 +445,7 @@ static void whisper_batch_free(struct whisper_batch batch) {
if (batch.logits) free(batch.logits);
}
static void whisper_batch_prep_legacy(whisper_batch & batch, const whisper_token * tokens, int n_tokens, int n_past) {
static void whisper_batch_prep_legacy(whisper_batch & batch, const whisper_token * tokens, int n_tokens, int n_past, int seq_id) {
batch.n_tokens = n_tokens;
for (int i = 0; i < n_tokens; ++i) {
if (tokens) {
@ -453,7 +453,7 @@ static void whisper_batch_prep_legacy(whisper_batch & batch, const whisper_token
}
batch.pos [i] = n_past + i;
batch.n_seq_id[i] = 1;
batch.seq_id [i][0] = 0;
batch.seq_id [i][0] = seq_id;
batch.logits [i] = 0;
}
batch.logits[n_tokens - 1] = 1;
@ -654,11 +654,11 @@ struct whisper_partial_utf8 {
};
struct whisper_grammar {
/*const*/ std::vector<std::vector<whisper_grammar_element>> rules;
std::vector<std::vector<const whisper_grammar_element *>> stacks;
/*const*/ std::vector<std::vector<whisper_grammar_element>> rules;
std::vector<std::vector<const whisper_grammar_element *>> stacks;
// buffer for partially generated UTF-8 sequence from accepted tokens
whisper_partial_utf8 partial_utf8;
whisper_partial_utf8 partial_utf8;
};
struct whisper_grammar_candidate {
@ -682,9 +682,6 @@ struct whisper_sequence {
// TAGS: WHISPER_DECODER_INIT
struct whisper_decoder {
// each decoder keeps its own KV-cache
whisper_kv_cache kv_self;
// the currently generated sequence of tokens
whisper_sequence sequence;
@ -701,8 +698,6 @@ struct whisper_decoder {
std::vector<float> probs;
std::vector<float> logits;
std::vector<float> logprobs;
std::vector<whisper_token> tokens_tmp; // used for whisper_decode calls
};
// replace std::pair by using customized pair struct (reason: std::pair is very slow)
@ -717,12 +712,6 @@ struct whisper_pair {
whisper_pair() : first(A()), second(B()) {}
};
// beam-search helpers
struct kv_buf {
std::vector<uint8_t> k;
std::vector<uint8_t> v;
};
// ggml_allocr wrapper for whisper usage
struct whisper_allocr {
ggml_allocr * alloc = nullptr;
@ -787,18 +776,19 @@ struct whisper_state {
int32_t n_fail_p = 0; // number of logprob threshold failures
int32_t n_fail_h = 0; // number of entropy threshold failures
// unified self-attention KV cache for all decoders
whisper_kv_cache kv_self;
// cross-attention KV cache for the decoders
// shared between all decoders
whisper_kv_cache kv_cross;
whisper_mel mel;
whisper_batch batch;
whisper_decoder decoders[WHISPER_MAX_DECODERS] = {};
// buffer for swapping KV caches between decoders during beam-search
std::vector<kv_buf> kv_swap_bufs;
ggml_backend_t backend = nullptr;
// ggml-alloc:
@ -1046,7 +1036,7 @@ static int32_t whisper_kv_cache_cell_max(const struct whisper_kv_cache & cache)
}
}
return 0;
return 1;
}
static void whisper_kv_cache_clear(struct whisper_kv_cache & cache) {
@ -1057,6 +1047,36 @@ static void whisper_kv_cache_clear(struct whisper_kv_cache & cache) {
cache.head = 0;
}
static void whisper_kv_cache_seq_rm(
struct whisper_kv_cache & cache,
whisper_seq_id seq_id,
whisper_pos p0,
whisper_pos p1) {
uint32_t new_head = cache.size;
if (p0 < 0) p0 = 0;
if (p1 < 0) p1 = std::numeric_limits<whisper_pos>::max();
for (uint32_t i = 0; i < cache.size; ++i) {
if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
if (seq_id < 0) {
cache.cells[i].seq_id.clear();
} else if (cache.cells[i].has_seq_id(seq_id)) {
cache.cells[i].seq_id.erase(seq_id);
} else {
continue;
}
if (cache.cells[i].seq_id.empty()) {
cache.cells[i].pos = -1;
if (new_head == cache.size) new_head = i;
}
}
}
// If we freed up a slot, set head to it so searching can start there.
if (new_head != cache.size) cache.head = new_head;
}
static void whisper_kv_cache_seq_cp(
struct whisper_kv_cache & cache,
whisper_seq_id seq_id_src,
@ -2197,13 +2217,12 @@ static bool whisper_encode_internal(
static struct ggml_cgraph * whisper_build_graph_decoder(
whisper_context & wctx,
whisper_state & wstate,
whisper_decoder & decoder,
const whisper_batch & batch) {
const auto & model = wctx.model;
const auto & hparams = model.hparams;
// TODO: move to wstate
auto & kv_self = decoder.kv_self;
auto & kv_self = wstate.kv_self;
WHISPER_ASSERT(!!kv_self.ctx);
@ -2374,7 +2393,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
n_kv, n_state/n_head, n_head,
n_ctx*ggml_element_size(kv_self.v),
n_ctx*ggml_element_size(kv_self.v)*n_state/n_head,
il*n_ctx*ggml_element_size(kv_self.v)*n_state);
n_ctx*ggml_element_size(kv_self.v)*n_state*il);
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
@ -2574,7 +2593,6 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
static bool whisper_decode_internal(
whisper_context & wctx,
whisper_state & wstate,
whisper_decoder & decoder,
const whisper_batch & batch,
const int n_threads,
whisper_abort_callback abort_callback,
@ -2590,13 +2608,15 @@ static bool whisper_decode_internal(
struct ggml_tensor * logits;
auto & kv_self = decoder.kv_self;
auto & kv_self = wstate.kv_self;
if (!whisper_kv_cache_find_slot(kv_self, batch)) {
return 1;
}
kv_self.n = std::min((int32_t) hparams.n_text_ctx, std::max(32, whisper_kv_cache_cell_max(kv_self)));
kv_self.n = whisper_kv_cache_cell_max(kv_self);
//kv_self.n = std::min((int32_t) hparams.n_text_ctx, std::max(32, whisper_kv_cache_cell_max(kv_self)));
//printf("n_tokens = %5d, kv_self.head = %5d, kv_self.n = %5d, seq_id = %5d\n", batch.n_tokens, kv_self.head, kv_self.n, batch.seq_id[0][0]);
// decoder
{
@ -2604,7 +2624,7 @@ static bool whisper_decode_internal(
ggml_allocr_reset(alloc);
ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, decoder, batch);
ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, batch);
ggml_allocr_alloc_graph(alloc, gf);
@ -3054,14 +3074,14 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
state->backend = whisper_backend_init(ctx->params);
if (!kv_cache_init(ctx->model.hparams, state->decoders[0].kv_self, ctx->backend, ctx->itype, ctx->model.hparams.n_text_ctx)) {
if (!kv_cache_init(ctx->model.hparams, state->kv_self, ctx->backend, ctx->itype, ctx->model.hparams.n_text_ctx)) {
WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__);
delete state;
return nullptr;
}
{
const size_t memory_size = ggml_nbytes(state->decoders[0].kv_self.k) + ggml_nbytes(state->decoders[0].kv_self.v);
const size_t memory_size = ggml_nbytes(state->kv_self.k) + ggml_nbytes(state->kv_self.v);
WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
}
@ -3147,9 +3167,9 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
const int n_tokens = hparams.n_text_ctx;
const int n_past = 0;
whisper_batch_prep_legacy(state->batch, nullptr, n_tokens, n_past);
whisper_batch_prep_legacy(state->batch, nullptr, n_tokens, n_past, 0);
return whisper_build_graph_decoder(*ctx, *state, state->decoders[0], state->batch);
return whisper_build_graph_decoder(*ctx, *state, state->batch);
});
WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1024.0 / 1024.0);
@ -3386,12 +3406,9 @@ struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loa
void whisper_free_state(struct whisper_state * state)
{
if (state) {
kv_cache_free(state->kv_self);
kv_cache_free(state->kv_cross);
for (int i = 0; i < WHISPER_MAX_DECODERS; ++i) {
kv_cache_free(state->decoders[i].kv_self);
}
#ifdef WHISPER_USE_COREML
if (state->ctx_coreml != nullptr) {
whisper_coreml_free(state->ctx_coreml);
@ -3534,11 +3551,9 @@ int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
}
int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
const int selected_decoder_id = 0;
whisper_batch_prep_legacy(state->batch, tokens, n_tokens, n_past, 0);
whisper_batch_prep_legacy(state->batch, tokens, n_tokens, n_past);
if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], state->batch, n_threads, nullptr, nullptr)) {
if (!whisper_decode_internal(*ctx, *state, state->batch, n_threads, nullptr, nullptr)) {
WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
return 1;
}
@ -3547,17 +3562,14 @@ int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state
}
int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
// TODO: add selected_decoder_id to state
const int selected_decoder_id = 0;
if (ctx->state == nullptr) {
WHISPER_LOG_ERROR("%s: ERROR state was not loaded.\n", __func__);
return false;
}
whisper_batch_prep_legacy(ctx->state->batch, tokens, n_tokens, n_past);
whisper_batch_prep_legacy(ctx->state->batch, tokens, n_tokens, n_past, 0);
if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], ctx->state->batch, n_threads, nullptr, nullptr)) {
if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->batch, n_threads, nullptr, nullptr)) {
WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
return 1;
}
@ -4178,8 +4190,7 @@ static std::vector<whisper_grammar_candidate> whisper_grammar_reject_candidates_
if (*tok.code_points == 0) {
// reached end of full codepoints in token, reject iff it ended in a partial sequence
// that cannot satisfy this position in grammar
if (tok.partial_utf8.n_remain != 0 &&
!whisper_grammar_match_partial_char(stack_pos, tok.partial_utf8)) {
if (tok.partial_utf8.n_remain != 0 && !whisper_grammar_match_partial_char(stack_pos, tok.partial_utf8)) {
rejects.push_back(tok);
}
} else if (whisper_grammar_match_char(stack_pos, *tok.code_points).first) {
@ -5006,125 +5017,6 @@ static void whisper_sequence_score(
}
}
static bool whisper_kv_swap_fast(
std::vector<int> & view,
whisper_decoder src[],
std::vector<kv_buf> & kv_swap_bufs,
const int & n_decoders) {
WHISPER_PRINT_DEBUG("%s: n_decoders %d\n", __func__, n_decoders);
// (decoder->buffer->decoder or decoder->buffer + decoder->decoder)
std::set<int> two_copy; // decoder indices require two copies to safely modify KV caches
// (buffer->decoder or decoder->decoder)
std::set<int> one_copy; // decoder indices require one copy to safely modify KV caches
// (decoder<->decoder)
std::set<int> p_swap_set; // decoder indices able to swap KV-cache pointers
std::vector<whisper_pair<int, int>> p_swap_vec;
p_swap_vec.reserve(n_decoders);
// see https://github.com/ggerganov/whisper.cpp/wiki
for (int i = 0; i < n_decoders; i++) {
// zero-copy (no modification)
if (i == view[i] || view[i] < 0) {
continue;
}
bool is_one_copy = true;
// since we modify data sequentially, we only consider decoder indices after current index
for (int j = i + 1; j < n_decoders; j++) {
if (i == view[j]) {
// detect symmetric diagram
if (j == view[i]) {
p_swap_set.insert(i);
p_swap_set.insert(j);
p_swap_vec.emplace_back(i, j);
} else {
two_copy.insert(i);
is_one_copy = false;
}
break;
}
}
if (is_one_copy) {
one_copy.insert(i);
}
}
kv_swap_bufs.resize(n_decoders);
for (int i = 0; i < n_decoders; i++) {
kv_swap_bufs[i].k.resize(ggml_nbytes(src[i].kv_self.k));
kv_swap_bufs[i].v.resize(ggml_nbytes(src[i].kv_self.v));
}
for (auto & i : two_copy) {
// make a copy of KV caches
WHISPER_PRINT_DEBUG("%s: store KV cache into swap: idx %d\n", __func__, i);
//memcpy(kv_swap_bufs[i].k.data(), src[i].kv_self.k->data, kv_swap_bufs[i].k.size());
//memcpy(kv_swap_bufs[i].v.data(), src[i].kv_self.v->data, kv_swap_bufs[i].v.size());
ggml_backend_tensor_get(src[i].kv_self.k, kv_swap_bufs[i].k.data(), 0, kv_swap_bufs[i].k.size());
ggml_backend_tensor_get(src[i].kv_self.v, kv_swap_bufs[i].v.data(), 0, kv_swap_bufs[i].v.size());
}
// since two-copy decoder KV caches are protected by kv_swap_bufs, modify them first
for (auto & i : two_copy) {
// skip the decoder indices that require pointer swapping
if (p_swap_set.find(i) != p_swap_set.end()) {
continue;
}
if (two_copy.find(view[i]) != two_copy.end()) {
// modify KV caches of decoder using data from kv_swap_bufs
WHISPER_PRINT_DEBUG("%s: two-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i);
//memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
//memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
ggml_backend_tensor_set(src[i].kv_self.k, kv_swap_bufs[view[i]].k.data(), 0, kv_swap_bufs[view[i]].k.size());
ggml_backend_tensor_set(src[i].kv_self.v, kv_swap_bufs[view[i]].v.data(), 0, kv_swap_bufs[view[i]].v.size());
} else {
// modify KV caches of decoder using data from correspond decoder KV caches directly
WHISPER_PRINT_DEBUG("%s: two-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i);
//memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k));
//memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v));
ggml_backend_tensor_copy(src[view[i]].kv_self.k, src[i].kv_self.k);
ggml_backend_tensor_copy(src[view[i]].kv_self.v, src[i].kv_self.v);
}
}
// then modify one-copy decoder KV caches
for (auto & i : one_copy) {
// skip the decoder indices that require pointer swapping
if (p_swap_set.find(i) != p_swap_set.end()) {
continue;
}
if (two_copy.find(view[i]) != two_copy.end()) {
// modify KV caches of decoder using data from kv_swap_bufs
WHISPER_PRINT_DEBUG("%s: one-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i);
//memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
//memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
ggml_backend_tensor_set(src[i].kv_self.k, kv_swap_bufs[view[i]].k.data(), 0, kv_swap_bufs[view[i]].k.size());
ggml_backend_tensor_set(src[i].kv_self.v, kv_swap_bufs[view[i]].v.data(), 0, kv_swap_bufs[view[i]].v.size());
} else {
// modify KV caches of decoder using data from correspond decoder KV caches directly
WHISPER_PRINT_DEBUG("%s: one-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i);
//memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k));
//memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v));
ggml_backend_tensor_copy(src[view[i]].kv_self.k, src[i].kv_self.k);
ggml_backend_tensor_copy(src[view[i]].kv_self.v, src[i].kv_self.v);
}
}
// swap the pointers
for (auto & i : p_swap_vec) {
WHISPER_PRINT_DEBUG("%s: swap pointers: %d <-> %d\n", __func__, i.first, i.second);
std::swap(src[i.first].kv_self, src[i.second].kv_self);
}
return true;
}
int whisper_full_with_state(
struct whisper_context * ctx,
struct whisper_state * state,
@ -5218,21 +5110,11 @@ int whisper_full_with_state(
for (int j = 1; j < n_decoders; j++) {
auto & decoder = state->decoders[j];
if (decoder.kv_self.ctx == nullptr) {
decoder.kv_self = state->decoders[0].kv_self;
if (!kv_cache_reinit(decoder.kv_self, ctx->backend)) {
WHISPER_LOG_ERROR("%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j);
return -4;
}
decoder.sequence.tokens.reserve(state->decoders[0].sequence.tokens.capacity());
WHISPER_PRINT_DEBUG("%s: initialized self-attention kv cache, decoder %d\n", __func__, j);
decoder.sequence.tokens.reserve(state->decoders[0].sequence.tokens.capacity());
decoder.probs.resize (ctx->vocab.n_vocab);
decoder.logits.resize (ctx->vocab.n_vocab);
decoder.logprobs.resize(ctx->vocab.n_vocab);
}
decoder.probs.resize (ctx->vocab.n_vocab);
decoder.logits.resize (ctx->vocab.n_vocab);
decoder.logprobs.resize(ctx->vocab.n_vocab);
}
// the accumulated text context so far
@ -5309,6 +5191,7 @@ int whisper_full_with_state(
bool has_ts;
whisper_sequence sequence;
whisper_grammar grammar;
};
std::vector<beam_candidate> beam_candidates;
@ -5378,8 +5261,6 @@ int whisper_full_with_state(
for (int j = 0; j < n_decoders_cur; ++j) {
auto & decoder = state->decoders[j];
decoder.kv_self.n = 0;
decoder.sequence.tokens.clear();
decoder.sequence.result_len = 0;
decoder.sequence.sum_logprobs_all = 0.0;
@ -5395,15 +5276,14 @@ int whisper_full_with_state(
decoder.has_ts = false;
if (params.grammar_rules != nullptr) {
decoder.grammar = whisper_grammar_init(
params.grammar_rules, params.n_grammar_rules, params.i_start_rule);
decoder.grammar = whisper_grammar_init(params.grammar_rules, params.n_grammar_rules, params.i_start_rule);
} else {
decoder.grammar = {};
}
}
// init prompt and kv cache for the current iteration
// run whisper_decoder() only for decoder 0 and copy the results for the other decoders
// TODO: do not recompute the prompt if it is the same as previous time
{
prompt.clear();
@ -5425,11 +5305,11 @@ int whisper_full_with_state(
}
WHISPER_PRINT_DEBUG("\n\n");
whisper_kv_cache_clear(state->decoders[0].kv_self);
whisper_kv_cache_clear(state->kv_self);
whisper_batch_prep_legacy(state->batch, prompt.data(), prompt.size(), 0);
whisper_batch_prep_legacy(state->batch, prompt.data(), prompt.size(), 0, 0);
if (!whisper_decode_internal(*ctx, *state, state->decoders[0], state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
return -7;
}
@ -5439,18 +5319,10 @@ int whisper_full_with_state(
whisper_process_logits(*ctx, *state, params, state->decoders[0], t_cur);
state->decoders[0].kv_self.n += prompt.size();
for (int j = 1; j < n_decoders_cur; ++j) {
auto & decoder = state->decoders[j];
// TODO: fix CUDA
//memcpy(decoder.kv_self.k->data, state->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k));
//memcpy(decoder.kv_self.v->data, state->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v));
ggml_backend_tensor_copy(state->decoders[0].kv_self.k, decoder.kv_self.k);
ggml_backend_tensor_copy(state->decoders[0].kv_self.v, decoder.kv_self.v);
decoder.kv_self.n += prompt.size();
whisper_kv_cache_seq_cp(state->kv_self, 0, j, -1, -1);
memcpy(decoder.probs.data(), state->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0]));
memcpy(decoder.logits.data(), state->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0]));
@ -5492,7 +5364,7 @@ int whisper_full_with_state(
const auto tokens_new = whisper_sample_token_topk(*ctx, *state, decoder, params.beam_search.beam_size);
for (const auto & token : tokens_new) {
beam_candidates.push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence });
beam_candidates.push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence, decoder.grammar, });
beam_candidates.back().sequence.tokens.push_back(token);
beam_candidates.back().sequence.sum_logprobs_all += token.plog;
@ -5531,17 +5403,30 @@ int whisper_full_with_state(
++cur_c;
}
decoder.sequence = cur.sequence;
decoder.seek_delta = cur.seek_delta;
decoder.has_ts = cur.has_ts;
decoder.sequence = cur.sequence;
decoder.grammar = cur.grammar;
decoder_idx[j] = cur.decoder_idx;
whisper_kv_cache_seq_cp(state->kv_self, cur.decoder_idx, WHISPER_MAX_DECODERS + j, -1, -1);
WHISPER_PRINT_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n",
__func__, j, cur.decoder_idx, ctx->vocab.id_to_token.at(decoder.sequence.tokens.back().id).c_str(), decoder.sequence.tokens.back().plog, decoder.sequence.sum_logprobs_all);
}
// update KV caches
whisper_kv_swap_fast(decoder_idx, state->decoders, state->kv_swap_bufs, n_decoders_cur);
for (int j = 0; j < n_decoders_cur; ++j) {
auto & decoder = state->decoders[j];
if (decoder.completed || decoder.failed) {
continue;
}
whisper_kv_cache_seq_rm(state->kv_self, j, -1, -1);
whisper_kv_cache_seq_cp(state->kv_self, WHISPER_MAX_DECODERS + j, j, -1, -1);
whisper_kv_cache_seq_rm(state->kv_self, WHISPER_MAX_DECODERS + j, -1, -1);
}
}
// update the decoder state
@ -5657,14 +5542,14 @@ int whisper_full_with_state(
continue;
}
decoder.tokens_tmp.resize(1);
decoder.tokens_tmp[0] = decoder.sequence.tokens.back().id;
//WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, seek_delta %d\n", __func__, j, decoder.sequence.tokens.back().id, decoder.seek_delta);
//WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta);
// TODO: use batch
const int n_past = prompt.size() + i;
whisper_batch_prep_legacy(state->batch, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n);
whisper_batch_prep_legacy(state->batch, &decoder.sequence.tokens.back().id, 1, n_past, j);
if (!whisper_decode_internal(*ctx, *state, decoder, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
return -8;
}
@ -5674,8 +5559,6 @@ int whisper_full_with_state(
whisper_process_logits(*ctx, *state, params, decoder, t_cur);
++decoder.kv_self.n;
state->t_sample_us += ggml_time_us() - t_start_sample_us;
}
}