diff --git a/whisper-mel-cuda.cu b/whisper-mel-cuda.cu index ad36cae5..3f3e3158 100644 --- a/whisper-mel-cuda.cu +++ b/whisper-mel-cuda.cu @@ -8,6 +8,7 @@ #include #include #include +#include #include @@ -301,27 +302,23 @@ public: &fzero, mel_data, int(n_mag_frames))); - float * log_mels = nullptr; - CUDA_CHECK(cudaMallocAsync(&log_mels, m_n_mel * n_mag_frames * sizeof(float), m_stream)); + whisper_mel ret; + // Calculate semi-padded sample length to ensure compatibility + int n_len_org = 1 + int(samples.len + mirror_pad - WHISPER_N_FFT) / WHISPER_HOP_LENGTH; + ret.init(m_backend, int(n_mag_frames), n_len_org, m_n_mel); + assert(ggml_nbytes(ret.tensor) == m_n_mel * n_mag_frames * sizeof(float)); + + float* log_mels = reinterpret_cast(ret.tensor->data); calc_log_mel( mel_data, int(m_n_mel * n_mag_frames), - m_log_mel_temp_storage, int(m_log_mel_temp_storage_size), + m_log_mel_temp_storage , int(m_log_mel_temp_storage_size), log_mels, m_stream); - whisper_mel ret; - ret.n_mel = m_n_mel; - ret.n_len = int(n_mag_frames); - // Calculate semi-padded sample length to ensure compatibility - ret.n_len_org = 1 + int(samples.len + mirror_pad - WHISPER_N_FFT) / WHISPER_HOP_LENGTH; - ret.data.resize(m_n_mel * n_mag_frames); - CUDA_CHECK(cudaMemcpyAsync(ret.data.data(), log_mels, ret.data.size() * sizeof(float), cudaMemcpyDeviceToHost, m_stream)); - CUDA_CHECK(cudaStreamSynchronize(m_stream)); // cleanup CUFFT_CHECK(cufftDestroy(plan)); - CUDA_CHECK(cudaFreeAsync(log_mels, m_stream)); CUDA_CHECK(cudaFreeAsync(mel_data, m_stream)); CUDA_CHECK(cudaFreeAsync(magnitudes, m_stream)); CUDA_CHECK(cudaFreeAsync(stft_out, m_stream)); diff --git a/whisper-mel.hpp b/whisper-mel.hpp index bc48475f..e52b804d 100644 --- a/whisper-mel.hpp +++ b/whisper-mel.hpp @@ -3,11 +3,23 @@ #include struct whisper_mel { - int n_len; - int n_len_org; - int n_mel; + int n_len_org = 0; - std::vector data; + ggml_tensor * tensor = nullptr; + ggml_context * ctx = nullptr; + ggml_backend_buffer_t buffer = nullptr; + + whisper_mel() = default; + ~whisper_mel(); + + whisper_mel(const whisper_mel &) = delete; + whisper_mel & operator=(const whisper_mel &) = delete; + whisper_mel(whisper_mel &&) noexcept; + whisper_mel & operator=(whisper_mel &&) noexcept; + + void init(ggml_backend_t backend, int n_len, int n_len_org, int n_mel); + void reset(); + void take(whisper_mel & other) noexcept; }; struct whisper_filters { diff --git a/whisper.cpp b/whisper.cpp index 2dd2f591..dfbcc9d3 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -821,7 +821,6 @@ struct whisper_state { struct ggml_tensor * embd_enc = nullptr; // helpers for GPU offloading - std::vector inp_mel; std::vector inp_mask; // decode output (2-dimensional array: [n_tokens][n_vocab]) @@ -1815,7 +1814,8 @@ static bool whisper_encode_external(const whisper_state & wstate) { static struct ggml_cgraph * whisper_build_graph_conv( whisper_context & wctx, - whisper_state & wstate) { + whisper_state & wstate, + const int mel_offset) { const auto & model = wctx.model; const auto & hparams = model.hparams; @@ -1834,9 +1834,32 @@ static struct ggml_cgraph * whisper_build_graph_conv( ggml_cgraph * gf = ggml_new_graph(ctx0); - struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels); - ggml_set_name(mel, "mel"); - ggml_set_input(mel); + ggml_tensor * mel_inp = wstate.mel.tensor; + ggml_tensor * mel; + if (mel_inp) { + const int n_len = int(mel_inp->ne[0]); + const int out_s = 2 * n_ctx; + const int i0 = std::min(mel_offset, n_len); + const int i1 = std::min(mel_offset + out_s, n_len); + const int mel_s = i1 - i0; + + assert(mel_inp->type == GGML_TYPE_F32); + assert(mel_inp->ne[1] == n_mels); + + ggml_tensor * cur = ggml_view_2d(ctx0, mel_inp, out_s, n_mels, mel_inp->nb[1], ggml_row_size(mel_inp->type, i0)); + + if (mel_s < out_s) { + mel = ggml_pad(ctx0, cur, out_s - mel_s, 0, 0, 0); + } + else { + mel = ggml_cont(ctx0, cur); + } + } + else { + // just create some tensor so that the graph/buffer size estimation is correct + mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2 * n_ctx, n_mels); + } + ggml_set_name(mel, "mel"); // used with external encoding struct ggml_tensor * cur = nullptr; @@ -2218,45 +2241,21 @@ static bool whisper_encode_internal( { auto & alloc = wstate.alloc_conv.alloc; - ggml_cgraph * gf = whisper_build_graph_conv(wctx, wstate); + ggml_cgraph * gf = whisper_build_graph_conv(wctx, wstate, mel_offset); if (!ggml_gallocr_alloc_graph(alloc, gf)) { // should never happen as we pre-allocate the memory return false; } - struct ggml_tensor * mel = ggml_graph_get_tensor(gf, "mel"); - - // set the input - { - const auto & mel_inp = wstate.mel; - const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : wctx.model.hparams.n_audio_ctx; - - assert(mel->type == GGML_TYPE_F32); - assert(mel_inp.n_mel == wctx.model.hparams.n_mels); - - wstate.inp_mel.resize(ggml_nelements(mel)); - - float * dst = wstate.inp_mel.data(); - memset(dst, 0, ggml_nbytes(mel)); - - const int i0 = std::min(mel_offset, mel_inp.n_len); - const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len); - - for (int j = 0; j < mel_inp.n_mel; ++j) { - for (int i = i0; i < i1; ++i) { - dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i]; - } - } - - ggml_backend_tensor_set(mel, wstate.inp_mel.data(), 0, ggml_nelements(mel)*sizeof(float)); + if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) { + return false; } - if (!whisper_encode_external(wstate)) { - if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) { - return false; - } - } else { + if (whisper_encode_external(wstate)) { + ggml_tensor * mel = ggml_graph_get_tensor(gf, "mel"); + assert(mel->ne[1] == wctx.model.hparams.n_mels); + GGML_UNUSED(mel); #if defined(WHISPER_USE_COREML) whisper_coreml_encode(wstate.ctx_coreml, mel->ne[0], mel->ne[1], (float *) mel->data, (float *) wstate.embd_enc->data); #elif defined(WHISPER_USE_OPENVINO) @@ -2886,6 +2885,54 @@ struct whisper_global_cache { // Mel spectrogram +whisper_mel::~whisper_mel() { + reset(); +} + +whisper_mel::whisper_mel(whisper_mel && other) noexcept { + take(other); +} + +whisper_mel & whisper_mel::operator=(whisper_mel && other) noexcept { + if (this != &other) { + reset(); + take(other); + } + return *this; +} + +void whisper_mel::init(ggml_backend_t backend, int n_len, int n_len_org, int n_mel) { + this->n_len_org = n_len_org; + assert(!ctx); + ctx = ggml_init({ggml_tensor_overhead(), nullptr, true}); + tensor = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_len, n_mel); + buffer = ggml_backend_alloc_buffer(backend, ggml_nbytes(tensor) + ggml_backend_get_alignment(backend)); + auto alloc = ggml_tallocr_new(buffer); + ggml_tallocr_alloc(&alloc, tensor); +} + +void whisper_mel::reset() { + ggml_free(ctx); + ggml_backend_buffer_free(buffer); + + n_len_org = 0; + tensor = nullptr; + ctx = nullptr; + buffer = nullptr; +} + +void whisper_mel::take(whisper_mel & other) noexcept { + n_len_org = other.n_len_org; + tensor = other.tensor; + ctx = other.ctx; + buffer = other.buffer; + + other.n_len_org = 0; + other.tensor = nullptr; + other.ctx = nullptr; + other.buffer = nullptr; +} + whisper_mel_calc::~whisper_mel_calc() = default; // export vtable whisper_span whisper_mel_calc::hann_window() { @@ -2973,9 +3020,18 @@ static void fft(const std::vector & in, std::vector & out) { } } -static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector & samples, +namespace { + +struct whisper_mel_data { + int n_len; + int n_len_org; + int n_mel; + float* data; +}; + +void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector & samples, int n_samples, int n_threads, - const whisper_filters & filters, whisper_mel & mel) { + const whisper_filters & filters, whisper_mel_data & mel) { const auto frame_size = WHISPER_N_FFT; const auto frame_step = WHISPER_HOP_LENGTH; std::vector fft_in(frame_size, 0.0); @@ -3041,10 +3097,11 @@ static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const } } } -namespace { + struct mel_calc_cpu : public whisper_mel_calc { + ggml_backend_t m_backend; const whisper_filters& m_filters; - mel_calc_cpu(const whisper_filters & filters) : m_filters(filters) {} + mel_calc_cpu(ggml_backend_t backend, const whisper_filters & filters) : m_backend(backend), m_filters(filters) {} // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L110-L157 whisper_mel calculate(whisper_span ssamples, int n_threads) const override { @@ -3069,15 +3126,24 @@ struct mel_calc_cpu : public whisper_mel_calc { // reflective pad 200 samples at the beginning of audio std::reverse_copy(samples + 1, samples + 1 + stage_2_pad, samples_padded.begin()); - whisper_mel mel; + whisper_mel_data mel; mel.n_mel = m_filters.n_mel; // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/SpectralOps.cpp#L936 // Calculate number of frames + remove the last frame mel.n_len = (samples_padded.size() - WHISPER_N_FFT) / WHISPER_HOP_LENGTH; // Calculate semi-padded sample length to ensure compatibility mel.n_len_org = 1 + (n_samples + stage_2_pad - WHISPER_N_FFT) / WHISPER_HOP_LENGTH; - mel.data.resize(mel.n_mel * mel.n_len); + std::vector host_mel_data; + + whisper_mel ret; + ret.init(m_backend, mel.n_len, mel.n_len_org, mel.n_mel); + if (ggml_backend_buffer_is_host(ret.buffer)) { + mel.data = reinterpret_cast(ret.tensor->data); + } else { + host_mel_data.resize(mel.n_len * mel.n_mel); + mel.data = host_mel_data.data(); + } { std::vector workers(n_threads - 1); @@ -3114,7 +3180,12 @@ struct mel_calc_cpu : public whisper_mel_calc { mel.data[i] = (mel.data[i] + 4.0)/4.0; } - return mel; + if (!host_mel_data.empty()) { + // the ret buffer is not host-accessible so we used this temporary buffer and now we need to upload it + ggml_backend_tensor_set(ret.tensor, host_mel_data.data(), 0, ggml_nbytes(ret.tensor)); + } + + return ret; } }; } @@ -3129,7 +3200,7 @@ whisper_mel_calc * whisper_mel_calc_create(ggml_backend_t backend, const whisper return ret; } else #endif - return new mel_calc_cpu(filters); + return new mel_calc_cpu(backend, filters); } // split text into tokens @@ -3347,7 +3418,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { { bool ok = whisper_allocr_graph_init(state->alloc_conv, ctx->backend, [&]() { - return whisper_build_graph_conv(*ctx, *state); + return whisper_build_graph_conv(*ctx, *state, 0); }); if (!ok) { @@ -3763,12 +3834,9 @@ int whisper_set_mel_with_state( return -1; } - state->mel.n_len = n_len; - state->mel.n_len_org = n_len; - state->mel.n_mel = n_mel; - - state->mel.data.resize(n_len*n_mel); - memcpy(state->mel.data.data(), data, n_len*n_mel*sizeof(float)); + state->mel.reset(); + state->mel.init(ctx->backend, n_len, n_len, n_mel); + ggml_backend_tensor_set(state->mel.tensor, data, 0, ggml_nbytes(state->mel.tensor)); return 0; }