diff --git a/whisper-mel-cuda.cu b/whisper-mel-cuda.cu index 9a6f1093..cc44556f 100644 --- a/whisper-mel-cuda.cu +++ b/whisper-mel-cuda.cu @@ -145,17 +145,6 @@ void calc_magnitudes( constexpr auto LOG_MEL_PREFIX_SIZE = 256; -size_t get_log_mel_temp_storage_size() { - constexpr auto maxPaddedSamples = 2 * WHISPER_N_SAMPLES + WHISPER_N_FFT; - constexpr auto maxFrames = 1 + (maxPaddedSamples - WHISPER_N_FFT) / WHISPER_HOP_LENGTH; - constexpr auto maxMels = 160; - - size_t nbytes = 0; - float * temp = nullptr; - cub::DeviceReduce::Max(nullptr, nbytes, temp, temp, maxFrames * maxMels); - return nbytes + LOG_MEL_PREFIX_SIZE; -} - void calc_log_mel( const float * mel_data, int n_mel, @@ -186,11 +175,14 @@ class mel_calc_cuda : public whisper_mel_calc { float * m_hann_window = nullptr; + float * m_filters = nullptr; + + // max samples for which we have allocated memory for the temp working areas below (cufft, log_mel) + int m_n_max_samples = 0; + size_t m_cufft_workspace_size = 0; void * m_cufft_workspace = nullptr; - float * m_filters = nullptr; - size_t m_log_mel_temp_storage_size = 0; void * m_log_mel_temp_storage = nullptr; public: @@ -215,14 +207,6 @@ public: CUDA_CHECK(cudaMemcpyAsync(m_hann_window, hw.data, hw.len * sizeof(float), cudaMemcpyHostToDevice, m_stream)); } - // create working area - { - constexpr auto maxPaddedSamples = 2 * WHISPER_N_SAMPLES + WHISPER_N_FFT; - constexpr auto maxFrames = 1 + (maxPaddedSamples - WHISPER_N_FFT) / WHISPER_HOP_LENGTH; - CUFFT_CHECK(cufftEstimate1d(WHISPER_N_FFT, CUFFT_R2C, maxFrames, &m_cufft_workspace_size)); - CUDA_CHECK(cudaMallocAsync(&m_cufft_workspace, m_cufft_workspace_size, m_stream)); - } - // fill filters { auto& f = filters.data; @@ -230,10 +214,8 @@ public: CUDA_CHECK(cudaMemcpyAsync(m_filters, f.data(), f.size() * sizeof(float), cudaMemcpyHostToDevice, m_stream)); } - { - m_log_mel_temp_storage_size = get_log_mel_temp_storage_size(); - CUDA_CHECK(cudaMallocAsync(&m_log_mel_temp_storage, m_log_mel_temp_storage_size, m_stream)); - } + // preallocate working areas enough for the most common cases (<= 30s) + ensure_working_areas(WHISPER_N_SAMPLES); } ~mel_calc_cuda() { @@ -245,7 +227,49 @@ public: CUDA_CHECK(cudaFree(m_log_mel_temp_storage)); } - virtual whisper_mel calculate(whisper_span samples, int /*n_threads*/) const override { + void ensure_working_areas(int n_samples) { + if (n_samples <= m_n_max_samples) { + return; + } + + const auto max_padded_samples = n_samples + WHISPER_N_SAMPLES + WHISPER_N_FFT; + const auto max_frames = 1 + (max_padded_samples - WHISPER_N_FFT) / WHISPER_HOP_LENGTH; + + // cufft workspace + { + if (m_cufft_workspace) { + CUDA_CHECK(cudaFree(m_cufft_workspace)); + m_cufft_workspace_size = 0; + m_cufft_workspace = nullptr; + } + CUFFT_CHECK(cufftEstimate1d(WHISPER_N_FFT, CUFFT_R2C, max_frames, &m_cufft_workspace_size)); + CUDA_CHECK(cudaMallocAsync(&m_cufft_workspace, m_cufft_workspace_size, m_stream)); + } + + // device reduce working area + { + if (m_log_mel_temp_storage) { + CUDA_CHECK(cudaFree(m_log_mel_temp_storage)); + m_log_mel_temp_storage_size = 0; + m_log_mel_temp_storage = nullptr; + } + + const auto max_mels = 160; + + size_t nbytes = 0; + float* temp = nullptr; + cub::DeviceReduce::Max(nullptr, nbytes, temp, temp, max_frames * max_mels); + m_log_mel_temp_storage_size = nbytes + LOG_MEL_PREFIX_SIZE; + + CUDA_CHECK(cudaMallocAsync(&m_log_mel_temp_storage, m_log_mel_temp_storage_size, m_stream)); + } + + m_n_max_samples = n_samples; + } + + virtual whisper_mel calculate(whisper_span samples, int /*n_threads*/) override { + ensure_working_areas(samples.len); + const size_t mirror_pad = WHISPER_N_FFT / 2; const size_t padded_size = samples.len + WHISPER_N_SAMPLES + WHISPER_N_FFT; diff --git a/whisper-mel.hpp b/whisper-mel.hpp index 1a54a23c..f4210b41 100644 --- a/whisper-mel.hpp +++ b/whisper-mel.hpp @@ -29,6 +29,6 @@ struct whisper_span { struct whisper_mel_calc { virtual ~whisper_mel_calc(); - virtual whisper_mel calculate(whisper_span samples, int n_threads) const = 0; + virtual whisper_mel calculate(whisper_span samples, int n_threads) = 0; static whisper_span hann_window(); }; diff --git a/whisper.cpp b/whisper.cpp index 58b4a65e..457fef9f 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -802,6 +802,7 @@ struct whisper_state { whisper_mel mel; whisper_mel_calc * mel_calc = nullptr; + whisper_mel_calc * mel_calc_fallback = nullptr; whisper_batch batch; @@ -3079,7 +3080,7 @@ struct mel_calc_cpu : public whisper_mel_calc { mel_calc_cpu(ggml_backend_t backend, const whisper_filters & filters) : m_backend(backend), m_filters(filters) {} // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L110-L157 - whisper_mel calculate(whisper_span ssamples, int n_threads) const override { + whisper_mel calculate(whisper_span ssamples, int n_threads) override { // Hann window const float * hann = global_cache.hann_window; @@ -3721,6 +3722,8 @@ void whisper_free_state(struct whisper_state * state) { delete state->mel_calc; state->mel_calc = nullptr; + delete state->mel_calc_fallback; + state->mel_calc_fallback = nullptr; #ifdef WHISPER_USE_COREML if (state->ctx_coreml != nullptr) { @@ -3778,11 +3781,24 @@ void whisper_free_params(struct whisper_full_params * params) { } } -int whisper_pcm_to_mel_with_state(struct whisper_context * /*ctx*/, struct whisper_state * state, const float * samples, int n_samples, int n_threads) { +int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) { const int64_t t_start_us = ggml_time_us(); whisper_mel_free(state->mel); - state->mel = state->mel_calc->calculate({samples, n_samples}, n_threads); + if (n_samples <= 5 * 60 * WHISPER_SAMPLE_RATE) { + // calculate mel spectrogram for lengths up to 5 minutes on the most optimal mel calculator + state->mel = state->mel_calc->calculate({samples, n_samples}, n_threads); + } else { + // calcuate mel spectrogram for longer audios on the CPU + // 1. gpu calculations may use hundreds of megabytes of memory for longer audios so we're being conservative + // with our gpu demands + // 2. the time to transcribe audios this long will be dominated by the decoding time, so the mel calculation + // taking longer is not a major concern + if (!state->mel_calc_fallback) { + state->mel_calc_fallback = new mel_calc_cpu(state->backend, ctx->model.filters); + } + state->mel = state->mel_calc_fallback->calculate({samples, n_samples}, n_threads); + } state->t_mel_us += ggml_time_us() - t_start_us;