bench : add batch size 5 bench

This commit is contained in:
Georgi Gerganov 2023-11-14 22:45:08 +02:00
parent 3ed9af34f2
commit ae1bd69041
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
3 changed files with 37 additions and 16 deletions

View File

@ -81,7 +81,7 @@ int whisper_bench_full(const whisper_params & params) {
} }
// heat encoder // heat encoder
if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) { if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) {
fprintf(stderr, "error: failed to encode model: %d\n", ret); fprintf(stderr, "error: failed to encode: %d\n", ret);
return 4; return 4;
} }
@ -90,13 +90,13 @@ int whisper_bench_full(const whisper_params & params) {
// prompt heat // prompt heat
if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) { if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
fprintf(stderr, "error: failed to encode model: %d\n", ret); fprintf(stderr, "error: failed to decode: %d\n", ret);
return 4; return 4;
} }
// text-generation heat // text-generation heat
if (int ret = whisper_decode(ctx, tokens, 1, 256, params.n_threads) != 0) { if (int ret = whisper_decode(ctx, tokens, 1, 256, params.n_threads) != 0) {
fprintf(stderr, "error: failed to encode model: %d\n", ret); fprintf(stderr, "error: failed to decode: %d\n", ret);
return 4; return 4;
} }
@ -104,20 +104,30 @@ int whisper_bench_full(const whisper_params & params) {
// actual run // actual run
if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) { if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) {
fprintf(stderr, "error: failed to encode model: %d\n", ret); fprintf(stderr, "error: failed to encode: %d\n", ret);
return 4; return 4;
} }
for (int i = 0; i < 16; i++) { // text-generation
if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) { for (int i = 0; i < 256; i++) {
fprintf(stderr, "error: failed to encode model: %d\n", ret); if (int ret = whisper_decode(ctx, tokens, 1, i, params.n_threads) != 0) {
fprintf(stderr, "error: failed to decode: %d\n", ret);
return 4; return 4;
} }
} }
for (int i = 0; i < 256; i++) { // batched decoding
if (int ret = whisper_decode(ctx, tokens, 1, i, params.n_threads) != 0) { for (int i = 0; i < 64; i++) {
fprintf(stderr, "error: failed to encode model: %d\n", ret); if (int ret = whisper_decode(ctx, tokens, 5, 0, params.n_threads) != 0) {
fprintf(stderr, "error: failed to decode: %d\n", ret);
return 4;
}
}
// prompt processing
for (int i = 0; i < 16; i++) {
if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
fprintf(stderr, "error: failed to decode: %d\n", ret);
return 4; return 4;
} }
} }

View File

@ -44,8 +44,8 @@ if [ "$encoder_only" -eq 0 ]; then
printf "\n" printf "\n"
fi fi
printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s |\n" "CPU" "OS" "Config" "Model" "Th" "Enc." "Dec." "PP" "Commit" printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "CPU" "OS" "Config" "Model" "Th" "Enc." "Dec." "Bch5" "PP" "Commit"
printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s |\n" "---" "---" "---" "---" "---" "---" "---" "---" "---" printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "---" "---" "---" "---" "---" "---" "---" "---" "---" "---"
for model in "${models[@]}"; do for model in "${models[@]}"; do
# actual run # actual run
@ -56,6 +56,7 @@ for model in "${models[@]}"; do
# parse the output: # parse the output:
encode_time=$(echo "$output" | grep "encode time" | awk '{print $11}') encode_time=$(echo "$output" | grep "encode time" | awk '{print $11}')
decode_time=$(echo "$output" | grep "decode time" | awk '{print $11}') decode_time=$(echo "$output" | grep "decode time" | awk '{print $11}')
batchd_time=$(echo "$output" | grep "batchd time" | awk '{print $11}')
prompt_time=$(echo "$output" | grep "prompt time" | awk '{print $11}') prompt_time=$(echo "$output" | grep "prompt time" | awk '{print $11}')
system_info=$(echo "$output" | grep "system_info") system_info=$(echo "$output" | grep "system_info")
n_threads=$(echo "$output" | grep "system_info" | awk '{print $4}') n_threads=$(echo "$output" | grep "system_info" | awk '{print $4}')
@ -94,6 +95,6 @@ for model in "${models[@]}"; do
commit=$(git rev-parse --short HEAD) commit=$(git rev-parse --short HEAD)
if [ $ret -eq 0 ]; then if [ $ret -eq 0 ]; then
printf "| <todo> | <todo> | %16s | %11s | %3s | %7s | %7s | %7s | %7s |\n" "$config" "$model" "$n_threads" "$encode_time" "$decode_time" "$prompt_time" "$commit" printf "| <todo> | <todo> | %16s | %11s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "$config" "$model" "$n_threads" "$encode_time" "$decode_time" "$batchd_time" "$prompt_time" "$commit"
fi fi
done done

View File

@ -773,13 +773,15 @@ struct whisper_state {
int64_t t_sample_us = 0; int64_t t_sample_us = 0;
int64_t t_encode_us = 0; int64_t t_encode_us = 0;
int64_t t_decode_us = 0; int64_t t_decode_us = 0;
int64_t t_batchd_us = 0;
int64_t t_prompt_us = 0; int64_t t_prompt_us = 0;
int64_t t_mel_us = 0; int64_t t_mel_us = 0;
int32_t n_sample = 0; // number of tokens sampled int32_t n_sample = 0; // number of tokens sampled
int32_t n_encode = 0; // number of encoder calls int32_t n_encode = 0; // number of encoder calls
int32_t n_decode = 0; // number of decoder calls with n_tokens == 1 (text-generation) int32_t n_decode = 0; // number of decoder calls with n_tokens == 1 (text-generation)
int32_t n_prompt = 0; // number of decoder calls with n_tokens > 1 (prompt encoding) int32_t n_batchd = 0; // number of decoder calls with n_tokens < 16 (batch decoding)
int32_t n_prompt = 0; // number of decoder calls with n_tokens > 1 (prompt encoding)
int32_t n_fail_p = 0; // number of logprob threshold failures int32_t n_fail_p = 0; // number of logprob threshold failures
int32_t n_fail_h = 0; // number of entropy threshold failures int32_t n_fail_h = 0; // number of entropy threshold failures
@ -2616,9 +2618,12 @@ static bool whisper_decode_internal(
if (batch.n_tokens == 1) { if (batch.n_tokens == 1) {
wstate.t_decode_us += ggml_time_us() - t_start_us; wstate.t_decode_us += ggml_time_us() - t_start_us;
wstate.n_decode++; wstate.n_decode++;
} else if (batch.n_tokens < 16) {
wstate.t_batchd_us += ggml_time_us() - t_start_us;
wstate.n_batchd += n_tokens;
} else { } else {
wstate.t_prompt_us += ggml_time_us() - t_start_us; wstate.t_prompt_us += ggml_time_us() - t_start_us;
wstate.n_prompt++; wstate.n_prompt += n_tokens;
} }
return !(abort_callback && abort_callback(abort_callback_data)); return !(abort_callback && abort_callback(abort_callback_data));
@ -3827,6 +3832,7 @@ void whisper_print_timings(struct whisper_context * ctx) {
const int32_t n_sample = std::max(1, ctx->state->n_sample); const int32_t n_sample = std::max(1, ctx->state->n_sample);
const int32_t n_encode = std::max(1, ctx->state->n_encode); const int32_t n_encode = std::max(1, ctx->state->n_encode);
const int32_t n_decode = std::max(1, ctx->state->n_decode); const int32_t n_decode = std::max(1, ctx->state->n_decode);
const int32_t n_batchd = std::max(1, ctx->state->n_batchd);
const int32_t n_prompt = std::max(1, ctx->state->n_prompt); const int32_t n_prompt = std::max(1, ctx->state->n_prompt);
WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h); WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
@ -3834,6 +3840,7 @@ void whisper_print_timings(struct whisper_context * ctx) {
WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample); WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode); WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode); WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
WHISPER_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_batchd_us, n_batchd, 1e-3f * ctx->state->t_batchd_us / n_batchd);
WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt); WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt);
} }
WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f); WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
@ -3850,6 +3857,7 @@ void whisper_reset_timings(struct whisper_context * ctx) {
ctx->state->n_sample = 0; ctx->state->n_sample = 0;
ctx->state->n_encode = 0; ctx->state->n_encode = 0;
ctx->state->n_decode = 0; ctx->state->n_decode = 0;
ctx->state->n_batchd = 0;
ctx->state->n_prompt = 0; ctx->state->n_prompt = 0;
} }
} }
@ -5896,11 +5904,13 @@ int whisper_full_parallel(
ctx->state->t_sample_us += states[i]->t_sample_us; ctx->state->t_sample_us += states[i]->t_sample_us;
ctx->state->t_encode_us += states[i]->t_encode_us; ctx->state->t_encode_us += states[i]->t_encode_us;
ctx->state->t_decode_us += states[i]->t_decode_us; ctx->state->t_decode_us += states[i]->t_decode_us;
ctx->state->t_batchd_us += states[i]->t_batchd_us;
ctx->state->t_prompt_us += states[i]->t_prompt_us; ctx->state->t_prompt_us += states[i]->t_prompt_us;
ctx->state->n_sample += states[i]->n_sample; ctx->state->n_sample += states[i]->n_sample;
ctx->state->n_encode += states[i]->n_encode; ctx->state->n_encode += states[i]->n_encode;
ctx->state->n_decode += states[i]->n_decode; ctx->state->n_decode += states[i]->n_decode;
ctx->state->n_batchd += states[i]->n_batchd;
ctx->state->n_prompt += states[i]->n_prompt; ctx->state->n_prompt += states[i]->n_prompt;
whisper_free_state(states[i]); whisper_free_state(states[i]);