mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2024-12-18 20:27:53 +00:00
bench : add batch size 5 bench
This commit is contained in:
parent
3ed9af34f2
commit
ae1bd69041
@ -81,7 +81,7 @@ int whisper_bench_full(const whisper_params & params) {
|
||||
}
|
||||
// heat encoder
|
||||
if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
||||
fprintf(stderr, "error: failed to encode: %d\n", ret);
|
||||
return 4;
|
||||
}
|
||||
|
||||
@ -90,13 +90,13 @@ int whisper_bench_full(const whisper_params & params) {
|
||||
|
||||
// prompt heat
|
||||
if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
||||
fprintf(stderr, "error: failed to decode: %d\n", ret);
|
||||
return 4;
|
||||
}
|
||||
|
||||
// text-generation heat
|
||||
if (int ret = whisper_decode(ctx, tokens, 1, 256, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
||||
fprintf(stderr, "error: failed to decode: %d\n", ret);
|
||||
return 4;
|
||||
}
|
||||
|
||||
@ -104,20 +104,30 @@ int whisper_bench_full(const whisper_params & params) {
|
||||
|
||||
// actual run
|
||||
if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
||||
fprintf(stderr, "error: failed to encode: %d\n", ret);
|
||||
return 4;
|
||||
}
|
||||
|
||||
for (int i = 0; i < 16; i++) {
|
||||
if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
||||
// text-generation
|
||||
for (int i = 0; i < 256; i++) {
|
||||
if (int ret = whisper_decode(ctx, tokens, 1, i, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to decode: %d\n", ret);
|
||||
return 4;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < 256; i++) {
|
||||
if (int ret = whisper_decode(ctx, tokens, 1, i, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
||||
// batched decoding
|
||||
for (int i = 0; i < 64; i++) {
|
||||
if (int ret = whisper_decode(ctx, tokens, 5, 0, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to decode: %d\n", ret);
|
||||
return 4;
|
||||
}
|
||||
}
|
||||
|
||||
// prompt processing
|
||||
for (int i = 0; i < 16; i++) {
|
||||
if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to decode: %d\n", ret);
|
||||
return 4;
|
||||
}
|
||||
}
|
||||
|
@ -44,8 +44,8 @@ if [ "$encoder_only" -eq 0 ]; then
|
||||
printf "\n"
|
||||
fi
|
||||
|
||||
printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s |\n" "CPU" "OS" "Config" "Model" "Th" "Enc." "Dec." "PP" "Commit"
|
||||
printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s |\n" "---" "---" "---" "---" "---" "---" "---" "---" "---"
|
||||
printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "CPU" "OS" "Config" "Model" "Th" "Enc." "Dec." "Bch5" "PP" "Commit"
|
||||
printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "---" "---" "---" "---" "---" "---" "---" "---" "---" "---"
|
||||
|
||||
for model in "${models[@]}"; do
|
||||
# actual run
|
||||
@ -56,6 +56,7 @@ for model in "${models[@]}"; do
|
||||
# parse the output:
|
||||
encode_time=$(echo "$output" | grep "encode time" | awk '{print $11}')
|
||||
decode_time=$(echo "$output" | grep "decode time" | awk '{print $11}')
|
||||
batchd_time=$(echo "$output" | grep "batchd time" | awk '{print $11}')
|
||||
prompt_time=$(echo "$output" | grep "prompt time" | awk '{print $11}')
|
||||
system_info=$(echo "$output" | grep "system_info")
|
||||
n_threads=$(echo "$output" | grep "system_info" | awk '{print $4}')
|
||||
@ -94,6 +95,6 @@ for model in "${models[@]}"; do
|
||||
commit=$(git rev-parse --short HEAD)
|
||||
|
||||
if [ $ret -eq 0 ]; then
|
||||
printf "| <todo> | <todo> | %16s | %11s | %3s | %7s | %7s | %7s | %7s |\n" "$config" "$model" "$n_threads" "$encode_time" "$decode_time" "$prompt_time" "$commit"
|
||||
printf "| <todo> | <todo> | %16s | %11s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "$config" "$model" "$n_threads" "$encode_time" "$decode_time" "$batchd_time" "$prompt_time" "$commit"
|
||||
fi
|
||||
done
|
||||
|
16
whisper.cpp
16
whisper.cpp
@ -773,13 +773,15 @@ struct whisper_state {
|
||||
int64_t t_sample_us = 0;
|
||||
int64_t t_encode_us = 0;
|
||||
int64_t t_decode_us = 0;
|
||||
int64_t t_batchd_us = 0;
|
||||
int64_t t_prompt_us = 0;
|
||||
int64_t t_mel_us = 0;
|
||||
|
||||
int32_t n_sample = 0; // number of tokens sampled
|
||||
int32_t n_encode = 0; // number of encoder calls
|
||||
int32_t n_decode = 0; // number of decoder calls with n_tokens == 1 (text-generation)
|
||||
int32_t n_prompt = 0; // number of decoder calls with n_tokens > 1 (prompt encoding)
|
||||
int32_t n_decode = 0; // number of decoder calls with n_tokens == 1 (text-generation)
|
||||
int32_t n_batchd = 0; // number of decoder calls with n_tokens < 16 (batch decoding)
|
||||
int32_t n_prompt = 0; // number of decoder calls with n_tokens > 1 (prompt encoding)
|
||||
int32_t n_fail_p = 0; // number of logprob threshold failures
|
||||
int32_t n_fail_h = 0; // number of entropy threshold failures
|
||||
|
||||
@ -2616,9 +2618,12 @@ static bool whisper_decode_internal(
|
||||
if (batch.n_tokens == 1) {
|
||||
wstate.t_decode_us += ggml_time_us() - t_start_us;
|
||||
wstate.n_decode++;
|
||||
} else if (batch.n_tokens < 16) {
|
||||
wstate.t_batchd_us += ggml_time_us() - t_start_us;
|
||||
wstate.n_batchd += n_tokens;
|
||||
} else {
|
||||
wstate.t_prompt_us += ggml_time_us() - t_start_us;
|
||||
wstate.n_prompt++;
|
||||
wstate.n_prompt += n_tokens;
|
||||
}
|
||||
|
||||
return !(abort_callback && abort_callback(abort_callback_data));
|
||||
@ -3827,6 +3832,7 @@ void whisper_print_timings(struct whisper_context * ctx) {
|
||||
const int32_t n_sample = std::max(1, ctx->state->n_sample);
|
||||
const int32_t n_encode = std::max(1, ctx->state->n_encode);
|
||||
const int32_t n_decode = std::max(1, ctx->state->n_decode);
|
||||
const int32_t n_batchd = std::max(1, ctx->state->n_batchd);
|
||||
const int32_t n_prompt = std::max(1, ctx->state->n_prompt);
|
||||
|
||||
WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
|
||||
@ -3834,6 +3840,7 @@ void whisper_print_timings(struct whisper_context * ctx) {
|
||||
WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
|
||||
WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
|
||||
WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
|
||||
WHISPER_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_batchd_us, n_batchd, 1e-3f * ctx->state->t_batchd_us / n_batchd);
|
||||
WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt);
|
||||
}
|
||||
WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
|
||||
@ -3850,6 +3857,7 @@ void whisper_reset_timings(struct whisper_context * ctx) {
|
||||
ctx->state->n_sample = 0;
|
||||
ctx->state->n_encode = 0;
|
||||
ctx->state->n_decode = 0;
|
||||
ctx->state->n_batchd = 0;
|
||||
ctx->state->n_prompt = 0;
|
||||
}
|
||||
}
|
||||
@ -5896,11 +5904,13 @@ int whisper_full_parallel(
|
||||
ctx->state->t_sample_us += states[i]->t_sample_us;
|
||||
ctx->state->t_encode_us += states[i]->t_encode_us;
|
||||
ctx->state->t_decode_us += states[i]->t_decode_us;
|
||||
ctx->state->t_batchd_us += states[i]->t_batchd_us;
|
||||
ctx->state->t_prompt_us += states[i]->t_prompt_us;
|
||||
|
||||
ctx->state->n_sample += states[i]->n_sample;
|
||||
ctx->state->n_encode += states[i]->n_encode;
|
||||
ctx->state->n_decode += states[i]->n_decode;
|
||||
ctx->state->n_batchd += states[i]->n_batchd;
|
||||
ctx->state->n_prompt += states[i]->n_prompt;
|
||||
|
||||
whisper_free_state(states[i]);
|
||||
|
Loading…
Reference in New Issue
Block a user