From 2b434c449ef091db93b2b644df8b3a2912632d77 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 13 May 2024 14:43:43 +0300 Subject: [PATCH] whisper : switch back to F32 mask (#0) --- whisper.cpp | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/whisper.cpp b/whisper.cpp index f3daf5b6..bdcf3de4 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -2294,8 +2294,6 @@ static struct ggml_cgraph * whisper_build_graph_decoder( ggml_set_name(KQ_mask, "KQ_mask"); ggml_set_input(KQ_mask); - struct ggml_tensor * KQ_mask_f16 = ggml_cast(ctx0, KQ_mask, GGML_TYPE_F16); - // token encoding + position encoding struct ggml_tensor * cur = ggml_add(ctx0, @@ -2379,7 +2377,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder( // K * Q struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, KQ_mask_f16, 1.0f, 0.0f); + struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, KQ_mask, 1.0f, 0.0f); struct ggml_tensor * V = ggml_view_3d(ctx0, kv_self.v, @@ -2873,8 +2871,8 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector int i = ith; // make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist - assert( n_fft == 1 + (frame_size / 2) ); - + assert(n_fft == 1 + (frame_size / 2)); + // calculate FFT only when fft_in are not all zero for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) { const int offset = i * frame_step;