mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2024-12-18 20:27:53 +00:00
whisper : move kv_self to whisper_state
This commit is contained in:
parent
3cbaaed060
commit
8b943f9843
297
whisper.cpp
297
whisper.cpp
@ -445,7 +445,7 @@ static void whisper_batch_free(struct whisper_batch batch) {
|
||||
if (batch.logits) free(batch.logits);
|
||||
}
|
||||
|
||||
static void whisper_batch_prep_legacy(whisper_batch & batch, const whisper_token * tokens, int n_tokens, int n_past) {
|
||||
static void whisper_batch_prep_legacy(whisper_batch & batch, const whisper_token * tokens, int n_tokens, int n_past, int seq_id) {
|
||||
batch.n_tokens = n_tokens;
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
if (tokens) {
|
||||
@ -453,7 +453,7 @@ static void whisper_batch_prep_legacy(whisper_batch & batch, const whisper_token
|
||||
}
|
||||
batch.pos [i] = n_past + i;
|
||||
batch.n_seq_id[i] = 1;
|
||||
batch.seq_id [i][0] = 0;
|
||||
batch.seq_id [i][0] = seq_id;
|
||||
batch.logits [i] = 0;
|
||||
}
|
||||
batch.logits[n_tokens - 1] = 1;
|
||||
@ -654,11 +654,11 @@ struct whisper_partial_utf8 {
|
||||
};
|
||||
|
||||
struct whisper_grammar {
|
||||
/*const*/ std::vector<std::vector<whisper_grammar_element>> rules;
|
||||
std::vector<std::vector<const whisper_grammar_element *>> stacks;
|
||||
/*const*/ std::vector<std::vector<whisper_grammar_element>> rules;
|
||||
std::vector<std::vector<const whisper_grammar_element *>> stacks;
|
||||
|
||||
// buffer for partially generated UTF-8 sequence from accepted tokens
|
||||
whisper_partial_utf8 partial_utf8;
|
||||
whisper_partial_utf8 partial_utf8;
|
||||
};
|
||||
|
||||
struct whisper_grammar_candidate {
|
||||
@ -682,9 +682,6 @@ struct whisper_sequence {
|
||||
|
||||
// TAGS: WHISPER_DECODER_INIT
|
||||
struct whisper_decoder {
|
||||
// each decoder keeps its own KV-cache
|
||||
whisper_kv_cache kv_self;
|
||||
|
||||
// the currently generated sequence of tokens
|
||||
whisper_sequence sequence;
|
||||
|
||||
@ -701,8 +698,6 @@ struct whisper_decoder {
|
||||
std::vector<float> probs;
|
||||
std::vector<float> logits;
|
||||
std::vector<float> logprobs;
|
||||
|
||||
std::vector<whisper_token> tokens_tmp; // used for whisper_decode calls
|
||||
};
|
||||
|
||||
// replace std::pair by using customized pair struct (reason: std::pair is very slow)
|
||||
@ -717,12 +712,6 @@ struct whisper_pair {
|
||||
whisper_pair() : first(A()), second(B()) {}
|
||||
};
|
||||
|
||||
// beam-search helpers
|
||||
struct kv_buf {
|
||||
std::vector<uint8_t> k;
|
||||
std::vector<uint8_t> v;
|
||||
};
|
||||
|
||||
// ggml_allocr wrapper for whisper usage
|
||||
struct whisper_allocr {
|
||||
ggml_allocr * alloc = nullptr;
|
||||
@ -787,18 +776,19 @@ struct whisper_state {
|
||||
int32_t n_fail_p = 0; // number of logprob threshold failures
|
||||
int32_t n_fail_h = 0; // number of entropy threshold failures
|
||||
|
||||
// unified self-attention KV cache for all decoders
|
||||
whisper_kv_cache kv_self;
|
||||
|
||||
// cross-attention KV cache for the decoders
|
||||
// shared between all decoders
|
||||
whisper_kv_cache kv_cross;
|
||||
|
||||
whisper_mel mel;
|
||||
|
||||
whisper_batch batch;
|
||||
|
||||
whisper_decoder decoders[WHISPER_MAX_DECODERS] = {};
|
||||
|
||||
// buffer for swapping KV caches between decoders during beam-search
|
||||
std::vector<kv_buf> kv_swap_bufs;
|
||||
|
||||
ggml_backend_t backend = nullptr;
|
||||
|
||||
// ggml-alloc:
|
||||
@ -1046,7 +1036,7 @@ static int32_t whisper_kv_cache_cell_max(const struct whisper_kv_cache & cache)
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
return 1;
|
||||
}
|
||||
|
||||
static void whisper_kv_cache_clear(struct whisper_kv_cache & cache) {
|
||||
@ -1057,6 +1047,36 @@ static void whisper_kv_cache_clear(struct whisper_kv_cache & cache) {
|
||||
cache.head = 0;
|
||||
}
|
||||
|
||||
static void whisper_kv_cache_seq_rm(
|
||||
struct whisper_kv_cache & cache,
|
||||
whisper_seq_id seq_id,
|
||||
whisper_pos p0,
|
||||
whisper_pos p1) {
|
||||
uint32_t new_head = cache.size;
|
||||
|
||||
if (p0 < 0) p0 = 0;
|
||||
if (p1 < 0) p1 = std::numeric_limits<whisper_pos>::max();
|
||||
|
||||
for (uint32_t i = 0; i < cache.size; ++i) {
|
||||
if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
|
||||
if (seq_id < 0) {
|
||||
cache.cells[i].seq_id.clear();
|
||||
} else if (cache.cells[i].has_seq_id(seq_id)) {
|
||||
cache.cells[i].seq_id.erase(seq_id);
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
if (cache.cells[i].seq_id.empty()) {
|
||||
cache.cells[i].pos = -1;
|
||||
if (new_head == cache.size) new_head = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we freed up a slot, set head to it so searching can start there.
|
||||
if (new_head != cache.size) cache.head = new_head;
|
||||
}
|
||||
|
||||
static void whisper_kv_cache_seq_cp(
|
||||
struct whisper_kv_cache & cache,
|
||||
whisper_seq_id seq_id_src,
|
||||
@ -2197,13 +2217,12 @@ static bool whisper_encode_internal(
|
||||
static struct ggml_cgraph * whisper_build_graph_decoder(
|
||||
whisper_context & wctx,
|
||||
whisper_state & wstate,
|
||||
whisper_decoder & decoder,
|
||||
const whisper_batch & batch) {
|
||||
const auto & model = wctx.model;
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
// TODO: move to wstate
|
||||
auto & kv_self = decoder.kv_self;
|
||||
auto & kv_self = wstate.kv_self;
|
||||
|
||||
WHISPER_ASSERT(!!kv_self.ctx);
|
||||
|
||||
@ -2374,7 +2393,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
||||
n_kv, n_state/n_head, n_head,
|
||||
n_ctx*ggml_element_size(kv_self.v),
|
||||
n_ctx*ggml_element_size(kv_self.v)*n_state/n_head,
|
||||
il*n_ctx*ggml_element_size(kv_self.v)*n_state);
|
||||
n_ctx*ggml_element_size(kv_self.v)*n_state*il);
|
||||
|
||||
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
|
||||
|
||||
@ -2574,7 +2593,6 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
||||
static bool whisper_decode_internal(
|
||||
whisper_context & wctx,
|
||||
whisper_state & wstate,
|
||||
whisper_decoder & decoder,
|
||||
const whisper_batch & batch,
|
||||
const int n_threads,
|
||||
whisper_abort_callback abort_callback,
|
||||
@ -2590,13 +2608,15 @@ static bool whisper_decode_internal(
|
||||
|
||||
struct ggml_tensor * logits;
|
||||
|
||||
auto & kv_self = decoder.kv_self;
|
||||
auto & kv_self = wstate.kv_self;
|
||||
|
||||
if (!whisper_kv_cache_find_slot(kv_self, batch)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
kv_self.n = std::min((int32_t) hparams.n_text_ctx, std::max(32, whisper_kv_cache_cell_max(kv_self)));
|
||||
kv_self.n = whisper_kv_cache_cell_max(kv_self);
|
||||
//kv_self.n = std::min((int32_t) hparams.n_text_ctx, std::max(32, whisper_kv_cache_cell_max(kv_self)));
|
||||
//printf("n_tokens = %5d, kv_self.head = %5d, kv_self.n = %5d, seq_id = %5d\n", batch.n_tokens, kv_self.head, kv_self.n, batch.seq_id[0][0]);
|
||||
|
||||
// decoder
|
||||
{
|
||||
@ -2604,7 +2624,7 @@ static bool whisper_decode_internal(
|
||||
|
||||
ggml_allocr_reset(alloc);
|
||||
|
||||
ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, decoder, batch);
|
||||
ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, batch);
|
||||
|
||||
ggml_allocr_alloc_graph(alloc, gf);
|
||||
|
||||
@ -3054,14 +3074,14 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
||||
|
||||
state->backend = whisper_backend_init(ctx->params);
|
||||
|
||||
if (!kv_cache_init(ctx->model.hparams, state->decoders[0].kv_self, ctx->backend, ctx->itype, ctx->model.hparams.n_text_ctx)) {
|
||||
if (!kv_cache_init(ctx->model.hparams, state->kv_self, ctx->backend, ctx->itype, ctx->model.hparams.n_text_ctx)) {
|
||||
WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__);
|
||||
delete state;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
{
|
||||
const size_t memory_size = ggml_nbytes(state->decoders[0].kv_self.k) + ggml_nbytes(state->decoders[0].kv_self.v);
|
||||
const size_t memory_size = ggml_nbytes(state->kv_self.k) + ggml_nbytes(state->kv_self.v);
|
||||
WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
|
||||
}
|
||||
|
||||
@ -3147,9 +3167,9 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
||||
const int n_tokens = hparams.n_text_ctx;
|
||||
const int n_past = 0;
|
||||
|
||||
whisper_batch_prep_legacy(state->batch, nullptr, n_tokens, n_past);
|
||||
whisper_batch_prep_legacy(state->batch, nullptr, n_tokens, n_past, 0);
|
||||
|
||||
return whisper_build_graph_decoder(*ctx, *state, state->decoders[0], state->batch);
|
||||
return whisper_build_graph_decoder(*ctx, *state, state->batch);
|
||||
});
|
||||
|
||||
WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1024.0 / 1024.0);
|
||||
@ -3386,12 +3406,9 @@ struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loa
|
||||
void whisper_free_state(struct whisper_state * state)
|
||||
{
|
||||
if (state) {
|
||||
kv_cache_free(state->kv_self);
|
||||
kv_cache_free(state->kv_cross);
|
||||
|
||||
for (int i = 0; i < WHISPER_MAX_DECODERS; ++i) {
|
||||
kv_cache_free(state->decoders[i].kv_self);
|
||||
}
|
||||
|
||||
#ifdef WHISPER_USE_COREML
|
||||
if (state->ctx_coreml != nullptr) {
|
||||
whisper_coreml_free(state->ctx_coreml);
|
||||
@ -3534,11 +3551,9 @@ int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
|
||||
}
|
||||
|
||||
int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
|
||||
const int selected_decoder_id = 0;
|
||||
whisper_batch_prep_legacy(state->batch, tokens, n_tokens, n_past, 0);
|
||||
|
||||
whisper_batch_prep_legacy(state->batch, tokens, n_tokens, n_past);
|
||||
|
||||
if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], state->batch, n_threads, nullptr, nullptr)) {
|
||||
if (!whisper_decode_internal(*ctx, *state, state->batch, n_threads, nullptr, nullptr)) {
|
||||
WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
@ -3547,17 +3562,14 @@ int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state
|
||||
}
|
||||
|
||||
int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
|
||||
// TODO: add selected_decoder_id to state
|
||||
const int selected_decoder_id = 0;
|
||||
|
||||
if (ctx->state == nullptr) {
|
||||
WHISPER_LOG_ERROR("%s: ERROR state was not loaded.\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
whisper_batch_prep_legacy(ctx->state->batch, tokens, n_tokens, n_past);
|
||||
whisper_batch_prep_legacy(ctx->state->batch, tokens, n_tokens, n_past, 0);
|
||||
|
||||
if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], ctx->state->batch, n_threads, nullptr, nullptr)) {
|
||||
if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->batch, n_threads, nullptr, nullptr)) {
|
||||
WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
@ -4178,8 +4190,7 @@ static std::vector<whisper_grammar_candidate> whisper_grammar_reject_candidates_
|
||||
if (*tok.code_points == 0) {
|
||||
// reached end of full codepoints in token, reject iff it ended in a partial sequence
|
||||
// that cannot satisfy this position in grammar
|
||||
if (tok.partial_utf8.n_remain != 0 &&
|
||||
!whisper_grammar_match_partial_char(stack_pos, tok.partial_utf8)) {
|
||||
if (tok.partial_utf8.n_remain != 0 && !whisper_grammar_match_partial_char(stack_pos, tok.partial_utf8)) {
|
||||
rejects.push_back(tok);
|
||||
}
|
||||
} else if (whisper_grammar_match_char(stack_pos, *tok.code_points).first) {
|
||||
@ -5006,125 +5017,6 @@ static void whisper_sequence_score(
|
||||
}
|
||||
}
|
||||
|
||||
static bool whisper_kv_swap_fast(
|
||||
std::vector<int> & view,
|
||||
whisper_decoder src[],
|
||||
std::vector<kv_buf> & kv_swap_bufs,
|
||||
const int & n_decoders) {
|
||||
WHISPER_PRINT_DEBUG("%s: n_decoders %d\n", __func__, n_decoders);
|
||||
|
||||
// (decoder->buffer->decoder or decoder->buffer + decoder->decoder)
|
||||
std::set<int> two_copy; // decoder indices require two copies to safely modify KV caches
|
||||
|
||||
// (buffer->decoder or decoder->decoder)
|
||||
std::set<int> one_copy; // decoder indices require one copy to safely modify KV caches
|
||||
|
||||
// (decoder<->decoder)
|
||||
std::set<int> p_swap_set; // decoder indices able to swap KV-cache pointers
|
||||
std::vector<whisper_pair<int, int>> p_swap_vec;
|
||||
p_swap_vec.reserve(n_decoders);
|
||||
|
||||
// see https://github.com/ggerganov/whisper.cpp/wiki
|
||||
for (int i = 0; i < n_decoders; i++) {
|
||||
// zero-copy (no modification)
|
||||
if (i == view[i] || view[i] < 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
bool is_one_copy = true;
|
||||
// since we modify data sequentially, we only consider decoder indices after current index
|
||||
for (int j = i + 1; j < n_decoders; j++) {
|
||||
if (i == view[j]) {
|
||||
// detect symmetric diagram
|
||||
if (j == view[i]) {
|
||||
p_swap_set.insert(i);
|
||||
p_swap_set.insert(j);
|
||||
p_swap_vec.emplace_back(i, j);
|
||||
} else {
|
||||
two_copy.insert(i);
|
||||
is_one_copy = false;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (is_one_copy) {
|
||||
one_copy.insert(i);
|
||||
}
|
||||
}
|
||||
|
||||
kv_swap_bufs.resize(n_decoders);
|
||||
|
||||
for (int i = 0; i < n_decoders; i++) {
|
||||
kv_swap_bufs[i].k.resize(ggml_nbytes(src[i].kv_self.k));
|
||||
kv_swap_bufs[i].v.resize(ggml_nbytes(src[i].kv_self.v));
|
||||
}
|
||||
|
||||
for (auto & i : two_copy) {
|
||||
// make a copy of KV caches
|
||||
WHISPER_PRINT_DEBUG("%s: store KV cache into swap: idx %d\n", __func__, i);
|
||||
//memcpy(kv_swap_bufs[i].k.data(), src[i].kv_self.k->data, kv_swap_bufs[i].k.size());
|
||||
//memcpy(kv_swap_bufs[i].v.data(), src[i].kv_self.v->data, kv_swap_bufs[i].v.size());
|
||||
ggml_backend_tensor_get(src[i].kv_self.k, kv_swap_bufs[i].k.data(), 0, kv_swap_bufs[i].k.size());
|
||||
ggml_backend_tensor_get(src[i].kv_self.v, kv_swap_bufs[i].v.data(), 0, kv_swap_bufs[i].v.size());
|
||||
}
|
||||
|
||||
// since two-copy decoder KV caches are protected by kv_swap_bufs, modify them first
|
||||
for (auto & i : two_copy) {
|
||||
// skip the decoder indices that require pointer swapping
|
||||
if (p_swap_set.find(i) != p_swap_set.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (two_copy.find(view[i]) != two_copy.end()) {
|
||||
// modify KV caches of decoder using data from kv_swap_bufs
|
||||
WHISPER_PRINT_DEBUG("%s: two-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i);
|
||||
//memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
|
||||
//memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
|
||||
ggml_backend_tensor_set(src[i].kv_self.k, kv_swap_bufs[view[i]].k.data(), 0, kv_swap_bufs[view[i]].k.size());
|
||||
ggml_backend_tensor_set(src[i].kv_self.v, kv_swap_bufs[view[i]].v.data(), 0, kv_swap_bufs[view[i]].v.size());
|
||||
} else {
|
||||
// modify KV caches of decoder using data from correspond decoder KV caches directly
|
||||
WHISPER_PRINT_DEBUG("%s: two-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i);
|
||||
//memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k));
|
||||
//memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v));
|
||||
ggml_backend_tensor_copy(src[view[i]].kv_self.k, src[i].kv_self.k);
|
||||
ggml_backend_tensor_copy(src[view[i]].kv_self.v, src[i].kv_self.v);
|
||||
}
|
||||
}
|
||||
|
||||
// then modify one-copy decoder KV caches
|
||||
for (auto & i : one_copy) {
|
||||
// skip the decoder indices that require pointer swapping
|
||||
if (p_swap_set.find(i) != p_swap_set.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (two_copy.find(view[i]) != two_copy.end()) {
|
||||
// modify KV caches of decoder using data from kv_swap_bufs
|
||||
WHISPER_PRINT_DEBUG("%s: one-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i);
|
||||
//memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
|
||||
//memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
|
||||
ggml_backend_tensor_set(src[i].kv_self.k, kv_swap_bufs[view[i]].k.data(), 0, kv_swap_bufs[view[i]].k.size());
|
||||
ggml_backend_tensor_set(src[i].kv_self.v, kv_swap_bufs[view[i]].v.data(), 0, kv_swap_bufs[view[i]].v.size());
|
||||
} else {
|
||||
// modify KV caches of decoder using data from correspond decoder KV caches directly
|
||||
WHISPER_PRINT_DEBUG("%s: one-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i);
|
||||
//memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k));
|
||||
//memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v));
|
||||
ggml_backend_tensor_copy(src[view[i]].kv_self.k, src[i].kv_self.k);
|
||||
ggml_backend_tensor_copy(src[view[i]].kv_self.v, src[i].kv_self.v);
|
||||
}
|
||||
}
|
||||
|
||||
// swap the pointers
|
||||
for (auto & i : p_swap_vec) {
|
||||
WHISPER_PRINT_DEBUG("%s: swap pointers: %d <-> %d\n", __func__, i.first, i.second);
|
||||
std::swap(src[i.first].kv_self, src[i.second].kv_self);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int whisper_full_with_state(
|
||||
struct whisper_context * ctx,
|
||||
struct whisper_state * state,
|
||||
@ -5218,21 +5110,11 @@ int whisper_full_with_state(
|
||||
for (int j = 1; j < n_decoders; j++) {
|
||||
auto & decoder = state->decoders[j];
|
||||
|
||||
if (decoder.kv_self.ctx == nullptr) {
|
||||
decoder.kv_self = state->decoders[0].kv_self;
|
||||
if (!kv_cache_reinit(decoder.kv_self, ctx->backend)) {
|
||||
WHISPER_LOG_ERROR("%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j);
|
||||
return -4;
|
||||
}
|
||||
decoder.sequence.tokens.reserve(state->decoders[0].sequence.tokens.capacity());
|
||||
|
||||
WHISPER_PRINT_DEBUG("%s: initialized self-attention kv cache, decoder %d\n", __func__, j);
|
||||
|
||||
decoder.sequence.tokens.reserve(state->decoders[0].sequence.tokens.capacity());
|
||||
|
||||
decoder.probs.resize (ctx->vocab.n_vocab);
|
||||
decoder.logits.resize (ctx->vocab.n_vocab);
|
||||
decoder.logprobs.resize(ctx->vocab.n_vocab);
|
||||
}
|
||||
decoder.probs.resize (ctx->vocab.n_vocab);
|
||||
decoder.logits.resize (ctx->vocab.n_vocab);
|
||||
decoder.logprobs.resize(ctx->vocab.n_vocab);
|
||||
}
|
||||
|
||||
// the accumulated text context so far
|
||||
@ -5309,6 +5191,7 @@ int whisper_full_with_state(
|
||||
bool has_ts;
|
||||
|
||||
whisper_sequence sequence;
|
||||
whisper_grammar grammar;
|
||||
};
|
||||
|
||||
std::vector<beam_candidate> beam_candidates;
|
||||
@ -5378,8 +5261,6 @@ int whisper_full_with_state(
|
||||
for (int j = 0; j < n_decoders_cur; ++j) {
|
||||
auto & decoder = state->decoders[j];
|
||||
|
||||
decoder.kv_self.n = 0;
|
||||
|
||||
decoder.sequence.tokens.clear();
|
||||
decoder.sequence.result_len = 0;
|
||||
decoder.sequence.sum_logprobs_all = 0.0;
|
||||
@ -5395,15 +5276,14 @@ int whisper_full_with_state(
|
||||
decoder.has_ts = false;
|
||||
|
||||
if (params.grammar_rules != nullptr) {
|
||||
decoder.grammar = whisper_grammar_init(
|
||||
params.grammar_rules, params.n_grammar_rules, params.i_start_rule);
|
||||
decoder.grammar = whisper_grammar_init(params.grammar_rules, params.n_grammar_rules, params.i_start_rule);
|
||||
} else {
|
||||
decoder.grammar = {};
|
||||
}
|
||||
}
|
||||
|
||||
// init prompt and kv cache for the current iteration
|
||||
// run whisper_decoder() only for decoder 0 and copy the results for the other decoders
|
||||
// TODO: do not recompute the prompt if it is the same as previous time
|
||||
{
|
||||
prompt.clear();
|
||||
|
||||
@ -5425,11 +5305,11 @@ int whisper_full_with_state(
|
||||
}
|
||||
WHISPER_PRINT_DEBUG("\n\n");
|
||||
|
||||
whisper_kv_cache_clear(state->decoders[0].kv_self);
|
||||
whisper_kv_cache_clear(state->kv_self);
|
||||
|
||||
whisper_batch_prep_legacy(state->batch, prompt.data(), prompt.size(), 0);
|
||||
whisper_batch_prep_legacy(state->batch, prompt.data(), prompt.size(), 0, 0);
|
||||
|
||||
if (!whisper_decode_internal(*ctx, *state, state->decoders[0], state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
|
||||
if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
|
||||
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
|
||||
return -7;
|
||||
}
|
||||
@ -5439,18 +5319,10 @@ int whisper_full_with_state(
|
||||
|
||||
whisper_process_logits(*ctx, *state, params, state->decoders[0], t_cur);
|
||||
|
||||
state->decoders[0].kv_self.n += prompt.size();
|
||||
|
||||
for (int j = 1; j < n_decoders_cur; ++j) {
|
||||
auto & decoder = state->decoders[j];
|
||||
|
||||
// TODO: fix CUDA
|
||||
//memcpy(decoder.kv_self.k->data, state->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k));
|
||||
//memcpy(decoder.kv_self.v->data, state->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v));
|
||||
ggml_backend_tensor_copy(state->decoders[0].kv_self.k, decoder.kv_self.k);
|
||||
ggml_backend_tensor_copy(state->decoders[0].kv_self.v, decoder.kv_self.v);
|
||||
|
||||
decoder.kv_self.n += prompt.size();
|
||||
whisper_kv_cache_seq_cp(state->kv_self, 0, j, -1, -1);
|
||||
|
||||
memcpy(decoder.probs.data(), state->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0]));
|
||||
memcpy(decoder.logits.data(), state->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0]));
|
||||
@ -5492,7 +5364,7 @@ int whisper_full_with_state(
|
||||
const auto tokens_new = whisper_sample_token_topk(*ctx, *state, decoder, params.beam_search.beam_size);
|
||||
|
||||
for (const auto & token : tokens_new) {
|
||||
beam_candidates.push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence });
|
||||
beam_candidates.push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence, decoder.grammar, });
|
||||
beam_candidates.back().sequence.tokens.push_back(token);
|
||||
beam_candidates.back().sequence.sum_logprobs_all += token.plog;
|
||||
|
||||
@ -5531,17 +5403,30 @@ int whisper_full_with_state(
|
||||
++cur_c;
|
||||
}
|
||||
|
||||
decoder.sequence = cur.sequence;
|
||||
decoder.seek_delta = cur.seek_delta;
|
||||
decoder.has_ts = cur.has_ts;
|
||||
decoder.sequence = cur.sequence;
|
||||
decoder.grammar = cur.grammar;
|
||||
|
||||
decoder_idx[j] = cur.decoder_idx;
|
||||
|
||||
whisper_kv_cache_seq_cp(state->kv_self, cur.decoder_idx, WHISPER_MAX_DECODERS + j, -1, -1);
|
||||
|
||||
WHISPER_PRINT_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n",
|
||||
__func__, j, cur.decoder_idx, ctx->vocab.id_to_token.at(decoder.sequence.tokens.back().id).c_str(), decoder.sequence.tokens.back().plog, decoder.sequence.sum_logprobs_all);
|
||||
}
|
||||
|
||||
// update KV caches
|
||||
whisper_kv_swap_fast(decoder_idx, state->decoders, state->kv_swap_bufs, n_decoders_cur);
|
||||
for (int j = 0; j < n_decoders_cur; ++j) {
|
||||
auto & decoder = state->decoders[j];
|
||||
|
||||
if (decoder.completed || decoder.failed) {
|
||||
continue;
|
||||
}
|
||||
|
||||
whisper_kv_cache_seq_rm(state->kv_self, j, -1, -1);
|
||||
whisper_kv_cache_seq_cp(state->kv_self, WHISPER_MAX_DECODERS + j, j, -1, -1);
|
||||
whisper_kv_cache_seq_rm(state->kv_self, WHISPER_MAX_DECODERS + j, -1, -1);
|
||||
}
|
||||
}
|
||||
|
||||
// update the decoder state
|
||||
@ -5657,14 +5542,14 @@ int whisper_full_with_state(
|
||||
continue;
|
||||
}
|
||||
|
||||
decoder.tokens_tmp.resize(1);
|
||||
decoder.tokens_tmp[0] = decoder.sequence.tokens.back().id;
|
||||
//WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, seek_delta %d\n", __func__, j, decoder.sequence.tokens.back().id, decoder.seek_delta);
|
||||
|
||||
//WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta);
|
||||
// TODO: use batch
|
||||
const int n_past = prompt.size() + i;
|
||||
|
||||
whisper_batch_prep_legacy(state->batch, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n);
|
||||
whisper_batch_prep_legacy(state->batch, &decoder.sequence.tokens.back().id, 1, n_past, j);
|
||||
|
||||
if (!whisper_decode_internal(*ctx, *state, decoder, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
|
||||
if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
|
||||
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
|
||||
return -8;
|
||||
}
|
||||
@ -5674,8 +5559,6 @@ int whisper_full_with_state(
|
||||
|
||||
whisper_process_logits(*ctx, *state, params, decoder, t_cur);
|
||||
|
||||
++decoder.kv_self.n;
|
||||
|
||||
state->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user