mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2024-12-18 20:27:53 +00:00
whisper : add whisper_batch
This commit is contained in:
parent
d4231649e6
commit
3cbaaed060
325
whisper.cpp
325
whisper.cpp
@ -406,6 +406,59 @@ struct whisper_segment {
|
||||
bool speaker_turn_next;
|
||||
};
|
||||
|
||||
struct whisper_batch {
|
||||
int32_t n_tokens;
|
||||
|
||||
whisper_token * token;
|
||||
whisper_pos * pos;
|
||||
int32_t * n_seq_id;
|
||||
whisper_seq_id ** seq_id;
|
||||
int8_t * logits;
|
||||
};
|
||||
|
||||
static struct whisper_batch whisper_batch_init(int32_t n_tokens, int32_t n_seq_max) {
|
||||
whisper_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, };
|
||||
|
||||
batch.token = (whisper_token *) malloc(sizeof(whisper_token) * n_tokens);
|
||||
|
||||
batch.pos = (whisper_pos *) malloc(sizeof(whisper_pos) * n_tokens);
|
||||
batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens);
|
||||
batch.seq_id = (whisper_seq_id **) malloc(sizeof(whisper_seq_id *) * n_tokens);
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
batch.seq_id[i] = (whisper_seq_id *) malloc(sizeof(whisper_seq_id) * n_seq_max);
|
||||
}
|
||||
batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
|
||||
|
||||
return batch;
|
||||
}
|
||||
|
||||
static void whisper_batch_free(struct whisper_batch batch) {
|
||||
if (batch.token) free(batch.token);
|
||||
if (batch.pos) free(batch.pos);
|
||||
if (batch.n_seq_id) free(batch.n_seq_id);
|
||||
if (batch.seq_id) {
|
||||
for (int i = 0; i < batch.n_tokens; ++i) {
|
||||
free(batch.seq_id[i]);
|
||||
}
|
||||
free(batch.seq_id);
|
||||
}
|
||||
if (batch.logits) free(batch.logits);
|
||||
}
|
||||
|
||||
static void whisper_batch_prep_legacy(whisper_batch & batch, const whisper_token * tokens, int n_tokens, int n_past) {
|
||||
batch.n_tokens = n_tokens;
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
if (tokens) {
|
||||
batch.token[i] = tokens[i];
|
||||
}
|
||||
batch.pos [i] = n_past + i;
|
||||
batch.n_seq_id[i] = 1;
|
||||
batch.seq_id [i][0] = 0;
|
||||
batch.logits [i] = 0;
|
||||
}
|
||||
batch.logits[n_tokens - 1] = 1;
|
||||
}
|
||||
|
||||
// medium
|
||||
// hparams: {
|
||||
// 'n_mels': 80,
|
||||
@ -523,15 +576,31 @@ struct whisper_layer_decoder {
|
||||
struct ggml_tensor * mlp_1_b;
|
||||
};
|
||||
|
||||
struct whisper_kv_cell {
|
||||
whisper_pos pos = -1;
|
||||
|
||||
std::set<whisper_seq_id> seq_id;
|
||||
|
||||
bool has_seq_id(const whisper_seq_id & id) const {
|
||||
return seq_id.find(id) != seq_id.end();
|
||||
}
|
||||
};
|
||||
|
||||
struct whisper_kv_cache {
|
||||
uint32_t head = 0;
|
||||
uint32_t size = 0;
|
||||
|
||||
// computed before each graph build
|
||||
uint32_t n = 0;
|
||||
|
||||
std::vector<whisper_kv_cell> cells;
|
||||
|
||||
struct ggml_tensor * k;
|
||||
struct ggml_tensor * v;
|
||||
|
||||
struct ggml_context * ctx;
|
||||
|
||||
ggml_backend_buffer_t buffer;
|
||||
|
||||
int n; // number of tokens currently in the cache
|
||||
};
|
||||
|
||||
struct whisper_model {
|
||||
@ -723,6 +792,8 @@ struct whisper_state {
|
||||
whisper_kv_cache kv_cross;
|
||||
whisper_mel mel;
|
||||
|
||||
whisper_batch batch;
|
||||
|
||||
whisper_decoder decoders[WHISPER_MAX_DECODERS] = {};
|
||||
|
||||
// buffer for swapping KV caches between decoders during beam-search
|
||||
@ -742,8 +813,9 @@ struct whisper_state {
|
||||
struct ggml_tensor * embd_conv = nullptr;
|
||||
struct ggml_tensor * embd_enc = nullptr;
|
||||
|
||||
// helper for GPU offloading
|
||||
// helpers for GPU offloading
|
||||
std::vector<float> inp_mel;
|
||||
std::vector<float> inp_mask;
|
||||
|
||||
// decode output (2-dimensional array: [n_tokens][n_vocab])
|
||||
std::vector<float> logits;
|
||||
@ -831,6 +903,12 @@ static bool kv_cache_init(
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
|
||||
cache.head = 0;
|
||||
cache.size = n_ctx;
|
||||
|
||||
cache.cells.clear();
|
||||
cache.cells.resize(n_ctx);
|
||||
|
||||
cache.ctx = ggml_init(params);
|
||||
|
||||
if (!cache.ctx) {
|
||||
@ -858,6 +936,14 @@ static bool kv_cache_init(
|
||||
return true;
|
||||
}
|
||||
|
||||
static void kv_cache_free(struct whisper_kv_cache & cache) {
|
||||
if (cache.ctx) {
|
||||
ggml_free(cache.ctx);
|
||||
ggml_backend_buffer_free(cache.buffer);
|
||||
cache.ctx = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: remove after batched decoding
|
||||
static bool kv_cache_reinit(struct whisper_kv_cache & cache, ggml_backend_t backend) {
|
||||
WHISPER_ASSERT(cache.ctx);
|
||||
@ -901,11 +987,91 @@ static bool kv_cache_reinit(struct whisper_kv_cache & cache, ggml_backend_t back
|
||||
return true;
|
||||
}
|
||||
|
||||
static void kv_cache_free(struct whisper_kv_cache & cache) {
|
||||
if (cache.ctx) {
|
||||
ggml_free(cache.ctx);
|
||||
ggml_backend_buffer_free(cache.buffer);
|
||||
cache.ctx = nullptr;
|
||||
static bool whisper_kv_cache_find_slot(
|
||||
struct whisper_kv_cache & cache,
|
||||
const struct whisper_batch & batch) {
|
||||
const uint32_t n_ctx = cache.size;
|
||||
const uint32_t n_tokens = batch.n_tokens;
|
||||
|
||||
if (n_tokens > n_ctx) {
|
||||
WHISPER_LOG_ERROR("%s: n_tokens=%d > n_ctx=%d\n", __func__, n_tokens, n_ctx);
|
||||
return false;
|
||||
}
|
||||
|
||||
uint32_t n_tested = 0;
|
||||
|
||||
while (true) {
|
||||
if (cache.head + n_tokens > n_ctx) {
|
||||
n_tested += n_ctx - cache.head;
|
||||
cache.head = 0;
|
||||
continue;
|
||||
}
|
||||
|
||||
bool found = true;
|
||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||
if (cache.cells[cache.head + i].pos >= 0) {
|
||||
found = false;
|
||||
cache.head += i + 1;
|
||||
n_tested += i + 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (found) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (n_tested >= n_ctx) {
|
||||
//WHISPER_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||
cache.cells[cache.head + i].pos = batch.pos[i];
|
||||
|
||||
for (int32_t j = 0; j < batch.n_seq_id[i]; j++) {
|
||||
cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i][j]);
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// find how many cells are currently in use
|
||||
static int32_t whisper_kv_cache_cell_max(const struct whisper_kv_cache & cache) {
|
||||
for (uint32_t i = cache.size - 1; i > 0; --i) {
|
||||
if (cache.cells[i].pos >= 0 && !cache.cells[i].seq_id.empty()) {
|
||||
return i + 1;
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static void whisper_kv_cache_clear(struct whisper_kv_cache & cache) {
|
||||
for (int32_t i = 0; i < (int32_t) cache.size; ++i) {
|
||||
cache.cells[i].pos = -1;
|
||||
cache.cells[i].seq_id.clear();
|
||||
}
|
||||
cache.head = 0;
|
||||
}
|
||||
|
||||
static void whisper_kv_cache_seq_cp(
|
||||
struct whisper_kv_cache & cache,
|
||||
whisper_seq_id seq_id_src,
|
||||
whisper_seq_id seq_id_dst,
|
||||
whisper_pos p0,
|
||||
whisper_pos p1) {
|
||||
if (p0 < 0) p0 = 0;
|
||||
if (p1 < 0) p1 = std::numeric_limits<whisper_pos>::max();
|
||||
|
||||
cache.head = 0;
|
||||
|
||||
for (uint32_t i = 0; i < cache.size; ++i) {
|
||||
if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
|
||||
cache.cells[i].seq_id.insert(seq_id_dst);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -2032,25 +2198,29 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
||||
whisper_context & wctx,
|
||||
whisper_state & wstate,
|
||||
whisper_decoder & decoder,
|
||||
const whisper_token * tokens,
|
||||
int n_tokens,
|
||||
int n_past) {
|
||||
const whisper_batch & batch) {
|
||||
const auto & model = wctx.model;
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
// TODO: move to wstate
|
||||
auto & kv_self = decoder.kv_self;
|
||||
|
||||
WHISPER_ASSERT(!!kv_self.ctx);
|
||||
|
||||
ggml_allocr * alloc = wstate.alloc_decode.alloc;
|
||||
|
||||
const int n_ctx = hparams.n_text_ctx;
|
||||
const int n_state = hparams.n_text_state;
|
||||
const int n_head = hparams.n_text_head;
|
||||
const int n_layer = hparams.n_text_layer;
|
||||
|
||||
const int N = n_tokens;
|
||||
const int M = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
|
||||
const int n_tokens = batch.n_tokens;
|
||||
const int n_audio_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
|
||||
|
||||
//WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx);
|
||||
const int32_t n_kv = ggml_allocr_is_measure(alloc) ? n_ctx : kv_self.n;
|
||||
const int32_t kv_head = ggml_allocr_is_measure(alloc) ? n_ctx - n_tokens : kv_self.head;
|
||||
|
||||
//WHISPER_PRINT_DEBUG("%s: n_past = %d, n_tokens = %d, n_audio_ctx = %d, n_ctx = %d\n", __func__, n_past, n_tokens, n_audio_ctx, n_ctx);
|
||||
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ wstate.alloc_decode.meta.size(),
|
||||
@ -2062,21 +2232,19 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
||||
|
||||
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false);
|
||||
|
||||
ggml_allocr * alloc = wstate.alloc_decode.alloc;
|
||||
|
||||
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
||||
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
|
||||
ggml_allocr_alloc(alloc, embd);
|
||||
|
||||
if (!ggml_allocr_is_measure(alloc)) {
|
||||
ggml_backend_tensor_set(embd, tokens, 0, N*ggml_element_size(embd));
|
||||
ggml_backend_tensor_set(embd, batch.token, 0, n_tokens*ggml_element_size(embd));
|
||||
}
|
||||
|
||||
struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
||||
struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
|
||||
ggml_allocr_alloc(alloc, position);
|
||||
|
||||
if (!ggml_allocr_is_measure(alloc)) {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
const int32_t val = n_past + i;
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
const int32_t val = batch.pos[i];
|
||||
ggml_backend_tensor_set(position, &val, i*sizeof(int32_t), sizeof(int32_t));
|
||||
}
|
||||
}
|
||||
@ -2089,6 +2257,31 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
||||
ggml_backend_tensor_set(KQscale, &val, 0, sizeof(float));
|
||||
}
|
||||
|
||||
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
|
||||
ggml_allocr_alloc(alloc, KQ_mask);
|
||||
|
||||
if (!ggml_allocr_is_measure(alloc)) {
|
||||
wstate.inp_mask.resize(n_kv*n_tokens);
|
||||
|
||||
float * data = wstate.inp_mask.data();
|
||||
memset(data, 0, ggml_nbytes(KQ_mask));
|
||||
|
||||
for (int h = 0; h < 1; ++h) {
|
||||
for (int j = 0; j < n_tokens; ++j) {
|
||||
const whisper_pos pos = batch.pos[j];
|
||||
const whisper_seq_id seq_id = batch.seq_id[j][0];
|
||||
|
||||
for (int i = 0; i < n_kv; ++i) {
|
||||
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
|
||||
data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ggml_backend_tensor_set(KQ_mask, wstate.inp_mask.data(), 0, ggml_nelements(KQ_mask)*sizeof(float));
|
||||
}
|
||||
|
||||
// token encoding + position encoding
|
||||
struct ggml_tensor * cur =
|
||||
ggml_add(ctx0,
|
||||
@ -2141,12 +2334,12 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
||||
Vcur,
|
||||
layer.attn_v_b);
|
||||
|
||||
Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, N));
|
||||
Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, n_tokens));
|
||||
|
||||
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + n_past));
|
||||
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_state,
|
||||
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head));
|
||||
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_state,
|
||||
( n_ctx)*ggml_element_size(kv_self.v),
|
||||
(il*n_ctx)*ggml_element_size(kv_self.v)*n_state + n_past*ggml_element_size(kv_self.v));
|
||||
(il*n_ctx)*ggml_element_size(kv_self.v)*n_state + kv_head*ggml_element_size(kv_self.v));
|
||||
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
|
||||
@ -2156,12 +2349,12 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
||||
|
||||
struct ggml_tensor * Q =
|
||||
ggml_permute(ctx0,
|
||||
ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, N),
|
||||
ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens),
|
||||
0, 2, 1, 3);
|
||||
|
||||
struct ggml_tensor * K =
|
||||
ggml_view_3d(ctx0, kv_self.k,
|
||||
n_state/n_head, n_past + N, n_head,
|
||||
n_state/n_head, n_kv, n_head,
|
||||
ggml_element_size(kv_self.k)*n_state,
|
||||
ggml_element_size(kv_self.k)*n_state/n_head,
|
||||
ggml_element_size(kv_self.k)*n_state*n_ctx*il);
|
||||
@ -2171,13 +2364,14 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
||||
|
||||
//struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale);
|
||||
|
||||
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past);
|
||||
//struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past);
|
||||
struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ, KQ_mask);
|
||||
|
||||
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
|
||||
|
||||
struct ggml_tensor * V =
|
||||
ggml_view_3d(ctx0, kv_self.v,
|
||||
n_past + N, n_state/n_head, n_head,
|
||||
n_kv, n_state/n_head, n_head,
|
||||
n_ctx*ggml_element_size(kv_self.v),
|
||||
n_ctx*ggml_element_size(kv_self.v)*n_state/n_head,
|
||||
il*n_ctx*ggml_element_size(kv_self.v)*n_state);
|
||||
@ -2188,7 +2382,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
||||
|
||||
cur = ggml_cpy(ctx0,
|
||||
KQV_merged,
|
||||
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N));
|
||||
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens));
|
||||
}
|
||||
|
||||
// projection
|
||||
@ -2232,33 +2426,33 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
||||
// Kcross is already scaled
|
||||
struct ggml_tensor * Kcross =
|
||||
ggml_view_3d(ctx0, wstate.kv_cross.k,
|
||||
n_state/n_head, M, n_head,
|
||||
n_state/n_head, n_audio_ctx, n_head,
|
||||
ggml_element_size(wstate.kv_cross.k)*n_state,
|
||||
ggml_element_size(wstate.kv_cross.k)*n_state/n_head,
|
||||
ggml_element_size(wstate.kv_cross.k)*n_state*M*il);
|
||||
ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il);
|
||||
|
||||
//struct ggml_tensor * Vcross =
|
||||
// ggml_reshape_3d(ctx0,
|
||||
// ggml_view_1d(ctx0, wstate.kv_cross.v, M*n_state, il*M*ggml_element_size(wstate.kv_cross.v)*n_state),
|
||||
// n_state/n_head, n_head, M);
|
||||
// ggml_view_1d(ctx0, wstate.kv_cross.v, n_audio_ctx*n_state, il*n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state),
|
||||
// n_state/n_head, n_head, n_audio_ctx);
|
||||
|
||||
//struct ggml_tensor * V_trans =
|
||||
// ggml_cpy(ctx0,
|
||||
// ggml_permute(ctx0, Vcross, 1, 2, 0, 3),
|
||||
// ggml_new_tensor_3d(ctx0, Vcross->type, M, n_state/n_head, n_head));
|
||||
// ggml_new_tensor_3d(ctx0, Vcross->type, n_audio_ctx, n_state/n_head, n_head));
|
||||
|
||||
struct ggml_tensor * V =
|
||||
ggml_view_3d(ctx0, wstate.kv_cross.v,
|
||||
M, n_state/n_head, n_head,
|
||||
M*ggml_element_size(wstate.kv_cross.v),
|
||||
M*ggml_element_size(wstate.kv_cross.v)*n_state/n_head,
|
||||
il*M*ggml_element_size(wstate.kv_cross.v)*n_state);
|
||||
n_audio_ctx, n_state/n_head, n_head,
|
||||
n_audio_ctx*ggml_element_size(wstate.kv_cross.v),
|
||||
n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state/n_head,
|
||||
n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state*il);
|
||||
|
||||
// ------
|
||||
|
||||
struct ggml_tensor * Q =
|
||||
ggml_permute(ctx0,
|
||||
ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, N),
|
||||
ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens),
|
||||
0, 2, 1, 3);
|
||||
|
||||
// K * Q
|
||||
@ -2279,10 +2473,10 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
||||
|
||||
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
||||
|
||||
// cur = KQV_merged.contiguous().view(n_state, N)
|
||||
// cur = KQV_merged.contiguous().view(n_state, n_tokens)
|
||||
cur = ggml_cpy(ctx0,
|
||||
KQV_merged,
|
||||
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N));
|
||||
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens));
|
||||
}
|
||||
|
||||
// projection
|
||||
@ -2354,7 +2548,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
||||
}
|
||||
|
||||
// compute logits only for the last token
|
||||
// comment this line to compute logits for all N tokens
|
||||
// comment this line to compute logits for all n_tokens
|
||||
// might be useful in the future
|
||||
cur = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]);
|
||||
|
||||
@ -2381,9 +2575,7 @@ static bool whisper_decode_internal(
|
||||
whisper_context & wctx,
|
||||
whisper_state & wstate,
|
||||
whisper_decoder & decoder,
|
||||
const whisper_token * tokens,
|
||||
const int n_tokens,
|
||||
const int n_past,
|
||||
const whisper_batch & batch,
|
||||
const int n_threads,
|
||||
whisper_abort_callback abort_callback,
|
||||
void * abort_callback_data) {
|
||||
@ -2398,13 +2590,21 @@ static bool whisper_decode_internal(
|
||||
|
||||
struct ggml_tensor * logits;
|
||||
|
||||
auto & kv_self = decoder.kv_self;
|
||||
|
||||
if (!whisper_kv_cache_find_slot(kv_self, batch)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
kv_self.n = std::min((int32_t) hparams.n_text_ctx, std::max(32, whisper_kv_cache_cell_max(kv_self)));
|
||||
|
||||
// decoder
|
||||
{
|
||||
auto & alloc = wstate.alloc_decode.alloc;
|
||||
|
||||
ggml_allocr_reset(alloc);
|
||||
|
||||
ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, decoder, tokens, n_tokens, n_past);
|
||||
ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, decoder, batch);
|
||||
|
||||
ggml_allocr_alloc_graph(alloc, gf);
|
||||
|
||||
@ -2423,7 +2623,7 @@ static bool whisper_decode_internal(
|
||||
//memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab);
|
||||
ggml_backend_tensor_get(logits, logits_out.data(), 0, sizeof(float)*n_vocab);
|
||||
|
||||
if (n_tokens > 1) {
|
||||
if (batch.n_tokens > 1) {
|
||||
//printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
|
||||
// ggml_used_mem(ctx0)/1024.0/1024.0,
|
||||
// wstate.get_buf_max_mem(0)/1024.0/1024.0,
|
||||
@ -2432,7 +2632,7 @@ static bool whisper_decode_internal(
|
||||
// wstate.get_buf_max_mem(3)/1024.0/1024.0);
|
||||
}
|
||||
|
||||
if (n_tokens == 1) {
|
||||
if (batch.n_tokens == 1) {
|
||||
wstate.t_decode_us += ggml_time_us() - t_start_us;
|
||||
wstate.n_decode++;
|
||||
} else {
|
||||
@ -2443,7 +2643,6 @@ static bool whisper_decode_internal(
|
||||
return !(abort_callback && abort_callback(abort_callback_data));
|
||||
}
|
||||
|
||||
|
||||
// 500 -> 00:05.000
|
||||
// 6000 -> 01:00.000
|
||||
static std::string to_timestamp(int64_t t, bool comma = false) {
|
||||
@ -2899,6 +3098,8 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
||||
|
||||
state->logits_id.reserve(ctx->model.hparams.n_vocab);
|
||||
|
||||
state->batch = whisper_batch_init(ctx->model.hparams.n_text_ctx, WHISPER_MAX_DECODERS);
|
||||
|
||||
// TAGS: WHISPER_DECODER_INIT
|
||||
state->decoders[0].sequence.tokens.reserve(ctx->model.hparams.n_text_ctx);
|
||||
|
||||
@ -2946,7 +3147,9 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
||||
const int n_tokens = hparams.n_text_ctx;
|
||||
const int n_past = 0;
|
||||
|
||||
return whisper_build_graph_decoder(*ctx, *state, state->decoders[0], nullptr, n_tokens, n_past);
|
||||
whisper_batch_prep_legacy(state->batch, nullptr, n_tokens, n_past);
|
||||
|
||||
return whisper_build_graph_decoder(*ctx, *state, state->decoders[0], state->batch);
|
||||
});
|
||||
|
||||
WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1024.0 / 1024.0);
|
||||
@ -3203,6 +3406,8 @@ void whisper_free_state(struct whisper_state * state)
|
||||
}
|
||||
#endif
|
||||
|
||||
whisper_batch_free(state->batch);
|
||||
|
||||
whisper_allocr_free(state->alloc_conv);
|
||||
whisper_allocr_free(state->alloc_encode);
|
||||
whisper_allocr_free(state->alloc_cross);
|
||||
@ -3331,7 +3536,9 @@ int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
|
||||
int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
|
||||
const int selected_decoder_id = 0;
|
||||
|
||||
if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
|
||||
whisper_batch_prep_legacy(state->batch, tokens, n_tokens, n_past);
|
||||
|
||||
if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], state->batch, n_threads, nullptr, nullptr)) {
|
||||
WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
@ -3348,7 +3555,9 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
|
||||
whisper_batch_prep_legacy(ctx->state->batch, tokens, n_tokens, n_past);
|
||||
|
||||
if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], ctx->state->batch, n_threads, nullptr, nullptr)) {
|
||||
WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
@ -5216,7 +5425,11 @@ int whisper_full_with_state(
|
||||
}
|
||||
WHISPER_PRINT_DEBUG("\n\n");
|
||||
|
||||
if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
|
||||
whisper_kv_cache_clear(state->decoders[0].kv_self);
|
||||
|
||||
whisper_batch_prep_legacy(state->batch, prompt.data(), prompt.size(), 0);
|
||||
|
||||
if (!whisper_decode_internal(*ctx, *state, state->decoders[0], state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
|
||||
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
|
||||
return -7;
|
||||
}
|
||||
@ -5449,7 +5662,9 @@ int whisper_full_with_state(
|
||||
|
||||
//WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta);
|
||||
|
||||
if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
|
||||
whisper_batch_prep_legacy(state->batch, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n);
|
||||
|
||||
if (!whisper_decode_internal(*ctx, *state, decoder, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
|
||||
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
|
||||
return -8;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user