Merge branch 'master' into avx512

This commit is contained in:
Georgi Gerganov 2022-11-06 09:09:50 +02:00
commit 5d895d60b6
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
18 changed files with 1519 additions and 251 deletions

View File

@ -41,8 +41,13 @@ option(WHISPER_BUILD_EXAMPLES "whisper: build examples" ${WHISPER_STAND
option(WHISPER_SUPPORT_SDL2 "whisper: support for libSDL2" OFF)
option(WHISPER_PERF "whisper: enable perf timings" OFF)
if (APPLE)
option(WHISPER_NO_ACCELERATE "whisper: disable Accelerate framework" OFF)
else()
option(WHISPER_SUPPORT_OPENBLAS "whisper: support for OpenBLAS" OFF)
endif()
option(WHISPER_PERF "whisper: enable perf timings" OFF)
# sanitizers
@ -86,6 +91,18 @@ if (APPLE AND NOT WHISPER_NO_ACCELERATE)
endif()
endif()
if (WHISPER_SUPPORT_OPENBLAS)
find_library(OPENBLAS_LIB openblas)
if (OPENBLAS_LIB)
message(STATUS "OpenBLAS found")
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ${OPENBLAS_LIB})
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_USE_OPENBLAS)
else()
message(WARNING "OpenBLAS not found")
endif()
endif()
# compiler flags
if (NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES)
@ -134,6 +151,10 @@ else()
endif()
endif()
if (WHISPER_PERF)
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_PERF)
endif()
#
# whisper - this is the main library of the project
#
@ -151,6 +172,8 @@ target_include_directories(${TARGET} PUBLIC
if (MSVC)
target_link_libraries(${TARGET} PRIVATE ${WHISPER_EXTRA_LIBS} ${CMAKE_THREAD_LIBS_INIT})
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -D_CRT_SECURE_NO_WARNINGS)
else()
target_link_libraries(${TARGET} PRIVATE m ${WHISPER_EXTRA_LIBS} ${CMAKE_THREAD_LIBS_INIT})
endif()
@ -163,10 +186,6 @@ if (BUILD_SHARED_LIBS)
target_compile_definitions(${TARGET} PUBLIC
WHISPER_SHARED
)
if (MSVC)
target_compile_definitions(${TARGET} PUBLIC __AVX2__ _CRT_SECURE_NO_WARNINGS)
endif()
endif()
target_compile_definitions(${TARGET} PUBLIC

View File

@ -1,6 +1,14 @@
ifndef UNAME_S
UNAME_S := $(shell uname -s)
endif
ifndef UNAME_P
UNAME_P := $(shell uname -p)
endif
ifndef UNAME_M
UNAME_M := $(shell uname -m)
endif
# Mac OS + Arm can report x86_64
# ref: https://github.com/ggerganov/whisper.cpp/issues/66#issuecomment-1282546789
@ -8,8 +16,8 @@ ifeq ($(UNAME_S),Darwin)
ifneq ($(UNAME_P),arm)
SYSCTL_M := $(shell sysctl -n hw.optional.arm64)
ifeq ($(SYSCTL_M),1)
UNAME_P := arm
UNAME_M := arm64
# UNAME_P := arm
# UNAME_M := arm64
warn := $(warning Your arch is announced as x86_64, but it seems to actually be ARM64. Not fixing that can lead to bad performance. For more info see: https://github.com/ggerganov/whisper.cpp/issues/66\#issuecomment-1282546789)
endif
endif
@ -51,7 +59,7 @@ endif
ifeq ($(UNAME_M),amd64)
CFLAGS += -mavx -mavx2 -mfma -mf16c
endif
ifneq ($(filter arm%,$(UNAME_M)),)
ifndef WHISPER_NO_ACCELERATE
# Mac M1 - include Accelerate framework
ifeq ($(UNAME_S),Darwin)
CFLAGS += -DGGML_USE_ACCELERATE
@ -82,13 +90,13 @@ main: examples/main/main.cpp ggml.o whisper.o
./main -h
ggml.o: ggml.c ggml.h
$(CC) $(CFLAGS) -c ggml.c
$(CC) $(CFLAGS) -c ggml.c -o ggml.o
whisper.o: whisper.cpp whisper.h
$(CXX) $(CXXFLAGS) -c whisper.cpp
$(CXX) $(CXXFLAGS) -c whisper.cpp -o whisper.o
libwhisper.a: ggml.o whisper.o
ar rcs libwhisper.a ggml.o whisper.o
$(AR) rcs libwhisper.a ggml.o whisper.o
clean:
rm -f *.o main stream bench libwhisper.a

193
README.md
View File

@ -26,14 +26,41 @@ Supported platforms:
The entire implementation of the model is contained in 2 source files:
- [ggml.h](ggml.h) / [ggml.c](ggml.c)
- [whisper.h](whisper.h) / [whisper.cpp](whisper.cpp)
- Tensor operations: [ggml.h](ggml.h) / [ggml.c](ggml.c)
- Transformer inference: [whisper.h](whisper.h) / [whisper.cpp](whisper.cpp)
Having such a lightweight implementation of the model allows to easily integrate it in different platforms and applications.
As an example, here is a video of running the model on an iPhone 13 device - fully offline, on-device:
https://user-images.githubusercontent.com/1991296/197385372-962a6dea-bca1-4d50-bf96-1d8c27b98c81.mp4
## Implementation details
- The core tensor operations are implemented in C ([ggml.h](ggml.h) / [ggml.c](ggml.c))
- The transformer model and the high-level C-style API are implemented in C++ ([whisper.h](whisper.h) / [whisper.cpp](whisper.cpp))
- Sample usage is demonstrated in [main.cpp](examples/main)
- Sample real-time audio transcription from the microphone is demonstrated in [stream.cpp](examples/stream)
- Various other examples are available in the [examples](examples) folder
The tensor operators are optimized heavily for Apple silicon CPUs. Depending on the computation size, Arm Neon SIMD
instrisics or CBLAS Accelerate framework routines are used. The latter are especially effective for bigger sizes since
the Accelerate framework utilizes the special-purpose AMX coprocessor available in modern Apple products.
## Limitations
- Inference only
- No GPU support
- Very basic greedy sampling scheme - always pick up the token with highest probability.
This should be similar to the [GreedyDecoder](https://github.com/openai/whisper/blob/main/whisper/decoding.py#L249-L274)
from the original python implementation, so in order to make a fair comparison between the 2 implementations, make sure
to run the python code with the following parameters:
```
whisper --best_of None --beam_size None ...
```
In the future, `whisper.cpp` will support more sampling strategies.
## Quick start
First, download one of the Whisper models converted in [ggml format](models). For example:
@ -59,8 +86,8 @@ For a quick demo, simply run `make base.en`:
```java
$ make base.en
cc -I. -O3 -std=c11 -pthread -DGGML_USE_ACCELERATE -c ggml.c
c++ -I. -I./examples -O3 -std=c++11 -pthread -c whisper.cpp
cc -I. -O3 -std=c11 -pthread -DGGML_USE_ACCELERATE -c ggml.c -o ggml.o
c++ -I. -I./examples -O3 -std=c++11 -pthread -c whisper.cpp -o whisper.o
c++ -I. -I./examples -O3 -std=c++11 -pthread examples/main/main.cpp whisper.o ggml.o -o main -framework Accelerate
./main -h
@ -70,13 +97,18 @@ options:
-h, --help show this help message and exit
-s SEED, --seed SEED RNG seed (default: -1)
-t N, --threads N number of threads to use during computation (default: 4)
-p N, --processors N number of processors to use during computation (default: 1)
-ot N, --offset-t N time offset in milliseconds (default: 0)
-on N, --offset-n N segment index offset (default: 0)
-mc N, --max-context N maximum number of text context tokens to store (default: max)
-ml N, --max-len N maximum segment length in characters (default: 0)
-wt N, --word-thold N word timestamp probability threshold (default: 0.010000)
-v, --verbose verbose output
--translate translate from source language to english
-otxt, --output-txt output result in a text file
-ovtt, --output-vtt output result in a vtt file
-osrt, --output-srt output result in a srt file
-owts, --output-words output script for generating karaoke video
-ps, --print_special print special tokens
-pc, --print_colors print colors
-nt, --no_timestamps do not print timestamps
@ -114,23 +146,26 @@ whisper_model_load: n_text_layer = 6
whisper_model_load: n_mels = 80
whisper_model_load: f16 = 1
whisper_model_load: type = 2
whisper_model_load: mem_required = 505.00 MB
whisper_model_load: mem_required = 670.00 MB
whisper_model_load: adding 1607 extra tokens
whisper_model_load: ggml ctx size = 163.43 MB
whisper_model_load: ggml ctx size = 140.60 MB
whisper_model_load: memory size = 22.83 MB
whisper_model_load: model size = 140.54 MB
main: processing 'samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, lang = en, task = transcribe, timestamps = 1 ...
system_info: n_threads = 4 / 10 | AVX2 = 0 | AVX512 = 0 | NEON = 1 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 |
[00:00.000 --> 00:11.000] And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country.
main: processing 'samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 processors, lang = en, task = transcribe, timestamps = 1 ...
whisper_print_timings: load time = 87.21 ms
whisper_print_timings: mel time = 24.26 ms
whisper_print_timings: sample time = 3.87 ms
whisper_print_timings: encode time = 323.67 ms / 53.94 ms per layer
whisper_print_timings: decode time = 83.25 ms / 13.87 ms per layer
whisper_print_timings: total time = 522.66 ms
[00:00:00.000 --> 00:00:11.000] And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country.
whisper_print_timings: load time = 105.91 ms
whisper_print_timings: mel time = 24.62 ms
whisper_print_timings: sample time = 3.63 ms
whisper_print_timings: encode time = 324.71 ms / 54.12 ms per layer
whisper_print_timings: decode time = 83.58 ms / 13.93 ms per layer
whisper_print_timings: total time = 542.81 ms
```
The command downloads the `base.en` model converted to custom `ggml` format and runs the inference on all `.wav` samples in the folder `samples`.
@ -168,6 +203,16 @@ make medium
make large
```
## Memory usage
| Model | Disk | Mem | SHA |
| --- | --- | --- | --- |
| tiny | 75 MB | ~390 MB | `bd577a113a864445d4c299885e0cb97d4ba92b5f` |
| base | 142 MB | ~500 MB | `465707469ff3a37a2b9b8d8f89f2f99de7299dac` |
| small | 466 MB | ~1.0 GB | `55356645c2b361a969dfd0ef2c5a50d530afd8d5` |
| medium | 1.5 GB | ~2.6 GB | `fd9727b6e1217c2f614f9b698455c4ffd82463b4` |
| large | 2.9 GB | ~4.7 GB | `b1caaf735c4cc1429223d5a74f0f4d0b9b59a299` |
## Another example
Here is another example of transcribing a [3:24 min speech](https://upload.wikimedia.org/wikipedia/commons/1/1f/George_W_Bush_Columbia_FINAL.ogg)
@ -263,42 +308,108 @@ to highlight words with high or low confidence:
<img width="965" alt="image" src="https://user-images.githubusercontent.com/1991296/197356445-311c8643-9397-4e5e-b46e-0b4b4daa2530.png">
## Implementation details
## Controlling the length of the generated text segments (experimental)
- The core tensor operations are implemented in C ([ggml.h](ggml.h) / [ggml.c](ggml.c))
- The high-level C-style API is implemented in C++ ([whisper.h](whisper.h) / [whisper.cpp](whisper.cpp))
- Sample usage is demonstrated in [main.cpp](examples/main)
- Sample real-time audio transcription from the microphone is demonstrated in [stream.cpp](examples/stream)
- Various other examples are available in the [examples](examples) folder
For example, to limit the line length to a maximum of 16 characters, simply add `-ml 16`:
The tensor operators are optimized heavily for Apple silicon CPUs. Depending on the computation size, Arm Neon SIMD
instrisics or CBLAS Accelerate framework routines are used. The latter are especially effective for bigger sizes since
the Accelerate framework utilizes the special-purpose AMX coprocessor available in modern Apple products.
```java
./main -m ./models/ggml-base.en.bin -f ./samples/jfk.wav -ml 16
## Limitations
whisper_model_load: loading model from './models/ggml-base.en.bin'
...
system_info: n_threads = 4 / 10 | AVX2 = 0 | AVX512 = 0 | NEON = 1 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 |
- Inference only
- No GPU support
- Very basic greedy sampling scheme - always pick up the token with highest probability.
This should be similar to the [GreedyDecoder](https://github.com/openai/whisper/blob/main/whisper/decoding.py#L249-L274)
from the original python implementation, so in order to make a fair comparison between the 2 implementations, make sure
to run the python code with the following parameters:
main: processing './samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 processors, lang = en, task = transcribe, timestamps = 1 ...
```
whisper --best_of None --beam_size None ...
[00:00:00.000 --> 00:00:00.850] And so my
[00:00:00.850 --> 00:00:01.590] fellow
[00:00:01.590 --> 00:00:04.140] Americans, ask
[00:00:04.140 --> 00:00:05.660] not what your
[00:00:05.660 --> 00:00:06.840] country can do
[00:00:06.840 --> 00:00:08.430] for you, ask
[00:00:08.430 --> 00:00:09.440] what you can do
[00:00:09.440 --> 00:00:10.020] for your
[00:00:10.020 --> 00:00:11.000] country.
```
In the future, `whisper.cpp` will support more sampling strategies.
## Word-level timestamp
## Memory usage
The `--max-len` argument can be used to obtain word-level timestamps. Simply use `-ml 1`:
| Model | Disk | Mem |
| --- | --- | --- |
| tiny | 75 MB | ~280 MB |
| base | 142 MB | ~430 MB |
| small | 466 MB | ~1.0 GB |
| medium | 1.5 GB | ~2.6 GB |
| large | 2.9 GB | ~4.7 GB |
```java
./main -m ./models/ggml-base.en.bin -f ./samples/jfk.wav -ml 1
whisper_model_load: loading model from './models/ggml-base.en.bin'
...
system_info: n_threads = 4 / 10 | AVX2 = 0 | AVX512 = 0 | NEON = 1 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 |
main: processing './samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 processors, lang = en, task = transcribe, timestamps = 1 ...
[00:00:00.000 --> 00:00:00.320]
[00:00:00.320 --> 00:00:00.370] And
[00:00:00.370 --> 00:00:00.690] so
[00:00:00.690 --> 00:00:00.850] my
[00:00:00.850 --> 00:00:01.590] fellow
[00:00:01.590 --> 00:00:02.850] Americans
[00:00:02.850 --> 00:00:03.300] ,
[00:00:03.300 --> 00:00:04.140] ask
[00:00:04.140 --> 00:00:04.990] not
[00:00:04.990 --> 00:00:05.410] what
[00:00:05.410 --> 00:00:05.660] your
[00:00:05.660 --> 00:00:06.260] country
[00:00:06.260 --> 00:00:06.600] can
[00:00:06.600 --> 00:00:06.840] do
[00:00:06.840 --> 00:00:07.010] for
[00:00:07.010 --> 00:00:08.170] you
[00:00:08.170 --> 00:00:08.190] ,
[00:00:08.190 --> 00:00:08.430] ask
[00:00:08.430 --> 00:00:08.910] what
[00:00:08.910 --> 00:00:09.040] you
[00:00:09.040 --> 00:00:09.320] can
[00:00:09.320 --> 00:00:09.440] do
[00:00:09.440 --> 00:00:09.760] for
[00:00:09.760 --> 00:00:10.020] your
[00:00:10.020 --> 00:00:10.510] country
[00:00:10.510 --> 00:00:11.000] .
```
## Karaoke-style movie generation (experimental)
The [main](examples/main) example provides support for output of karaoke-style movies, where the
currently pronounced word is highlighted. Use the `-wts` argument and run the generated bash script.
This requires to have `ffmpeg` installed.
Here are a few *"typical"* examples:
```java
./main -m ./models/ggml-base.en.bin -f ./samples/jfk.wav -owts
source ./samples/jfk.wav.wts
ffplay ./samples/jfk.wav.mp4
```
https://user-images.githubusercontent.com/1991296/199337465-dbee4b5e-9aeb-48a3-b1c6-323ac4db5b2c.mp4
---
```java
./main -m ./models/ggml-base.en.bin -f ./samples/mm0.wav -owts
source ./samples/mm0.wav.wts
ffplay ./samples/mm0.wav.mp4
```
https://user-images.githubusercontent.com/1991296/199337504-cc8fd233-0cb7-4920-95f9-4227de3570aa.mp4
---
```java
./main -m ./models/ggml-base.en.bin -f ./samples/gb0.wav -owts
source ./samples/gb0.wav.wts
ffplay ./samples/gb0.wav.mp4
```
https://user-images.githubusercontent.com/1991296/199337538-b7b0c7a3-2753-4a88-a0cd-f28a317987ba.mp4
---
## Benchmarks

View File

@ -46,8 +46,6 @@ EMSCRIPTEN_BINDINGS(whisper) {
struct whisper_full_params params = whisper_full_default_params(whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY);
printf("full_default: available threads %d\n", std::thread::hardware_concurrency());
params.print_realtime = true;
params.print_progress = false;
params.print_timestamps = true;
@ -57,9 +55,6 @@ EMSCRIPTEN_BINDINGS(whisper) {
params.n_threads = std::min(8, (int) std::thread::hardware_concurrency());
params.offset_ms = 0;
printf("full_default: using %d threads\n", params.n_threads);
printf("full_default: language '%s'\n", params.language);
std::vector<float> pcmf32;
const int n = audio["length"].as<int>();
@ -71,6 +66,20 @@ EMSCRIPTEN_BINDINGS(whisper) {
emscripten::val memoryView = audio["constructor"].new_(memory, reinterpret_cast<uintptr_t>(pcmf32.data()), n);
memoryView.call<void>("set", audio);
// print system information
{
printf("system_info: n_threads = %d / %d | %s\n",
params.n_threads, std::thread::hardware_concurrency(), whisper_print_system_info());
printf("%s: processing %d samples, %.1f sec, %d threads, %d processors, lang = %s, task = %s ...\n",
__func__, int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
params.n_threads, 1,
params.language,
params.translate ? "translate" : "transcribe");
printf("\n");
}
int ret = whisper_full(g_contexts[index], params, pcmf32.data(), pcmf32.size());
whisper_print_timings(g_contexts[index]);

File diff suppressed because one or more lines are too long

View File

@ -6,21 +6,29 @@ It can be used as a reference for using the `whisper.cpp` library in other proje
```
./main -h
usage: ./main [options] file0.wav file1.wav ...
usage: ./bin/main [options] file0.wav file1.wav ...
options:
-h, --help show this help message and exit
-s SEED, --seed SEED RNG seed (default: -1)
-t N, --threads N number of threads to use during computation (default: 4)
-o N, --offset N offset in milliseconds (default: 0)
-p N, --processors N number of processors to use during computation (default: 1)
-ot N, --offset-t N time offset in milliseconds (default: 0)
-on N, --offset-n N segment index offset (default: 0)
-mc N, --max-context N maximum number of text context tokens to store (default: max)
-ml N, --max-len N maximum segment length in characters (default: 0)
-wt N, --word-thold N word timestamp probability threshold (default: 0.010000)
-v, --verbose verbose output
--translate translate from source language to english
-otxt, --output-txt output result in a text file
-ovtt, --output-vtt output result in a vtt file
-osrt, --output-srt output result in a srt file
-owts, --output-words output script for generating karaoke video
-ps, --print_special print special tokens
-pc, --print_colors print colors
-nt, --no_timestamps do not print timestamps
-l LANG, --language LANG spoken language (default: en)
-m FNAME, --model FNAME model path (default: models/ggml-base.en.bin)
-f FNAME, --file FNAME input WAV file path
-h, --help show this help message and exit
```

View File

@ -36,18 +36,34 @@ std::string to_timestamp(int64_t t, bool comma = false) {
return std::string(buf);
}
// helper function to replace substrings
void replace_all(std::string & s, const std::string & search, const std::string & replace) {
for (size_t pos = 0; ; pos += replace.length()) {
pos = s.find(search, pos);
if (pos == std::string::npos) break;
s.erase(pos, search.length());
s.insert(pos, replace);
}
}
// command-line parameters
struct whisper_params {
int32_t seed = -1; // RNG seed, not used currently
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
int32_t n_processors = 1;
int32_t offset_t_ms = 0;
int32_t offset_n = 0;
int32_t max_context = -1;
int32_t max_len = 0;
float word_thold = 0.01f;
bool verbose = false;
bool translate = false;
bool output_txt = false;
bool output_vtt = false;
bool output_srt = false;
bool output_wts = false;
bool print_special_tokens = false;
bool print_colors = false;
bool no_timestamps = false;
@ -73,10 +89,18 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
params.seed = std::stoi(argv[++i]);
} else if (arg == "-t" || arg == "--threads") {
params.n_threads = std::stoi(argv[++i]);
} else if (arg == "-p" || arg == "--processors") {
params.n_processors = std::stoi(argv[++i]);
} else if (arg == "-ot" || arg == "--offset-t") {
params.offset_t_ms = std::stoi(argv[++i]);
} else if (arg == "-on" || arg == "--offset-n") {
params.offset_n = std::stoi(argv[++i]);
} else if (arg == "-mc" || arg == "--max-context") {
params.max_context = std::stoi(argv[++i]);
} else if (arg == "-ml" || arg == "--max-len") {
params.max_len = std::stoi(argv[++i]);
} else if (arg == "-wt" || arg == "--word-thold") {
params.word_thold = std::stof(argv[++i]);
} else if (arg == "-v" || arg == "--verbose") {
params.verbose = true;
} else if (arg == "--translate") {
@ -94,6 +118,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
params.output_vtt = true;
} else if (arg == "-osrt" || arg == "--output-srt") {
params.output_srt = true;
} else if (arg == "-owts" || arg == "--output-words") {
params.output_wts = true;
} else if (arg == "-ps" || arg == "--print_special") {
params.print_special_tokens = true;
} else if (arg == "-pc" || arg == "--print_colors") {
@ -125,13 +151,18 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
fprintf(stderr, " -h, --help show this help message and exit\n");
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n");
fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
fprintf(stderr, " -p N, --processors N number of processors to use during computation (default: %d)\n", params.n_processors);
fprintf(stderr, " -ot N, --offset-t N time offset in milliseconds (default: %d)\n", params.offset_t_ms);
fprintf(stderr, " -on N, --offset-n N segment index offset (default: %d)\n", params.offset_n);
fprintf(stderr, " -mc N, --max-context N maximum number of text context tokens to store (default: max)\n");
fprintf(stderr, " -ml N, --max-len N maximum segment length in characters (default: %d)\n", params.max_len);
fprintf(stderr, " -wt N, --word-thold N word timestamp probability threshold (default: %f)\n", params.word_thold);
fprintf(stderr, " -v, --verbose verbose output\n");
fprintf(stderr, " --translate translate from source language to english\n");
fprintf(stderr, " -otxt, --output-txt output result in a text file\n");
fprintf(stderr, " -ovtt, --output-vtt output result in a vtt file\n");
fprintf(stderr, " -osrt, --output-srt output result in a srt file\n");
fprintf(stderr, " -owts, --output-words output script for generating karaoke video\n");
fprintf(stderr, " -ps, --print_special print special tokens\n");
fprintf(stderr, " -pc, --print_colors print colors\n");
fprintf(stderr, " -nt, --no_timestamps do not print timestamps\n");
@ -141,17 +172,18 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
fprintf(stderr, "\n");
}
void whisper_print_segment_callback(struct whisper_context * ctx, void * user_data) {
void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) {
const whisper_params & params = *(whisper_params *) user_data;
const int n_segments = whisper_full_n_segments(ctx);
// print the last segment
const int i = n_segments - 1;
if (i == 0) {
// print the last n_new segments
const int s0 = n_segments - n_new;
if (s0 == 0) {
printf("\n");
}
for (int i = s0; i < n_segments; i++) {
if (params.no_timestamps) {
if (params.print_colors) {
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
@ -203,6 +235,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, void * user_da
}
}
}
}
bool output_txt(struct whisper_context * ctx, const char * fname) {
std::ofstream fout(fname);
@ -269,6 +302,121 @@ bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_
return true;
}
// karaoke video generation
// outputs a bash script that uses ffmpeg to generate a video with the subtitles
// TODO: font parameter adjustments
bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, float t_sec) {
std::ofstream fout(fname);
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
// TODO: become parameter
static const char * font = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
fout << "#!/bin/bash" << "\n";
fout << "\n";
fout << "ffmpeg -i " << fname_inp << " -f lavfi -i color=size=1200x120:duration=" << t_sec << ":rate=25:color=black -vf \"";
for (int i = 0; i < whisper_full_n_segments(ctx); i++) {
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
const int n = whisper_full_n_tokens(ctx, i);
std::vector<whisper_token_data> tokens(n);
for (int j = 0; j < n; ++j) {
tokens[j] = whisper_full_get_token_data(ctx, i, j);
}
if (i > 0) {
fout << ",";
}
// background text
fout << "drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='':enable='between(t," << t0/100.0 << "," << t0/100.0 << ")'";
bool is_first = true;
for (int j = 0; j < n; ++j) {
const auto & token = tokens[j];
if (tokens[j].id >= whisper_token_eot(ctx)) {
continue;
}
std::string txt_bg;
std::string txt_fg; // highlight token
std::string txt_ul; // underline
txt_bg = "> ";
txt_fg = "> ";
txt_ul = "\\ \\ ";
{
int ncnt = 0;
for (int k = 0; k < n; ++k) {
const auto & token2 = tokens[k];
if (tokens[k].id >= whisper_token_eot(ctx)) {
continue;
}
const std::string txt = whisper_token_to_str(ctx, token2.id);
txt_bg += txt;
if (k == j) {
for (int l = 0; l < (int) txt.size(); ++l) {
txt_fg += txt[l];
txt_ul += "_";
}
txt_fg += "|";
} else {
for (int l = 0; l < (int) txt.size(); ++l) {
txt_fg += "\\ ";
txt_ul += "\\ ";
}
}
ncnt += txt.size();
}
::replace_all(txt_bg, "'", "");
::replace_all(txt_bg, "\"", "\\\"");
::replace_all(txt_fg, "'", "");
::replace_all(txt_fg, "\"", "\\\"");
}
if (is_first) {
// background text
fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='" << txt_bg << "':enable='between(t," << t0/100.0 << "," << t1/100.0 << ")'";
is_first = false;
}
// foreground text
fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=lightgreen:x=(w-text_w)/2+8:y=h/2:text='" << txt_fg << "':enable='between(t," << token.t0/100.0 << "," << token.t1/100.0 << ")'";
// underline
fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=lightgreen:x=(w-text_w)/2+8:y=h/2+16:text='" << txt_ul << "':enable='between(t," << token.t0/100.0 << "," << token.t1/100.0 << ")'";
}
}
fout << "\" -c:v libx264 -pix_fmt yuv420p -y " << fname_inp << ".mp4" << "\n";
fout << "\n\n";
fout << "echo \"Your video has been saved to " << fname_inp << ".mp4\"" << "\n";
fout << "\n";
fout << "echo \" ffplay " << fname_inp << ".mp4\"\n";
fout << "\n";
fout.close();
fprintf(stderr, "%s: run 'source %s' to generate karaoke video\n", __func__, fname);
return true;
}
int main(int argc, char ** argv) {
whisper_params params;
@ -346,7 +494,8 @@ int main(int argc, char ** argv) {
// print system information
{
fprintf(stderr, "\n");
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", params.n_threads, std::thread::hardware_concurrency(), whisper_print_system_info());
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
params.n_threads*params.n_processors, std::thread::hardware_concurrency(), whisper_print_system_info());
}
// print some info about the processing
@ -359,8 +508,9 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
}
}
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, lang = %s, task = %s, timestamps = %d ...\n",
__func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE, params.n_threads,
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, timestamps = %d ...\n",
__func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
params.n_threads, params.n_processors,
params.language.c_str(),
params.translate ? "translate" : "transcribe",
params.no_timestamps ? 0 : 1);
@ -380,19 +530,27 @@ int main(int argc, char ** argv) {
wparams.translate = params.translate;
wparams.language = params.language.c_str();
wparams.n_threads = params.n_threads;
wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
wparams.offset_ms = params.offset_t_ms;
wparams.token_timestamps = params.output_wts || params.max_len > 0;
wparams.thold_pt = params.word_thold;
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
// this callback is called on each new segment
if (!wparams.print_realtime) {
wparams.new_segment_callback = whisper_print_segment_callback;
wparams.new_segment_callback_user_data = &params;
}
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
fprintf(stderr, "%s: failed to process audio\n", argv[0]);
return 8;
}
}
// output stuff
{
printf("\n");
// output to text file
@ -412,6 +570,12 @@ int main(int argc, char ** argv) {
const auto fname_srt = fname_inp + ".srt";
output_srt(ctx, fname_srt.c_str(), params);
}
// output to WTS file
if (params.output_wts) {
const auto fname_wts = fname_inp + ".wts";
output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE);
}
}
}

View File

@ -17,6 +17,7 @@
#include <string>
#include <thread>
#include <vector>
#include <fstream>
// 500 -> 00:05.000
// 6000 -> 01:00.000
@ -38,6 +39,7 @@ struct whisper_params {
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
int32_t step_ms = 3000;
int32_t length_ms = 10000;
int32_t capture_id = -1;
bool verbose = false;
bool translate = false;
@ -47,7 +49,7 @@ struct whisper_params {
std::string language = "en";
std::string model = "models/ggml-base.en.bin";
std::string fname_inp = "samples/jfk.wav";
std::string fname_out = "";
};
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
@ -64,6 +66,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
params.step_ms = std::stoi(argv[++i]);
} else if (arg == "--length") {
params.length_ms = std::stoi(argv[++i]);
} else if (arg == "-c" || arg == "--capture") {
params.capture_id = std::stoi(argv[++i]);
} else if (arg == "-v" || arg == "--verbose") {
params.verbose = true;
} else if (arg == "--translate") {
@ -84,7 +88,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
} else if (arg == "-m" || arg == "--model") {
params.model = argv[++i];
} else if (arg == "-f" || arg == "--file") {
params.fname_inp = argv[++i];
params.fname_out = argv[++i];
} else if (arg == "-h" || arg == "--help") {
whisper_print_usage(argc, argv, params);
exit(0);
@ -108,6 +112,7 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
fprintf(stderr, " --step N audio step size in milliseconds (default: %d)\n", params.step_ms);
fprintf(stderr, " --length N audio length in milliseconds (default: %d)\n", params.length_ms);
fprintf(stderr, " -c ID, --capture ID capture device ID (default: -1)\n");
fprintf(stderr, " -v, --verbose verbose output\n");
fprintf(stderr, " --translate translate from source language to english\n");
fprintf(stderr, " -kc, --keep-context keep text context from earlier audio (default: false)\n");
@ -115,7 +120,7 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
fprintf(stderr, " -nt, --no_timestamps do not print timestamps\n");
fprintf(stderr, " -l LANG, --language LANG spoken language (default: %s)\n", params.language.c_str());
fprintf(stderr, " -m FNAME, --model FNAME model path (default: %s)\n", params.model.c_str());
fprintf(stderr, " -f FNAME, --file FNAME input WAV file path (default: %s)\n", params.fname_inp.c_str());
fprintf(stderr, " -f FNAME, --file FNAME text output file name (default: no output to file)\n");
fprintf(stderr, "\n");
}
@ -143,9 +148,9 @@ bool audio_sdl_init(const int capture_id) {
{
int nDevices = SDL_GetNumAudioDevices(SDL_TRUE);
printf("%s: found %d capture devices:\n", __func__, nDevices);
fprintf(stderr, "%s: found %d capture devices:\n", __func__, nDevices);
for (int i = 0; i < nDevices; i++) {
printf("%s: - Capture device #%d: '%s'\n", __func__, i, SDL_GetAudioDeviceName(i, SDL_TRUE));
fprintf(stderr, "%s: - Capture device #%d: '%s'\n", __func__, i, SDL_GetAudioDeviceName(i, SDL_TRUE));
}
}
}
@ -163,21 +168,21 @@ bool audio_sdl_init(const int capture_id) {
capture_spec_requested.samples = 1024;
if (capture_id >= 0) {
printf("%s: attempt to open capture device %d : '%s' ...\n", __func__, capture_id, SDL_GetAudioDeviceName(capture_id, SDL_TRUE));
fprintf(stderr, "%s: attempt to open capture device %d : '%s' ...\n", __func__, capture_id, SDL_GetAudioDeviceName(capture_id, SDL_TRUE));
g_dev_id_in = SDL_OpenAudioDevice(SDL_GetAudioDeviceName(capture_id, SDL_TRUE), SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
} else {
printf("%s: attempt to open default capture device ...\n", __func__);
fprintf(stderr, "%s: attempt to open default capture device ...\n", __func__);
g_dev_id_in = SDL_OpenAudioDevice(nullptr, SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
}
if (!g_dev_id_in) {
printf("%s: couldn't open an audio device for capture: %s!\n", __func__, SDL_GetError());
fprintf(stderr, "%s: couldn't open an audio device for capture: %s!\n", __func__, SDL_GetError());
g_dev_id_in = 0;
} else {
printf("%s: obtained spec for input device (SDL Id = %d):\n", __func__, g_dev_id_in);
printf("%s: - sample rate: %d\n", __func__, capture_spec_obtained.freq);
printf("%s: - format: %d (required: %d)\n", __func__, capture_spec_obtained.format, capture_spec_requested.format);
printf("%s: - channels: %d (required: %d)\n", __func__, capture_spec_obtained.channels, capture_spec_requested.channels);
printf("%s: - samples per frame: %d\n", __func__, capture_spec_obtained.samples);
fprintf(stderr, "%s: obtained spec for input device (SDL Id = %d):\n", __func__, g_dev_id_in);
fprintf(stderr, "%s: - sample rate: %d\n", __func__, capture_spec_obtained.freq);
fprintf(stderr, "%s: - format: %d (required: %d)\n", __func__, capture_spec_obtained.format, capture_spec_requested.format);
fprintf(stderr, "%s: - channels: %d (required: %d)\n", __func__, capture_spec_obtained.channels, capture_spec_requested.channels);
fprintf(stderr, "%s: - samples per frame: %d\n", __func__, capture_spec_obtained.samples);
}
}
@ -200,7 +205,7 @@ int main(int argc, char ** argv) {
// init audio
if (!audio_sdl_init(-1)) {
if (!audio_sdl_init(params.capture_id)) {
fprintf(stderr, "%s: audio_sdl_init() failed!\n", __func__);
return 1;
}
@ -212,6 +217,7 @@ int main(int argc, char ** argv) {
const int n_samples = (params.step_ms/1000.0)*WHISPER_SAMPLE_RATE;
const int n_samples_len = (params.length_ms/1000.0)*WHISPER_SAMPLE_RATE;
const int n_samples_30s = 30*WHISPER_SAMPLE_RATE;
std::vector<float> pcmf32(n_samples_30s, 0.0f);
std::vector<float> pcmf32_old;
@ -219,15 +225,15 @@ int main(int argc, char ** argv) {
// print some info about the processing
{
printf("\n");
fprintf(stderr, "\n");
if (!whisper_is_multilingual(ctx)) {
if (params.language != "en" || params.translate) {
params.language = "en";
params.translate = false;
printf("%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
}
}
printf("%s: processing %d samples (step = %.1f sec / len = %.1f sec), %d threads, lang = %s, task = %s, timestamps = %d ...\n",
fprintf(stderr, "%s: processing %d samples (step = %.1f sec / len = %.1f sec), %d threads, lang = %s, task = %s, timestamps = %d ...\n",
__func__,
n_samples,
float(n_samples)/WHISPER_SAMPLE_RATE,
@ -237,8 +243,8 @@ int main(int argc, char ** argv) {
params.translate ? "translate" : "transcribe",
params.no_timestamps ? 0 : 1);
printf("%s: n_new_line = %d\n", __func__, n_new_line);
printf("\n");
fprintf(stderr, "%s: n_new_line = %d\n", __func__, n_new_line);
fprintf(stderr, "\n");
}
SDL_PauseAudioDevice(g_dev_id_in, 0);
@ -246,6 +252,18 @@ int main(int argc, char ** argv) {
int n_iter = 0;
bool is_running = true;
std::ofstream fout;
if (params.fname_out.length() > 0) {
fout.open(params.fname_out);
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open output file '%s'!\n", __func__, params.fname_out.c_str());
return 1;
}
}
printf("[Start speaking]");
fflush(stdout);
// main audio loop
while (is_running) {
// process SDL events:
@ -253,13 +271,18 @@ int main(int argc, char ** argv) {
while (SDL_PollEvent(&event)) {
switch (event.type) {
case SDL_QUIT:
{
is_running = false;
break;
} break;
default:
break;
}
}
if (!is_running) {
break;
}
// process new audio
if (n_iter > 0 && SDL_GetQueuedAudioSize(g_dev_id_in) > 2*n_samples*sizeof(float)) {
fprintf(stderr, "\n\n%s: WARNING: cannot process audio fast enough, dropping audio ...\n\n", __func__);
@ -312,6 +335,11 @@ int main(int argc, char ** argv) {
{
printf("\33[2K\r");
// print long empty line to clear the previous line
printf("%s", std::string(100, ' ').c_str());
printf("\33[2K\r");
const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
@ -319,15 +347,27 @@ int main(int argc, char ** argv) {
if (params.no_timestamps) {
printf("%s", text);
fflush(stdout);
if (params.fname_out.length() > 0) {
fout << text;
}
} else {
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
printf ("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
if (params.fname_out.length() > 0) {
fout << "[" << to_timestamp(t0) << " --> " << to_timestamp(t1) << "] " << text << std::endl;
}
}
}
if (params.fname_out.length() > 0) {
fout << std::endl;
}
}
++n_iter;
if ((n_iter % n_new_line) == 0) {

View File

@ -0,0 +1,92 @@
# whisper.nvim
Speech-to-text in Neovim
The transcription is performed on the CPU and no data leaves your computer. Works best on Apple Silicon devices.
https://user-images.githubusercontent.com/1991296/198382564-784e9663-2037-4d04-99b8-f39136929b7e.mp4
## Usage
- Simply press `Ctrl-G` in `INSERT`, `VISUAL` or `NORMAL` mode and say something
- When you are done - press `Ctrl-C` to end the transcription and insert the transcribed text under the cursor
## Installation
*Note: this is a bit tedious and hacky atm, but I hope it will be improved with time*
- Clone this repo and build the `stream` tool:
```
git clone https://github.com/ggerganov/whisper.cpp
cd whisper.cpp
make stream
```
- Download the `base.en` Whisper model (140 MB):
```
./models/download-ggml-model.sh base.en
```
- Place the [whisper.nvim](whisper.nvim) script somewhere in your PATH and give it execute permissions:
```
cp examples/whisper.nvim/whisper.nvim ~/bin/
chmod u+x ~/bin/whisper.nvim
```
- Fine-tune the script to your preference and machine parameters:
```
./stream -t 8 -m models/ggml-base.en.bin --step 350 --length 10000 -f /tmp/whisper.nvim 2> /dev/null
```
On slower machines, try to increase the `step` parameter.
- Add the following shortcuts to your `~/.config/nvim/init.vim`:
```
inoremap <C-G> <C-O>:!whisper.nvim<CR><C-O>:let @a = system("cat /tmp/whisper.nvim \| tail -n 1 \| xargs -0 \| tr -d '\\n' \| sed -e 's/^[[:space:]]*//'")<CR><C-R>a
nnoremap <C-G> :!whisper.nvim<CR>:let @a = system("cat /tmp/whisper.nvim \| tail -n 1 \| xargs -0 \| tr -d '\\n' \| sed -e 's/^[[:space:]]*//'")<CR>"ap
vnoremap <C-G> c<C-O>:!whisper.nvim<CR><C-O>:let @a = system("cat /tmp/whisper.nvim \| tail -n 1 \| xargs -0 \| tr -d '\\n' \| sed -e 's/^[[:space:]]*//'")<CR><C-R>a
```
Explanation: pressing `Ctrl-G` runs the [whisper.nvim](whisper.nvim) script which in turn calls the `stream` binary to transcribe your speech through the microphone. The results from the transcription are continuously dumped into `/tmp/whisper.nvim`. After you kill the program with `Ctrl-C`, the vim command grabs the last line from the `/tmp/whisper.nvim` file and puts it under the cursor.
Probably there is a much more intelligent way to achieve all this, but this is what I could hack in an hour. Any suggestions how to improve this are welcome.
You are now ready to use speech-to-text in Neovim!
## TODO
There are a lot of ways to improve this idea and I don't have much experience with Vim plugin programming, so contributions are welcome!
- [ ] **Wrap this into a plugin**
It would be great to make a standalone plugin out of this that can be installed with `vim-plug` or similar
- [ ] **Simplify the `init.vim` mappings (maybe factor out the common call into a separate function)**
- [ ] **Add Copilot/GPT-3 integration**
This is probably a very long shot, but I think it will be very cool to have the functionality to select some code and then hit Ctrl-G and say something like:
*"refactor this using stl containers"*
or
*"optimize by sorting the data first"*
The plugin would then make an appropriate query using the selected text and code context to Copilot or GPT-3 and return the result.
Here is a proof-of-concept:
https://user-images.githubusercontent.com/1991296/199078847-0278fcde-5667-4748-ba0d-7d55381d6047.mp4
https://user-images.githubusercontent.com/1991296/200067939-f98d2ac2-7519-438a-85f9-79db0841ba4f.mp4
For explanation how this works see: https://twitter.com/ggerganov/status/1587168771789258756
## Discussion
If you find this idea interesting, you can join the discussion here: https://github.com/ggerganov/whisper.cpp/discussions/108

View File

@ -0,0 +1,50 @@
#!/bin/bash
# INSTRUCTIONS
#
# This simple script is called by Neovim to capture audio from the microphone and transcribe it with Whisper.
# In order for this to work, you need to clone the whisper.cpp repo and build the 'stream' tool
#
# git clone https://github.com/ggerganov/whisper.cpp
# cd whisper.cpp
# make stream
#
# Also, make sure the current script is in your PATH env variable. You should be able to run the following command:
#
# whisper.nvim
#
# Next, export the path to the whisper.cpp repository via the WHISPER_CPP_HOME env variable:
#
# export WHISPER_CPP_HOME=/path/to/whisper.cpp
#
# Finally, add the following lines to your ~/.config/nvim/init.vim:
#
# inoremap <C-G> <C-O>:!whisper.nvim<CR><C-O>:let @a = system("cat /tmp/whisper.nvim \| tail -n 1 \| xargs -0 \| tr -d '\\n' \| sed -e 's/^[[:space:]]*//'")<CR><C-R>a
# nnoremap <C-G> :!whisper.nvim<CR>:let @a = system("cat /tmp/whisper.nvim \| tail -n 1 \| xargs -0 \| tr -d '\\n' \| sed -e 's/^[[:space:]]*//'")<CR>"ap
# vnoremap <C-G> c<C-O>:!whisper.nvim<CR><C-O>:let @a = system("cat /tmp/whisper.nvim \| tail -n 1 \| xargs -0 \| tr -d '\\n' \| sed -e 's/^[[:space:]]*//'")<CR><C-R>a
#
# This allows you to press Ctrl-G in order to capture audio from the microphone and transcribe it.
# When you are done speaking - press Ctrl-C
#
# the Whisper model to use
model="base.en"
# export the path to the whisper.cpp repo in the WHISPER_CPP_HOME env variable
# https://github.com/ggerganov/whisper.cpp
cd ${WHISPER_CPP_HOME}
if [ ! -f ./stream ] ; then
echo "whisper.nvim: the 'stream' executable was not found! WHISPER_CPP_HOME=${WHISPER_CPP_HOME}" > /tmp/whisper.nvim
exit 1
fi
if [ ! -f ./models/ggml-${model}.bin ] ; then
echo "whisper.nvim: the '$model' model was not found! WHISPER_CPP_HOME=${WHISPER_CPP_HOME}" > /tmp/whisper.nvim
exit 2
fi
# fine-tune the parameters according to your machine specs
./stream -t 8 -m models/ggml-base.en.bin --step 350 --length 10000 -f /tmp/whisper.nvim 2> /dev/null
exit 0

View File

@ -469,7 +469,6 @@
printTextarea('js: processing - this might take a while ...');
printTextarea('js: the page will be unresponsive until the processing is completed');
printTextarea('');
printTextarea('');
setTimeout(function() {
var ret = Module.full_default(instance, audio, document.getElementById('language').value, translate);

7
extra/sha-all.sh Executable file
View File

@ -0,0 +1,7 @@
#!/bin/bash
# Compute the SHA1 of all model files in ./models/ggml-*.bin
for f in ./models/ggml-*.bin; do
shasum "$f" -a 1
done

138
ggml.c
View File

@ -14,7 +14,7 @@
#include <stdint.h>
#include <stdio.h>
#if defined _MSC_VER
#if defined _MSC_VER || defined(__MINGW32__)
#include <Windows.h>
typedef volatile LONG atomic_int;
@ -44,6 +44,11 @@ static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void
static int pthread_join(pthread_t thread, void* unused) {
return (int) WaitForSingleObject(thread, INFINITE);
}
static int sched_yield (void) {
Sleep (0);
return 0;
}
#else
#include <pthread.h>
#include <stdatomic.h>
@ -76,6 +81,8 @@ typedef void* thread_ret_t;
#ifdef GGML_USE_ACCELERATE
#include <Accelerate/Accelerate.h>
#elif GGML_USE_OPENBLAS
#include <cblas.h>
#endif
// floating point type used to accumulate sums
@ -191,7 +198,7 @@ static ggml_fp16_t table_exp_f16[1 << 16];
// timing
//
#if defined(_MSC_VER)
#if defined(_MSC_VER) || defined(__MINGW32__)
static int64_t timer_freq;
void ggml_time_init(void) {
LARGE_INTEGER frequency;
@ -1284,6 +1291,7 @@ struct ggml_state {
// global state
struct ggml_state g_state;
atomic_int g_state_barrier = 0;
////////////////////////////////////////////////////////////////////////////////
@ -1413,6 +1421,17 @@ int ggml_up64(int n) {
////////////////////////////////////////////////////////////////////////////////
struct ggml_context * ggml_init(struct ggml_init_params params) {
// make this function thread safe
{
int processing = atomic_fetch_add(&g_state_barrier, 1);
while (processing > 0) {
// wait for other threads to finish
atomic_fetch_sub(&g_state_barrier, 1);
sched_yield();
processing = atomic_fetch_add(&g_state_barrier, 1);
}
}
static bool is_first_call = true;
if (is_first_call) {
const uint64_t t_start = ggml_time_us(); UNUSED(t_start);
@ -1456,6 +1475,9 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
if (ctx == NULL) {
GGML_PRINT_DEBUG("%s: no unused context found\n", __func__);
atomic_fetch_sub(&g_state_barrier, 1);
return NULL;
}
@ -1470,10 +1492,25 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
ggml_assert_aligned(ctx->mem_buffer);
GGML_PRINT_DEBUG("%s: context initialized\n", __func__);
atomic_fetch_sub(&g_state_barrier, 1);
return ctx;
}
void ggml_free(struct ggml_context * ctx) {
// make this function thread safe
{
int processing = atomic_fetch_add(&g_state_barrier, 1);
while (processing > 0) {
// wait for other threads to finish
atomic_fetch_sub(&g_state_barrier, 1);
sched_yield();
processing = atomic_fetch_add(&g_state_barrier, 1);
}
}
for (int i = 0; i < GGML_MAX_CONTEXTS; i++) {
if (&g_state.contexts[i].context == ctx) {
g_state.contexts[i].used = false;
@ -1485,11 +1522,15 @@ void ggml_free(struct ggml_context * ctx) {
free(ctx->mem_buffer);
}
atomic_fetch_sub(&g_state_barrier, 1);
return;
}
}
GGML_PRINT_DEBUG("%s: context not found\n", __func__);
atomic_fetch_sub(&g_state_barrier, 1);
}
size_t ggml_used_mem(const struct ggml_context * ctx) {
@ -3259,7 +3300,10 @@ void ggml_compute_forward_add_f32(
GGML_ASSERT(nb00 == sizeof(float));
if (nb10 == sizeof(float)) {
for (int j = ith; j < n; j += nth) {
const int j0 = (n/nth)*ith;
const int j1 = ith == nth - 1 ? n : (n/nth)*(ith + 1);
for (int j = j0; j < j1; j++) {
ggml_vec_add_f32(nc,
(float *) ((char *) dst->data + j*nb1),
(float *) ((char *) src0->data + j*nb01),
@ -4205,46 +4249,44 @@ void ggml_compute_forward_mul_mat_f32(
// nb00 < nb01 - src0 is transposed
// compute by src0 columns
//#ifdef GGML_USE_ACCELERATE
// if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
// GGML_ASSERT(ggml_is_contiguous(src0));
// GGML_ASSERT(nb10 == sizeof(float));
//
// if (params->ith != 0) return;
//
// if (params->type == GGML_TASK_INIT) {
// return;
// }
//
// if (params->type == GGML_TASK_FINALIZE) {
// return;
// }
//
// float * const wdata = params->wdata;
//
// for (int i03 = 0; i03 < ne03; i03++) {
// for (int i02 = 0; i02 < ne02; i02++) {
// const float * x = (float *) (src0->data);
// const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
//
// float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
//
// // zT = y * xT
// {
// cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
// ne11, ne01, ne10,
// 1.0f, y, ne10,
// x, ne10,
// 0.0f, d, ne01);
// }
// }
// }
//
// //printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
//
// return;
// }
//#endif
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(nb10 == sizeof(float));
if (params->ith != 0) return;
if (params->type == GGML_TASK_INIT) {
return;
}
if (params->type == GGML_TASK_FINALIZE) {
return;
}
for (int i03 = 0; i03 < ne03; i03++) {
for (int i02 = 0; i02 < ne02; i02++) {
const float * x = (float *) (src0->data);
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
// zT = y * xT
{
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
ne11, ne01, ne10,
1.0f, y, ne10,
x, ne10,
0.0f, d, ne01);
}
}
}
//printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
return;
}
#endif
if (params->type == GGML_TASK_INIT) {
if (nb01 >= nb00) {
@ -4451,7 +4493,7 @@ void ggml_compute_forward_mul_mat_f16_f32(
// nb00 < nb01 - src0 is transposed
// compute by src0 columns
#ifdef GGML_USE_ACCELERATE
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
GGML_ASSERT(nb10 == sizeof(float));
@ -6968,7 +7010,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
} break;
case GGML_OP_ADD:
{
node->n_tasks = 1;
node->n_tasks = n_threads;
} break;
case GGML_OP_SUB:
case GGML_OP_MUL:
@ -7007,7 +7049,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
} else {
if (node->src0->type == GGML_TYPE_F16 &&
node->src1->type == GGML_TYPE_F32) {
#ifdef GGML_USE_ACCELERATE
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
cur = sizeof(float)*(node->src0->ne[0]*node->src0->ne[1]);
} else {
@ -8200,7 +8242,7 @@ int ggml_cpu_has_avx512(void) {
}
int ggml_cpu_has_neon(void) {
#if defined(__ARM_NEON__)
#if defined(__ARM_NEON)
return 1;
#else
return 0;
@ -8224,7 +8266,7 @@ int ggml_cpu_has_wasm_simd(void) {
}
int ggml_cpu_has_blas(void) {
#if defined(GGML_USE_BLAS) || defined(GGML_USE_ACCELERATE)
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
return 1;
#else
return 0;

2
ggml.h
View File

@ -11,7 +11,7 @@ extern "C" {
#define GGML_MAX_DIMS 4
#define GGML_MAX_NODES 4096
#define GGML_MAX_PARAMS 16
#define GGML_MAX_CONTEXTS 16
#define GGML_MAX_CONTEXTS 64
#define GGML_MAX_OPT 4
#ifdef __ARM_NEON

View File

@ -22,6 +22,20 @@ A third option to obtain the model files is to download them from Hugging Face:
https://huggingface.co/datasets/ggerganov/whisper.cpp/tree/main
## Available models
| Model | Disk | Mem | SHA |
| --- | --- | --- | --- |
| tiny | 75 MB | ~390 MB | `bd577a113a864445d4c299885e0cb97d4ba92b5f` |
| tiny.en | 75 MB | ~390 MB | `c78c86eb1a8faa21b369bcd33207cc90d64ae9df` |
| base | 142 MB | ~500 MB | `465707469ff3a37a2b9b8d8f89f2f99de7299dac` |
| base.en | 142 MB | ~500 MB | `137c40403d78fd54d454da0f9bd998f78703390c` |
| small | 466 MB | ~1.0 GB | `55356645c2b361a969dfd0ef2c5a50d530afd8d5` |
| small.en | 466 MB | ~1.0 GB | `db8a495a91d927739e50b3fc1cc4c6b8f6c2d022` |
| medium | 1.5 GB | ~2.6 GB | `fd9727b6e1217c2f614f9b698455c4ffd82463b4` |
| medium.en | 1.5 GB | ~2.6 GB | `8c30f0e44ce9560643ebd10bbe50cd20eafd3723` |
| large | 2.9 GB | ~4.7 GB | `b1caaf735c4cc1429223d5a74f0f4d0b9b59a299` |
## Model files for testing purposes
The model files pefixed with `for-tests-` are empty (i.e. do not contain any weights) and are used by the CI for testing purposes.

View File

@ -0,0 +1,63 @@
@echo off
pushd %~dp0
set models_path=%CD%
popd
set argc=0
for %%x in (%*) do set /A argc+=1
set models=tiny.en tiny base.en base small.en small medium.en medium large
if %argc% neq 1 (
echo.
echo Usage: download-ggml-model.cmd model
CALL :list_models
goto :eof
)
set model=%1
for %%b in (%models%) do (
if "%%b"=="%model%" (
CALL :download_model
goto :eof
)
)
echo Invalid model: %model%
CALL :list_models
goto :eof
:download_model
echo Downloading ggml model %model%...
cd %models_path%
if exist "ggml-%model%.bin" (
echo Model %model% already exists. Skipping download.
goto :eof
)
PowerShell -NoProfile -ExecutionPolicy Bypass -Command "Invoke-WebRequest -Uri https://ggml.ggerganov.com/ggml-model-whisper-%model%.bin -OutFile ggml-%model%.bin"
if %ERRORLEVEL% neq 0 (
echo Failed to download ggml model %model%
echo Please try again later or download the original Whisper model files and convert them yourself.
goto :eof
)
echo Done! Model %model% saved in %models_path%\models\ggml-%model%.bin
echo You can now use it like this:
echo main.exe -m %models_path%\models\ggml-%model%.bin -f %models_path%\samples\jfk.wav
goto :eof
:list_models
echo.
echo Available models:
(for %%a in (%models%) do (
echo %%a
))
echo.
exit /b

View File

@ -1,3 +1,4 @@
#define WHISPER_BUILD
#include "whisper.h"
#include "ggml.h"
@ -132,11 +133,19 @@ static const std::map<std::string, std::pair<int, std::string>> g_lang = {
static const size_t MB = 1024*1024;
static const std::map<e_model, size_t> MEM_REQ_MODEL = {
{ MODEL_TINY, 86ull*MB },
{ MODEL_BASE, 165ull*MB },
{ MODEL_SMALL, 540ull*MB },
{ MODEL_MEDIUM, 1650ull*MB },
{ MODEL_LARGE, 3260ull*MB },
{ MODEL_TINY, 74ull*MB },
{ MODEL_BASE, 142ull*MB },
{ MODEL_SMALL, 466ull*MB },
{ MODEL_MEDIUM, 1464ull*MB },
{ MODEL_LARGE, 2952ull*MB },
};
static const std::map<e_model, size_t> MEM_REQ_MEMORY = {
{ MODEL_TINY, 12ull*MB },
{ MODEL_BASE, 24ull*MB },
{ MODEL_SMALL, 70ull*MB },
{ MODEL_MEDIUM, 184ull*MB },
{ MODEL_LARGE, 306ull*MB },
};
static const std::map<e_model, size_t> MEM_REQ_ENCODE = {
@ -210,14 +219,6 @@ struct whisper_vocab {
}
};
struct whisper_token_data {
whisper_token id; // token id
whisper_token tid; // forced timestamp token id
float p; // probability of the token
float pt; // probability of the timestamp token
};
struct whisper_segment {
int64_t t0;
int64_t t1;
@ -386,6 +387,7 @@ struct whisper_model {
// context
struct ggml_context * ctx;
struct ggml_context * ctx_mem;
// tensors
int n_loaded;
@ -400,7 +402,8 @@ struct whisper_context {
int64_t t_decode_us = 0;
int64_t t_start_us = 0;
std::vector<uint8_t> buf_model;
std::vector<uint8_t> * buf_model; // the model buffer is read-only and can be shared between processors
std::vector<uint8_t> buf_memory;
std::vector<uint8_t> buf_compute;
std::vector<uint8_t> buf_compute_layer;
@ -412,10 +415,15 @@ struct whisper_context {
std::vector<float> probs;
std::vector<float> logits;
std::vector<whisper_token_data> tokens_cur;
std::vector<whisper_segment> result_all;
std::vector<whisper_token> prompt_past;
// [EXPERIMENTAL] token-level timestamps data
int64_t t_beg;
int64_t t_last;
whisper_token tid_last;
std::vector<float> energy; // PCM signal energy
};
// load the model from a ggml file
@ -429,7 +437,7 @@ struct whisper_context {
//
// see the convert-pt-to-ggml.py script for details
//
bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
static bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
fprintf(stderr, "%s: loading model from '%s'\n", __func__, fname.c_str());
auto & model = wctx.model;
@ -502,13 +510,16 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
fprintf(stderr, "%s: f16 = %d\n", __func__, hparams.f16);
fprintf(stderr, "%s: type = %d\n", __func__, model.type);
wctx.buf_model.resize(MEM_REQ_MODEL.at(model.type));
wctx.buf_model = new std::vector<uint8_t>();
wctx.buf_model->resize(MEM_REQ_MODEL.at(model.type));
wctx.buf_memory.resize(MEM_REQ_MEMORY.at(model.type));
wctx.buf_compute.resize(std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)));
wctx.buf_compute_layer.resize(std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type)));
// this is the total memory required to run the inference
const size_t mem_required =
wctx.buf_model.size() +
wctx.buf_model->size() +
wctx.buf_memory.size() +
wctx.buf_compute.size() +
wctx.buf_compute_layer.size();
@ -591,6 +602,7 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
size_t ctx_size = 0;
size_t ctx_mem_size = 0;
{
const auto & hparams = model.hparams;
@ -699,11 +711,11 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_1_b
}
ctx_size += n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_k
ctx_size += n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_v
ctx_mem_size += n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_k
ctx_mem_size += n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_v
ctx_size += n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_cross_k
ctx_size += n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_cross_v
ctx_mem_size += n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_cross_k
ctx_mem_size += n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_cross_v
ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead
@ -713,8 +725,8 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
// create the ggml context
{
struct ggml_init_params params = {
.mem_size = wctx.buf_model.size(),
.mem_buffer = wctx.buf_model.data(),
.mem_size = wctx.buf_model->size(),
.mem_buffer = wctx.buf_model->data(),
};
model.ctx = ggml_init(params);
@ -920,9 +932,23 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
}
}
// create the ggml memory context
{
struct ggml_init_params params = {
.mem_size = wctx.buf_memory.size(),
.mem_buffer = wctx.buf_memory.data(),
};
model.ctx_mem = ggml_init(params);
if (!model.ctx_mem) {
fprintf(stderr, "%s: ggml_init() failed\n", __func__);
return false;
}
}
// key + value memory
{
auto & ctx = model.ctx;
auto & ctx = model.ctx_mem;
const auto & hparams = model.hparams;
@ -1042,7 +1068,7 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
// - n_threads: number of threads to use
// - mel_offset: offset in the mel spectrogram (i.e. audio offset)
//
bool whisper_encode(
static bool whisper_encode(
whisper_context & wctx,
const int n_threads,
const int mel_offset) {
@ -1428,7 +1454,7 @@ bool whisper_encode(
// - n_tokens: number of tokens in the prompt
// - n_past: number of past tokens to prefix the prompt with
//
bool whisper_decode(
static bool whisper_decode(
whisper_context & wctx,
const int n_threads,
const whisper_token * tokens,
@ -1791,10 +1817,12 @@ bool whisper_decode(
}
// the most basic sampling scheme - select the top token
whisper_token_data whisper_sample_best(
static whisper_token_data whisper_sample_best(
const whisper_vocab & vocab,
const float * probs) {
whisper_token_data result;
whisper_token_data result = {
0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
};
int n_logits = vocab.id_to_token.size();
@ -1831,7 +1859,8 @@ whisper_token_data whisper_sample_best(
}
}
result.pt = max_ts/(sum_ts + 1e-6);
result.pt = max_ts/(sum_ts + 1e-10);
result.ptsum = sum_ts;
}
// find the top K tokens
@ -1866,7 +1895,7 @@ whisper_token_data whisper_sample_best(
}
// samples only from the timestamps tokens
whisper_vocab::id whisper_sample_timestamp(
static whisper_vocab::id whisper_sample_timestamp(
const whisper_vocab & vocab,
const float * probs) {
int n_logits = vocab.id_to_token.size();
@ -1898,14 +1927,19 @@ whisper_vocab::id whisper_sample_timestamp(
return probs_id[0].second;
}
static std::string to_timestamp(int64_t t) {
int64_t sec = t/100;
int64_t msec = t - sec*100;
int64_t min = sec/60;
sec = sec - min*60;
// 500 -> 00:05.000
// 6000 -> 01:00.000
static std::string to_timestamp(int64_t t, bool comma = false) {
int64_t msec = t * 10;
int64_t hr = msec / (1000 * 60 * 60);
msec = msec - hr * (1000 * 60 * 60);
int64_t min = msec / (1000 * 60);
msec = msec - min * (1000 * 60);
int64_t sec = msec / 1000;
msec = msec - sec * 1000;
char buf[32];
snprintf(buf, sizeof(buf), "%02d:%02d.%03d", (int) min, (int) sec, (int) msec);
snprintf(buf, sizeof(buf), "%02d:%02d:%02d%s%03d", (int) hr, (int) min, (int) sec, comma ? "," : ".", (int) msec);
return std::string(buf);
}
@ -1913,7 +1947,7 @@ static std::string to_timestamp(int64_t t) {
// naive Discrete Fourier Transform
// input is real-valued
// output is complex-valued
void dft(const std::vector<float> & in, std::vector<float> & out) {
static void dft(const std::vector<float> & in, std::vector<float> & out) {
int N = in.size();
out.resize(N*2);
@ -1937,7 +1971,7 @@ void dft(const std::vector<float> & in, std::vector<float> & out) {
// poor man's implementation - use something better
// input is real-valued
// output is complex-valued
void fft(const std::vector<float> & in, std::vector<float> & out) {
static void fft(const std::vector<float> & in, std::vector<float> & out) {
out.resize(in.size()*2);
int N = in.size();
@ -1988,7 +2022,7 @@ void fft(const std::vector<float> & in, std::vector<float> & out) {
}
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124
bool log_mel_spectrogram(
static bool log_mel_spectrogram(
const float * samples,
const int n_samples,
const int sample_rate,
@ -2127,6 +2161,9 @@ struct whisper_context * whisper_init(const char * path_model) {
void whisper_free(struct whisper_context * ctx) {
if (ctx) {
if (ctx->buf_model) {
delete ctx->buf_model;
}
delete ctx;
}
}
@ -2189,7 +2226,7 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i
return 0;
}
whisper_token whisper_sample_best(struct whisper_context * ctx) {
struct whisper_token_data whisper_sample_best(struct whisper_context * ctx) {
const int64_t t_start_sample_us = ggml_time_us();
// TODO: simplify
@ -2197,7 +2234,7 @@ whisper_token whisper_sample_best(struct whisper_context * ctx) {
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
return res.id;
return res;
}
whisper_token whisper_sample_timestamp(struct whisper_context * ctx) {
@ -2300,6 +2337,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.strategy =*/ WHISPER_SAMPLING_GREEDY,
/*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
/*.n_max_text_ctx =*/ 16384,
/*.offset_ms =*/ 0,
/*.translate =*/ false,
@ -2309,6 +2347,11 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.print_realtime =*/ false,
/*.print_timestamps =*/ true,
/*.token_timestamps =*/ false,
/*.thold_pt =*/ 0.01f,
/*.thold_ptsum =*/ 0.01f,
/*.max_len =*/ 0,
/*.language =*/ "en",
/*.greedy =*/ {
@ -2331,6 +2374,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.strategy =*/ WHISPER_SAMPLING_BEAM_SEARCH,
/*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
/*.n_max_text_ctx =*/ 16384,
/*.offset_ms =*/ 0,
/*.translate =*/ false,
@ -2340,6 +2384,11 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.print_realtime =*/ false,
/*.print_timestamps =*/ true,
/*.token_timestamps =*/ false,
/*.thold_pt =*/ 0.01f,
/*.thold_ptsum =*/ 0.01f,
/*.max_len =*/ 0,
/*.language =*/ "en",
/*.greedy =*/ {
@ -2361,6 +2410,68 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
return result;
}
// forward declarations
static std::vector<float> get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window);
static void whisper_exp_compute_token_level_timestamps(
struct whisper_context * ctx,
int i_segment,
float thold_pt,
float thold_ptsum);
// wrap the last segment to max_len characters
// returns the number of new segments
static int whisper_wrap_segment(struct whisper_context * ctx, int max_len) {
auto segment = ctx->result_all.back();
int res = 1;
int acc = 0;
std::string text;
for (int i = 0; i < (int) segment.tokens.size(); i++) {
const auto & token = segment.tokens[i];
if (token.id >= whisper_token_eot(ctx)) {
continue;
}
const auto txt = whisper_token_to_str(ctx, token.id);
const int cur = strlen(txt);
if (acc + cur > max_len && i > 0) {
// split here
ctx->result_all.back().text = std::move(text);
ctx->result_all.back().t1 = token.t0;
ctx->result_all.back().tokens.resize(i);
ctx->result_all.push_back({});
ctx->result_all.back().t0 = token.t0;
ctx->result_all.back().t1 = segment.t1;
// add tokens [i, end] to the new segment
ctx->result_all.back().tokens.insert(
ctx->result_all.back().tokens.end(),
segment.tokens.begin() + i,
segment.tokens.end());
acc = 0;
text = "";
segment = ctx->result_all.back();
i = -1;
res++;
} else {
acc += cur;
text += txt;
}
}
ctx->result_all.back().text = std::move(text);
return res;
}
int whisper_full(
struct whisper_context * ctx,
struct whisper_full_params params,
@ -2368,7 +2479,6 @@ int whisper_full(
int n_samples) {
// clear old results
auto & result_all = ctx->result_all;
auto & tokens_cur = ctx->tokens_cur;
result_all.clear();
@ -2378,10 +2488,19 @@ int whisper_full(
return -1;
}
if (params.token_timestamps) {
ctx->t_beg = 0;
ctx->t_last = 0;
ctx->tid_last = 0;
ctx->energy = get_signal_energy(samples, n_samples, 32);
}
const int seek_start = params.offset_ms/10;
// if length of spectrogram is less than 1s (100 samples), then return
// basically don't process anything that is less than 1s
// see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39
if (whisper_n_len(ctx) < 100) {
if (whisper_n_len(ctx) < 100 + seek_start) {
return 0;
}
@ -2405,8 +2524,14 @@ int whisper_full(
int progress_prev = 0;
int progress_step = 5;
std::vector<whisper_token_data> tokens_cur;
tokens_cur.reserve(whisper_n_text_ctx(ctx));
std::vector<whisper_token> prompt;
prompt.reserve(whisper_n_text_ctx(ctx));
// main loop
int seek = params.offset_ms/10;
int seek = seek_start;
while (true) {
int progress_cur = (100*seek)/whisper_n_len(ctx);
while (progress_cur >= progress_prev + progress_step) {
@ -2426,13 +2551,12 @@ int whisper_full(
return 7;
}
std::vector<whisper_token> prompt;
int n_past = 0;
prompt.clear();
// if we have already generated some text, use it as a prompt to condition the next generation
if (prompt_past.size() > 0) {
int n_take = std::min(whisper_n_text_ctx(ctx)/2, int(prompt_past.size()));
int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size()));
prompt = { whisper_token_prev(ctx) };
prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end());
@ -2474,7 +2598,7 @@ int whisper_full(
// feel free to experiment!
//
{
auto token = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab));
auto token = whisper_sample_best(ctx);
if (i == 0) {
token.tid = whisper_token_beg(ctx);
@ -2490,7 +2614,10 @@ int whisper_full(
prompt.push_back(token.id);
tokens_cur.push_back(token);
//printf("%s: %s\n", __func__, ctx->vocab.id_to_token[id].c_str());
//{
// const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]";
// printf("%s: %10s %6.3f '%s'\n", __func__, tt.c_str(), token.pt, ctx->vocab.id_to_token[token.id].c_str());
//}
// end of text token
if (token.id == whisper_token_eot(ctx)) {
@ -2517,6 +2644,7 @@ int whisper_full(
}
}
// shrink down to result_len
tokens_cur.resize(result_len);
for (const auto & r : tokens_cur) {
@ -2555,8 +2683,19 @@ int whisper_full(
for (int j = i0; j <= i; j++) {
result_all.back().tokens.push_back(tokens_cur[j]);
}
int n_new = 1;
if (params.token_timestamps) {
whisper_exp_compute_token_level_timestamps(
ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
if (params.max_len > 0) {
n_new = whisper_wrap_segment(ctx, params.max_len);
}
}
if (params.new_segment_callback) {
params.new_segment_callback(ctx, params.new_segment_callback_user_data);
params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
}
}
text = "";
@ -2585,8 +2724,19 @@ int whisper_full(
for (int j = i0; j < (int) tokens_cur.size(); j++) {
result_all.back().tokens.push_back(tokens_cur[j]);
}
int n_new = 1;
if (params.token_timestamps) {
whisper_exp_compute_token_level_timestamps(
ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
if (params.max_len > 0) {
n_new = whisper_wrap_segment(ctx, params.max_len);
}
}
if (params.new_segment_callback) {
params.new_segment_callback(ctx, params.new_segment_callback_user_data);
params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
}
}
}
@ -2597,6 +2747,156 @@ int whisper_full(
return 0;
}
int whisper_full_parallel(
struct whisper_context * ctx,
struct whisper_full_params params,
const float * samples,
int n_samples,
const int n_processors) {
if (n_processors == 1) {
return whisper_full(ctx, params, samples, n_samples);
}
int ret = 0;
// prepare separate contexts for each thread
std::vector<struct whisper_context> ctxs(n_processors - 1);
for (int i = 0; i < n_processors - 1; ++i) {
ctxs[i] = *ctx;
auto & model = ctxs[i].model;
// create the ggml memory context
{
struct ggml_init_params params = {
.mem_size = ctxs[i].buf_memory.size(),
.mem_buffer = ctxs[i].buf_memory.data(),
};
model.ctx_mem = ggml_init(params);
if (!model.ctx_mem) {
fprintf(stderr, "%s: ggml_init() failed\n", __func__);
return false;
}
}
// separate key + value memory for each processor
{
auto & ctx = model.ctx_mem;
const auto & hparams = model.hparams;
const int n_text_state = hparams.n_text_state;
const int n_text_layer = hparams.n_text_layer;
const int n_text_ctx = hparams.n_text_ctx;
// key/value memory for the self-attention layer
{
const int n_mem = n_text_layer*n_text_ctx;
const int n_elements = n_text_state*n_mem;
model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
}
// key/value memory for the cross-attention layer
{
const int n_audio_ctx = hparams.n_audio_ctx;
const int n_mem = n_text_layer*n_audio_ctx;
const int n_elements = n_text_state*n_mem;
model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
}
const size_t memory_size =
ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v) +
ggml_nbytes(model.memory_cross_k) + ggml_nbytes(model.memory_cross_v);
}
}
const int offset_samples = (WHISPER_SAMPLE_RATE*params.offset_ms)/1000;
const int n_samples_per_processor = (n_samples - offset_samples)/n_processors;
// the calling thread will process the first chunk
// while the other threads will process the remaining chunks
std::vector<std::thread> workers(n_processors - 1);
for (int i = 0; i < n_processors - 1; ++i) {
const int start_samples = offset_samples + (i + 1)*n_samples_per_processor;
const int n_samples_cur = (i == n_processors - 2) ? n_samples - start_samples : n_samples_per_processor;
auto params_cur = params;
params_cur.offset_ms = 0;
params_cur.print_progress = false;
params_cur.print_realtime = false;
params_cur.new_segment_callback = nullptr;
params_cur.new_segment_callback_user_data = nullptr;
workers[i] = std::thread(whisper_full, &ctxs[i], std::move(params_cur), samples + start_samples, n_samples_cur);
}
{
auto params_cur = params;
ret = whisper_full(ctx, std::move(params_cur), samples, offset_samples + n_samples_per_processor);
}
for (int i = 0; i < n_processors - 1; ++i) {
workers[i].join();
}
const int64_t offset_t = (int64_t) params.offset_ms/10.0;
// combine results into ctx->result_all
for (int i = 0; i < n_processors - 1; ++i) {
auto & results_i = ctxs[i].result_all;
for (int j = 0; j < (int) results_i.size(); ++j) {
// correct the segment timestamp taking into account the offset
results_i[j].t0 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
results_i[j].t1 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
// make sure that segments are not overlapping
if (ctx->result_all.size() > 0) {
results_i[j].t0 = std::max(results_i[j].t0, ctx->result_all.back().t1);
}
ctx->result_all.push_back(std::move(results_i[j]));
// call the new_segment_callback for each segment
if (params.new_segment_callback) {
params.new_segment_callback(ctx, 1, params.new_segment_callback_user_data);
}
}
ctx->t_mel_us += ctxs[i].t_mel_us;
ctx->t_sample_us += ctxs[i].t_sample_us;
ctx->t_encode_us += ctxs[i].t_encode_us;
ctx->t_decode_us += ctxs[i].t_decode_us;
}
// average the timings
ctx->t_mel_us /= n_processors;
ctx->t_sample_us /= n_processors;
ctx->t_encode_us /= n_processors;
ctx->t_decode_us /= n_processors;
// print information about the audio boundaries
fprintf(stderr, "\n");
fprintf(stderr, "%s: the audio has been split into %d chunks at the following times:\n", __func__, n_processors);
for (int i = 0; i < n_processors - 1; ++i) {
fprintf(stderr, "%s: split %d - %s\n", __func__, (i + 1), to_timestamp(100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t).c_str());
}
fprintf(stderr, "%s: the transcription quality may be degraded near these boundaries\n", __func__);
return ret;
}
int whisper_full_n_segments(struct whisper_context * ctx) {
return ctx->result_all.size();
}
@ -2625,6 +2925,10 @@ whisper_token whisper_full_get_token_id(struct whisper_context * ctx, int i_segm
return ctx->result_all[i_segment].tokens[i_token].id;
}
struct whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token) {
return ctx->result_all[i_segment].tokens[i_token];
}
float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token) {
return ctx->result_all[i_segment].tokens[i_token].p;
}
@ -2642,3 +2946,304 @@ const char * whisper_print_system_info() {
return s.c_str();
}
// =================================================================================================
//
// Experimental stuff below
//
// Not sure if these should be part of the library at all, because the quality of the results is not
// guaranteed. Might get removed at some point unless a robust algorithm implementation is found
//
// =================================================================================================
//
// token-level timestamps
//
static int timestamp_to_sample(int64_t t, int n_samples) {
return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100)));
}
static int64_t sample_to_timestamp(int i_sample) {
return (100*i_sample)/WHISPER_SAMPLE_RATE;
}
// a cost-function / heuristic that is high for text that takes longer to pronounce
// obviously, can be improved
static float voice_length(const std::string & text) {
float res = 0.0f;
for (size_t i = 0; i < text.size(); ++i) {
if (text[i] == ' ') {
res += 0.01f;
} else if (text[i] == ',') {
res += 2.00f;
} else if (text[i] == '.') {
res += 3.00f;
} else if (text[i] == '!') {
res += 3.00f;
} else if (text[i] == '?') {
res += 3.00f;
} else if (text[i] >= '0' && text[i] <= '9') {
res += 3.00f;
} else {
res += 1.00f;
}
}
return res;
}
// average the fabs of the signal
static std::vector<float> get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window) {
const int hw = n_samples_per_half_window;
std::vector<float> result(n_samples);
for (int i = 0; i < n_samples; i++) {
float sum = 0;
for (int j = -hw; j <= hw; j++) {
if (i + j >= 0 && i + j < n_samples) {
sum += fabs(signal[i + j]);
}
}
result[i] = sum/(2*hw + 1);
}
return result;
}
static void whisper_exp_compute_token_level_timestamps(
struct whisper_context * ctx,
int i_segment,
float thold_pt,
float thold_ptsum) {
auto & segment = ctx->result_all[i_segment];
auto & tokens = segment.tokens;
const int n_samples = ctx->energy.size();
if (n_samples == 0) {
fprintf(stderr, "%s: no signal data available\n", __func__);
return;
}
const int64_t t0 = segment.t0;
const int64_t t1 = segment.t1;
const int s0 = timestamp_to_sample(t0, n_samples);
const int s1 = timestamp_to_sample(t1, n_samples);
const int n = tokens.size();
if (n == 0) {
return;
}
if (n == 1) {
tokens[0].t0 = t0;
tokens[0].t1 = t1;
return;
}
auto & t_beg = ctx->t_beg;
auto & t_last = ctx->t_last;
auto & tid_last = ctx->tid_last;
for (int j = 0; j < n; ++j) {
auto & token = tokens[j];
if (j == 0) {
if (token.id == whisper_token_beg(ctx)) {
tokens[j ].t0 = t0;
tokens[j ].t1 = t0;
tokens[j + 1].t0 = t0;
t_beg = t0;
t_last = t0;
tid_last = whisper_token_beg(ctx);
} else {
tokens[j ].t0 = t_last;
}
}
const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(ctx));
tokens[j].id = token.id;
tokens[j].tid = token.tid;
tokens[j].p = token.p;
tokens[j].pt = token.pt;
tokens[j].ptsum = token.ptsum;
tokens[j].vlen = voice_length(whisper_token_to_str(ctx, token.id));
if (token.pt > thold_pt && token.ptsum > thold_ptsum && token.tid > tid_last && tt <= t1) {
if (j > 0) {
tokens[j - 1].t1 = tt;
}
tokens[j].t0 = tt;
tid_last = token.tid;
}
}
tokens[n - 2].t1 = t1;
tokens[n - 1].t0 = t1;
tokens[n - 1].t1 = t1;
t_last = t1;
// find intervals of tokens with unknown timestamps
// fill the timestamps by proportionally splitting the interval based on the token voice lengths
{
int p0 = 0;
int p1 = 0;
while (true) {
while (p1 < n && tokens[p1].t1 < 0) {
p1++;
}
if (p1 >= n) {
p1--;
}
if (p1 > p0) {
double psum = 0.0;
for (int j = p0; j <= p1; j++) {
psum += tokens[j].vlen;
}
//printf("analyzing %d - %d, psum = %f\n", p0, p1, psum);
const double dt = tokens[p1].t1 - tokens[p0].t0;
// split the time proportionally to the voice length
for (int j = p0 + 1; j <= p1; j++) {
const double ct = tokens[j - 1].t0 + dt*tokens[j - 1].vlen/psum;
tokens[j - 1].t1 = ct;
tokens[j ].t0 = ct;
}
}
p1++;
p0 = p1;
if (p1 >= n) {
break;
}
}
}
// fix up (just in case)
for (int j = 0; j < n - 1; j++) {
if (tokens[j].t1 < 0) {
tokens[j + 1].t0 = tokens[j].t1;
}
if (j > 0) {
if (tokens[j - 1].t1 > tokens[j].t0) {
tokens[j].t0 = tokens[j - 1].t1;
tokens[j].t1 = std::max(tokens[j].t0, tokens[j].t1);
}
}
}
// VAD
// expand or contract tokens based on voice activity
{
const int hw = WHISPER_SAMPLE_RATE/8;
for (int j = 0; j < n; j++) {
if (tokens[j].id >= whisper_token_eot(ctx)) {
continue;
}
int s0 = timestamp_to_sample(tokens[j].t0, n_samples);
int s1 = timestamp_to_sample(tokens[j].t1, n_samples);
const int ss0 = std::max(s0 - hw, 0);
const int ss1 = std::min(s1 + hw, n_samples);
const int ns = ss1 - ss0;
float sum = 0.0f;
for (int k = ss0; k < ss1; k++) {
sum += ctx->energy[k];
}
const float thold = 0.5*sum/ns;
{
int k = s0;
if (ctx->energy[k] > thold && j > 0) {
while (k > 0 && ctx->energy[k] > thold) {
k--;
}
tokens[j].t0 = sample_to_timestamp(k);
if (tokens[j].t0 < tokens[j - 1].t1) {
tokens[j].t0 = tokens[j - 1].t1;
} else {
s0 = k;
}
} else {
while (ctx->energy[k] < thold && k < s1) {
k++;
}
s0 = k;
tokens[j].t0 = sample_to_timestamp(k);
}
}
{
int k = s1;
if (ctx->energy[k] > thold) {
while (k < n_samples - 1 && ctx->energy[k] > thold) {
k++;
}
tokens[j].t1 = sample_to_timestamp(k);
if (j < ns - 1 && tokens[j].t1 > tokens[j + 1].t0) {
tokens[j].t1 = tokens[j + 1].t0;
} else {
s1 = k;
}
} else {
while (ctx->energy[k] < thold && k > s0) {
k--;
}
s1 = k;
tokens[j].t1 = sample_to_timestamp(k);
}
}
}
}
// fixed token expand (optional)
//{
// const int t_expand = 0;
// for (int j = 0; j < n; j++) {
// if (j > 0) {
// tokens[j].t0 = std::max(0, (int) (tokens[j].t0 - t_expand));
// }
// if (j < n - 1) {
// tokens[j].t1 = tokens[j].t1 + t_expand;
// }
// }
//}
// debug info
//for (int j = 0; j < n; ++j) {
// const auto & token = tokens[j];
// const auto tt = token.pt > thold_pt && token.ptsum > 0.01 ? whisper_token_to_str(ctx, token.tid) : "[?]";
// printf("%s: %10s %6.3f %6.3f %6.3f %6.3f %5d %5d '%s'\n", __func__,
// tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, whisper_token_to_str(ctx, token.id));
// if (tokens[j].id >= whisper_token_eot(ctx)) {
// continue;
// }
//}
}

View File

@ -68,6 +68,22 @@ extern "C" {
typedef int whisper_token;
typedef struct whisper_token_data {
whisper_token id; // token id
whisper_token tid; // forced timestamp token id
float p; // probability of the token
float pt; // probability of the timestamp token
float ptsum; // sum of probabilities of all timestamp tokens
// token-level timestamp data
// do not use if you haven't computed token-level timestamps
int64_t t0; // start time of the token
int64_t t1; // end time of the token
float vlen; // voice length of the token
} whisper_token_data;
// Allocates all memory needed for the model and loads the model from the given file.
// Returns NULL on failure.
WHISPER_API struct whisper_context * whisper_init(const char * path_model);
@ -120,7 +136,7 @@ extern "C" {
// You can also implement your own sampling method using the whisper_get_probs() function.
// whisper_sample_best() returns the token with the highest probability
// whisper_sample_timestamp() returns the most probable timestamp token
WHISPER_API whisper_token whisper_sample_best(struct whisper_context * ctx);
WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx);
WHISPER_API whisper_token whisper_sample_timestamp(struct whisper_context * ctx);
// Return the id of the specified language, returns -1 if not found
@ -163,12 +179,13 @@ extern "C" {
// Text segment callback
// Called on every newly generated text segment
// Use the whisper_full_...() functions to obtain the text segments
typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, void * user_data);
typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, int n_new, void * user_data);
struct whisper_full_params {
enum whisper_sampling_strategy strategy;
int n_threads;
int n_max_text_ctx;
int offset_ms;
bool translate;
@ -178,6 +195,12 @@ extern "C" {
bool print_realtime;
bool print_timestamps;
// [EXPERIMENTAL] token-level timestamps
bool token_timestamps; // enable token-level timestamps
float thold_pt; // timestamp token probability threshold (~0.01)
float thold_ptsum; // timestamp token sum probability threshold (~0.01)
int max_len; // max segment length in characters
const char * language;
struct {
@ -204,6 +227,16 @@ extern "C" {
const float * samples,
int n_samples);
// Split the input audio in chunks and process each chunk separately using whisper_full()
// It seems this approach can offer some speedup in some cases.
// However, the transcription accuracy can be worse at the beginning and end of each chunk.
WHISPER_API int whisper_full_parallel(
struct whisper_context * ctx,
struct whisper_full_params params,
const float * samples,
int n_samples,
const int n_processors);
// Number of generated text segments.
// A segment can be a few words, a sentence, or even a paragraph.
WHISPER_API int whisper_full_n_segments(struct whisper_context * ctx);
@ -222,6 +255,10 @@ extern "C" {
WHISPER_API const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token);
WHISPER_API whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token);
// Get token data for the specified token in the specified segment.
// This contains probabilities, timestamps, etc.
WHISPER_API whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token);
// Get the probability of the specified token in the specified segment.
WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token);