whisper : support speaker segmentation (local diarization) of mono audio via tinydiarize (#1058)

* add HuggingFace mirror to download  ggml model

* support tdrz via simple hack overriding solm tokens

* fix incorrect translate/transcribe token_ids that are not static const

* add apollo 13 sample for tdrz demo

* render [SPEAKER TURN] consistently in all terminal output using vocab.id_to_token

* extend whisper_segment with speaker_turn_next field and save in json output

* fix failing go build

* slipped in some python syntax whoops

* whisper : finalize tinydiarize support (add flag + fixes)

* whisper : tdrz support for word-level timestamps (respect max_len)

* java : try to fix tests after adding tdrz_enable flag

* main : remove TODO leftover

* java : fix params order list after adding "tdrz_enable"

* whisper : fix solm and add nosp token

* main : print tinydiarize help

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Akash Mahajan
2023-07-03 23:45:00 -07:00
committed by GitHub
parent fdf58a6668
commit c8d0f5fe98
8 changed files with 215 additions and 130 deletions

View File

@ -380,16 +380,18 @@ struct whisper_vocab {
std::map<token, id> token_to_id;
std::map<id, token> id_to_token;
id token_eot = 50256;
id token_sot = 50257;
id token_prev = 50360;
id token_solm = 50361; // ??
id token_not = 50362; // no timestamps
id token_beg = 50363;
// available tasks
static const id token_translate = 50358;
static const id token_transcribe = 50359;
// reference: https://github.com/openai/whisper/blob/248b6cb124225dd263bb9bd32d060b6517e067f8/whisper/tokenizer.py#L334-L349
id token_eot = 50256;
id token_sot = 50257;
// task tokens (used only for multilingual models)
id token_translate = 50357;
id token_transcribe = 50358;
// other special tokens
id token_solm = 50359; // [TDRZ] used by tinydiarize models to indicate speaker turn
id token_prev = 50360;
id token_nosp = 50361;
id token_not = 50362; // no timestamps
id token_beg = 50363; // begin timestamps
bool is_multilingual() const {
return n_vocab == 51865;
@ -403,6 +405,8 @@ struct whisper_segment {
std::string text;
std::vector<whisper_token_data> tokens;
bool speaker_turn_next;
};
// medium
@ -966,8 +970,11 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
if (vocab.is_multilingual()) {
vocab.token_eot++;
vocab.token_sot++;
vocab.token_prev++;
vocab.token_translate++;
vocab.token_transcribe++;
vocab.token_solm++;
vocab.token_prev++;
vocab.token_nosp++;
vocab.token_not++;
vocab.token_beg++;
}
@ -981,8 +988,12 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
word = "[_EOT_]";
} else if (i == vocab.token_sot) {
word = "[_SOT_]";
} else if (i == vocab.token_solm) {
word = "[_SOLM_]";
} else if (i == vocab.token_prev) {
word = "[_PREV_]";
} else if (i == vocab.token_nosp) {
word = "[_NOSP_]";
} else if (i == vocab.token_not) {
word = "[_NOT_]";
} else if (i == vocab.token_beg) {
@ -3208,12 +3219,16 @@ whisper_token whisper_token_sot(struct whisper_context * ctx) {
return ctx->vocab.token_sot;
}
whisper_token whisper_token_solm(struct whisper_context * ctx) {
return ctx->vocab.token_solm;
}
whisper_token whisper_token_prev(struct whisper_context * ctx) {
return ctx->vocab.token_prev;
}
whisper_token whisper_token_solm(struct whisper_context * ctx) {
return ctx->vocab.token_solm;
whisper_token whisper_token_nosp(struct whisper_context * ctx) {
return ctx->vocab.token_nosp;
}
whisper_token whisper_token_not(struct whisper_context * ctx) {
@ -3228,12 +3243,12 @@ whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id) {
return whisper_token_sot(ctx) + 1 + lang_id;
}
whisper_token whisper_token_translate(void) {
return whisper_vocab::token_translate;
whisper_token whisper_token_translate(struct whisper_context * ctx) {
return ctx->vocab.token_translate;
}
whisper_token whisper_token_transcribe(void) {
return whisper_vocab::token_transcribe;
whisper_token whisper_token_transcribe(struct whisper_context * ctx) {
return ctx->vocab.token_transcribe;
}
void whisper_print_timings(struct whisper_context * ctx) {
@ -3305,51 +3320,53 @@ struct whisper_full_params * whisper_full_default_params_by_ref(enum whisper_sam
struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) {
struct whisper_full_params result = {
/*.strategy =*/ strategy,
/*.strategy =*/ strategy,
/*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
/*.n_max_text_ctx =*/ 16384,
/*.offset_ms =*/ 0,
/*.duration_ms =*/ 0,
/*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
/*.n_max_text_ctx =*/ 16384,
/*.offset_ms =*/ 0,
/*.duration_ms =*/ 0,
/*.translate =*/ false,
/*.no_context =*/ true,
/*.single_segment =*/ false,
/*.print_special =*/ false,
/*.print_progress =*/ true,
/*.print_realtime =*/ false,
/*.print_timestamps =*/ true,
/*.translate =*/ false,
/*.no_context =*/ true,
/*.single_segment =*/ false,
/*.print_special =*/ false,
/*.print_progress =*/ true,
/*.print_realtime =*/ false,
/*.print_timestamps =*/ true,
/*.token_timestamps =*/ false,
/*.thold_pt =*/ 0.01f,
/*.thold_ptsum =*/ 0.01f,
/*.max_len =*/ 0,
/*.split_on_word =*/ false,
/*.max_tokens =*/ 0,
/*.token_timestamps =*/ false,
/*.thold_pt =*/ 0.01f,
/*.thold_ptsum =*/ 0.01f,
/*.max_len =*/ 0,
/*.split_on_word =*/ false,
/*.max_tokens =*/ 0,
/*.speed_up =*/ false,
/*.audio_ctx =*/ 0,
/*.speed_up =*/ false,
/*.audio_ctx =*/ 0,
/*.initial_prompt =*/ nullptr,
/*.prompt_tokens =*/ nullptr,
/*.prompt_n_tokens =*/ 0,
/*.tdrz_enable =*/ false,
/*.language =*/ "en",
/*.detect_language =*/ false,
/*.initial_prompt =*/ nullptr,
/*.prompt_tokens =*/ nullptr,
/*.prompt_n_tokens =*/ 0,
/*.suppress_blank =*/ true,
/*.language =*/ "en",
/*.detect_language =*/ false,
/*.suppress_blank =*/ true,
/*.suppress_non_speech_tokens =*/ false,
/*.temperature =*/ 0.0f,
/*.max_initial_ts =*/ 1.0f,
/*.length_penalty =*/ -1.0f,
/*.temperature =*/ 0.0f,
/*.max_initial_ts =*/ 1.0f,
/*.length_penalty =*/ -1.0f,
/*.temperature_inc =*/ 0.4f,
/*.entropy_thold =*/ 2.4f,
/*.logprob_thold =*/ -1.0f,
/*.no_speech_thold =*/ 0.6f,
/*.temperature_inc =*/ 0.4f,
/*.entropy_thold =*/ 2.4f,
/*.logprob_thold =*/ -1.0f,
/*.no_speech_thold =*/ 0.6f,
/*.greedy =*/ {
/*.greedy =*/ {
/*.best_of =*/ -1,
},
@ -3430,6 +3447,7 @@ static int whisper_wrap_segment(struct whisper_context & ctx, struct whisper_sta
state.result_all.back().text = std::move(text);
state.result_all.back().t1 = token.t0;
state.result_all.back().tokens.resize(i);
state.result_all.back().speaker_turn_next = false;
state.result_all.push_back({});
state.result_all.back().t0 = token.t0;
@ -3441,6 +3459,8 @@ static int whisper_wrap_segment(struct whisper_context & ctx, struct whisper_sta
segment.tokens.begin() + i,
segment.tokens.end());
state.result_all.back().speaker_turn_next = segment.speaker_turn_next;
acc = 0;
text = "";
@ -3519,9 +3539,14 @@ static void whisper_process_logits(
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L410-L412
logits[vocab.token_not] = -INFINITY;
// suppress sot and solm tokens
// suppress sot and nosp tokens
logits[vocab.token_sot] = -INFINITY;
logits[vocab.token_solm] = -INFINITY;
logits[vocab.token_nosp] = -INFINITY; // TODO: ignore this token for now
// [TDRZ] when tinydiarize is disabled, suppress solm token
if (params.tdrz_enable == false) {
logits[vocab.token_solm] = -INFINITY;
}
// suppress task tokens
logits[vocab.token_translate] = -INFINITY;
@ -4018,9 +4043,9 @@ int whisper_full_with_state(
state->lang_id = lang_id;
prompt_init.push_back(whisper_token_lang(ctx, lang_id));
if (params.translate) {
prompt_init.push_back(whisper_token_translate());
prompt_init.push_back(whisper_token_translate(ctx));
} else {
prompt_init.push_back(whisper_token_transcribe());
prompt_init.push_back(whisper_token_transcribe(ctx));
}
}
@ -4500,23 +4525,27 @@ int whisper_full_with_state(
prompt_past.push_back(tokens_cur[i].id);
}
// store the text from this iteration
if (!tokens_cur.empty() && ctx->model.n_loaded > 0) {
int i0 = 0;
auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx));
std::string text;
bool speaker_turn_next = false;
for (int i = 0; i < (int) tokens_cur.size(); i++) {
//printf("%s: %18s %6.3f %18s %6.3f\n", __func__,
// ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p,
// ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt);
if (params.print_special == false && tokens_cur[i].id >= whisper_token_eot(ctx)) {
} else {
if (params.print_special || tokens_cur[i].id < whisper_token_eot(ctx)) {
text += whisper_token_to_str(ctx, tokens_cur[i].id);
}
// [TDRZ] record if speaker turn was predicted after current segment
if (params.tdrz_enable && tokens_cur[i].id == whisper_token_solm(ctx)) {
speaker_turn_next = true;
}
if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) {
const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx));
@ -4535,7 +4564,7 @@ int whisper_full_with_state(
//printf("tt0 = %d, tt1 = %d, text = %s, token = %s, token_id = %d, tid = %d\n", tt0, tt1, text.c_str(), ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].id, tokens_cur[i].tid);
result_all.push_back({ tt0, tt1, text, {} });
result_all.push_back({ tt0, tt1, text, {}, speaker_turn_next });
for (int j = i0; j <= i; j++) {
result_all.back().tokens.push_back(tokens_cur[j]);
}
@ -4561,6 +4590,7 @@ int whisper_full_with_state(
i--;
t0 = t1;
i0 = i + 1;
speaker_turn_next = false;
}
}
@ -4579,7 +4609,7 @@ int whisper_full_with_state(
}
}
result_all.push_back({ tt0, tt1, text, {} });
result_all.push_back({ tt0, tt1, text, {} , speaker_turn_next });
for (int j = i0; j < (int) tokens_cur.size(); j++) {
result_all.back().tokens.push_back(tokens_cur[j]);
}
@ -4759,6 +4789,10 @@ int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment)
return ctx->state->result_all[i_segment].t1;
}
bool whisper_full_get_segment_speaker_turn_next(struct whisper_context * ctx, int i_segment) {
return ctx->state->result_all[i_segment].speaker_turn_next;
}
const char * whisper_full_get_segment_text_from_state(struct whisper_state * state, int i_segment) {
return state->result_all[i_segment].text.c_str();
}