diff --git a/examples/bench/bench.cpp b/examples/bench/bench.cpp index 32aa3cfd..f728348e 100644 --- a/examples/bench/bench.cpp +++ b/examples/bench/bench.cpp @@ -44,13 +44,13 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); fprintf(stderr, " -w N, --what N [%-7d] what to benchmark:\n", params.what); - fprintf(stderr, " %-7s 0 - whisper encoder\n", ""); + fprintf(stderr, " %-7s 0 - whisper\n", ""); fprintf(stderr, " %-7s 1 - memcpy\n", ""); fprintf(stderr, " %-7s 2 - ggml_mul_mat\n", ""); fprintf(stderr, "\n"); } -int whisper_bench_encoder(const whisper_params & params) { +int whisper_bench_full(const whisper_params & params) { // whisper init struct whisper_context * ctx = whisper_init_from_file(params.model.c_str()); @@ -70,12 +70,26 @@ int whisper_bench_encoder(const whisper_params & params) { return 3; } - // heat up + // heat encoder if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) { fprintf(stderr, "error: failed to encode model: %d\n", ret); return 4; } + whisper_token tokens[512]; + + // prompt heat + if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) { + fprintf(stderr, "error: failed to encode model: %d\n", ret); + return 4; + } + + // text-generation heat + if (int ret = whisper_decode(ctx, tokens, 1, 256, params.n_threads) != 0) { + fprintf(stderr, "error: failed to encode model: %d\n", ret); + return 4; + } + whisper_reset_timings(ctx); // actual run @@ -84,6 +98,20 @@ int whisper_bench_encoder(const whisper_params & params) { return 4; } + for (int i = 0; i < 16; i++) { + if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) { + fprintf(stderr, "error: failed to encode model: %d\n", ret); + return 4; + } + } + + for (int i = 0; i < 128; i++) { + if (int ret = whisper_decode(ctx, tokens, 1, 256, params.n_threads) != 0) { + fprintf(stderr, "error: failed to encode model: %d\n", ret); + return 4; + } + } + whisper_print_timings(ctx); whisper_free(ctx); @@ -112,7 +140,7 @@ int main(int argc, char ** argv) { int ret = -1; switch (params.what) { - case 0: ret = whisper_bench_encoder(params); break; + case 0: ret = whisper_bench_full(params); break; case 1: ret = whisper_bench_memcpy(params.n_threads); break; case 2: ret = whisper_bench_ggml_mul_mat(params.n_threads); break; default: fprintf(stderr, "error: unknown benchmark: %d\n", params.what); break; diff --git a/extra/bench-all.sh b/extra/bench-all.sh index 772d55a7..98c8cfd6 100755 --- a/extra/bench-all.sh +++ b/extra/bench-all.sh @@ -44,8 +44,8 @@ if [ "$encoder_only" -eq 0 ]; then printf "\n" fi -printf "| CPU | OS | Config | Model | Th | Load | Enc. | Commit |\n" -printf "| --- | -- | ------ | ----- | -- | ---- | ---- | ------ |\n" +printf "| CPU | OS | Config | Model | Th | Enc. | Dec. | PP | Commit |\n" +printf "| --- | -- | ------ | ----- | -- | ---- | ---- | ---- | ------ |\n" for model in "${models[@]}"; do # actual run @@ -54,14 +54,16 @@ for model in "${models[@]}"; do ret=$? # parse the output: - load_time=$(echo "$output" | grep "load time" | awk '{print $5}') - encode_time=$(echo "$output" | grep "encode time" | awk '{print $5}') + encode_time=$(echo "$output" | grep "encode time" | awk '{print $11}') + decode_time=$(echo "$output" | grep "decode time" | awk '{print $11}') + prompt_time=$(echo "$output" | grep "prompt time" | awk '{print $11}') system_info=$(echo "$output" | grep "system_info") n_threads=$(echo "$output" | grep "system_info" | awk '{print $4}') # floor to milliseconds - load_time=${load_time%.*} - encode_time=${encode_time%.*} + #encode_time=${encode_time%.*} + #decode_time=${decode_time%.*} + #prompt_time=${prompt_time%.*} config="" @@ -84,6 +86,6 @@ for model in "${models[@]}"; do commit=$(git rev-parse --short HEAD) if [ $ret -eq 0 ]; then - printf "| | | $config | $model | $n_threads | $load_time | $encode_time | $commit |\n" + printf "| | | $config | $model | $n_threads | $encode_time | $decode_time | $prompt_time | $commit |\n" fi done diff --git a/whisper.cpp b/whisper.cpp index 4db95a61..b3f1dac0 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -641,11 +641,13 @@ struct whisper_state { int64_t t_sample_us = 0; int64_t t_encode_us = 0; int64_t t_decode_us = 0; + int64_t t_prompt_us = 0; int64_t t_mel_us = 0; int32_t n_sample = 0; // number of tokens sampled int32_t n_encode = 0; // number of encoder calls - int32_t n_decode = 0; // number of decoder calls + int32_t n_decode = 0; // number of decoder calls with n_tokens == 1 (text-generation) + int32_t n_prompt = 0; // number of decoder calls with n_tokens > 1 (prompt encoding) int32_t n_fail_p = 0; // number of logprob threshold failures int32_t n_fail_h = 0; // number of entropy threshold failures @@ -2359,8 +2361,13 @@ static bool whisper_decode_internal( // wstate.get_buf_max_mem(3)/1024.0/1024.0); } - wstate.t_decode_us += ggml_time_us() - t_start_us; - wstate.n_decode++; + if (n_tokens == 1) { + wstate.t_decode_us += ggml_time_us() - t_start_us; + wstate.n_decode++; + } else { + wstate.t_prompt_us += ggml_time_us() - t_start_us; + wstate.n_prompt++; + } return true; } @@ -3573,12 +3580,14 @@ void whisper_print_timings(struct whisper_context * ctx) { const int32_t n_sample = std::max(1, ctx->state->n_sample); const int32_t n_encode = std::max(1, ctx->state->n_encode); const int32_t n_decode = std::max(1, ctx->state->n_decode); + const int32_t n_prompt = std::max(1, ctx->state->n_prompt); log("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h); log("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f); log("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample); log("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode); log("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode); + log("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt); } log("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f); } @@ -3588,9 +3597,11 @@ void whisper_reset_timings(struct whisper_context * ctx) { ctx->state->t_sample_us = 0; ctx->state->t_encode_us = 0; ctx->state->t_decode_us = 0; + ctx->state->t_prompt_us = 0; ctx->state->n_sample = 0; ctx->state->n_encode = 0; ctx->state->n_decode = 0; + ctx->state->n_prompt = 0; } } @@ -5161,6 +5172,12 @@ int whisper_full_parallel( ctx->state->t_sample_us += states[i]->t_sample_us; ctx->state->t_encode_us += states[i]->t_encode_us; ctx->state->t_decode_us += states[i]->t_decode_us; + ctx->state->t_prompt_us += states[i]->t_prompt_us; + + ctx->state->n_sample += states[i]->n_sample; + ctx->state->n_encode += states[i]->n_encode; + ctx->state->n_decode += states[i]->n_decode; + ctx->state->n_prompt += states[i]->n_prompt; whisper_free_state(states[i]); }