server : add no-speech threshold parameter and functionality (#2654)

This commit is contained in:
Sacha Arbonel 2024-12-21 16:00:08 +01:00 committed by GitHub
parent f4668169a0
commit 4183517076
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 15 additions and 3 deletions

View File

@ -61,6 +61,7 @@ struct whisper_params {
float logprob_thold = -1.00f; float logprob_thold = -1.00f;
float temperature = 0.00f; float temperature = 0.00f;
float temperature_inc = 0.20f; float temperature_inc = 0.20f;
float no_speech_thold = 0.6f;
bool debug_mode = false; bool debug_mode = false;
bool translate = false; bool translate = false;
@ -137,6 +138,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " --inference-path PATH, [%-7s] Inference path for all requests\n", sparams.inference_path.c_str()); fprintf(stderr, " --inference-path PATH, [%-7s] Inference path for all requests\n", sparams.inference_path.c_str());
fprintf(stderr, " --convert, [%-7s] Convert audio to WAV, requires ffmpeg on the server", sparams.ffmpeg_converter ? "true" : "false"); fprintf(stderr, " --convert, [%-7s] Convert audio to WAV, requires ffmpeg on the server", sparams.ffmpeg_converter ? "true" : "false");
fprintf(stderr, " -sns, --suppress-nst [%-7s] suppress non-speech tokens\n", params.suppress_nst ? "true" : "false"); fprintf(stderr, " -sns, --suppress-nst [%-7s] suppress non-speech tokens\n", params.suppress_nst ? "true" : "false");
fprintf(stderr, " -nth N, --no-speech-thold N [%-7.2f] no speech threshold\n", params.no_speech_thold);
fprintf(stderr, "\n"); fprintf(stderr, "\n");
} }
@ -182,6 +184,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
else if (arg == "-sns" || arg == "--suppress-nst") { params.suppress_nst = true; } else if (arg == "-sns" || arg == "--suppress-nst") { params.suppress_nst = true; }
else if (arg == "-nth" || arg == "--no-speech-thold") { params.no_speech_thold = std::stof(argv[++i]); }
// server params // server params
else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); } else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); }
else if ( arg == "--host") { sparams.hostname = argv[++i]; } else if ( arg == "--host") { sparams.hostname = argv[++i]; }
@ -790,6 +794,7 @@ int main(int argc, char ** argv) {
wparams.beam_search.beam_size = params.beam_size; wparams.beam_search.beam_size = params.beam_size;
wparams.temperature = params.temperature; wparams.temperature = params.temperature;
wparams.no_speech_thold = params.no_speech_thold;
wparams.temperature_inc = params.temperature_inc; wparams.temperature_inc = params.temperature_inc;
wparams.entropy_thold = params.entropy_thold; wparams.entropy_thold = params.entropy_thold;
wparams.logprob_thold = params.logprob_thold; wparams.logprob_thold = params.logprob_thold;
@ -942,7 +947,7 @@ int main(int argc, char ** argv) {
// TODO compression_ratio and no_speech_prob are not implemented yet // TODO compression_ratio and no_speech_prob are not implemented yet
// segment["compression_ratio"] = 0; // segment["compression_ratio"] = 0;
// segment["no_speech_prob"] = 0; segment["no_speech_prob"] = whisper_full_get_segment_no_speech_prob(ctx, i);
jres["segments"].push_back(segment); jres["segments"].push_back(segment);
} }

View File

@ -665,6 +665,8 @@ extern "C" {
WHISPER_API void whisper_log_set(ggml_log_callback log_callback, void * user_data); WHISPER_API void whisper_log_set(ggml_log_callback log_callback, void * user_data);
// Get the no_speech probability for the specified segment
WHISPER_API float whisper_full_get_segment_no_speech_prob (struct whisper_context * ctx, int i_segment);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View File

@ -428,6 +428,7 @@ struct whisper_segment {
int64_t t1; int64_t t1;
std::string text; std::string text;
float no_speech_prob;
std::vector<whisper_token_data> tokens; std::vector<whisper_token_data> tokens;
@ -6147,7 +6148,7 @@ int whisper_full_with_state(
//printf("tt0 = %d, tt1 = %d, text = %s, token = %s, token_id = %d, tid = %d\n", tt0, tt1, text.c_str(), ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].id, tokens_cur[i].tid); //printf("tt0 = %d, tt1 = %d, text = %s, token = %s, token_id = %d, tid = %d\n", tt0, tt1, text.c_str(), ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].id, tokens_cur[i].tid);
result_all.push_back({ tt0, tt1, text, {}, speaker_turn_next }); result_all.push_back({ tt0, tt1, text, state->no_speech_prob, {}, speaker_turn_next });
for (int j = i0; j <= i; j++) { for (int j = i0; j <= i; j++) {
result_all.back().tokens.push_back(tokens_cur[j]); result_all.back().tokens.push_back(tokens_cur[j]);
} }
@ -6192,7 +6193,7 @@ int whisper_full_with_state(
} }
} }
result_all.push_back({ tt0, tt1, text, {} , speaker_turn_next }); result_all.push_back({ tt0, tt1, text, state->no_speech_prob, {}, speaker_turn_next });
for (int j = i0; j < (int) tokens_cur.size(); j++) { for (int j = i0; j < (int) tokens_cur.size(); j++) {
result_all.back().tokens.push_back(tokens_cur[j]); result_all.back().tokens.push_back(tokens_cur[j]);
} }
@ -6459,6 +6460,10 @@ float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int
return ctx->state->result_all[i_segment].tokens[i_token].p; return ctx->state->result_all[i_segment].tokens[i_token].p;
} }
float whisper_full_get_segment_no_speech_prob(struct whisper_context * ctx, int i_segment) {
return ctx->state->result_all[i_segment].no_speech_prob;
}
// ================================================================================================= // =================================================================================================
// //