mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-06-17 22:38:07 +00:00
whisper : add batched decoding (#1486)
* whisper : add whisper_batch * whisper : move kv_self to whisper_state * whisper : full batched decoding support * whisper : fix memory leak in whisper_batch * whisper : fix mem leak again + remove oboslete function * whisper : clear kv cache when using whisper_decode API * whisper : speed-up sampling * whisper : fix decoders initializer * bench : add batch size 5 bench * whisper : add comment about the KV cache size * whisper : add check for max number of decoders * whisper : avoid starting sampling threads with bs=1 * whisper : enable beam-search by default * cuda : sync llama.cpp fixes
This commit is contained in:
@ -81,7 +81,7 @@ int whisper_bench_full(const whisper_params & params) {
|
||||
}
|
||||
// heat encoder
|
||||
if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
||||
fprintf(stderr, "error: failed to encode: %d\n", ret);
|
||||
return 4;
|
||||
}
|
||||
|
||||
@ -90,13 +90,13 @@ int whisper_bench_full(const whisper_params & params) {
|
||||
|
||||
// prompt heat
|
||||
if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
||||
fprintf(stderr, "error: failed to decode: %d\n", ret);
|
||||
return 4;
|
||||
}
|
||||
|
||||
// text-generation heat
|
||||
if (int ret = whisper_decode(ctx, tokens, 1, 256, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
||||
fprintf(stderr, "error: failed to decode: %d\n", ret);
|
||||
return 4;
|
||||
}
|
||||
|
||||
@ -104,20 +104,30 @@ int whisper_bench_full(const whisper_params & params) {
|
||||
|
||||
// actual run
|
||||
if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
||||
fprintf(stderr, "error: failed to encode: %d\n", ret);
|
||||
return 4;
|
||||
}
|
||||
|
||||
for (int i = 0; i < 16; i++) {
|
||||
if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
||||
// text-generation
|
||||
for (int i = 0; i < 256; i++) {
|
||||
if (int ret = whisper_decode(ctx, tokens, 1, i, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to decode: %d\n", ret);
|
||||
return 4;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < 256; i++) {
|
||||
if (int ret = whisper_decode(ctx, tokens, 1, i, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
||||
// batched decoding
|
||||
for (int i = 0; i < 64; i++) {
|
||||
if (int ret = whisper_decode(ctx, tokens, 5, 0, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to decode: %d\n", ret);
|
||||
return 4;
|
||||
}
|
||||
}
|
||||
|
||||
// prompt processing
|
||||
for (int i = 0; i < 16; i++) {
|
||||
if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to decode: %d\n", ret);
|
||||
return 4;
|
||||
}
|
||||
}
|
||||
|
@ -62,8 +62,8 @@ struct whisper_params {
|
||||
int32_t progress_step = 5;
|
||||
int32_t max_context = -1;
|
||||
int32_t max_len = 0;
|
||||
int32_t best_of = 2;
|
||||
int32_t beam_size = -1;
|
||||
int32_t best_of = whisper_full_default_params(WHISPER_SAMPLING_GREEDY).greedy.best_of;
|
||||
int32_t beam_size = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH).beam_search.beam_size;
|
||||
|
||||
float word_thold = 0.01f;
|
||||
float entropy_thold = 2.40f;
|
||||
@ -925,9 +925,9 @@ int main(int argc, char ** argv) {
|
||||
if (params.detect_language) {
|
||||
params.language = "auto";
|
||||
}
|
||||
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, %stimestamps = %d ...\n",
|
||||
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, %d beams + best of %d, lang = %s, task = %s, %stimestamps = %d ...\n",
|
||||
__func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
|
||||
params.n_threads, params.n_processors,
|
||||
params.n_threads, params.n_processors, params.beam_size, params.best_of,
|
||||
params.language.c_str(),
|
||||
params.translate ? "translate" : "transcribe",
|
||||
params.tinydiarize ? "tdrz = 1, " : "",
|
||||
|
Reference in New Issue
Block a user