whisper : add whisper_batch

This commit is contained in:
Georgi Gerganov 2023-11-13 22:39:43 +02:00
parent d4231649e6
commit 3cbaaed060
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 273 additions and 56 deletions

View File

@ -406,6 +406,59 @@ struct whisper_segment {
bool speaker_turn_next;
};
struct whisper_batch {
int32_t n_tokens;
whisper_token * token;
whisper_pos * pos;
int32_t * n_seq_id;
whisper_seq_id ** seq_id;
int8_t * logits;
};
static struct whisper_batch whisper_batch_init(int32_t n_tokens, int32_t n_seq_max) {
whisper_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, };
batch.token = (whisper_token *) malloc(sizeof(whisper_token) * n_tokens);
batch.pos = (whisper_pos *) malloc(sizeof(whisper_pos) * n_tokens);
batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens);
batch.seq_id = (whisper_seq_id **) malloc(sizeof(whisper_seq_id *) * n_tokens);
for (int i = 0; i < n_tokens; ++i) {
batch.seq_id[i] = (whisper_seq_id *) malloc(sizeof(whisper_seq_id) * n_seq_max);
}
batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
return batch;
}
static void whisper_batch_free(struct whisper_batch batch) {
if (batch.token) free(batch.token);
if (batch.pos) free(batch.pos);
if (batch.n_seq_id) free(batch.n_seq_id);
if (batch.seq_id) {
for (int i = 0; i < batch.n_tokens; ++i) {
free(batch.seq_id[i]);
}
free(batch.seq_id);
}
if (batch.logits) free(batch.logits);
}
static void whisper_batch_prep_legacy(whisper_batch & batch, const whisper_token * tokens, int n_tokens, int n_past) {
batch.n_tokens = n_tokens;
for (int i = 0; i < n_tokens; ++i) {
if (tokens) {
batch.token[i] = tokens[i];
}
batch.pos [i] = n_past + i;
batch.n_seq_id[i] = 1;
batch.seq_id [i][0] = 0;
batch.logits [i] = 0;
}
batch.logits[n_tokens - 1] = 1;
}
// medium
// hparams: {
// 'n_mels': 80,
@ -523,15 +576,31 @@ struct whisper_layer_decoder {
struct ggml_tensor * mlp_1_b;
};
struct whisper_kv_cell {
whisper_pos pos = -1;
std::set<whisper_seq_id> seq_id;
bool has_seq_id(const whisper_seq_id & id) const {
return seq_id.find(id) != seq_id.end();
}
};
struct whisper_kv_cache {
uint32_t head = 0;
uint32_t size = 0;
// computed before each graph build
uint32_t n = 0;
std::vector<whisper_kv_cell> cells;
struct ggml_tensor * k;
struct ggml_tensor * v;
struct ggml_context * ctx;
ggml_backend_buffer_t buffer;
int n; // number of tokens currently in the cache
};
struct whisper_model {
@ -723,6 +792,8 @@ struct whisper_state {
whisper_kv_cache kv_cross;
whisper_mel mel;
whisper_batch batch;
whisper_decoder decoders[WHISPER_MAX_DECODERS] = {};
// buffer for swapping KV caches between decoders during beam-search
@ -742,8 +813,9 @@ struct whisper_state {
struct ggml_tensor * embd_conv = nullptr;
struct ggml_tensor * embd_enc = nullptr;
// helper for GPU offloading
// helpers for GPU offloading
std::vector<float> inp_mel;
std::vector<float> inp_mask;
// decode output (2-dimensional array: [n_tokens][n_vocab])
std::vector<float> logits;
@ -831,6 +903,12 @@ static bool kv_cache_init(
/*.no_alloc =*/ true,
};
cache.head = 0;
cache.size = n_ctx;
cache.cells.clear();
cache.cells.resize(n_ctx);
cache.ctx = ggml_init(params);
if (!cache.ctx) {
@ -858,6 +936,14 @@ static bool kv_cache_init(
return true;
}
static void kv_cache_free(struct whisper_kv_cache & cache) {
if (cache.ctx) {
ggml_free(cache.ctx);
ggml_backend_buffer_free(cache.buffer);
cache.ctx = nullptr;
}
}
// TODO: remove after batched decoding
static bool kv_cache_reinit(struct whisper_kv_cache & cache, ggml_backend_t backend) {
WHISPER_ASSERT(cache.ctx);
@ -901,11 +987,91 @@ static bool kv_cache_reinit(struct whisper_kv_cache & cache, ggml_backend_t back
return true;
}
static void kv_cache_free(struct whisper_kv_cache & cache) {
if (cache.ctx) {
ggml_free(cache.ctx);
ggml_backend_buffer_free(cache.buffer);
cache.ctx = nullptr;
static bool whisper_kv_cache_find_slot(
struct whisper_kv_cache & cache,
const struct whisper_batch & batch) {
const uint32_t n_ctx = cache.size;
const uint32_t n_tokens = batch.n_tokens;
if (n_tokens > n_ctx) {
WHISPER_LOG_ERROR("%s: n_tokens=%d > n_ctx=%d\n", __func__, n_tokens, n_ctx);
return false;
}
uint32_t n_tested = 0;
while (true) {
if (cache.head + n_tokens > n_ctx) {
n_tested += n_ctx - cache.head;
cache.head = 0;
continue;
}
bool found = true;
for (uint32_t i = 0; i < n_tokens; i++) {
if (cache.cells[cache.head + i].pos >= 0) {
found = false;
cache.head += i + 1;
n_tested += i + 1;
break;
}
}
if (found) {
break;
}
if (n_tested >= n_ctx) {
//WHISPER_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
return false;
}
}
for (uint32_t i = 0; i < n_tokens; i++) {
cache.cells[cache.head + i].pos = batch.pos[i];
for (int32_t j = 0; j < batch.n_seq_id[i]; j++) {
cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i][j]);
}
}
return true;
}
// find how many cells are currently in use
static int32_t whisper_kv_cache_cell_max(const struct whisper_kv_cache & cache) {
for (uint32_t i = cache.size - 1; i > 0; --i) {
if (cache.cells[i].pos >= 0 && !cache.cells[i].seq_id.empty()) {
return i + 1;
}
}
return 0;
}
static void whisper_kv_cache_clear(struct whisper_kv_cache & cache) {
for (int32_t i = 0; i < (int32_t) cache.size; ++i) {
cache.cells[i].pos = -1;
cache.cells[i].seq_id.clear();
}
cache.head = 0;
}
static void whisper_kv_cache_seq_cp(
struct whisper_kv_cache & cache,
whisper_seq_id seq_id_src,
whisper_seq_id seq_id_dst,
whisper_pos p0,
whisper_pos p1) {
if (p0 < 0) p0 = 0;
if (p1 < 0) p1 = std::numeric_limits<whisper_pos>::max();
cache.head = 0;
for (uint32_t i = 0; i < cache.size; ++i) {
if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
cache.cells[i].seq_id.insert(seq_id_dst);
}
}
}
@ -2032,25 +2198,29 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
whisper_context & wctx,
whisper_state & wstate,
whisper_decoder & decoder,
const whisper_token * tokens,
int n_tokens,
int n_past) {
const whisper_batch & batch) {
const auto & model = wctx.model;
const auto & hparams = model.hparams;
// TODO: move to wstate
auto & kv_self = decoder.kv_self;
WHISPER_ASSERT(!!kv_self.ctx);
ggml_allocr * alloc = wstate.alloc_decode.alloc;
const int n_ctx = hparams.n_text_ctx;
const int n_state = hparams.n_text_state;
const int n_head = hparams.n_text_head;
const int n_layer = hparams.n_text_layer;
const int N = n_tokens;
const int M = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
const int n_tokens = batch.n_tokens;
const int n_audio_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
//WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx);
const int32_t n_kv = ggml_allocr_is_measure(alloc) ? n_ctx : kv_self.n;
const int32_t kv_head = ggml_allocr_is_measure(alloc) ? n_ctx - n_tokens : kv_self.head;
//WHISPER_PRINT_DEBUG("%s: n_past = %d, n_tokens = %d, n_audio_ctx = %d, n_ctx = %d\n", __func__, n_past, n_tokens, n_audio_ctx, n_ctx);
struct ggml_init_params params = {
/*.mem_size =*/ wstate.alloc_decode.meta.size(),
@ -2062,21 +2232,19 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false);
ggml_allocr * alloc = wstate.alloc_decode.alloc;
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
ggml_allocr_alloc(alloc, embd);
if (!ggml_allocr_is_measure(alloc)) {
ggml_backend_tensor_set(embd, tokens, 0, N*ggml_element_size(embd));
ggml_backend_tensor_set(embd, batch.token, 0, n_tokens*ggml_element_size(embd));
}
struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
ggml_allocr_alloc(alloc, position);
if (!ggml_allocr_is_measure(alloc)) {
for (int i = 0; i < N; ++i) {
const int32_t val = n_past + i;
for (int i = 0; i < n_tokens; ++i) {
const int32_t val = batch.pos[i];
ggml_backend_tensor_set(position, &val, i*sizeof(int32_t), sizeof(int32_t));
}
}
@ -2089,6 +2257,31 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
ggml_backend_tensor_set(KQscale, &val, 0, sizeof(float));
}
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
ggml_allocr_alloc(alloc, KQ_mask);
if (!ggml_allocr_is_measure(alloc)) {
wstate.inp_mask.resize(n_kv*n_tokens);
float * data = wstate.inp_mask.data();
memset(data, 0, ggml_nbytes(KQ_mask));
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const whisper_pos pos = batch.pos[j];
const whisper_seq_id seq_id = batch.seq_id[j][0];
for (int i = 0; i < n_kv; ++i) {
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
}
}
}
}
ggml_backend_tensor_set(KQ_mask, wstate.inp_mask.data(), 0, ggml_nelements(KQ_mask)*sizeof(float));
}
// token encoding + position encoding
struct ggml_tensor * cur =
ggml_add(ctx0,
@ -2141,12 +2334,12 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
Vcur,
layer.attn_v_b);
Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, N));
Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, n_tokens));
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + n_past));
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_state,
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head));
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_state,
( n_ctx)*ggml_element_size(kv_self.v),
(il*n_ctx)*ggml_element_size(kv_self.v)*n_state + n_past*ggml_element_size(kv_self.v));
(il*n_ctx)*ggml_element_size(kv_self.v)*n_state + kv_head*ggml_element_size(kv_self.v));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
@ -2156,12 +2349,12 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
struct ggml_tensor * Q =
ggml_permute(ctx0,
ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, N),
ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens),
0, 2, 1, 3);
struct ggml_tensor * K =
ggml_view_3d(ctx0, kv_self.k,
n_state/n_head, n_past + N, n_head,
n_state/n_head, n_kv, n_head,
ggml_element_size(kv_self.k)*n_state,
ggml_element_size(kv_self.k)*n_state/n_head,
ggml_element_size(kv_self.k)*n_state*n_ctx*il);
@ -2171,13 +2364,14 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
//struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale);
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past);
//struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past);
struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ, KQ_mask);
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
struct ggml_tensor * V =
ggml_view_3d(ctx0, kv_self.v,
n_past + N, n_state/n_head, n_head,
n_kv, n_state/n_head, n_head,
n_ctx*ggml_element_size(kv_self.v),
n_ctx*ggml_element_size(kv_self.v)*n_state/n_head,
il*n_ctx*ggml_element_size(kv_self.v)*n_state);
@ -2188,7 +2382,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
cur = ggml_cpy(ctx0,
KQV_merged,
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N));
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens));
}
// projection
@ -2232,33 +2426,33 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
// Kcross is already scaled
struct ggml_tensor * Kcross =
ggml_view_3d(ctx0, wstate.kv_cross.k,
n_state/n_head, M, n_head,
n_state/n_head, n_audio_ctx, n_head,
ggml_element_size(wstate.kv_cross.k)*n_state,
ggml_element_size(wstate.kv_cross.k)*n_state/n_head,
ggml_element_size(wstate.kv_cross.k)*n_state*M*il);
ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il);
//struct ggml_tensor * Vcross =
// ggml_reshape_3d(ctx0,
// ggml_view_1d(ctx0, wstate.kv_cross.v, M*n_state, il*M*ggml_element_size(wstate.kv_cross.v)*n_state),
// n_state/n_head, n_head, M);
// ggml_view_1d(ctx0, wstate.kv_cross.v, n_audio_ctx*n_state, il*n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state),
// n_state/n_head, n_head, n_audio_ctx);
//struct ggml_tensor * V_trans =
// ggml_cpy(ctx0,
// ggml_permute(ctx0, Vcross, 1, 2, 0, 3),
// ggml_new_tensor_3d(ctx0, Vcross->type, M, n_state/n_head, n_head));
// ggml_new_tensor_3d(ctx0, Vcross->type, n_audio_ctx, n_state/n_head, n_head));
struct ggml_tensor * V =
ggml_view_3d(ctx0, wstate.kv_cross.v,
M, n_state/n_head, n_head,
M*ggml_element_size(wstate.kv_cross.v),
M*ggml_element_size(wstate.kv_cross.v)*n_state/n_head,
il*M*ggml_element_size(wstate.kv_cross.v)*n_state);
n_audio_ctx, n_state/n_head, n_head,
n_audio_ctx*ggml_element_size(wstate.kv_cross.v),
n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state/n_head,
n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state*il);
// ------
struct ggml_tensor * Q =
ggml_permute(ctx0,
ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, N),
ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens),
0, 2, 1, 3);
// K * Q
@ -2279,10 +2473,10 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
// cur = KQV_merged.contiguous().view(n_state, N)
// cur = KQV_merged.contiguous().view(n_state, n_tokens)
cur = ggml_cpy(ctx0,
KQV_merged,
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N));
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens));
}
// projection
@ -2354,7 +2548,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
}
// compute logits only for the last token
// comment this line to compute logits for all N tokens
// comment this line to compute logits for all n_tokens
// might be useful in the future
cur = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]);
@ -2381,9 +2575,7 @@ static bool whisper_decode_internal(
whisper_context & wctx,
whisper_state & wstate,
whisper_decoder & decoder,
const whisper_token * tokens,
const int n_tokens,
const int n_past,
const whisper_batch & batch,
const int n_threads,
whisper_abort_callback abort_callback,
void * abort_callback_data) {
@ -2398,13 +2590,21 @@ static bool whisper_decode_internal(
struct ggml_tensor * logits;
auto & kv_self = decoder.kv_self;
if (!whisper_kv_cache_find_slot(kv_self, batch)) {
return 1;
}
kv_self.n = std::min((int32_t) hparams.n_text_ctx, std::max(32, whisper_kv_cache_cell_max(kv_self)));
// decoder
{
auto & alloc = wstate.alloc_decode.alloc;
ggml_allocr_reset(alloc);
ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, decoder, tokens, n_tokens, n_past);
ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, decoder, batch);
ggml_allocr_alloc_graph(alloc, gf);
@ -2423,7 +2623,7 @@ static bool whisper_decode_internal(
//memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab);
ggml_backend_tensor_get(logits, logits_out.data(), 0, sizeof(float)*n_vocab);
if (n_tokens > 1) {
if (batch.n_tokens > 1) {
//printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
// ggml_used_mem(ctx0)/1024.0/1024.0,
// wstate.get_buf_max_mem(0)/1024.0/1024.0,
@ -2432,7 +2632,7 @@ static bool whisper_decode_internal(
// wstate.get_buf_max_mem(3)/1024.0/1024.0);
}
if (n_tokens == 1) {
if (batch.n_tokens == 1) {
wstate.t_decode_us += ggml_time_us() - t_start_us;
wstate.n_decode++;
} else {
@ -2443,7 +2643,6 @@ static bool whisper_decode_internal(
return !(abort_callback && abort_callback(abort_callback_data));
}
// 500 -> 00:05.000
// 6000 -> 01:00.000
static std::string to_timestamp(int64_t t, bool comma = false) {
@ -2899,6 +3098,8 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
state->logits_id.reserve(ctx->model.hparams.n_vocab);
state->batch = whisper_batch_init(ctx->model.hparams.n_text_ctx, WHISPER_MAX_DECODERS);
// TAGS: WHISPER_DECODER_INIT
state->decoders[0].sequence.tokens.reserve(ctx->model.hparams.n_text_ctx);
@ -2946,7 +3147,9 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
const int n_tokens = hparams.n_text_ctx;
const int n_past = 0;
return whisper_build_graph_decoder(*ctx, *state, state->decoders[0], nullptr, n_tokens, n_past);
whisper_batch_prep_legacy(state->batch, nullptr, n_tokens, n_past);
return whisper_build_graph_decoder(*ctx, *state, state->decoders[0], state->batch);
});
WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1024.0 / 1024.0);
@ -3203,6 +3406,8 @@ void whisper_free_state(struct whisper_state * state)
}
#endif
whisper_batch_free(state->batch);
whisper_allocr_free(state->alloc_conv);
whisper_allocr_free(state->alloc_encode);
whisper_allocr_free(state->alloc_cross);
@ -3331,7 +3536,9 @@ int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
const int selected_decoder_id = 0;
if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
whisper_batch_prep_legacy(state->batch, tokens, n_tokens, n_past);
if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], state->batch, n_threads, nullptr, nullptr)) {
WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
return 1;
}
@ -3348,7 +3555,9 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i
return false;
}
if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
whisper_batch_prep_legacy(ctx->state->batch, tokens, n_tokens, n_past);
if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], ctx->state->batch, n_threads, nullptr, nullptr)) {
WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
return 1;
}
@ -5216,7 +5425,11 @@ int whisper_full_with_state(
}
WHISPER_PRINT_DEBUG("\n\n");
if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
whisper_kv_cache_clear(state->decoders[0].kv_self);
whisper_batch_prep_legacy(state->batch, prompt.data(), prompt.size(), 0);
if (!whisper_decode_internal(*ctx, *state, state->decoders[0], state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
return -7;
}
@ -5449,7 +5662,9 @@ int whisper_full_with_state(
//WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta);
if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
whisper_batch_prep_legacy(state->batch, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n);
if (!whisper_decode_internal(*ctx, *state, decoder, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
return -8;
}

View File

@ -78,7 +78,9 @@ extern "C" {
struct whisper_state;
struct whisper_full_params;
typedef int whisper_token;
typedef int32_t whisper_pos;
typedef int32_t whisper_token;
typedef int32_t whisper_seq_id;
struct whisper_context_params {
bool use_gpu;