whisper : add new-segment callback

Can be used to process new segments as they are being generated.
Sample usage in main, for printing the resulting segments during the
inference.
This commit is contained in:
Georgi Gerganov
2022-10-22 21:06:50 +03:00
parent 8f95c25aed
commit 7affd309d3
3 changed files with 81 additions and 39 deletions

View File

@ -2320,6 +2320,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.beam_width =*/ -1,
/*.n_best =*/ -1,
},
/*.new_segment_callback =*/ nullptr,
/*.new_segment_callback_user_data =*/ nullptr,
};
} break;
case WHISPER_SAMPLING_BEAM_SEARCH:
@ -2348,6 +2351,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.beam_width =*/ 10,
/*.n_best =*/ 5,
},
/*.new_segment_callback =*/ nullptr,
/*.new_segment_callback_user_data =*/ nullptr,
};
} break;
}
@ -2549,6 +2555,9 @@ int whisper_full(
for (int j = i0; j <= i; j++) {
result_all.back().tokens.push_back(tokens_cur[j]);
}
if (params.new_segment_callback) {
params.new_segment_callback(ctx, params.new_segment_callback_user_data);
}
}
text = "";
while (i < (int) tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) {
@ -2576,6 +2585,9 @@ int whisper_full(
for (int j = i0; j < (int) tokens_cur.size(); j++) {
result_all.back().tokens.push_back(tokens_cur[j]);
}
if (params.new_segment_callback) {
params.new_segment_callback(ctx, params.new_segment_callback_user_data);
}
}
}
@ -2609,6 +2621,10 @@ const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_seg
return ctx->vocab.id_to_token[ctx->result_all[i_segment].tokens[i_token].id].c_str();
}
whisper_token whisper_full_get_token_id(struct whisper_context * ctx, int i_segment, int i_token) {
return ctx->result_all[i_segment].tokens[i_token].id;
}
float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token) {
return ctx->result_all[i_segment].tokens[i_token].p;
}