wip : experimental color coding of tokens based on probabilities

This commit is contained in:
Georgi Gerganov
2022-10-21 17:33:59 +03:00
parent f4aa01c2f8
commit 31ff0c6a1f
3 changed files with 134 additions and 57 deletions

View File

@ -210,9 +210,12 @@ struct whisper_vocab {
}
};
struct whisper_result {
int64_t t;
whisper_token id;
struct whisper_token_data {
whisper_token id; // token id
whisper_token tid; // forced timestamp token id
float p; // probability of the token
float pt; // probability of the timestamp token
};
struct whisper_segment {
@ -220,6 +223,8 @@ struct whisper_segment {
int64_t t1;
std::string text;
std::vector<whisper_token_data> tokens;
};
// medium
@ -407,7 +412,7 @@ struct whisper_context {
std::vector<float> probs;
std::vector<float> logits;
std::vector<whisper_result> result_cur;
std::vector<whisper_token_data> tokens_cur;
std::vector<whisper_segment> result_all;
std::vector<whisper_token> prompt_past;
@ -1786,9 +1791,11 @@ bool whisper_decode(
}
// the most basic sampling scheme - select the top token
whisper_vocab::id whisper_sample_best(
whisper_token_data whisper_sample_best(
const whisper_vocab & vocab,
const float * probs) {
whisper_token_data result;
int n_logits = vocab.id_to_token.size();
std::vector<std::pair<double, whisper_vocab::id>> probs_id;
@ -1798,24 +1805,33 @@ whisper_vocab::id whisper_sample_best(
probs_id.push_back(std::make_pair(probs[i], i));
}
double sum_ts = 0.0;
double max_tx = 0.0;
{
double sum_ts = 0.0;
double max_ts = -1.0;
double max_tx = -1.0;
for (int i = 0; i < vocab.token_beg; i++) {
max_tx = std::max(max_tx, probs_id[i].first);
}
for (int i = vocab.token_beg; i < n_logits; i++) {
sum_ts += probs_id[i].first;
}
// if the probability sum of all timestamp tokesn is higher than the max probability of the text tokens - sample a
// timestamp token
if (sum_ts > max_tx) {
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L430-L438
for (int i = 0; i < vocab.token_beg; i++) {
probs_id[i].first = -INFINITY;
max_tx = std::max(max_tx, probs_id[i].first);
}
for (int i = vocab.token_beg; i < n_logits; i++) {
sum_ts += probs_id[i].first;
if (probs_id[i].first > max_ts) {
max_ts = probs_id[i].first;
result.tid = probs_id[i].second;
}
}
// if the probability sum of all timestamp tokens is higher than the max probability of the text tokens - sample a
// timestamp token
if (sum_ts > max_tx) {
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L430-L438
for (int i = 0; i < vocab.token_beg; i++) {
probs_id[i].first = -INFINITY;
}
}
result.pt = max_ts/(sum_ts + 1e-6);
}
// find the top K tokens
@ -1843,7 +1859,10 @@ whisper_vocab::id whisper_sample_best(
res++;
}
return probs_id[res].second;
result.id = probs_id[res].second;
result.p = probs_id[res].first;
return result;
}
// samples only from the timestamps tokens
@ -2178,7 +2197,7 @@ whisper_token whisper_sample_best(struct whisper_context * ctx) {
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
return res;
return res.id;
}
whisper_token whisper_sample_timestamp(struct whisper_context * ctx) {
@ -2343,7 +2362,7 @@ int whisper_full(
int n_samples) {
// clear old results
auto & result_all = ctx->result_all;
auto & result_cur = ctx->result_cur;
auto & tokens_cur = ctx->tokens_cur;
result_all.clear();
@ -2430,7 +2449,7 @@ int whisper_full(
// the accumulated transcription in the current interation
int result_len = 0;
result_cur.clear();
tokens_cur.clear();
for (int i = 0; i < whisper_n_text_ctx(ctx)/2 - 4; ++i) {
if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) {
@ -2449,28 +2468,26 @@ int whisper_full(
// feel free to experiment!
//
{
whisper_token id = 0;
whisper_token tid = whisper_token_beg(ctx);
auto token = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab));
id = whisper_sample_best(ctx);
if (i > 0) {
tid = whisper_sample_timestamp(ctx);
if (i == 0) {
token.tid = whisper_token_beg(ctx);
}
// update sliding window
if (id > whisper_token_beg(ctx)) {
seek_delta = 2*(id - whisper_token_beg(ctx));
// timestamp token - update sliding window
if (token.id > whisper_token_beg(ctx)) {
seek_delta = 2*(token.id - whisper_token_beg(ctx));
result_len = i + 1;
}
// add it to the context
prompt.push_back(id);
result_cur.push_back({ seek + 2*(tid - whisper_token_beg(ctx)), id });
prompt.push_back(token.id);
tokens_cur.push_back(token);
//printf("%s: %s\n", __func__, ctx->vocab.id_to_token[id].c_str());
// end of text token
if (id == whisper_token_eot(ctx)) {
if (token.id == whisper_token_eot(ctx)) {
if (result_len == 0) {
if (seek + seek_delta + 100 >= whisper_n_len(ctx)) {
result_len = i + 1;
@ -2494,25 +2511,30 @@ int whisper_full(
}
}
result_cur.resize(result_len);
tokens_cur.resize(result_len);
for (const auto & r : result_cur) {
for (const auto & r : tokens_cur) {
prompt_past.push_back(r.id);
}
// store the text from this iteration
if (result_cur.size() > 0) {
auto t0 = result_cur.front().t;
if (tokens_cur.size() > 0) {
int i0 = 0;
auto t0 = 2*(tokens_cur.front().tid - whisper_token_beg(ctx));
std::string text = "";
for (int i = 0; i < (int) result_cur.size(); i++) {
if (params.print_special_tokens == false && result_cur[i].id >= whisper_token_eot(ctx)) {
for (int i = 0; i < (int) tokens_cur.size(); i++) {
//printf("%s: %18s %6.3f %18s %6.3f\n", __func__,
// ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p,
// ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt);
if (params.print_special_tokens == false && tokens_cur[i].id >= whisper_token_eot(ctx)) {
} else {
text += whisper_token_to_str(ctx, result_cur[i].id);
text += whisper_token_to_str(ctx, tokens_cur[i].id);
}
if (result_cur[i].id > whisper_token_beg(ctx)) {
const auto t1 = result_cur[i].t;
if (tokens_cur[i].id > whisper_token_beg(ctx)) {
const auto t1 = 2*(tokens_cur[i].tid - whisper_token_beg(ctx));
if (!text.empty()) {
if (params.print_realtime) {
if (params.print_timestamps) {
@ -2523,14 +2545,18 @@ int whisper_full(
}
}
result_all.push_back({ t0, t1, text });
result_all.push_back({ t0, t1, text, {} });
for (int j = i0; j <= i; j++) {
result_all.back().tokens.push_back(tokens_cur[j]);
}
}
text = "";
while (i < (int) result_cur.size() && result_cur[i].id > whisper_token_beg(ctx)) {
while (i < (int) tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) {
i++;
}
i--;
t0 = result_cur[i].t;
t0 = t1;
i0 = i + 1;
}
}
@ -2546,7 +2572,10 @@ int whisper_full(
}
}
result_all.push_back({ t0, t1, text });
result_all.push_back({ t0, t1, text, {} });
for (int j = i0; j < (int) tokens_cur.size(); j++) {
result_all.back().tokens.push_back(tokens_cur[j]);
}
}
}
@ -2571,3 +2600,15 @@ int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment)
const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment) {
return ctx->result_all[i_segment].text.c_str();
}
int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment) {
return ctx->result_all[i_segment].tokens.size();
}
const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token) {
return ctx->vocab.id_to_token[ctx->result_all[i_segment].tokens[i_token].id].c_str();
}
float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token) {
return ctx->result_all[i_segment].tokens[i_token].p;
}