diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 7e8d461f..ada1a312 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -1253,3 +1253,23 @@ jobs: source venv/bin/activate pip install ane_transformers openai-whisper coremltools ./models/generate-coreml-model.sh ${{ env.MODEL_NAME }} + + vad: + if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || + github.event.inputs.run_type == 'full-ci' }} + runs-on: ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Build + shell: bash + run: | + cmake -B build + cmake --build build --config Release + + - name: Test + shell: bash + run: | + ctest -R ^test-vad$ --test-dir build --output-on-failure -VV diff --git a/README.md b/README.md index 860aa608..d0ead52e 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,7 @@ High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisp - [Ascend NPU Support](#ascend-npu-support) - [Moore Threads GPU Support](#moore-threads-gpu-support) - [C-style API](https://github.com/ggml-org/whisper.cpp/blob/master/include/whisper.h) +- [Voice Activity Detection (VAD)](#voice-activity-detection-vad) Supported platforms: @@ -732,6 +733,64 @@ let package = Package( ) ``` +### Voice Activity Detection (VAD) +Support for Voice Activity Detection (VAD) can be enabled using the `--vad` +argument to `whisper-cli`. In addition to this option a VAD model is also +required. + +The way this works is that first the audio samples are passed through +the VAD model which will detect speech segments. Using this information the +only the speech segments that are detected are extracted from the original audio +input and passed to whisper for processing. This reduces the amount of audio +data that needs to be processed by whisper and can significantly speed up the +transcription process. + +The following VAD models are currently supported: + +#### Silero-VAD +[Silero-vad](https://github.com/snakers4/silero-vad) is a lightweight VAD model +written in Python that is fast and accurate. + +This model can be converted to ggml using the following command: +```console +$ python3 -m venv venv && source venv/bin/activate +$ (venv) pip install silero-vad +$ (venv) $ python models/convert-silero-vad-to-ggml.py --output models/silero.bin +Saving GGML Silero-VAD model to models/silero-v5.1.2-ggml.bin +``` +And it can then be used with whisper as follows: +```console +$ ./build/bin/whisper-cli \ + --file ./samples/jfk.wav \ + --model ./models/ggml-base.en.bin \ + --vad \ + --vad-model ./models/silero-v5.1.2-ggml.bin +``` + +#### VAD Options + +* --vad-threshold: Threshold probability for speech detection. A probability +for a speech segment/frame above this threshold will be considered as speech. + +* --vad-min-speech-duration-ms: Minimum speech duration in milliseconds. Speech +segments shorter than this value will be discarded to filter out brief noise or +false positives. + +* --vad-min-silence-duration-ms: Minimum silence duration in milliseconds. Silence +periods must be at least this long to end a speech segment. Shorter silence +periods will be ignored and included as part of the speech. + +* --vad-max-speech-duration-s: Maximum speech duration in seconds. Speech segments +longer than this will be automatically split into multiple segments at silence +points exceeding 98ms to prevent excessively long segments. + +* --vad-speech-pad-ms: Speech padding in milliseconds. Adds this amount of padding +before and after each detected speech segment to avoid cutting off speech edges. + +* --vad-samples-overlap: Amount of audio to extend from each speech segment into +the next one, in seconds (e.g., 0.10 = 100ms overlap). This ensures speech isn't +cut off abruptly between segments when they're concatenated together. + ## Examples There are various examples of using the library for different projects in the [examples](examples) folder. diff --git a/examples/cli/cli.cpp b/examples/cli/cli.cpp index 0cc3d38f..28dba706 100644 --- a/examples/cli/cli.cpp +++ b/examples/cli/cli.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #if defined(_WIN32) #ifndef NOMINMAX @@ -97,6 +98,16 @@ struct whisper_params { std::vector fname_out = {}; grammar_parser::parse_state grammar_parsed; + + // Voice Activity Detection (VAD) parameters + bool vad = false; + std::string vad_model = ""; + float vad_threshold = 0.5f; + int vad_min_speech_duration_ms = 250; + int vad_min_silence_duration_ms = 100; + float vad_max_speech_duration_s = FLT_MAX; + int vad_speech_pad_ms = 30; + float vad_samples_overlap = 0.1f; }; static void whisper_print_usage(int argc, char ** argv, const whisper_params & params); @@ -185,6 +196,15 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params else if ( arg == "--grammar") { params.grammar = ARGV_NEXT; } else if ( arg == "--grammar-rule") { params.grammar_rule = ARGV_NEXT; } else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(ARGV_NEXT); } + // Voice Activity Detection (VAD) + else if (arg == "-v" || arg == "--vad") { params.vad = true; } + else if (arg == "-vm" || arg == "--vad-model") { params.vad_model = ARGV_NEXT; } + else if (arg == "-vt" || arg == "--vad-threshold") { params.vad_threshold = std::stof(ARGV_NEXT); } + else if (arg == "-vsd" || arg == "--vad-min-speech-duration-ms") { params.vad_min_speech_duration_ms = std::stoi(ARGV_NEXT); } + else if (arg == "-vsd" || arg == "--vad-min-silence-duration-ms") { params.vad_min_speech_duration_ms = std::stoi(ARGV_NEXT); } + else if (arg == "-vmsd" || arg == "--vad-max-speech-duration-s") { params.vad_max_speech_duration_s = std::stof(ARGV_NEXT); } + else if (arg == "-vp" || arg == "--vad-speech-pad-ms") { params.vad_speech_pad_ms = std::stoi(ARGV_NEXT); } + else if (arg == "-vo" || arg == "--vad-samples-overlap") { params.vad_samples_overlap = std::stof(ARGV_NEXT); } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); whisper_print_usage(argc, argv, params); @@ -254,6 +274,18 @@ static void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str()); fprintf(stderr, " --grammar-rule RULE [%-7s] top-level GBNF grammar rule name\n", params.grammar_rule.c_str()); fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty); + // Voice Activity Detection (VAD) parameters + fprintf(stderr, "\nVoice Activity Detection (VAD) options:\n"); + fprintf(stderr, " -v, --vad [%-7s] enable Voice Activity Detection (VAD)\n", params.vad ? "true" : "false"); + fprintf(stderr, " -vm FNAME, --vad-model FNAME [%-7s] VAD model path\n", params.vad_model.c_str()); + fprintf(stderr, " -vt N, --vad-threshold N [%-7.2f] VAD threshold for speech recognition\n", params.vad_threshold); + fprintf(stderr, " -vspd N, --vad-min-speech-duration-ms N [%-7d] VAD min speech duration (0.0-1.0)\n", params.vad_min_speech_duration_ms); + fprintf(stderr, " -vsd N, --vad-min-silence-duration-ms N [%-7d] VAD min silence duration (to split segments)\n", params.vad_min_silence_duration_ms); + fprintf(stderr, " -vmsd N, --vad-max-speech-duration-s N [%-7s] VAD max speech duration (auto-split longer)\n", params.vad_max_speech_duration_s == FLT_MAX ? + std::string("FLT_MAX").c_str() : + std::to_string(params.vad_max_speech_duration_s).c_str()); + fprintf(stderr, " -vp N, --vad-speech-pad-ms N [%-7d] VAD speech padding (extend segments)\n", params.vad_speech_pad_ms); + fprintf(stderr, " -vo N, --vad-samples-overlap N [%-7.2f] VAD samples overlap (seconds between segments)\n", params.vad_samples_overlap); fprintf(stderr, "\n"); } @@ -1134,6 +1166,16 @@ int main(int argc, char ** argv) { wparams.suppress_nst = params.suppress_nst; + wparams.vad = params.vad; + wparams.vad_model_path = params.vad_model.c_str(); + + wparams.vad_params.threshold = params.vad_threshold; + wparams.vad_params.min_speech_duration_ms = params.vad_min_speech_duration_ms; + wparams.vad_params.min_silence_duration_ms = params.vad_min_silence_duration_ms; + wparams.vad_params.max_speech_duration_s = params.vad_max_speech_duration_s; + wparams.vad_params.speech_pad_ms = params.vad_speech_pad_ms; + wparams.vad_params.samples_overlap = params.vad_samples_overlap; + whisper_print_user_data user_data = { ¶ms, &pcmf32s, 0 }; const auto & grammar_parsed = params.grammar_parsed; diff --git a/include/whisper.h b/include/whisper.h index 1e137503..4aeda98f 100644 --- a/include/whisper.h +++ b/include/whisper.h @@ -189,6 +189,15 @@ extern "C" { uint32_t value; // Unicode code point or rule ID } whisper_grammar_element; + typedef struct whisper_vad_params { + float threshold; // Probability threshold to consider as speech. + int min_speech_duration_ms; // Min duration for a valid speech segment. + int min_silence_duration_ms; // Min silence duration to consider speech as ended. + float max_speech_duration_s; // Max duration of a speech segment before forcing a new segment. + int speech_pad_ms; // Padding added before and after speech segments. + float samples_overlap; // Overlap in seconds when copying audio samples from speech segment. + } whisper_vad_params; + // Various functions for loading a ggml whisper model. // Allocate (almost) all memory needed for the model. // Return NULL on failure @@ -570,11 +579,18 @@ extern "C" { size_t n_grammar_rules; size_t i_start_rule; float grammar_penalty; + + // Voice Activity Detection (VAD) params + bool vad; // Enable VAD + const char * vad_model_path; // Path to VAD model + + whisper_vad_params vad_params; }; // NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see whisper_free_context_params & whisper_free_params() WHISPER_API struct whisper_context_params * whisper_context_default_params_by_ref(void); WHISPER_API struct whisper_context_params whisper_context_default_params (void); + WHISPER_API struct whisper_full_params * whisper_full_default_params_by_ref(enum whisper_sampling_strategy strategy); WHISPER_API struct whisper_full_params whisper_full_default_params (enum whisper_sampling_strategy strategy); @@ -652,6 +668,53 @@ extern "C" { WHISPER_API float whisper_full_get_token_p (struct whisper_context * ctx, int i_segment, int i_token); WHISPER_API float whisper_full_get_token_p_from_state(struct whisper_state * state, int i_segment, int i_token); + // + // Voice Activity Detection (VAD) + // + + struct whisper_vad_context; + + WHISPER_API struct whisper_vad_params whisper_vad_default_params(void); + + struct whisper_vad_context_params { + int n_threads; // The number of threads to use for processing. + bool use_gpu; + int gpu_device; // CUDA device + }; + + WHISPER_API struct whisper_vad_context_params whisper_vad_default_context_params(void); + + WHISPER_API struct whisper_vad_context * whisper_vad_init_from_file_with_params(const char * path_model, struct whisper_vad_context_params params); + WHISPER_API struct whisper_vad_context * whisper_vad_init_with_params (struct whisper_model_loader * loader, struct whisper_vad_context_params params); + + WHISPER_API bool whisper_vad_detect_speech( + struct whisper_vad_context * vctx, + const float * samples, + int n_samples); + + WHISPER_API int whisper_vad_n_probs(struct whisper_vad_context * vctx); + WHISPER_API float * whisper_vad_probs (struct whisper_vad_context * vctx); + + struct whisper_vad_segments; + + WHISPER_API struct whisper_vad_segments * whisper_vad_segments_from_probs( + struct whisper_vad_context * vctx, + struct whisper_vad_params params); + + WHISPER_API struct whisper_vad_segments * whisper_vad_segments_from_samples( + struct whisper_vad_context * vctx, + struct whisper_vad_params params, + const float * samples, + int n_samples); + + WHISPER_API int whisper_vad_segments_n_segments(struct whisper_vad_segments * segments); + + WHISPER_API float whisper_vad_segments_get_segment_t0(struct whisper_vad_segments * segments, int i_segment); + WHISPER_API float whisper_vad_segments_get_segment_t1(struct whisper_vad_segments * segments, int i_segment); + + WHISPER_API void whisper_vad_free_segments(struct whisper_vad_segments * segments); + WHISPER_API void whisper_vad_free (struct whisper_vad_context * ctx); + //////////////////////////////////////////////////////////////////////////// // Temporary helpers needed for exposing ggml interface diff --git a/models/convert-silero-vad-to-ggml.py b/models/convert-silero-vad-to-ggml.py new file mode 100644 index 00000000..078131c9 --- /dev/null +++ b/models/convert-silero-vad-to-ggml.py @@ -0,0 +1,196 @@ +import os +import struct +import argparse +import torch +import numpy as np +from silero_vad import load_silero_vad, __version__ as silero_version + +def convert_silero_vad(output_path, print_tensors=True): + model = load_silero_vad() + state_dict = model.state_dict() + + # Clean up state dict keys - filter out 8k model + cleaned_dict = {} + for key, value in state_dict.items(): + # Skip 8k model + if "_8k" not in key: + clean_key = key + if not key.startswith("_model."): + clean_key = "_model." + key + cleaned_dict[clean_key] = value + + base, ext = os.path.splitext(output_path) + output_file = f"{base}-v{silero_version}-ggml{ext}" + print(f"Saving GGML Silero-VAD model to {output_file}") + + print("\nTensor info for debugging:") + for key, tensor in cleaned_dict.items(): + print(f" - {key}: {tensor.shape} ({tensor.dtype})") + print() + + with open(output_file, "wb") as fout: + # Write magic and version + fout.write(struct.pack("i", 0x67676d6c)) + + model_type = "silero-16k" + str_len = len(model_type) + fout.write(struct.pack("i", str_len)) + fout.write(model_type.encode('utf-8')) + + version_parts = silero_version.split('.') + major, minor, patch = map(int, version_parts) + print(f"Version: {major}.{minor}.{patch}") + fout.write(struct.pack("i", major)) + fout.write(struct.pack("i", minor)) + fout.write(struct.pack("i", patch)) + + # Write model architecture parameters + window_size = 512 + fout.write(struct.pack("i", window_size)) + context_size = 64 + fout.write(struct.pack("i", context_size)) + + n_encoder_layers = 4 + fout.write(struct.pack("i", n_encoder_layers)) + + # Write encoder dimensions + input_channels = 129 + encoder_in_channels = [input_channels, 128, 64, 64] + encoder_out_channels = [128, 64, 64, 128] + kernel_size = 3 + + for i in range(n_encoder_layers): + fout.write(struct.pack("i", encoder_in_channels[i])) + fout.write(struct.pack("i", encoder_out_channels[i])) + fout.write(struct.pack("i", kernel_size)) + + # Write LSTM dimensions + lstm_input_size = 128 + lstm_hidden_size = 128 + fout.write(struct.pack("i", lstm_input_size)) + fout.write(struct.pack("i", lstm_hidden_size)) + + # Write final conv dimensions + final_conv_in = 128 + final_conv_out = 1 + fout.write(struct.pack("i", final_conv_in)) + fout.write(struct.pack("i", final_conv_out)) + + # Define tensor keys to write + tensor_keys = [] + + # Encoder weights + for i in range(n_encoder_layers): + weight_key = f"_model.encoder.{i}.reparam_conv.weight" + bias_key = f"_model.encoder.{i}.reparam_conv.bias" + if weight_key in cleaned_dict and bias_key in cleaned_dict: + tensor_keys.append(weight_key) + tensor_keys.append(bias_key) + + # LSTM weights + lstm_keys = [ + "_model.decoder.rnn.weight_ih", + "_model.decoder.rnn.weight_hh", + "_model.decoder.rnn.bias_ih", + "_model.decoder.rnn.bias_hh" + ] + tensor_keys.extend([k for k in lstm_keys if k in cleaned_dict]) + + # Final conv weights + final_keys = [ + "_model.decoder.decoder.2.weight", + "_model.decoder.decoder.2.bias" + ] + tensor_keys.extend([k for k in final_keys if k in cleaned_dict]) + + # STFT basis - add this last + stft_tensor = "_model.stft.forward_basis_buffer" + tensor_keys.append(stft_tensor) + + print(f"Writing {len(tensor_keys)} tensors:") + for key in tensor_keys: + if key in cleaned_dict: + print(f" - {key}: {cleaned_dict[key].shape}") + else: + print(f" - {key}: MISSING") + + # Process each tensor + for key in tensor_keys: + if key not in cleaned_dict: + print(f"Warning: Missing tensor {key}, skipping") + continue + + tensor = cleaned_dict[key] + + # Special handling for STFT tensor + if key == "_model.stft.forward_basis_buffer": + # Get the original numpy array without squeezing + data = tensor.detach().cpu().numpy() + # Ensure it has the expected shape + print(f"STFT tensor original shape: {data.shape}") + n_dims = 3 + tensor_shape = [data.shape[2], data.shape[1], data.shape[0]] + is_conv_weight = True + else: + # For other tensors, we can use standard processing + data = tensor.detach().cpu().squeeze().numpy() + tensor_shape = list(data.shape) + + # Ensure we have at most 4 dimensions for GGML + n_dims = min(len(tensor_shape), 4) + + # Reverse dimensions for GGML + tensor_shape = tensor_shape[:n_dims] + tensor_shape.reverse() + + # Check if this is a convolution weight tensor + is_conv_weight = "weight" in key and ("encoder" in key or "_model.decoder.decoder.2" in key) + + # Convert to float16 for convolution weights + if is_conv_weight: + data = data.astype(np.float16) + ftype = 1 # float16 + else: + ftype = 0 # float32 + + # Debug printing of tensor info + print(f"\nWriting tensor: {key}") + print(f" Original shape: {tensor.shape}") + print(f" Processed shape: {data.shape}") + print(f" GGML dimensions: {n_dims}") + print(f" GGML shape: {tensor_shape}") + print(f" Type: {'float16' if ftype == 1 else 'float32'}") + + # Convert tensor name to bytes + name_bytes = key.encode('utf-8') + name_length = len(name_bytes) + + # Write tensor header + fout.write(struct.pack("i", n_dims)) + fout.write(struct.pack("i", name_length)) + fout.write(struct.pack("i", ftype)) + + # Write tensor dimensions + for i in range(n_dims): + size = tensor_shape[i] if i < len(tensor_shape) else 1 + fout.write(struct.pack("i", size)) + print(f" Writing dimension {i}: {size}") + + # Write tensor name + fout.write(name_bytes) + + # Write tensor data + data.tofile(fout) + + print(f" Wrote {data.size * (2 if ftype==1 else 4)} bytes") + + print(f"\nDone! Model has been converted to GGML format: {output_file}") + print(f"File size: {os.path.getsize(output_file)} bytes") + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert Silero-VAD PyTorch model to GGML format") + parser.add_argument("--output", type=str, required=True, help="Path to output GGML model file") + parser.add_argument("--print-tensors", action="store_true", help="Print tensor values", default=True) + args = parser.parse_args() + + convert_silero_vad(args.output, args.print_tensors) diff --git a/models/for-tests-silero-v5.1.2-ggml.bin b/models/for-tests-silero-v5.1.2-ggml.bin new file mode 100644 index 00000000..c5ddfb53 Binary files /dev/null and b/models/for-tests-silero-v5.1.2-ggml.bin differ diff --git a/src/whisper-arch.h b/src/whisper-arch.h index ea2cfd60..3a65ff35 100644 --- a/src/whisper-arch.h +++ b/src/whisper-arch.h @@ -139,3 +139,59 @@ static const std::map ASR_TENSOR_INFO = { {ASR_TENSOR_ATTN_OUT_WEIGHT, GGML_OP_MUL_MAT}, {ASR_TENSOR_ATTN_OUT_BIAS, GGML_OP_ADD}, }; + +enum vad_tensor { + VAD_TENSOR_STFT_BASIS, + VAD_TENSOR_ENC_0_WEIGHT, + VAD_TENSOR_ENC_0_BIAS, + VAD_TENSOR_ENC_1_WEIGHT, + VAD_TENSOR_ENC_1_BIAS, + VAD_TENSOR_ENC_2_WEIGHT, + VAD_TENSOR_ENC_2_BIAS, + VAD_TENSOR_ENC_3_WEIGHT, + VAD_TENSOR_ENC_3_BIAS, + VAD_TENSOR_LSTM_WEIGHT_IH, + VAD_TENSOR_LSTM_WEIGHT_HH, + VAD_TENSOR_LSTM_BIAS_IH, + VAD_TENSOR_LSTM_BIAS_HH, + VAD_TENSOR_FINAL_CONV_WEIGHT, + VAD_TENSOR_FINAL_CONV_BIAS, +}; + +static const std::map VAD_TENSOR_OPS = { + {VAD_TENSOR_STFT_BASIS, GGML_OP_IM2COL}, + {VAD_TENSOR_ENC_0_WEIGHT, GGML_OP_IM2COL}, + {VAD_TENSOR_ENC_0_BIAS, GGML_OP_ADD}, + {VAD_TENSOR_ENC_1_WEIGHT, GGML_OP_IM2COL}, + {VAD_TENSOR_ENC_1_BIAS, GGML_OP_ADD}, + {VAD_TENSOR_ENC_2_WEIGHT, GGML_OP_IM2COL}, + {VAD_TENSOR_ENC_2_BIAS, GGML_OP_ADD}, + {VAD_TENSOR_ENC_3_WEIGHT, GGML_OP_IM2COL}, + {VAD_TENSOR_ENC_3_BIAS, GGML_OP_ADD}, + + {VAD_TENSOR_LSTM_WEIGHT_IH, GGML_OP_MUL_MAT}, + {VAD_TENSOR_LSTM_WEIGHT_HH, GGML_OP_MUL_MAT}, + {VAD_TENSOR_LSTM_BIAS_IH, GGML_OP_ADD}, + {VAD_TENSOR_LSTM_BIAS_HH, GGML_OP_ADD}, + + {VAD_TENSOR_FINAL_CONV_WEIGHT, GGML_OP_IM2COL}, + {VAD_TENSOR_FINAL_CONV_BIAS, GGML_OP_ADD} +}; + +static const std::map VAD_TENSOR_NAMES = { + {VAD_TENSOR_STFT_BASIS, "_model.stft.forward_basis_buffer"}, + {VAD_TENSOR_ENC_0_WEIGHT, "_model.encoder.0.reparam_conv.weight"}, + {VAD_TENSOR_ENC_0_BIAS, "_model.encoder.0.reparam_conv.bias"}, + {VAD_TENSOR_ENC_1_WEIGHT, "_model.encoder.1.reparam_conv.weight"}, + {VAD_TENSOR_ENC_1_BIAS, "_model.encoder.1.reparam_conv.bias"}, + {VAD_TENSOR_ENC_2_WEIGHT, "_model.encoder.2.reparam_conv.weight"}, + {VAD_TENSOR_ENC_2_BIAS, "_model.encoder.2.reparam_conv.bias"}, + {VAD_TENSOR_ENC_3_WEIGHT, "_model.encoder.3.reparam_conv.weight"}, + {VAD_TENSOR_ENC_3_BIAS, "_model.encoder.3.reparam_conv.bias"}, + {VAD_TENSOR_LSTM_WEIGHT_IH, "_model.decoder.rnn.weight_ih"}, + {VAD_TENSOR_LSTM_WEIGHT_HH, "_model.decoder.rnn.weight_hh"}, + {VAD_TENSOR_LSTM_BIAS_IH, "_model.decoder.rnn.bias_ih"}, + {VAD_TENSOR_LSTM_BIAS_HH, "_model.decoder.rnn.bias_hh"}, + {VAD_TENSOR_FINAL_CONV_WEIGHT, "_model.decoder.decoder.2.weight"}, + {VAD_TENSOR_FINAL_CONV_BIAS, "_model.decoder.decoder.2.bias"} +}; diff --git a/src/whisper.cpp b/src/whisper.cpp index e103e29b..bba91f20 100644 --- a/src/whisper.cpp +++ b/src/whisper.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #define _USE_MATH_DEFINES #include #include @@ -163,7 +164,6 @@ static bool ggml_graph_compute_helper( int n_threads, ggml_abort_callback abort_callback, void * abort_callback_data) { - ggml_backend_ptr backend { ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr) }; auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend.get())); @@ -184,8 +184,8 @@ static bool ggml_graph_compute_helper( static bool ggml_graph_compute_helper( ggml_backend_sched_t sched, struct ggml_cgraph * graph, - int n_threads) { - + int n_threads, + bool sched_reset = true) { for (int i = 0; i < ggml_backend_sched_get_n_backends(sched); ++i) { ggml_backend_t backend = ggml_backend_sched_get_backend(sched, i); ggml_backend_dev_t dev = ggml_backend_get_device(backend); @@ -197,8 +197,12 @@ static bool ggml_graph_compute_helper( } } - bool t = ggml_backend_sched_graph_compute(sched, graph) == GGML_STATUS_SUCCESS; - ggml_backend_sched_reset(sched); + const bool t = (ggml_backend_sched_graph_compute(sched, graph) == GGML_STATUS_SUCCESS); + + if (!t || sched_reset) { + ggml_backend_sched_reset(sched); + } + return t; } @@ -949,6 +953,15 @@ struct whisper_state { // [EXPERIMENTAL] speed-up techniques int32_t exp_n_audio_ctx = 0; // 0 - use default + + struct vad_segment_info { + float orig_start; + float orig_end; + float vad_start; + float vad_end; + }; + std::vector vad_segments; + bool has_vad_segments = false; }; struct whisper_context { @@ -4340,6 +4353,1118 @@ const char * whisper_print_system_info(void) { return s.c_str(); } +////////////////////////////////// +// Voice Activity Detection (VAD) +////////////////////////////////// + +struct whisper_vad_hparams { + int32_t n_encoder_layers; + int32_t * encoder_in_channels; + int32_t * encoder_out_channels; + int32_t * kernel_sizes; + int32_t lstm_input_size; + int32_t lstm_hidden_size; + int32_t final_conv_in; + int32_t final_conv_out; +}; + +struct whisper_vad_model { + std::string type; + std::string version; + whisper_vad_hparams hparams; + + struct ggml_tensor * stft_forward_basis; // [256, 1, 258] + + // Encoder tensors - 4 convolutional layers + struct ggml_tensor * encoder_0_weight; // [3, 129, 128] + struct ggml_tensor * encoder_0_bias; // [128] + + // Second encoder layer + struct ggml_tensor * encoder_1_weight; // [3, 128, 64] + struct ggml_tensor * encoder_1_bias; // [64] + + // Third encoder layer + struct ggml_tensor * encoder_2_weight; // [3, 64, 64] + struct ggml_tensor * encoder_2_bias; // [64] + + // Fourth encoder layer + struct ggml_tensor * encoder_3_weight; // [3, 64, 128] + struct ggml_tensor * encoder_3_bias; // [128] + + // LSTM decoder tensors + struct ggml_tensor * lstm_ih_weight; // [128, 512] input-to-hidden + struct ggml_tensor * lstm_ih_bias; // [512] + struct ggml_tensor * lstm_hh_weight; // [128, 512] hidden-to-hidden + struct ggml_tensor * lstm_hh_bias; // [512] + + // Final conv layer + struct ggml_tensor * final_conv_weight; // [128] + struct ggml_tensor * final_conv_bias; // [1] + + // ggml contexts + std::vector ctxs; + + // buffer for the model tensors + std::vector buffers; + + // tensors + int n_loaded; + std::map tensors; +}; + +struct whisper_vad_segment { + float start; // Start time in seconds + float end; // End time in seconds +}; + +struct whisper_vad_segments { + std::vector data; +}; + +struct whisper_vad_context { + int64_t t_vad_us = 0; + + int n_window; + int n_context; + int n_threads; + + std::vector backends; + ggml_backend_buffer_t buffer = nullptr; + whisper_context_params params; + std::vector ctx_buf; + whisper_sched sched; + + whisper_vad_model model; + std::string path_model; + struct ggml_tensor * h_state; + struct ggml_tensor * c_state; + std::vector probs; +}; + +struct whisper_vad_context_params whisper_vad_default_context_params(void) { + whisper_vad_context_params result = { + /*.n_thread = */ 4, + /*.use_gpu = */ false, + /*.gpu_device = */ 0, + }; + return result; +} + +struct whisper_vad_params whisper_vad_default_params(void) { + whisper_vad_params result = { + /* threshold = */ 0.5f, + /* min_speech_duration_ms = */ 250, + /* min_silence_duration_ms = */ 100, + /* max_speech_duration_s = */ FLT_MAX, + /* speech_pad_ms = */ 30, + /* samples_overlap = */ 0.1, + }; + return result; +} + +static bool weight_buft_supported(const whisper_vad_hparams & hparams, ggml_tensor * w, ggml_op op, ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev) { + bool op_supported = true; + + if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU || + (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU && buft == ggml_backend_cpu_buffer_type())) { + // GPU and default CPU backend support all operators + op_supported = true; + } else { + switch (op) { + // The current extra_buffer_type implementations only support GGML_OP_MUL_MAT + case GGML_OP_MUL_MAT: { + ggml_init_params params = { + /*.mem_size =*/ 2 * ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + + ggml_context_ptr ctx_ptr { ggml_init(params) }; + if (!ctx_ptr) { + throw std::runtime_error("failed to create ggml context"); + } + ggml_context * ctx = ctx_ptr.get(); + + ggml_tensor * op_tensor = nullptr; + + int64_t n_ctx = hparams.lstm_hidden_size; + ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], n_ctx, w->ne[2], w->ne[3]); + op_tensor = ggml_mul_mat(ctx, w, b); + + // create a temporary dummy buffer for the weight so that supports_op can check the buffer type + GGML_ASSERT(w->buffer == nullptr); + w->buffer = ggml_backend_buft_alloc_buffer(buft, 0); + op_supported = ggml_backend_dev_supports_op(dev, op_tensor); + ggml_backend_buffer_free(w->buffer); + w->buffer = nullptr; + break; + } + default: { + op_supported = false; + break; + } + }; + } + return op_supported; +} + +static ggml_backend_buffer_type_t select_weight_buft(const whisper_vad_hparams & hparams, ggml_tensor * w, ggml_op op, buft_list_t buft_list) { + GGML_ASSERT(!buft_list.empty()); + for (const auto & p : buft_list) { + ggml_backend_dev_t dev = p.first; + ggml_backend_buffer_type_t buft = p.second; + if (weight_buft_supported(hparams, w, op, buft, dev)) { + return buft; + } + } + + return nullptr; +} + +static ggml_tensor * whisper_vad_build_stft_layer(ggml_context * ctx0, + const whisper_vad_model & model, ggml_tensor * cur) { + // Apply reflective padding to the input tensor + ggml_tensor * padded = ggml_pad_reflect_1d(ctx0, cur, 64, 64); + + struct ggml_tensor * stft = ggml_conv_1d(ctx0, model.stft_forward_basis, padded, model.hparams.lstm_input_size, 0, 1); + + // Calculate cutoff for real/imaginary parts + int cutoff = model.stft_forward_basis->ne[2] / 2; + + // Extract real part (first half of the STFT output). + struct ggml_tensor * real_part = ggml_view_2d(ctx0, stft, 4, cutoff, stft->nb[1], 0); + // Extract imaginary part (second half of the STFT output). + struct ggml_tensor * img_part = ggml_view_2d(ctx0, stft, 4, cutoff, stft->nb[1], cutoff * stft->nb[1]); + + // Calculate magnitude: sqrt(real^2 + imag^2) + struct ggml_tensor * real_squared = ggml_mul(ctx0, real_part, real_part); + struct ggml_tensor * img_squared = ggml_mul(ctx0, img_part, img_part); + struct ggml_tensor * sum_squares = ggml_add(ctx0, real_squared, img_squared); + struct ggml_tensor * magnitude = ggml_sqrt(ctx0, sum_squares); + return magnitude; +} + +static ggml_tensor * whisper_vad_build_encoder_layer(ggml_context * ctx0, + const whisper_vad_model & model, ggml_tensor * cur) { + // First Conv1D: expands to 128 channels. + cur = ggml_conv_1d(ctx0, model.encoder_0_weight, cur, 1, 1, 1); + cur = ggml_add(ctx0, cur, ggml_reshape_3d(ctx0, model.encoder_0_bias, 1, 128, 1)); + cur = ggml_relu(ctx0, cur); + + // Second Conv1D: reduces to 64 channels. + cur = ggml_conv_1d(ctx0, model.encoder_1_weight, cur, 2, 1, 1); + cur = ggml_add(ctx0, cur, ggml_reshape_3d(ctx0, model.encoder_1_bias, 1, 64, 1)); + cur = ggml_relu(ctx0, cur); + + // Third Conv1D: maintains 64 channels + cur = ggml_conv_1d(ctx0, model.encoder_2_weight, cur, 2, 1, 1); + cur = ggml_add(ctx0, cur, ggml_reshape_3d(ctx0, model.encoder_2_bias, 1, 64, 1)); + cur = ggml_relu(ctx0, cur); + + // Fourth Conv1D: expands to 128 channels + cur = ggml_conv_1d(ctx0, model.encoder_3_weight, cur, 1, 1, 1); + cur = ggml_add(ctx0, cur, ggml_reshape_3d(ctx0, model.encoder_3_bias, 1, 128, 1)); + cur = ggml_relu(ctx0, cur); + + return cur; +} + +static ggml_tensor * whisper_vad_build_lstm_layer(ggml_context * ctx0, + const whisper_vad_context & vctx, ggml_tensor * cur, ggml_cgraph * gf) { + const whisper_vad_model & model = vctx.model; + const int hdim = model.hparams.lstm_hidden_size; + + struct ggml_tensor * x_t = ggml_transpose(ctx0, cur); + + // Create operations using the input-to-hidden weights. + struct ggml_tensor * inp_gate = ggml_mul_mat(ctx0, model.lstm_ih_weight, x_t); + inp_gate = ggml_add(ctx0, inp_gate, model.lstm_ih_bias); + + // Create operations using the hidden-to-hidden weights. + struct ggml_tensor * hid_gate = ggml_mul_mat(ctx0, model.lstm_hh_weight, vctx.h_state); + hid_gate = ggml_add(ctx0, hid_gate, model.lstm_hh_bias); + + // Create add operation to get preactivations for all gates. + struct ggml_tensor * out_gate = ggml_add(ctx0, inp_gate, hid_gate); + + const size_t hdim_size = ggml_row_size(out_gate->type, hdim); + + // Create sigmoid for input gate (using the first 128 bytes from the preactivations). + struct ggml_tensor * i_t = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, out_gate, hdim, 0 * hdim_size)); + + // Create sigmoid for the forget gate (using the second 128 bytes from the preactivations). + struct ggml_tensor * f_t = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, out_gate, hdim, 1 * hdim_size)); + + // Create sigmoid for the cell gate (using the third 128 bytes from the preactivations). + struct ggml_tensor * g_t = ggml_tanh(ctx0, ggml_view_1d(ctx0, out_gate, hdim, 2 * hdim_size)); + + // Create sigmoid for the output gate (using the fourth 128 bytes from the preactivations). + struct ggml_tensor * o_t = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, out_gate, hdim, 3 * hdim_size)); + + // Update cell state + struct ggml_tensor * c_out = ggml_add(ctx0, + ggml_mul(ctx0, f_t, vctx.c_state), + ggml_mul(ctx0, i_t, g_t)); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, c_out, vctx.c_state)); + + // Update hidden state + struct ggml_tensor * out = ggml_mul(ctx0, o_t, ggml_tanh(ctx0, c_out)); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, out, vctx.h_state)); + + return out; +} + +static struct ggml_cgraph * whisper_vad_build_graph(whisper_vad_context & vctx) { + const auto & model = vctx.model; + + struct ggml_init_params params = { + /*.mem_size =*/ vctx.sched.meta.size(), + /*.mem_buffer =*/ vctx.sched.meta.data(), + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx0 = ggml_init(params); + + ggml_cgraph * gf = ggml_new_graph(ctx0); + + struct ggml_tensor * frame = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, vctx.n_window, 1); + ggml_set_name(frame, "frame"); + ggml_set_input(frame); + + struct ggml_tensor * cur = nullptr; + { + cur = whisper_vad_build_stft_layer(ctx0, model, frame); + + cur = whisper_vad_build_encoder_layer(ctx0, model, cur); + + // Extract the first element of the first dimension + // (equivalent to pytorch's [:, :, 0]) + cur = ggml_view_2d(ctx0, cur, 1, 128, cur->nb[1], 0); + + cur = whisper_vad_build_lstm_layer(ctx0, vctx, cur, gf); + cur = ggml_relu(ctx0, cur); + cur = ggml_conv_1d(ctx0, model.final_conv_weight, cur, 1, 0, 1); + cur = ggml_add(ctx0, cur, model.final_conv_bias); + cur = ggml_sigmoid(ctx0, cur); + ggml_set_name(cur, "prob"); + ggml_set_output(cur); + } + + ggml_build_forward_expand(gf, cur); + + ggml_free(ctx0); + + return gf; +} + +static bool whisper_vad_init_context(whisper_vad_context * vctx) { + + auto whisper_context_params = whisper_context_default_params(); + // TODO: GPU VAD is forced disabled until the performance is improved + //whisper_context_params.use_gpu = vctx->params.use_gpu; + whisper_context_params.use_gpu = false; + whisper_context_params.gpu_device = vctx->params.gpu_device; + + vctx->backends = whisper_backend_init(whisper_context_params); + if (vctx->backends.empty()) { + WHISPER_LOG_ERROR("%s: whisper_backend_init() failed\n", __func__); + return false; + } + + const int32_t lstm_hidden_size = vctx->model.hparams.lstm_hidden_size; + + vctx->ctx_buf.resize(2u*ggml_tensor_overhead()); + + struct ggml_init_params params = { + /*.mem_size =*/ vctx->ctx_buf.size(), + /*.mem_buffer =*/ vctx->ctx_buf.data(), + /*.no_alloc =*/ true, + }; + + ggml_context * ctx = ggml_init(params); + if (!ctx) { + WHISPER_LOG_ERROR("%s: failed to init LSTM state ggml context\n", __func__); + return false; + } + + // LSTM Hidden state + vctx->h_state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, lstm_hidden_size); + ggml_set_name(vctx->h_state, "h_state"); + + // LSTM Cell state + vctx->c_state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, lstm_hidden_size); + ggml_set_name(vctx->c_state, "c_state"); + + vctx->buffer = ggml_backend_alloc_ctx_tensors(ctx, vctx->backends[0]); + if (!vctx->buffer) { + WHISPER_LOG_ERROR("%s: failed to allocate memory for the VAD state\n", __func__); + return false; + } + + { + bool ok = whisper_sched_graph_init(vctx->sched, vctx->backends, + [&]() { + return whisper_vad_build_graph(*vctx); + }); + + if (!ok) { + WHISPER_LOG_ERROR("%s: failed to init VAD allocator\n", __func__); + return false; + } + + WHISPER_LOG_INFO("%s: compute buffer (VAD) = %7.2f MB\n", __func__, whisper_sched_size(vctx->sched) / 1e6); + } + + return true; +} + +struct whisper_vad_context * whisper_vad_init_from_file_with_params( + const char * path_model, + struct whisper_vad_context_params params) { + WHISPER_LOG_INFO("%s: loading VAD model from '%s'\n", __func__, path_model); +#ifdef _MSC_VER + std::wstring_convert> converter; + std::wstring path_model_wide = converter.from_bytes(path_model); + auto fin = std::ifstream(path_model_wide, std::ios::binary); +#else + auto fin = std::ifstream(path_model, std::ios::binary); +#endif + if (!fin) { + WHISPER_LOG_ERROR("%s: failed to open VAD model '%s'\n", __func__, path_model); + return nullptr; + } + + whisper_model_loader loader = {}; + loader.context = &fin; + + loader.read = [](void * ctx, void * output, size_t read_size) { + std::ifstream * fin = (std::ifstream*)ctx; + fin->read((char *)output, read_size); + return read_size; + }; + + loader.eof = [](void * ctx) { + std::ifstream * fin = (std::ifstream*)ctx; + return fin->eof(); + }; + + loader.close = [](void * ctx) { + std::ifstream * fin = (std::ifstream*)ctx; + fin->close(); + }; + + auto ctx = whisper_vad_init_with_params(&loader, params); + if (!ctx) { + whisper_vad_free(ctx); + return nullptr; + } + ctx->path_model = path_model; + return ctx; +} + +struct whisper_vad_context * whisper_vad_init_with_params( + struct whisper_model_loader * loader, + struct whisper_vad_context_params params) { + // Read the VAD model + { + uint32_t magic; + read_safe(loader, magic); + if (magic != GGML_FILE_MAGIC) { + WHISPER_LOG_ERROR("%s: invalid model data (bad magic)\n", __func__); + return nullptr; + } + } + + whisper_vad_context * vctx = new whisper_vad_context; + vctx->n_threads = params.n_threads; + vctx->params.use_gpu = params.use_gpu; + vctx->params.gpu_device = params.gpu_device; + + auto & model = vctx->model; + auto & hparams = model.hparams; + + // load model context params. + { + int32_t str_len; + read_safe(loader, str_len); + std::vector buffer(str_len + 1, 0); + loader->read(loader->context, buffer.data(), str_len); + std::string model_type(buffer.data(), str_len); + model.type = model_type; + WHISPER_LOG_INFO("%s: model type: %s\n", __func__, model.type.c_str()); + + int32_t major, minor, patch; + read_safe(loader, major); + read_safe(loader, minor); + read_safe(loader, patch); + std::string version_str = std::to_string(major) + "." + + std::to_string(minor) + "." + + std::to_string(patch); + model.version = version_str; + WHISPER_LOG_INFO("%s: model version: %s\n", __func__, model.version.c_str()); + + read_safe(loader, vctx->n_window); + read_safe(loader, vctx->n_context); + } + + // load model hyper params (hparams). + { + read_safe(loader, hparams.n_encoder_layers); + + hparams.encoder_in_channels = new int32_t[hparams.n_encoder_layers]; + hparams.encoder_out_channels = new int32_t[hparams.n_encoder_layers]; + hparams.kernel_sizes = new int32_t[hparams.n_encoder_layers]; + + for (int32_t i = 0; i < hparams.n_encoder_layers; i++) { + read_safe(loader, hparams.encoder_in_channels[i]); + read_safe(loader, hparams.encoder_out_channels[i]); + read_safe(loader, hparams.kernel_sizes[i]); + } + + read_safe(loader, hparams.lstm_input_size); + read_safe(loader, hparams.lstm_hidden_size); + read_safe(loader, hparams.final_conv_in); + read_safe(loader, hparams.final_conv_out); + + WHISPER_LOG_INFO("%s: n_encoder_layers = %d\n", __func__, hparams.n_encoder_layers); + for (int32_t i = 0; i < hparams.n_encoder_layers; i++) { + WHISPER_LOG_INFO("%s: encoder_in_channels[%d] = %d\n", __func__, i, hparams.encoder_in_channels[i]); + } + for (int32_t i = 0; i < hparams.n_encoder_layers; i++) { + WHISPER_LOG_INFO("%s: encoder_out_channels[%d] = %d\n", __func__, i, hparams.encoder_out_channels[i]); + } + WHISPER_LOG_INFO("%s: lstm_input_size = %d\n", __func__, hparams.lstm_input_size); + WHISPER_LOG_INFO("%s: lstm_hidden_size = %d\n", __func__, hparams.lstm_hidden_size); + WHISPER_LOG_INFO("%s: final_conv_in = %d\n", __func__, hparams.final_conv_in); + WHISPER_LOG_INFO("%s: final_conv_out = %d\n", __func__, hparams.final_conv_out); + } + + // 1 STFT tensor, 4*2 encoder tensors, 4 LSTM tensors, 2 final output tensors + const size_t n_tensors = hparams.n_encoder_layers * 2 + 4 + 2 + 1; + + std::map ctx_map; + auto get_ctx = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { + auto it = ctx_map.find(buft); + if (it == ctx_map.end()) { + ggml_init_params params = { + /*.mem_size =*/ n_tensors * ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + + ggml_context * ctx = ggml_init(params); + if (!ctx) { + throw std::runtime_error("failed to create ggml context"); + } + + ctx_map[buft] = ctx; + model.ctxs.emplace_back(ctx); + + return ctx; + } + + return it->second; + }; + + whisper_context_params wparams = whisper_context_default_params(); + wparams.use_gpu = params.use_gpu; + wparams.gpu_device = params.gpu_device; + buft_list_t buft_list = make_buft_list(wparams); + + auto create_tensor = [&](vad_tensor type, ggml_tensor * meta) -> ggml_tensor * { + ggml_op op = VAD_TENSOR_OPS.at(type); + ggml_backend_buffer_type_t buft = select_weight_buft(hparams, meta, op, buft_list); + if (!buft) { + throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", VAD_TENSOR_NAMES.at(type))); + } + ggml_context * ctx = get_ctx(buft); + ggml_tensor * tensor = ggml_dup_tensor(ctx, meta); + model.tensors[VAD_TENSOR_NAMES.at(type)] = tensor; + + return tensor; + }; + + // create tensors + { + ggml_init_params params = { + /*.mem_size =*/ n_tensors * ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + + ggml_context * ctx = ggml_init(params); + const auto & hparams = model.hparams; + + // SFTF precomputed basis matrix + model.stft_forward_basis = create_tensor(VAD_TENSOR_STFT_BASIS, + ggml_new_tensor_3d(ctx, GGML_TYPE_F16, 256, 1, 258)); + + model.encoder_0_weight = create_tensor(VAD_TENSOR_ENC_0_WEIGHT, + ggml_new_tensor_3d( + ctx, + GGML_TYPE_F16, + hparams.kernel_sizes[0], + hparams.encoder_in_channels[0], + hparams.encoder_out_channels[0] + )); + model.encoder_0_bias = create_tensor(VAD_TENSOR_ENC_0_BIAS, + ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.encoder_out_channels[0])); + + model.encoder_1_weight = create_tensor(VAD_TENSOR_ENC_1_WEIGHT, + ggml_new_tensor_3d( + ctx, + GGML_TYPE_F16, + hparams.kernel_sizes[1], + hparams.encoder_in_channels[1], + hparams.encoder_out_channels[1] + )); + model.encoder_1_bias = create_tensor(VAD_TENSOR_ENC_1_BIAS, + ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.encoder_out_channels[1])); + + model.encoder_2_weight = create_tensor(VAD_TENSOR_ENC_2_WEIGHT, + ggml_new_tensor_3d( + ctx, + GGML_TYPE_F16, + hparams.kernel_sizes[2], + hparams.encoder_in_channels[2], + hparams.encoder_out_channels[2] + )); + model.encoder_2_bias = create_tensor(VAD_TENSOR_ENC_2_BIAS, + ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.encoder_out_channels[2])); + + model.encoder_3_weight = create_tensor(VAD_TENSOR_ENC_3_WEIGHT, + ggml_new_tensor_3d( + ctx, + GGML_TYPE_F16, + hparams.kernel_sizes[3], + hparams.encoder_in_channels[3], + hparams.encoder_out_channels[3] + )); + model.encoder_3_bias = create_tensor(VAD_TENSOR_ENC_3_BIAS, + ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.encoder_out_channels[3])); + + // Hidden State dimension (input gate, forget gate, cell gate, output gate) + const int hstate_dim = hparams.lstm_hidden_size * 4; + + // LSTM weights - input to hidden + model.lstm_ih_weight = create_tensor( + VAD_TENSOR_LSTM_WEIGHT_IH, + ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hparams.lstm_hidden_size, hstate_dim) + ); + model.lstm_ih_bias = create_tensor( + VAD_TENSOR_LSTM_BIAS_IH, + ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hstate_dim) + ); + + // LSTM weights - hidden to hidden + model.lstm_hh_weight = create_tensor( + VAD_TENSOR_LSTM_WEIGHT_HH, + ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hparams.lstm_hidden_size, hstate_dim) + ); + model.lstm_hh_bias = create_tensor( + VAD_TENSOR_LSTM_BIAS_HH, + ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hstate_dim) + ); + + // Final conv layer weight + model.final_conv_weight = create_tensor( + VAD_TENSOR_FINAL_CONV_WEIGHT, + ggml_new_tensor_2d(ctx, GGML_TYPE_F16, hparams.final_conv_in, 1) + ); + model.final_conv_bias = create_tensor( + VAD_TENSOR_FINAL_CONV_BIAS, + ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1) + ); + + ggml_free(ctx); + } + + // allocate tensors in the backend buffers + for (auto & p : ctx_map) { + ggml_backend_buffer_type_t buft = p.first; + ggml_context * ctx = p.second; + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + if (buf) { + model.buffers.emplace_back(buf); + + size_t size_main = ggml_backend_buffer_get_size(buf); + WHISPER_LOG_INFO("%s: %12s total size = %8.2f MB\n", __func__, ggml_backend_buffer_name(buf), size_main / 1e6); + } + } + + // load weights + { + size_t total_size = 0; + model.n_loaded = 0; + std::vector read_buf; + + while (true) { + int32_t n_dims; + int32_t length; + int32_t ttype; + + read_safe(loader, n_dims); + read_safe(loader, length); + read_safe(loader, ttype); + + if (loader->eof(loader->context)) { + break; + } + + int32_t nelements = 1; + int32_t ne[4] = { 1, 1, 1, 1 }; + for (int i = 0; i < n_dims; ++i) { + read_safe(loader, ne[i]); + nelements *= ne[i]; + } + + std::string name; + std::vector tmp(length); + loader->read(loader->context, &tmp[0], tmp.size()); + name.assign(&tmp[0], tmp.size()); + + if (model.tensors.find(name) == model.tensors.end()) { + WHISPER_LOG_ERROR("%s: unknown tensor '%s' in model file\n", __func__, name.data()); + return nullptr; + } + + auto tensor = model.tensors[name.data()]; + + if (ggml_nelements(tensor) != nelements) { + WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data()); + WHISPER_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n", + __func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]); + return nullptr; + } + + if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) { + WHISPER_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n", + __func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]); + return nullptr; + } + + const size_t bpe = ggml_type_size(ggml_type(ttype)); + + if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) { + WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", + __func__, name.data(), ggml_nbytes(tensor), nelements*bpe); + return nullptr; + } + + if (ggml_backend_buffer_is_host(tensor->buffer)) { + // for the CPU and Metal backend, we can read directly into the tensor + loader->read(loader->context, tensor->data, ggml_nbytes(tensor)); + BYTESWAP_TENSOR(tensor); + } else { + // read into a temporary buffer first, then copy to device memory + read_buf.resize(ggml_nbytes(tensor)); + + loader->read(loader->context, read_buf.data(), read_buf.size()); + + ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor)); + } + + total_size += ggml_nbytes(tensor); + model.n_loaded++; + } + + WHISPER_LOG_INFO("%s: model size = %7.2f MB\n", __func__, total_size/1e6); + + if (model.n_loaded == 0) { + WHISPER_LOG_WARN("%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__); + } else if (model.n_loaded != (int) model.tensors.size()) { + WHISPER_LOG_ERROR("%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded); + return nullptr; + } + + } + + if (!whisper_vad_init_context(vctx)) { + whisper_vad_free(vctx); + return nullptr; + } + + return vctx; +} + +bool whisper_vad_detect_speech( + struct whisper_vad_context * vctx, + const float * samples, + int n_samples) { + int n_chunks = n_samples / vctx->n_window; + if (n_samples % vctx->n_window != 0) { + n_chunks += 1; // Add one more chunk for remaining samples. + } + + WHISPER_LOG_INFO("%s: detecting speech in %d samples\n", __func__, n_samples); + WHISPER_LOG_INFO("%s: n_chunks: %d\n", __func__, n_chunks); + + // Reset LSTM hidden/cell states + ggml_backend_buffer_clear(vctx->buffer, 0); + + vctx->probs.resize(n_chunks); + WHISPER_LOG_INFO("%s: props size: %u\n", __func__, n_chunks); + + std::vector window(vctx->n_window, 0.0f); + + auto & sched = vctx->sched.sched; + + ggml_cgraph * gf = whisper_vad_build_graph(*vctx); + + if (!ggml_backend_sched_alloc_graph(sched, gf)) { + WHISPER_LOG_ERROR("%s: failed to allocate the compute buffer\n", __func__); + return false; + } + + struct ggml_tensor * frame = ggml_graph_get_tensor(gf, "frame"); + struct ggml_tensor * prob = ggml_graph_get_tensor(gf, "prob"); + + // we are going to reuse the graph multiple times for each chunk + const int64_t t_start_vad_us = ggml_time_us(); + + for (int i = 0; i < n_chunks; i++) { + const int idx_start = i * vctx->n_window; + const int idx_end = std::min(idx_start + vctx->n_window, n_samples); + + const int chunk_len = idx_end - idx_start; + + if (chunk_len < vctx->n_window) { + WHISPER_LOG_INFO("%s: chunk_len: %d < n_window: %d\n", __func__, chunk_len, vctx->n_window); + std::vector partial_chunk(vctx->n_window, 0.0f); + std::copy(samples + idx_start, samples + idx_end, partial_chunk.begin()); + + // Copy the zero-padded chunk to the window. + const int samples_to_copy_max = vctx->n_window; + const int samples_to_copy_cur = std::min(samples_to_copy_max, (int)partial_chunk.size()); + std::copy(partial_chunk.begin(), partial_chunk.begin() + samples_to_copy_cur, window.begin()); + if (samples_to_copy_cur < samples_to_copy_max) { + std::fill(window.begin() + samples_to_copy_cur, window.end(), 0.0f); + } + } else { + // Copy current frame samples to the window. + const int samples_to_copy = std::min(idx_end - idx_start, vctx->n_window); + std::copy(samples + idx_start, samples + idx_start + samples_to_copy, window.begin()); + } + + // Set the frame tensor data with the samples. + ggml_backend_tensor_set(frame, window.data(), 0, ggml_nelements(frame) * sizeof(float)); + + // do not reset the scheduler - we will reuse the graph in the next chunk + if (!ggml_graph_compute_helper(sched, gf, vctx->n_threads, false)) { + WHISPER_LOG_ERROR("%s: failed to compute VAD graph\n", __func__); + break; + } + + // Get the probability for this chunk. + ggml_backend_tensor_get(prob, &vctx->probs[i], 0, sizeof(float)); + + //WHISPER_LOG_DEBUG("chunk %d: p = %7.3f\n", i, probs[i]); + } + + vctx->t_vad_us += ggml_time_us() - t_start_vad_us; + WHISPER_LOG_INFO("%s: vad time = %.2f ms processing %d samples\n", __func__, 1e-3f * vctx->t_vad_us, n_samples); + + ggml_backend_sched_reset(sched); + + return true; +} + +int whisper_vad_segments_n_segments(struct whisper_vad_segments * segments) { + return segments->data.size(); +} + +float whisper_vad_segments_get_segment_t0(struct whisper_vad_segments * segments, int i_segment) { + return segments->data[i_segment].start; +} + +float whisper_vad_segments_get_segment_t1(struct whisper_vad_segments * segments, int i_segment) { + return segments->data[i_segment].end; +} + +int whisper_vad_n_probs(struct whisper_vad_context * vctx) { + return vctx->probs.size(); +} + +float * whisper_vad_probs(struct whisper_vad_context * vctx) { + return vctx->probs.data(); +} + +struct whisper_vad_segments * whisper_vad_segments_from_probs( + struct whisper_vad_context * vctx, + whisper_vad_params params) { + WHISPER_LOG_INFO("%s: detecting speech timestamps using %d probabilities\n", __func__, whisper_vad_n_probs(vctx)); + + int n_probs = whisper_vad_n_probs(vctx); + float * probs = whisper_vad_probs(vctx); + float threshold = params.threshold; + int min_speech_duration_ms = params.min_speech_duration_ms; + int min_silence_duration_ms = params.min_silence_duration_ms; + float max_speech_duration_s = params.max_speech_duration_s; + int speech_pad_ms = params.speech_pad_ms; + int n_window = vctx->n_window; + int sample_rate = WHISPER_SAMPLE_RATE; + int min_silence_samples = sample_rate * min_silence_duration_ms / 1000; + int audio_length_samples = n_probs * n_window; + + // Min number of samples to be considered valid speech. + int min_speech_samples = sample_rate * min_speech_duration_ms / 1000; + int speech_pad_samples = sample_rate * speech_pad_ms / 1000; + + // Max number of samples that a speech segment can contain before it is + // split into multiple segments. + int max_speech_samples; + if (max_speech_duration_s > 100000.0f) { + max_speech_samples = INT_MAX / 2; + } else { + int64_t temp = (int64_t)sample_rate * (int64_t)(max_speech_duration_s) - n_window - 2 * speech_pad_samples; + max_speech_samples = (temp > INT_MAX) ? INT_MAX / 2 : (int)temp; + if (max_speech_samples < 0) { + max_speech_samples = INT_MAX / 2; + } + } + // Detect silence period that exceeds this value, then that location (sample) + // is marked as a potential place where the segment could be split if + // max_speech_samples is reached. The value 98 was taken from the original + // silaro-vad python implementation: + //https://github.com/snakers4/silero-vad/blob/0dd45f0bcd7271463c234f3bae5ad25181f9df8b/src/silero_vad/utils_vad.py#L291 + int min_silence_samples_at_max_speech = sample_rate * 98 / 1000; + + // Calculate lower threshold for detecting end of speech segments. + float neg_threshold = threshold - 0.15f; + if (neg_threshold < 0.01f) { + neg_threshold = 0.01f; + } + + struct speech_segment_t { + int start; + int end; + }; + + std::vector speeches; + speeches.reserve(256); + + bool is_speech_segment = false; + int temp_end = 0; + int prev_end = 0; + int next_start = 0; + int curr_speech_start = 0; + bool has_curr_speech = false; + + for (int i = 0; i < n_probs; i++) { + float curr_prob = probs[i]; + int curr_sample = n_window * i; + + // Reset temp_end when we get back to speech + if ((curr_prob >= threshold) && temp_end) { + temp_end = 0; + if (next_start < prev_end) { + next_start = curr_sample; + } + } + + // Start a new speech segment when probability exceeds threshold and not already in speech + if ((curr_prob >= threshold) && !is_speech_segment) { + is_speech_segment = true; + curr_speech_start = curr_sample; + has_curr_speech = true; + continue; + } + + // Handle maximum speech duration + if (is_speech_segment && (curr_sample - curr_speech_start) > max_speech_samples) { + if (prev_end) { + speeches.push_back({ curr_speech_start, prev_end }); + has_curr_speech = true; + + if (next_start < prev_end) { // Previously reached silence and is still not speech + is_speech_segment = false; + has_curr_speech = false; + } else { + curr_speech_start = next_start; + } + prev_end = next_start = temp_end = 0; + } else { + speeches.push_back({ curr_speech_start, curr_sample }); + + prev_end = next_start = temp_end = 0; + is_speech_segment = false; + has_curr_speech = false; + continue; + } + } + + // Handle silence after speech + if ((curr_prob < neg_threshold) && is_speech_segment) { + if (!temp_end) { + temp_end = curr_sample; + } + + // Track potential segment ends for max_speech handling + if ((curr_sample - temp_end) > min_silence_samples_at_max_speech) { + prev_end = temp_end; + } + + // Check if silence is long enough to end the segment + if ((curr_sample - temp_end) < min_silence_samples) { + continue; + } else { + // End the segment if it's long enough + if ((temp_end - curr_speech_start) > min_speech_samples) { + speeches.push_back({ curr_speech_start, temp_end }); + } + + prev_end = next_start = temp_end = 0; + is_speech_segment = false; + has_curr_speech = false; + continue; + } + } + } + + // Handle the case if we're still in a speech segment at the end + if (has_curr_speech && (audio_length_samples - curr_speech_start) > min_speech_samples) { + speeches.push_back({ curr_speech_start, audio_length_samples }); + } + + // Merge adjacent segments with small gaps in between (post-processing) + if (speeches.size() > 1) { + int merged_count = 0; + for (int i = 0; i < (int) speeches.size() - 1; i++) { + // Define maximum gap allowed for merging (e.g., 200ms converted to samples) + int max_merge_gap_samples = sample_rate * 200 / 1000; + + // If the gap between this segment and the next is small enough + if (speeches[i+1].start - speeches[i].end < max_merge_gap_samples) { + // Merge by extending current segment to the end of next segment + speeches[i].end = speeches[i+1].end; + speeches.erase(speeches.begin() + i + 1); + + i--; + merged_count++; + } + } + WHISPER_LOG_INFO("%s: Merged %d adjacent segments, now have %d segments\n", + __func__, merged_count, (int) speeches.size()); + } + + // Double-check for minimum speech duration + for (int i = 0; i < (int) speeches.size(); i++) { + if (speeches[i].end - speeches[i].start < min_speech_samples) { + WHISPER_LOG_INFO("%s: Removing segment %d (too short: %d samples)\n", + __func__, i, speeches[i].end - speeches[i].start); + + speeches.erase(speeches.begin() + i); + i--; + } + } + + WHISPER_LOG_INFO("%s: Final speech segments after filtering: %d\n", __func__, (int) speeches.size()); + + // Allocate final segments + std::vector segments; + if (speeches.size() > 0) { + try { + segments.resize(speeches.size()); + } catch (const std::bad_alloc &) { + WHISPER_LOG_ERROR("%s: failed to allocate memory for final segments\n", __func__); + return nullptr; + } + } + + // Apply padding to segments and copy to final segments + for (int i = 0; i < (int) speeches.size(); i++) { + // Apply padding to the start of the first segment + if (i == 0) { + speeches[i].start = + (speeches[i].start > speech_pad_samples) ? + (speeches[i].start - speech_pad_samples) : 0; + } + + // Handle spacing between segments + if (i < (int) speeches.size() - 1) { + int silence_duration = speeches[i+1].start - speeches[i].end; + + if (silence_duration < 2 * speech_pad_samples) { + // If segments are close, split the difference + speeches[i].end += silence_duration / 2; + speeches[i+1].start = + (speeches[i+1].start > silence_duration / 2) ? + (speeches[i+1].start - silence_duration / 2) : 0; + } else { + // Otherwise, apply full padding to both + speeches[i].end = + (speeches[i].end + speech_pad_samples < audio_length_samples) ? + (speeches[i].end + speech_pad_samples) : audio_length_samples; + speeches[i+1].start = + (speeches[i+1].start > speech_pad_samples) ? + (speeches[i+1].start - speech_pad_samples) : 0; + } + } else { + // Apply padding to the end of the last segment + speeches[i].end = + (speeches[i].end + speech_pad_samples < audio_length_samples) ? + (speeches[i].end + speech_pad_samples) : audio_length_samples; + } + + // Convert from samples to seconds and copy to final segments + segments[i].start = (float)speeches[i].start / sample_rate; + segments[i].end = (float)speeches[i].end / sample_rate; + + WHISPER_LOG_INFO("%s: VAD segment %d: start = %.2f, end = %.2f (duration: %.2f)\n", + __func__, i, segments[i].start, segments[i].end, segments[i].end - segments[i].start); + } + + whisper_vad_segments * vad_segments = new whisper_vad_segments; + if (vad_segments == NULL) { + WHISPER_LOG_ERROR("%s: failed to allocate memory for whisper_vad_segments\n", __func__); + return nullptr; + } + + vad_segments->data = std::move(segments); + + return vad_segments; +} + +struct whisper_vad_segments * whisper_vad_segments_from_samples( + whisper_vad_context * vctx, + whisper_vad_params params, + const float * samples, + int n_samples) { + WHISPER_LOG_INFO("%s: detecting speech timestamps in %d samples\n", __func__, n_samples); + if (!whisper_vad_detect_speech(vctx, samples, n_samples)) { + WHISPER_LOG_ERROR("%s: failed to detect speech\n", __func__); + return nullptr; + } + return whisper_vad_segments_from_probs(vctx, params); +} + +void whisper_vad_free(whisper_vad_context * ctx) { + if (ctx) { + for (ggml_context * context : ctx->model.ctxs) { + ggml_free(context); + } + + for (ggml_backend_buffer_t buf : ctx->model.buffers) { + ggml_backend_buffer_free(buf); + } + + ggml_backend_sched_free(ctx->sched.sched); + + for (auto & backend : ctx->backends) { + ggml_backend_free(backend); + } + + + delete ctx; + } +} + +void whisper_vad_free_segments(whisper_vad_segments * segments) { + if (segments) { + delete segments; + } +} + ////////////////////////////////// // Grammar - ported from llama.cpp ////////////////////////////////// @@ -4856,6 +5981,11 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.n_grammar_rules =*/ 0, /*.i_start_rule =*/ 0, /*.grammar_penalty =*/ 100.0f, + + /*.vad =*/ false, + /*.vad_model_path =*/ nullptr, + + /* vad_params =*/ whisper_vad_default_params(), }; switch (strategy) { @@ -5472,6 +6602,117 @@ static void whisper_sequence_score( } } +static bool whisper_vad( + struct whisper_context * ctx, + struct whisper_state * state, + struct whisper_full_params params, + const float * samples, + int n_samples, + std::vector & filtered_samples, + int & filtered_n_samples) { + WHISPER_LOG_INFO("%s: VAD is enabled, processing speach segments only\n", __func__); + filtered_n_samples = 0; + + struct whisper_vad_context_params vad_ctx_params = whisper_vad_default_context_params(); + struct whisper_vad_context * vctx = whisper_vad_init_from_file_with_params(params.vad_model_path, vad_ctx_params); + if (vctx == nullptr) { + WHISPER_LOG_ERROR("%s: failed to initialize VAD context\n", __func__); + return false; + } + + const whisper_vad_params & vad_params = params.vad_params; + + whisper_vad_segments * vad_segments = whisper_vad_segments_from_samples(vctx, vad_params, samples, n_samples); + + if (vad_segments->data.size() > 0) { + state->has_vad_segments = true; + ctx->state->vad_segments.clear(); + ctx->state->vad_segments.reserve(vad_segments->data.size()); + + WHISPER_LOG_INFO("%s: detected %d speech segments\n", __func__, (int)vad_segments->data.size()); + float overlap_seconds = vad_params.samples_overlap; + int overlap_samples = overlap_seconds * WHISPER_SAMPLE_RATE; + + for (int i = 0; i < (int)vad_segments->data.size(); i++) { + int segment_start_samples = vad_segments->data[i].start * WHISPER_SAMPLE_RATE; + int segment_end_samples = vad_segments->data[i].end * WHISPER_SAMPLE_RATE; + + if (i < (int)vad_segments->data.size() - 1) { + segment_end_samples += overlap_samples; + } + segment_end_samples = std::min(segment_end_samples, n_samples - 1); + filtered_n_samples += (segment_end_samples - segment_start_samples); + + WHISPER_LOG_INFO("%s: Including segment %d: %.2f - %.2f (duration: %.2f)\n", + __func__, i, vad_segments->data[i].start, + vad_segments->data[i].end + (i < (int)vad_segments->data.size() - 1 ? overlap_seconds : 0), + (vad_segments->data[i].end - vad_segments->data[i].start) + + (i < (int)vad_segments->data.size() - 1 ? overlap_seconds : 0)); + } + + int silence_samples = 0.1 * WHISPER_SAMPLE_RATE; + int total_silence_samples = (vad_segments->data.size() > 1) ? (vad_segments->data.size() - 1) * silence_samples : 0; + int total_samples_needed = filtered_n_samples + total_silence_samples; + + WHISPER_LOG_INFO("%s: total duration of speech segments: %.2f seconds\n", + __func__, (float)filtered_n_samples / WHISPER_SAMPLE_RATE); + + try { + filtered_samples.resize(total_samples_needed); + } catch (const std::bad_alloc & /* e */) { + WHISPER_LOG_ERROR("%s: failed to allocate memory for filtered samples\n", __func__); + whisper_vad_free_segments(vad_segments); + whisper_vad_free(vctx); + return false; + } + + int offset = 0; + for (int i = 0; i < (int)vad_segments->data.size(); i++) { + int segment_start_samples = vad_segments->data[i].start * WHISPER_SAMPLE_RATE; + int segment_end_samples = vad_segments->data[i].end * WHISPER_SAMPLE_RATE; + + if (i < (int)vad_segments->data.size() - 1) { + segment_end_samples += overlap_samples; + } + + segment_start_samples = std::min(segment_start_samples, n_samples - 1); + segment_end_samples = std::min(segment_end_samples, n_samples); + int segment_length = segment_end_samples - segment_start_samples; + + if (segment_length > 0) { + whisper_state::vad_segment_info segment; + + segment.orig_start = vad_segments->data[i].start; + segment.orig_end = vad_segments->data[i].end; + + segment.vad_start = offset / (float)WHISPER_SAMPLE_RATE; + segment.vad_end = (offset + segment_length) / (float)WHISPER_SAMPLE_RATE; + + WHISPER_LOG_INFO("%s: vad_segment_info: orig_start: %.2f, orig_end: %.2f, vad_start: %.2f, vad_end: %.2f\n", + __func__, segment.orig_start, segment.orig_end, segment.vad_start, segment.vad_end); + ctx->state->vad_segments.push_back(segment); + + // Copy this speech segment + memcpy(filtered_samples.data() + offset, samples + segment_start_samples, segment_length * sizeof(float)); + offset += segment_length; + + // Add silence after this segment (except after the last segment) + if (i < (int)vad_segments->data.size() - 1) { + // Fill with zeros (silence) + memset(filtered_samples.data() + offset, 0, silence_samples * sizeof(float)); + offset += silence_samples; + } + } + } + + filtered_n_samples = offset; + WHISPER_LOG_INFO("%s: Reduced audio from %d to %d samples (%.1f%% reduction)\n", + __func__, n_samples, filtered_n_samples, 100.0f * (1.0f - (float)filtered_n_samples / n_samples)); + } + + return true; +} + int whisper_full_with_state( struct whisper_context * ctx, struct whisper_state * state, @@ -5483,9 +6724,24 @@ int whisper_full_with_state( result_all.clear(); - if (n_samples > 0) { + const float * process_samples = samples; + int n_process_samples = n_samples; + std::vector vad_samples; + + if (params.vad) { + WHISPER_LOG_INFO("%s: VAD is enabled, processing speech segments only\n", __func__); + int vad_n_samples; + if (!whisper_vad(ctx, state, params, samples, n_samples, vad_samples, vad_n_samples)) { + WHISPER_LOG_ERROR("%s: failed to compute VAD\n", __func__); + return -1; + } + process_samples = vad_samples.data(); + n_process_samples = vad_n_samples; + } + + if (n_process_samples > 0) { // compute log mel spectrogram - if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) { + if (whisper_pcm_to_mel_with_state(ctx, state, process_samples, n_process_samples, params.n_threads) != 0) { WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__); return -2; } @@ -6530,19 +7786,133 @@ int whisper_full_lang_id(struct whisper_context * ctx) { } int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int i_segment) { - return state->result_all[i_segment].t0; + // If VAD wasn't used, return the original timestamp + if (!state->has_vad_segments || state->vad_segments.empty()) { + return state->result_all[i_segment].t0; + } + + // Get the start timestamp produced by whisper_full. whisper_full processes + // only the speech segments in this case so we need to map these timestamps + // back to the original audio. + float t0 = state->result_all[i_segment].t0 / 100.0f; + + // Find which VAD segment this timestamp belongs. + // TODO(danbev) This could be optimized by using a binary search if the number + // of segments exceed a certain limit. Also we might be able to assume that + // the access pattern is sequential and optimized for that too. + for (size_t i = 0; i < state->vad_segments.size(); i++) { + const auto & segment = state->vad_segments[i]; + + // Check if the timestamp falls within this segment. + if (t0 >= segment.vad_start && t0 <= segment.vad_end) { + float proportion = 0.0f; + if (segment.vad_end > segment.vad_start) { + proportion = (t0 - segment.vad_start) / (segment.vad_end - segment.vad_start); + } + float orig_t0 = segment.orig_start + proportion * (segment.orig_end - segment.orig_start); + return (int64_t)(orig_t0 * 100); + } + } + + // Check if the timestamp falls between two segments. + for (size_t i = 0; i < state->vad_segments.size() - 1; i++) { + const auto & curr = state->vad_segments[i]; + const auto & next = state->vad_segments[i + 1]; + + if (t0 > curr.vad_end && t0 < next.vad_start) { + // Calculate how far we are through the gap as a proportion + float gap_proportion = 0.0f; + if (next.vad_start > curr.vad_end) { + gap_proportion = (t0 - curr.vad_end) / (next.vad_start - curr.vad_end); + } + float orig_t0 = curr.orig_end + gap_proportion * (next.orig_start - curr.orig_end); + return (int64_t)(orig_t0 * 100); + } + } + + // Handle the case where the timestamp is after the last segment. + if (t0 > state->vad_segments.back().vad_end) { + // For timestamps after the last segment, add the extra time to the end of the last segment + const auto& last = state->vad_segments.back(); + // Calculate how far beyond the last segment + float extra_time = t0 - last.vad_end; + // Add this extra time to the original end time + float orig_t0 = last.orig_end + extra_time; + return (int64_t)(orig_t0 * 100); + } + + WHISPER_LOG_WARN("%s: Could not map t0 = %f to a VAD segment\n", __func__, t0); + return t0; } int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) { - return ctx->state->result_all[i_segment].t0; + return whisper_full_get_segment_t0_from_state(ctx->state, i_segment); } int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int i_segment) { - return state->result_all[i_segment].t1; + // If VAD wasn't used, return the original timestamp + if (!state->has_vad_segments || state->vad_segments.empty()) { + return state->result_all[i_segment].t1; + } + + // Get the end timestamp produced by whisper_full. whisper_full processes + // only the speech segments in this case so we need to map these timestamps + // back to the original audio. + float t1 = state->result_all[i_segment].t1 / 100.0f; + + // Find which VAD segment this timestamp belongs. + // TODO(danbev) This could be optimized by using a binary search if the number + // of segments exceed a certain limit. Also we might be able to assume that + // the access pattern is sequential and optimized for that too. + for (size_t i = 0; i < state->vad_segments.size(); i++) { + const auto& segment = state->vad_segments[i]; + + // Check if the timestamp falls within this segment. + if (t1 >= segment.vad_start && t1 <= segment.vad_end) { + // Calculate the proportion through the filtered segment. + float proportion = 0.0f; + if (segment.vad_end > segment.vad_start) { + proportion = (t1 - segment.vad_start) / (segment.vad_end - segment.vad_start); + } + float orig_t1 = segment.orig_start + proportion * (segment.orig_end - segment.orig_start); + return (int64_t)(orig_t1 * 100); + } + } + + // Check if the timestamp falls between two segments. + for (size_t i = 0; i < state->vad_segments.size() - 1; i++) { + const auto & curr = state->vad_segments[i]; + const auto & next = state->vad_segments[i + 1]; + + if (t1 > curr.vad_end && t1 < next.vad_start) { + // Calculate how far we are through the gap as a proportion + float gap_proportion = 0.0f; + if (next.vad_start > curr.vad_end) { + gap_proportion = (t1 - curr.vad_end) / (next.vad_start - curr.vad_end); + } + // Map to the corresponding position in the original gap + float orig_t1 = curr.orig_end + gap_proportion * (next.orig_start - curr.orig_end); + return (int64_t)(orig_t1 * 100); + } + } + + // Handle the case where the timestamp is after the last segment + if (t1 > state->vad_segments.back().vad_end) { + // For the last segment, use the end of the last VAD segment + const auto& last = state->vad_segments.back(); + // Calculate how far beyond the last segment + float extra_time = t1 - last.vad_end; + // Add this extra time to the original end time + float orig_t1 = last.orig_end + extra_time; + return (int64_t)(orig_t1 * 100); + } + + WHISPER_LOG_WARN("%s: Could not map t1 = %f to a VAD segment\n", __func__, t1); + return t1; } int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment) { - return ctx->state->result_all[i_segment].t1; + return whisper_full_get_segment_t1_from_state(ctx->state, i_segment); } bool whisper_full_get_segment_speaker_turn_next_from_state(struct whisper_state * state, int i_segment) { diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 7cdfed82..efa1bbe3 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -1,3 +1,6 @@ +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + if (EMSCRIPTEN) # # test-whisper-js @@ -85,3 +88,18 @@ if (WHISPER_FFMPEG) set_tests_properties(${TEST_TARGET} PROPERTIES LABELS "tiny;mp3") endif() +# VAD test tests VAD in isolation +set(VAD_TEST test-vad) +add_executable(${VAD_TEST} ${VAD_TEST}.cpp) +target_include_directories(${VAD_TEST} PRIVATE ../include ../ggml/include ../examples) +target_link_libraries(${VAD_TEST} PRIVATE common) +add_test(NAME ${VAD_TEST} COMMAND ${VAD_TEST}) +set_tests_properties(${VAD_TEST} PROPERTIES LABELS "unit") + +# VAD test full uses whisper_full with VAD enabled +set(VAD_TEST test-vad-full) +add_executable(${VAD_TEST} ${VAD_TEST}.cpp) +target_include_directories(${VAD_TEST} PRIVATE ../include ../ggml/include ../examples) +target_link_libraries(${VAD_TEST} PRIVATE common) +add_test(NAME ${VAD_TEST} COMMAND ${VAD_TEST}) +set_tests_properties(${VAD_TARGET} PROPERTIES LABELS "base;en") diff --git a/tests/test-vad-full.cpp b/tests/test-vad-full.cpp new file mode 100644 index 00000000..9eac11ed --- /dev/null +++ b/tests/test-vad-full.cpp @@ -0,0 +1,54 @@ +#include "whisper.h" +#include "common-whisper.h" + +#include +#include +#include +#include + +#ifdef NDEBUG +#undef NDEBUG +#endif + +#include + +int main() { + std::string whisper_model_path = "../../models/ggml-base.en.bin"; + std::string vad_model_path = "../../models/for-tests-silero-v5.1.2-ggml.bin"; + std::string sample_path = "../../samples/jfk.wav"; + + // Load the sample audio file + std::vector pcmf32; + std::vector> pcmf32s; + assert(read_audio_data(sample_path.c_str(), pcmf32, pcmf32s, false)); + + struct whisper_context_params cparams = whisper_context_default_params(); + struct whisper_context * wctx = whisper_init_from_file_with_params( + whisper_model_path.c_str(), + cparams); + + struct whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH); + wparams.vad = true; + wparams.vad_model_path = vad_model_path.c_str(); + + wparams.vad_params.threshold = 0.5f; + wparams.vad_params.min_speech_duration_ms = 250; + wparams.vad_params.min_silence_duration_ms = 100; + wparams.vad_params.max_speech_duration_s = FLT_MAX; + wparams.vad_params.speech_pad_ms = 30; + + assert(whisper_full_parallel(wctx, wparams, pcmf32.data(), pcmf32.size(), 1) == 0); + + const int n_segments = whisper_full_n_segments(wctx); + assert(n_segments == 1); + + assert(strcmp(" And so my fellow Americans, ask not what your country can do for you," + " ask what you can do for your country.", + whisper_full_get_segment_text(wctx, 0)) == 0); + assert(whisper_full_get_segment_t0(wctx, 0) == 29); + assert(whisper_full_get_segment_t1(wctx, 0) == 1049); + + whisper_free(wctx); + + return 0; +} diff --git a/tests/test-vad.cpp b/tests/test-vad.cpp new file mode 100644 index 00000000..e6886e31 --- /dev/null +++ b/tests/test-vad.cpp @@ -0,0 +1,83 @@ +#include "whisper.h" +#include "common-whisper.h" + +#include +#include + +#ifdef NDEBUG +#undef NDEBUG +#endif +#include + +void assert_default_params(const struct whisper_vad_params & params) { + assert(params.threshold == 0.5); + assert(params.min_speech_duration_ms == 250); + assert(params.min_silence_duration_ms == 100); + assert(params.samples_overlap == 0.1f); +} + +void assert_default_context_params(const struct whisper_vad_context_params & params) { + assert(params.n_threads == 4); + assert(params.use_gpu == false); + assert(params.gpu_device == 0); +} + +void test_detect_speech( + struct whisper_vad_context * vctx, + struct whisper_vad_params params, + const float * pcmf32, + int n_samples) { + assert(whisper_vad_detect_speech(vctx, pcmf32, n_samples)); + assert(whisper_vad_n_probs(vctx) == 344); + assert(whisper_vad_probs(vctx) != nullptr); +} + +struct whisper_vad_segments * test_detect_timestamps( + struct whisper_vad_context * vctx, + struct whisper_vad_params params) { + struct whisper_vad_segments * timestamps = whisper_vad_segments_from_probs(vctx, params); + assert(whisper_vad_segments_n_segments(timestamps) == 5); + + for (int i = 0; i < whisper_vad_segments_n_segments(timestamps); ++i) { + printf("VAD segment %d: start = %.2f, end = %.2f\n", i, + whisper_vad_segments_get_segment_t0(timestamps, i), + whisper_vad_segments_get_segment_t1(timestamps, i)); + } + + return timestamps; +} + +int main() { + std::string vad_model_path = "../../models/for-tests-silero-v5.1.2-ggml.bin"; + std::string sample_path = "../../samples/jfk.wav"; + + // Load the sample audio file + std::vector pcmf32; + std::vector> pcmf32s; + assert(read_audio_data(sample_path.c_str(), pcmf32, pcmf32s, false)); + assert(pcmf32.size() > 0); + assert(pcmf32s.size() == 0); // no stereo vector + + // Load the VAD model + struct whisper_vad_context_params ctx_params = whisper_vad_default_context_params(); + assert_default_context_params(ctx_params); + + struct whisper_vad_context * vctx = whisper_vad_init_from_file_with_params( + vad_model_path.c_str(), + ctx_params); + assert(vctx != nullptr); + + struct whisper_vad_params params = whisper_vad_default_params(); + assert_default_params(params); + + // Test speech probabilites + test_detect_speech(vctx, params, pcmf32.data(), pcmf32.size()); + + // Test speech timestamps (uses speech probabilities from above) + struct whisper_vad_segments * timestamps = test_detect_timestamps(vctx, params); + + whisper_vad_free_segments(timestamps); + whisper_vad_free(vctx); + + return 0; +}