whisper : add batched decoding (#1486)

* whisper : add whisper_batch

* whisper : move kv_self to whisper_state

* whisper : full batched decoding support

* whisper : fix memory leak in whisper_batch

* whisper : fix mem leak again + remove oboslete function

* whisper : clear kv cache when using whisper_decode API

* whisper : speed-up sampling

* whisper : fix decoders initializer

* bench : add batch size 5 bench

* whisper : add comment about the KV cache size

* whisper : add check for max number of decoders

* whisper : avoid starting sampling threads with bs=1

* whisper : enable beam-search by default

* cuda : sync llama.cpp fixes
This commit is contained in:
Georgi Gerganov
2023-11-15 16:12:52 +02:00
committed by GitHub
parent d4231649e6
commit b6c5f49b78
7 changed files with 836 additions and 572 deletions

View File

@ -81,7 +81,7 @@ int whisper_bench_full(const whisper_params & params) {
}
// heat encoder
if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) {
fprintf(stderr, "error: failed to encode model: %d\n", ret);
fprintf(stderr, "error: failed to encode: %d\n", ret);
return 4;
}
@ -90,13 +90,13 @@ int whisper_bench_full(const whisper_params & params) {
// prompt heat
if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
fprintf(stderr, "error: failed to encode model: %d\n", ret);
fprintf(stderr, "error: failed to decode: %d\n", ret);
return 4;
}
// text-generation heat
if (int ret = whisper_decode(ctx, tokens, 1, 256, params.n_threads) != 0) {
fprintf(stderr, "error: failed to encode model: %d\n", ret);
fprintf(stderr, "error: failed to decode: %d\n", ret);
return 4;
}
@ -104,20 +104,30 @@ int whisper_bench_full(const whisper_params & params) {
// actual run
if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) {
fprintf(stderr, "error: failed to encode model: %d\n", ret);
fprintf(stderr, "error: failed to encode: %d\n", ret);
return 4;
}
for (int i = 0; i < 16; i++) {
if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
fprintf(stderr, "error: failed to encode model: %d\n", ret);
// text-generation
for (int i = 0; i < 256; i++) {
if (int ret = whisper_decode(ctx, tokens, 1, i, params.n_threads) != 0) {
fprintf(stderr, "error: failed to decode: %d\n", ret);
return 4;
}
}
for (int i = 0; i < 256; i++) {
if (int ret = whisper_decode(ctx, tokens, 1, i, params.n_threads) != 0) {
fprintf(stderr, "error: failed to encode model: %d\n", ret);
// batched decoding
for (int i = 0; i < 64; i++) {
if (int ret = whisper_decode(ctx, tokens, 5, 0, params.n_threads) != 0) {
fprintf(stderr, "error: failed to decode: %d\n", ret);
return 4;
}
}
// prompt processing
for (int i = 0; i < 16; i++) {
if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
fprintf(stderr, "error: failed to decode: %d\n", ret);
return 4;
}
}