mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-05-15 23:13:10 +00:00
vad : add initial Voice Activity Detection (VAD) support (#3065)
Some checks failed
Bindings Tests (Ruby) / ubuntu-22 (push) Has been cancelled
CI / determine-tag (push) Has been cancelled
CI / ubuntu-22 (linux/amd64) (push) Has been cancelled
CI / ubuntu-22 (linux/ppc64le) (push) Has been cancelled
CI / ubuntu-22-arm64 (linux/arm64) (push) Has been cancelled
CI / ubuntu-22-arm-v7 (linux/arm/v7) (push) Has been cancelled
CI / macOS-latest (generic/platform=iOS) (push) Has been cancelled
CI / macOS-latest (generic/platform=macOS) (push) Has been cancelled
CI / macOS-latest (generic/platform=tvOS) (push) Has been cancelled
CI / ubuntu-22-gcc (linux/amd64, Debug) (push) Has been cancelled
CI / ubuntu-22-gcc (linux/amd64, Release) (push) Has been cancelled
CI / ubuntu-22-gcc (linux/ppc64le, Debug) (push) Has been cancelled
CI / ubuntu-22-gcc (linux/ppc64le, Release) (push) Has been cancelled
CI / ubuntu-22-gcc-arm64 (linux/arm64, Debug) (push) Has been cancelled
CI / ubuntu-22-gcc-arm64 (linux/arm64, Release) (push) Has been cancelled
CI / ubuntu-22-gcc-arm-v7 (linux/arm/v7, Debug) (push) Has been cancelled
CI / ubuntu-22-gcc-arm-v7 (linux/arm/v7, Release) (push) Has been cancelled
CI / ubuntu-22-clang (linux/amd64, Debug) (push) Has been cancelled
CI / ubuntu-22-clang (linux/amd64, Release) (push) Has been cancelled
CI / ubuntu-22-clang (linux/arm64, Debug) (push) Has been cancelled
CI / ubuntu-22-clang (linux/arm64, Release) (push) Has been cancelled
CI / ubuntu-22-clang (linux/ppc64le, Debug) (push) Has been cancelled
CI / ubuntu-22-clang (linux/ppc64le, Release) (push) Has been cancelled
CI / ubuntu-22-gcc-sanitized (linux/amd64, ADDRESS) (push) Has been cancelled
CI / ubuntu-22-gcc-sanitized (linux/amd64, THREAD) (push) Has been cancelled
CI / ubuntu-22-gcc-sanitized (linux/amd64, UNDEFINED) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl (linux/amd64, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl (linux/arm/v7, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl (linux/arm64, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl (linux/ppc64le, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl-fp16 (linux/amd64, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl-fp16 (linux/arm/v7, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl-fp16 (linux/arm64, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl-fp16 (linux/ppc64le, icx, icpx, ON) (push) Has been cancelled
CI / windows-msys2 (Release, clang-x86_64, CLANG64) (push) Has been cancelled
CI / windows-msys2 (Release, ucrt-x86_64, UCRT64) (push) Has been cancelled
CI / windows (Win32, Release, win32-x86, x86, 2.28.5, ON) (push) Has been cancelled
CI / windows (x64, Release, win32-x86-64, x64, 2.28.5, ON) (push) Has been cancelled
CI / windows-blas (Win32, ON, Release, x86, 2.28.5, ON) (push) Has been cancelled
CI / windows-blas (x64, ON, Release, x64, 2.28.5, ON) (push) Has been cancelled
CI / windows-cublas (x64, Release, ON, 11.8.0, ON, 2.28.5) (push) Has been cancelled
CI / windows-cublas (x64, Release, ON, 12.2.0, ON, 2.28.5) (push) Has been cancelled
CI / emscripten (Release) (push) Has been cancelled
CI / ios-xcode-build (Release) (push) Has been cancelled
CI / android (push) Has been cancelled
CI / android_java (push) Has been cancelled
CI / bindings-java (push) Has been cancelled
CI / quantize (push) Has been cancelled
CI / release (push) Has been cancelled
CI / coreml-base-en (push) Has been cancelled
CI / vad (push) Has been cancelled
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/main-musa.Dockerfile platform:linux/amd64 tag:main-musa]) (push) Has been cancelled
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/main.Dockerfile platform:linux/amd64 tag:main]) (push) Has been cancelled
Examples WASM / deploy-wasm-github-pages (push) Has been cancelled
Some checks failed
Bindings Tests (Ruby) / ubuntu-22 (push) Has been cancelled
CI / determine-tag (push) Has been cancelled
CI / ubuntu-22 (linux/amd64) (push) Has been cancelled
CI / ubuntu-22 (linux/ppc64le) (push) Has been cancelled
CI / ubuntu-22-arm64 (linux/arm64) (push) Has been cancelled
CI / ubuntu-22-arm-v7 (linux/arm/v7) (push) Has been cancelled
CI / macOS-latest (generic/platform=iOS) (push) Has been cancelled
CI / macOS-latest (generic/platform=macOS) (push) Has been cancelled
CI / macOS-latest (generic/platform=tvOS) (push) Has been cancelled
CI / ubuntu-22-gcc (linux/amd64, Debug) (push) Has been cancelled
CI / ubuntu-22-gcc (linux/amd64, Release) (push) Has been cancelled
CI / ubuntu-22-gcc (linux/ppc64le, Debug) (push) Has been cancelled
CI / ubuntu-22-gcc (linux/ppc64le, Release) (push) Has been cancelled
CI / ubuntu-22-gcc-arm64 (linux/arm64, Debug) (push) Has been cancelled
CI / ubuntu-22-gcc-arm64 (linux/arm64, Release) (push) Has been cancelled
CI / ubuntu-22-gcc-arm-v7 (linux/arm/v7, Debug) (push) Has been cancelled
CI / ubuntu-22-gcc-arm-v7 (linux/arm/v7, Release) (push) Has been cancelled
CI / ubuntu-22-clang (linux/amd64, Debug) (push) Has been cancelled
CI / ubuntu-22-clang (linux/amd64, Release) (push) Has been cancelled
CI / ubuntu-22-clang (linux/arm64, Debug) (push) Has been cancelled
CI / ubuntu-22-clang (linux/arm64, Release) (push) Has been cancelled
CI / ubuntu-22-clang (linux/ppc64le, Debug) (push) Has been cancelled
CI / ubuntu-22-clang (linux/ppc64le, Release) (push) Has been cancelled
CI / ubuntu-22-gcc-sanitized (linux/amd64, ADDRESS) (push) Has been cancelled
CI / ubuntu-22-gcc-sanitized (linux/amd64, THREAD) (push) Has been cancelled
CI / ubuntu-22-gcc-sanitized (linux/amd64, UNDEFINED) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl (linux/amd64, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl (linux/arm/v7, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl (linux/arm64, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl (linux/ppc64le, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl-fp16 (linux/amd64, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl-fp16 (linux/arm/v7, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl-fp16 (linux/arm64, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl-fp16 (linux/ppc64le, icx, icpx, ON) (push) Has been cancelled
CI / windows-msys2 (Release, clang-x86_64, CLANG64) (push) Has been cancelled
CI / windows-msys2 (Release, ucrt-x86_64, UCRT64) (push) Has been cancelled
CI / windows (Win32, Release, win32-x86, x86, 2.28.5, ON) (push) Has been cancelled
CI / windows (x64, Release, win32-x86-64, x64, 2.28.5, ON) (push) Has been cancelled
CI / windows-blas (Win32, ON, Release, x86, 2.28.5, ON) (push) Has been cancelled
CI / windows-blas (x64, ON, Release, x64, 2.28.5, ON) (push) Has been cancelled
CI / windows-cublas (x64, Release, ON, 11.8.0, ON, 2.28.5) (push) Has been cancelled
CI / windows-cublas (x64, Release, ON, 12.2.0, ON, 2.28.5) (push) Has been cancelled
CI / emscripten (Release) (push) Has been cancelled
CI / ios-xcode-build (Release) (push) Has been cancelled
CI / android (push) Has been cancelled
CI / android_java (push) Has been cancelled
CI / bindings-java (push) Has been cancelled
CI / quantize (push) Has been cancelled
CI / release (push) Has been cancelled
CI / coreml-base-en (push) Has been cancelled
CI / vad (push) Has been cancelled
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/main-musa.Dockerfile platform:linux/amd64 tag:main-musa]) (push) Has been cancelled
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/main.Dockerfile platform:linux/amd64 tag:main]) (push) Has been cancelled
Examples WASM / deploy-wasm-github-pages (push) Has been cancelled
* vad : add initial Voice Activity Detection (VAD) support This commit add support for Voice Activity Detection (VAD). When enabled this feature will process the audio input and detect speech segments. This information is then used to reduce the number of samples that need to be processed by whisper_full. Resolves: https://github.com/ggml-org/whisper.cpp/issues/3003 --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
e39ba750cd
commit
e41bc5c61a
20
.github/workflows/build.yml
vendored
20
.github/workflows/build.yml
vendored
@ -1253,3 +1253,23 @@ jobs:
|
||||
source venv/bin/activate
|
||||
pip install ane_transformers openai-whisper coremltools
|
||||
./models/generate-coreml-model.sh ${{ env.MODEL_NAME }}
|
||||
|
||||
vad:
|
||||
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
|
||||
github.event.inputs.run_type == 'full-ci' }}
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Build
|
||||
shell: bash
|
||||
run: |
|
||||
cmake -B build
|
||||
cmake --build build --config Release
|
||||
|
||||
- name: Test
|
||||
shell: bash
|
||||
run: |
|
||||
ctest -R ^test-vad$ --test-dir build --output-on-failure -VV
|
||||
|
59
README.md
59
README.md
@ -25,6 +25,7 @@ High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisp
|
||||
- [Ascend NPU Support](#ascend-npu-support)
|
||||
- [Moore Threads GPU Support](#moore-threads-gpu-support)
|
||||
- [C-style API](https://github.com/ggml-org/whisper.cpp/blob/master/include/whisper.h)
|
||||
- [Voice Activity Detection (VAD)](#voice-activity-detection-vad)
|
||||
|
||||
Supported platforms:
|
||||
|
||||
@ -732,6 +733,64 @@ let package = Package(
|
||||
)
|
||||
```
|
||||
|
||||
### Voice Activity Detection (VAD)
|
||||
Support for Voice Activity Detection (VAD) can be enabled using the `--vad`
|
||||
argument to `whisper-cli`. In addition to this option a VAD model is also
|
||||
required.
|
||||
|
||||
The way this works is that first the audio samples are passed through
|
||||
the VAD model which will detect speech segments. Using this information the
|
||||
only the speech segments that are detected are extracted from the original audio
|
||||
input and passed to whisper for processing. This reduces the amount of audio
|
||||
data that needs to be processed by whisper and can significantly speed up the
|
||||
transcription process.
|
||||
|
||||
The following VAD models are currently supported:
|
||||
|
||||
#### Silero-VAD
|
||||
[Silero-vad](https://github.com/snakers4/silero-vad) is a lightweight VAD model
|
||||
written in Python that is fast and accurate.
|
||||
|
||||
This model can be converted to ggml using the following command:
|
||||
```console
|
||||
$ python3 -m venv venv && source venv/bin/activate
|
||||
$ (venv) pip install silero-vad
|
||||
$ (venv) $ python models/convert-silero-vad-to-ggml.py --output models/silero.bin
|
||||
Saving GGML Silero-VAD model to models/silero-v5.1.2-ggml.bin
|
||||
```
|
||||
And it can then be used with whisper as follows:
|
||||
```console
|
||||
$ ./build/bin/whisper-cli \
|
||||
--file ./samples/jfk.wav \
|
||||
--model ./models/ggml-base.en.bin \
|
||||
--vad \
|
||||
--vad-model ./models/silero-v5.1.2-ggml.bin
|
||||
```
|
||||
|
||||
#### VAD Options
|
||||
|
||||
* --vad-threshold: Threshold probability for speech detection. A probability
|
||||
for a speech segment/frame above this threshold will be considered as speech.
|
||||
|
||||
* --vad-min-speech-duration-ms: Minimum speech duration in milliseconds. Speech
|
||||
segments shorter than this value will be discarded to filter out brief noise or
|
||||
false positives.
|
||||
|
||||
* --vad-min-silence-duration-ms: Minimum silence duration in milliseconds. Silence
|
||||
periods must be at least this long to end a speech segment. Shorter silence
|
||||
periods will be ignored and included as part of the speech.
|
||||
|
||||
* --vad-max-speech-duration-s: Maximum speech duration in seconds. Speech segments
|
||||
longer than this will be automatically split into multiple segments at silence
|
||||
points exceeding 98ms to prevent excessively long segments.
|
||||
|
||||
* --vad-speech-pad-ms: Speech padding in milliseconds. Adds this amount of padding
|
||||
before and after each detected speech segment to avoid cutting off speech edges.
|
||||
|
||||
* --vad-samples-overlap: Amount of audio to extend from each speech segment into
|
||||
the next one, in seconds (e.g., 0.10 = 100ms overlap). This ensures speech isn't
|
||||
cut off abruptly between segments when they're concatenated together.
|
||||
|
||||
## Examples
|
||||
|
||||
There are various examples of using the library for different projects in the [examples](examples) folder.
|
||||
|
@ -11,6 +11,7 @@
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <cstring>
|
||||
#include <cfloat>
|
||||
|
||||
#if defined(_WIN32)
|
||||
#ifndef NOMINMAX
|
||||
@ -97,6 +98,16 @@ struct whisper_params {
|
||||
std::vector<std::string> fname_out = {};
|
||||
|
||||
grammar_parser::parse_state grammar_parsed;
|
||||
|
||||
// Voice Activity Detection (VAD) parameters
|
||||
bool vad = false;
|
||||
std::string vad_model = "";
|
||||
float vad_threshold = 0.5f;
|
||||
int vad_min_speech_duration_ms = 250;
|
||||
int vad_min_silence_duration_ms = 100;
|
||||
float vad_max_speech_duration_s = FLT_MAX;
|
||||
int vad_speech_pad_ms = 30;
|
||||
float vad_samples_overlap = 0.1f;
|
||||
};
|
||||
|
||||
static void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
|
||||
@ -185,6 +196,15 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params
|
||||
else if ( arg == "--grammar") { params.grammar = ARGV_NEXT; }
|
||||
else if ( arg == "--grammar-rule") { params.grammar_rule = ARGV_NEXT; }
|
||||
else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(ARGV_NEXT); }
|
||||
// Voice Activity Detection (VAD)
|
||||
else if (arg == "-v" || arg == "--vad") { params.vad = true; }
|
||||
else if (arg == "-vm" || arg == "--vad-model") { params.vad_model = ARGV_NEXT; }
|
||||
else if (arg == "-vt" || arg == "--vad-threshold") { params.vad_threshold = std::stof(ARGV_NEXT); }
|
||||
else if (arg == "-vsd" || arg == "--vad-min-speech-duration-ms") { params.vad_min_speech_duration_ms = std::stoi(ARGV_NEXT); }
|
||||
else if (arg == "-vsd" || arg == "--vad-min-silence-duration-ms") { params.vad_min_speech_duration_ms = std::stoi(ARGV_NEXT); }
|
||||
else if (arg == "-vmsd" || arg == "--vad-max-speech-duration-s") { params.vad_max_speech_duration_s = std::stof(ARGV_NEXT); }
|
||||
else if (arg == "-vp" || arg == "--vad-speech-pad-ms") { params.vad_speech_pad_ms = std::stoi(ARGV_NEXT); }
|
||||
else if (arg == "-vo" || arg == "--vad-samples-overlap") { params.vad_samples_overlap = std::stof(ARGV_NEXT); }
|
||||
else {
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
whisper_print_usage(argc, argv, params);
|
||||
@ -254,6 +274,18 @@ static void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params
|
||||
fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
|
||||
fprintf(stderr, " --grammar-rule RULE [%-7s] top-level GBNF grammar rule name\n", params.grammar_rule.c_str());
|
||||
fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty);
|
||||
// Voice Activity Detection (VAD) parameters
|
||||
fprintf(stderr, "\nVoice Activity Detection (VAD) options:\n");
|
||||
fprintf(stderr, " -v, --vad [%-7s] enable Voice Activity Detection (VAD)\n", params.vad ? "true" : "false");
|
||||
fprintf(stderr, " -vm FNAME, --vad-model FNAME [%-7s] VAD model path\n", params.vad_model.c_str());
|
||||
fprintf(stderr, " -vt N, --vad-threshold N [%-7.2f] VAD threshold for speech recognition\n", params.vad_threshold);
|
||||
fprintf(stderr, " -vspd N, --vad-min-speech-duration-ms N [%-7d] VAD min speech duration (0.0-1.0)\n", params.vad_min_speech_duration_ms);
|
||||
fprintf(stderr, " -vsd N, --vad-min-silence-duration-ms N [%-7d] VAD min silence duration (to split segments)\n", params.vad_min_silence_duration_ms);
|
||||
fprintf(stderr, " -vmsd N, --vad-max-speech-duration-s N [%-7s] VAD max speech duration (auto-split longer)\n", params.vad_max_speech_duration_s == FLT_MAX ?
|
||||
std::string("FLT_MAX").c_str() :
|
||||
std::to_string(params.vad_max_speech_duration_s).c_str());
|
||||
fprintf(stderr, " -vp N, --vad-speech-pad-ms N [%-7d] VAD speech padding (extend segments)\n", params.vad_speech_pad_ms);
|
||||
fprintf(stderr, " -vo N, --vad-samples-overlap N [%-7.2f] VAD samples overlap (seconds between segments)\n", params.vad_samples_overlap);
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
@ -1134,6 +1166,16 @@ int main(int argc, char ** argv) {
|
||||
|
||||
wparams.suppress_nst = params.suppress_nst;
|
||||
|
||||
wparams.vad = params.vad;
|
||||
wparams.vad_model_path = params.vad_model.c_str();
|
||||
|
||||
wparams.vad_params.threshold = params.vad_threshold;
|
||||
wparams.vad_params.min_speech_duration_ms = params.vad_min_speech_duration_ms;
|
||||
wparams.vad_params.min_silence_duration_ms = params.vad_min_silence_duration_ms;
|
||||
wparams.vad_params.max_speech_duration_s = params.vad_max_speech_duration_s;
|
||||
wparams.vad_params.speech_pad_ms = params.vad_speech_pad_ms;
|
||||
wparams.vad_params.samples_overlap = params.vad_samples_overlap;
|
||||
|
||||
whisper_print_user_data user_data = { ¶ms, &pcmf32s, 0 };
|
||||
|
||||
const auto & grammar_parsed = params.grammar_parsed;
|
||||
|
@ -189,6 +189,15 @@ extern "C" {
|
||||
uint32_t value; // Unicode code point or rule ID
|
||||
} whisper_grammar_element;
|
||||
|
||||
typedef struct whisper_vad_params {
|
||||
float threshold; // Probability threshold to consider as speech.
|
||||
int min_speech_duration_ms; // Min duration for a valid speech segment.
|
||||
int min_silence_duration_ms; // Min silence duration to consider speech as ended.
|
||||
float max_speech_duration_s; // Max duration of a speech segment before forcing a new segment.
|
||||
int speech_pad_ms; // Padding added before and after speech segments.
|
||||
float samples_overlap; // Overlap in seconds when copying audio samples from speech segment.
|
||||
} whisper_vad_params;
|
||||
|
||||
// Various functions for loading a ggml whisper model.
|
||||
// Allocate (almost) all memory needed for the model.
|
||||
// Return NULL on failure
|
||||
@ -570,11 +579,18 @@ extern "C" {
|
||||
size_t n_grammar_rules;
|
||||
size_t i_start_rule;
|
||||
float grammar_penalty;
|
||||
|
||||
// Voice Activity Detection (VAD) params
|
||||
bool vad; // Enable VAD
|
||||
const char * vad_model_path; // Path to VAD model
|
||||
|
||||
whisper_vad_params vad_params;
|
||||
};
|
||||
|
||||
// NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see whisper_free_context_params & whisper_free_params()
|
||||
WHISPER_API struct whisper_context_params * whisper_context_default_params_by_ref(void);
|
||||
WHISPER_API struct whisper_context_params whisper_context_default_params (void);
|
||||
|
||||
WHISPER_API struct whisper_full_params * whisper_full_default_params_by_ref(enum whisper_sampling_strategy strategy);
|
||||
WHISPER_API struct whisper_full_params whisper_full_default_params (enum whisper_sampling_strategy strategy);
|
||||
|
||||
@ -652,6 +668,53 @@ extern "C" {
|
||||
WHISPER_API float whisper_full_get_token_p (struct whisper_context * ctx, int i_segment, int i_token);
|
||||
WHISPER_API float whisper_full_get_token_p_from_state(struct whisper_state * state, int i_segment, int i_token);
|
||||
|
||||
//
|
||||
// Voice Activity Detection (VAD)
|
||||
//
|
||||
|
||||
struct whisper_vad_context;
|
||||
|
||||
WHISPER_API struct whisper_vad_params whisper_vad_default_params(void);
|
||||
|
||||
struct whisper_vad_context_params {
|
||||
int n_threads; // The number of threads to use for processing.
|
||||
bool use_gpu;
|
||||
int gpu_device; // CUDA device
|
||||
};
|
||||
|
||||
WHISPER_API struct whisper_vad_context_params whisper_vad_default_context_params(void);
|
||||
|
||||
WHISPER_API struct whisper_vad_context * whisper_vad_init_from_file_with_params(const char * path_model, struct whisper_vad_context_params params);
|
||||
WHISPER_API struct whisper_vad_context * whisper_vad_init_with_params (struct whisper_model_loader * loader, struct whisper_vad_context_params params);
|
||||
|
||||
WHISPER_API bool whisper_vad_detect_speech(
|
||||
struct whisper_vad_context * vctx,
|
||||
const float * samples,
|
||||
int n_samples);
|
||||
|
||||
WHISPER_API int whisper_vad_n_probs(struct whisper_vad_context * vctx);
|
||||
WHISPER_API float * whisper_vad_probs (struct whisper_vad_context * vctx);
|
||||
|
||||
struct whisper_vad_segments;
|
||||
|
||||
WHISPER_API struct whisper_vad_segments * whisper_vad_segments_from_probs(
|
||||
struct whisper_vad_context * vctx,
|
||||
struct whisper_vad_params params);
|
||||
|
||||
WHISPER_API struct whisper_vad_segments * whisper_vad_segments_from_samples(
|
||||
struct whisper_vad_context * vctx,
|
||||
struct whisper_vad_params params,
|
||||
const float * samples,
|
||||
int n_samples);
|
||||
|
||||
WHISPER_API int whisper_vad_segments_n_segments(struct whisper_vad_segments * segments);
|
||||
|
||||
WHISPER_API float whisper_vad_segments_get_segment_t0(struct whisper_vad_segments * segments, int i_segment);
|
||||
WHISPER_API float whisper_vad_segments_get_segment_t1(struct whisper_vad_segments * segments, int i_segment);
|
||||
|
||||
WHISPER_API void whisper_vad_free_segments(struct whisper_vad_segments * segments);
|
||||
WHISPER_API void whisper_vad_free (struct whisper_vad_context * ctx);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Temporary helpers needed for exposing ggml interface
|
||||
|
196
models/convert-silero-vad-to-ggml.py
Normal file
196
models/convert-silero-vad-to-ggml.py
Normal file
@ -0,0 +1,196 @@
|
||||
import os
|
||||
import struct
|
||||
import argparse
|
||||
import torch
|
||||
import numpy as np
|
||||
from silero_vad import load_silero_vad, __version__ as silero_version
|
||||
|
||||
def convert_silero_vad(output_path, print_tensors=True):
|
||||
model = load_silero_vad()
|
||||
state_dict = model.state_dict()
|
||||
|
||||
# Clean up state dict keys - filter out 8k model
|
||||
cleaned_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
# Skip 8k model
|
||||
if "_8k" not in key:
|
||||
clean_key = key
|
||||
if not key.startswith("_model."):
|
||||
clean_key = "_model." + key
|
||||
cleaned_dict[clean_key] = value
|
||||
|
||||
base, ext = os.path.splitext(output_path)
|
||||
output_file = f"{base}-v{silero_version}-ggml{ext}"
|
||||
print(f"Saving GGML Silero-VAD model to {output_file}")
|
||||
|
||||
print("\nTensor info for debugging:")
|
||||
for key, tensor in cleaned_dict.items():
|
||||
print(f" - {key}: {tensor.shape} ({tensor.dtype})")
|
||||
print()
|
||||
|
||||
with open(output_file, "wb") as fout:
|
||||
# Write magic and version
|
||||
fout.write(struct.pack("i", 0x67676d6c))
|
||||
|
||||
model_type = "silero-16k"
|
||||
str_len = len(model_type)
|
||||
fout.write(struct.pack("i", str_len))
|
||||
fout.write(model_type.encode('utf-8'))
|
||||
|
||||
version_parts = silero_version.split('.')
|
||||
major, minor, patch = map(int, version_parts)
|
||||
print(f"Version: {major}.{minor}.{patch}")
|
||||
fout.write(struct.pack("i", major))
|
||||
fout.write(struct.pack("i", minor))
|
||||
fout.write(struct.pack("i", patch))
|
||||
|
||||
# Write model architecture parameters
|
||||
window_size = 512
|
||||
fout.write(struct.pack("i", window_size))
|
||||
context_size = 64
|
||||
fout.write(struct.pack("i", context_size))
|
||||
|
||||
n_encoder_layers = 4
|
||||
fout.write(struct.pack("i", n_encoder_layers))
|
||||
|
||||
# Write encoder dimensions
|
||||
input_channels = 129
|
||||
encoder_in_channels = [input_channels, 128, 64, 64]
|
||||
encoder_out_channels = [128, 64, 64, 128]
|
||||
kernel_size = 3
|
||||
|
||||
for i in range(n_encoder_layers):
|
||||
fout.write(struct.pack("i", encoder_in_channels[i]))
|
||||
fout.write(struct.pack("i", encoder_out_channels[i]))
|
||||
fout.write(struct.pack("i", kernel_size))
|
||||
|
||||
# Write LSTM dimensions
|
||||
lstm_input_size = 128
|
||||
lstm_hidden_size = 128
|
||||
fout.write(struct.pack("i", lstm_input_size))
|
||||
fout.write(struct.pack("i", lstm_hidden_size))
|
||||
|
||||
# Write final conv dimensions
|
||||
final_conv_in = 128
|
||||
final_conv_out = 1
|
||||
fout.write(struct.pack("i", final_conv_in))
|
||||
fout.write(struct.pack("i", final_conv_out))
|
||||
|
||||
# Define tensor keys to write
|
||||
tensor_keys = []
|
||||
|
||||
# Encoder weights
|
||||
for i in range(n_encoder_layers):
|
||||
weight_key = f"_model.encoder.{i}.reparam_conv.weight"
|
||||
bias_key = f"_model.encoder.{i}.reparam_conv.bias"
|
||||
if weight_key in cleaned_dict and bias_key in cleaned_dict:
|
||||
tensor_keys.append(weight_key)
|
||||
tensor_keys.append(bias_key)
|
||||
|
||||
# LSTM weights
|
||||
lstm_keys = [
|
||||
"_model.decoder.rnn.weight_ih",
|
||||
"_model.decoder.rnn.weight_hh",
|
||||
"_model.decoder.rnn.bias_ih",
|
||||
"_model.decoder.rnn.bias_hh"
|
||||
]
|
||||
tensor_keys.extend([k for k in lstm_keys if k in cleaned_dict])
|
||||
|
||||
# Final conv weights
|
||||
final_keys = [
|
||||
"_model.decoder.decoder.2.weight",
|
||||
"_model.decoder.decoder.2.bias"
|
||||
]
|
||||
tensor_keys.extend([k for k in final_keys if k in cleaned_dict])
|
||||
|
||||
# STFT basis - add this last
|
||||
stft_tensor = "_model.stft.forward_basis_buffer"
|
||||
tensor_keys.append(stft_tensor)
|
||||
|
||||
print(f"Writing {len(tensor_keys)} tensors:")
|
||||
for key in tensor_keys:
|
||||
if key in cleaned_dict:
|
||||
print(f" - {key}: {cleaned_dict[key].shape}")
|
||||
else:
|
||||
print(f" - {key}: MISSING")
|
||||
|
||||
# Process each tensor
|
||||
for key in tensor_keys:
|
||||
if key not in cleaned_dict:
|
||||
print(f"Warning: Missing tensor {key}, skipping")
|
||||
continue
|
||||
|
||||
tensor = cleaned_dict[key]
|
||||
|
||||
# Special handling for STFT tensor
|
||||
if key == "_model.stft.forward_basis_buffer":
|
||||
# Get the original numpy array without squeezing
|
||||
data = tensor.detach().cpu().numpy()
|
||||
# Ensure it has the expected shape
|
||||
print(f"STFT tensor original shape: {data.shape}")
|
||||
n_dims = 3
|
||||
tensor_shape = [data.shape[2], data.shape[1], data.shape[0]]
|
||||
is_conv_weight = True
|
||||
else:
|
||||
# For other tensors, we can use standard processing
|
||||
data = tensor.detach().cpu().squeeze().numpy()
|
||||
tensor_shape = list(data.shape)
|
||||
|
||||
# Ensure we have at most 4 dimensions for GGML
|
||||
n_dims = min(len(tensor_shape), 4)
|
||||
|
||||
# Reverse dimensions for GGML
|
||||
tensor_shape = tensor_shape[:n_dims]
|
||||
tensor_shape.reverse()
|
||||
|
||||
# Check if this is a convolution weight tensor
|
||||
is_conv_weight = "weight" in key and ("encoder" in key or "_model.decoder.decoder.2" in key)
|
||||
|
||||
# Convert to float16 for convolution weights
|
||||
if is_conv_weight:
|
||||
data = data.astype(np.float16)
|
||||
ftype = 1 # float16
|
||||
else:
|
||||
ftype = 0 # float32
|
||||
|
||||
# Debug printing of tensor info
|
||||
print(f"\nWriting tensor: {key}")
|
||||
print(f" Original shape: {tensor.shape}")
|
||||
print(f" Processed shape: {data.shape}")
|
||||
print(f" GGML dimensions: {n_dims}")
|
||||
print(f" GGML shape: {tensor_shape}")
|
||||
print(f" Type: {'float16' if ftype == 1 else 'float32'}")
|
||||
|
||||
# Convert tensor name to bytes
|
||||
name_bytes = key.encode('utf-8')
|
||||
name_length = len(name_bytes)
|
||||
|
||||
# Write tensor header
|
||||
fout.write(struct.pack("i", n_dims))
|
||||
fout.write(struct.pack("i", name_length))
|
||||
fout.write(struct.pack("i", ftype))
|
||||
|
||||
# Write tensor dimensions
|
||||
for i in range(n_dims):
|
||||
size = tensor_shape[i] if i < len(tensor_shape) else 1
|
||||
fout.write(struct.pack("i", size))
|
||||
print(f" Writing dimension {i}: {size}")
|
||||
|
||||
# Write tensor name
|
||||
fout.write(name_bytes)
|
||||
|
||||
# Write tensor data
|
||||
data.tofile(fout)
|
||||
|
||||
print(f" Wrote {data.size * (2 if ftype==1 else 4)} bytes")
|
||||
|
||||
print(f"\nDone! Model has been converted to GGML format: {output_file}")
|
||||
print(f"File size: {os.path.getsize(output_file)} bytes")
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert Silero-VAD PyTorch model to GGML format")
|
||||
parser.add_argument("--output", type=str, required=True, help="Path to output GGML model file")
|
||||
parser.add_argument("--print-tensors", action="store_true", help="Print tensor values", default=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_silero_vad(args.output, args.print_tensors)
|
BIN
models/for-tests-silero-v5.1.2-ggml.bin
Normal file
BIN
models/for-tests-silero-v5.1.2-ggml.bin
Normal file
Binary file not shown.
@ -139,3 +139,59 @@ static const std::map<asr_tensor, ggml_op> ASR_TENSOR_INFO = {
|
||||
{ASR_TENSOR_ATTN_OUT_WEIGHT, GGML_OP_MUL_MAT},
|
||||
{ASR_TENSOR_ATTN_OUT_BIAS, GGML_OP_ADD},
|
||||
};
|
||||
|
||||
enum vad_tensor {
|
||||
VAD_TENSOR_STFT_BASIS,
|
||||
VAD_TENSOR_ENC_0_WEIGHT,
|
||||
VAD_TENSOR_ENC_0_BIAS,
|
||||
VAD_TENSOR_ENC_1_WEIGHT,
|
||||
VAD_TENSOR_ENC_1_BIAS,
|
||||
VAD_TENSOR_ENC_2_WEIGHT,
|
||||
VAD_TENSOR_ENC_2_BIAS,
|
||||
VAD_TENSOR_ENC_3_WEIGHT,
|
||||
VAD_TENSOR_ENC_3_BIAS,
|
||||
VAD_TENSOR_LSTM_WEIGHT_IH,
|
||||
VAD_TENSOR_LSTM_WEIGHT_HH,
|
||||
VAD_TENSOR_LSTM_BIAS_IH,
|
||||
VAD_TENSOR_LSTM_BIAS_HH,
|
||||
VAD_TENSOR_FINAL_CONV_WEIGHT,
|
||||
VAD_TENSOR_FINAL_CONV_BIAS,
|
||||
};
|
||||
|
||||
static const std::map<vad_tensor, ggml_op> VAD_TENSOR_OPS = {
|
||||
{VAD_TENSOR_STFT_BASIS, GGML_OP_IM2COL},
|
||||
{VAD_TENSOR_ENC_0_WEIGHT, GGML_OP_IM2COL},
|
||||
{VAD_TENSOR_ENC_0_BIAS, GGML_OP_ADD},
|
||||
{VAD_TENSOR_ENC_1_WEIGHT, GGML_OP_IM2COL},
|
||||
{VAD_TENSOR_ENC_1_BIAS, GGML_OP_ADD},
|
||||
{VAD_TENSOR_ENC_2_WEIGHT, GGML_OP_IM2COL},
|
||||
{VAD_TENSOR_ENC_2_BIAS, GGML_OP_ADD},
|
||||
{VAD_TENSOR_ENC_3_WEIGHT, GGML_OP_IM2COL},
|
||||
{VAD_TENSOR_ENC_3_BIAS, GGML_OP_ADD},
|
||||
|
||||
{VAD_TENSOR_LSTM_WEIGHT_IH, GGML_OP_MUL_MAT},
|
||||
{VAD_TENSOR_LSTM_WEIGHT_HH, GGML_OP_MUL_MAT},
|
||||
{VAD_TENSOR_LSTM_BIAS_IH, GGML_OP_ADD},
|
||||
{VAD_TENSOR_LSTM_BIAS_HH, GGML_OP_ADD},
|
||||
|
||||
{VAD_TENSOR_FINAL_CONV_WEIGHT, GGML_OP_IM2COL},
|
||||
{VAD_TENSOR_FINAL_CONV_BIAS, GGML_OP_ADD}
|
||||
};
|
||||
|
||||
static const std::map<vad_tensor, const char *> VAD_TENSOR_NAMES = {
|
||||
{VAD_TENSOR_STFT_BASIS, "_model.stft.forward_basis_buffer"},
|
||||
{VAD_TENSOR_ENC_0_WEIGHT, "_model.encoder.0.reparam_conv.weight"},
|
||||
{VAD_TENSOR_ENC_0_BIAS, "_model.encoder.0.reparam_conv.bias"},
|
||||
{VAD_TENSOR_ENC_1_WEIGHT, "_model.encoder.1.reparam_conv.weight"},
|
||||
{VAD_TENSOR_ENC_1_BIAS, "_model.encoder.1.reparam_conv.bias"},
|
||||
{VAD_TENSOR_ENC_2_WEIGHT, "_model.encoder.2.reparam_conv.weight"},
|
||||
{VAD_TENSOR_ENC_2_BIAS, "_model.encoder.2.reparam_conv.bias"},
|
||||
{VAD_TENSOR_ENC_3_WEIGHT, "_model.encoder.3.reparam_conv.weight"},
|
||||
{VAD_TENSOR_ENC_3_BIAS, "_model.encoder.3.reparam_conv.bias"},
|
||||
{VAD_TENSOR_LSTM_WEIGHT_IH, "_model.decoder.rnn.weight_ih"},
|
||||
{VAD_TENSOR_LSTM_WEIGHT_HH, "_model.decoder.rnn.weight_hh"},
|
||||
{VAD_TENSOR_LSTM_BIAS_IH, "_model.decoder.rnn.bias_ih"},
|
||||
{VAD_TENSOR_LSTM_BIAS_HH, "_model.decoder.rnn.bias_hh"},
|
||||
{VAD_TENSOR_FINAL_CONV_WEIGHT, "_model.decoder.decoder.2.weight"},
|
||||
{VAD_TENSOR_FINAL_CONV_BIAS, "_model.decoder.decoder.2.bias"}
|
||||
};
|
||||
|
1392
src/whisper.cpp
1392
src/whisper.cpp
File diff suppressed because it is too large
Load Diff
@ -1,3 +1,6 @@
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
||||
if (EMSCRIPTEN)
|
||||
#
|
||||
# test-whisper-js
|
||||
@ -85,3 +88,18 @@ if (WHISPER_FFMPEG)
|
||||
set_tests_properties(${TEST_TARGET} PROPERTIES LABELS "tiny;mp3")
|
||||
endif()
|
||||
|
||||
# VAD test tests VAD in isolation
|
||||
set(VAD_TEST test-vad)
|
||||
add_executable(${VAD_TEST} ${VAD_TEST}.cpp)
|
||||
target_include_directories(${VAD_TEST} PRIVATE ../include ../ggml/include ../examples)
|
||||
target_link_libraries(${VAD_TEST} PRIVATE common)
|
||||
add_test(NAME ${VAD_TEST} COMMAND ${VAD_TEST})
|
||||
set_tests_properties(${VAD_TEST} PROPERTIES LABELS "unit")
|
||||
|
||||
# VAD test full uses whisper_full with VAD enabled
|
||||
set(VAD_TEST test-vad-full)
|
||||
add_executable(${VAD_TEST} ${VAD_TEST}.cpp)
|
||||
target_include_directories(${VAD_TEST} PRIVATE ../include ../ggml/include ../examples)
|
||||
target_link_libraries(${VAD_TEST} PRIVATE common)
|
||||
add_test(NAME ${VAD_TEST} COMMAND ${VAD_TEST})
|
||||
set_tests_properties(${VAD_TARGET} PROPERTIES LABELS "base;en")
|
||||
|
54
tests/test-vad-full.cpp
Normal file
54
tests/test-vad-full.cpp
Normal file
@ -0,0 +1,54 @@
|
||||
#include "whisper.h"
|
||||
#include "common-whisper.h"
|
||||
|
||||
#include <cstdio>
|
||||
#include <cfloat>
|
||||
#include <string>
|
||||
#include <cstring>
|
||||
|
||||
#ifdef NDEBUG
|
||||
#undef NDEBUG
|
||||
#endif
|
||||
|
||||
#include <cassert>
|
||||
|
||||
int main() {
|
||||
std::string whisper_model_path = "../../models/ggml-base.en.bin";
|
||||
std::string vad_model_path = "../../models/for-tests-silero-v5.1.2-ggml.bin";
|
||||
std::string sample_path = "../../samples/jfk.wav";
|
||||
|
||||
// Load the sample audio file
|
||||
std::vector<float> pcmf32;
|
||||
std::vector<std::vector<float>> pcmf32s;
|
||||
assert(read_audio_data(sample_path.c_str(), pcmf32, pcmf32s, false));
|
||||
|
||||
struct whisper_context_params cparams = whisper_context_default_params();
|
||||
struct whisper_context * wctx = whisper_init_from_file_with_params(
|
||||
whisper_model_path.c_str(),
|
||||
cparams);
|
||||
|
||||
struct whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH);
|
||||
wparams.vad = true;
|
||||
wparams.vad_model_path = vad_model_path.c_str();
|
||||
|
||||
wparams.vad_params.threshold = 0.5f;
|
||||
wparams.vad_params.min_speech_duration_ms = 250;
|
||||
wparams.vad_params.min_silence_duration_ms = 100;
|
||||
wparams.vad_params.max_speech_duration_s = FLT_MAX;
|
||||
wparams.vad_params.speech_pad_ms = 30;
|
||||
|
||||
assert(whisper_full_parallel(wctx, wparams, pcmf32.data(), pcmf32.size(), 1) == 0);
|
||||
|
||||
const int n_segments = whisper_full_n_segments(wctx);
|
||||
assert(n_segments == 1);
|
||||
|
||||
assert(strcmp(" And so my fellow Americans, ask not what your country can do for you,"
|
||||
" ask what you can do for your country.",
|
||||
whisper_full_get_segment_text(wctx, 0)) == 0);
|
||||
assert(whisper_full_get_segment_t0(wctx, 0) == 29);
|
||||
assert(whisper_full_get_segment_t1(wctx, 0) == 1049);
|
||||
|
||||
whisper_free(wctx);
|
||||
|
||||
return 0;
|
||||
}
|
83
tests/test-vad.cpp
Normal file
83
tests/test-vad.cpp
Normal file
@ -0,0 +1,83 @@
|
||||
#include "whisper.h"
|
||||
#include "common-whisper.h"
|
||||
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
|
||||
#ifdef NDEBUG
|
||||
#undef NDEBUG
|
||||
#endif
|
||||
#include <cassert>
|
||||
|
||||
void assert_default_params(const struct whisper_vad_params & params) {
|
||||
assert(params.threshold == 0.5);
|
||||
assert(params.min_speech_duration_ms == 250);
|
||||
assert(params.min_silence_duration_ms == 100);
|
||||
assert(params.samples_overlap == 0.1f);
|
||||
}
|
||||
|
||||
void assert_default_context_params(const struct whisper_vad_context_params & params) {
|
||||
assert(params.n_threads == 4);
|
||||
assert(params.use_gpu == false);
|
||||
assert(params.gpu_device == 0);
|
||||
}
|
||||
|
||||
void test_detect_speech(
|
||||
struct whisper_vad_context * vctx,
|
||||
struct whisper_vad_params params,
|
||||
const float * pcmf32,
|
||||
int n_samples) {
|
||||
assert(whisper_vad_detect_speech(vctx, pcmf32, n_samples));
|
||||
assert(whisper_vad_n_probs(vctx) == 344);
|
||||
assert(whisper_vad_probs(vctx) != nullptr);
|
||||
}
|
||||
|
||||
struct whisper_vad_segments * test_detect_timestamps(
|
||||
struct whisper_vad_context * vctx,
|
||||
struct whisper_vad_params params) {
|
||||
struct whisper_vad_segments * timestamps = whisper_vad_segments_from_probs(vctx, params);
|
||||
assert(whisper_vad_segments_n_segments(timestamps) == 5);
|
||||
|
||||
for (int i = 0; i < whisper_vad_segments_n_segments(timestamps); ++i) {
|
||||
printf("VAD segment %d: start = %.2f, end = %.2f\n", i,
|
||||
whisper_vad_segments_get_segment_t0(timestamps, i),
|
||||
whisper_vad_segments_get_segment_t1(timestamps, i));
|
||||
}
|
||||
|
||||
return timestamps;
|
||||
}
|
||||
|
||||
int main() {
|
||||
std::string vad_model_path = "../../models/for-tests-silero-v5.1.2-ggml.bin";
|
||||
std::string sample_path = "../../samples/jfk.wav";
|
||||
|
||||
// Load the sample audio file
|
||||
std::vector<float> pcmf32;
|
||||
std::vector<std::vector<float>> pcmf32s;
|
||||
assert(read_audio_data(sample_path.c_str(), pcmf32, pcmf32s, false));
|
||||
assert(pcmf32.size() > 0);
|
||||
assert(pcmf32s.size() == 0); // no stereo vector
|
||||
|
||||
// Load the VAD model
|
||||
struct whisper_vad_context_params ctx_params = whisper_vad_default_context_params();
|
||||
assert_default_context_params(ctx_params);
|
||||
|
||||
struct whisper_vad_context * vctx = whisper_vad_init_from_file_with_params(
|
||||
vad_model_path.c_str(),
|
||||
ctx_params);
|
||||
assert(vctx != nullptr);
|
||||
|
||||
struct whisper_vad_params params = whisper_vad_default_params();
|
||||
assert_default_params(params);
|
||||
|
||||
// Test speech probabilites
|
||||
test_detect_speech(vctx, params, pcmf32.data(), pcmf32.size());
|
||||
|
||||
// Test speech timestamps (uses speech probabilities from above)
|
||||
struct whisper_vad_segments * timestamps = test_detect_timestamps(vctx, params);
|
||||
|
||||
whisper_vad_free_segments(timestamps);
|
||||
whisper_vad_free(vctx);
|
||||
|
||||
return 0;
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user