Compare commits

...

154 Commits

Author SHA1 Message Date
8cbc363561 coreml : attempt to fix ANE-optimized models 2023-07-11 23:03:53 +03:00
4774d2feb0 whisper : minor OpenVINO refactoring (#1037)
Hopefully I didn't break something - haven't tested
2023-07-04 20:28:27 +03:00
6f0114f4a6 go : call SetDuration appropriately (#1077) 2023-07-04 16:13:25 +03:00
66616dbd4d go : fix context.Process call in examples (#1067) 2023-07-04 16:05:35 +03:00
62b81276e0 whisper : add OpenVINO support (#1037)
* openvino: use OpenVINO encoder inference

* openvino: add python script for OpenVINO model generation

* whisper: Fix 'unused' warnings when OpenVINO isn't enabled in build

* Apply suggestions from code review

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* whisper: Fix compilation error

* whisper: revert whisper_get_openvino_path_encoder & whisper_get_openvino_path_cache to non-const func signatures

* cmake: Add openvino-encoder as separate object target

* whisper : minor style fixes

* minor : indentation fixes

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2023-07-04 15:56:11 +03:00
176d7e4e7b readme : better wording (#1064) 2023-07-04 15:30:31 +03:00
70e6fcd78b readme : add tinydiarize instructions (#1058) 2023-07-04 09:51:22 +03:00
c8d0f5fe98 whisper : support speaker segmentation (local diarization) of mono audio via tinydiarize (#1058)
* add HuggingFace mirror to download  ggml model

* support tdrz via simple hack overriding solm tokens

* fix incorrect translate/transcribe token_ids that are not static const

* add apollo 13 sample for tdrz demo

* render [SPEAKER TURN] consistently in all terminal output using vocab.id_to_token

* extend whisper_segment with speaker_turn_next field and save in json output

* fix failing go build

* slipped in some python syntax whoops

* whisper : finalize tinydiarize support (add flag + fixes)

* whisper : tdrz support for word-level timestamps (respect max_len)

* java : try to fix tests after adding tdrz_enable flag

* main : remove TODO leftover

* java : fix params order list after adding "tdrz_enable"

* whisper : fix solm and add nosp token

* main : print tinydiarize help

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2023-07-04 09:45:00 +03:00
fdf58a6668 talk-llama : fix new rope interface 2023-07-03 19:24:01 +03:00
8ba42095c5 Revert "ggml : do not use _GNU_SOURCE gratuitously (#1027)"
This reverts commit 3f7a03ebe3.
2023-07-02 21:53:52 +03:00
d6509bf78d ggml : sync latest repo (mostly refactoring changes) 2023-07-02 21:46:09 +03:00
85ed71aaec talk-llama : fix build on macOS (#1062)
* talk-llama : use posix_madvise() instead of madvise() derived from BSD

sed -i 's,\<madvise\>,posix_&,g;s,\<MADV_,POSIX_&,g' examples/talk-llama/llama-util.h

* make : enable Darwin extensions for macOS builds

This is an attempt at fixing macOS build error coming from the fact that
RLIMIT_MEMLOCK define is not available there without Darwin extensions.
2023-06-28 22:34:50 +03:00
49c9472fa0 extra : update 'quantize-all.sh' to quantize all downloaded models (#1054)
Script will now do what it says: quantize everything except testing models in the 'models'  directory.
2023-06-28 22:07:02 +03:00
72deb41eb2 whisper : split_on_word no longer trims (#1046) 2023-06-25 23:51:01 +03:00
3f7a03ebe3 ggml : do not use _GNU_SOURCE gratuitously (#1027)
* Do not use _GNU_SOURCE gratuitously.

What is needed to build whisper.cpp and examples is availability of
stuff defined in The Open Group Base Specifications Issue 6
(https://pubs.opengroup.org/onlinepubs/009695399/) known also as
Single Unix Specification v3 (SUSv3) or POSIX.1-2001 + XSI extensions.

There is no need to penalize musl libc which simply follows standards.

Not having feature test macros in source code gives greater flexibility
to those wanting to reuse it in 3rd party app, as they can build it with
minimal FTM (_XOPEN_SOURCE=600) or other FTM depending on their needs.

It builds without issues in Alpine (musl libc), Ubuntu (glibc), MSYS2.

* examples : include SDL headers before other headers

This is an attempt at fixing macOS build error coming from SDL2 relying
on Darwin extension memset_pattern4/8/16 coming from Apple's string.h.
2023-06-25 16:34:30 +03:00
62642bb61c talk-llama : fix build after ggml sync (#1049)
sed -i 's,GGML_BACKEND_CUDA,GGML_BACKEND_GPU,g' examples/talk-llama/llama.cpp
2023-06-25 16:13:50 +03:00
f1c9df5806 metal : sync ggml-metal (ref #1047) 2023-06-25 15:40:39 +03:00
6c25fae1c4 opencl : sync latest ggml-opencl 2023-06-25 15:38:30 +03:00
44cb044e66 whisper : fix build with -Werror=undef (#1045) 2023-06-25 15:30:39 +03:00
6c68218e3c models : add ggml_to_pt script (#1042)
* adding ggml_to_pt

* typo sys too many args

* fixing swap errors dimensions

---------

Co-authored-by: simonMoisselin <simon.moisselin@gmail.com>
2023-06-25 15:29:54 +03:00
f11f33f1c0 models : cd statements are quoted to allow spaces in path (#1041) 2023-06-25 15:27:28 +03:00
8ac23c9f77 models : handle paths with spaces in download script (close #1038) 2023-06-25 15:23:23 +03:00
14baf2e7f3 main : add diarization support for all current output types (#1031)
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2023-06-25 15:07:57 +03:00
bc2dcf85fe readme : add java alternative binding (#1029)
Signed-off-by: Miguel Álvarez <miguelwork92@gmail.com>
2023-06-25 14:46:07 +03:00
1e45911f1a go : add support for whisper_full_lang_id() (#1010)
* * Add support for whisper_full_lang_id() to go bindings

* Expose token.id so we can test beg, eot etc

---------

Co-authored-by: Jay Binks <jay.binks@overthewire.com.au>
2023-06-25 14:45:33 +03:00
67564201ec go : fix "cb" -> "callNewSegment" 2023-06-25 14:34:10 +03:00
5feb0dffba ggml : sync latest ggml lib 2023-06-25 14:30:44 +03:00
7dfc11843c go : improve progress reporting and callback handling (#1024)
- Rename `cb` to `callNewSegment` in the `Process` function
- Add `callProgress` as a new parameter to the `Process` function
- Introduce `ProgressCallback` type for reporting progress during processing
- Update `Whisper_full` function to include `progressCallback` parameter
- Add `registerProgressCallback` function and `cbProgress` map for handling progress callbacks

Signed-off-by: appleboy <appleboy.tw@gmail.com>
2023-06-25 14:07:55 +03:00
6a7f3b8db2 make : update cuBLAS build both x86 and aarch64 (#1015)
make cuBLAS compilation compatible with x86 as well as aarch64.
2023-06-25 13:59:48 +03:00
207a12f5bc make : fix for CUDA native not working as an option on Ubuntu (#1012) 2023-06-25 13:57:18 +03:00
26b70395ff main : exit gracefully when invalid params are passed
* Refactor whisper_params_parse to return false on failure

* Updated help flag behavior
2023-06-25 13:52:29 +03:00
598f607e28 main : gracefully exit when invalid params are passed (#1002)
* Refactor whisper_params_parse to return false on failure

* Updated help flag behavior
2023-06-25 13:51:59 +03:00
3ec7bfffe0 py : make convert-pt-to-ggml.py backwards compatible with older vocab.json tokenizer files (#1001)
* patch checkpoint convert script to keep compatibility with older hf_transformers whisper tokenizer

* typo fix
2023-06-25 13:50:14 +03:00
a7f822ef59 readme : corrected syntax for markdown link (#995) 2023-06-25 13:46:44 +03:00
57543c169e updated java README 2023-06-06 10:27:26 +10:00
5b9e59bc07 speak scripts for Windows 2023-06-01 22:45:00 +10:00
3f7436e8a0 updated README for java 2023-06-01 16:55:48 +10:00
ce6f747064 whisper.android : support decode wav file has 2 channels (#972) 2023-05-31 10:13:14 +03:00
d7c936b44a Feature/java bindings2 (#944)
* Java needs to call `whisper_full_default_params_by_ref()`, returning struct by val does not seem to work.
* added convenience methods to WhisperFullParams
* Remove unused WhisperJavaParams
2023-05-29 09:38:58 +10:00
9b926844e3 models : fix README.md (#964)
Fixes typo on line 76 of models/README.md
2023-05-27 10:40:28 +03:00
5e2b3407ef examples : update elevenlabs scripts to use official python API (#837)
* Update elevenlabs example to use ufficial python API

* Update elevenlabs example to use official python API
2023-05-24 21:11:01 +03:00
4e16a8fb63 readme : highlight OpenBLAS support (#956)
* highlight openblas support

* Update README.md
2023-05-24 11:23:51 +03:00
77eab3fbfe talk-llama : sync latest llama.cpp (close #922, close #954) 2023-05-23 14:04:39 +03:00
041be06d58 cmake : build with any BLAS compatible library (#927)
* Build with any BLAS library

* ci: Removed explicit CUDA nvcc path
2023-05-20 21:23:45 +03:00
429b9785c0 ggml : update WASM SIMD 2023-05-20 20:00:06 +03:00
e410cfc3ce ggml : sync latest ggml repo
- new Q4 and Q8 quantization
- updated CUDA
2023-05-20 18:56:30 +03:00
bc89f285d8 bindings : add java bindings (#931)
* WIP - java bindings

* updated README

* failed attempt at JNI

* fullTranscribe() test passes

* tested on Ubuntu 20

* link to Java bindings
2023-05-20 18:25:02 +03:00
56a87ba45d whisper : fix hebrew language code (#935) 2023-05-20 18:17:54 +03:00
95b02d76b0 coreml : add support of large-v1 model (#926) 2023-05-15 18:36:06 +03:00
a5defbc1b9 release : v1.4.2 2023-05-14 19:06:45 +03:00
aaf0d41c7c ggml : add AVX dot products 2023-05-14 18:56:46 +03:00
0cb820e0f9 talk-llama : fix build + sync latest llama.cpp 2023-05-14 18:46:42 +03:00
16564f554f readme : improve Core ML model conversion guidance (#915) 2023-05-14 18:11:08 +03:00
fd01209d09 coreml : support quantized model files 2023-05-14 18:09:44 +03:00
e693074aa6 ggml : sync latest ggml
- New Q4 and Q5 formats
- Various improvements
2023-05-14 18:04:23 +03:00
d652cf12ec main : fix help for --no-timestamps arg (#908) 2023-05-14 17:54:57 +03:00
2b6a074305 extra : update ggml sync script 2023-05-14 10:01:52 +03:00
5300117471 whisper.objc : enable Core ML in example & fix segmentation fault (#910)
* coreml : update endcoder header import path

* coreml : force objc_arc in whisper-encoder.mm

* whisper.objc : create coreml/ group link

* whisper.objc : add coreml model link

* whisper.objc : update readme

* coreml : use -fobjc-arc for coreml/whisper-encoder.mm

* ci: create dummy .mlmodelc for pass ios build

* whisper.objc : update readme

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2023-05-14 09:47:02 +03:00
70af52a316 coreml : fix seg fault, double free (#919, #917, #899) 2023-05-14 09:42:19 +03:00
1d17cd5bb3 coreml : fix memory leak (#899) 2023-05-09 18:38:12 +03:00
bf2449dfae cmake : fix define used for COREML_ALLOW_FALLBACK (#893) 2023-05-08 21:08:09 +03:00
4e4d00c67a talk-llama : only copy used KV cache in get / set state (#890)
---------

Co-authored-by: ejones <evan.q.jones@gmail.com>
2023-05-08 20:59:21 +03:00
9931d66400 readme : add instructions on converting to GGML + "--no-config" to wget (#874) 2023-05-08 20:58:36 +03:00
1a548c048e cmake : fix options disabling AVX and AVX2 flags (#885) 2023-05-08 20:45:53 +03:00
14bee39b29 cmake : add options to disable CPU flags (#860) 2023-05-04 19:31:04 +03:00
d458fcbc15 ci : add cuBLAS build workflow and fix error causing lines in CMakeLists (#867)
* Add windows build with cuBLAS

* Remove error causing lines for cuBLAS on Windows
2023-05-03 23:47:37 +03:00
919e58b96a readme : partial OpenCL GPU support via CLBlast (#863)
* ggml : CLBlast support as in llama.cpp

Building with CLBlast speeds up whisper.cpp ~2x on low end / older AMD APUs (CPU with integrated GPU) such as the A9.

Usage:
WHISPER_CLBLAST=1 make

* CMake/Makefile : CLBlast support as in llama.cpp

Building with CLBlast speeds up whisper.cpp ~2x on low end / older AMD APUs (CPU with integrated GPU) such as the A9.

Usage:
```
Makefile:
cd whisper.cpp
WHISPER_CLBLAST=1 make

CMake:
cd whisper.cpp ; mkdir build ; cd build
cmake -DWHISPER_CLBLAST=ON  ..
make
```

* Update README.md

Added OpenCL Build Instructions

* Instruction: Partial OpenCL GPU support via CLBlast

Added build instructions and examples for Make and CMake to support OpenCL enabled GPUs.
2023-05-03 19:24:43 +03:00
05bef0f0e9 build : CLBlast support as in llama.cpp (#862)
* ggml : CLBlast support as in llama.cpp

Building with CLBlast speeds up whisper.cpp ~2x on low end / older AMD APUs (CPU with integrated GPU) such as the A9.

Usage:
WHISPER_CLBLAST=1 make

* CMake/Makefile : CLBlast support as in llama.cpp

Building with CLBlast speeds up whisper.cpp ~2x on low end / older AMD APUs (CPU with integrated GPU) such as the A9.

Usage:
```
Makefile:
cd whisper.cpp
WHISPER_CLBLAST=1 make

CMake:
cd whisper.cpp ; mkdir build ; cd build
cmake -DWHISPER_CLBLAST=ON  ..
make
```
2023-05-02 22:50:32 +03:00
5974c8facd ggml : fix 32-bit ARM build + quantization 2023-05-02 21:52:26 +03:00
0bcb64b184 ggml : sync ggml (clBLAST + tensor names) 2023-05-02 21:24:18 +03:00
0bf680fea2 talk-llama : fix session prompt load (#854) 2023-05-02 20:05:27 +03:00
b806420873 whisper : add detect-language mode (#853)
* add detectlanguage flag

* renaming and help

* no idea why that last one didn't commit

* run language detection if dl is set

* help message fix

* various fixes

* fix quitting

* fix language being english on print
2023-05-02 19:51:52 +03:00
be5911a9f3 talk-llama : add --session support (#845)
* feat: adding session support

* readme: adding --session info in examples/talk-llama

* llama: adding session fixes

* readme: updating session doc

* talk-llama: update the value of need_to_save_session to true in order to save the session in the subsequent interaction

* talk-llama: adding missing function which updates session_tokens
2023-05-01 20:18:10 +03:00
d375d73b2e bench : improve benchmarks 2023-05-01 14:44:39 +03:00
7765770f89 whisper : add memory sizes for Q8_0 (close #846) 2023-05-01 10:03:56 +03:00
872a85ae94 whisper.wasm : fix typo in readme (#832) 2023-05-01 09:28:05 +03:00
9c61f5f585 release : v1.4.1 2023-04-30 22:57:42 +03:00
c94c469592 whisper : fix quantize bug (#842)
* whisper : debug

* whisper : fix bug during quantization
2023-04-30 22:50:04 +03:00
feac80dd3f ggml : fix UB (int << 31) 2023-04-30 22:27:30 +03:00
fa8dbdc888 release : v1.4.0 2023-04-30 19:23:37 +03:00
4a7d49af95 examples : fix + refactor Levenshtein distance 2023-04-30 19:12:49 +03:00
794b162a46 whisper : add integer quantization support (#540)
* whisper : add integer quantization support

* examples : add common-ggml + prepare to add "quantize" tool

* whisper : quantization tool ready

* whisper : fix F32 support

* whisper : try to fix shared lib linkage

* wasm : update quantized models to Q5

* bench.wasm : remove "medium" button

* bench.wasm : fix custom model button

* ggml : add Q5_0 and Q5_1 WASM SIMD

* wasm : add quantized models to all WASM examples

* wasm : bump DB version number to 2

* talk-llama : update example to latest llama.cpp

* node : increase test timeout to 10s

* readme : add information for model quantization

* wasm : add links to other examples
2023-04-30 18:51:57 +03:00
5fd1bdd7fc whisper : add GPU support via cuBLAS (#834)
* make : add WHISPER_CUBLAS

* make : fix CUBLAS build

* whisper : disable Flash Attention + adjust memory buffers

* whisper : remove old commented code

* readme : add cuBLAS instructions

* cmake : add WHISPER_CUBLAS option

* gitignore : ignore build-cublas
2023-04-30 12:14:33 +03:00
0ccd6746c9 ggml : fix WASM build 2023-04-29 21:37:23 +03:00
d9b550c0a1 ggml : fix 32-bit ARM NEON (#836)
* ggml : add support for 32-bit ARM

* ggml : fix

* ggml : fix
2023-04-29 21:33:33 +03:00
e9b091c92a ggml : use vzip instead of vuzp for consistency 2023-04-29 21:14:09 +03:00
1f30b99208 ggml : fix WASM build 2023-04-29 20:21:25 +03:00
05c3ea3bc8 ggml : sync with ggml repo (warning fixes + asserts) 2023-04-29 19:33:28 +03:00
6108d3cc58 whisper : use correct seek_end when offset is used (#833)
Whenever an `offset_ms` is provided, the value of `seek_end` is
calculated incorrectly. This causes Whisper to keep transcribing
after the end of the file.

The current behavior looks like
```
[00:34:40.000 --> 00:34:47.000]   This is an example audio file.
[00:34:47.000 --> 00:34:49.000]   The text has been redacted
[00:34:49.000 --> 00:34:51.000]   This is the end of the audio.
[00:34:51.000 --> 00:34:52.000]   ***
[00:34:52.000 --> 00:34:53.000]   ***
[00:34:53.000 --> 00:34:54.000]   ***
[00:34:55.000 --> 00:34:56.000]   ***
...
```

The expected behavior should be
```
[00:34:40.000 --> 00:34:47.000]   This is an example audio file.
[00:34:47.000 --> 00:34:49.000]   The text has been redacted
[00:34:49.000 --> 00:34:51.000]   This is the end of the audio.
- end of program -
```

This commit changes the calculation of the `seek_end` variable to
only add `seek_start` if a custom `duration_ms` is provided.
Otherwise, it defaults to the end of the file.

Signed-off-by: Thijs Raymakers <thijs@raymakers.nl>
2023-04-29 18:55:37 +03:00
bab97c83d0 tests : add "threads" to run-tests.sh 2023-04-29 12:32:28 +03:00
3eaeb030ff extra : add sync-ggml.sh script 2023-04-29 12:32:28 +03:00
acec73ab6e ggml : sync latest ggml + llama.cpp updates (quantization) 2023-04-29 12:32:28 +03:00
5cc17418c7 whisper.android : add some tips (#816) 2023-04-29 11:00:20 +03:00
3efb81dec6 build : add WHISPER_COREML_ALLOW_FALLBACK to make / CMake (#812) 2023-04-29 10:55:24 +03:00
94a7cd2a07 whisper : allow non-CoreML fallback when Core ML cannot be loaded (#812)
if the Core ML model cannot be loaded, continue without Core ML instead of
returning. This allows a single build to transcribe using Core ML models
where available, and regular models when not.
2023-04-29 10:49:02 +03:00
3e82ff4747 whisper : fix bug from previous commit 2023-04-29 10:42:14 +03:00
b5bd2f43c5 whisper : avoid designated initializers 2023-04-29 10:36:50 +03:00
94aa56f19e minor : improve C++ and Python style (#768)
* use some STL functions

* use self.field than setattr, use pathlib.Path

* recover some format

* const some iter

* Keep the original

* 2 space
2023-04-29 10:06:25 +03:00
4d89ee2e59 readme : add logo 2023-04-28 22:41:29 +03:00
70567eff23 main : escape quotes in csv output (#815) 2023-04-23 19:01:59 +03:00
02ec83c5d5 stream : flush upon finishing inference (#811) 2023-04-23 17:00:30 +03:00
2bd4b8d577 examples : add missing #include <cstdint> (#798)
common.cpp uses uint8_t and uint64_t, which are defined in <cstdint>.
2023-04-23 16:52:52 +03:00
eecf2c3d41 main : update escape_double_quotes() function (#776)
Updated the escape_double_quotes() function such that the function now escapes both double quotes and backslashes in the input string.

Changes Made:

- Renamed the function to escape_quotes_and_backslashes

- Modified the condition in the first loop to increment the value of 'escaped_length' for both double quotes and backslashes.

- Modified the condition in second loop to add a backslash before the current character if it is a double quote or a backslash.

Resolves: #769
2023-04-23 16:47:30 +03:00
c23588cc4b release : v1.3.0 2023-04-15 17:30:44 +03:00
5108b30e6d whisper : pad audio instead of spectrogram (#579)
Also, fallback only if more temperatures are available and if we are
at least 3 seconds before the end of the audio
2023-04-15 17:19:19 +03:00
f19e23fbd1 whisper : restore decoder temperature fallbacks
I disabled this because there were many complaints about slow decoding.
The current implementation does not allow batching the decoders when
using the "best of" or "beam size" parameters, so the decoding time is
proportional to the number of decoders, which is obviously not great.

However, now there are even more complaints about wrong decodings and
repetition.

So, making a compromise by re-enabling the fallbacks, but defaulting to
just 2 "best of" / "beam size" decoders. Also, the temperature step is
increased from 0.2 to 0.4 - i.e. from maximum of 5 fallbacks to maximum
of 2.

Also, the stream example now has fallbacks enabled by default.

close #471 #477 #508 #612 #719 #731
2023-04-15 16:12:55 +03:00
ea1f8a50d4 ggml, ci : fix build on whisper.android (ARM_NEON) + add CI (#764)
* ggml : fix undefined symbol by remove inline handle

* ggml : make own ggml_aligned_malloc function

* ci: add ios/android build
2023-04-15 14:21:58 +03:00
3dead611bb whisper : slightly faster Log Mel computation + n-1 FFT threads (#568) 2023-04-15 14:18:46 +03:00
355da83690 readme : fix link 2023-04-15 13:30:36 +03:00
3e5c49e59a readme : add usage instructions for Core ML 2023-04-15 13:30:07 +03:00
5e47e223bd whisper : add Core ML support (#566)
* coreml : use Core ML encoder inference

* coreml : simlpify whisper_encode + log messages

* whisper : resolve rebase conflicts

* coreml : add scripts for CoreML model generation

* bench-all : recognize COREML flag
2023-04-15 13:21:27 +03:00
794ff3074a whisper : do not launch log_mel threads when n_thread is 1 (#763) 2023-04-14 22:35:34 +03:00
7e2afa4384 whisper : fix the bug related to word splitting errors in the "tokenize" function. (#760)
Co-authored-by: AfryMask <afrymask@gmail.com>
2023-04-14 20:35:03 +03:00
1c5edc3cb3 readme : add SwiftWhisper to listed bindings (#755) 2023-04-14 20:24:00 +03:00
34b772727d gitignore : add .test 2023-04-14 20:13:47 +03:00
2c856fb9e5 whisper : fix potential memory leaks (#740)
* fix potential memory leak if whisper_init_state failed

* fix potential memory leak if gpt2_init failed
2023-04-14 20:05:56 +03:00
7727a40dc9 license : update year (#739) 2023-04-14 20:04:42 +03:00
b5639ed313 whisper : fix typos in whisper.h (#737)
Fixed a couple of typos (in comments, so nothing major). Keep up the great work 😄
2023-04-14 20:03:16 +03:00
2c4ac2627d stream : support language auto-detect (#501)
#445  fix Language auto-detect "auto" flag does not work using the stream tool
2023-04-14 20:02:18 +03:00
674a8e579b readme : add unity bindings (#733) 2023-04-14 19:59:44 +03:00
001083a769 talk, talk-llama : add basic example script for eleven-labs tts (#728) 2023-04-14 19:53:58 +03:00
62b51c3070 models : change convert-pt-to-ggml to use .tiktoken tokenizer files (#725) 2023-04-14 19:50:39 +03:00
61128870b8 cmake : add msvc compiler args /utf-8 fix error C3688 (#721)
* force msvc compiler use utf-8 encode

* only enable on msvc
2023-04-14 19:36:38 +03:00
78548dc03f talk-llama : correct default speak.sh path (#720)
There is `speak.sh` file in `./examples/talk-llama` as described in README.
However `./examples/talk/speak.sh` is used in `talk-llama.cpp`, this commit corrects that.
2023-04-14 19:36:09 +03:00
66110dafcc main : add lrc output support (#718)
* add lrc output support.

* fix wrong comment
2023-04-14 19:35:33 +03:00
Sam
b73a4638ac readme : make the quick start instructions clearer. (#716)
Users wanting to make use of this implementation of the whisper model with no prior knowledge of C/C++ may download the Whisper model but fail to use of the "make" command as specified given that they forgot or didn't know they needed to clone the repository first. Hope this modification clears things up.
2023-04-14 19:33:06 +03:00
5f16420333 make : disable avx in case f16c is not available (#706)
Why:

* ggml.c does not support AVX without F16C
2023-04-14 19:31:51 +03:00
ccb47e7e10 readme : add shell command example for --print-colors (#710)
The section of the readme file explaining `--print-colors` includes only a screenshot with directories that are inconsistent with other examples. This commit adds an example shell command, consistent with the remaining examples.
2023-04-14 19:25:23 +03:00
677ad754a0 ggml : sync latest ggml 2023-04-14 19:20:39 +03:00
514cd04452 whisper : fix bug in prompt processing (close #705)
Was dereferencing a dangling pointer
2023-04-14 19:17:07 +03:00
6704a81255 go : exposed various parts to the Go Interface (#697) 2023-04-14 18:52:10 +03:00
463e46338c ggml : fix q4_1 dot product types (#759)
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2023-04-14 13:34:20 +03:00
2f889132c6 ggml : sync latest changes from ggml and llama.cpp 2023-04-13 18:53:44 +03:00
ebef1e8620 ggml : fix WASM build 2023-04-10 23:18:29 +03:00
114df388fe talk-llama : increase context to 2048 2023-04-10 23:09:15 +03:00
ea36831459 talk-llama : update to latest llama.cpp (improved performance) 2023-04-10 22:59:13 +03:00
69b8503935 ggml : backport llama.cpp updates (close #709)
- About x2 overall performance improvement on Apple Silicon
- Results should now be the same for different number of threads (not
  tested)
2023-04-10 22:28:54 +03:00
0a2d1210bc whisper : add progress callback (#600) 2023-03-30 20:29:29 +03:00
859ffc994e misc : typo (#688) 2023-03-30 07:51:33 +03:00
5e6e2187a3 talk-llama : fixing usage message for talk-llama (#687)
"-ml" instead of "-mg" for specifying the llama file
2023-03-30 00:10:20 +03:00
a7f1f33715 main : add <cstring> header 2023-03-29 23:59:45 +03:00
86ecfc6333 whisper.addon : fixed test to new async implementation (#686)
* fixed blocking code on node addon

* modify the example to run async

* format

* added logic to see the whisper output

* added logic to see the whisper output

* removed extra function for more clean example

* fixed whisper test to new async implementation
2023-03-29 23:59:17 +03:00
18e6fb0287 models : handle spaces and special characters in shell script paths (#677)
This commit modifies the `get_script_path` function to correctly handle
spaces and special characters in directory paths. The fix involves adding
double quotes around variables and commands where needed to ensure proper
parsing of paths with spaces and special characters.
2023-03-29 23:38:33 +03:00
0f759f125d main : fix typo in JSON output (#648)
* typo in JSON output

* fix double quotes in JSON output
2023-03-29 23:26:39 +03:00
eefed45e37 whisper : add initial_prompt param (#645) 2023-03-29 23:23:23 +03:00
aac1710afb make : 32-bit ARM flags (#486)
* issue #470 - working 32-bit ARM

* Update Makefile

* Update Makefile

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2023-03-29 23:11:35 +03:00
21c1e6afc5 whisper.swiftui : update README.md (#682)
- Slight tweaks to README for improved comprehension.
2023-03-29 23:04:38 +03:00
a47e812a54 talk-llama : add alpaca support (#668) 2023-03-29 23:01:14 +03:00
42c6855103 whisper : bump "large" scratch buffer even mode (close #671) 2023-03-28 10:50:49 +03:00
0be9cd3497 whisper : increase scratch buffers after recent change (#671)
Should fix the error:

ggml_new_tensor_impl: not enough space in the scratch memory
2023-03-28 10:36:16 +03:00
e5c197d8aa talk-llama : add discussion link 2023-03-28 10:11:34 +03:00
7cd1d3bc34 talk-llama : try to fix windows build .. 2023-03-27 22:40:59 +03:00
82637b8e9f readme : add talk-llama example to the table 2023-03-27 21:02:35 +03:00
4a0deb8b1e talk-llama : add new example + sync ggml from llama.cpp (#664)
* talk-llama : talk with LLaMA AI

* talk.llama : disable EOS token

* talk-llama : add README instructions

* ggml : fix build in debug
2023-03-27 21:00:32 +03:00
142 changed files with 33869 additions and 4349 deletions

View File

@ -125,8 +125,10 @@ jobs:
include:
- arch: Win32
s2arc: x86
jnaPath: win32-x86
- arch: x64
s2arc: x64
jnaPath: win32-x86-64
- sdl2: ON
s2ver: 2.26.0
@ -159,6 +161,12 @@ jobs:
if: matrix.sdl2 == 'ON'
run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }}
- name: Upload dll
uses: actions/upload-artifact@v3
with:
name: ${{ matrix.jnaPath }}_whisper.dll
path: build/bin/${{ matrix.build }}/whisper.dll
- name: Upload binaries
if: matrix.sdl2 == 'ON'
uses: actions/upload-artifact@v1
@ -235,6 +243,61 @@ jobs:
with:
name: whisper-blas-bin-${{ matrix.arch }}
path: build/bin/${{ matrix.build }}
windows-cublas:
runs-on: windows-latest
strategy:
matrix:
build: [Release]
arch: [x64]
cublas: [ON]
sdl2: [ON]
include:
- arch: x64
s2arc: x64
- sdl2: ON
s2ver: 2.26.0
steps:
- name: Clone
uses: actions/checkout@v1
- name: Add msbuild to PATH
uses: microsoft/setup-msbuild@v1
- name: Install CUDA Toolkit
id: cuda-toolkit
uses: Jimver/cuda-toolkit@v0.2.10
- name: Fetch SDL2 and set SDL2_DIR
if: matrix.sdl2 == 'ON'
run: |
C:/msys64/usr/bin/wget.exe -qO sdl2.zip https://github.com/libsdl-org/SDL/releases/download/release-${{ matrix.s2ver }}/SDL2-devel-${{ matrix.s2ver }}-VC.zip
7z x sdl2.zip
echo "SDL2_DIR=$env:GITHUB_WORKSPACE/SDL2-${{ matrix.s2ver }}/cmake" >> $env:GITHUB_ENV
- name: Configure
run: >
cmake -S . -B ./build -A ${{ matrix.arch }}
-DCMAKE_BUILD_TYPE=${{ matrix.build }}
-DWHISPER_CUBLAS=1
- name: Build
run: |
cd ./build
msbuild ALL_BUILD.vcxproj -t:build -p:configuration=${{ matrix.build }} -p:platform=${{ matrix.arch }}
- name: Copy SDL2.dll
if: matrix.sdl2 == 'ON'
run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }}
- name: Upload binaries
if: matrix.sdl2 == 'ON'
uses: actions/upload-artifact@v1
with:
name: whisper-cublas-bin-${{ matrix.arch }}
path: build/bin/${{ matrix.build }}
emscripten:
runs-on: ubuntu-latest
@ -265,3 +328,85 @@ jobs:
popd
emcmake cmake . -DCMAKE_BUILD_TYPE=${{ matrix.build }}
make
ios:
runs-on: macos-latest
strategy:
matrix:
build: [Release]
steps:
- name: Clone
uses: actions/checkout@v1
- name: Configure
run: |
cp models/for-tests-ggml-base.en.bin models/ggml-base.en.bin
mkdir models/ggml-base.en-encoder.mlmodelc
- name: Build objc example
run: xcodebuild -project examples/whisper.objc/whisper.objc.xcodeproj -scheme whisper.objc -configuration ${{ matrix.build }} -sdk iphonesimulator build
- name: Build swiftui example
run: xcodebuild -project examples/whisper.swiftui/whisper.swiftui.xcodeproj -scheme WhisperCppDemo -configuration ${{ matrix.build }} -sdk iphonesimulator build
android:
runs-on: ubuntu-latest
steps:
- name: Clone
uses: actions/checkout@v1
- name: Install Java
uses: actions/setup-java@v3
with:
distribution: zulu
java-version: 17
- name: Setup Android SDK
uses: android-actions/setup-android@v2
- name: Build
run: |
cd examples/whisper.android
./gradlew assembleRelease --no-daemon
java:
needs: [ 'windows' ]
runs-on: windows-latest
steps:
- uses: actions/checkout@v1
- name: Install Java
uses: actions/setup-java@v1
with:
java-version: 17
- name: Download Windows lib
uses: actions/download-artifact@v3
with:
name: win32-x86-64_whisper.dll
path: bindings/java/build/generated/resources/main/win32-x86-64
- name: Build
run: |
models\download-ggml-model.cmd tiny.en
cd bindings/java
chmod +x ./gradlew
./gradlew build
- name: Upload jar
uses: actions/upload-artifact@v3
with:
name: whispercpp.jar
path: bindings/java/build/libs/whispercpp-*.jar
# - name: Publish package
# if: ${{ github.ref == 'refs/heads/master' }}
# uses: gradle/gradle-build-action@v2
# with:
# arguments: publish
# env:
# MAVEN_USERNAME: ${{ secrets.OSSRH_USERNAME }}
# MAVEN_PASSWORD: ${{ secrets.OSSRH_TOKEN }}

12
.gitignore vendored
View File

@ -1,6 +1,8 @@
*.o
*.a
.cache/
.coreml/
.test/
.vs/
.vscode/
.DS_Store
@ -10,6 +12,7 @@ build-em/
build-debug/
build-release/
build-static/
build-cublas/
build-no-accel/
build-sanitize-addr/
build-sanitize-thread/
@ -18,7 +21,9 @@ build-sanitize-thread/
/stream
/command
/talk
/talk-llama
/bench
/quantize
arm_neon.h
sync.sh
@ -32,3 +37,10 @@ examples/whisper.objc/whisper.objc.xcodeproj/xcuserdata/
examples/whisper.objc/whisper.objc.xcodeproj/project.xcworkspace/xcuserdata
extra/bench-gg.txt
models/*.mlmodel
models/*.mlmodelc
models/*.mlpackage
bindings/java/.gradle/
bindings/java/.idea/
.idea/

View File

@ -1,6 +1,6 @@
cmake_minimum_required (VERSION 3.0)
project(whisper.cpp VERSION 1.2.1)
project(whisper.cpp VERSION 1.4.2)
# Add path to modules
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
@ -35,30 +35,40 @@ endif()
# options
option(BUILD_SHARED_LIBS "whisper: build shared libs" ${BUILD_SHARED_LIBS_DEFAULT})
option(BUILD_SHARED_LIBS "whisper: build shared libs" ${BUILD_SHARED_LIBS_DEFAULT})
option(WHISPER_ALL_WARNINGS "whisper: enable all compiler warnings" ON)
option(WHISPER_ALL_WARNINGS_3RD_PARTY "whisper: enable all compiler warnings in 3rd party libs" OFF)
option(WHISPER_ALL_WARNINGS "whisper: enable all compiler warnings" ON)
option(WHISPER_ALL_WARNINGS_3RD_PARTY "whisper: enable all compiler warnings in 3rd party libs" OFF)
option(WHISPER_SANITIZE_THREAD "whisper: enable thread sanitizer" OFF)
option(WHISPER_SANITIZE_ADDRESS "whisper: enable address sanitizer" OFF)
option(WHISPER_SANITIZE_UNDEFINED "whisper: enable undefined sanitizer" OFF)
option(WHISPER_SANITIZE_THREAD "whisper: enable thread sanitizer" OFF)
option(WHISPER_SANITIZE_ADDRESS "whisper: enable address sanitizer" OFF)
option(WHISPER_SANITIZE_UNDEFINED "whisper: enable undefined sanitizer" OFF)
option(WHISPER_BUILD_TESTS "whisper: build tests" ${WHISPER_STANDALONE})
option(WHISPER_BUILD_EXAMPLES "whisper: build examples" ${WHISPER_STANDALONE})
option(WHISPER_BUILD_TESTS "whisper: build tests" ${WHISPER_STANDALONE})
option(WHISPER_BUILD_EXAMPLES "whisper: build examples" ${WHISPER_STANDALONE})
option(WHISPER_SUPPORT_SDL2 "whisper: support for libSDL2" OFF)
option(WHISPER_SDL2 "whisper: support for libSDL2" OFF)
option(WHISPER_NO_AVX "whisper: disable AVX" OFF)
option(WHISPER_NO_AVX2 "whisper: disable AVX2" OFF)
option(WHISPER_NO_FMA "whisper: disable FMA" OFF)
option(WHISPER_NO_F16C "whisper: disable F16c" OFF)
option(WHISPER_OPENVINO "whisper: support for OpenVINO" OFF)
if (APPLE)
option(WHISPER_NO_ACCELERATE "whisper: disable Accelerate framework" OFF)
option(WHISPER_NO_AVX "whisper: disable AVX" OFF)
option(WHISPER_NO_AVX2 "whisper: disable AVX2" OFF)
option(WHISPER_NO_FMA "whisper: disable FMA" OFF)
option(WHISPER_NO_ACCELERATE "whisper: disable Accelerate framework" OFF)
option(WHISPER_COREML "whisper: enable Core ML framework" OFF)
option(WHISPER_COREML_ALLOW_FALLBACK "whisper: allow non-CoreML fallback" OFF)
else()
option(WHISPER_SUPPORT_OPENBLAS "whisper: support for OpenBLAS" OFF)
option(WHISPER_BLAS "whisper: use BLAS libraries" OFF)
option(WHISPER_BLAS_VENDOR "whisper: BLAS library vendor" Generic)
option(WHISPER_OPENBLAS "whisper: prefer OpenBLAS" OFF)
option(WHISPER_CUBLAS "whisper: support for cuBLAS" OFF)
option(WHISPER_CLBLAST "whisper: use CLBlast" OFF)
endif()
option(WHISPER_PERF "whisper: enable perf timings" OFF)
option(WHISPER_PERF "whisper: enable perf timings" OFF)
# sanitizers
@ -86,31 +96,106 @@ endif()
find_package(Threads REQUIRED)
# on APPLE - include Accelerate framework
if (APPLE AND NOT WHISPER_NO_ACCELERATE)
find_library(ACCELERATE_FRAMEWORK Accelerate)
if (ACCELERATE_FRAMEWORK)
message(STATUS "Accelerate framework found")
# on APPLE
if (APPLE)
# include Accelerate framework
if (NOT WHISPER_NO_ACCELERATE)
find_library(ACCELERATE_FRAMEWORK Accelerate)
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ${ACCELERATE_FRAMEWORK})
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_USE_ACCELERATE)
else()
message(WARNING "Accelerate framework not found")
if (ACCELERATE_FRAMEWORK)
message(STATUS "Accelerate framework found")
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ${ACCELERATE_FRAMEWORK})
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_USE_ACCELERATE)
else()
message(WARNING "Accelerate framework not found")
endif()
endif()
if (WHISPER_COREML)
find_library(FOUNDATION_FRAMEWORK Foundation)
find_library(COREML_FRAMEWORK CoreML)
if (COREML_FRAMEWORK)
message(STATUS "CoreML framework found")
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DWHISPER_USE_COREML)
else()
message(WARNING "CoreML framework not found")
endif()
if (WHISPER_COREML_ALLOW_FALLBACK)
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DWHISPER_COREML_ALLOW_FALLBACK)
endif()
endif()
endif()
if (WHISPER_SUPPORT_OPENBLAS)
find_library(OPENBLAS_LIB
NAMES openblas libopenblas
)
if (OPENBLAS_LIB)
message(STATUS "OpenBLAS found")
if (WHISPER_OPENBLAS)
set(WHISPER_BLAS_VENDOR "OpenBLAS")
set(WHISPER_BLAS ON)
endif()
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ${OPENBLAS_LIB})
if (WHISPER_BLAS)
set(BLA_STATIC 1)
set(BLA_VENDOR ${WHISPER_BLAS_VENDOR})
# set(BLA_PREFER_PKGCONFIG 1)
set(BLA_SIZEOF_INTEGER 8)
find_package(BLAS)
if(BLAS_FOUND)
message(STATUS "BLAS compatible library found")
message(STATUS "Libraries ${BLAS_LIBRARIES}")
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_USE_OPENBLAS)
include_directories(${BLAS_INCLUDE_DIRS})
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ${BLAS_LIBRARIES})
else()
message(WARNING "OpenBLAS not found")
message(WARNING "BLAS library was not found")
endif()
endif ()
if (WHISPER_CUBLAS)
cmake_minimum_required(VERSION 3.17)
find_package(CUDAToolkit)
if (CUDAToolkit_FOUND)
message(STATUS "cuBLAS found")
enable_language(CUDA)
set(GGML_CUDA_SOURCES ggml-cuda.cu ggml-cuda.h)
add_compile_definitions(GGML_USE_CUBLAS)
if (WHISPER_STATIC)
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
else()
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
endif()
else()
message(WARNING "cuBLAS not found")
endif()
endif()
if (WHISPER_CLBLAST)
find_package(CLBlast)
if (CLBlast_FOUND)
message(STATUS "CLBlast found")
set(GGML_OPENCL_SOURCES ggml-opencl.cpp ggml-opencl.h)
add_compile_definitions(GGML_USE_CLBLAST)
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} clblast)
else()
message(WARNING "CLBlast not found")
endif()
endif()
if( WHISPER_OPENVINO )
find_package(OpenVINO REQUIRED COMPONENTS Runtime)
endif()
# compiler flags
@ -155,9 +240,17 @@ if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm" OR ${CMAKE_SYSTEM_PROCESSOR} MATCHES
else()
message(STATUS "x86 detected")
if (MSVC)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX2")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /arch:AVX2")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /arch:AVX2")
if(NOT WHISPER_NO_AVX2)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX2")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /arch:AVX2")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /arch:AVX2")
else()
if(NOT WHISPER_NO_AVX)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /arch:AVX")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /arch:AVX")
endif()
endif()
else()
if (EMSCRIPTEN)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -pthread")
@ -183,6 +276,51 @@ if (WHISPER_PERF)
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_PERF)
endif()
#
# whisper.coreml - Core ML support
#
if (WHISPER_COREML)
set(TARGET whisper.coreml)
add_library(${TARGET}
coreml/whisper-encoder.h
coreml/whisper-encoder.mm
coreml/whisper-encoder-impl.h
coreml/whisper-encoder-impl.m
)
include(DefaultTargetOptions)
target_include_directories(${TARGET} PUBLIC
.
)
target_link_libraries(${TARGET} PRIVATE ${FOUNDATION_FRAMEWORK} ${COREML_FRAMEWORK})
set_target_properties(${TARGET} PROPERTIES
COMPILE_FLAGS "-fobjc-arc"
)
endif()
if (WHISPER_OPENVINO)
set(TARGET whisper.openvino)
add_library(${TARGET} OBJECT
openvino/whisper-openvino-encoder.h
openvino/whisper-openvino-encoder.cpp
)
target_include_directories(${TARGET} PUBLIC
.
)
set_property(TARGET ${TARGET} PROPERTY POSITION_INDEPENDENT_CODE ON)
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DWHISPER_USE_OPENVINO)
target_link_libraries(${TARGET} PRIVATE openvino::runtime)
endif()
#
# whisper - this is the main library of the project
#
@ -192,6 +330,8 @@ set(TARGET whisper)
add_library(${TARGET}
ggml.h
ggml.c
${GGML_CUDA_SOURCES}
${GGML_OPENCL_SOURCES}
whisper.h
whisper.cpp
)
@ -202,6 +342,14 @@ target_include_directories(${TARGET} PUBLIC
.
)
if (WHISPER_COREML)
target_link_libraries(${TARGET} PRIVATE whisper.coreml)
endif()
if (WHISPER_OPENVINO)
target_link_libraries(${TARGET} PRIVATE whisper.openvino)
endif()
if (MSVC)
target_link_libraries(${TARGET} PRIVATE ${WHISPER_EXTRA_LIBS} ${CMAKE_THREAD_LIBS_INIT})
@ -217,7 +365,19 @@ if (BUILD_SHARED_LIBS)
target_compile_definitions(${TARGET} PUBLIC
WHISPER_SHARED
GGML_SHARED
)
target_compile_definitions(${TARGET} PRIVATE
WHISPER_BUILD
GGML_BUILD
)
endif()
if (GGML_CUDA_SOURCES)
message(STATUS "GGML CUDA sources found, configuring CUDA architecture")
set_property(TARGET whisper PROPERTY CUDA_ARCHITECTURES OFF)
set_property(TARGET whisper PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto")
endif()
if (EMSCRIPTEN)

View File

@ -1,6 +1,6 @@
MIT License
Copyright (c) 2022 Georgi Gerganov
Copyright (c) 2023 Georgi Gerganov
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal

149
Makefile
View File

@ -1,3 +1,5 @@
default: main bench quantize
ifndef UNAME_S
UNAME_S := $(shell uname -s)
endif
@ -36,10 +38,17 @@ LDFLAGS =
# ref: https://github.com/ggerganov/whisper.cpp/issues/37
ifneq ($(wildcard /usr/include/musl/*),)
CFLAGS += -D_POSIX_SOURCE -D_GNU_SOURCE
CFLAGS += -D_POSIX_SOURCE -D_GNU_SOURCE
CXXFLAGS += -D_POSIX_SOURCE -D_GNU_SOURCE
endif
# RLIMIT_MEMLOCK came in BSD, is not specified in POSIX.1,
# and on macOS its availability depends on enabling Darwin extensions
ifeq ($(UNAME_S),Darwin)
CFLAGS += -D_DARWIN_C_SOURCE
CXXFLAGS += -D_DARWIN_C_SOURCE
endif
# OS specific
# TODO: support Windows
ifeq ($(UNAME_S),Linux)
@ -77,10 +86,6 @@ ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686))
CFLAGS += -mavx2
endif
else ifeq ($(UNAME_S),Linux)
AVX1_M := $(shell grep "avx " /proc/cpuinfo)
ifneq (,$(findstring avx,$(AVX1_M)))
CFLAGS += -mavx
endif
AVX2_M := $(shell grep "avx2 " /proc/cpuinfo)
ifneq (,$(findstring avx2,$(AVX2_M)))
CFLAGS += -mavx2
@ -92,16 +97,17 @@ ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686))
F16C_M := $(shell grep "f16c " /proc/cpuinfo)
ifneq (,$(findstring f16c,$(F16C_M)))
CFLAGS += -mf16c
AVX1_M := $(shell grep "avx " /proc/cpuinfo)
ifneq (,$(findstring avx,$(AVX1_M)))
CFLAGS += -mavx
endif
endif
SSE3_M := $(shell grep "sse3 " /proc/cpuinfo)
ifneq (,$(findstring sse3,$(SSE3_M)))
CFLAGS += -msse3
endif
else ifeq ($(UNAME_S),Haiku)
AVX1_M := $(shell sysinfo -cpu | grep "AVX ")
ifneq (,$(findstring avx,$(AVX1_M)))
CFLAGS += -mavx
endif
AVX2_M := $(shell sysinfo -cpu | grep "AVX2 ")
ifneq (,$(findstring avx2,$(AVX2_M)))
CFLAGS += -mavx2
@ -113,6 +119,11 @@ ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686))
F16C_M := $(shell sysinfo -cpu | grep "F16C ")
ifneq (,$(findstring f16c,$(F16C_M)))
CFLAGS += -mf16c
AVX1_M := $(shell sysinfo -cpu | grep "AVX ")
ifneq (,$(findstring avx,$(AVX1_M)))
CFLAGS += -mavx
endif
endif
else
CFLAGS += -mfma -mf16c -mavx -mavx2
@ -121,6 +132,7 @@ endif
ifeq ($(UNAME_M),amd64)
CFLAGS += -mavx -mavx2 -mfma -mf16c
endif
ifneq ($(filter ppc64%,$(UNAME_M)),)
POWER9_M := $(shell grep "POWER9" /proc/cpuinfo)
ifneq (,$(findstring POWER9,$(POWER9_M)))
@ -131,6 +143,7 @@ ifneq ($(filter ppc64%,$(UNAME_M)),)
CXXFLAGS += -std=c++23 -DGGML_BIG_ENDIAN
endif
endif
ifndef WHISPER_NO_ACCELERATE
# Mac M1 - include Accelerate framework
ifeq ($(UNAME_S),Darwin)
@ -138,29 +151,71 @@ ifndef WHISPER_NO_ACCELERATE
LDFLAGS += -framework Accelerate
endif
endif
ifdef WHISPER_COREML
CXXFLAGS += -DWHISPER_USE_COREML
LDFLAGS += -framework Foundation -framework CoreML
ifdef WHISPER_COREML_ALLOW_FALLBACK
CXXFLAGS += -DWHISPER_COREML_ALLOW_FALLBACK
endif
endif
ifdef WHISPER_OPENBLAS
CFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/openblas
LDFLAGS += -lopenblas
endif
ifdef WHISPER_CUBLAS
CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include
CXXFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include
LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib
WHISPER_OBJ += ggml-cuda.o
NVCC = nvcc
NVCCFLAGS = --forward-unknown-to-host-compiler -arch=any
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
$(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@
endif
ifdef WHISPER_CLBLAST
CFLAGS += -DGGML_USE_CLBLAST
LDFLAGS += -lclblast -lOpenCL
WHISPER_OBJ += ggml-opencl.o
ggml-opencl.o: ggml-opencl.cpp ggml-opencl.h
$(CC) $(CFLAGS) -c $< -o $@
endif
ifdef WHISPER_GPROF
CFLAGS += -pg
CXXFLAGS += -pg
endif
ifneq ($(filter aarch64%,$(UNAME_M)),)
CFLAGS += -mcpu=native
CFLAGS += -mcpu=native
CXXFLAGS += -mcpu=native
endif
ifneq ($(filter armv6%,$(UNAME_M)),)
# Raspberry Pi 1, 2, 3
CFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access
# 32-bit Raspberry Pi 1, 2, 3
CFLAGS += -mfpu=neon -mfp16-format=ieee -mno-unaligned-access
endif
ifneq ($(filter armv7%,$(UNAME_M)),)
# Raspberry Pi 4
CFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations
# 32-bit ARM, for example on Armbian or possibly raspbian
#CFLAGS += -mfpu=neon -mfp16-format=ieee -funsafe-math-optimizations -mno-unaligned-access
#CXXFLAGS += -mfpu=neon -mfp16-format=ieee -funsafe-math-optimizations -mno-unaligned-access
# 64-bit ARM on 32-bit OS, use these (TODO: auto-detect 64-bit)
CFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -funsafe-math-optimizations -mno-unaligned-access
CXXFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -funsafe-math-optimizations -mno-unaligned-access
endif
ifneq ($(filter armv8%,$(UNAME_M)),)
# Raspberry Pi 4
CFLAGS += -mfp16-format=ieee -mno-unaligned-access
CFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -funsafe-math-optimizations -mno-unaligned-access
CXXFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -funsafe-math-optimizations -mno-unaligned-access
endif
#
@ -178,26 +233,36 @@ $(info I CC: $(CCV))
$(info I CXX: $(CXXV))
$(info )
default: main
#
# Build library
#
ggml.o: ggml.c ggml.h
$(CC) $(CFLAGS) -c ggml.c -o ggml.o
ggml.o: ggml.c ggml.h ggml-cuda.h
$(CC) $(CFLAGS) -c $< -o $@
whisper.o: whisper.cpp whisper.h
$(CXX) $(CXXFLAGS) -c whisper.cpp -o whisper.o
whisper.o: whisper.cpp whisper.h ggml.h ggml-cuda.h
$(CXX) $(CXXFLAGS) -c $< -o $@
libwhisper.a: ggml.o whisper.o
$(AR) rcs libwhisper.a ggml.o whisper.o
ifndef WHISPER_COREML
WHISPER_OBJ += whisper.o
else
whisper-encoder.o: coreml/whisper-encoder.mm coreml/whisper-encoder.h
$(CXX) -O3 -I . -fobjc-arc -c coreml/whisper-encoder.mm -o whisper-encoder.o
libwhisper.so: ggml.o whisper.o
$(CXX) $(CXXFLAGS) -shared -o libwhisper.so ggml.o whisper.o $(LDFLAGS)
whisper-encoder-impl.o: coreml/whisper-encoder-impl.m coreml/whisper-encoder-impl.h
$(CXX) -O3 -I . -fobjc-arc -c coreml/whisper-encoder-impl.m -o whisper-encoder-impl.o
WHISPER_OBJ += whisper.o whisper-encoder.o whisper-encoder-impl.o
endif
libwhisper.a: ggml.o $(WHISPER_OBJ)
$(AR) rcs libwhisper.a ggml.o $(WHISPER_OBJ)
libwhisper.so: ggml.o $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) -shared -o libwhisper.so ggml.o $(WHISPER_OBJ) $(LDFLAGS)
clean:
rm -f *.o main stream command talk bench libwhisper.a libwhisper.so
rm -f *.o main stream command talk talk-llama bench quantize libwhisper.a libwhisper.so
#
# Examples
@ -205,24 +270,30 @@ clean:
CC_SDL=`sdl2-config --cflags --libs`
SRC_COMMON = examples/common.cpp
SRC_COMMON = examples/common.cpp examples/common-ggml.cpp
SRC_COMMON_SDL = examples/common-sdl.cpp
main: examples/main/main.cpp $(SRC_COMMON) ggml.o whisper.o
$(CXX) $(CXXFLAGS) examples/main/main.cpp $(SRC_COMMON) ggml.o whisper.o -o main $(LDFLAGS)
main: examples/main/main.cpp $(SRC_COMMON) ggml.o $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/main/main.cpp $(SRC_COMMON) ggml.o $(WHISPER_OBJ) -o main $(LDFLAGS)
./main -h
stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o
$(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o -o stream $(CC_SDL) $(LDFLAGS)
bench: examples/bench/bench.cpp ggml.o $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/bench/bench.cpp ggml.o $(WHISPER_OBJ) -o bench $(LDFLAGS)
command: examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o
$(CXX) $(CXXFLAGS) examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o -o command $(CC_SDL) $(LDFLAGS)
quantize: examples/quantize/quantize.cpp ggml.o $(WHISPER_OBJ) $(SRC_COMMON)
$(CXX) $(CXXFLAGS) examples/quantize/quantize.cpp $(SRC_COMMON) ggml.o $(WHISPER_OBJ) -o quantize $(LDFLAGS)
talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o
$(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o -o talk $(CC_SDL) $(LDFLAGS)
stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o stream $(CC_SDL) $(LDFLAGS)
bench: examples/bench/bench.cpp ggml.o whisper.o
$(CXX) $(CXXFLAGS) examples/bench/bench.cpp ggml.o whisper.o -o bench $(LDFLAGS)
command: examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o command $(CC_SDL) $(LDFLAGS)
talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o talk $(CC_SDL) $(LDFLAGS)
talk-llama: examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o talk-llama $(CC_SDL) $(LDFLAGS)
#
# Audio samples
@ -237,12 +308,16 @@ samples:
@wget --quiet --show-progress -O samples/gb1.ogg https://upload.wikimedia.org/wikipedia/commons/1/1f/George_W_Bush_Columbia_FINAL.ogg
@wget --quiet --show-progress -O samples/hp0.ogg https://upload.wikimedia.org/wikipedia/en/d/d4/En.henryfphillips.ogg
@wget --quiet --show-progress -O samples/mm1.wav https://cdn.openai.com/whisper/draft-20220913a/micro-machines.wav
@wget --quiet --show-progress -O samples/a13.mp3 https://upload.wikimedia.org/wikipedia/commons/transcoded/6/6f/Apollo13-wehaveaproblem.ogg/Apollo13-wehaveaproblem.ogg.mp3
@echo "Converting to 16-bit WAV ..."
@ffmpeg -loglevel -0 -y -i samples/gb0.ogg -ar 16000 -ac 1 -c:a pcm_s16le samples/gb0.wav
@ffmpeg -loglevel -0 -y -i samples/gb1.ogg -ar 16000 -ac 1 -c:a pcm_s16le samples/gb1.wav
@ffmpeg -loglevel -0 -y -i samples/hp0.ogg -ar 16000 -ac 1 -c:a pcm_s16le samples/hp0.wav
@rm samples/*.ogg
@ffmpeg -loglevel -0 -y -i samples/mm1.wav -ar 16000 -ac 1 -c:a pcm_s16le samples/mm0.wav
@rm samples/mm1.wav
@ffmpeg -loglevel -0 -y -i samples/a13.mp3 -ar 16000 -ac 1 -c:a pcm_s16le -ss 00:00:00 -to 00:00:30 samples/a13.wav
@rm samples/a13.mp3
#
# Models

185
README.md
View File

@ -1,21 +1,27 @@
# whisper.cpp
![whisper.cpp](https://user-images.githubusercontent.com/1991296/235238348-05d0f6a4-da44-4900-a1de-d0707e75b763.jpeg)
[![Actions Status](https://github.com/ggerganov/whisper.cpp/workflows/CI/badge.svg)](https://github.com/ggerganov/whisper.cpp/actions)
[![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](https://opensource.org/licenses/MIT)
[![npm](https://img.shields.io/npm/v/whisper.cpp.svg)](https://www.npmjs.com/package/whisper.cpp/)
Stable: [v1.2.1](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.2.1) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126)
Beta: [v1.4.2](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.4.2) / Stable: [v1.2.1](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.2.1) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126)
High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model:
- Plain C/C++ implementation without dependencies
- Apple silicon first-class citizen - optimized via Arm Neon and Accelerate framework
- Apple silicon first-class citizen - optimized via ARM NEON, Accelerate framework and [Core ML](https://github.com/ggerganov/whisper.cpp#core-ml-support)
- AVX intrinsics support for x86 architectures
- VSX intrinsics support for POWER architectures
- Mixed F16 / F32 precision
- [4-bit and 5-bit integer quantization support](https://github.com/ggerganov/whisper.cpp#quantization)
- Low memory usage (Flash Attention)
- Zero memory allocations at runtime
- Runs on the CPU
- [Partial GPU support for NVIDIA via cuBLAS](https://github.com/ggerganov/whisper.cpp#nvidia-gpu-support-via-cublas)
- [Partial OpenCL GPU support via CLBlast](https://github.com/ggerganov/whisper.cpp#opencl-gpu-support-via-clblast)
- [BLAS CPU support via OpenBLAS](https://github.com/ggerganov/whisper.cpp#blas-cpu-support-via-openblas)
- [C-style API](https://github.com/ggerganov/whisper.cpp/blob/master/whisper.h)
Supported platforms:
@ -23,6 +29,7 @@ Supported platforms:
- [x] Mac OS (Intel and Arm)
- [x] [iOS](examples/whisper.objc)
- [x] [Android](examples/whisper.android)
- [x] [Java](bindings/java/README.md)
- [x] Linux / [FreeBSD](https://github.com/ggerganov/whisper.cpp/issues/56#issuecomment-1350920264)
- [x] [WebAssembly](examples/whisper.wasm)
- [x] Windows ([MSVC](https://github.com/ggerganov/whisper.cpp/blob/master/.github/workflows/build.yml#L117-L144) and [MinGW](https://github.com/ggerganov/whisper.cpp/issues/168)]
@ -58,12 +65,16 @@ the Accelerate framework utilizes the special-purpose AMX coprocessor available
## Quick start
First, download one of the Whisper models converted in [ggml format](models). For example:
First clone the repository.
Then, download one of the Whisper models converted in [ggml format](models). For example:
```bash
bash ./models/download-ggml-model.sh base.en
```
If you wish to convert the Whisper models to ggml format yourself, instructions are in [models/README.md](models/README.md).
Now build the [main](examples/main) example and transcribe an audio file like this:
```bash
@ -104,6 +115,7 @@ options:
-lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail
-su, --speed-up [false ] speed up audio by x2 (reduced accuracy)
-tr, --translate [false ] translate from source language to english
-tdrz, --tinydiarize [false ] enable tinydiarize (requires a tdrz model)
-di, --diarize [false ] stereo audio diarization
-nf, --no-fallback [false ] do not use temperature fallback while decoding
-otxt, --output-txt [false ] output result in a text file
@ -223,10 +235,134 @@ make large
| medium | 1.5 GB | ~1.7 GB | `fd9727b6e1217c2f614f9b698455c4ffd82463b4` |
| large | 2.9 GB | ~3.3 GB | `0f4c8e34f21cf1a914c59d8b3ce882345ad349d6` |
## Quantization
`whisper.cpp` supports integer quantization of the Whisper `ggml` models.
Quantized models require less memory and disk space and depending on the hardware can be processed more efficiently.
Here are the steps for creating and using a quantized model:
```bash
# quantize a model with Q5_0 method
make quantize
./quantize models/ggml-base.en.bin models/ggml-base.en-q5_0.bin q5_0
# run the examples as usual, specifying the quantized model file
./main -m models/ggml-base.en-q5_0.bin ./samples/gb0.wav
```
## Core ML support
On Apple Silicon devices, the Encoder inference can be executed on the Apple Neural Engine (ANE) via Core ML. This can result in significant
speed-up - more than x3 faster compared with CPU-only execution. Here are the instructions for generating a Core ML model and using it with `whisper.cpp`:
- Install Python dependencies needed for the creation of the Core ML model:
```bash
pip install ane_transformers
pip install openai-whisper
pip install coremltools
```
- To ensure `coremltools` operates correctly, please confirm that [Xcode](https://developer.apple.com/xcode/) is installed and execute `xcode-select --install` to install the command-line tools.
- Python 3.10 is recommended.
- [OPTIONAL] It is recommended to utilize a Python version management system, such as [Miniconda](https://docs.conda.io/en/latest/miniconda.html) for this step:
- To create an environment, use: `conda create -n py310-whisper python=3.10 -y`
- To activate the environment, use: `conda activate py310-whisper`
- Generate a Core ML model. For example, to generate a `base.en` model, use:
```bash
./models/generate-coreml-model.sh base.en
```
This will generate the folder `models/ggml-base.en-encoder.mlmodelc`
- Build `whisper.cpp` with Core ML support:
```bash
# using Makefile
make clean
WHISPER_COREML=1 make -j
# using CMake
cd build
cmake -DWHISPER_COREML=1 ..
```
- Run the examples as usual. For example:
```bash
./main -m models/ggml-base.en.bin -f samples/jfk.wav
...
whisper_init_state: loading Core ML model from 'models/ggml-base.en-encoder.mlmodelc'
whisper_init_state: first run on a device may take a while ...
whisper_init_state: Core ML model loaded
system_info: n_threads = 4 / 10 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 | COREML = 1 |
...
```
The first run on a device is slow, since the ANE service compiles the Core ML model to some device-specific format.
Next runs are faster.
For more information about the Core ML implementation please refer to PR [#566](https://github.com/ggerganov/whisper.cpp/pull/566).
## NVIDIA GPU support via cuBLAS
With NVIDIA cards the Encoder processing can to a large extent be offloaded to the GPU through cuBLAS.
First, make sure you have installed `cuda`: https://developer.nvidia.com/cuda-downloads
Now build `whisper.cpp` with cuBLAS support:
```
make clean
WHISPER_CUBLAS=1 make -j
```
## OpenCL GPU support via CLBlast
For cards and integrated GPUs that support OpenCL, the Encoder processing can be largely offloaded to the GPU through CLBlast. This is especially useful for users with AMD APUs or low end devices for up to ~2x speedup.
First, make sure you have installed `CLBlast` for your OS or Distribution: https://github.com/CNugteren/CLBlast
Now build `whisper.cpp` with CLBlast support:
```
Makefile:
cd whisper.cpp
make clean
WHISPER_CLBLAST=1 make -j
CMake:
cd whisper.cpp ; mkdir build ; cd build
cmake -DWHISPER_CLBLAST=ON ..
make clean
make -j
cp bin/* ../
```
Run all the examples as usual.
## BLAS CPU support via OpenBLAS
Encoder processing can be accelerated on the CPU via OpenBLAS.
First, make sure you have installed `openblas`: https://www.openblas.net/
Now build `whisper.cpp` with OpenBLAS support:
```
make clean
WHISPER_OPENBLAS=1 make -j
```
## Limitations
- Inference only
- No GPU support (yet)
## Another example
@ -313,7 +449,7 @@ whisper_print_timings: total time = 32733.52 ms
## Real-time audio input example
This is a naive example of performing real-time inference on audio from your microphone.
The [stream](examples/stream) tool samples the audio every half a second and runs the transcription continously.
The [stream](examples/stream) tool samples the audio every half a second and runs the transcription continuously.
More info is available in [issue #10](https://github.com/ggerganov/whisper.cpp/issues/10).
```java
@ -328,6 +464,10 @@ https://user-images.githubusercontent.com/1991296/194935793-76afede7-cfa8-48d8-a
Adding the `--print-colors` argument will print the transcribed text using an experimental color coding strategy
to highlight words with high or low confidence:
```java
./main -m models/ggml-base.en.bin -f samples/gb0.wav --print-colors
```
<img width="965" alt="image" src="https://user-images.githubusercontent.com/1991296/197356445-311c8643-9397-4e5e-b46e-0b4b4daa2530.png">
## Controlling the length of the generated text segments (experimental)
@ -354,7 +494,7 @@ main: processing './samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 pr
[00:00:10.020 --> 00:00:11.000] country.
```
## Word-level timestamp
## Word-level timestamp (experimental)
The `--max-len` argument can be used to obtain word-level timestamps. Simply use `-ml 1`:
@ -367,7 +507,7 @@ system_info: n_threads = 4 / 10 | AVX2 = 0 | AVX512 = 0 | NEON = 1 | FP16_VA = 1
main: processing './samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 processors, lang = en, task = transcribe, timestamps = 1 ...
[00:00:00.000 --> 00:00:00.320]
[00:00:00.000 --> 00:00:00.320]
[00:00:00.320 --> 00:00:00.370] And
[00:00:00.370 --> 00:00:00.690] so
[00:00:00.690 --> 00:00:00.850] my
@ -395,6 +535,32 @@ main: processing './samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 pr
[00:00:10.510 --> 00:00:11.000] .
```
## Speaker segmentation via tinydiarize (experimental)
More information about this approach is available here: https://github.com/ggerganov/whisper.cpp/pull/1058
Sample usage:
```py
# download a tinydiarize compatible model
./models/download-ggml-model.sh small.en-tdrz
# run as usual, adding the "-tdrz" command-line argument
./main -f ./samples/a13.wav -m ./models/ggml-small.en-tdrz.bin -tdrz
...
main: processing './samples/a13.wav' (480000 samples, 30.0 sec), 4 threads, 1 processors, lang = en, task = transcribe, tdrz = 1, timestamps = 1 ...
...
[00:00:00.000 --> 00:00:03.800] Okay Houston, we've had a problem here. [SPEAKER_TURN]
[00:00:03.800 --> 00:00:06.200] This is Houston. Say again please. [SPEAKER_TURN]
[00:00:06.200 --> 00:00:08.260] Uh Houston we've had a problem.
[00:00:08.260 --> 00:00:11.320] We've had a main beam up on a volt. [SPEAKER_TURN]
[00:00:11.320 --> 00:00:13.820] Roger main beam interval. [SPEAKER_TURN]
[00:00:13.820 --> 00:00:15.100] Uh uh [SPEAKER_TURN]
[00:00:15.100 --> 00:00:18.020] So okay stand, by thirteen we're looking at it. [SPEAKER_TURN]
[00:00:18.020 --> 00:00:25.740] Okay uh right now uh Houston the uh voltage is uh is looking good um.
[00:00:27.620 --> 00:00:29.940] And we had a a pretty large bank or so.
```
## Karaoke-style movie generation (experimental)
The [main](examples/main) example provides support for output of karaoke-style movies, where the
@ -478,8 +644,11 @@ in [models](models).
- [X] Javascript: [bindings/javascript](bindings/javascript) | [#309](https://github.com/ggerganov/whisper.cpp/discussions/309)
- React Native (iOS / Android): [whisper.rn](https://github.com/mybigday/whisper.rn)
- [X] Go: [bindings/go](bindings/go) | [#312](https://github.com/ggerganov/whisper.cpp/discussions/312)
- [X] Java:
- [GiviMAD/whisper-jni](https://github.com/GiviMAD/whisper-jni)
- [X] Ruby: [bindings/ruby](bindings/ruby) | [#507](https://github.com/ggerganov/whisper.cpp/discussions/507)
- [X] Objective-C / Swift: [ggerganov/whisper.spm](https://github.com/ggerganov/whisper.spm) | [#313](https://github.com/ggerganov/whisper.cpp/discussions/313)
- [exPHAT/SwiftWhisper](https://github.com/exPHAT/SwiftWhisper)
- [X] .NET: | [#422](https://github.com/ggerganov/whisper.cpp/discussions/422)
- [sandrohanea/whisper.net](https://github.com/sandrohanea/whisper.net)
- [NickDarvey/whisper](https://github.com/NickDarvey/whisper)
@ -487,6 +656,7 @@ in [models](models).
- [stlukey/whispercpp.py](https://github.com/stlukey/whispercpp.py) (Cython)
- [aarnphm/whispercpp](https://github.com/aarnphm/whispercpp) (Pybind11)
- [X] R: [bnosac/audio.whisper](https://github.com/bnosac/audio.whisper)
- [X] Unity: [macoron/whisper.unity](https://github.com/Macoron/whisper.unity)
## Examples
@ -500,6 +670,7 @@ Some of the examples are even ported to run in the browser using WebAssembly. Ch
| [stream](examples/stream) | [stream.wasm](examples/stream.wasm) | Real-time transcription of raw microphone capture |
| [command](examples/command) | [command.wasm](examples/command.wasm) | Basic voice assistant example for receiving voice commands from the mic |
| [talk](examples/talk) | [talk.wasm](examples/talk.wasm) | Talk with a GPT-2 bot |
| [talk-llama](examples/talk-llama) | | Talk with a LLaMA bot |
| [whisper.objc](examples/whisper.objc) | | iOS mobile application using whisper.cpp |
| [whisper.swiftui](examples/whisper.swiftui) | | SwiftUI iOS / macOS application using whisper.cpp |
| [whisper.android](examples/whisper.android) | | Android mobile application using whisper.cpp |

View File

@ -32,7 +32,7 @@ mkdir:
modtidy:
@go mod tidy
clean:
clean:
@echo Clean
@rm -fr $(BUILD_DIR)
@go clean

View File

@ -31,7 +31,7 @@ func main() {
if err != nil {
panic(err)
}
if err := context.Process(samples, nil); err != nil {
if err := context.Process(samples, nil, nil); err != nil {
return err
}
@ -71,7 +71,7 @@ The examples are placed in the `build` directory. Once built, you can download a
And you can then test a model against samples with the following command:
```bash
./build/go-whisper -model models/ggml-tiny.en.bin samples/jfk.wav
./build/go-whisper -model models/ggml-tiny.en.bin samples/jfk.wav
```
## Using the bindings

View File

@ -67,7 +67,7 @@ func Process(model whisper.Model, path string, flags *Flags) error {
// Process the data
fmt.Fprintf(flags.Output(), " ...processing %q\n", path)
context.ResetTimings()
if err := context.Process(data, cb); err != nil {
if err := context.Process(data, cb, nil); err != nil {
return err
}

View File

@ -105,6 +105,10 @@ func (p *Params) SetMaxSegmentLength(n int) {
p.max_len = C.int(n)
}
func (p *Params) SetTokenTimestamps(b bool) {
p.token_timestamps = toBool(b)
}
// Set max tokens per segment (0 = no limit)
func (p *Params) SetMaxTokensPerSegment(n int) {
p.max_tokens = C.int(n)

View File

@ -93,7 +93,7 @@ func (context *context) SetOffset(v time.Duration) {
// Set duration of audio to process
func (context *context) SetDuration(v time.Duration) {
context.params.SetOffset(int(v.Milliseconds()))
context.params.SetDuration(int(v.Milliseconds()))
}
// Set timestamp token probability threshold (~0.01)
@ -111,6 +111,11 @@ func (context *context) SetMaxSegmentLength(n uint) {
context.params.SetMaxSegmentLength(int(n))
}
// Set token timestamps flag
func (context *context) SetTokenTimestamps(b bool) {
context.params.SetTokenTimestamps(b)
}
// Set max tokens per segment (0 = no limit)
func (context *context) SetMaxTokensPerSegment(n uint) {
context.params.SetMaxTokensPerSegment(int(n))
@ -147,12 +152,16 @@ func (context *context) WhisperLangAutoDetect(offset_ms int, n_threads int) ([]f
}
// Process new sample data and return any errors
func (context *context) Process(data []float32, cb SegmentCallback) error {
func (context *context) Process(
data []float32,
callNewSegment SegmentCallback,
callProgress ProgressCallback,
) error {
if context.model.ctx == nil {
return ErrInternalAppError
}
// If the callback is defined then we force on single_segment mode
if cb != nil {
if callNewSegment != nil {
context.params.SetSingleSegment(true)
}
@ -160,24 +169,28 @@ func (context *context) Process(data []float32, cb SegmentCallback) error {
processors := 0
if processors > 1 {
if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, nil, func(new int) {
if cb != nil {
if callNewSegment != nil {
num_segments := context.model.ctx.Whisper_full_n_segments()
s0 := num_segments - new
for i := s0; i < num_segments; i++ {
cb(toSegment(context.model.ctx, i))
callNewSegment(toSegment(context.model.ctx, i))
}
}
}); err != nil {
return err
}
} else if err := context.model.ctx.Whisper_full(context.params, data, nil, func(new int) {
if cb != nil {
if callNewSegment != nil {
num_segments := context.model.ctx.Whisper_full_n_segments()
s0 := num_segments - new
for i := s0; i < num_segments; i++ {
cb(toSegment(context.model.ctx, i))
callNewSegment(toSegment(context.model.ctx, i))
}
}
}, func(progress int) {
if callProgress != nil {
callProgress(progress)
}
}); err != nil {
return err
}
@ -280,10 +293,14 @@ func toSegment(ctx *whisper.Context, n int) Segment {
func toTokens(ctx *whisper.Context, n int) []Token {
result := make([]Token, ctx.Whisper_full_n_tokens(n))
for i := 0; i < len(result); i++ {
data := ctx.Whisper_full_get_token_data(n, i)
result[i] = Token{
Id: int(ctx.Whisper_full_get_token_id(n, i)),
Text: strings.TrimSpace(ctx.Whisper_full_get_token_text(n, i)),
P: ctx.Whisper_full_get_token_p(n, i),
Id: int(ctx.Whisper_full_get_token_id(n, i)),
Text: ctx.Whisper_full_get_token_text(n, i),
P: ctx.Whisper_full_get_token_p(n, i),
Start: time.Duration(data.T0()) * time.Millisecond * 10,
End: time.Duration(data.T1()) * time.Millisecond * 10,
}
}
return result

View File

@ -12,6 +12,10 @@ import (
// time. It is called during the Process function
type SegmentCallback func(Segment)
// ProgressCallback is the callback function for reporting progress during
// processing. It is called during the Process function
type ProgressCallback func(int)
// Model is the interface to a whisper model. Create a new model with the
// function whisper.New(string)
type Model interface {
@ -41,12 +45,13 @@ type Context interface {
SetTokenThreshold(float32) // Set timestamp token probability threshold
SetTokenSumThreshold(float32) // Set timestamp token sum probability threshold
SetMaxSegmentLength(uint) // Set max segment length in characters
SetTokenTimestamps(bool) // Set token timestamps flag
SetMaxTokensPerSegment(uint) // Set max tokens per segment (0 = no limit)
// Process mono audio data and return any errors.
// If defined, newly generated segments are passed to the
// callback function during processing.
Process([]float32, SegmentCallback) error
Process([]float32, SegmentCallback, ProgressCallback) error
// After process is called, return segments until the end of the stream
// is reached, when io.EOF is returned.
@ -85,7 +90,8 @@ type Segment struct {
// Token is a text or special token
type Token struct {
Id int
Text string
P float32
Id int
Text string
P float32
Start, End time.Duration
}

View File

@ -15,6 +15,7 @@ import (
#include <stdlib.h>
extern void callNewSegment(void* user_data, int new);
extern void callProgress(void* user_data, int progress);
extern bool callEncoderBegin(void* user_data);
// Text segment callback
@ -26,6 +27,15 @@ static void whisper_new_segment_cb(struct whisper_context* ctx, struct whisper_s
}
}
// Progress callback
// Called on every newly generated text segment
// Use the whisper_full_...() functions to obtain the text segments
static void whisper_progress_cb(struct whisper_context* ctx, struct whisper_state* state, int progress, void* user_data) {
if(user_data != NULL && ctx != NULL) {
callProgress(user_data, progress);
}
}
// Encoder begin callback
// If not NULL, called before the encoder starts
// If it returns false, the computation is aborted
@ -43,6 +53,8 @@ static struct whisper_full_params whisper_full_default_params_cb(struct whisper_
params.new_segment_callback_user_data = (void*)(ctx);
params.encoder_begin_callback = whisper_encoder_begin_cb;
params.encoder_begin_callback_user_data = (void*)(ctx);
params.progress_callback = whisper_progress_cb;
params.progress_callback_user_data = (void*)(ctx);
return params;
}
*/
@ -258,13 +270,13 @@ func (ctx *Context) Whisper_token_lang(lang_id int) Token {
}
// Task tokens
func Whisper_token_translate() Token {
return Token(C.whisper_token_translate())
func (ctx *Context) Whisper_token_translate() Token {
return Token(C.whisper_token_translate((*C.struct_whisper_context)(ctx)))
}
// Task tokens
func Whisper_token_transcribe() Token {
return Token(C.whisper_token_transcribe())
func (ctx *Context) Whisper_token_transcribe() Token {
return Token(C.whisper_token_transcribe((*C.struct_whisper_context)(ctx)))
}
// Performance information
@ -290,11 +302,19 @@ func (ctx *Context) Whisper_full_default_params(strategy SamplingStrategy) Param
// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
// Uses the specified decoding strategy to obtain the text.
func (ctx *Context) Whisper_full(params Params, samples []float32, encoderBeginCallback func() bool, newSegmentCallback func(int)) error {
func (ctx *Context) Whisper_full(
params Params,
samples []float32,
encoderBeginCallback func() bool,
newSegmentCallback func(int),
progressCallback func(int),
) error {
registerEncoderBeginCallback(ctx, encoderBeginCallback)
registerNewSegmentCallback(ctx, newSegmentCallback)
registerProgressCallback(ctx, progressCallback)
defer registerEncoderBeginCallback(ctx, nil)
defer registerNewSegmentCallback(ctx, nil)
defer registerProgressCallback(ctx, nil)
if C.whisper_full((*C.struct_whisper_context)(ctx), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples))) == 0 {
return nil
} else {
@ -318,6 +338,18 @@ func (ctx *Context) Whisper_full_parallel(params Params, samples []float32, proc
}
}
// Return the id of the autodetected language, returns -1 if not found
// Added to whisper.cpp in
// https://github.com/ggerganov/whisper.cpp/commit/a1c1583cc7cd8b75222857afc936f0638c5683d6
//
// Examples:
//
// "de" -> 2
// "german" -> 2
func (ctx *Context) Whisper_full_lang_id() int {
return int(C.whisper_full_lang_id((*C.struct_whisper_context)(ctx)))
}
// Number of generated text segments.
// A segment can be a few words, a sentence, or even a paragraph.
func (ctx *Context) Whisper_full_n_segments() int {
@ -356,7 +388,7 @@ func (ctx *Context) Whisper_full_get_token_id(segment int, token int) Token {
// Get token data for the specified token in the specified segment.
// This contains probabilities, timestamps, etc.
func (ctx *Context) whisper_full_get_token_data(segment int, token int) TokenData {
func (ctx *Context) Whisper_full_get_token_data(segment int, token int) TokenData {
return TokenData(C.whisper_full_get_token_data((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
}
@ -370,6 +402,7 @@ func (ctx *Context) Whisper_full_get_token_p(segment int, token int) float32 {
var (
cbNewSegment = make(map[unsafe.Pointer]func(int))
cbProgress = make(map[unsafe.Pointer]func(int))
cbEncoderBegin = make(map[unsafe.Pointer]func() bool)
)
@ -381,6 +414,14 @@ func registerNewSegmentCallback(ctx *Context, fn func(int)) {
}
}
func registerProgressCallback(ctx *Context, fn func(int)) {
if fn == nil {
delete(cbProgress, unsafe.Pointer(ctx))
} else {
cbProgress[unsafe.Pointer(ctx)] = fn
}
}
func registerEncoderBeginCallback(ctx *Context, fn func() bool) {
if fn == nil {
delete(cbEncoderBegin, unsafe.Pointer(ctx))
@ -396,6 +437,13 @@ func callNewSegment(user_data unsafe.Pointer, new C.int) {
}
}
//export callProgress
func callProgress(user_data unsafe.Pointer, progress C.int) {
if fn, ok := cbProgress[user_data]; ok {
fn(int(progress))
}
}
//export callEncoderBegin
func callEncoderBegin(user_data unsafe.Pointer) C.bool {
if fn, ok := cbEncoderBegin[user_data]; ok {
@ -407,3 +455,15 @@ func callEncoderBegin(user_data unsafe.Pointer) C.bool {
}
return true
}
func (t TokenData) T0() int64 {
return int64(t.t0)
}
func (t TokenData) T1() int64 {
return int64(t.t1)
}
func (t TokenData) Id() Token {
return Token(t.id)
}

View File

@ -52,7 +52,7 @@ func Test_Whisper_001(t *testing.T) {
defer ctx.Whisper_free()
params := ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY)
data := buf.AsFloat32Buffer().Data
err = ctx.Whisper_full(params, data, nil, nil)
err = ctx.Whisper_full(params, data, nil, nil, nil)
assert.NoError(err)
// Print out tokens

124
bindings/java/.idea/uiDesigner.xml generated Normal file
View File

@ -0,0 +1,124 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="Palette2">
<group name="Swing">
<item class="com.intellij.uiDesigner.HSpacer" tooltip-text="Horizontal Spacer" icon="/com/intellij/uiDesigner/icons/hspacer.svg" removable="false" auto-create-binding="false" can-attach-label="false">
<default-constraints vsize-policy="1" hsize-policy="6" anchor="0" fill="1" />
</item>
<item class="com.intellij.uiDesigner.VSpacer" tooltip-text="Vertical Spacer" icon="/com/intellij/uiDesigner/icons/vspacer.svg" removable="false" auto-create-binding="false" can-attach-label="false">
<default-constraints vsize-policy="6" hsize-policy="1" anchor="0" fill="2" />
</item>
<item class="javax.swing.JPanel" icon="/com/intellij/uiDesigner/icons/panel.svg" removable="false" auto-create-binding="false" can-attach-label="false">
<default-constraints vsize-policy="3" hsize-policy="3" anchor="0" fill="3" />
</item>
<item class="javax.swing.JScrollPane" icon="/com/intellij/uiDesigner/icons/scrollPane.svg" removable="false" auto-create-binding="false" can-attach-label="true">
<default-constraints vsize-policy="7" hsize-policy="7" anchor="0" fill="3" />
</item>
<item class="javax.swing.JButton" icon="/com/intellij/uiDesigner/icons/button.svg" removable="false" auto-create-binding="true" can-attach-label="false">
<default-constraints vsize-policy="0" hsize-policy="3" anchor="0" fill="1" />
<initial-values>
<property name="text" value="Button" />
</initial-values>
</item>
<item class="javax.swing.JRadioButton" icon="/com/intellij/uiDesigner/icons/radioButton.svg" removable="false" auto-create-binding="true" can-attach-label="false">
<default-constraints vsize-policy="0" hsize-policy="3" anchor="8" fill="0" />
<initial-values>
<property name="text" value="RadioButton" />
</initial-values>
</item>
<item class="javax.swing.JCheckBox" icon="/com/intellij/uiDesigner/icons/checkBox.svg" removable="false" auto-create-binding="true" can-attach-label="false">
<default-constraints vsize-policy="0" hsize-policy="3" anchor="8" fill="0" />
<initial-values>
<property name="text" value="CheckBox" />
</initial-values>
</item>
<item class="javax.swing.JLabel" icon="/com/intellij/uiDesigner/icons/label.svg" removable="false" auto-create-binding="false" can-attach-label="false">
<default-constraints vsize-policy="0" hsize-policy="0" anchor="8" fill="0" />
<initial-values>
<property name="text" value="Label" />
</initial-values>
</item>
<item class="javax.swing.JTextField" icon="/com/intellij/uiDesigner/icons/textField.svg" removable="false" auto-create-binding="true" can-attach-label="true">
<default-constraints vsize-policy="0" hsize-policy="6" anchor="8" fill="1">
<preferred-size width="150" height="-1" />
</default-constraints>
</item>
<item class="javax.swing.JPasswordField" icon="/com/intellij/uiDesigner/icons/passwordField.svg" removable="false" auto-create-binding="true" can-attach-label="true">
<default-constraints vsize-policy="0" hsize-policy="6" anchor="8" fill="1">
<preferred-size width="150" height="-1" />
</default-constraints>
</item>
<item class="javax.swing.JFormattedTextField" icon="/com/intellij/uiDesigner/icons/formattedTextField.svg" removable="false" auto-create-binding="true" can-attach-label="true">
<default-constraints vsize-policy="0" hsize-policy="6" anchor="8" fill="1">
<preferred-size width="150" height="-1" />
</default-constraints>
</item>
<item class="javax.swing.JTextArea" icon="/com/intellij/uiDesigner/icons/textArea.svg" removable="false" auto-create-binding="true" can-attach-label="true">
<default-constraints vsize-policy="6" hsize-policy="6" anchor="0" fill="3">
<preferred-size width="150" height="50" />
</default-constraints>
</item>
<item class="javax.swing.JTextPane" icon="/com/intellij/uiDesigner/icons/textPane.svg" removable="false" auto-create-binding="true" can-attach-label="true">
<default-constraints vsize-policy="6" hsize-policy="6" anchor="0" fill="3">
<preferred-size width="150" height="50" />
</default-constraints>
</item>
<item class="javax.swing.JEditorPane" icon="/com/intellij/uiDesigner/icons/editorPane.svg" removable="false" auto-create-binding="true" can-attach-label="true">
<default-constraints vsize-policy="6" hsize-policy="6" anchor="0" fill="3">
<preferred-size width="150" height="50" />
</default-constraints>
</item>
<item class="javax.swing.JComboBox" icon="/com/intellij/uiDesigner/icons/comboBox.svg" removable="false" auto-create-binding="true" can-attach-label="true">
<default-constraints vsize-policy="0" hsize-policy="2" anchor="8" fill="1" />
</item>
<item class="javax.swing.JTable" icon="/com/intellij/uiDesigner/icons/table.svg" removable="false" auto-create-binding="true" can-attach-label="false">
<default-constraints vsize-policy="6" hsize-policy="6" anchor="0" fill="3">
<preferred-size width="150" height="50" />
</default-constraints>
</item>
<item class="javax.swing.JList" icon="/com/intellij/uiDesigner/icons/list.svg" removable="false" auto-create-binding="true" can-attach-label="false">
<default-constraints vsize-policy="6" hsize-policy="2" anchor="0" fill="3">
<preferred-size width="150" height="50" />
</default-constraints>
</item>
<item class="javax.swing.JTree" icon="/com/intellij/uiDesigner/icons/tree.svg" removable="false" auto-create-binding="true" can-attach-label="false">
<default-constraints vsize-policy="6" hsize-policy="6" anchor="0" fill="3">
<preferred-size width="150" height="50" />
</default-constraints>
</item>
<item class="javax.swing.JTabbedPane" icon="/com/intellij/uiDesigner/icons/tabbedPane.svg" removable="false" auto-create-binding="true" can-attach-label="false">
<default-constraints vsize-policy="3" hsize-policy="3" anchor="0" fill="3">
<preferred-size width="200" height="200" />
</default-constraints>
</item>
<item class="javax.swing.JSplitPane" icon="/com/intellij/uiDesigner/icons/splitPane.svg" removable="false" auto-create-binding="false" can-attach-label="false">
<default-constraints vsize-policy="3" hsize-policy="3" anchor="0" fill="3">
<preferred-size width="200" height="200" />
</default-constraints>
</item>
<item class="javax.swing.JSpinner" icon="/com/intellij/uiDesigner/icons/spinner.svg" removable="false" auto-create-binding="true" can-attach-label="true">
<default-constraints vsize-policy="0" hsize-policy="6" anchor="8" fill="1" />
</item>
<item class="javax.swing.JSlider" icon="/com/intellij/uiDesigner/icons/slider.svg" removable="false" auto-create-binding="true" can-attach-label="false">
<default-constraints vsize-policy="0" hsize-policy="6" anchor="8" fill="1" />
</item>
<item class="javax.swing.JSeparator" icon="/com/intellij/uiDesigner/icons/separator.svg" removable="false" auto-create-binding="false" can-attach-label="false">
<default-constraints vsize-policy="6" hsize-policy="6" anchor="0" fill="3" />
</item>
<item class="javax.swing.JProgressBar" icon="/com/intellij/uiDesigner/icons/progressbar.svg" removable="false" auto-create-binding="true" can-attach-label="false">
<default-constraints vsize-policy="0" hsize-policy="6" anchor="0" fill="1" />
</item>
<item class="javax.swing.JToolBar" icon="/com/intellij/uiDesigner/icons/toolbar.svg" removable="false" auto-create-binding="false" can-attach-label="false">
<default-constraints vsize-policy="0" hsize-policy="6" anchor="0" fill="1">
<preferred-size width="-1" height="20" />
</default-constraints>
</item>
<item class="javax.swing.JToolBar$Separator" icon="/com/intellij/uiDesigner/icons/toolbarSeparator.svg" removable="false" auto-create-binding="false" can-attach-label="false">
<default-constraints vsize-policy="0" hsize-policy="0" anchor="0" fill="1" />
</item>
<item class="javax.swing.JScrollBar" icon="/com/intellij/uiDesigner/icons/scrollbar.svg" removable="false" auto-create-binding="true" can-attach-label="false">
<default-constraints vsize-policy="6" hsize-policy="0" anchor="0" fill="2" />
</item>
</group>
</component>
</project>

71
bindings/java/README.md Normal file
View File

@ -0,0 +1,71 @@
# Java JNI bindings for Whisper
This package provides Java JNI bindings for whisper.cpp. They have been tested on:
* <strike>Darwin (OS X) 12.6 on x64_64</strike>
* Ubuntu on x86_64
* Windows on x86_64
The "low level" bindings are in `WhisperCppJnaLibrary`. The most simple usage is as follows:
JNA will attempt to load the `whispercpp` shared library from:
- jna.library.path
- jna.platform.library
- ~/Library/Frameworks
- /Library/Frameworks
- /System/Library/Frameworks
- classpath
```java
import io.github.ggerganov.whispercpp.WhisperCpp;
public class Example {
public static void main(String[] args) {
WhisperCpp whisper = new WhisperCpp();
// By default, models are loaded from ~/.cache/whisper/ and are usually named "ggml-${name}.bin"
// or you can provide the absolute path to the model file.
long context = whisper.initContext("base.en");
try {
var whisperParams = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
// custom configuration if required
whisperParams.temperature_inc = 0f;
var samples = readAudio(); // divide each value by 32767.0f
whisper.fullTranscribe(whisperParams, samples);
int segmentCount = whisper.getTextSegmentCount(context);
for (int i = 0; i < segmentCount; i++) {
String text = whisper.getTextSegment(context, i);
System.out.println(segment.getText());
}
} finally {
whisper.freeContext(context);
}
}
}
```
## Building & Testing
In order to build, you need to have the JDK 8 or higher installed. Run the tests with:
```bash
git clone https://github.com/ggerganov/whisper.cpp.git
cd whisper.cpp/bindings/java
./gradlew build
```
You need to have the `whisper` library in your [JNA library path](https://java-native-access.github.io/jna/4.2.1/com/sun/jna/NativeLibrary.html). On Windows the dll is included in the jar and you can update it:
```bash
copy /y ..\..\build\bin\Release\whisper.dll build\generated\resources\main\win32-x86-64\whisper.dll
```
## License
The license for the Go bindings is the same as the license for the rest of the whisper.cpp project, which is the MIT License. See the `LICENSE` file for more details.

112
bindings/java/build.gradle Normal file
View File

@ -0,0 +1,112 @@
plugins {
id 'java'
id 'java-library'
id 'maven-publish'
}
archivesBaseName = 'whispercpp'
group = 'io.github.ggerganov'
version = '1.4.0'
sourceCompatibility = 1.8
targetCompatibility = 1.8
sourceSets {
main {
resources {
srcDirs = ['src/main/resources', 'build/generated/resources/main']
}
}
test {
runtimeClasspath += files('build/generated/resources/main')
}
}
tasks.register('copyLibwhisperDynlib', Copy) {
from '../../build'
include 'libwhisper.dynlib'
into 'build/generated/resources/main/darwin'
}
tasks.register('copyLibwhisperSo', Copy) {
from '../../build'
include 'libwhisper.so'
into 'build/generated/resources/main/linux-x86-64'
}
tasks.register('copyWhisperDll', Copy) {
from '../../build/Release'
include 'whisper.dll'
into 'build/generated/resources/main/windows-x86-64'
}
tasks.register('copyLibs') {
dependsOn copyLibwhisperDynlib, copyLibwhisperSo, copyWhisperDll
}
test {
systemProperty 'jna.library.path', project.file('build/generated/resources/main').absolutePath
}
java {
withSourcesJar()
withJavadocJar()
}
jar {
exclude '**/whisper_java.exp', '**/whisper_java.lib'
}
javadoc {
options.addStringOption('Xdoclint:none', '-quiet')
}
tasks.withType(Test) {
useJUnitPlatform()
}
dependencies {
implementation "net.java.dev.jna:jna:5.13.0"
testImplementation "org.junit.jupiter:junit-jupiter:5.9.2"
testImplementation "org.assertj:assertj-core:3.24.2"
}
repositories {
mavenCentral()
}
publishing {
publications {
mavenJava(MavenPublication) {
artifactId = 'whispercpp'
from components.java
pom {
name = 'whispercpp'
description = "Java JNA bindings for OpenAI's Whisper model, implemented in C/C++"
url = 'https://github.com/ggerganov/whisper.cpp'
licenses {
license {
name = 'MIT licence'
url = 'https://raw.githubusercontent.com/ggerganov/whisper.cpp/master/LICENSE'
}
}
developers {
developer {
id = 'ggerganov'
name = 'Georgi Gerganov'
email = 'ggerganov@gmail.com'
}
developer {
id = 'nalbion'
name = 'Nicholas Albion'
email = 'nalbion@yahoo.com'
}
}
scm {
connection = 'scm:git:git://github.com/ggerganov/whisper.cpp.git'
url = 'https://github.com/ggerganov/whisper.cpp'
}
}
}
}
}

View File

@ -0,0 +1,6 @@
org.gradle.jvmargs=-Xms256m -Xmx1024m
system.include.dir=/usr/include
#system.local.include.dir=../../include
system.local.include.dir=./build/generated/sources/headers/java/main
jni.include.dir=/usr/lib/jvm/java-8-openjdk-amd64/include/
jni.lib.dir=/usr/lib/jvm/java-8-openjdk-amd64/lib/

Binary file not shown.

View File

@ -0,0 +1,6 @@
distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists
distributionUrl=https\://services.gradle.org/distributions/gradle-8.1-bin.zip
networkTimeout=10000
zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists

244
bindings/java/gradlew vendored Normal file
View File

@ -0,0 +1,244 @@
#!/bin/sh
#
# Copyright © 2015-2021 the original authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
##############################################################################
#
# Gradle start up script for POSIX generated by Gradle.
#
# Important for running:
#
# (1) You need a POSIX-compliant shell to run this script. If your /bin/sh is
# noncompliant, but you have some other compliant shell such as ksh or
# bash, then to run this script, type that shell name before the whole
# command line, like:
#
# ksh Gradle
#
# Busybox and similar reduced shells will NOT work, because this script
# requires all of these POSIX shell features:
# * functions;
# * expansions «$var», «${var}», «${var:-default}», «${var+SET}»,
# «${var#prefix}», «${var%suffix}», and «$( cmd )»;
# * compound commands having a testable exit status, especially «case»;
# * various built-in commands including «command», «set», and «ulimit».
#
# Important for patching:
#
# (2) This script targets any POSIX shell, so it avoids extensions provided
# by Bash, Ksh, etc; in particular arrays are avoided.
#
# The "traditional" practice of packing multiple parameters into a
# space-separated string is a well documented source of bugs and security
# problems, so this is (mostly) avoided, by progressively accumulating
# options in "$@", and eventually passing that to Java.
#
# Where the inherited environment variables (DEFAULT_JVM_OPTS, JAVA_OPTS,
# and GRADLE_OPTS) rely on word-splitting, this is performed explicitly;
# see the in-line comments for details.
#
# There are tweaks for specific operating systems such as AIX, CygWin,
# Darwin, MinGW, and NonStop.
#
# (3) This script is generated from the Groovy template
# https://github.com/gradle/gradle/blob/HEAD/subprojects/plugins/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt
# within the Gradle project.
#
# You can find Gradle at https://github.com/gradle/gradle/.
#
##############################################################################
# Attempt to set APP_HOME
# Resolve links: $0 may be a link
app_path=$0
# Need this for daisy-chained symlinks.
while
APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path
[ -h "$app_path" ]
do
ls=$( ls -ld "$app_path" )
link=${ls#*' -> '}
case $link in #(
/*) app_path=$link ;; #(
*) app_path=$APP_HOME$link ;;
esac
done
# This is normally unused
# shellcheck disable=SC2034
APP_BASE_NAME=${0##*/}
APP_HOME=$( cd "${APP_HOME:-./}" && pwd -P ) || exit
# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"'
# Use the maximum available, or set MAX_FD != -1 to use that value.
MAX_FD=maximum
warn () {
echo "$*"
} >&2
die () {
echo
echo "$*"
echo
exit 1
} >&2
# OS specific support (must be 'true' or 'false').
cygwin=false
msys=false
darwin=false
nonstop=false
case "$( uname )" in #(
CYGWIN* ) cygwin=true ;; #(
Darwin* ) darwin=true ;; #(
MSYS* | MINGW* ) msys=true ;; #(
NONSTOP* ) nonstop=true ;;
esac
CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
# Determine the Java command to use to start the JVM.
if [ -n "$JAVA_HOME" ] ; then
if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
# IBM's JDK on AIX uses strange locations for the executables
JAVACMD=$JAVA_HOME/jre/sh/java
else
JAVACMD=$JAVA_HOME/bin/java
fi
if [ ! -x "$JAVACMD" ] ; then
die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
Please set the JAVA_HOME variable in your environment to match the
location of your Java installation."
fi
else
JAVACMD=java
which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
Please set the JAVA_HOME variable in your environment to match the
location of your Java installation."
fi
# Increase the maximum file descriptors if we can.
if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then
case $MAX_FD in #(
max*)
# In POSIX sh, ulimit -H is undefined. That's why the result is checked to see if it worked.
# shellcheck disable=SC3045
MAX_FD=$( ulimit -H -n ) ||
warn "Could not query maximum file descriptor limit"
esac
case $MAX_FD in #(
'' | soft) :;; #(
*)
# In POSIX sh, ulimit -n is undefined. That's why the result is checked to see if it worked.
# shellcheck disable=SC3045
ulimit -n "$MAX_FD" ||
warn "Could not set maximum file descriptor limit to $MAX_FD"
esac
fi
# Collect all arguments for the java command, stacking in reverse order:
# * args from the command line
# * the main class name
# * -classpath
# * -D...appname settings
# * --module-path (only if needed)
# * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables.
# For Cygwin or MSYS, switch paths to Windows format before running java
if "$cygwin" || "$msys" ; then
APP_HOME=$( cygpath --path --mixed "$APP_HOME" )
CLASSPATH=$( cygpath --path --mixed "$CLASSPATH" )
JAVACMD=$( cygpath --unix "$JAVACMD" )
# Now convert the arguments - kludge to limit ourselves to /bin/sh
for arg do
if
case $arg in #(
-*) false ;; # don't mess with options #(
/?*) t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath
[ -e "$t" ] ;; #(
*) false ;;
esac
then
arg=$( cygpath --path --ignore --mixed "$arg" )
fi
# Roll the args list around exactly as many times as the number of
# args, so each arg winds up back in the position where it started, but
# possibly modified.
#
# NB: a `for` loop captures its iteration list before it begins, so
# changing the positional parameters here affects neither the number of
# iterations, nor the values presented in `arg`.
shift # remove old arg
set -- "$@" "$arg" # push replacement arg
done
fi
# Collect all arguments for the java command;
# * $DEFAULT_JVM_OPTS, $JAVA_OPTS, and $GRADLE_OPTS can contain fragments of
# shell script including quotes and variable substitutions, so put them in
# double quotes to make sure that they get re-expanded; and
# * put everything else in single quotes, so that it's not re-expanded.
set -- \
"-Dorg.gradle.appname=$APP_BASE_NAME" \
-classpath "$CLASSPATH" \
org.gradle.wrapper.GradleWrapperMain \
"$@"
# Stop when "xargs" is not available.
if ! command -v xargs >/dev/null 2>&1
then
die "xargs is not available"
fi
# Use "xargs" to parse quoted args.
#
# With -n1 it outputs one arg per line, with the quotes and backslashes removed.
#
# In Bash we could simply go:
#
# readarray ARGS < <( xargs -n1 <<<"$var" ) &&
# set -- "${ARGS[@]}" "$@"
#
# but POSIX shell has neither arrays nor command substitution, so instead we
# post-process each arg (as a line of input to sed) to backslash-escape any
# character that might be a shell metacharacter, then use eval to reverse
# that process (while maintaining the separation between arguments), and wrap
# the whole thing up as a single "set" statement.
#
# This will of course break if any of these variables contains a newline or
# an unmatched quote.
#
eval "set -- $(
printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" |
xargs -n1 |
sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' |
tr '\n' ' '
)" '"$@"'
exec "$JAVACMD" "$@"

92
bindings/java/gradlew.bat vendored Normal file
View File

@ -0,0 +1,92 @@
@rem
@rem Copyright 2015 the original author or authors.
@rem
@rem Licensed under the Apache License, Version 2.0 (the "License");
@rem you may not use this file except in compliance with the License.
@rem You may obtain a copy of the License at
@rem
@rem https://www.apache.org/licenses/LICENSE-2.0
@rem
@rem Unless required by applicable law or agreed to in writing, software
@rem distributed under the License is distributed on an "AS IS" BASIS,
@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@rem See the License for the specific language governing permissions and
@rem limitations under the License.
@rem
@if "%DEBUG%"=="" @echo off
@rem ##########################################################################
@rem
@rem Gradle startup script for Windows
@rem
@rem ##########################################################################
@rem Set local scope for the variables with windows NT shell
if "%OS%"=="Windows_NT" setlocal
set DIRNAME=%~dp0
if "%DIRNAME%"=="" set DIRNAME=.
@rem This is normally unused
set APP_BASE_NAME=%~n0
set APP_HOME=%DIRNAME%
@rem Resolve any "." and ".." in APP_HOME to make it shorter.
for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi
@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m"
@rem Find java.exe
if defined JAVA_HOME goto findJavaFromJavaHome
set JAVA_EXE=java.exe
%JAVA_EXE% -version >NUL 2>&1
if %ERRORLEVEL% equ 0 goto execute
echo.
echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
echo.
echo Please set the JAVA_HOME variable in your environment to match the
echo location of your Java installation.
goto fail
:findJavaFromJavaHome
set JAVA_HOME=%JAVA_HOME:"=%
set JAVA_EXE=%JAVA_HOME%/bin/java.exe
if exist "%JAVA_EXE%" goto execute
echo.
echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME%
echo.
echo Please set the JAVA_HOME variable in your environment to match the
echo location of your Java installation.
goto fail
:execute
@rem Setup the command line
set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
@rem Execute Gradle
"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %*
:end
@rem End local scope for the variables with windows NT shell
if %ERRORLEVEL% equ 0 goto mainEnd
:fail
rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
rem the _cmd.exe /c_ return code!
set EXIT_CODE=%ERRORLEVEL%
if %EXIT_CODE% equ 0 set EXIT_CODE=1
if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE%
exit /b %EXIT_CODE%
:mainEnd
if "%OS%"=="Windows_NT" endlocal
:omega

View File

@ -0,0 +1 @@
rootProject.name = "whispercpp"

View File

@ -0,0 +1,39 @@
package io.github.ggerganov.whispercpp;
import com.sun.jna.Structure;
import com.sun.jna.ptr.PointerByReference;
import io.github.ggerganov.whispercpp.ggml.GgmlType;
import io.github.ggerganov.whispercpp.WhisperModel;
import java.util.List;
public class WhisperContext extends Structure {
int t_load_us = 0;
int t_start_us = 0;
/** weight type (FP32 / FP16 / QX) */
GgmlType wtype = GgmlType.GGML_TYPE_F16;
/** intermediate type (FP32 or FP16) */
GgmlType itype = GgmlType.GGML_TYPE_F16;
// WhisperModel model;
public PointerByReference model;
// whisper_vocab vocab;
// whisper_state * state = nullptr;
public PointerByReference vocab;
public PointerByReference state;
/** populated by whisper_init_from_file() */
String path_model;
// public static class ByReference extends WhisperContext implements Structure.ByReference {
// }
//
// public static class ByValue extends WhisperContext implements Structure.ByValue {
// }
//
// @Override
// protected List<String> getFieldOrder() {
// return List.of("t_load_us", "t_start_us", "wtype", "itype", "model", "vocab", "state", "path_model");
// }
}

View File

@ -0,0 +1,151 @@
package io.github.ggerganov.whispercpp;
import com.sun.jna.Native;
import com.sun.jna.Pointer;
import io.github.ggerganov.whispercpp.params.WhisperFullParams;
import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
/**
* Before calling most methods, you must call `initContext(modelPath)` to initialise the `ctx` Pointer.
*/
public class WhisperCpp implements AutoCloseable {
private WhisperCppJnaLibrary lib = WhisperCppJnaLibrary.instance;
private Pointer ctx = null;
private Pointer greedyPointer = null;
private Pointer beamPointer = null;
public File modelDir() {
String modelDirPath = System.getenv("XDG_CACHE_HOME");
if (modelDirPath == null) {
modelDirPath = System.getProperty("user.home") + "/.cache";
}
return new File(modelDirPath, "whisper");
}
/**
* @param modelPath - absolute path, or just the name (eg: "base", "base-en" or "base.en")
*/
public void initContext(String modelPath) throws FileNotFoundException {
if (ctx != null) {
lib.whisper_free(ctx);
}
if (!modelPath.contains("/") && !modelPath.contains("\\")) {
if (!modelPath.endsWith(".bin")) {
modelPath = "ggml-" + modelPath.replace("-", ".") + ".bin";
}
modelPath = new File(modelDir(), modelPath).getAbsolutePath();
}
ctx = lib.whisper_init_from_file(modelPath);
if (ctx == null) {
throw new FileNotFoundException(modelPath);
}
}
/**
* Provides default params which can be used with `whisper_full()` etc.
* Because this function allocates memory for the params, the caller must call either:
* - call `whisper_free_params()`
* - `Native.free(Pointer.nativeValue(pointer));`
*
* @param strategy - GREEDY
*/
public WhisperFullParams getFullDefaultParams(WhisperSamplingStrategy strategy) {
Pointer pointer;
// whisper_full_default_params_by_ref allocates memory which we need to delete, so only create max 1 pointer for each strategy.
if (strategy == WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY) {
if (greedyPointer == null) {
greedyPointer = lib.whisper_full_default_params_by_ref(strategy.ordinal());
}
pointer = greedyPointer;
} else {
if (beamPointer == null) {
beamPointer = lib.whisper_full_default_params_by_ref(strategy.ordinal());
}
pointer = beamPointer;
}
WhisperFullParams params = new WhisperFullParams(pointer);
params.read();
return params;
}
@Override
public void close() {
freeContext();
freeParams();
System.out.println("Whisper closed");
}
private void freeContext() {
if (ctx != null) {
lib.whisper_free(ctx);
}
}
private void freeParams() {
if (greedyPointer != null) {
Native.free(Pointer.nativeValue(greedyPointer));
greedyPointer = null;
}
if (beamPointer != null) {
Native.free(Pointer.nativeValue(beamPointer));
beamPointer = null;
}
}
/**
* Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text.
* Not thread safe for same context
* Uses the specified decoding strategy to obtain the text.
*/
public String fullTranscribe(WhisperFullParams whisperParams, float[] audioData) throws IOException {
if (ctx == null) {
throw new IllegalStateException("Model not initialised");
}
if (lib.whisper_full(ctx, whisperParams, audioData, audioData.length) != 0) {
throw new IOException("Failed to process audio");
}
int nSegments = lib.whisper_full_n_segments(ctx);
StringBuilder str = new StringBuilder();
for (int i = 0; i < nSegments; i++) {
String text = lib.whisper_full_get_segment_text(ctx, i);
System.out.println("Segment:" + text);
str.append(text);
}
return str.toString().trim();
}
// public int getTextSegmentCount(Pointer ctx) {
// return lib.whisper_full_n_segments(ctx);
// }
// public String getTextSegment(Pointer ctx, int index) {
// return lib.whisper_full_get_segment_text(ctx, index);
// }
public String getSystemInfo() {
return lib.whisper_print_system_info();
}
public int benchMemcpy(int nthread) {
return lib.whisper_bench_memcpy(nthread);
}
public int benchGgmlMulMat(int nthread) {
return lib.whisper_bench_ggml_mul_mat(nthread);
}
}

View File

@ -0,0 +1,376 @@
package io.github.ggerganov.whispercpp;
import com.sun.jna.Library;
import com.sun.jna.Native;
import com.sun.jna.Pointer;
import io.github.ggerganov.whispercpp.model.WhisperModelLoader;
import io.github.ggerganov.whispercpp.model.WhisperTokenData;
import io.github.ggerganov.whispercpp.params.WhisperFullParams;
public interface WhisperCppJnaLibrary extends Library {
WhisperCppJnaLibrary instance = Native.load("whisper", WhisperCppJnaLibrary.class);
String whisper_print_system_info();
/**
* Allocate (almost) all memory needed for the model by loading from a file.
*
* @param path_model Path to the model file
* @return Whisper context on success, null on failure
*/
Pointer whisper_init_from_file(String path_model);
/**
* Allocate (almost) all memory needed for the model by loading from a buffer.
*
* @param buffer Model buffer
* @param buffer_size Size of the model buffer
* @return Whisper context on success, null on failure
*/
Pointer whisper_init_from_buffer(Pointer buffer, int buffer_size);
/**
* Allocate (almost) all memory needed for the model using a model loader.
*
* @param loader Model loader
* @return Whisper context on success, null on failure
*/
Pointer whisper_init(WhisperModelLoader loader);
/**
* Allocate (almost) all memory needed for the model by loading from a file without allocating the state.
*
* @param path_model Path to the model file
* @return Whisper context on success, null on failure
*/
Pointer whisper_init_from_file_no_state(String path_model);
/**
* Allocate (almost) all memory needed for the model by loading from a buffer without allocating the state.
*
* @param buffer Model buffer
* @param buffer_size Size of the model buffer
* @return Whisper context on success, null on failure
*/
Pointer whisper_init_from_buffer_no_state(Pointer buffer, int buffer_size);
// Pointer whisper_init_from_buffer_no_state(Pointer buffer, long buffer_size);
/**
* Allocate (almost) all memory needed for the model using a model loader without allocating the state.
*
* @param loader Model loader
* @return Whisper context on success, null on failure
*/
Pointer whisper_init_no_state(WhisperModelLoader loader);
/**
* Allocate memory for the Whisper state.
*
* @param ctx Whisper context
* @return Whisper state on success, null on failure
*/
Pointer whisper_init_state(Pointer ctx);
/**
* Free all allocated memory associated with the Whisper context.
*
* @param ctx Whisper context
*/
void whisper_free(Pointer ctx);
/**
* Free all allocated memory associated with the Whisper state.
*
* @param state Whisper state
*/
void whisper_free_state(Pointer state);
/**
* Convert RAW PCM audio to log mel spectrogram.
* The resulting spectrogram is stored inside the default state of the provided whisper context.
*
* @param ctx - Pointer to a WhisperContext
* @return 0 on success
*/
int whisper_pcm_to_mel(Pointer ctx, final float[] samples, int n_samples, int n_threads);
/**
* @param ctx Pointer to a WhisperContext
* @param state Pointer to WhisperState
* @param n_samples
* @param n_threads
* @return 0 on success
*/
int whisper_pcm_to_mel_with_state(Pointer ctx, Pointer state, final float[] samples, int n_samples, int n_threads);
/**
* This can be used to set a custom log mel spectrogram inside the default state of the provided whisper context.
* Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
* n_mel must be 80
* @return 0 on success
*/
int whisper_set_mel(Pointer ctx, final float[] data, int n_len, int n_mel);
int whisper_set_mel_with_state(Pointer ctx, Pointer state, final float[] data, int n_len, int n_mel);
/**
* Run the Whisper encoder on the log mel spectrogram stored inside the default state in the provided whisper context.
* Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
* Offset can be used to specify the offset of the first frame in the spectrogram.
* @return 0 on success
*/
int whisper_encode(Pointer ctx, int offset, int n_threads);
int whisper_encode_with_state(Pointer ctx, Pointer state, int offset, int n_threads);
/**
* Run the Whisper decoder to obtain the logits and probabilities for the next token.
* Make sure to call whisper_encode() first.
* tokens + n_tokens is the provided context for the decoder.
* n_past is the number of tokens to use from previous decoder calls.
* Returns 0 on success
* TODO: add support for multiple decoders
*/
int whisper_decode(Pointer ctx, Pointer tokens, int n_tokens, int n_past, int n_threads);
/**
* @param ctx
* @param state
* @param tokens Pointer to int tokens
* @param n_tokens
* @param n_past
* @param n_threads
* @return
*/
int whisper_decode_with_state(Pointer ctx, Pointer state, Pointer tokens, int n_tokens, int n_past, int n_threads);
/**
* Convert the provided text into tokens.
* The tokens pointer must be large enough to hold the resulting tokens.
* Returns the number of tokens on success, no more than n_max_tokens
* Returns -1 on failure
* TODO: not sure if correct
*/
int whisper_tokenize(Pointer ctx, String text, Pointer tokens, int n_max_tokens);
/** Largest language id (i.e. number of available languages - 1) */
int whisper_lang_max_id();
/**
* @return the id of the specified language, returns -1 if not found.
* Examples:
* "de" -> 2
* "german" -> 2
*/
int whisper_lang_id(String lang);
/** @return the short string of the specified language id (e.g. 2 -> "de"), returns nullptr if not found */
String whisper_lang_str(int id);
/**
* Use mel data at offset_ms to try and auto-detect the spoken language.
* Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first
* Returns the top language id or negative on failure
* If not null, fills the lang_probs array with the probabilities of all languages
* The array must be whisper_lang_max_id() + 1 in size
*
* ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69
*/
int whisper_lang_auto_detect(Pointer ctx, int offset_ms, int n_threads, float[] lang_probs);
int whisper_lang_auto_detect_with_state(Pointer ctx, Pointer state, int offset_ms, int n_threads, float[] lang_probs);
int whisper_n_len (Pointer ctx); // mel length
int whisper_n_len_from_state(Pointer state); // mel length
int whisper_n_vocab (Pointer ctx);
int whisper_n_text_ctx (Pointer ctx);
int whisper_n_audio_ctx (Pointer ctx);
int whisper_is_multilingual (Pointer ctx);
int whisper_model_n_vocab (Pointer ctx);
int whisper_model_n_audio_ctx (Pointer ctx);
int whisper_model_n_audio_state(Pointer ctx);
int whisper_model_n_audio_head (Pointer ctx);
int whisper_model_n_audio_layer(Pointer ctx);
int whisper_model_n_text_ctx (Pointer ctx);
int whisper_model_n_text_state (Pointer ctx);
int whisper_model_n_text_head (Pointer ctx);
int whisper_model_n_text_layer (Pointer ctx);
int whisper_model_n_mels (Pointer ctx);
int whisper_model_ftype (Pointer ctx);
int whisper_model_type (Pointer ctx);
/**
* Token logits obtained from the last call to whisper_decode().
* The logits for the last token are stored in the last row
* Rows: n_tokens
* Cols: n_vocab
*/
float[] whisper_get_logits (Pointer ctx);
float[] whisper_get_logits_from_state(Pointer state);
// Token Id -> String. Uses the vocabulary in the provided context
String whisper_token_to_str(Pointer ctx, int token);
String whisper_model_type_readable(Pointer ctx);
// Special tokens
int whisper_token_eot (Pointer ctx);
int whisper_token_sot (Pointer ctx);
int whisper_token_prev(Pointer ctx);
int whisper_token_solm(Pointer ctx);
int whisper_token_not (Pointer ctx);
int whisper_token_beg (Pointer ctx);
int whisper_token_lang(Pointer ctx, int lang_id);
// Task tokens
int whisper_token_translate (Pointer ctx);
int whisper_token_transcribe(Pointer ctx);
// Performance information from the default state.
void whisper_print_timings(Pointer ctx);
void whisper_reset_timings(Pointer ctx);
// Note: Even if `whisper_full_params is stripped back to just 4 ints, JNA throws "Invalid memory access"
// when `whisper_full_default_params()` tries to return a struct.
// WhisperFullParams whisper_full_default_params(int strategy);
/**
* Provides default params which can be used with `whisper_full()` etc.
* Because this function allocates memory for the params, the caller must call either:
* - call `whisper_free_params()`
* - `Native.free(Pointer.nativeValue(pointer));`
*
* @param strategy - WhisperSamplingStrategy.value
*/
Pointer whisper_full_default_params_by_ref(int strategy);
void whisper_free_params(Pointer params);
/**
* Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
* Not thread safe for same context
* Uses the specified decoding strategy to obtain the text.
*/
int whisper_full(Pointer ctx, WhisperFullParams params, final float[] samples, int n_samples);
int whisper_full_with_state(Pointer ctx, Pointer state, WhisperFullParams params, final float[] samples, int n_samples);
// Split the input audio in chunks and process each chunk separately using whisper_full_with_state()
// Result is stored in the default state of the context
// Not thread safe if executed in parallel on the same context.
// It seems this approach can offer some speedup in some cases.
// However, the transcription accuracy can be worse at the beginning and end of each chunk.
int whisper_full_parallel(Pointer ctx, WhisperFullParams params, final float[] samples, int n_samples, int n_processors);
/**
* Number of generated text segments.
* A segment can be a few words, a sentence, or even a paragraph.
* @param ctx Pointer to WhisperContext
*/
int whisper_full_n_segments (Pointer ctx);
/**
* @param state Pointer to WhisperState
*/
int whisper_full_n_segments_from_state(Pointer state);
/**
* Language id associated with the context's default state.
* @param ctx Pointer to WhisperContext
*/
int whisper_full_lang_id(Pointer ctx);
/** Language id associated with the provided state */
int whisper_full_lang_id_from_state(Pointer state);
/**
* Convert RAW PCM audio to log mel spectrogram but applies a Phase Vocoder to speed up the audio x2.
* The resulting spectrogram is stored inside the default state of the provided whisper context.
* @return 0 on success
*/
int whisper_pcm_to_mel_phase_vocoder(Pointer ctx, final float[] samples, int n_samples, int n_threads);
int whisper_pcm_to_mel_phase_vocoder_with_state(Pointer ctx, Pointer state, final float[] samples, int n_samples, int n_threads);
/** Get the start time of the specified segment. */
long whisper_full_get_segment_t0(Pointer ctx, int i_segment);
/** Get the start time of the specified segment from the state. */
long whisper_full_get_segment_t0_from_state(Pointer state, int i_segment);
/** Get the end time of the specified segment. */
long whisper_full_get_segment_t1(Pointer ctx, int i_segment);
/** Get the end time of the specified segment from the state. */
long whisper_full_get_segment_t1_from_state(Pointer state, int i_segment);
/** Get the text of the specified segment. */
String whisper_full_get_segment_text(Pointer ctx, int i_segment);
/** Get the text of the specified segment from the state. */
String whisper_full_get_segment_text_from_state(Pointer state, int i_segment);
/** Get the number of tokens in the specified segment. */
int whisper_full_n_tokens(Pointer ctx, int i_segment);
/** Get the number of tokens in the specified segment from the state. */
int whisper_full_n_tokens_from_state(Pointer state, int i_segment);
/** Get the token text of the specified token in the specified segment. */
String whisper_full_get_token_text(Pointer ctx, int i_segment, int i_token);
/** Get the token text of the specified token in the specified segment from the state. */
String whisper_full_get_token_text_from_state(Pointer ctx, Pointer state, int i_segment, int i_token);
/** Get the token ID of the specified token in the specified segment. */
int whisper_full_get_token_id(Pointer ctx, int i_segment, int i_token);
/** Get the token ID of the specified token in the specified segment from the state. */
int whisper_full_get_token_id_from_state(Pointer state, int i_segment, int i_token);
/** Get token data for the specified token in the specified segment. */
WhisperTokenData whisper_full_get_token_data(Pointer ctx, int i_segment, int i_token);
/** Get token data for the specified token in the specified segment from the state. */
WhisperTokenData whisper_full_get_token_data_from_state(Pointer state, int i_segment, int i_token);
/** Get the probability of the specified token in the specified segment. */
float whisper_full_get_token_p(Pointer ctx, int i_segment, int i_token);
/** Get the probability of the specified token in the specified segment from the state. */
float whisper_full_get_token_p_from_state(Pointer state, int i_segment, int i_token);
/**
* Benchmark function for memcpy.
*
* @param nThreads Number of threads to use for the benchmark.
* @return The result of the benchmark.
*/
int whisper_bench_memcpy(int nThreads);
/**
* Benchmark function for memcpy as a string.
*
* @param nThreads Number of threads to use for the benchmark.
* @return The result of the benchmark as a string.
*/
String whisper_bench_memcpy_str(int nThreads);
/**
* Benchmark function for ggml_mul_mat.
*
* @param nThreads Number of threads to use for the benchmark.
* @return The result of the benchmark.
*/
int whisper_bench_ggml_mul_mat(int nThreads);
/**
* Benchmark function for ggml_mul_mat as a string.
*
* @param nThreads Number of threads to use for the benchmark.
* @return The result of the benchmark as a string.
*/
String whisper_bench_ggml_mul_mat_str(int nThreads);
}

View File

@ -0,0 +1,24 @@
package io.github.ggerganov.whispercpp.callbacks;
import com.sun.jna.Callback;
import com.sun.jna.Pointer;
import io.github.ggerganov.whispercpp.WhisperContext;
import io.github.ggerganov.whispercpp.model.WhisperState;
/**
* Callback before the encoder starts.
* If not null, called before the encoder starts.
* If it returns false, the computation is aborted.
*/
public interface WhisperEncoderBeginCallback extends Callback {
/**
* Callback method before the encoder starts.
*
* @param ctx The whisper context.
* @param state The whisper state.
* @param user_data User data.
* @return True if the computation should proceed, false otherwise.
*/
boolean callback(Pointer ctx, Pointer state, Pointer user_data);
}

View File

@ -0,0 +1,25 @@
package io.github.ggerganov.whispercpp.callbacks;
import com.sun.jna.Callback;
import com.sun.jna.Pointer;
import io.github.ggerganov.whispercpp.model.WhisperTokenData;
/**
* Callback to filter logits.
* Can be used to modify the logits before sampling.
* If not null, called after applying temperature to logits.
*/
public interface WhisperLogitsFilterCallback extends Callback {
/**
* Callback method to filter logits.
*
* @param ctx The whisper context.
* @param state The whisper state.
* @param tokens The array of whisper_token_data.
* @param n_tokens The number of tokens.
* @param logits The array of logits.
* @param user_data User data.
*/
void callback(Pointer ctx, Pointer state, WhisperTokenData[] tokens, int n_tokens, float[] logits, Pointer user_data);
}

View File

@ -0,0 +1,24 @@
package io.github.ggerganov.whispercpp.callbacks;
import com.sun.jna.Callback;
import com.sun.jna.Pointer;
import io.github.ggerganov.whispercpp.WhisperContext;
import io.github.ggerganov.whispercpp.model.WhisperState;
/**
* Callback for the text segment.
* Called on every newly generated text segment.
* Use the whisper_full_...() functions to obtain the text segments.
*/
public interface WhisperNewSegmentCallback extends Callback {
/**
* Callback method for the text segment.
*
* @param ctx The whisper context.
* @param state The whisper state.
* @param n_new The number of newly generated text segments.
* @param user_data User data.
*/
void callback(Pointer ctx, Pointer state, int n_new, Pointer user_data);
}

View File

@ -0,0 +1,22 @@
package io.github.ggerganov.whispercpp.callbacks;
import com.sun.jna.Callback;
import com.sun.jna.Pointer;
import io.github.ggerganov.whispercpp.WhisperContext;
import io.github.ggerganov.whispercpp.model.WhisperState;
/**
* Callback for progress updates.
*/
public interface WhisperProgressCallback extends Callback {
/**
* Callback method for progress updates.
*
* @param ctx The whisper context.
* @param state The whisper state.
* @param progress The progress value.
* @param user_data User data.
*/
void callback(Pointer ctx, Pointer state, int progress, Pointer user_data);
}

View File

@ -0,0 +1,4 @@
package io.github.ggerganov.whispercpp.ggml;
public class GgmlTensor {
}

View File

@ -0,0 +1,18 @@
package io.github.ggerganov.whispercpp.ggml;
public enum GgmlType {
GGML_TYPE_F32,
GGML_TYPE_F16,
GGML_TYPE_Q4_0,
GGML_TYPE_Q4_1,
REMOVED_GGML_TYPE_Q4_2, // support has been removed
REMOVED_GGML_TYPE_Q4_3, // support has been removed
GGML_TYPE_Q5_0,
GGML_TYPE_Q5_1,
GGML_TYPE_Q8_0,
GGML_TYPE_Q8_1,
GGML_TYPE_I8,
GGML_TYPE_I16,
GGML_TYPE_I32,
GGML_TYPE_COUNT,
}

View File

@ -0,0 +1,10 @@
package io.github.ggerganov.whispercpp.model;
public enum EModel {
MODEL_UNKNOWN,
MODEL_TINY,
MODEL_BASE,
MODEL_SMALL,
MODEL_MEDIUM,
MODEL_LARGE,
}

View File

@ -0,0 +1,49 @@
package io.github.ggerganov.whispercpp;
import io.github.ggerganov.whispercpp.ggml.GgmlTensor;
import io.github.ggerganov.whispercpp.model.EModel;
public class WhisperModel {
// EModel type = EModel.MODEL_UNKNOWN;
//
// WhisperHParams hparams;
// WhisperFilters filters;
//
// // encoder.positional_embedding
// GgmlTensor e_pe;
//
// // encoder.conv1
// GgmlTensor e_conv_1_w;
// GgmlTensor e_conv_1_b;
//
// // encoder.conv2
// GgmlTensor e_conv_2_w;
// GgmlTensor e_conv_2_b;
//
// // encoder.ln_post
// GgmlTensor e_ln_w;
// GgmlTensor e_ln_b;
//
// // decoder.positional_embedding
// GgmlTensor d_pe;
//
// // decoder.token_embedding
// GgmlTensor d_te;
//
// // decoder.ln
// GgmlTensor d_ln_w;
// GgmlTensor d_ln_b;
//
// std::vector<whisper_layer_encoder> layers_encoder;
// std::vector<whisper_layer_decoder> layers_decoder;
//
// // context
// struct ggml_context * ctx;
//
// // the model memory buffer is read-only and can be shared between processors
// std::vector<uint8_t> * buf;
//
// // tensors
// int n_loaded;
// Map<String, GgmlTensor> tensors;
}

View File

@ -0,0 +1,62 @@
package io.github.ggerganov.whispercpp.model;
import com.sun.jna.Callback;
import com.sun.jna.Pointer;
import com.sun.jna.Structure;
public class WhisperModelLoader extends Structure {
public Pointer context;
public ReadFunction read;
public EOFFunction eof;
public CloseFunction close;
public static class ReadFunction implements Callback {
public Pointer invoke(Pointer ctx, Pointer output, int readSize) {
// TODO
return ctx;
}
}
public static class EOFFunction implements Callback {
public boolean invoke(Pointer ctx) {
// TODO
return false;
}
}
public static class CloseFunction implements Callback {
public void invoke(Pointer ctx) {
// TODO
}
}
// public WhisperModelLoader(Pointer p) {
// super(p);
// read = new ReadFunction();
// eof = new EOFFunction();
// close = new CloseFunction();
// read.setCallback(this);
// eof.setCallback(this);
// close.setCallback(this);
// read.write();
// eof.write();
// close.write();
// }
public WhisperModelLoader() {
super();
}
public interface ReadCallback extends Callback {
Pointer invoke(Pointer ctx, Pointer output, int readSize);
}
public interface EOFCallback extends Callback {
boolean invoke(Pointer ctx);
}
public interface CloseCallback extends Callback {
void invoke(Pointer ctx);
}
}

View File

@ -0,0 +1,4 @@
package io.github.ggerganov.whispercpp.model;
public class WhisperState {
}

View File

@ -0,0 +1,50 @@
package io.github.ggerganov.whispercpp.model;
import com.sun.jna.Structure;
import java.util.Arrays;
import java.util.List;
/**
* Structure representing token data.
*/
public class WhisperTokenData extends Structure {
/** Token ID. */
public int id;
/** Forced timestamp token ID. */
public int tid;
/** Probability of the token. */
public float p;
/** Log probability of the token. */
public float plog;
/** Probability of the timestamp token. */
public float pt;
/** Sum of probabilities of all timestamp tokens. */
public float ptsum;
/**
* Start time of the token (token-level timestamp data).
* Do not use if you haven't computed token-level timestamps.
*/
public long t0;
/**
* End time of the token (token-level timestamp data).
* Do not use if you haven't computed token-level timestamps.
*/
public long t1;
/** Voice length of the token. */
public float vlen;
@Override
protected List<String> getFieldOrder() {
return Arrays.asList("id", "tid", "p", "plog", "pt", "ptsum", "t0", "t1", "vlen");
}
}

View File

@ -0,0 +1,19 @@
package io.github.ggerganov.whispercpp.params;
import com.sun.jna.Structure;
import java.util.Arrays;
import java.util.List;
public class BeamSearchParams extends Structure {
/** ref: <a href="https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L265">...</a> */
public int beam_size;
/** ref: <a href="https://arxiv.org/pdf/2204.05424.pdf">...</a> */
public float patience;
@Override
protected List<String> getFieldOrder() {
return Arrays.asList("beam_size", "patience");
}
}

View File

@ -0,0 +1,30 @@
package io.github.ggerganov.whispercpp.params;
import com.sun.jna.IntegerType;
import java.util.function.BooleanSupplier;
public class CBool extends IntegerType implements BooleanSupplier {
public static final int SIZE = 1;
public static final CBool FALSE = new CBool(0);
public static final CBool TRUE = new CBool(1);
public CBool() {
this(0);
}
public CBool(long value) {
super(SIZE, value, true);
}
@Override
public boolean getAsBoolean() {
return intValue() == 1;
}
@Override
public String toString() {
return intValue() == 1 ? "true" : "false";
}
}

View File

@ -0,0 +1,16 @@
package io.github.ggerganov.whispercpp.params;
import com.sun.jna.Structure;
import java.util.Collections;
import java.util.List;
public class GreedyParams extends Structure {
/** <a href="https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L264">...</a> */
public int best_of;
@Override
protected List<String> getFieldOrder() {
return Collections.singletonList("best_of");
}
}

View File

@ -0,0 +1,10 @@
package io.github.ggerganov.whispercpp.params;
import java.util.List;
public class WhisperFilters {
int n_mel;
int n_fft;
List<Float> data;
}

View File

@ -0,0 +1,321 @@
package io.github.ggerganov.whispercpp.params;
import com.sun.jna.*;
import io.github.ggerganov.whispercpp.callbacks.WhisperEncoderBeginCallback;
import io.github.ggerganov.whispercpp.callbacks.WhisperLogitsFilterCallback;
import io.github.ggerganov.whispercpp.callbacks.WhisperNewSegmentCallback;
import io.github.ggerganov.whispercpp.callbacks.WhisperProgressCallback;
import java.util.Arrays;
import java.util.List;
/**
* Parameters for the whisper_full() function.
* If you change the order or add new parameters, make sure to update the default values in whisper.cpp:
* whisper_full_default_params()
*/
public class WhisperFullParams extends Structure {
public WhisperFullParams(Pointer p) {
super(p);
// super(p, ALIGN_MSVC);
// super(p, ALIGN_GNUC);
}
/** Sampling strategy for whisper_full() function. */
public int strategy;
/** Number of threads. (default = 4) */
public int n_threads;
/** Maximum tokens to use from past text as a prompt for the decoder. (default = 16384) */
public int n_max_text_ctx;
/** Start offset in milliseconds. (default = 0) */
public int offset_ms;
/** Audio duration to process in milliseconds. (default = 0) */
public int duration_ms;
/** Translate flag. (default = false) */
public CBool translate;
/** The compliment of translateMode() */
public void transcribeMode() {
translate = CBool.FALSE;
}
/** The compliment of transcribeMode() */
public void translateMode() {
translate = CBool.TRUE;
}
/** Flag to indicate whether to use past transcription (if any) as an initial prompt for the decoder. (default = true) */
public CBool no_context;
/** Flag to indicate whether to use past transcription (if any) as an initial prompt for the decoder. (default = true) */
public void enableContext(boolean enable) {
no_context = enable ? CBool.FALSE : CBool.TRUE;
}
/** Flag to force single segment output (useful for streaming). (default = false) */
public CBool single_segment;
/** Flag to force single segment output (useful for streaming). (default = false) */
public void singleSegment(boolean single) {
single_segment = single ? CBool.TRUE : CBool.FALSE;
}
/** Flag to print special tokens (e.g., &lt;SOT>, &lt;EOT>, &lt;BEG>, etc.). (default = false) */
public CBool print_special;
/** Flag to print special tokens (e.g., &lt;SOT>, &lt;EOT>, &lt;BEG>, etc.). (default = false) */
public void printSpecial(boolean enable) {
print_special = enable ? CBool.TRUE : CBool.FALSE;
}
/** Flag to print progress information. (default = true) */
public CBool print_progress;
/** Flag to print progress information. (default = true) */
public void printProgress(boolean enable) {
print_progress = enable ? CBool.TRUE : CBool.FALSE;
}
/** Flag to print results from within whisper.cpp (avoid it, use callback instead). (default = true) */
public CBool print_realtime;
/** Flag to print results from within whisper.cpp (avoid it, use callback instead). (default = true) */
public void printRealtime(boolean enable) {
print_realtime = enable ? CBool.TRUE : CBool.FALSE;
}
/** Flag to print timestamps for each text segment when printing realtime. (default = true) */
public CBool print_timestamps;
/** Flag to print timestamps for each text segment when printing realtime. (default = true) */
public void printTimestamps(boolean enable) {
print_timestamps = enable ? CBool.TRUE : CBool.FALSE;
}
/** [EXPERIMENTAL] Flag to enable token-level timestamps. (default = false) */
public CBool token_timestamps;
/** [EXPERIMENTAL] Flag to enable token-level timestamps. (default = false) */
public void tokenTimestamps(boolean enable) {
token_timestamps = enable ? CBool.TRUE : CBool.FALSE;
}
/** [EXPERIMENTAL] Timestamp token probability threshold (~0.01). (default = 0.01) */
public float thold_pt;
/** [EXPERIMENTAL] Timestamp token sum probability threshold (~0.01). */
public float thold_ptsum;
/** Maximum segment length in characters. (default = 0) */
public int max_len;
/** Flag to split on word rather than on token (when used with max_len). (default = false) */
public CBool split_on_word;
/** Flag to split on word rather than on token (when used with max_len). (default = false) */
public void splitOnWord(boolean enable) {
split_on_word = enable ? CBool.TRUE : CBool.FALSE;
}
/** Maximum tokens per segment (0, default = no limit) */
public int max_tokens;
/** Flag to speed up the audio by 2x using Phase Vocoder. (default = false) */
public CBool speed_up;
/** Flag to speed up the audio by 2x using Phase Vocoder. (default = false) */
public void speedUp(boolean enable) {
speed_up = enable ? CBool.TRUE : CBool.FALSE;
}
/** Overwrite the audio context size (0 = use default). */
public int audio_ctx;
/** Enable tinydiarize (default = false) */
public CBool tdrz_enable;
/** Enable tinydiarize (default = false) */
public void tdrzEnable(boolean enable) {
tdrz_enable = enable ? CBool.TRUE : CBool.FALSE;
}
/** Tokens to provide to the whisper decoder as an initial prompt.
* These are prepended to any existing text context from a previous call. */
public String initial_prompt;
/** Prompt tokens. (int*) */
public Pointer prompt_tokens;
public void setPromptTokens(int[] tokens) {
Memory mem = new Memory(tokens.length * 4L);
mem.write(0, tokens, 0, tokens.length);
prompt_tokens = mem;
}
/** Number of prompt tokens. */
public int prompt_n_tokens;
/** Language for auto-detection.
* For auto-detection, set to `null`, `""`, or "auto". */
public String language;
/** Flag to indicate whether to detect language automatically. */
public CBool detect_language;
/** Flag to indicate whether to detect language automatically. */
public void detectLanguage(boolean enable) {
detect_language = enable ? CBool.TRUE : CBool.FALSE;
}
// Common decoding parameters.
/** Flag to suppress blank tokens. */
public CBool suppress_blank;
public void suppressBlanks(boolean enable) {
suppress_blank = enable ? CBool.TRUE : CBool.FALSE;
}
/** Flag to suppress non-speech tokens. */
public CBool suppress_non_speech_tokens;
/** Flag to suppress non-speech tokens. */
public void suppressNonSpeechTokens(boolean enable) {
suppress_non_speech_tokens = enable ? CBool.TRUE : CBool.FALSE;
}
/** Initial decoding temperature. */
public float temperature;
/** Maximum initial timestamp. */
public float max_initial_ts;
/** Length penalty. */
public float length_penalty;
// Fallback parameters.
/** Temperature increment. */
public float temperature_inc;
/** Entropy threshold (similar to OpenAI's "compression_ratio_threshold"). */
public float entropy_thold;
/** Log probability threshold. */
public float logprob_thold;
/** No speech threshold. */
public float no_speech_thold;
/** Greedy decoding parameters. */
public GreedyParams greedy;
/**
* Beam search decoding parameters.
*/
public BeamSearchParams beam_search;
public void setBestOf(int bestOf) {
if (greedy == null) {
greedy = new GreedyParams();
}
greedy.best_of = bestOf;
}
public void setBeamSize(int beamSize) {
if (beam_search == null) {
beam_search = new BeamSearchParams();
}
beam_search.beam_size = beamSize;
}
public void setBeamSizeAndPatience(int beamSize, float patience) {
if (beam_search == null) {
beam_search = new BeamSearchParams();
}
beam_search.beam_size = beamSize;
beam_search.patience = patience;
}
/**
* Callback for every newly generated text segment.
* WhisperNewSegmentCallback
*/
public Pointer new_segment_callback;
/**
* User data for the new_segment_callback.
*/
public Pointer new_segment_callback_user_data;
/**
* Callback on each progress update.
* WhisperProgressCallback
*/
public Pointer progress_callback;
/**
* User data for the progress_callback.
*/
public Pointer progress_callback_user_data;
/**
* Callback each time before the encoder starts.
* WhisperEncoderBeginCallback
*/
public Pointer encoder_begin_callback;
/**
* User data for the encoder_begin_callback.
*/
public Pointer encoder_begin_callback_user_data;
/**
* Callback by each decoder to filter obtained logits.
* WhisperLogitsFilterCallback
*/
public Pointer logits_filter_callback;
/**
* User data for the logits_filter_callback.
*/
public Pointer logits_filter_callback_user_data;
public void setNewSegmentCallback(WhisperNewSegmentCallback callback) {
new_segment_callback = CallbackReference.getFunctionPointer(callback);
}
public void setProgressCallback(WhisperProgressCallback callback) {
progress_callback = CallbackReference.getFunctionPointer(callback);
}
public void setEncoderBeginCallbackeginCallbackCallback(WhisperEncoderBeginCallback callback) {
encoder_begin_callback = CallbackReference.getFunctionPointer(callback);
}
public void setLogitsFilterCallback(WhisperLogitsFilterCallback callback) {
logits_filter_callback = CallbackReference.getFunctionPointer(callback);
}
@Override
protected List<String> getFieldOrder() {
return Arrays.asList("strategy", "n_threads", "n_max_text_ctx", "offset_ms", "duration_ms", "translate",
"no_context", "single_segment",
"print_special", "print_progress", "print_realtime", "print_timestamps", "token_timestamps",
"thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "speed_up", "audio_ctx",
"tdrz_enable", "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language",
"suppress_blank", "suppress_non_speech_tokens", "temperature", "max_initial_ts", "length_penalty",
"temperature_inc", "entropy_thold", "logprob_thold", "no_speech_thold", "greedy", "beam_search",
"new_segment_callback", "new_segment_callback_user_data",
"progress_callback", "progress_callback_user_data",
"encoder_begin_callback", "encoder_begin_callback_user_data",
"logits_filter_callback", "logits_filter_callback_user_data");
}
}

View File

@ -0,0 +1,15 @@
package io.github.ggerganov.whispercpp.params;
public class WhisperHParams {
int n_vocab = 51864;
int n_audio_ctx = 1500;
int n_audio_state = 384;
int n_audio_head = 6;
int n_audio_layer = 4;
int n_text_ctx = 448;
int n_text_state = 384;
int n_text_head = 6;
int n_text_layer = 4;
int n_mels = 80;
int ftype = 1;
}

View File

@ -0,0 +1,10 @@
package io.github.ggerganov.whispercpp.params;
/** Available sampling strategies */
public enum WhisperSamplingStrategy {
/** similar to OpenAI's GreedyDecoder */
WHISPER_SAMPLING_GREEDY,
/** similar to OpenAI's BeamSearchDecoder */
WHISPER_SAMPLING_BEAM_SEARCH
}

View File

@ -0,0 +1,102 @@
package io.github.ggerganov.whispercpp;
import static org.junit.jupiter.api.Assertions.*;
import io.github.ggerganov.whispercpp.params.CBool;
import io.github.ggerganov.whispercpp.params.WhisperFullParams;
import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import javax.sound.sampled.AudioInputStream;
import javax.sound.sampled.AudioSystem;
import java.io.File;
import java.io.FileNotFoundException;
class WhisperCppTest {
private static WhisperCpp whisper = new WhisperCpp();
private static boolean modelInitialised = false;
@BeforeAll
static void init() throws FileNotFoundException {
// By default, models are loaded from ~/.cache/whisper/ and are usually named "ggml-${name}.bin"
// or you can provide the absolute path to the model file.
String modelName = "../../models/ggml-tiny.en.bin";
try {
whisper.initContext(modelName);
// whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
// whisper.getJavaDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
modelInitialised = true;
} catch (FileNotFoundException ex) {
System.out.println("Model " + modelName + " not found");
}
}
@Test
void testGetDefaultFullParams_BeamSearch() {
// When
WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
// Then
assertEquals(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH.ordinal(), params.strategy);
assertNotEquals(0, params.n_threads);
assertEquals(16384, params.n_max_text_ctx);
assertFalse(params.translate);
assertEquals(0.01f, params.thold_pt);
assertEquals(2, params.beam_search.beam_size);
assertEquals(-1.0f, params.beam_search.patience);
}
@Test
void testGetDefaultFullParams_Greedy() {
// When
WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
// Then
assertEquals(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY.ordinal(), params.strategy);
assertNotEquals(0, params.n_threads);
assertEquals(16384, params.n_max_text_ctx);
assertEquals(2, params.greedy.best_of);
}
@Test
void testFullTranscribe() throws Exception {
if (!modelInitialised) {
System.out.println("Model not initialised, skipping test");
return;
}
// Given
File file = new File(System.getProperty("user.dir"), "../../samples/jfk.wav");
AudioInputStream audioInputStream = AudioSystem.getAudioInputStream(file);
byte[] b = new byte[audioInputStream.available()];
float[] floats = new float[b.length / 2];
// WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
params.setProgressCallback((ctx, state, progress, user_data) -> System.out.println("progress: " + progress));
params.print_progress = CBool.FALSE;
// params.initial_prompt = "and so my fellow Americans um, like";
try {
audioInputStream.read(b);
for (int i = 0, j = 0; i < b.length; i += 2, j++) {
int intSample = (int) (b[i + 1]) << 8 | (int) (b[i]) & 0xFF;
floats[j] = intSample / 32767.0f;
}
// When
String result = whisper.fullTranscribe(params, floats);
// Then
System.err.println(result);
assertEquals("And so my fellow Americans ask not what your country can do for you " +
"ask what you can do for your country.",
result.replace(",", ""));
} finally {
audioInputStream.close();
}
}
}

View File

@ -0,0 +1,17 @@
package io.github.ggerganov.whispercpp;
import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.Test;
class WhisperJnaLibraryTest {
@Test
void testWhisperPrint_system_info() {
String systemInfo = WhisperCppJnaLibrary.instance.whisper_print_system_info();
// eg: "AVX = 1 | AVX2 = 1 | AVX512 = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0
// | WASM_SIMD = 0 | BLAS = 0 | SSE3 = 1 | VSX = 0 | COREML = 0 | "
System.out.println("System info: " + systemInfo);
assertTrue(systemInfo.length() > 10);
}
}

View File

@ -1,6 +1,6 @@
{
"name": "whisper.cpp",
"version": "1.2.1",
"version": "1.4.2",
"description": "Whisper speech recognition",
"main": "whisper.js",
"scripts": {

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,146 @@
//
// whisper-decoder-impl.h
//
// This file was automatically generated and should not be edited.
//
#import <Foundation/Foundation.h>
#import <CoreML/CoreML.h>
#include <stdint.h>
#include <os/log.h>
NS_ASSUME_NONNULL_BEGIN
/// Model Prediction Input Type
API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
@interface whisper_decoder_implInput : NSObject<MLFeatureProvider>
/// token_data as 1 by 1 matrix of 32-bit integers
@property (readwrite, nonatomic, strong) MLMultiArray * token_data;
/// audio_data as 1 × 384 × 1 × 1500 4-dimensional array of floats
@property (readwrite, nonatomic, strong) MLMultiArray * audio_data;
- (instancetype)init NS_UNAVAILABLE;
- (instancetype)initWithToken_data:(MLMultiArray *)token_data audio_data:(MLMultiArray *)audio_data NS_DESIGNATED_INITIALIZER;
@end
/// Model Prediction Output Type
API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
@interface whisper_decoder_implOutput : NSObject<MLFeatureProvider>
/// var_1195 as multidimensional array of floats
@property (readwrite, nonatomic, strong) MLMultiArray * var_1195;
- (instancetype)init NS_UNAVAILABLE;
- (instancetype)initWithVar_1195:(MLMultiArray *)var_1195 NS_DESIGNATED_INITIALIZER;
@end
/// Class for model loading and prediction
API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
@interface whisper_decoder_impl : NSObject
@property (readonly, nonatomic, nullable) MLModel * model;
/**
URL of the underlying .mlmodelc directory.
*/
+ (nullable NSURL *)URLOfModelInThisBundle;
/**
Initialize whisper_decoder_impl instance from an existing MLModel object.
Usually the application does not use this initializer unless it makes a subclass of whisper_decoder_impl.
Such application may want to use `-[MLModel initWithContentsOfURL:configuration:error:]` and `+URLOfModelInThisBundle` to create a MLModel object to pass-in.
*/
- (instancetype)initWithMLModel:(MLModel *)model NS_DESIGNATED_INITIALIZER;
/**
Initialize whisper_decoder_impl instance with the model in this bundle.
*/
- (nullable instancetype)init;
/**
Initialize whisper_decoder_impl instance with the model in this bundle.
@param configuration The model configuration object
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
*/
- (nullable instancetype)initWithConfiguration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error;
/**
Initialize whisper_decoder_impl instance from the model URL.
@param modelURL URL to the .mlmodelc directory for whisper_decoder_impl.
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
*/
- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL error:(NSError * _Nullable __autoreleasing * _Nullable)error;
/**
Initialize whisper_decoder_impl instance from the model URL.
@param modelURL URL to the .mlmodelc directory for whisper_decoder_impl.
@param configuration The model configuration object
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
*/
- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error;
/**
Construct whisper_decoder_impl instance asynchronously with configuration.
Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread.
@param configuration The model configuration
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_decoder_impl instance or NSError object.
*/
+ (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_decoder_impl * _Nullable model, NSError * _Nullable error))handler;
/**
Construct whisper_decoder_impl instance asynchronously with URL of .mlmodelc directory and optional configuration.
Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread.
@param modelURL The model URL.
@param configuration The model configuration
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_decoder_impl instance or NSError object.
*/
+ (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_decoder_impl * _Nullable model, NSError * _Nullable error))handler;
/**
Make a prediction using the standard interface
@param input an instance of whisper_decoder_implInput to predict from
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
@return the prediction as whisper_decoder_implOutput
*/
- (nullable whisper_decoder_implOutput *)predictionFromFeatures:(whisper_decoder_implInput *)input error:(NSError * _Nullable __autoreleasing * _Nullable)error;
/**
Make a prediction using the standard interface
@param input an instance of whisper_decoder_implInput to predict from
@param options prediction options
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
@return the prediction as whisper_decoder_implOutput
*/
- (nullable whisper_decoder_implOutput *)predictionFromFeatures:(whisper_decoder_implInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error;
/**
Make a prediction using the convenience interface
@param token_data as 1 by 1 matrix of 32-bit integers:
@param audio_data as 1 × 384 × 1 × 1500 4-dimensional array of floats:
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
@return the prediction as whisper_decoder_implOutput
*/
- (nullable whisper_decoder_implOutput *)predictionFromToken_data:(MLMultiArray *)token_data audio_data:(MLMultiArray *)audio_data error:(NSError * _Nullable __autoreleasing * _Nullable)error;
/**
Batch prediction
@param inputArray array of whisper_decoder_implInput instances to obtain predictions from
@param options prediction options
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
@return the predictions as NSArray<whisper_decoder_implOutput *>
*/
- (nullable NSArray<whisper_decoder_implOutput *> *)predictionsFromInputs:(NSArray<whisper_decoder_implInput*> *)inputArray options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error;
@end
NS_ASSUME_NONNULL_END

View File

@ -0,0 +1,201 @@
//
// whisper-decoder-impl.m
//
// This file was automatically generated and should not be edited.
//
#if !__has_feature(objc_arc)
#error This file must be compiled with automatic reference counting enabled (-fobjc-arc)
#endif
#import "whisper-decoder-impl.h"
@implementation whisper_decoder_implInput
- (instancetype)initWithToken_data:(MLMultiArray *)token_data audio_data:(MLMultiArray *)audio_data {
self = [super init];
if (self) {
_token_data = token_data;
_audio_data = audio_data;
}
return self;
}
- (NSSet<NSString *> *)featureNames {
return [NSSet setWithArray:@[@"token_data", @"audio_data"]];
}
- (nullable MLFeatureValue *)featureValueForName:(NSString *)featureName {
if ([featureName isEqualToString:@"token_data"]) {
return [MLFeatureValue featureValueWithMultiArray:self.token_data];
}
if ([featureName isEqualToString:@"audio_data"]) {
return [MLFeatureValue featureValueWithMultiArray:self.audio_data];
}
return nil;
}
@end
@implementation whisper_decoder_implOutput
- (instancetype)initWithVar_1195:(MLMultiArray *)var_1195 {
self = [super init];
if (self) {
_var_1195 = var_1195;
}
return self;
}
- (NSSet<NSString *> *)featureNames {
return [NSSet setWithArray:@[@"var_1195"]];
}
- (nullable MLFeatureValue *)featureValueForName:(NSString *)featureName {
if ([featureName isEqualToString:@"var_1195"]) {
return [MLFeatureValue featureValueWithMultiArray:self.var_1195];
}
return nil;
}
@end
@implementation whisper_decoder_impl
/**
URL of the underlying .mlmodelc directory.
*/
+ (nullable NSURL *)URLOfModelInThisBundle {
NSString *assetPath = [[NSBundle bundleForClass:[self class]] pathForResource:@"whisper_decoder_impl" ofType:@"mlmodelc"];
if (nil == assetPath) { os_log_error(OS_LOG_DEFAULT, "Could not load whisper-decoder-impl.mlmodelc in the bundle resource"); return nil; }
return [NSURL fileURLWithPath:assetPath];
}
/**
Initialize whisper_decoder_impl instance from an existing MLModel object.
Usually the application does not use this initializer unless it makes a subclass of whisper_decoder_impl.
Such application may want to use `-[MLModel initWithContentsOfURL:configuration:error:]` and `+URLOfModelInThisBundle` to create a MLModel object to pass-in.
*/
- (instancetype)initWithMLModel:(MLModel *)model {
self = [super init];
if (!self) { return nil; }
_model = model;
if (_model == nil) { return nil; }
return self;
}
/**
Initialize whisper_decoder_impl instance with the model in this bundle.
*/
- (nullable instancetype)init {
return [self initWithContentsOfURL:(NSURL * _Nonnull)self.class.URLOfModelInThisBundle error:nil];
}
/**
Initialize whisper_decoder_impl instance with the model in this bundle.
@param configuration The model configuration object
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
*/
- (nullable instancetype)initWithConfiguration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error {
return [self initWithContentsOfURL:(NSURL * _Nonnull)self.class.URLOfModelInThisBundle configuration:configuration error:error];
}
/**
Initialize whisper_decoder_impl instance from the model URL.
@param modelURL URL to the .mlmodelc directory for whisper_decoder_impl.
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
*/
- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL error:(NSError * _Nullable __autoreleasing * _Nullable)error {
MLModel *model = [MLModel modelWithContentsOfURL:modelURL error:error];
if (model == nil) { return nil; }
return [self initWithMLModel:model];
}
/**
Initialize whisper_decoder_impl instance from the model URL.
@param modelURL URL to the .mlmodelc directory for whisper_decoder_impl.
@param configuration The model configuration object
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
*/
- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error {
MLModel *model = [MLModel modelWithContentsOfURL:modelURL configuration:configuration error:error];
if (model == nil) { return nil; }
return [self initWithMLModel:model];
}
/**
Construct whisper_decoder_impl instance asynchronously with configuration.
Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread.
@param configuration The model configuration
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_decoder_impl instance or NSError object.
*/
+ (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_decoder_impl * _Nullable model, NSError * _Nullable error))handler {
[self loadContentsOfURL:(NSURL * _Nonnull)[self URLOfModelInThisBundle]
configuration:configuration
completionHandler:handler];
}
/**
Construct whisper_decoder_impl instance asynchronously with URL of .mlmodelc directory and optional configuration.
Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread.
@param modelURL The model URL.
@param configuration The model configuration
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_decoder_impl instance or NSError object.
*/
+ (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_decoder_impl * _Nullable model, NSError * _Nullable error))handler {
[MLModel loadContentsOfURL:modelURL
configuration:configuration
completionHandler:^(MLModel *model, NSError *error) {
if (model != nil) {
whisper_decoder_impl *typedModel = [[whisper_decoder_impl alloc] initWithMLModel:model];
handler(typedModel, nil);
} else {
handler(nil, error);
}
}];
}
- (nullable whisper_decoder_implOutput *)predictionFromFeatures:(whisper_decoder_implInput *)input error:(NSError * _Nullable __autoreleasing * _Nullable)error {
return [self predictionFromFeatures:input options:[[MLPredictionOptions alloc] init] error:error];
}
- (nullable whisper_decoder_implOutput *)predictionFromFeatures:(whisper_decoder_implInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error {
id<MLFeatureProvider> outFeatures = [self.model predictionFromFeatures:input options:options error:error];
if (!outFeatures) { return nil; }
return [[whisper_decoder_implOutput alloc] initWithVar_1195:(MLMultiArray *)[outFeatures featureValueForName:@"var_1195"].multiArrayValue];
}
- (nullable whisper_decoder_implOutput *)predictionFromToken_data:(MLMultiArray *)token_data audio_data:(MLMultiArray *)audio_data error:(NSError * _Nullable __autoreleasing * _Nullable)error {
whisper_decoder_implInput *input_ = [[whisper_decoder_implInput alloc] initWithToken_data:token_data audio_data:audio_data];
return [self predictionFromFeatures:input_ error:error];
}
- (nullable NSArray<whisper_decoder_implOutput *> *)predictionsFromInputs:(NSArray<whisper_decoder_implInput*> *)inputArray options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error {
id<MLBatchProvider> inBatch = [[MLArrayBatchProvider alloc] initWithFeatureProviderArray:inputArray];
id<MLBatchProvider> outBatch = [self.model predictionsFromBatch:inBatch options:options error:error];
if (!outBatch) { return nil; }
NSMutableArray<whisper_decoder_implOutput*> *results = [NSMutableArray arrayWithCapacity:(NSUInteger)outBatch.count];
for (NSInteger i = 0; i < outBatch.count; i++) {
id<MLFeatureProvider> resultProvider = [outBatch featuresAtIndex:i];
whisper_decoder_implOutput * result = [[whisper_decoder_implOutput alloc] initWithVar_1195:(MLMultiArray *)[resultProvider featureValueForName:@"var_1195"].multiArrayValue];
[results addObject:result];
}
return results;
}
@end

View File

@ -0,0 +1,142 @@
//
// whisper-encoder-impl.h
//
// This file was automatically generated and should not be edited.
//
#import <Foundation/Foundation.h>
#import <CoreML/CoreML.h>
#include <stdint.h>
#include <os/log.h>
NS_ASSUME_NONNULL_BEGIN
/// Model Prediction Input Type
API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
@interface whisper_encoder_implInput : NSObject<MLFeatureProvider>
/// logmel_data as 1 × 80 × 3000 3-dimensional array of floats
@property (readwrite, nonatomic, strong) MLMultiArray * logmel_data;
- (instancetype)init NS_UNAVAILABLE;
- (instancetype)initWithLogmel_data:(MLMultiArray *)logmel_data NS_DESIGNATED_INITIALIZER;
@end
/// Model Prediction Output Type
API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
@interface whisper_encoder_implOutput : NSObject<MLFeatureProvider>
/// output as multidimensional array of floats
@property (readwrite, nonatomic, strong) MLMultiArray * output;
- (instancetype)init NS_UNAVAILABLE;
- (instancetype)initWithOutput:(MLMultiArray *)output NS_DESIGNATED_INITIALIZER;
@end
/// Class for model loading and prediction
API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
@interface whisper_encoder_impl : NSObject
@property (readonly, nonatomic, nullable) MLModel * model;
/**
URL of the underlying .mlmodelc directory.
*/
+ (nullable NSURL *)URLOfModelInThisBundle;
/**
Initialize whisper_encoder_impl instance from an existing MLModel object.
Usually the application does not use this initializer unless it makes a subclass of whisper_encoder_impl.
Such application may want to use `-[MLModel initWithContentsOfURL:configuration:error:]` and `+URLOfModelInThisBundle` to create a MLModel object to pass-in.
*/
- (instancetype)initWithMLModel:(MLModel *)model NS_DESIGNATED_INITIALIZER;
/**
Initialize whisper_encoder_impl instance with the model in this bundle.
*/
- (nullable instancetype)init;
/**
Initialize whisper_encoder_impl instance with the model in this bundle.
@param configuration The model configuration object
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
*/
- (nullable instancetype)initWithConfiguration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error;
/**
Initialize whisper_encoder_impl instance from the model URL.
@param modelURL URL to the .mlmodelc directory for whisper_encoder_impl.
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
*/
- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL error:(NSError * _Nullable __autoreleasing * _Nullable)error;
/**
Initialize whisper_encoder_impl instance from the model URL.
@param modelURL URL to the .mlmodelc directory for whisper_encoder_impl.
@param configuration The model configuration object
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
*/
- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error;
/**
Construct whisper_encoder_impl instance asynchronously with configuration.
Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread.
@param configuration The model configuration
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_encoder_impl instance or NSError object.
*/
+ (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_encoder_impl * _Nullable model, NSError * _Nullable error))handler;
/**
Construct whisper_encoder_impl instance asynchronously with URL of .mlmodelc directory and optional configuration.
Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread.
@param modelURL The model URL.
@param configuration The model configuration
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_encoder_impl instance or NSError object.
*/
+ (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_encoder_impl * _Nullable model, NSError * _Nullable error))handler;
/**
Make a prediction using the standard interface
@param input an instance of whisper_encoder_implInput to predict from
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
@return the prediction as whisper_encoder_implOutput
*/
- (nullable whisper_encoder_implOutput *)predictionFromFeatures:(whisper_encoder_implInput *)input error:(NSError * _Nullable __autoreleasing * _Nullable)error;
/**
Make a prediction using the standard interface
@param input an instance of whisper_encoder_implInput to predict from
@param options prediction options
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
@return the prediction as whisper_encoder_implOutput
*/
- (nullable whisper_encoder_implOutput *)predictionFromFeatures:(whisper_encoder_implInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error;
/**
Make a prediction using the convenience interface
@param logmel_data as 1 × 80 × 3000 3-dimensional array of floats:
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
@return the prediction as whisper_encoder_implOutput
*/
- (nullable whisper_encoder_implOutput *)predictionFromLogmel_data:(MLMultiArray *)logmel_data error:(NSError * _Nullable __autoreleasing * _Nullable)error;
/**
Batch prediction
@param inputArray array of whisper_encoder_implInput instances to obtain predictions from
@param options prediction options
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
@return the predictions as NSArray<whisper_encoder_implOutput *>
*/
- (nullable NSArray<whisper_encoder_implOutput *> *)predictionsFromInputs:(NSArray<whisper_encoder_implInput*> *)inputArray options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error;
@end
NS_ASSUME_NONNULL_END

View File

@ -0,0 +1,197 @@
//
// whisper-encoder-impl.m
//
// This file was automatically generated and should not be edited.
//
#if !__has_feature(objc_arc)
#error This file must be compiled with automatic reference counting enabled (-fobjc-arc)
#endif
#import "whisper-encoder-impl.h"
@implementation whisper_encoder_implInput
- (instancetype)initWithLogmel_data:(MLMultiArray *)logmel_data {
self = [super init];
if (self) {
_logmel_data = logmel_data;
}
return self;
}
- (NSSet<NSString *> *)featureNames {
return [NSSet setWithArray:@[@"logmel_data"]];
}
- (nullable MLFeatureValue *)featureValueForName:(NSString *)featureName {
if ([featureName isEqualToString:@"logmel_data"]) {
return [MLFeatureValue featureValueWithMultiArray:self.logmel_data];
}
return nil;
}
@end
@implementation whisper_encoder_implOutput
- (instancetype)initWithOutput:(MLMultiArray *)output {
self = [super init];
if (self) {
_output = output;
}
return self;
}
- (NSSet<NSString *> *)featureNames {
return [NSSet setWithArray:@[@"output"]];
}
- (nullable MLFeatureValue *)featureValueForName:(NSString *)featureName {
if ([featureName isEqualToString:@"output"]) {
return [MLFeatureValue featureValueWithMultiArray:self.output];
}
return nil;
}
@end
@implementation whisper_encoder_impl
/**
URL of the underlying .mlmodelc directory.
*/
+ (nullable NSURL *)URLOfModelInThisBundle {
NSString *assetPath = [[NSBundle bundleForClass:[self class]] pathForResource:@"whisper_encoder_impl" ofType:@"mlmodelc"];
if (nil == assetPath) { os_log_error(OS_LOG_DEFAULT, "Could not load whisper-encoder-impl.mlmodelc in the bundle resource"); return nil; }
return [NSURL fileURLWithPath:assetPath];
}
/**
Initialize whisper_encoder_impl instance from an existing MLModel object.
Usually the application does not use this initializer unless it makes a subclass of whisper_encoder_impl.
Such application may want to use `-[MLModel initWithContentsOfURL:configuration:error:]` and `+URLOfModelInThisBundle` to create a MLModel object to pass-in.
*/
- (instancetype)initWithMLModel:(MLModel *)model {
self = [super init];
if (!self) { return nil; }
_model = model;
if (_model == nil) { return nil; }
return self;
}
/**
Initialize whisper_encoder_impl instance with the model in this bundle.
*/
- (nullable instancetype)init {
return [self initWithContentsOfURL:(NSURL * _Nonnull)self.class.URLOfModelInThisBundle error:nil];
}
/**
Initialize whisper_encoder_impl instance with the model in this bundle.
@param configuration The model configuration object
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
*/
- (nullable instancetype)initWithConfiguration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error {
return [self initWithContentsOfURL:(NSURL * _Nonnull)self.class.URLOfModelInThisBundle configuration:configuration error:error];
}
/**
Initialize whisper_encoder_impl instance from the model URL.
@param modelURL URL to the .mlmodelc directory for whisper_encoder_impl.
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
*/
- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL error:(NSError * _Nullable __autoreleasing * _Nullable)error {
MLModel *model = [MLModel modelWithContentsOfURL:modelURL error:error];
if (model == nil) { return nil; }
return [self initWithMLModel:model];
}
/**
Initialize whisper_encoder_impl instance from the model URL.
@param modelURL URL to the .mlmodelc directory for whisper_encoder_impl.
@param configuration The model configuration object
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
*/
- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error {
MLModel *model = [MLModel modelWithContentsOfURL:modelURL configuration:configuration error:error];
if (model == nil) { return nil; }
return [self initWithMLModel:model];
}
/**
Construct whisper_encoder_impl instance asynchronously with configuration.
Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread.
@param configuration The model configuration
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_encoder_impl instance or NSError object.
*/
+ (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_encoder_impl * _Nullable model, NSError * _Nullable error))handler {
[self loadContentsOfURL:(NSURL * _Nonnull)[self URLOfModelInThisBundle]
configuration:configuration
completionHandler:handler];
}
/**
Construct whisper_encoder_impl instance asynchronously with URL of .mlmodelc directory and optional configuration.
Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread.
@param modelURL The model URL.
@param configuration The model configuration
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_encoder_impl instance or NSError object.
*/
+ (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_encoder_impl * _Nullable model, NSError * _Nullable error))handler {
[MLModel loadContentsOfURL:modelURL
configuration:configuration
completionHandler:^(MLModel *model, NSError *error) {
if (model != nil) {
whisper_encoder_impl *typedModel = [[whisper_encoder_impl alloc] initWithMLModel:model];
handler(typedModel, nil);
} else {
handler(nil, error);
}
}];
}
- (nullable whisper_encoder_implOutput *)predictionFromFeatures:(whisper_encoder_implInput *)input error:(NSError * _Nullable __autoreleasing * _Nullable)error {
return [self predictionFromFeatures:input options:[[MLPredictionOptions alloc] init] error:error];
}
- (nullable whisper_encoder_implOutput *)predictionFromFeatures:(whisper_encoder_implInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error {
id<MLFeatureProvider> outFeatures = [self.model predictionFromFeatures:input options:options error:error];
if (!outFeatures) { return nil; }
return [[whisper_encoder_implOutput alloc] initWithOutput:(MLMultiArray *)[outFeatures featureValueForName:@"output"].multiArrayValue];
}
- (nullable whisper_encoder_implOutput *)predictionFromLogmel_data:(MLMultiArray *)logmel_data error:(NSError * _Nullable __autoreleasing * _Nullable)error {
whisper_encoder_implInput *input_ = [[whisper_encoder_implInput alloc] initWithLogmel_data:logmel_data];
return [self predictionFromFeatures:input_ error:error];
}
- (nullable NSArray<whisper_encoder_implOutput *> *)predictionsFromInputs:(NSArray<whisper_encoder_implInput*> *)inputArray options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error {
id<MLBatchProvider> inBatch = [[MLArrayBatchProvider alloc] initWithFeatureProviderArray:inputArray];
id<MLBatchProvider> outBatch = [self.model predictionsFromBatch:inBatch options:options error:error];
if (!outBatch) { return nil; }
NSMutableArray<whisper_encoder_implOutput*> *results = [NSMutableArray arrayWithCapacity:(NSUInteger)outBatch.count];
for (NSInteger i = 0; i < outBatch.count; i++) {
id<MLFeatureProvider> resultProvider = [outBatch featuresAtIndex:i];
whisper_encoder_implOutput * result = [[whisper_encoder_implOutput alloc] initWithOutput:(MLMultiArray *)[resultProvider featureValueForName:@"output"].multiArrayValue];
[results addObject:result];
}
return results;
}
@end

22
coreml/whisper-encoder.h Normal file
View File

@ -0,0 +1,22 @@
// Wrapper of the Core ML Whisper Encoder model
//
// Code is derived from the work of Github user @wangchou
// ref: https://github.com/wangchou/callCoreMLFromCpp
#if __cplusplus
extern "C" {
#endif
struct whisper_coreml_context;
struct whisper_coreml_context * whisper_coreml_init(const char * path_model);
void whisper_coreml_free(struct whisper_coreml_context * ctx);
void whisper_coreml_encode(
const whisper_coreml_context * ctx,
float * mel,
float * out);
#if __cplusplus
}
#endif

63
coreml/whisper-encoder.mm Normal file
View File

@ -0,0 +1,63 @@
#if !__has_feature(objc_arc)
#error This file must be compiled with automatic reference counting enabled (-fobjc-arc)
#endif
#import "whisper-encoder.h"
#import "whisper-encoder-impl.h"
#import <CoreML/CoreML.h>
#include <stdlib.h>
#if __cplusplus
extern "C" {
#endif
struct whisper_coreml_context {
const void * data;
};
struct whisper_coreml_context * whisper_coreml_init(const char * path_model) {
NSString * path_model_str = [[NSString alloc] initWithUTF8String:path_model];
NSURL * url_model = [NSURL fileURLWithPath: path_model_str];
const void * data = CFBridgingRetain([[whisper_encoder_impl alloc] initWithContentsOfURL:url_model error:nil]);
if (data == NULL) {
return NULL;
}
whisper_coreml_context * ctx = new whisper_coreml_context;
ctx->data = data;
return ctx;
}
void whisper_coreml_free(struct whisper_coreml_context * ctx) {
CFRelease(ctx->data);
delete ctx;
}
void whisper_coreml_encode(
const whisper_coreml_context * ctx,
float * mel,
float * out) {
MLMultiArray * inMultiArray = [
[MLMultiArray alloc] initWithDataPointer: mel
shape: @[@1, @80, @3000]
dataType: MLMultiArrayDataTypeFloat32
strides: @[@(240000), @(3000), @1]
deallocator: nil
error: nil
];
whisper_encoder_implOutput * outCoreML = [(__bridge id) ctx->data predictionFromLogmel_data:inMultiArray error:nil];
memcpy(out, outCoreML.output.dataPointer, outCoreML.output.count * sizeof(float));
}
#if __cplusplus
}
#endif

View File

@ -4,7 +4,7 @@ find_package(Threads REQUIRED)
# third-party
if (WHISPER_SUPPORT_SDL2)
if (WHISPER_SDL2)
# SDL2
find_package(SDL2 REQUIRED)
@ -21,13 +21,17 @@ set(TARGET common)
add_library(${TARGET} STATIC
common.h
common.cpp
common-ggml.h
common-ggml.cpp
)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE whisper)
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
if (WHISPER_SUPPORT_SDL2)
if (WHISPER_SDL2)
# common-sdl
set(TARGET common-sdl)
@ -62,5 +66,7 @@ else()
add_subdirectory(stream)
add_subdirectory(command)
add_subdirectory(bench)
add_subdirectory(quantize)
add_subdirectory(talk)
add_subdirectory(talk-llama)
endif()

View File

@ -1,15 +1,23 @@
const path = require('path');
const { whisper } = require(path.join(__dirname, '../../../build/Release/whisper-addon'));
const path = require("path");
const { whisper } = require(path.join(
__dirname,
"../../../build/Release/whisper-addon"
));
const { promisify } = require("util");
const whisperAsync = promisify(whisper);
const whisperParamsMock = {
language: 'en',
model: path.join(__dirname, '../../../models/ggml-base.en.bin'),
fname_inp: path.join(__dirname, '../../../samples/jfk.wav'),
language: "en",
model: path.join(__dirname, "../../../models/ggml-base.en.bin"),
fname_inp: path.join(__dirname, "../../../samples/jfk.wav"),
};
describe("Run whisper.node", () => {
test("it should receive a non-empty value", async () => {
let result = await whisperAsync(whisperParamsMock);
test("it should receive a non-empty value", () => {
expect(whisper(whisperParamsMock).length).toBeGreaterThan(0);
});
expect(result.length).toBeGreaterThan(0);
}, 10000);
});

View File

@ -160,22 +160,6 @@ int run(whisper_params &params, std::vector<std::vector<std::string>> &result) {
return 3;
}
// initial prompt
std::vector<whisper_token> prompt_tokens;
if (!params.prompt.empty()) {
prompt_tokens.resize(1024);
prompt_tokens.resize(whisper_tokenize(ctx, params.prompt.c_str(), prompt_tokens.data(), prompt_tokens.size()));
fprintf(stderr, "\n");
fprintf(stderr, "initial prompt: '%s'\n", params.prompt.c_str());
fprintf(stderr, "initial tokens: [ ");
for (int i = 0; i < (int) prompt_tokens.size(); ++i) {
fprintf(stderr, "%d ", prompt_tokens[i]);
}
fprintf(stderr, "]\n");
}
for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
const auto fname_inp = params.fname_inp[f];
const auto fname_out = f < (int)params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
@ -243,8 +227,7 @@ int run(whisper_params &params, std::vector<std::vector<std::string>> &result) {
wparams.greedy.best_of = params.best_of;
wparams.beam_search.beam_size = params.beam_size;
wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();
wparams.initial_prompt = params.prompt.c_str();
whisper_print_user_data user_data = { &params, &pcmf32s };

View File

@ -31,9 +31,9 @@ endif()
set_target_properties(${TARGET} PROPERTIES LINK_FLAGS " \
--bind \
-s USE_PTHREADS=1 \
-s PTHREAD_POOL_SIZE=8 \
-s INITIAL_MEMORY=1024MB \
-s TOTAL_MEMORY=1024MB \
-s PTHREAD_POOL_SIZE_STRICT=0 \
-s INITIAL_MEMORY=2000MB \
-s TOTAL_MEMORY=2000MB \
-s FORCE_FILESYSTEM=1 \
-s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap']\" \
${EXTRA_FLAGS} \

View File

@ -35,6 +35,15 @@
<br><br>
<b>More examples:</b>
<a href="https://whisper.ggerganov.com/">main</a> |
<a href="https://whisper.ggerganov.com/bench">bench</a> |
<a href="https://whisper.ggerganov.com/stream">stream</a> |
<a href="https://whisper.ggerganov.com/command">command</a> |
<a href="https://whisper.ggerganov.com/talk">talk</a> |
<br><br>
<hr>
Select the model you would like to use and click the "Bench" button.<br>
@ -44,11 +53,18 @@
<div id="model-whisper">
Whisper model: <span id="model-whisper-status"></span>
<button id="fetch-whisper-tiny-en" onclick="loadWhisper('tiny.en')">tiny.en (75 MB)</button>
<button id="fetch-whisper-base-en" onclick="loadWhisper('base.en')">base.en (142 MB)</button>
<span id="fetch-whisper-progress"></span>
<button id="fetch-whisper-tiny-en" onclick="loadWhisper('tiny.en')">tiny.en (75 MB)</button>
<button id="fetch-whisper-base-en" onclick="loadWhisper('base.en')">base.en (142 MB)</button>
<button id="fetch-whisper-small-en" onclick="loadWhisper('small.en')">small.en (466 MB)</button>
<input type="file" id="whisper-file" name="file" onchange="loadFile(event, 'whisper.bin')" />
<br><br>
Quantized models:<br><br>
<button id="fetch-whisper-tiny-en-q5_1" onclick="loadWhisper('tiny-en-q5_1')">tiny.en (Q5_1, 31 MB)</button>
<button id="fetch-whisper-base-en-q5_1" onclick="loadWhisper('base-en-q5_1')">base.en (Q5_1, 57 MB)</button>
<button id="fetch-whisper-small-en-q5_1" onclick="loadWhisper('small-en-q5_1')">small.en (Q5_1, 182 MB)</button>
<button id="fetch-whisper-medium-en-q5_0" onclick="loadWhisper('medium-en-q5_0')">medium.en (Q5_0, 515 MB)</button>
<button id="fetch-whisper-large-q5_0" onclick="loadWhisper('large-q5_0')">large (Q5_0, 1030 MB)</button>
<span id="fetch-whisper-progress"></span>
</div>
<br>
@ -160,6 +176,14 @@
document.getElementById('fetch-whisper-tiny-en').style.display = 'none';
document.getElementById('fetch-whisper-base-en').style.display = 'none';
document.getElementById('fetch-whisper-small-en').style.display = 'none';
document.getElementById('fetch-whisper-tiny-en-q5_1' ).style.display = 'none';
document.getElementById('fetch-whisper-base-en-q5_1' ).style.display = 'none';
document.getElementById('fetch-whisper-small-en-q5_1' ).style.display = 'none';
document.getElementById('fetch-whisper-medium-en-q5_0').style.display = 'none';
document.getElementById('fetch-whisper-large-q5_0' ).style.display = 'none';
document.getElementById('whisper-file' ).style.display = 'none';
document.getElementById('model-whisper-status' ).innerHTML = 'loaded model: ' + file.name;
}
@ -168,19 +192,42 @@
let urls = {
'tiny.en': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en.bin',
'base.en': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en.bin',
'small.en': 'https://whisper.ggerganov.com/ggml-model-whisper-small.en.bin',
'tiny-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en-q5_1.bin',
'base-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en-q5_1.bin',
'small-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-small.en-q5_1.bin',
'medium-en-q5_0':'https://whisper.ggerganov.com/ggml-model-whisper-medium.en-q5_0.bin',
'large-q5_0': 'https://whisper.ggerganov.com/ggml-model-whisper-large-q5_0.bin',
};
let sizes = {
'tiny.en': 75,
'base.en': 142,
'small.en': 466,
'tiny-en-q5_1': 31,
'base-en-q5_1': 57,
'small-en-q5_1': 182,
'medium-en-q5_0': 515,
'large-q5_0': 1030,
};
let url = urls[model];
let dst = 'whisper.bin';
let size_mb = sizes[model];
document.getElementById('fetch-whisper-tiny-en').style.display = 'none';
document.getElementById('fetch-whisper-base-en').style.display = 'none';
document.getElementById('fetch-whisper-tiny-en').style.display = 'none';
document.getElementById('fetch-whisper-base-en').style.display = 'none';
document.getElementById('fetch-whisper-small-en').style.display = 'none';
document.getElementById('fetch-whisper-tiny-en-q5_1' ).style.display = 'none';
document.getElementById('fetch-whisper-base-en-q5_1' ).style.display = 'none';
document.getElementById('fetch-whisper-small-en-q5_1' ).style.display = 'none';
document.getElementById('fetch-whisper-medium-en-q5_0').style.display = 'none';
document.getElementById('fetch-whisper-large-q5_0' ).style.display = 'none';
document.getElementById('whisper-file' ).style.display = 'none';
document.getElementById('model-whisper-status').innerHTML = 'loading "' + model + '" ... ';
cbProgress = function(p) {
@ -190,9 +237,18 @@
cbCancel = function() {
var el;
el = document.getElementById('fetch-whisper-tiny-en'); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-base-en'); if (el) el.style.display = 'inline-block';
el = document.getElementById('model-whisper-status'); if (el) el.innerHTML = '';
el = document.getElementById('fetch-whisper-tiny-en'); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-base-en'); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-small-en'); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-tiny-en-q5_1' ); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-base-en-q5_1' ); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-small-en-q5_1' ); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-medium-en-q5_0'); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-large-q5_0' ); if (el) el.style.display = 'inline-block';
el = document.getElementById('whisper-file' ); if (el) el.style.display = 'inline-block';
el = document.getElementById('model-whisper-status'); if (el) el.innerHTML = '';
};
loadRemote(url, dst, size_mb, cbProgress, storeFS, cbCancel, printTextarea);

View File

@ -28,31 +28,6 @@ std::string g_transcribed = "";
std::vector<float> g_pcmf32;
// compute similarity between two strings using Levenshtein distance
static float similarity(const std::string & s0, const std::string & s1) {
const size_t len0 = s0.size() + 1;
const size_t len1 = s1.size() + 1;
std::vector<int> col(len1, 0);
std::vector<int> prevCol(len1, 0);
for (size_t i = 0; i < len1; i++) {
prevCol[i] = i;
}
for (size_t i = 0; i < len0; i++) {
col[0] = i;
for (size_t j = 1; j < len1; j++) {
col[j] = std::min(std::min(1 + col[j - 1], 1 + prevCol[j]), prevCol[j - 1] + (s0[i - 1] == s1[j - 1] ? 0 : 1));
}
col.swap(prevCol);
}
const float dist = prevCol[len1 - 1];
return 1.0f - (dist / std::max(s0.size(), s1.size()));
}
void command_set_status(const std::string & status) {
std::lock_guard<std::mutex> lock(g_mutex);
g_status = status;

View File

@ -35,6 +35,15 @@
<br><br>
<b>More examples:</b>
<a href="https://whisper.ggerganov.com/">main</a> |
<a href="https://whisper.ggerganov.com/bench">bench</a> |
<a href="https://whisper.ggerganov.com/stream">stream</a> |
<a href="https://whisper.ggerganov.com/command">command</a> |
<a href="https://whisper.ggerganov.com/talk">talk</a> |
<br><br>
<hr>
Select the model you would like to use, click the "Start" button and follow the instructions.
@ -45,6 +54,10 @@
Whisper model: <span id="model-whisper-status"></span>
<button id="fetch-whisper-tiny-en" onclick="loadWhisper('tiny.en')">tiny.en (75 MB)</button>
<button id="fetch-whisper-base-en" onclick="loadWhisper('base.en')">base.en (142 MB)</button>
<br><br>
Quantized models:<br><br>
<button id="fetch-whisper-tiny-en-q5_1" onclick="loadWhisper('tiny-en-q5_1')">tiny.en (Q5_1, 31 MB)</button>
<button id="fetch-whisper-base-en-q5_1" onclick="loadWhisper('base-en-q5_1')">base.en (Q5_1, 57 MB)</button>
<span id="fetch-whisper-progress"></span>
<!--
@ -162,11 +175,17 @@
let urls = {
'tiny.en': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en.bin',
'base.en': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en.bin',
'tiny-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en-q5_1.bin',
'base-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en-q5_1.bin',
};
let sizes = {
'tiny.en': 75,
'base.en': 142,
'tiny-en-q5_1': 31,
'base-en-q5_1': 57,
};
let url = urls[model];
@ -177,6 +196,10 @@
document.getElementById('fetch-whisper-tiny-en').style.display = 'none';
document.getElementById('fetch-whisper-base-en').style.display = 'none';
document.getElementById('fetch-whisper-tiny-en-q5_1').style.display = 'none';
document.getElementById('fetch-whisper-base-en-q5_1').style.display = 'none';
document.getElementById('model-whisper-status').innerHTML = 'loading "' + model + '" ... ';
cbProgress = function(p) {
@ -188,6 +211,10 @@
var el;
el = document.getElementById('fetch-whisper-tiny-en'); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-base-en'); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-tiny-en-q5_1'); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-base-en-q5_1'); if (el) el.style.display = 'inline-block';
el = document.getElementById('model-whisper-status'); if (el) el.innerHTML = '';
};

View File

@ -1,4 +1,4 @@
if (WHISPER_SUPPORT_SDL2)
if (WHISPER_SDL2)
# command
set(TARGET command)
add_executable(${TARGET} command.cpp)

View File

@ -163,31 +163,6 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con
return result;
}
// compute similarity between two strings using Levenshtein distance
float similarity(const std::string & s0, const std::string & s1) {
const size_t len0 = s0.size() + 1;
const size_t len1 = s1.size() + 1;
std::vector<int> col(len1, 0);
std::vector<int> prevCol(len1, 0);
for (size_t i = 0; i < len1; i++) {
prevCol[i] = i;
}
for (size_t i = 0; i < len0; i++) {
col[0] = i;
for (size_t j = 1; j < len1; j++) {
col[j] = std::min(std::min(1 + col[j - 1], 1 + prevCol[j]), prevCol[j - 1] + (s0[i - 1] == s1[j - 1] ? 0 : 1));
}
col.swap(prevCol);
}
const float dist = prevCol[len1 - 1];
return 1.0f - (dist / std::max(s0.size(), s1.size()));
}
std::vector<std::string> read_allowed_commands(const std::string & fname) {
std::vector<std::string> allowed_commands;

246
examples/common-ggml.cpp Normal file
View File

@ -0,0 +1,246 @@
#include "common-ggml.h"
#include <regex>
#include <map>
static const std::map<std::string, enum ggml_ftype> GGML_FTYPE_MAP = {
{"q4_0", GGML_FTYPE_MOSTLY_Q4_0},
{"q4_1", GGML_FTYPE_MOSTLY_Q4_1},
{"q5_0", GGML_FTYPE_MOSTLY_Q5_0},
{"q5_1", GGML_FTYPE_MOSTLY_Q5_1},
{"q8_0", GGML_FTYPE_MOSTLY_Q8_0},
};
void ggml_print_ftypes(FILE * fp) {
for (auto it = GGML_FTYPE_MAP.begin(); it != GGML_FTYPE_MAP.end(); it++) {
fprintf(fp, " type = \"%s\" or %d\n", it->first.c_str(), it->second);
}
}
enum ggml_ftype ggml_parse_ftype(const char * str) {
enum ggml_ftype ftype;
if (str[0] == 'q') {
const auto it = GGML_FTYPE_MAP.find(str);
if (it == GGML_FTYPE_MAP.end()) {
fprintf(stderr, "%s: unknown ftype '%s'\n", __func__, str);
return GGML_FTYPE_UNKNOWN;
}
ftype = it->second;
} else {
ftype = (enum ggml_ftype) atoi(str);
}
return ftype;
}
bool ggml_common_quantize_0(
std::ifstream & finp,
std::ofstream & fout,
const ggml_ftype ftype,
const std::vector<std::string> & to_quant,
const std::vector<std::string> & to_skip) {
ggml_type qtype = GGML_TYPE_F32;
switch (ftype) {
case GGML_FTYPE_MOSTLY_Q4_0: qtype = GGML_TYPE_Q4_0; break;
case GGML_FTYPE_MOSTLY_Q4_1: qtype = GGML_TYPE_Q4_1; break;
case GGML_FTYPE_MOSTLY_Q5_0: qtype = GGML_TYPE_Q5_0; break;
case GGML_FTYPE_MOSTLY_Q5_1: qtype = GGML_TYPE_Q5_1; break;
case GGML_FTYPE_MOSTLY_Q8_0: qtype = GGML_TYPE_Q8_0; break;
case GGML_FTYPE_UNKNOWN:
case GGML_FTYPE_ALL_F32:
case GGML_FTYPE_MOSTLY_F16:
case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16:
case GGML_FTYPE_MOSTLY_Q2_K:
case GGML_FTYPE_MOSTLY_Q3_K:
case GGML_FTYPE_MOSTLY_Q4_K:
case GGML_FTYPE_MOSTLY_Q5_K:
case GGML_FTYPE_MOSTLY_Q6_K:
{
fprintf(stderr, "%s: invalid model type %d\n", __func__, ftype);
return false;
}
};
if (!ggml_is_quantized(qtype)) {
fprintf(stderr, "%s: invalid quantization type %d (%s)\n", __func__, qtype, ggml_type_name(qtype));
return false;
}
size_t total_size_org = 0;
size_t total_size_new = 0;
std::vector<float> work;
std::vector<uint8_t> data_u8;
std::vector<ggml_fp16_t> data_f16;
std::vector<float> data_f32;
std::vector<int64_t> hist_all(1 << 4, 0);
while (true) {
int32_t n_dims;
int32_t length;
int32_t ttype;
finp.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
finp.read(reinterpret_cast<char *>(&length), sizeof(length));
finp.read(reinterpret_cast<char *>(&ttype), sizeof(ttype));
if (finp.eof()) {
break;
}
int32_t nelements = 1;
int32_t ne[4] = { 1, 1, 1, 1 };
for (int i = 0; i < n_dims; ++i) {
finp.read (reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
nelements *= ne[i];
}
std::string name(length, 0);
finp.read (&name[0], length);
printf("%64s - [%5d, %5d, %5d], type = %6s ", name.data(), ne[0], ne[1], ne[2], ggml_type_name((ggml_type) ttype));
bool quantize = false;
// check if we should quantize this tensor
for (const auto & s : to_quant) {
if (std::regex_match(name, std::regex(s))) {
quantize = true;
break;
}
}
// check if we should skip this tensor
for (const auto & s : to_skip) {
if (std::regex_match(name, std::regex(s))) {
quantize = false;
break;
}
}
// quantize only 2D tensors
quantize &= (n_dims == 2);
if (quantize) {
if (ttype != GGML_TYPE_F32 && ttype != GGML_TYPE_F16) {
fprintf(stderr, "%s: unsupported ttype %d (%s) for integer quantization\n", __func__, ttype, ggml_type_name((ggml_type) ttype));
return false;
}
if (ttype == GGML_TYPE_F16) {
data_f16.resize(nelements);
finp.read(reinterpret_cast<char *>(data_f16.data()), nelements * sizeof(ggml_fp16_t));
data_f32.resize(nelements);
for (int i = 0; i < nelements; ++i) {
data_f32[i] = ggml_fp16_to_fp32(data_f16[i]);
}
} else {
data_f32.resize(nelements);
finp.read(reinterpret_cast<char *>(data_f32.data()), nelements * sizeof(float));
}
ttype = qtype;
} else {
const int bpe = (ttype == 0) ? sizeof(float) : sizeof(uint16_t);
data_u8.resize(nelements*bpe);
finp.read(reinterpret_cast<char *>(data_u8.data()), nelements * bpe);
}
fout.write(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
fout.write(reinterpret_cast<char *>(&length), sizeof(length));
fout.write(reinterpret_cast<char *>(&ttype), sizeof(ttype));
for (int i = 0; i < n_dims; ++i) {
fout.write(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
}
fout.write(&name[0], length);
if (quantize) {
work.resize(nelements); // for quantization
size_t cur_size = 0;
std::vector<int64_t> hist_cur(1 << 4, 0);
switch ((ggml_type) ttype) {
case GGML_TYPE_Q4_0:
{
cur_size = ggml_quantize_q4_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
} break;
case GGML_TYPE_Q4_1:
{
cur_size = ggml_quantize_q4_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
} break;
case GGML_TYPE_Q5_0:
{
cur_size = ggml_quantize_q5_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
} break;
case GGML_TYPE_Q5_1:
{
cur_size = ggml_quantize_q5_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
} break;
case GGML_TYPE_Q8_0:
{
cur_size = ggml_quantize_q8_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
} break;
case GGML_TYPE_F32:
case GGML_TYPE_F16:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
case GGML_TYPE_Q8_1:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_Q8_K:
case GGML_TYPE_COUNT:
{
fprintf(stderr, "%s: unsupported quantization type %d (%s)\n", __func__, ttype, ggml_type_name((ggml_type) ttype));
return false;
}
}
fout.write(reinterpret_cast<char *>(work.data()), cur_size);
total_size_new += cur_size;
printf("size = %8.2f MB -> %8.2f MB | hist: ", nelements * sizeof(float)/1024.0/1024.0, cur_size/1024.0/1024.0);
for (int i = 0; i < (int) hist_cur.size(); ++i) {
hist_all[i] += hist_cur[i];
}
for (int i = 0; i < (int) hist_cur.size(); ++i) {
printf("%5.3f ", hist_cur[i] / (float)nelements);
}
printf("\n");
} else {
printf("size = %8.3f MB\n", data_u8.size()/1024.0/1024.0);
fout.write(reinterpret_cast<char *>(data_u8.data()), data_u8.size());
total_size_new += data_u8.size();
}
total_size_org += nelements * sizeof(float);
}
printf("%s: model size = %8.2f MB\n", __func__, total_size_org/1024.0/1024.0);
printf("%s: quant size = %8.2f MB | ftype = %d (%s)\n", __func__, total_size_new/1024.0/1024.0, ftype, ggml_type_name(qtype));
{
int64_t sum_all = 0;
for (int i = 0; i < (int) hist_all.size(); ++i) {
sum_all += hist_all[i];
}
printf("%s: hist: ", __func__);
for (int i = 0; i < (int) hist_all.size(); ++i) {
printf("%5.3f ", hist_all[i] / (float)sum_all);
}
printf("\n");
}
return true;
}

18
examples/common-ggml.h Normal file
View File

@ -0,0 +1,18 @@
#pragma once
#include "ggml.h"
#include <fstream>
#include <vector>
#include <string>
enum ggml_ftype ggml_parse_ftype(const char * str);
void ggml_print_ftypes(FILE * fp = stderr);
bool ggml_common_quantize_0(
std::ifstream & finp,
std::ofstream & fout,
const ggml_ftype ftype,
const std::vector<std::string> & to_quant,
const std::vector<std::string> & to_skip);

View File

@ -6,12 +6,126 @@
#include "dr_wav.h"
#include <cmath>
#include <cstring>
#include <fstream>
#include <regex>
#include <locale>
#include <codecvt>
#include <sstream>
#ifndef M_PI
#define M_PI 3.14159265358979323846
#endif
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
for (int i = 1; i < argc; i++) {
std::string arg = argv[i];
if (arg == "-s" || arg == "--seed") {
params.seed = std::stoi(argv[++i]);
} else if (arg == "-t" || arg == "--threads") {
params.n_threads = std::stoi(argv[++i]);
} else if (arg == "-p" || arg == "--prompt") {
params.prompt = argv[++i];
} else if (arg == "-n" || arg == "--n_predict") {
params.n_predict = std::stoi(argv[++i]);
} else if (arg == "--top_k") {
params.top_k = std::max(1, std::stoi(argv[++i]));
} else if (arg == "--top_p") {
params.top_p = std::stof(argv[++i]);
} else if (arg == "--temp") {
params.temp = std::stof(argv[++i]);
} else if (arg == "--repeat-last-n") {
params.repeat_last_n = std::stof(argv[++i]);
} else if (arg == "--repeat-penalty") {
params.repeat_penalty = std::stof(argv[++i]);
} else if (arg == "-b" || arg == "--batch_size") {
params.n_batch = std::stoi(argv[++i]);
} else if (arg == "-m" || arg == "--model") {
params.model = argv[++i];
} else if (arg == "-i" || arg == "--interactive") {
params.interactive = true;
} else if (arg == "-ip" || arg == "--interactive-port") {
params.interactive = true;
params.interactive_port = std::stoi(argv[++i]);
} else if (arg == "-h" || arg == "--help") {
gpt_print_usage(argc, argv, params);
exit(0);
} else if (arg == "-f" || arg == "--file") {
if (++i > argc) {
fprintf(stderr, "Invalid file param");
break;
}
std::ifstream file(argv[i]);
if (!file) {
fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
break;
}
std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), back_inserter(params.prompt));
if (params.prompt.back() == '\n') {
params.prompt.pop_back();
}
} else if (arg == "-tt" || arg == "--token_test") {
params.token_test = argv[++i];
}
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
gpt_print_usage(argc, argv, params);
exit(0);
}
}
return true;
}
void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, "usage: %s [options]\n", argv[0]);
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help show this help message and exit\n");
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n");
fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
fprintf(stderr, " -p PROMPT, --prompt PROMPT\n");
fprintf(stderr, " prompt to start generation with (default: random)\n");
fprintf(stderr, " -f FNAME, --file FNAME\n");
fprintf(stderr, " load prompt from a file\n");
fprintf(stderr, " -tt TOKEN_TEST, --token_test TOKEN_TEST\n");
fprintf(stderr, " test tokenization\n");
fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d)\n", params.n_predict);
fprintf(stderr, " --top_k N top-k sampling (default: %d)\n", params.top_k);
fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", params.top_p);
fprintf(stderr, " --temp N temperature (default: %.1f)\n", params.temp);
fprintf(stderr, " --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled)\n", params.repeat_last_n);
fprintf(stderr, " --repeat-penalty N penalize repeat sequence of tokens (default: %.2f, 1.0 = disabled)\n", (double)params.repeat_penalty);
fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch);
fprintf(stderr, " -m FNAME, --model FNAME\n");
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
fprintf(stderr, "\n");
}
std::string gpt_random_prompt(std::mt19937 & rng) {
const int r = rng() % 10;
switch (r) {
case 0: return "So";
case 1: return "Once upon a time";
case 2: return "When";
case 3: return "The";
case 4: return "After";
case 5: return "If";
case 6: return "import";
case 7: return "He";
case 8: return "She";
case 9: return "They";
default: return "To";
}
return "The";
}
std::string trim(const std::string & s) {
std::regex e("^\\s+|\\s+$");
return std::regex_replace(s, e, "");
@ -27,6 +141,463 @@ std::string replace(const std::string & s, const std::string & from, const std::
return result;
}
void gpt_vocab::add_special_token(const std::string & token) {
special_tokens.push_back(token);
}
std::map<std::string, int32_t> json_parse(const std::string & fname) {
std::map<std::string, int32_t> result;
// read file into string
std::string json;
{
std::ifstream ifs(fname);
if (!ifs) {
fprintf(stderr, "Failed to open %s\n", fname.c_str());
exit(1);
}
json = std::string((std::istreambuf_iterator<char>(ifs)),
(std::istreambuf_iterator<char>()));
}
if (json[0] != '{') {
return result;
}
// parse json
{
bool has_key = false;
bool in_token = false;
std::string str_key = "";
std::string str_val = "";
int n = json.size();
for (int i = 1; i < n; ++i) {
if (!in_token) {
if (json[i] == ' ') continue;
if (json[i] == '"') {
in_token = true;
continue;
}
} else {
if (json[i] == '\\' && i+1 < n) {
if (has_key == false) {
str_key += json[i];
} else {
str_val += json[i];
}
++i;
} else if (json[i] == '"') {
if (has_key == false) {
has_key = true;
++i;
while (json[i] == ' ') ++i;
++i; // :
while (json[i] == ' ') ++i;
if (json[i] != '\"') {
while (json[i] != ',' && json[i] != '}') {
str_val += json[i++];
}
has_key = false;
} else {
in_token = true;
continue;
}
} else {
has_key = false;
}
str_key = ::replace(str_key, "\\u0120", " " ); // \u0120 -> space
str_key = ::replace(str_key, "\\u010a", "\n"); // \u010a -> new line
str_key = ::replace(str_key, "\\\"", "\""); // \\\" -> "
try {
result[str_key] = std::stoi(str_val);
} catch (...) {
//fprintf(stderr, "%s: ignoring key '%s' with value '%s'\n", fname.c_str(), str_key.c_str(), str_val.c_str());
}
str_key = "";
str_val = "";
in_token = false;
continue;
}
if (has_key == false) {
str_key += json[i];
} else {
str_val += json[i];
}
}
}
}
return result;
}
std::string convert_to_utf8(const std::wstring & input) {
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
return converter.to_bytes(input);
}
std::wstring convert_to_wstring(const std::string & input) {
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
return converter.from_bytes(input);
}
void gpt_split_words(std::string str, std::vector<std::string>& words) {
const std::string pattern = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
const std::regex re(pattern);
std::smatch m;
while (std::regex_search(str, m, re)) {
for (auto x : m) {
words.push_back(x);
}
str = m.suffix();
}
}
std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text) {
std::vector<std::string> words;
// first split the text into words
{
std::string str = text;
// Generate the subpattern from the special_tokens vector if it's not empty
if (!vocab.special_tokens.empty()) {
const std::regex escape(R"([\[\\\^\$\.\|\?\*\+\(\)\{\}])");
std::string special_tokens_subpattern;
for (const auto & token : vocab.special_tokens) {
if (!special_tokens_subpattern.empty()) {
special_tokens_subpattern += "|";
}
special_tokens_subpattern += std::regex_replace(token, escape, R"(\$&)");
}
std::regex re(special_tokens_subpattern);
std::smatch m;
// Split the text by special tokens.
while (std::regex_search(str, m, re)) {
// Split the substrings in-between special tokens into words.
gpt_split_words(m.prefix(), words);
// Add matched special tokens as words.
for (auto x : m) {
words.push_back(x);
}
str = m.suffix();
}
// Remaining text without special tokens will be handled below.
}
gpt_split_words(str, words);
}
// find the longest token that forms each word in words:
std::vector<gpt_vocab::id> tokens;
for (const auto & word : words) {
for (int i = 0; i < (int) word.size(); ){
for (int j = word.size() - 1; j >= i; j--){
auto cand = word.substr(i, j-i+1);
auto it = vocab.token_to_id.find(cand);
if (it != vocab.token_to_id.end()){ // word.substr(i, j-i+1) in vocab
tokens.push_back(it->second);
i = j + 1;
break;
}
else if (j == i){ // word.substr(i, 1) has no matching
fprintf(stderr, "%s: unknown token '%s'\n", __func__, word.substr(i, 1).data());
i++;
}
}
}
}
return tokens;
}
std::vector<gpt_vocab::id> parse_tokens_from_string(const std::string& input, char delimiter) {
std::vector<gpt_vocab::id> output;
std::stringstream ss(input);
std::string token;
while (std::getline(ss, token, delimiter)) {
output.push_back(std::stoi(token));
}
return output;
}
std::map<std::string, std::vector<gpt_vocab::id>> extract_tests_from_file(const std::string & fpath_test){
if (fpath_test.empty()){
fprintf(stderr, "%s : No test file found.\n", __func__);
return std::map<std::string, std::vector<gpt_vocab::id>>();
}
std::map<std::string, std::vector<gpt_vocab::id>> tests;
auto fin = std::ifstream(fpath_test, std::ios_base::in);
const char * delimeter = " => ";
const char del_tok = ',';
std::string line;
while (std::getline(fin, line)) {
size_t delimiterPos = line.find(delimeter);
if (delimiterPos != std::string::npos) {
std::string text = line.substr(0, delimiterPos);
std::string s_tokens = line.substr(delimiterPos + std::strlen(delimeter));
tests[text] = parse_tokens_from_string(s_tokens, del_tok);
}
}
return tests;
}
void test_gpt_tokenizer(gpt_vocab & vocab, const std::string & fpath_test){
std::map<std::string, std::vector<gpt_vocab::id>> tests = extract_tests_from_file(fpath_test);
size_t n_fails = 0;
for (const auto & test : tests) {
std::vector<gpt_vocab::id> tokens = gpt_tokenize(vocab, test.first);
if (tokens != test.second){
n_fails++;
// print out failure cases
fprintf(stderr, "%s : failed test: '%s'\n", __func__, test.first.c_str());
fprintf(stderr, "%s : tokens in hf: ", __func__);
for (const auto & t : test.second) {
fprintf(stderr, "%s(%d), ", vocab.id_to_token[t].c_str(), t);
}
fprintf(stderr, "\n");
fprintf(stderr, "%s : tokens in ggml: ", __func__);
for (const auto & t : tokens) {
fprintf(stderr, "%s(%d), ", vocab.id_to_token[t].c_str(), t);
}
fprintf(stderr, "\n");
}
}
fprintf(stderr, "%s : %zu tests failed out of %zu tests.\n", __func__, n_fails, tests.size());
}
bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) {
printf("%s: loading vocab from '%s'\n", __func__, fname.c_str());
vocab.token_to_id = ::json_parse(fname);
for (const auto & kv : vocab.token_to_id) {
vocab.id_to_token[kv.second] = kv.first;
}
printf("%s: vocab size = %d\n", __func__, (int) vocab.token_to_id.size());
// print the vocabulary
//for (auto kv : vocab.token_to_id) {
// printf("'%s' -> %d\n", kv.first.data(), kv.second);
//}
return true;
}
gpt_vocab::id gpt_sample_top_k_top_p(
const gpt_vocab & vocab,
const float * logits,
int top_k,
double top_p,
double temp,
std::mt19937 & rng) {
int n_logits = vocab.id_to_token.size();
std::vector<std::pair<double, gpt_vocab::id>> logits_id;
logits_id.reserve(n_logits);
{
const double scale = 1.0/temp;
for (int i = 0; i < n_logits; ++i) {
logits_id.push_back(std::make_pair(logits[i]*scale, i));
}
}
// find the top K tokens
std::partial_sort(
logits_id.begin(),
logits_id.begin() + top_k, logits_id.end(),
[](const std::pair<double, gpt_vocab::id> & a, const std::pair<double, gpt_vocab::id> & b) {
return a.first > b.first;
});
logits_id.resize(top_k);
double maxl = -INFINITY;
for (const auto & kv : logits_id) {
maxl = std::max(maxl, kv.first);
}
// compute probs for the top K tokens
std::vector<double> probs;
probs.reserve(logits_id.size());
double sum = 0.0;
for (const auto & kv : logits_id) {
double p = exp(kv.first - maxl);
probs.push_back(p);
sum += p;
}
// normalize the probs
for (auto & p : probs) {
p /= sum;
}
if (top_p < 1.0f) {
double cumsum = 0.0f;
for (int i = 0; i < top_k; i++) {
cumsum += probs[i];
if (cumsum >= top_p) {
top_k = i + 1;
probs.resize(top_k);
logits_id.resize(top_k);
break;
}
}
cumsum = 1.0/cumsum;
for (int i = 0; i < (int) probs.size(); i++) {
probs[i] *= cumsum;
}
}
//printf("\n");
//for (int i = 0; i < (int) probs.size(); i++) {
// printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]);
//}
//exit(0);
std::discrete_distribution<> dist(probs.begin(), probs.end());
int idx = dist(rng);
return logits_id[idx].second;
}
gpt_vocab::id gpt_sample_top_k_top_p_repeat(
const gpt_vocab & vocab,
const float * logits,
const int32_t * last_n_tokens_data,
size_t last_n_tokens_data_size,
int top_k,
double top_p,
double temp,
int repeat_last_n,
float repeat_penalty,
std::mt19937 & rng) {
int n_logits = vocab.id_to_token.size();
const auto * plogits = logits;
const auto last_n_tokens = std::vector<int32_t>(last_n_tokens_data, last_n_tokens_data + last_n_tokens_data_size);
if (temp <= 0) {
// select the token with the highest logit directly
float max_logit = plogits[0];
gpt_vocab::id max_id = 0;
for (int i = 1; i < n_logits; ++i) {
if (plogits[i] > max_logit) {
max_logit = plogits[i];
max_id = i;
}
}
return max_id;
}
std::vector<std::pair<double, gpt_vocab::id>> logits_id;
logits_id.reserve(n_logits);
{
const float scale = 1.0f/temp;
for (int i = 0; i < n_logits; ++i) {
// repetition penalty from ctrl paper (https://arxiv.org/abs/1909.05858)
// credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main
if (repeat_last_n > 0 && std::find(last_n_tokens.end()-repeat_last_n, last_n_tokens.end(), i) != last_n_tokens.end()) {
// if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
if (plogits[i] < 0.0f) {
logits_id.push_back(std::make_pair(plogits[i]*scale*repeat_penalty, i));
} else {
logits_id.push_back(std::make_pair(plogits[i]*scale/repeat_penalty, i));
}
} else {
logits_id.push_back(std::make_pair(plogits[i]*scale, i));
}
}
}
// find the top K tokens
std::partial_sort(
logits_id.begin(),
logits_id.begin() + top_k, logits_id.end(),
[](const std::pair<double, gpt_vocab::id> & a, const std::pair<double, gpt_vocab::id> & b) {
return a.first > b.first;
});
logits_id.resize(top_k);
double maxl = -INFINITY;
for (const auto & kv : logits_id) {
maxl = std::max(maxl, kv.first);
}
// compute probs for the top K tokens
std::vector<double> probs;
probs.reserve(logits_id.size());
double sum = 0.0;
for (const auto & kv : logits_id) {
double p = exp(kv.first - maxl);
probs.push_back(p);
sum += p;
}
// normalize the probs
for (auto & p : probs) {
p /= sum;
}
if (top_p < 1.0f) {
double cumsum = 0.0f;
for (int i = 0; i < top_k; i++) {
cumsum += probs[i];
if (cumsum >= top_p) {
top_k = i + 1;
probs.resize(top_k);
logits_id.resize(top_k);
break;
}
}
cumsum = 1.0/cumsum;
for (int i = 0; i < (int) probs.size(); i++) {
probs[i] *= cumsum;
}
}
// printf("\n");
// for (int i = 0; i < (int) probs.size(); i++) {
// for (int i = 0; i < 10; i++) {
// printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]);
// }
std::discrete_distribution<> dist(probs.begin(), probs.end());
int idx = dist(rng);
return logits_id[idx].second;
}
bool read_wav(const std::string & fname, std::vector<float>& pcmf32, std::vector<std::vector<float>>& pcmf32s, bool stereo) {
drwav wav;
std::vector<uint8_t> wav_data; // used for pipe input from stdin
@ -160,3 +731,27 @@ bool vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float
return true;
}
float similarity(const std::string & s0, const std::string & s1) {
const size_t len0 = s0.size() + 1;
const size_t len1 = s1.size() + 1;
std::vector<int> col(len1, 0);
std::vector<int> prevCol(len1, 0);
for (size_t i = 0; i < len1; i++) {
prevCol[i] = i;
}
for (size_t i = 0; i < len0; i++) {
col[0] = i;
for (size_t j = 1; j < len1; j++) {
col[j] = std::min(std::min(1 + col[j - 1], 1 + prevCol[j]), prevCol[j - 1] + (i > 0 && s0[i - 1] == s1[j - 1] ? 0 : 1));
}
col.swap(prevCol);
}
const float dist = prevCol[len1 - 1];
return 1.0f - (dist / std::max(s0.size(), s1.size()));
}

View File

@ -1,10 +1,49 @@
// Various helper functions and utilities
#pragma once
// needs to match WHISPER_SAMPLE_RATE
#include <string>
#include <map>
#include <vector>
#include <random>
#include <thread>
#define COMMON_SAMPLE_RATE 16000
#include <vector>
#include <string>
//
// CLI argument parsing
//
struct gpt_params {
int32_t seed = -1; // RNG seed
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
int32_t n_predict = 200; // new tokens to predict
int32_t n_batch = 8; // batch size for prompt processing
// sampling parameters
int32_t top_k = 40;
float top_p = 0.9f;
float temp = 0.9f;
int32_t repeat_last_n = 64;
float repeat_penalty = 1.00f;
std::string model = "models/gpt-2-117M/ggml-model.bin"; // model path
std::string prompt = "";
std::string token_test = "";
bool interactive = false;
int32_t interactive_port = -1;
};
bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
void gpt_print_usage(int argc, char ** argv, const gpt_params & params);
std::string gpt_random_prompt(std::mt19937 & rng);
//
// Vocab utils
//
std::string trim(const std::string & s);
@ -13,6 +52,82 @@ std::string replace(
const std::string & from,
const std::string & to);
struct gpt_vocab {
using id = int32_t;
using token = std::string;
std::map<token, id> token_to_id;
std::map<id, token> id_to_token;
std::vector<std::string> special_tokens;
void add_special_token(const std::string & token);
};
// poor-man's JSON parsing
std::map<std::string, int32_t> json_parse(const std::string & fname);
std::string convert_to_utf8(const std::wstring & input);
std::wstring convert_to_wstring(const std::string & input);
void gpt_split_words(std::string str, std::vector<std::string>& words);
// split text into tokens
//
// ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53
//
// Regex (Python):
// r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
//
// Regex (C++):
// R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"
//
std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text);
// test outputs of gpt_tokenize
//
// - compare with tokens generated by the huggingface tokenizer
// - test cases are chosen based on the model's main language (under 'prompt' directory)
// - if all sentences are tokenized identically, print 'All tests passed.'
// - otherwise, print sentence, huggingface tokens, ggml tokens
//
void test_gpt_tokenizer(gpt_vocab & vocab, const std::string & fpath_test);
// load the tokens from encoder.json
bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab);
// sample next token given probabilities for each embedding
//
// - consider only the top K tokens
// - from them, consider only the top tokens with cumulative probability > P
//
// TODO: not sure if this implementation is correct
// TODO: temperature is not implemented
//
gpt_vocab::id gpt_sample_top_k_top_p(
const gpt_vocab & vocab,
const float * logits,
int top_k,
double top_p,
double temp,
std::mt19937 & rng);
gpt_vocab::id gpt_sample_top_k_top_p_repeat(
const gpt_vocab & vocab,
const float * logits,
const int32_t * last_n_tokens_data,
size_t last_n_tokens_data_size,
int top_k,
double top_p,
double temp,
int repeat_last_n,
float repeat_penalty,
std::mt19937 & rng);
//
// Audio utils
//
// Read WAV audio file and store the PCM data into pcmf32
// The sample rate of the audio must be equal to COMMON_SAMPLE_RATE
// If stereo flag is set and the audio has 2 channels, the pcmf32s will contain 2 channel PCM
@ -38,3 +153,5 @@ bool vad_simple(
float freq_thold,
bool verbose);
// compute similarity between two strings using Levenshtein distance
float similarity(const std::string & s0, const std::string & s1);

View File

@ -145,7 +145,15 @@ function loadRemote(url, dst, size_mb, cbProgress, cbReady, cbCancel, cbPrint) {
var db = event.target.result;
var tx = db.transaction(['models'], 'readwrite');
var os = tx.objectStore('models');
var rq = os.put(data, url);
var rq = null;
try {
var rq = os.put(data, url);
} catch (e) {
cbPrint('loadRemote: failed to store "' + url + '" in the IndexedDB: \n' + e);
cbCancel();
return;
}
rq.onsuccess = function (event) {
cbPrint('loadRemote: "' + url + '" stored in the IndexedDB');
@ -180,7 +188,6 @@ function loadRemote(url, dst, size_mb, cbProgress, cbReady, cbCancel, cbPrint) {
rq.onabort = function (event) {
cbPrint('loadRemote: failed to open IndexedDB: abort');
cbCancel();
};
}

View File

@ -8,6 +8,11 @@
#include <string>
#include <thread>
#include <vector>
#include <cstring>
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
// Terminal color map. 10 colors grouped in ranges [0.0, 0.1, ..., 0.9]
// Lowest is red, middle is yellow, highest is green.
@ -56,33 +61,41 @@ struct whisper_params {
int32_t duration_ms = 0;
int32_t max_context = -1;
int32_t max_len = 0;
int32_t best_of = 5;
int32_t best_of = 2;
int32_t beam_size = -1;
float word_thold = 0.01f;
float entropy_thold = 2.40f;
float logprob_thold = -1.00f;
bool speed_up = false;
bool translate = false;
bool diarize = false;
bool split_on_word = false;
bool no_fallback = false;
bool output_txt = false;
bool output_vtt = false;
bool output_srt = false;
bool output_wts = false;
bool output_csv = false;
bool output_jsn = false;
bool print_special = false;
bool print_colors = false;
bool print_progress = false;
bool no_timestamps = false;
bool speed_up = false;
bool translate = false;
bool detect_language = false;
bool diarize = false;
bool tinydiarize = false;
bool split_on_word = false;
bool no_fallback = false;
bool output_txt = false;
bool output_vtt = false;
bool output_srt = false;
bool output_wts = false;
bool output_csv = false;
bool output_jsn = false;
bool output_lrc = false;
bool print_special = false;
bool print_colors = false;
bool print_progress = false;
bool no_timestamps = false;
std::string language = "en";
std::string language = "en";
std::string prompt;
std::string font_path = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
std::string model = "models/ggml-base.en.bin";
std::string model = "models/ggml-base.en.bin";
// [TDRZ] speaker turn string
std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line
std::string openvino_encode_device = "CPU";
std::vector<std::string> fname_inp = {};
std::vector<std::string> fname_out = {};
@ -108,39 +121,43 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
whisper_print_usage(argc, argv, params);
exit(0);
}
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(argv[++i]); }
else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(argv[++i]); }
else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(argv[++i]); }
else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); }
else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); }
else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); }
else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); }
else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); }
else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); }
else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); }
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; }
else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; }
else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; }
else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; }
else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; }
else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; }
else if (arg == "-oj" || arg == "--output-json") { params.output_jsn = true; }
else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(argv[++i]); }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
else if ( arg == "--prompt") { params.prompt = argv[++i]; }
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); }
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(argv[++i]); }
else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(argv[++i]); }
else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(argv[++i]); }
else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); }
else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); }
else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); }
else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); }
else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); }
else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); }
else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); }
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; }
else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; }
else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; }
else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; }
else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; }
else if (arg == "-olrc" || arg == "--output-lrc") { params.output_lrc = true; }
else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; }
else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; }
else if (arg == "-oj" || arg == "--output-json") { params.output_jsn = true; }
else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(argv[++i]); }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
else if (arg == "-dl" || arg == "--detect-language") { params.detect_language = true; }
else if ( arg == "--prompt") { params.prompt = argv[++i]; }
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); }
else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; }
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params);
@ -173,10 +190,12 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false");
fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false");
fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false");
fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false");
fprintf(stderr, " -olrc, --output-lrc [%-7s] output result in a lrc file\n", params.output_lrc ? "true" : "false");
fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false");
fprintf(stderr, " -fp, --font-path [%-7s] path to a monospace font for karaoke video\n", params.font_path.c_str());
fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false");
@ -185,11 +204,13 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true");
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false");
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
fprintf(stderr, " -dl, --detect-language [%-7s] exit after automatically detecting language\n", params.detect_language ? "true" : "false");
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str());
fprintf(stderr, "\n");
}
@ -199,6 +220,39 @@ struct whisper_print_user_data {
const std::vector<std::vector<float>> * pcmf32s;
};
std::string estimate_diarization_speaker(std::vector<std::vector<float>> pcmf32s, int64_t t0, int64_t t1, bool id_only = false) {
std::string speaker = "";
const int64_t n_samples = pcmf32s[0].size();
const int64_t is0 = timestamp_to_sample(t0, n_samples);
const int64_t is1 = timestamp_to_sample(t1, n_samples);
double energy0 = 0.0f;
double energy1 = 0.0f;
for (int64_t j = is0; j < is1; j++) {
energy0 += fabs(pcmf32s[0][j]);
energy1 += fabs(pcmf32s[1][j]);
}
if (energy0 > 1.1*energy1) {
speaker = "0";
} else if (energy1 > 1.1*energy0) {
speaker = "1";
} else {
speaker = "?";
}
//printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, speaker = %s\n", is0, is1, energy0, energy1, speaker.c_str());
if (!id_only) {
speaker.insert(0, "(speaker ");
speaker.append(")");
}
return speaker;
}
void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper_state * /*state*/, int n_new, void * user_data) {
const auto & params = *((whisper_print_user_data *) user_data)->params;
const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s;
@ -207,8 +261,8 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper
std::string speaker = "";
int64_t t0;
int64_t t1;
int64_t t0 = 0;
int64_t t1 = 0;
// print the last n_new segments
const int s0 = n_segments - n_new;
@ -228,28 +282,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper
}
if (params.diarize && pcmf32s.size() == 2) {
const int64_t n_samples = pcmf32s[0].size();
const int64_t is0 = timestamp_to_sample(t0, n_samples);
const int64_t is1 = timestamp_to_sample(t1, n_samples);
double energy0 = 0.0f;
double energy1 = 0.0f;
for (int64_t j = is0; j < is1; j++) {
energy0 += fabs(pcmf32s[0][j]);
energy1 += fabs(pcmf32s[1][j]);
}
if (energy0 > 1.1*energy1) {
speaker = "(speaker 0)";
} else if (energy1 > 1.1*energy0) {
speaker = "(speaker 1)";
} else {
speaker = "(speaker ?)";
}
//printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str());
speaker = estimate_diarization_speaker(pcmf32s, t0, t1);
}
if (params.print_colors) {
@ -274,6 +307,12 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper
printf("%s%s", speaker.c_str(), text);
}
if (params.tinydiarize) {
if (whisper_full_get_segment_speaker_turn_next(ctx, i)) {
printf("%s", params.tdrz_speaker_turn.c_str());
}
}
// with timestamps or speakers: each segment on new line
if (!params.no_timestamps || params.diarize) {
printf("\n");
@ -283,7 +322,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper
}
}
bool output_txt(struct whisper_context * ctx, const char * fname) {
bool output_txt(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
std::ofstream fout(fname);
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
@ -295,13 +334,22 @@ bool output_txt(struct whisper_context * ctx, const char * fname) {
const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
fout << text << "\n";
std::string speaker = "";
if (params.diarize && pcmf32s.size() == 2)
{
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
speaker = estimate_diarization_speaker(pcmf32s, t0, t1);
}
fout << speaker << text << "\n";
}
return true;
}
bool output_vtt(struct whisper_context * ctx, const char * fname) {
bool output_vtt(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
std::ofstream fout(fname);
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
@ -317,15 +365,23 @@ bool output_vtt(struct whisper_context * ctx, const char * fname) {
const char * text = whisper_full_get_segment_text(ctx, i);
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
std::string speaker = "";
if (params.diarize && pcmf32s.size() == 2)
{
speaker = estimate_diarization_speaker(pcmf32s, t0, t1, true);
speaker.insert(0, "<v Speaker");
speaker.append(">");
}
fout << to_timestamp(t0) << " --> " << to_timestamp(t1) << "\n";
fout << text << "\n\n";
fout << speaker << text << "\n\n";
}
return true;
}
bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_params & params) {
bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
std::ofstream fout(fname);
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
@ -339,16 +395,53 @@ bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_
const char * text = whisper_full_get_segment_text(ctx, i);
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
std::string speaker = "";
if (params.diarize && pcmf32s.size() == 2)
{
speaker = estimate_diarization_speaker(pcmf32s, t0, t1);
}
fout << i + 1 + params.offset_n << "\n";
fout << to_timestamp(t0, true) << " --> " << to_timestamp(t1, true) << "\n";
fout << text << "\n\n";
fout << speaker << text << "\n\n";
}
return true;
}
bool output_csv(struct whisper_context * ctx, const char * fname) {
char *escape_double_quotes_and_backslashes(const char *str) {
if (str == NULL) {
return NULL;
}
size_t escaped_length = strlen(str) + 1;
for (size_t i = 0; str[i] != '\0'; i++) {
if (str[i] == '"' || str[i] == '\\') {
escaped_length++;
}
}
char *escaped = (char *)calloc(escaped_length, 1); // pre-zeroed
if (escaped == NULL) {
return NULL;
}
size_t pos = 0;
for (size_t i = 0; str[i] != '\0'; i++) {
if (str[i] == '"' || str[i] == '\\') {
escaped[pos++] = '\\';
}
escaped[pos++] = str[i];
}
// no need to set zero due to calloc() being used prior
return escaped;
}
bool output_csv(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
std::ofstream fout(fname);
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
@ -358,20 +451,32 @@ bool output_csv(struct whisper_context * ctx, const char * fname) {
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
const int n_segments = whisper_full_n_segments(ctx);
fout << "start,end,text\n";
fout << "start,end,";
if (params.diarize && pcmf32s.size() == 2)
{
fout << "speaker,";
}
fout << "text\n";
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
char * text_escaped = escape_double_quotes_and_backslashes(text);
//need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds.
fout << 10 * t0 << "," << 10 * t1 << ",\"" << text << "\"\n";
fout << 10 * t0 << "," << 10 * t1 << ",";
if (params.diarize && pcmf32s.size() == 2)
{
fout << estimate_diarization_speaker(pcmf32s, t0, t1, true) << ",";
}
fout << "\"" << text_escaped << "\"\n";
}
return true;
}
bool output_json(struct whisper_context * ctx, const char * fname, const whisper_params & params) {
bool output_json(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
std::ofstream fout(fname);
int indent = 0;
@ -385,13 +490,13 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper
indent++;
};
auto end_arr = [&](bool end = false) {
auto end_arr = [&](bool end) {
indent--;
doindent();
fout << (end ? "]\n" : "},\n");
};
auto start_obj = [&](const char *name = nullptr) {
auto start_obj = [&](const char *name) {
doindent();
if (name) {
fout << "\"" << name << "\": {\n";
@ -401,7 +506,7 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper
indent++;
};
auto end_obj = [&](bool end = false) {
auto end_obj = [&](bool end) {
indent--;
doindent();
fout << (end ? "}\n" : "},\n");
@ -412,22 +517,24 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper
fout << "\"" << name << "\": ";
};
auto value_s = [&](const char *name, const char *val, bool end = false) {
auto value_s = [&](const char *name, const char *val, bool end) {
start_value(name);
fout << "\"" << val << (end ? "\"\n" : "\",\n");
char * val_escaped = escape_double_quotes_and_backslashes(val);
fout << "\"" << val_escaped << (end ? "\"\n" : "\",\n");
free(val_escaped);
};
auto end_value = [&](bool end = false) {
auto end_value = [&](bool end) {
fout << (end ? "\n" : ",\n");
};
auto value_i = [&](const char *name, const int64_t val, bool end = false) {
auto value_i = [&](const char *name, const int64_t val, bool end) {
start_value(name);
fout << val;
end_value(end);
};
auto value_b = [&](const char *name, const bool val, bool end = false) {
auto value_b = [&](const char *name, const bool val, bool end) {
start_value(name);
fout << (val ? "true" : "false");
end_value(end);
@ -439,53 +546,62 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper
}
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
start_obj();
value_s("systeminfo", whisper_print_system_info());
start_obj(nullptr);
value_s("systeminfo", whisper_print_system_info(), false);
start_obj("model");
value_s("type", whisper_model_type_readable(ctx));
value_b("multilingual", whisper_is_multilingual(ctx));
value_i("vocab", whisper_model_n_vocab(ctx));
value_s("type", whisper_model_type_readable(ctx), false);
value_b("multilingual", whisper_is_multilingual(ctx), false);
value_i("vocab", whisper_model_n_vocab(ctx), false);
start_obj("audio");
value_i("ctx", whisper_model_n_audio_ctx(ctx));
value_i("state", whisper_model_n_audio_state(ctx));
value_i("head", whisper_model_n_audio_head(ctx));
value_i("ctx", whisper_model_n_audio_ctx(ctx), false);
value_i("state", whisper_model_n_audio_state(ctx), false);
value_i("head", whisper_model_n_audio_head(ctx), false);
value_i("layer", whisper_model_n_audio_layer(ctx), true);
end_obj();
end_obj(false);
start_obj("text");
value_i("ctx", whisper_model_n_text_ctx(ctx));
value_i("state", whisper_model_n_text_state(ctx));
value_i("head", whisper_model_n_text_head(ctx));
value_i("leyer", whisper_model_n_text_layer(ctx), true);
end_obj();
value_i("mels", whisper_model_n_mels(ctx));
value_i("f16", whisper_model_f16(ctx), true);
end_obj();
value_i("ctx", whisper_model_n_text_ctx(ctx), false);
value_i("state", whisper_model_n_text_state(ctx), false);
value_i("head", whisper_model_n_text_head(ctx), false);
value_i("layer", whisper_model_n_text_layer(ctx), true);
end_obj(false);
value_i("mels", whisper_model_n_mels(ctx), false);
value_i("ftype", whisper_model_ftype(ctx), true);
end_obj(false);
start_obj("params");
value_s("model", params.model.c_str());
value_s("language", params.language.c_str());
value_s("model", params.model.c_str(), false);
value_s("language", params.language.c_str(), false);
value_b("translate", params.translate, true);
end_obj();
end_obj(false);
start_obj("result");
value_s("language", whisper_lang_str(whisper_full_lang_id(ctx)), true);
end_obj();
end_obj(false);
start_arr("transcription");
const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
start_obj();
start_obj("timestanps");
value_s("from", to_timestamp(t0, true).c_str());
start_obj(nullptr);
start_obj("timestamps");
value_s("from", to_timestamp(t0, true).c_str(), false);
value_s("to", to_timestamp(t1, true).c_str(), true);
end_obj();
end_obj(false);
start_obj("offsets");
value_i("from", t0 * 10);
value_i("from", t0 * 10, false);
value_i("to", t1 * 10, true);
end_obj();
value_s("text", text, true);
end_obj(false);
value_s("text", text, !params.diarize && !params.tinydiarize);
if (params.diarize && pcmf32s.size() == 2) {
value_s("speaker", estimate_diarization_speaker(pcmf32s, t0, t1, true).c_str(), true);
}
if (params.tinydiarize) {
value_b("speaker_turn_next", whisper_full_get_segment_speaker_turn_next(ctx, i), true);
}
end_obj(i == (n_segments - 1));
}
@ -497,7 +613,7 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper
// karaoke video generation
// outputs a bash script that uses ffmpeg to generate a video with the subtitles
// TODO: font parameter adjustments
bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, float t_sec) {
bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, float t_sec, std::vector<std::vector<float>> pcmf32s) {
std::ofstream fout(fname);
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
@ -534,6 +650,11 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
fout << "drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='':enable='between(t," << t0/100.0 << "," << t0/100.0 << ")'";
bool is_first = true;
std::string speaker = "";
if (params.diarize && pcmf32s.size() == 2) {
speaker = estimate_diarization_speaker(pcmf32s, t0, t1);
}
for (int j = 0; j < n; ++j) {
const auto & token = tokens[j];
@ -542,13 +663,19 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
continue;
}
std::string txt_bg;
std::string txt_fg; // highlight token
std::string txt_ul; // underline
std::string txt_bg = "";
std::string txt_fg = ""; // highlight token
std::string txt_ul = ""; // underline
txt_bg = "> ";
txt_fg = "> ";
txt_ul = "\\ \\ ";
if (params.diarize && pcmf32s.size() == 2) {
txt_bg = speaker;
txt_fg = speaker;
txt_ul = "\\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ ";
}
txt_bg.append("> ");
txt_fg.append("> ");
txt_ul.append("\\ \\ ");
{
for (int k = 0; k < n; ++k) {
@ -611,10 +738,51 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
return true;
}
bool output_lrc(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
std::ofstream fout(fname);
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
return false;
}
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
fout << "[by:whisper.cpp]\n";
const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
const int64_t t = whisper_full_get_segment_t0(ctx, i);
int64_t msec = t * 10;
int64_t min = msec / (1000 * 60);
msec = msec - min * (1000 * 60);
int64_t sec = msec / 1000;
msec = msec - sec * 1000;
char buf[16];
snprintf(buf, sizeof(buf), "%02d:%02d.%02d", (int) min, (int) sec, (int) ( msec / 10));
std::string timestamp_lrc = std::string(buf);
std::string speaker = "";
if (params.diarize && pcmf32s.size() == 2)
{
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
speaker = estimate_diarization_speaker(pcmf32s, t0, t1);
}
fout << '[' << timestamp_lrc << ']' << speaker << text << "\n";
}
return true;
}
int main(int argc, char ** argv) {
whisper_params params;
if (whisper_params_parse(argc, argv, params) == false) {
whisper_print_usage(argc, argv, params);
return 1;
}
@ -630,6 +798,12 @@ int main(int argc, char ** argv) {
exit(0);
}
if (params.diarize && params.tinydiarize) {
fprintf(stderr, "error: cannot use both --diarize and --tinydiarize\n");
whisper_print_usage(argc, argv, params);
exit(0);
}
// whisper init
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
@ -639,21 +813,8 @@ int main(int argc, char ** argv) {
return 3;
}
// initial prompt
std::vector<whisper_token> prompt_tokens;
if (!params.prompt.empty()) {
prompt_tokens.resize(1024);
prompt_tokens.resize(whisper_tokenize(ctx, params.prompt.c_str(), prompt_tokens.data(), prompt_tokens.size()));
fprintf(stderr, "\n");
fprintf(stderr, "initial prompt: '%s'\n", params.prompt.c_str());
fprintf(stderr, "initial tokens: [ ");
for (int i = 0; i < (int) prompt_tokens.size(); ++i) {
fprintf(stderr, "%d ", prompt_tokens[i]);
}
fprintf(stderr, "]\n");
}
// initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured
whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr);
for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
const auto fname_inp = params.fname_inp[f];
@ -684,11 +845,15 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
}
}
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, timestamps = %d ...\n",
if (params.detect_language) {
params.language = "auto";
}
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, %stimestamps = %d ...\n",
__func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
params.n_threads, params.n_processors,
params.language.c_str(),
params.translate ? "translate" : "transcribe",
params.tinydiarize ? "tdrz = 1, " : "",
params.no_timestamps ? 0 : 1);
fprintf(stderr, "\n");
@ -706,6 +871,7 @@ int main(int argc, char ** argv) {
wparams.print_special = params.print_special;
wparams.translate = params.translate;
wparams.language = params.language.c_str();
wparams.detect_language = params.detect_language;
wparams.n_threads = params.n_threads;
wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
wparams.offset_ms = params.offset_t_ms;
@ -718,8 +884,9 @@ int main(int argc, char ** argv) {
wparams.speed_up = params.speed_up;
wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();
wparams.tdrz_enable = params.tinydiarize; // [TDRZ]
wparams.initial_prompt = params.prompt.c_str();
wparams.greedy.best_of = params.best_of;
wparams.beam_search.beam_size = params.beam_size;
@ -762,37 +929,43 @@ int main(int argc, char ** argv) {
// output to text file
if (params.output_txt) {
const auto fname_txt = fname_out + ".txt";
output_txt(ctx, fname_txt.c_str());
output_txt(ctx, fname_txt.c_str(), params, pcmf32s);
}
// output to VTT file
if (params.output_vtt) {
const auto fname_vtt = fname_out + ".vtt";
output_vtt(ctx, fname_vtt.c_str());
output_vtt(ctx, fname_vtt.c_str(), params, pcmf32s);
}
// output to SRT file
if (params.output_srt) {
const auto fname_srt = fname_out + ".srt";
output_srt(ctx, fname_srt.c_str(), params);
output_srt(ctx, fname_srt.c_str(), params, pcmf32s);
}
// output to WTS file
if (params.output_wts) {
const auto fname_wts = fname_out + ".wts";
output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE);
output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE, pcmf32s);
}
// output to CSV file
if (params.output_csv) {
const auto fname_csv = fname_out + ".csv";
output_csv(ctx, fname_csv.c_str());
output_csv(ctx, fname_csv.c_str(), params, pcmf32s);
}
// output to JSON file
if (params.output_jsn) {
const auto fname_jsn = fname_out + ".json";
output_json(ctx, fname_jsn.c_str(), params);
output_json(ctx, fname_jsn.c_str(), params, pcmf32s);
}
// output to LRC file
if (params.output_lrc) {
const auto fname_lrc = fname_out + ".lrc";
output_lrc(ctx, fname_lrc.c_str(), params, pcmf32s);
}
}
}

View File

@ -0,0 +1,6 @@
set(TARGET quantize)
add_executable(${TARGET} quantize.cpp)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE common whisper ${CMAKE_THREAD_LIBS_INIT})

View File

@ -0,0 +1,3 @@
# quantize
Tool for integer quantization of Whisper `ggml` model files

View File

@ -0,0 +1,223 @@
#include "ggml.h"
#include "common.h"
#include "common-ggml.h"
#include <cassert>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <fstream>
#include <map>
#include <string>
#include <vector>
#include <regex>
// default hparams (Whisper tiny)
struct whisper_hparams {
int32_t n_vocab = 51864;
int32_t n_audio_ctx = 1500;
int32_t n_audio_state = 384;
int32_t n_audio_head = 6;
int32_t n_audio_layer = 4;
int32_t n_text_ctx = 448;
int32_t n_text_state = 384;
int32_t n_text_head = 6;
int32_t n_text_layer = 4;
int32_t n_mels = 80;
int32_t ftype = 1;
};
struct whisper_filters {
int32_t n_mel;
int32_t n_fft;
std::vector<float> data;
};
// quantize a model
bool whisper_model_quantize(const std::string & fname_inp, const std::string & fname_out, ggml_ftype ftype) {
gpt_vocab vocab;
printf("%s: loading model from '%s'\n", __func__, fname_inp.c_str());
auto finp = std::ifstream(fname_inp, std::ios::binary);
if (!finp) {
fprintf(stderr, "%s: failed to open '%s' for reading\n", __func__, fname_inp.c_str());
return false;
}
auto fout = std::ofstream(fname_out, std::ios::binary);
if (!fout) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_out.c_str());
return false;
}
// verify magic
{
uint32_t magic;
finp.read((char *) &magic, sizeof(magic));
if (magic != GGML_FILE_MAGIC) {
fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname_inp.c_str());
return false;
}
fout.write((char *) &magic, sizeof(magic));
}
whisper_hparams hparams;
// load hparams
{
finp.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
finp.read((char *) &hparams.n_audio_ctx, sizeof(hparams.n_audio_ctx));
finp.read((char *) &hparams.n_audio_state, sizeof(hparams.n_audio_state));
finp.read((char *) &hparams.n_audio_head, sizeof(hparams.n_audio_head));
finp.read((char *) &hparams.n_audio_layer, sizeof(hparams.n_audio_layer));
finp.read((char *) &hparams.n_text_ctx, sizeof(hparams.n_text_ctx));
finp.read((char *) &hparams.n_text_state, sizeof(hparams.n_text_state));
finp.read((char *) &hparams.n_text_head, sizeof(hparams.n_text_head));
finp.read((char *) &hparams.n_text_layer, sizeof(hparams.n_text_layer));
finp.read((char *) &hparams.n_mels, sizeof(hparams.n_mels));
finp.read((char *) &hparams.ftype, sizeof(hparams.ftype));
const int32_t qntvr_src = hparams.ftype / GGML_QNT_VERSION_FACTOR;
const int32_t ftype_dst = GGML_QNT_VERSION * GGML_QNT_VERSION_FACTOR + ftype;
fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab);
fprintf(stderr, "%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx);
fprintf(stderr, "%s: n_audio_state = %d\n", __func__, hparams.n_audio_state);
fprintf(stderr, "%s: n_audio_head = %d\n", __func__, hparams.n_audio_head);
fprintf(stderr, "%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer);
fprintf(stderr, "%s: n_text_ctx = %d\n", __func__, hparams.n_text_ctx);
fprintf(stderr, "%s: n_text_state = %d\n", __func__, hparams.n_text_state);
fprintf(stderr, "%s: n_text_head = %d\n", __func__, hparams.n_text_head);
fprintf(stderr, "%s: n_text_layer = %d\n", __func__, hparams.n_text_layer);
fprintf(stderr, "%s: n_mels = %d\n", __func__, hparams.n_mels);
fprintf(stderr, "%s: ftype (src) = %d\n", __func__, hparams.ftype);
fprintf(stderr, "%s: qntvr (src) = %d\n", __func__, qntvr_src);
fprintf(stderr, "%s: ftype (dst) = %d\n", __func__, ftype_dst);
fprintf(stderr, "%s: qntvr (dst) = %d\n", __func__, GGML_QNT_VERSION);
fout.write((const char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
fout.write((const char *) &hparams.n_audio_ctx, sizeof(hparams.n_audio_ctx));
fout.write((const char *) &hparams.n_audio_state, sizeof(hparams.n_audio_state));
fout.write((const char *) &hparams.n_audio_head, sizeof(hparams.n_audio_head));
fout.write((const char *) &hparams.n_audio_layer, sizeof(hparams.n_audio_layer));
fout.write((const char *) &hparams.n_text_ctx, sizeof(hparams.n_text_ctx));
fout.write((const char *) &hparams.n_text_state, sizeof(hparams.n_text_state));
fout.write((const char *) &hparams.n_text_head, sizeof(hparams.n_text_head));
fout.write((const char *) &hparams.n_text_layer, sizeof(hparams.n_text_layer));
fout.write((const char *) &hparams.n_mels, sizeof(hparams.n_mels));
fout.write((const char *) &ftype_dst, sizeof(hparams.ftype));
}
// load mel filters
{
whisper_filters filters;
finp.read ((char *) &filters.n_mel, sizeof(filters.n_mel));
fout.write((char *) &filters.n_mel, sizeof(filters.n_mel));
finp.read ((char *) &filters.n_fft, sizeof(filters.n_fft));
fout.write((char *) &filters.n_fft, sizeof(filters.n_fft));
filters.data.resize(filters.n_mel * filters.n_fft);
finp.read ((char *) filters.data.data(), filters.data.size() * sizeof(float));
fout.write((char *) filters.data.data(), filters.data.size() * sizeof(float));
}
// load vocab
{
int32_t n_vocab = 0;
finp.read ((char *) &n_vocab, sizeof(n_vocab));
fout.write((char *) &n_vocab, sizeof(n_vocab));
//if (n_vocab != hparams.n_vocab) {
// fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
// __func__, fname_inp.c_str(), n_vocab, hparams.n_vocab);
// return false;
//}
char word[128];
for (int i = 0; i < n_vocab; i++) {
uint32_t len;
finp.read ((char *) &len, sizeof(len));
fout.write((char *) &len, sizeof(len));
word[len] = '\0';
finp.read ((char *) word, len);
fout.write((char *) word, len);
vocab.token_to_id[word] = i;
vocab.id_to_token[i] = word;
}
}
// regexes of tensor names to not be quantized
const std::vector<std::string> to_skip = {
//"encoder.*",
"encoder.conv1.bias",
"encoder.conv2.bias",
"encoder.positional_embedding",
"decoder.positional_embedding",
};
if (!ggml_common_quantize_0(finp, fout, ftype, { ".*" }, to_skip)) {
fprintf(stderr, "%s: failed to quantize model '%s'\n", __func__, fname_inp.c_str());
return false;
}
finp.close();
fout.close();
return true;
}
int main(int argc, char ** argv) {
if (argc != 4) {
fprintf(stderr, "usage: %s model-f32.bin model-quant.bin type\n", argv[0]);
ggml_print_ftypes(stderr);
return 1;
}
// needed to initialize f16 tables
{
struct ggml_init_params params = { 0, NULL, false };
struct ggml_context * ctx = ggml_init(params);
ggml_free(ctx);
}
const std::string fname_inp = argv[1];
const std::string fname_out = argv[2];
const ggml_ftype ftype = ggml_parse_ftype(argv[3]);
const int64_t t_main_start_us = ggml_time_us();
int64_t t_quantize_us = 0;
// load the model
{
const int64_t t_start_us = ggml_time_us();
if (!whisper_model_quantize(fname_inp, fname_out, ggml_ftype(ftype))) {
fprintf(stderr, "%s: failed to quantize model from '%s'\n", __func__, fname_inp.c_str());
return 1;
}
t_quantize_us = ggml_time_us() - t_start_us;
}
// report timing
{
const int64_t t_main_end_us = ggml_time_us();
printf("\n");
printf("%s: quantize time = %8.2f ms\n", __func__, t_quantize_us/1000.0f);
printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f);
}
return 0;
}

View File

@ -35,6 +35,15 @@
<br><br>
<b>More examples:</b>
<a href="https://whisper.ggerganov.com/">main</a> |
<a href="https://whisper.ggerganov.com/bench">bench</a> |
<a href="https://whisper.ggerganov.com/stream">stream</a> |
<a href="https://whisper.ggerganov.com/command">command</a> |
<a href="https://whisper.ggerganov.com/talk">talk</a> |
<br><br>
<hr>
Select the model you would like to use, click the "Start" button and start speaking
@ -45,6 +54,10 @@
Whisper model: <span id="model-whisper-status"></span>
<button id="fetch-whisper-tiny-en" onclick="loadWhisper('tiny.en')">tiny.en (75 MB)</button>
<button id="fetch-whisper-base-en" onclick="loadWhisper('base.en')">base.en (142 MB)</button>
<br><br>
Quantized models:<br><br>
<button id="fetch-whisper-tiny-en-q5_1" onclick="loadWhisper('tiny-en-q5_1')">tiny.en (Q5_1, 31 MB)</button>
<button id="fetch-whisper-base-en-q5_1" onclick="loadWhisper('base-en-q5_1')">base.en (Q5_1, 57 MB)</button>
<span id="fetch-whisper-progress"></span>
<!--
@ -162,11 +175,17 @@
let urls = {
'tiny.en': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en.bin',
'base.en': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en.bin',
'tiny-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en-q5_1.bin',
'base-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en-q5_1.bin',
};
let sizes = {
'tiny.en': 75,
'base.en': 142,
'tiny-en-q5_1': 31,
'base-en-q5_1': 57,
};
let url = urls[model];
@ -177,6 +196,10 @@
document.getElementById('fetch-whisper-tiny-en').style.display = 'none';
document.getElementById('fetch-whisper-base-en').style.display = 'none';
document.getElementById('fetch-whisper-tiny-en-q5_1').style.display = 'none';
document.getElementById('fetch-whisper-base-en-q5_1').style.display = 'none';
document.getElementById('model-whisper-status').innerHTML = 'loading "' + model + '" ... ';
cbProgress = function(p) {
@ -188,6 +211,10 @@
var el;
el = document.getElementById('fetch-whisper-tiny-en'); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-base-en'); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-tiny-en-q5_1'); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-base-en-q5_1'); if (el) el.style.display = 'inline-block';
el = document.getElementById('model-whisper-status'); if (el) el.innerHTML = '';
};

View File

@ -1,4 +1,4 @@
if (WHISPER_SUPPORT_SDL2)
if (WHISPER_SDL2)
# stream
set(TARGET stream)
add_executable(${TARGET} stream.cpp)

View File

@ -43,6 +43,7 @@ struct whisper_params {
bool speed_up = false;
bool translate = false;
bool no_fallback = false;
bool print_special = false;
bool no_context = true;
bool no_timestamps = false;
@ -73,6 +74,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-kc" || arg == "--keep-context") { params.no_context = false; }
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
@ -94,22 +96,23 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
fprintf(stderr, " --step N [%-7d] audio step size in milliseconds\n", params.step_ms);
fprintf(stderr, " --length N [%-7d] audio length in milliseconds\n", params.length_ms);
fprintf(stderr, " --keep N [%-7d] audio to keep from previous step in ms\n", params.keep_ms);
fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -kc, --keep-context [%-7s] keep context between audio chunks\n", params.no_context ? "false" : "true");
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
fprintf(stderr, " --step N [%-7d] audio step size in milliseconds\n", params.step_ms);
fprintf(stderr, " --length N [%-7d] audio length in milliseconds\n", params.length_ms);
fprintf(stderr, " --keep N [%-7d] audio to keep from previous step in ms\n", params.keep_ms);
fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -kc, --keep-context [%-7s] keep context between audio chunks\n", params.no_context ? "false" : "true");
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
fprintf(stderr, "\n");
}
@ -148,7 +151,7 @@ int main(int argc, char ** argv) {
// whisper init
if (whisper_lang_id(params.language.c_str()) == -1) {
if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1){
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
whisper_print_usage(argc, argv, params);
exit(0);
@ -297,7 +300,8 @@ int main(int argc, char ** argv) {
wparams.speed_up = params.speed_up;
// disable temperature fallback
wparams.temperature_inc = -1.0f;
//wparams.temperature_inc = -1.0f;
wparams.temperature_inc = params.no_fallback ? 0.0f : wparams.temperature_inc;
wparams.prompt_tokens = params.no_context ? nullptr : prompt_tokens.data();
wparams.prompt_n_tokens = params.no_context ? 0 : prompt_tokens.size();
@ -379,6 +383,7 @@ int main(int argc, char ** argv) {
}
}
}
fflush(stdout);
}
}

1
examples/talk-llama/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
audio.mp3

View File

@ -0,0 +1,16 @@
if (WHISPER_SDL2)
# talk-llama
set(TARGET talk-llama)
#add_executable(${TARGET} talk-llama.cpp llama.cpp)
#target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
#target_link_libraries(${TARGET} PRIVATE common common-sdl whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
# TODO: this is temporary
# need to export ggml symbols for MSVC, but too lazy ..
add_executable(${TARGET} talk-llama.cpp llama.cpp ../common.cpp ../common-sdl.cpp ../../ggml.c ../../whisper.cpp)
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS} ../../)
target_link_libraries(${TARGET} PRIVATE ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
include(DefaultTargetOptions)
endif ()

View File

@ -0,0 +1,50 @@
# talk-llama
Talk with an LLaMA AI in your terminal
[Demo Talk](https://user-images.githubusercontent.com/1991296/228024237-848f998c-c334-46a6-bef8-3271590da83b.mp4)
## Building
The `talk-llama` tool depends on SDL2 library to capture audio from the microphone. You can build it like this:
```bash
# Install SDL2 on Linux
sudo apt-get install libsdl2-dev
# Install SDL2 on Mac OS
brew install sdl2
# Build the "talk-llama" executable
make talk-llama
# Run it
./talk-llama -mw ./models/ggml-small.en.bin -ml ../llama.cpp/models/13B/ggml-model-q4_0.bin -p "Georgi" -t 8
```
- The `-mw` argument specifies the Whisper model that you would like to use. Recommended `base` or `small` for real-time experience
- The `-ml` argument specifies the LLaMA model that you would like to use. Read the instructions in https://github.com/ggerganov/llama.cpp for information about how to obtain a `ggml` compatible LLaMA model
## Session
The `talk-llama` tool supports session management to enable more coherent and continuous conversations. By maintaining context from previous interactions, it can better understand and respond to user requests in a more natural way.
To enable session support, use the `--session FILE` command line option when running the program. The `talk-llama` model state will be saved to the specified file after each interaction. If the file does not exist, it will be created. If the file exists, the model state will be loaded from it, allowing you to resume a previous session.
This feature is especially helpful for maintaining context in long conversations or when interacting with the AI assistant across multiple sessions. It ensures that the assistant remembers the previous interactions and can provide more relevant and contextual responses.
Example usage:
```bash
./talk-llama --session ./my-session-file -mw ./models/ggml-small.en.bin -ml ../llama.cpp/models/13B/ggml-model-q4_0.bin -p "Georgi" -t 8
```
## TTS
For best experience, this example needs a TTS tool to convert the generated text responses to voice.
You can use any TTS engine that you would like - simply edit the [speak](speak) script to your needs.
By default, it is configured to use MacOS's `say` or Windows SpeechSynthesizer, but you can use whatever you wish.
## Discussion
If you have any feedback, please let "us" know in the following discussion: https://github.com/ggerganov/whisper.cpp/discussions/672?converting=1

View File

@ -0,0 +1,20 @@
import sys
import importlib.util
if importlib.util.find_spec("elevenlabs") is None:
print("elevenlabs library is not installed, you can install it to your enviroment using 'pip install elevenlabs'")
sys.exit()
from elevenlabs import generate, play, save
# Get a Voice object, by name or UUID
voice = "Arnold" #Possible Voices: Adam Antoni Arnold Bella Domi Elli Josh
# Generate the TTS
audio = generate(
text=str(sys.argv[2:]),
voice=voice
)
# Save the TTS to a file
save(audio, "audio.mp3")

View File

@ -0,0 +1,474 @@
// Internal header to be included only by llama.cpp.
// Contains wrappers around OS interfaces.
#ifndef LLAMA_UTIL_H
#define LLAMA_UTIL_H
#include <cstdio>
#include <cstdint>
#include <cerrno>
#include <cstring>
#include <cstdarg>
#include <cstdlib>
#include <climits>
#include <string>
#include <vector>
#include <stdexcept>
#ifdef __has_include
#if __has_include(<unistd.h>)
#include <unistd.h>
#if defined(_POSIX_MAPPED_FILES)
#include <sys/mman.h>
#endif
#if defined(_POSIX_MEMLOCK_RANGE)
#include <sys/resource.h>
#endif
#endif
#endif
#if defined(_WIN32)
#define WIN32_LEAN_AND_MEAN
#ifndef NOMINMAX
#define NOMINMAX
#endif
#include <windows.h>
#include <io.h>
#include <stdio.h> // for _fseeki64
#endif
#define LLAMA_ASSERT(x) \
do { \
if (!(x)) { \
fprintf(stderr, "LLAMA_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
abort(); \
} \
} while (0)
#ifdef __GNUC__
#ifdef __MINGW32__
__attribute__((format(gnu_printf, 1, 2)))
#else
__attribute__((format(printf, 1, 2)))
#endif
#endif
static std::string format(const char * fmt, ...) {
va_list ap, ap2;
va_start(ap, fmt);
va_copy(ap2, ap);
int size = vsnprintf(NULL, 0, fmt, ap);
LLAMA_ASSERT(size >= 0 && size < INT_MAX);
std::vector<char> buf(size + 1);
int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
LLAMA_ASSERT(size2 == size);
va_end(ap2);
va_end(ap);
return std::string(buf.data(), size);
}
struct llama_file {
// use FILE * so we don't have to re-open the file to mmap
FILE * fp;
size_t size;
llama_file(const char * fname, const char * mode) {
fp = std::fopen(fname, mode);
if (fp == NULL) {
throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno)));
}
seek(0, SEEK_END);
size = tell();
seek(0, SEEK_SET);
}
size_t tell() const {
#ifdef _WIN32
__int64 ret = _ftelli64(fp);
#else
long ret = std::ftell(fp);
#endif
LLAMA_ASSERT(ret != -1); // this really shouldn't fail
return (size_t) ret;
}
void seek(size_t offset, int whence) {
#ifdef _WIN32
int ret = _fseeki64(fp, (__int64) offset, whence);
#else
int ret = std::fseek(fp, (long) offset, whence);
#endif
LLAMA_ASSERT(ret == 0); // same
}
void read_raw(void * ptr, size_t len) const {
if (len == 0) {
return;
}
errno = 0;
std::size_t ret = std::fread(ptr, len, 1, fp);
if (ferror(fp)) {
throw std::runtime_error(format("read error: %s", strerror(errno)));
}
if (ret != 1) {
throw std::runtime_error(std::string("unexpectedly reached end of file"));
}
}
std::uint32_t read_u32() {
std::uint32_t ret;
read_raw(&ret, sizeof(ret));
return ret;
}
std::string read_string(std::uint32_t len) {
std::vector<char> chars(len);
read_raw(chars.data(), len);
return std::string(chars.data(), len);
}
void write_raw(const void * ptr, size_t len) const {
if (len == 0) {
return;
}
errno = 0;
size_t ret = std::fwrite(ptr, len, 1, fp);
if (ret != 1) {
throw std::runtime_error(format("write error: %s", strerror(errno)));
}
}
void write_u32(std::uint32_t val) {
write_raw(&val, sizeof(val));
}
~llama_file() {
if (fp) {
std::fclose(fp);
}
}
};
#if defined(_WIN32)
static std::string llama_format_win_err(DWORD err) {
LPSTR buf;
size_t size = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
NULL, err, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&buf, 0, NULL);
if (!size) {
return "FormatMessageA failed";
}
std::string ret(buf, size);
LocalFree(buf);
return ret;
}
#endif
struct llama_mmap {
void * addr;
size_t size;
llama_mmap(const llama_mmap &) = delete;
#ifdef _POSIX_MAPPED_FILES
static constexpr bool SUPPORTED = true;
llama_mmap(struct llama_file * file, size_t prefetch = (size_t) -1 /* -1 = max value */) {
size = file->size;
int fd = fileno(file->fp);
int flags = MAP_SHARED;
#ifdef __linux__
flags |= MAP_POPULATE;
#endif
addr = mmap(NULL, file->size, PROT_READ, flags, fd, 0);
if (addr == MAP_FAILED) {
throw std::runtime_error(format("mmap failed: %s", strerror(errno)));
}
if (prefetch > 0) {
// Advise the kernel to preload the mapped memory
if (posix_madvise(addr, std::min(file->size, prefetch), POSIX_MADV_WILLNEED)) {
fprintf(stderr, "warning: posix_madvise(.., POSIX_MADV_WILLNEED) failed: %s\n",
strerror(errno));
}
}
}
~llama_mmap() {
munmap(addr, size);
}
#elif defined(_WIN32)
static constexpr bool SUPPORTED = true;
llama_mmap(struct llama_file * file, bool prefetch = true) {
size = file->size;
HANDLE hFile = (HANDLE) _get_osfhandle(_fileno(file->fp));
HANDLE hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL);
DWORD error = GetLastError();
if (hMapping == NULL) {
throw std::runtime_error(format("CreateFileMappingA failed: %s", llama_format_win_err(error).c_str()));
}
addr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0);
error = GetLastError();
CloseHandle(hMapping);
if (addr == NULL) {
throw std::runtime_error(format("MapViewOfFile failed: %s", llama_format_win_err(error).c_str()));
}
#if _WIN32_WINNT >= _WIN32_WINNT_WIN8
if (prefetch) {
// Advise the kernel to preload the mapped memory
WIN32_MEMORY_RANGE_ENTRY range;
range.VirtualAddress = addr;
range.NumberOfBytes = (SIZE_T)size;
if (!PrefetchVirtualMemory(GetCurrentProcess(), 1, &range, 0)) {
fprintf(stderr, "warning: PrefetchVirtualMemory failed: %s\n",
llama_format_win_err(GetLastError()).c_str());
}
}
#else
#pragma message("warning: You are building for pre-Windows 8; prefetch not supported")
#endif // _WIN32_WINNT >= _WIN32_WINNT_WIN8
}
~llama_mmap() {
if (!UnmapViewOfFile(addr)) {
fprintf(stderr, "warning: UnmapViewOfFile failed: %s\n",
llama_format_win_err(GetLastError()).c_str());
}
}
#else
static constexpr bool SUPPORTED = false;
llama_mmap(struct llama_file *, bool prefetch = true) {
(void)prefetch;
throw std::runtime_error(std::string("mmap not supported"));
}
#endif
};
// Represents some region of memory being locked using mlock or VirtualLock;
// will automatically unlock on destruction.
struct llama_mlock {
void * addr = NULL;
size_t size = 0;
bool failed_already = false;
llama_mlock() {}
llama_mlock(const llama_mlock &) = delete;
~llama_mlock() {
if (size) {
raw_unlock(addr, size);
}
}
void init(void * ptr) {
LLAMA_ASSERT(addr == NULL && size == 0);
addr = ptr;
}
void grow_to(size_t target_size) {
LLAMA_ASSERT(addr);
if (failed_already) {
return;
}
size_t granularity = lock_granularity();
target_size = (target_size + granularity - 1) & ~(granularity - 1);
if (target_size > size) {
if (raw_lock((uint8_t *) addr + size, target_size - size)) {
size = target_size;
} else {
failed_already = true;
}
}
}
#ifdef _POSIX_MEMLOCK_RANGE
static constexpr bool SUPPORTED = true;
size_t lock_granularity() {
return (size_t) sysconf(_SC_PAGESIZE);
}
#ifdef __APPLE__
#define MLOCK_SUGGESTION \
"Try increasing the sysctl values 'vm.user_wire_limit' and 'vm.global_user_wire_limit' and/or " \
"decreasing 'vm.global_no_user_wire_amount'. Also try increasing RLIMIT_MLOCK (ulimit -l).\n"
#else
#define MLOCK_SUGGESTION \
"Try increasing RLIMIT_MLOCK ('ulimit -l' as root).\n"
#endif
bool raw_lock(const void * addr, size_t size) {
if (!mlock(addr, size)) {
return true;
} else {
char* errmsg = std::strerror(errno);
bool suggest = (errno == ENOMEM);
// Check if the resource limit is fine after all
struct rlimit lock_limit;
if (suggest && getrlimit(RLIMIT_MEMLOCK, &lock_limit))
suggest = false;
if (suggest && (lock_limit.rlim_max > lock_limit.rlim_cur + size))
suggest = false;
fprintf(stderr, "warning: failed to mlock %zu-byte buffer (after previously locking %zu bytes): %s\n%s",
size, this->size, errmsg, suggest ? MLOCK_SUGGESTION : "");
return false;
}
}
#undef MLOCK_SUGGESTION
void raw_unlock(void * addr, size_t size) {
if (munlock(addr, size)) {
fprintf(stderr, "warning: failed to munlock buffer: %s\n", std::strerror(errno));
}
}
#elif defined(_WIN32)
static constexpr bool SUPPORTED = true;
size_t lock_granularity() {
SYSTEM_INFO si;
GetSystemInfo(&si);
return (size_t) si.dwPageSize;
}
bool raw_lock(void * ptr, size_t len) {
for (int tries = 1; ; tries++) {
if (VirtualLock(ptr, len)) {
return true;
}
if (tries == 2) {
fprintf(stderr, "warning: failed to VirtualLock %zu-byte buffer (after previously locking %zu bytes): %s\n",
len, size, llama_format_win_err(GetLastError()).c_str());
return false;
}
// It failed but this was only the first try; increase the working
// set size and try again.
SIZE_T min_ws_size, max_ws_size;
if (!GetProcessWorkingSetSize(GetCurrentProcess(), &min_ws_size, &max_ws_size)) {
fprintf(stderr, "warning: GetProcessWorkingSetSize failed: %s\n",
llama_format_win_err(GetLastError()).c_str());
return false;
}
// Per MSDN: "The maximum number of pages that a process can lock
// is equal to the number of pages in its minimum working set minus
// a small overhead."
// Hopefully a megabyte is enough overhead:
size_t increment = len + 1048576;
// The minimum must be <= the maximum, so we need to increase both:
min_ws_size += increment;
max_ws_size += increment;
if (!SetProcessWorkingSetSize(GetCurrentProcess(), min_ws_size, max_ws_size)) {
fprintf(stderr, "warning: SetProcessWorkingSetSize failed: %s\n",
llama_format_win_err(GetLastError()).c_str());
return false;
}
}
}
void raw_unlock(void * ptr, size_t len) {
if (!VirtualUnlock(ptr, len)) {
fprintf(stderr, "warning: failed to VirtualUnlock buffer: %s\n",
llama_format_win_err(GetLastError()).c_str());
}
}
#else
static constexpr bool SUPPORTED = false;
size_t lock_granularity() {
return (size_t) 65536;
}
bool raw_lock(const void * addr, size_t len) {
fprintf(stderr, "warning: mlock not supported on this system\n");
return false;
}
void raw_unlock(const void * addr, size_t len) {}
#endif
};
// Replacement for std::vector<uint8_t> that doesn't require zero-initialization.
struct llama_buffer {
uint8_t * addr = NULL;
size_t size = 0;
llama_buffer() = default;
void resize(size_t len) {
delete[] addr;
addr = new uint8_t[len];
size = len;
}
~llama_buffer() {
delete[] addr;
}
// disable copy and move
llama_buffer(const llama_buffer&) = delete;
llama_buffer(llama_buffer&&) = delete;
llama_buffer& operator=(const llama_buffer&) = delete;
llama_buffer& operator=(llama_buffer&&) = delete;
};
#ifdef GGML_USE_CUBLAS
#include "ggml-cuda.h"
struct llama_ctx_buffer {
uint8_t * addr = NULL;
bool is_cuda;
size_t size = 0;
llama_ctx_buffer() = default;
void resize(size_t size) {
free();
addr = (uint8_t *) ggml_cuda_host_malloc(size);
if (addr) {
is_cuda = true;
}
else {
// fall back to pageable memory
addr = new uint8_t[size];
is_cuda = false;
}
this->size = size;
}
void free() {
if (addr) {
if (is_cuda) {
ggml_cuda_host_free(addr);
}
else {
delete[] addr;
}
}
addr = NULL;
}
~llama_ctx_buffer() {
free();
}
// disable copy and move
llama_ctx_buffer(const llama_ctx_buffer&) = delete;
llama_ctx_buffer(llama_ctx_buffer&&) = delete;
llama_ctx_buffer& operator=(const llama_ctx_buffer&) = delete;
llama_ctx_buffer& operator=(llama_ctx_buffer&&) = delete;
};
#else
typedef llama_buffer llama_ctx_buffer;
#endif
#endif

File diff suppressed because it is too large Load Diff

273
examples/talk-llama/llama.h Normal file
View File

@ -0,0 +1,273 @@
#ifndef LLAMA_H
#define LLAMA_H
#include <stddef.h>
#include <stdint.h>
#include <stdbool.h>
#ifdef LLAMA_SHARED
# if defined(_WIN32) && !defined(__MINGW32__)
# ifdef LLAMA_BUILD
# define LLAMA_API __declspec(dllexport)
# else
# define LLAMA_API __declspec(dllimport)
# endif
# else
# define LLAMA_API __attribute__ ((visibility ("default")))
# endif
#else
# define LLAMA_API
#endif
#define LLAMA_FILE_MAGIC_GGJT 0x67676a74u // 'ggjt'
#define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
#define LLAMA_FILE_MAGIC_GGMF 0x67676d66u // 'ggmf'
#define LLAMA_FILE_MAGIC_GGML 0x67676d6cu // 'ggml'
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
#define LLAMA_FILE_VERSION 3
#define LLAMA_FILE_MAGIC LLAMA_FILE_MAGIC_GGJT
#define LLAMA_FILE_MAGIC_UNVERSIONED LLAMA_FILE_MAGIC_GGML
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
#define LLAMA_SESSION_VERSION 1
#ifdef __cplusplus
extern "C" {
#endif
//
// C interface
//
// TODO: show sample usage
//
struct llama_context;
typedef int llama_token;
typedef struct llama_token_data {
llama_token id; // token id
float logit; // log-odds of the token
float p; // probability of the token
} llama_token_data;
typedef struct llama_token_data_array {
llama_token_data * data;
size_t size;
bool sorted;
} llama_token_data_array;
typedef void (*llama_progress_callback)(float progress, void *ctx);
struct llama_context_params {
int n_ctx; // text context
int n_gpu_layers; // number of layers to store in VRAM
int seed; // RNG seed, -1 for random
bool f16_kv; // use fp16 for KV cache
bool logits_all; // the llama_eval() call computes all logits, not just the last one
bool vocab_only; // only load the vocabulary, no weights
bool use_mmap; // use mmap if possible
bool use_mlock; // force system to keep model in RAM
bool embedding; // embedding mode only
// called with a progress value between 0 and 1, pass NULL to disable
llama_progress_callback progress_callback;
// context pointer passed to the progress callback
void * progress_callback_user_data;
};
// model file types
enum llama_ftype {
LLAMA_FTYPE_ALL_F32 = 0,
LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16
// LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed
// LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // support has been removed
LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors
};
LLAMA_API struct llama_context_params llama_context_default_params();
LLAMA_API bool llama_mmap_supported();
LLAMA_API bool llama_mlock_supported();
// TODO: not great API - very likely to change
// Initialize the llama + ggml backend
// Call once at the start of the program
LLAMA_API void llama_init_backend();
LLAMA_API int64_t llama_time_us();
// Various functions for loading a ggml llama model.
// Allocate (almost) all memory needed for the model.
// Return NULL on failure
LLAMA_API struct llama_context * llama_init_from_file(
const char * path_model,
struct llama_context_params params);
// Frees all allocated memory
LLAMA_API void llama_free(struct llama_context * ctx);
// TODO: not great API - very likely to change
// Returns 0 on success
// nthread - how many threads to use. If <=0, will use std::thread::hardware_concurrency(), else the number given
LLAMA_API int llama_model_quantize(
const char * fname_inp,
const char * fname_out,
enum llama_ftype ftype,
int nthread);
// Apply a LoRA adapter to a loaded model
// path_base_model is the path to a higher quality model to use as a base for
// the layers modified by the adapter. Can be NULL to use the current loaded model.
// The model needs to be reloaded before applying a new adapter, otherwise the adapter
// will be applied on top of the previous one
// Returns 0 on success
LLAMA_API int llama_apply_lora_from_file(
struct llama_context * ctx,
const char * path_lora,
const char * path_base_model,
int n_threads);
// Returns the number of tokens in the KV cache
LLAMA_API int llama_get_kv_cache_token_count(const struct llama_context * ctx);
// Sets the current rng seed.
LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, int seed);
// Returns the maximum size in bytes of the state (rng, logits, embedding
// and kv_cache) - will often be smaller after compacting tokens
LLAMA_API size_t llama_get_state_size(const struct llama_context * ctx);
// Copies the state to the specified destination address.
// Destination needs to have allocated enough memory.
// Returns the number of bytes copied
LLAMA_API size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst);
// Set the state reading from the specified address
// Returns the number of bytes read
LLAMA_API size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src);
// Save/load session file
LLAMA_API bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out);
LLAMA_API bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count);
// Run the llama inference to obtain the logits and probabilities for the next token.
// tokens + n_tokens is the provided batch of new tokens to process
// n_past is the number of tokens to use from previous eval calls
// Returns 0 on success
LLAMA_API int llama_eval(
struct llama_context * ctx,
const llama_token * tokens,
int n_tokens,
int n_past,
int n_threads);
// Convert the provided text into tokens.
// The tokens pointer must be large enough to hold the resulting tokens.
// Returns the number of tokens on success, no more than n_max_tokens
// Returns a negative number on failure - the number of tokens that would have been returned
// TODO: not sure if correct
LLAMA_API int llama_tokenize(
struct llama_context * ctx,
const char * text,
llama_token * tokens,
int n_max_tokens,
bool add_bos);
LLAMA_API int llama_n_vocab(const struct llama_context * ctx);
LLAMA_API int llama_n_ctx (const struct llama_context * ctx);
LLAMA_API int llama_n_embd (const struct llama_context * ctx);
// Token logits obtained from the last call to llama_eval()
// The logits for the last token are stored in the last row
// Can be mutated in order to change the probabilities of the next token
// Rows: n_tokens
// Cols: n_vocab
LLAMA_API float * llama_get_logits(struct llama_context * ctx);
// Get the embeddings for the input
// shape: [n_embd] (1-dimensional)
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
// Token Id -> String. Uses the vocabulary in the provided context
LLAMA_API const char * llama_token_to_str(const struct llama_context * ctx, llama_token token);
// Special tokens
LLAMA_API llama_token llama_token_bos();
LLAMA_API llama_token llama_token_eos();
LLAMA_API llama_token llama_token_nl();
// Sampling functions
/// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
LLAMA_API void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float penalty);
/// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float alpha_frequency, float alpha_presence);
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates);
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
LLAMA_API void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int k, size_t min_keep);
/// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
LLAMA_API void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep);
/// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
LLAMA_API void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep);
/// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep);
LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp);
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
/// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
LLAMA_API llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu);
/// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
LLAMA_API llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu);
/// @details Selects the token with the highest probability.
LLAMA_API llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates);
/// @details Randomly selects a token from the candidates based on their probabilities.
LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates);
// Performance information
LLAMA_API void llama_print_timings(struct llama_context * ctx);
LLAMA_API void llama_reset_timings(struct llama_context * ctx);
// Print system information
LLAMA_API const char * llama_print_system_info(void);
#ifdef __cplusplus
}
#endif
// Internal API to be implemented by llama.cpp and used by tests/benchmarks only
#ifdef LLAMA_API_INTERNAL
#include <vector>
#include <string>
struct ggml_tensor;
std::vector<std::pair<std::string, struct ggml_tensor *>>& llama_internal_get_tensor_map(struct llama_context * ctx);
#endif
#endif // LLAMA_H

View File

@ -0,0 +1,23 @@
Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
Write a text transcript of a never ending dialog, where {0} interacts with an AI assistant named {1}.
{1} is helpful, kind, honest, friendly, good at writing and never fails to answer {0}s requests immediately and with details and precision.
There are no annotations like (30 seconds passed...) or (to himself), just what {0} and {1} say aloud to each other.
The transcript only includes text, it does not include markup like HTML and Markdown.
{1} responds with short and concise answers.
### Response:
{0}{4} Hello, {1}!
{1}{4} Hello {0}! How may I help you today?
{0}{4} What time is it?
{1}{4} It is {2} o'clock.
{0}{4} What year is it?
{1}{4} We are in {3}.
{0}{4} What is a cat?
{1}{4} A cat is a domestic species of small carnivorous mammal. It is the only domesticated species in the family Felidae.
{0}{4} Name a color.
{1}{4} Blue
{0}{4}

24
examples/talk-llama/speak Normal file
View File

@ -0,0 +1,24 @@
#!/bin/bash
# Usage:
# speak.sh <voice_id> <text-to-speak>
# espeak
# Mac OS: brew install espeak
# Linux: apt-get install espeak
#
#espeak -v en-us+m$1 -s 225 -p 50 -a 200 -g 5 -k 5 "$2"
# for Mac
say "$2"
# Eleven Labs
# To use it, install the elevenlabs module from pip (pip install elevenlabs)
# It's possible to use the API for free with limited number of characters. To increase this limit register to https://beta.elevenlabs.io to get an api key and paste it after 'ELEVEN_API_KEY='
#Keep the line commented to use the free version whitout api key
#
#export ELEVEN_API_KEY=your_api_key
#wd=$(dirname $0)
#script=$wd/eleven-labs.py
#python3 $script $1 "$2" >/dev/null 2>&1
#ffplay -autoexit -nodisp -loglevel quiet -hide_banner -i ./audio.mp3 >/dev/null 2>&1

View File

@ -0,0 +1 @@
@powershell -ExecutionPolicy Bypass -F examples\talk\speak.ps1 %1 %2

View File

@ -0,0 +1,12 @@
# Set-ExecutionPolicy -ExecutionPolicy Bypass -Scope CurrentUser
param(
# voice options are David or Zira
[Parameter(Mandatory=$true)][string]$voice,
[Parameter(Mandatory=$true)][string]$text
)
Add-Type -AssemblyName System.Speech;
$speak = New-Object System.Speech.Synthesis.SpeechSynthesizer;
$speak.SelectVoice("Microsoft $voice Desktop");
$speak.Rate="0";
$speak.Speak($text);

View File

@ -0,0 +1,670 @@
// Talk with AI
//
#include "common.h"
#include "common-sdl.h"
#include "whisper.h"
#include "llama.h"
#include <cassert>
#include <cstdio>
#include <fstream>
#include <regex>
#include <string>
#include <thread>
#include <vector>
#include <regex>
std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos) {
// initialize to prompt numer of chars, since n_tokens <= n_prompt_chars
std::vector<llama_token> res(text.size() + (int)add_bos);
int n = llama_tokenize(ctx, text.c_str(), res.data(), res.size(), add_bos);
assert(n >= 0);
res.resize(n);
return res;
}
// command-line parameters
struct whisper_params {
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
int32_t voice_ms = 10000;
int32_t capture_id = -1;
int32_t max_tokens = 32;
int32_t audio_ctx = 0;
float vad_thold = 0.6f;
float freq_thold = 100.0f;
bool speed_up = false;
bool translate = false;
bool print_special = false;
bool print_energy = false;
bool no_timestamps = true;
bool verbose_prompt = false;
std::string person = "Georgi";
std::string language = "en";
std::string model_wsp = "models/ggml-base.en.bin";
std::string model_llama = "models/ggml-llama-7B.bin";
std::string speak = "./examples/talk-llama/speak";
std::string prompt = "";
std::string fname_out;
std::string path_session = ""; // path to file for saving/loading model eval state
};
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
for (int i = 1; i < argc; i++) {
std::string arg = argv[i];
if (arg == "-h" || arg == "--help") {
whisper_print_usage(argc, argv, params);
exit(0);
}
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
else if (arg == "-vms" || arg == "--voice-ms") { params.voice_ms = std::stoi(argv[++i]); }
else if (arg == "-c" || arg == "--capture") { params.capture_id = std::stoi(argv[++i]); }
else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); }
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
else if (arg == "--verbose-prompt") { params.verbose_prompt = true; }
else if (arg == "-p" || arg == "--person") { params.person = argv[++i]; }
else if (arg == "--session") { params.path_session = argv[++i];}
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
else if (arg == "-mw" || arg == "--model-whisper") { params.model_wsp = argv[++i]; }
else if (arg == "-ml" || arg == "--model-llama") { params.model_llama = argv[++i]; }
else if (arg == "-s" || arg == "--speak") { params.speak = argv[++i]; }
else if (arg == "--prompt-file") {
std::ifstream file(argv[++i]);
std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), back_inserter(params.prompt));
if (params.prompt.back() == '\n') {
params.prompt.pop_back();
}
}
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params);
exit(0);
}
}
return true;
}
void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params) {
fprintf(stderr, "\n");
fprintf(stderr, "usage: %s [options]\n", argv[0]);
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
fprintf(stderr, " -vms N, --voice-ms N [%-7d] voice duration in milliseconds\n", params.voice_ms);
fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
fprintf(stderr, " -p NAME, --person NAME [%-7s] person name (for prompt selection)\n", params.person.c_str());
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
fprintf(stderr, " -mw FILE, --model-whisper [%-7s] whisper model file\n", params.model_wsp.c_str());
fprintf(stderr, " -ml FILE, --model-llama [%-7s] llama model file\n", params.model_llama.c_str());
fprintf(stderr, " -s FILE, --speak TEXT [%-7s] command for TTS\n", params.speak.c_str());
fprintf(stderr, " --prompt-file FNAME [%-7s] file with custom prompt to start dialog\n", "");
fprintf(stderr, " --session FNAME file to cache model state in (may be large!) (default: none)\n");
fprintf(stderr, " --verbose-prompt [%-7s] print prompt at start\n", params.verbose_prompt ? "true" : "false");
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
fprintf(stderr, "\n");
}
std::string transcribe(
whisper_context * ctx,
const whisper_params & params,
const std::vector<float> & pcmf32,
const std::string prompt_text,
float & prob,
int64_t & t_ms) {
const auto t_start = std::chrono::high_resolution_clock::now();
prob = 0.0f;
t_ms = 0;
std::vector<whisper_token> prompt_tokens;
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
prompt_tokens.resize(1024);
prompt_tokens.resize(whisper_tokenize(ctx, prompt_text.c_str(), prompt_tokens.data(), prompt_tokens.size()));
wparams.print_progress = false;
wparams.print_special = params.print_special;
wparams.print_realtime = false;
wparams.print_timestamps = !params.no_timestamps;
wparams.translate = params.translate;
wparams.no_context = true;
wparams.single_segment = true;
wparams.max_tokens = params.max_tokens;
wparams.language = params.language.c_str();
wparams.n_threads = params.n_threads;
wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();
wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
return "";
}
int prob_n = 0;
std::string result;
const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
result += text;
const int n_tokens = whisper_full_n_tokens(ctx, i);
for (int j = 0; j < n_tokens; ++j) {
const auto token = whisper_full_get_token_data(ctx, i, j);
prob += token.p;
++prob_n;
}
}
if (prob_n > 0) {
prob /= prob_n;
}
const auto t_end = std::chrono::high_resolution_clock::now();
t_ms = std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count();
return result;
}
const std::string k_prompt_whisper = R"(A conversation with a person called {1}.)";
const std::string k_prompt_llama = R"(Text transcript of a never ending dialog, where {0} interacts with an AI assistant named {1}.
{1} is helpful, kind, honest, friendly, good at writing and never fails to answer {0}s requests immediately and with details and precision.
There are no annotations like (30 seconds passed...) or (to himself), just what {0} and {1} say aloud to each other.
The transcript only includes text, it does not include markup like HTML and Markdown.
{1} responds with short and concise answers.
{0}{4} Hello, {1}!
{1}{4} Hello {0}! How may I help you today?
{0}{4} What time is it?
{1}{4} It is {2} o'clock.
{0}{4} What year is it?
{1}{4} We are in {3}.
{0}{4} What is a cat?
{1}{4} A cat is a domestic species of small carnivorous mammal. It is the only domesticated species in the family Felidae.
{0}{4} Name a color.
{1}{4} Blue
{0}{4})";
int main(int argc, char ** argv) {
whisper_params params;
if (whisper_params_parse(argc, argv, params) == false) {
return 1;
}
if (whisper_lang_id(params.language.c_str()) == -1) {
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
whisper_print_usage(argc, argv, params);
exit(0);
}
// whisper init
struct whisper_context * ctx_wsp = whisper_init_from_file(params.model_wsp.c_str());
// llama init
llama_init_backend();
auto lparams = llama_context_default_params();
// tune these to your liking
lparams.n_ctx = 2048;
lparams.seed = 1;
lparams.f16_kv = true;
struct llama_context * ctx_llama = llama_init_from_file(params.model_llama.c_str(), lparams);
// print some info about the processing
{
fprintf(stderr, "\n");
if (!whisper_is_multilingual(ctx_wsp)) {
if (params.language != "en" || params.translate) {
params.language = "en";
params.translate = false;
fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
}
}
fprintf(stderr, "%s: processing, %d threads, lang = %s, task = %s, timestamps = %d ...\n",
__func__,
params.n_threads,
params.language.c_str(),
params.translate ? "translate" : "transcribe",
params.no_timestamps ? 0 : 1);
fprintf(stderr, "\n");
}
// init audio
audio_async audio(30*1000);
if (!audio.init(params.capture_id, WHISPER_SAMPLE_RATE)) {
fprintf(stderr, "%s: audio.init() failed!\n", __func__);
return 1;
}
audio.resume();
int n_iter = 0;
bool is_running = true;
bool force_speak = false;
float prob0 = 0.0f;
const std::string chat_symb = ":";
const std::string bot_name = "LLaMA";
std::vector<float> pcmf32_cur;
std::vector<float> pcmf32_prompt;
const std::string prompt_whisper = ::replace(k_prompt_whisper, "{1}", bot_name);
// construct the initial prompt for LLaMA inference
std::string prompt_llama = params.prompt.empty() ? k_prompt_llama : params.prompt;
// need to have leading ' '
prompt_llama.insert(0, 1, ' ');
prompt_llama = ::replace(prompt_llama, "{0}", params.person);
prompt_llama = ::replace(prompt_llama, "{1}", bot_name);
{
// get time string
std::string time_str;
{
time_t t = time(0);
struct tm * now = localtime(&t);
char buf[128];
strftime(buf, sizeof(buf), "%H:%M", now);
time_str = buf;
}
prompt_llama = ::replace(prompt_llama, "{2}", time_str);
}
{
// get year string
std::string year_str;
{
time_t t = time(0);
struct tm * now = localtime(&t);
char buf[128];
strftime(buf, sizeof(buf), "%Y", now);
year_str = buf;
}
prompt_llama = ::replace(prompt_llama, "{3}", year_str);
}
prompt_llama = ::replace(prompt_llama, "{4}", chat_symb);
// init session
std::string path_session = params.path_session;
std::vector<llama_token> session_tokens;
auto embd_inp = ::llama_tokenize(ctx_llama, prompt_llama, true);
if (!path_session.empty()) {
fprintf(stderr, "%s: attempting to load saved session from %s\n", __func__, path_session.c_str());
// fopen to check for existing session
FILE * fp = std::fopen(path_session.c_str(), "rb");
if (fp != NULL) {
std::fclose(fp);
session_tokens.resize(lparams.n_ctx);
size_t n_token_count_out = 0;
if (!llama_load_session_file(ctx_llama, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out)) {
fprintf(stderr, "%s: error: failed to load session file '%s'\n", __func__, path_session.c_str());
return 1;
}
session_tokens.resize(n_token_count_out);
for (size_t i = 0; i < session_tokens.size(); i++) {
embd_inp[i] = session_tokens[i];
}
fprintf(stderr, "%s: loaded a session with prompt size of %d tokens\n", __func__, (int) session_tokens.size());
} else {
fprintf(stderr, "%s: session file does not exist, will create\n", __func__);
}
}
// evaluate the initial prompt
printf("\n");
printf("%s : initializing - please wait ...\n", __func__);
if (llama_eval(ctx_llama, embd_inp.data(), embd_inp.size(), 0, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return 1;
}
if (params.verbose_prompt) {
fprintf(stdout, "\n");
fprintf(stdout, "%s", prompt_llama.c_str());
fflush(stdout);
}
// debug message about similarity of saved session, if applicable
size_t n_matching_session_tokens = 0;
if (session_tokens.size()) {
for (llama_token id : session_tokens) {
if (n_matching_session_tokens >= embd_inp.size() || id != embd_inp[n_matching_session_tokens]) {
break;
}
n_matching_session_tokens++;
}
if (n_matching_session_tokens >= embd_inp.size()) {
fprintf(stderr, "%s: session file has exact match for prompt!\n", __func__);
} else if (n_matching_session_tokens < (embd_inp.size() / 2)) {
fprintf(stderr, "%s: warning: session file has low similarity to prompt (%zu / %zu tokens); will mostly be reevaluated\n",
__func__, n_matching_session_tokens, embd_inp.size());
} else {
fprintf(stderr, "%s: session file matches %zu / %zu tokens of prompt\n",
__func__, n_matching_session_tokens, embd_inp.size());
}
}
// HACK - because session saving incurs a non-negligible delay, for now skip re-saving session
// if we loaded a session with at least 75% similarity. It's currently just used to speed up the
// initial prompt so it doesn't need to be an exact match.
bool need_to_save_session = !path_session.empty() && n_matching_session_tokens < (embd_inp.size() * 3 / 4);
printf("%s : done! start speaking in the microphone\n", __func__);
printf("\n");
printf("%s%s", params.person.c_str(), chat_symb.c_str());
fflush(stdout);
// clear audio buffer
audio.clear();
// text inference variables
const int voice_id = 2;
const int n_keep = embd_inp.size();
const int n_ctx = llama_n_ctx(ctx_llama);
int n_past = n_keep;
int n_prev = 64; // TODO arg
int n_session_consumed = !path_session.empty() && session_tokens.size() > 0 ? session_tokens.size() : 0;
std::vector<llama_token> embd;
// reverse prompts for detecting when it's time to stop speaking
std::vector<std::string> antiprompts = {
params.person + chat_symb,
};
// main loop
while (is_running) {
// handle Ctrl + C
is_running = sdl_poll_events();
if (!is_running) {
break;
}
// delay
std::this_thread::sleep_for(std::chrono::milliseconds(100));
int64_t t_ms = 0;
{
audio.get(2000, pcmf32_cur);
if (::vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1250, params.vad_thold, params.freq_thold, params.print_energy) || force_speak) {
//fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
audio.get(params.voice_ms, pcmf32_cur);
std::string text_heard;
if (!force_speak) {
text_heard = ::trim(::transcribe(ctx_wsp, params, pcmf32_cur, prompt_whisper, prob0, t_ms));
}
// remove text between brackets using regex
{
std::regex re("\\[.*?\\]");
text_heard = std::regex_replace(text_heard, re, "");
}
// remove text between brackets using regex
{
std::regex re("\\(.*?\\)");
text_heard = std::regex_replace(text_heard, re, "");
}
// remove all characters, except for letters, numbers, punctuation and ':', '\'', '-', ' '
text_heard = std::regex_replace(text_heard, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
// take first line
text_heard = text_heard.substr(0, text_heard.find_first_of('\n'));
// remove leading and trailing whitespace
text_heard = std::regex_replace(text_heard, std::regex("^\\s+"), "");
text_heard = std::regex_replace(text_heard, std::regex("\\s+$"), "");
const std::vector<llama_token> tokens = llama_tokenize(ctx_llama, text_heard.c_str(), false);
if (text_heard.empty() || tokens.empty() || force_speak) {
//fprintf(stdout, "%s: Heard nothing, skipping ...\n", __func__);
audio.clear();
continue;
}
force_speak = false;
text_heard.insert(0, 1, ' ');
text_heard += "\n" + bot_name + chat_symb;
fprintf(stdout, "%s%s%s", "\033[1m", text_heard.c_str(), "\033[0m");
fflush(stdout);
embd = ::llama_tokenize(ctx_llama, text_heard, false);
// Append the new input tokens to the session_tokens vector
if (!path_session.empty()) {
session_tokens.insert(session_tokens.end(), tokens.begin(), tokens.end());
}
// text inference
bool done = false;
std::string text_to_speak;
while (true) {
// predict
if (embd.size() > 0) {
if (n_past + (int) embd.size() > n_ctx) {
n_past = n_keep;
// insert n_left/2 tokens at the start of embd from last_n_tokens
embd.insert(embd.begin(), embd_inp.begin() + embd_inp.size() - n_prev, embd_inp.end());
// stop saving session if we run out of context
path_session = "";
//printf("\n---\n");
//printf("resetting: '");
//for (int i = 0; i < (int) embd.size(); i++) {
// printf("%s", llama_token_to_str(ctx_llama, embd[i]));
//}
//printf("'\n");
//printf("\n---\n");
}
// try to reuse a matching prefix from the loaded session instead of re-eval (via n_past)
// REVIEW
if (n_session_consumed < (int) session_tokens.size()) {
size_t i = 0;
for ( ; i < embd.size(); i++) {
if (embd[i] != session_tokens[n_session_consumed]) {
session_tokens.resize(n_session_consumed);
break;
}
n_past++;
n_session_consumed++;
if (n_session_consumed >= (int) session_tokens.size()) {
i++;
break;
}
}
if (i > 0) {
embd.erase(embd.begin(), embd.begin() + i);
}
}
if (embd.size() > 0 && !path_session.empty()) {
session_tokens.insert(session_tokens.end(), embd.begin(), embd.end());
n_session_consumed = session_tokens.size();
}
if (llama_eval(ctx_llama, embd.data(), embd.size(), n_past, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return 1;
}
}
embd_inp.insert(embd_inp.end(), embd.begin(), embd.end());
n_past += embd.size();
embd.clear();
if (done) break;
{
// out of user input, sample next token
const float top_k = 5;
const float top_p = 0.80f;
const float temp = 0.30f;
const float repeat_penalty = 1.1764f;
const int repeat_last_n = 256;
if (!path_session.empty() && need_to_save_session) {
need_to_save_session = false;
llama_save_session_file(ctx_llama, path_session.c_str(), session_tokens.data(), session_tokens.size());
}
llama_token id = 0;
{
auto logits = llama_get_logits(ctx_llama);
auto n_vocab = llama_n_vocab(ctx_llama);
logits[llama_token_eos()] = 0;
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
}
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
// apply repeat penalty
const float nl_logit = logits[llama_token_nl()];
llama_sample_repetition_penalty(ctx_llama, &candidates_p,
embd_inp.data() + std::max(0, n_past - repeat_last_n),
repeat_last_n, repeat_penalty);
logits[llama_token_nl()] = nl_logit;
if (temp <= 0) {
// Greedy sampling
id = llama_sample_token_greedy(ctx_llama, &candidates_p);
} else {
// Temperature sampling
llama_sample_top_k(ctx_llama, &candidates_p, top_k, 1);
llama_sample_top_p(ctx_llama, &candidates_p, top_p, 1);
llama_sample_temperature(ctx_llama, &candidates_p, temp);
id = llama_sample_token(ctx_llama, &candidates_p);
}
}
if (id != llama_token_eos()) {
// add it to the context
embd.push_back(id);
text_to_speak += llama_token_to_str(ctx_llama, id);
printf("%s", llama_token_to_str(ctx_llama, id));
}
}
{
std::string last_output;
for (int i = embd_inp.size() - 16; i < (int) embd_inp.size(); i++) {
last_output += llama_token_to_str(ctx_llama, embd_inp[i]);
}
last_output += llama_token_to_str(ctx_llama, embd[0]);
for (std::string & antiprompt : antiprompts) {
if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
done = true;
text_to_speak = ::replace(text_to_speak, antiprompt, "");
fflush(stdout);
need_to_save_session = true;
break;
}
}
}
is_running = sdl_poll_events();
if (!is_running) {
break;
}
}
text_to_speak = ::replace(text_to_speak, "\"", "");
system((params.speak + " " + std::to_string(voice_id) + " \"" + text_to_speak + "\"").c_str());
audio.clear();
++n_iter;
}
}
}
audio.pause();
whisper_print_timings(ctx_wsp);
whisper_free(ctx_wsp);
llama_print_timings(ctx_llama);
llama_free(ctx_llama);
return 0;
}

View File

@ -13,6 +13,7 @@ include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE
whisper
common
)
unset(EXTRA_FLAGS)

View File

@ -1,4 +1,6 @@
#include "ggml.h"
#include "common-ggml.h"
#include "gpt-2.h"
#include <cmath>
@ -14,150 +16,6 @@
/////////////////////// GPT-2 BEGIN /////////////////////////
//
// Vocab utils
//
std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text) {
std::vector<std::string> words;
// first split the text into words
{
std::string str = text;
std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
std::regex re(pat);
std::smatch m;
while (std::regex_search(str, m, re)) {
for (auto x : m) {
words.push_back(x);
}
str = m.suffix();
}
}
// find the longest tokens that form the words:
std::vector<gpt_vocab::id> tokens;
for (const auto & word : words) {
if (word.size() == 0) continue;
int i = 0;
int n = word.size();
while (i < n) {
int j = n;
while (j > i) {
auto it = vocab.token_to_id.find(word.substr(i, j-i));
if (it != vocab.token_to_id.end()) {
tokens.push_back(it->second);
i = j;
break;
}
--j;
}
if (i == n) {
break;
}
if (j == i) {
auto sub = word.substr(i, 1);
if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) {
tokens.push_back(vocab.token_to_id.at(sub));
} else {
fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data());
}
++i;
}
}
}
return tokens;
}
gpt_vocab::id gpt_sample_top_k_top_p(
const gpt_vocab & vocab,
const float * logits,
int top_k,
double top_p,
double temp,
std::mt19937 & rng) {
int n_logits = vocab.id_to_token.size();
std::vector<std::pair<double, gpt_vocab::id>> logits_id;
logits_id.reserve(n_logits);
for (int i = 0; i < n_logits; i++) {
logits_id.push_back(std::make_pair(logits[i], i));
}
// find the top K tokens
std::partial_sort(
logits_id.begin(),
logits_id.begin() + top_k, logits_id.end(),
[](const std::pair<double, gpt_vocab::id> & a, const std::pair<double, gpt_vocab::id> & b) {
return a.first > b.first;
});
logits_id.resize(top_k);
// normalize
{
double sum = 0.0f;
for (int i = 0; i < (int)logits_id.size(); i++) {
sum += logits_id[i].first;
}
sum = 1.0/sum;
for (int i = 0; i < (int)logits_id.size(); i++) {
logits_id[i].first *= sum;
}
}
if (top_p < 1.0f) {
{
double cumsum = 0.0f;
for (int i = 0; i < top_k; i++) {
cumsum += logits_id[i].first;
if (cumsum >= top_p) {
logits_id.resize(i+1);
break;
}
}
}
// normalize again
{
double sum = 0.0f;
for (int i = 0; i < (int)logits_id.size(); i++) {
sum += logits_id[i].first;
}
sum = 1.0/sum;
for (int i = 0; i < (int)logits_id.size(); i++) {
logits_id[i].first *= sum;
}
}
}
//printf("\n");
//for (int i = 0; i < (int)logits_id.size(); i++) {
// printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), logits_id[i].first);
//}
//exit(0);
// sample from the obtained distribution
std::vector<double> probs;
probs.reserve(logits_id.size());
for (int i = 0; i < (int) logits_id.size(); i++) {
probs.push_back(logits_id[i].first);
}
std::discrete_distribution<> dist(probs.begin(), probs.end());
int idx = dist(rng);
return logits_id[idx].second;
}
// default hparams (GPT-2 117M)
struct gpt2_hparams {
int32_t n_vocab = 50257;
@ -165,7 +23,7 @@ struct gpt2_hparams {
int32_t n_embd = 768;
int32_t n_head = 12;
int32_t n_layer = 12;
int32_t f16 = 1;
int32_t ftype = 1;
};
struct gpt2_layer {
@ -187,7 +45,7 @@ struct gpt2_layer {
struct ggml_tensor * c_mlp_fc_w;
struct ggml_tensor * c_mlp_fc_b;
struct ggml_tensor * c_mlp_proj_w_trans; // transposed for efficiency
struct ggml_tensor * c_mlp_proj_w;
struct ggml_tensor * c_mlp_proj_b;
};
@ -198,8 +56,9 @@ struct gpt2_model {
struct ggml_tensor * ln_f_g;
struct ggml_tensor * ln_f_b;
struct ggml_tensor * wte; // position embedding
struct ggml_tensor * wpe; // token embedding
struct ggml_tensor * wte; // position embedding
struct ggml_tensor * wpe; // token embedding
struct ggml_tensor * lm_head; // language model head
std::vector<gpt2_layer> layers;
@ -241,14 +100,14 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
fin.read((char *) &hparams.n_embd, sizeof(hparams.n_embd));
fin.read((char *) &hparams.n_head, sizeof(hparams.n_head));
fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));
fin.read((char *) &hparams.f16, sizeof(hparams.f16));
fin.read((char *) &hparams.ftype, sizeof(hparams.ftype));
printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx);
printf("%s: n_embd = %d\n", __func__, hparams.n_embd);
printf("%s: n_head = %d\n", __func__, hparams.n_head);
printf("%s: n_layer = %d\n", __func__, hparams.n_layer);
printf("%s: f16 = %d\n", __func__, hparams.f16);
printf("%s: ftype = %d\n", __func__, hparams.ftype);
}
// load vocab
@ -275,9 +134,14 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
}
}
// for the big tensors, we have the option to store the data in 16-bit floats
// for the big tensors, we have the option to store the data in 16-bit floats or quantized
// in order to save memory and also to speed up the computation
const ggml_type wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype));
if (wtype == GGML_TYPE_COUNT) {
fprintf(stderr, "%s: invalid model file '%s' (bad ftype value %d)\n",
__func__, fname.c_str(), model.hparams.ftype);
return false;
}
auto & ctx = model.ctx;
@ -291,32 +155,33 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
const int n_ctx = hparams.n_ctx;
const int n_vocab = hparams.n_vocab;
ctx_size += n_embd*ggml_type_size(GGML_TYPE_F32); // ln_f_g
ctx_size += n_embd*ggml_type_size(GGML_TYPE_F32); // ln_f_b
ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_g
ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_b
ctx_size += n_vocab*n_embd*ggml_type_size(wtype); // wte
ctx_size += n_ctx*n_embd*ggml_type_size(GGML_TYPE_F32); // wpe
ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype); // wte
ctx_size += n_ctx*n_embd*ggml_type_sizef(GGML_TYPE_F32); // wpe
ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype); // lm_head
ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_1_g
ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_1_b
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_g
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_b
ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_2_g
ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_2_b
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_g
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_b
ctx_size += n_layer*(3*n_embd*n_embd*ggml_type_size(wtype)); // c_attn_attn_w
ctx_size += n_layer*( 3*n_embd*ggml_type_size(GGML_TYPE_F32)); // c_attn_attn_b
ctx_size += n_layer*(3*n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_attn_w
ctx_size += n_layer*( 3*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_attn_b
ctx_size += n_layer*(n_embd*n_embd*ggml_type_size(wtype)); // c_attn_proj_w
ctx_size += n_layer*( n_embd*ggml_type_size(GGML_TYPE_F32)); // c_attn_proj_b
ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_proj_w
ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_proj_b
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_size(wtype)); // c_mlp_fc_w
ctx_size += n_layer*( 4*n_embd*ggml_type_size(GGML_TYPE_F32)); // c_mlp_fc_b
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_fc_w
ctx_size += n_layer*( 4*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_fc_b
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_size(wtype)); // c_mlp_proj_w
ctx_size += n_layer*( n_embd*ggml_type_size(GGML_TYPE_F32)); // c_mlp_proj_b
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_proj_w
ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_proj_b
ctx_size += n_ctx*n_layer*n_embd*ggml_type_size(GGML_TYPE_F32); // memory_k
ctx_size += n_ctx*n_layer*n_embd*ggml_type_size(GGML_TYPE_F32); // memory_v
ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_k
ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_v
ctx_size += (6 + 12*n_layer)*256; // object overhead
@ -325,9 +190,11 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
// create the ggml context
{
struct ggml_init_params params;
params.mem_size = ctx_size;
params.mem_buffer = NULL;
struct ggml_init_params params = {
.mem_size = ctx_size,
.mem_buffer = NULL,
.no_alloc = false,
};
model.ctx = ggml_init(params);
if (!model.ctx) {
@ -350,36 +217,38 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
model.ln_f_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
model.ln_f_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
model.wte = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab);
model.wpe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ctx);
model.wte = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab);
model.wpe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ctx);
model.lm_head = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab);
// map by name
model.tensors["model/ln_f/g"] = model.ln_f_g;
model.tensors["model/ln_f/b"] = model.ln_f_b;
model.tensors["model/wte"] = model.wte;
model.tensors["model/wpe"] = model.wpe;
model.tensors["model/wte"] = model.wte;
model.tensors["model/wpe"] = model.wpe;
model.tensors["model/lm_head"] = model.lm_head;
for (int i = 0; i < n_layer; ++i) {
auto & layer = model.layers[i];
layer.ln_1_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.ln_1_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.ln_2_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.ln_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.ln_2_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.ln_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.c_attn_attn_w = ggml_new_tensor_2d(ctx, wtype, 3*n_embd, n_embd);
layer.c_attn_attn_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3*n_embd);
layer.c_attn_attn_w = ggml_new_tensor_2d(ctx, wtype, n_embd, 3*n_embd);
layer.c_attn_attn_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3*n_embd);
layer.c_attn_proj_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
layer.c_attn_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.c_attn_proj_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
layer.c_attn_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.c_mlp_fc_w = ggml_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd);
layer.c_mlp_fc_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_embd);
layer.c_mlp_fc_w = ggml_new_tensor_2d(ctx, wtype, n_embd, 4*n_embd);
layer.c_mlp_fc_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_embd);
layer.c_mlp_proj_w_trans = ggml_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd);
layer.c_mlp_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.c_mlp_proj_w = ggml_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd);
layer.c_mlp_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
// map by name
model.tensors["model/h" + std::to_string(i) + "/ln_1/g"] = layer.ln_1_g;
@ -397,7 +266,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/w"] = layer.c_mlp_fc_w;
model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/b"] = layer.c_mlp_fc_b;
model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/w"] = layer.c_mlp_proj_w_trans;
model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/w"] = layer.c_mlp_proj_w;
model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/b"] = layer.c_mlp_proj_b;
}
}
@ -425,14 +294,16 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
{
size_t total_size = 0;
bool has_lm_head = false;
while (true) {
int32_t n_dims;
int32_t length;
int32_t ftype;
int32_t ttype;
fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
fin.read(reinterpret_cast<char *>(&length), sizeof(length));
fin.read(reinterpret_cast<char *>(&ftype), sizeof(ftype));
fin.read(reinterpret_cast<char *>(&ttype), sizeof(ttype));
if (fin.eof()) {
break;
@ -461,13 +332,18 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
__func__, name.data(), tensor->ne[0], tensor->ne[1], ne[0], ne[1]);
__func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], ne[0], ne[1]);
return false;
}
const size_t bpe = (ftype == 0) ? sizeof(float) : sizeof(ggml_fp16_t);
// for debugging
if (0) {
printf("%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", name.data(), ne[0], ne[1], ggml_type_name(ggml_type(ttype)), ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor));
}
if (nelements*bpe != ggml_nbytes(tensor)) {
const size_t bpe = ggml_type_size(ggml_type(ttype));
if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
return false;
@ -475,7 +351,15 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
//printf("%24s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
// GPT-2 models share the WTE tensor as the LM head
if (name == "model/wte" && has_lm_head == false) {
memcpy(model.lm_head->data, tensor->data, ggml_nbytes(tensor));
}
if (name == "model/lm_head") {
has_lm_head = true;
}
total_size += ggml_nbytes(tensor);
}
@ -493,7 +377,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
// - n_threads: number of threads to use
// - n_past: the context size so far
// - embd_inp: the embeddings of the tokens in the context
// - embd_w: the predicted probabilities of the next token
// - embd_w: the predicted logits for the next token
//
bool gpt2_eval(
const gpt2_model & model,
@ -512,12 +396,12 @@ bool gpt2_eval(
const int n_head = hparams.n_head;
const int n_vocab = hparams.n_vocab;
static size_t buf_size = 640u*1024*1024;
static size_t buf_size = 512u*1024*1024;
static void * buf = malloc(buf_size);
if (mem_per_token > 0 && mem_per_token*N > buf_size) {
const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead
printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
//printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
// reallocate
buf_size = buf_size_new;
@ -528,13 +412,14 @@ bool gpt2_eval(
}
}
struct ggml_init_params params;
params.mem_size = buf_size;
params.mem_buffer = buf;
struct ggml_init_params params = {
/*.mem_size =*/ buf_size,
/*.mem_buffer =*/ buf,
/*.no_alloc =*/ false,
};
struct ggml_context * ctx0 = ggml_init(params);
struct ggml_cgraph gf = { };
struct ggml_cgraph gf = {};
gf.n_threads = n_threads;
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
@ -578,7 +463,7 @@ bool gpt2_eval(
// [2304, N]
{
cur = ggml_mul_mat(ctx0,
ggml_transpose(ctx0, model.layers[il].c_attn_attn_w),
model.layers[il].c_attn_attn_w,
cur);
cur = ggml_add(ctx0,
@ -654,11 +539,13 @@ bool gpt2_eval(
// V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
// [n_past + N, 64, 12]
struct ggml_tensor * V_trans =
ggml_permute(ctx0,
ggml_reshape_3d(ctx0,
ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
n_embd/n_head, n_head, n_past + N),
1, 2, 0, 3);
ggml_cpy(ctx0,
ggml_permute(ctx0,
ggml_reshape_3d(ctx0,
ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
n_embd/n_head, n_head, n_past + N),
1, 2, 0, 3),
ggml_new_tensor_3d(ctx0, model.memory_v->type, n_past + N, n_embd/n_head, n_head));
// KQV = transpose(V) * KQ_soft_max
// [64, N, 12]
@ -685,7 +572,7 @@ bool gpt2_eval(
// [768, N]
{
cur = ggml_mul_mat(ctx0,
ggml_transpose(ctx0, model.layers[il].c_attn_proj_w),
model.layers[il].c_attn_proj_w,
cur);
cur = ggml_add(ctx0,
@ -722,7 +609,7 @@ bool gpt2_eval(
// cur = fc_w*cur + fc_b
// [3072, N]
cur = ggml_mul_mat(ctx0,
ggml_transpose(ctx0, model.layers[il].c_mlp_fc_w),
model.layers[il].c_mlp_fc_w,
cur);
cur = ggml_add(ctx0,
@ -742,7 +629,7 @@ bool gpt2_eval(
// cur = proj_w*cur + proj_b
// [768, N]
cur = ggml_mul_mat(ctx0,
model.layers[il].c_mlp_proj_w_trans,
model.layers[il].c_mlp_proj_w,
cur);
cur = ggml_add(ctx0,
@ -769,12 +656,12 @@ bool gpt2_eval(
}
// inpL = WTE * inpL
// [ 768, 50257] - model.wte
// [ 768, 50257] - model.lm_head
// [ 768, N] - inpL
inpL = ggml_mul_mat(ctx0, model.wte, inpL);
inpL = ggml_mul_mat(ctx0, model.lm_head, inpL);
// logits -> probs
inpL = ggml_soft_max(ctx0, inpL);
//inpL = ggml_soft_max(ctx0, inpL);
// run the computation
ggml_build_forward_expand(&gf, inpL);
@ -788,7 +675,7 @@ bool gpt2_eval(
//embd_w.resize(n_vocab*N);
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
// return result for just the last token
// return result just for the last token
embd_w.resize(n_vocab);
memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
@ -825,7 +712,7 @@ Me too.
int32_t n_threads = std::min(N_THREAD, (int) std::thread::hardware_concurrency());
// sampling parameters
int32_t top_k = 40;
int32_t top_k = 5;
float top_p = 0.9f;
float temp = 1.0f;
};
@ -833,14 +720,15 @@ Me too.
struct gpt2_context * gpt2_init(const char * path_model) {
gpt2_context * ctx = new gpt2_context;
ctx->rng = std::mt19937(time(NULL));
ctx->rng = std::mt19937(time(nullptr));
// load the model
{
const int64_t t_start_us = ggml_time_us();
if (!gpt2_model_load(path_model, ctx->model, ctx->vocab)) {
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, "gpt-2.bin");
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, path_model);
delete ctx;
return nullptr;
}
@ -884,9 +772,9 @@ std::string gpt2_gen_text(gpt2_context * ctx, const char * text, int max_tokens)
std::string result;
for (int i = embd.size(); i < embd_inp.size() + n_predict; i++) {
for (int i = embd.size(); i < (int) embd_inp.size() + n_predict; i++) {
// predict
if (embd.size() > 0) {
if (!embd.empty()) {
if (!gpt2_eval(ctx->model, ctx->n_threads, n_past, embd, embd_w, mem_per_token)) {
printf("gpt-2: failed to generate text\n");
return "";
@ -913,10 +801,7 @@ std::string gpt2_gen_text(gpt2_context * ctx, const char * text, int max_tokens)
result += ctx->vocab.id_to_token[embd[0]];
// end of text token
if (embd.back() == 50256 ||
ctx->vocab.id_to_token[embd.back()] == "." ||
ctx->vocab.id_to_token[embd.back()] == "!" ||
ctx->vocab.id_to_token[embd.back()] == "?") {
if (embd.back() == 50256) {
break;
}
}

View File

@ -2,18 +2,12 @@
// TODO: Change to C-style API and move to ./examples for easy reuse.
#include "common.h"
#include <vector>
#include <map>
#include <string>
struct gpt_vocab {
using id = int32_t;
using token = std::string;
std::map<token, id> token_to_id;
std::map<id, token> id_to_token;
};
struct gpt2_context;
struct gpt2_context * gpt2_init(const char * path_model);

View File

@ -44,6 +44,15 @@
<br><br>
<b>More examples:</b>
<a href="https://whisper.ggerganov.com/">main</a> |
<a href="https://whisper.ggerganov.com/bench">bench</a> |
<a href="https://whisper.ggerganov.com/stream">stream</a> |
<a href="https://whisper.ggerganov.com/command">command</a> |
<a href="https://whisper.ggerganov.com/talk">talk</a> |
<br><br>
<hr>
Select the models you would like to use and click the "Start" button to begin the conversation
@ -54,6 +63,10 @@
Whisper model: <span id="model-whisper-status"></span>
<button id="fetch-whisper-tiny-en" onclick="loadWhisper('tiny.en')">tiny.en (75 MB)</button>
<button id="fetch-whisper-base-en" onclick="loadWhisper('base.en')">base.en (142 MB)</button>
<br><br>
Quantized models:<br><br>
<button id="fetch-whisper-tiny-en-q5_1" onclick="loadWhisper('tiny-en-q5_1')">tiny.en (Q5_1, 31 MB)</button>
<button id="fetch-whisper-base-en-q5_1" onclick="loadWhisper('base-en-q5_1')">base.en (Q5_1, 57 MB)</button>
<span id="fetch-whisper-progress"></span>
<!--
@ -266,11 +279,17 @@
let urls = {
'tiny.en': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en.bin',
'base.en': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en.bin',
'tiny-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en-q5_1.bin',
'base-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en-q5_1.bin',
};
let sizes = {
'tiny.en': 75,
'base.en': 142,
'tiny-en-q5_1': 31,
'base-en-q5_1': 57,
};
let url = urls[model];
@ -281,6 +300,10 @@
document.getElementById('fetch-whisper-tiny-en').style.display = 'none';
document.getElementById('fetch-whisper-base-en').style.display = 'none';
document.getElementById('fetch-whisper-tiny-en-q5_1').style.display = 'none';
document.getElementById('fetch-whisper-base-en-q5_1').style.display = 'none';
document.getElementById('model-whisper-status').innerHTML = 'loading "' + model + '" ... ';
cbProgress = function(p) {
@ -292,6 +315,10 @@
var el;
el = document.getElementById('fetch-whisper-tiny-en'); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-base-en'); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-tiny-en-q5_1'); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-base-en-q5_1'); if (el) el.style.display = 'inline-block';
el = document.getElementById('model-whisper-status'); if (el) el.innerHTML = '';
};

View File

@ -1 +1 @@
eleven-labs.py
audio.mp3

View File

@ -1,16 +1,8 @@
if (WHISPER_SUPPORT_SDL2)
if (WHISPER_SDL2)
# talk
set(TARGET talk)
#add_executable(${TARGET} talk.cpp gpt-2.cpp)
#target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
#target_link_libraries(${TARGET} PRIVATE whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
# TODO: this is temporary
# need to export ggml symbols for MSVC, but too lazy ..
add_executable(${TARGET} talk.cpp gpt-2.cpp ../common.cpp ../common-sdl.cpp ../../ggml.c ../../whisper.cpp)
add_executable(${TARGET} talk.cpp gpt-2.cpp)
target_link_libraries(${TARGET} PRIVATE common common-sdl whisper ${CMAKE_THREAD_LIBS_INIT})
include(DefaultTargetOptions)
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS} ../../)
target_link_libraries(${TARGET} PRIVATE ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
endif ()

View File

@ -37,5 +37,5 @@ wget --quiet --show-progress -O models/ggml-gpt-2-117M.bin https://huggingface.c
## TTS
For best experience, this example needs a TTS tool to convert the generated text responses to voice.
You can use any TTS engine that you would like - simply edit the [speak.sh](speak.sh) script to your needs.
By default, it is configured to use `espeak`, but you can use whatever you wish.
You can use any TTS engine that you would like - simply edit the [speak](speak) script to your needs.
By default, it is configured to use MacOS's `say` or `espeak` or Windows SpeechSynthesizer, but you can use whatever you wish.

View File

@ -0,0 +1,20 @@
import sys
import importlib.util
if importlib.util.find_spec("elevenlabs") is None:
print("elevenlabs library is not installed, you can install it to your enviroment using 'pip install elevenlabs'")
sys.exit()
from elevenlabs import generate, play, save
# Get a Voice object, by name or UUID
voice = "Arnold" #Possible Voices: Adam Antoni Arnold Bella Domi Elli Josh
# Generate the TTS
audio = generate(
text=str(sys.argv[2:]),
voice=voice
)
# Save the TTS to a file
save(audio, "audio.mp3")

View File

@ -1,4 +1,6 @@
#include "ggml.h"
#include "common-ggml.h"
#include "gpt-2.h"
#include <cmath>
@ -14,150 +16,6 @@
/////////////////////// GPT-2 BEGIN /////////////////////////
//
// Vocab utils
//
std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text) {
std::vector<std::string> words;
// first split the text into words
{
std::string str = text;
std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
std::regex re(pat);
std::smatch m;
while (std::regex_search(str, m, re)) {
for (auto x : m) {
words.push_back(x);
}
str = m.suffix();
}
}
// find the longest tokens that form the words:
std::vector<gpt_vocab::id> tokens;
for (const auto & word : words) {
if (word.empty()) continue;
int i = 0;
int n = word.size();
while (i < n) {
int j = n;
while (j > i) {
auto it = vocab.token_to_id.find(word.substr(i, j-i));
if (it != vocab.token_to_id.end()) {
tokens.push_back(it->second);
i = j;
break;
}
--j;
}
if (i == n) {
break;
}
if (j == i) {
auto sub = word.substr(i, 1);
if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) {
tokens.push_back(vocab.token_to_id.at(sub));
} else {
fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data());
}
++i;
}
}
}
return tokens;
}
gpt_vocab::id gpt_sample_top_k_top_p(
const gpt_vocab & vocab,
const float * logits,
int top_k,
double top_p,
double /*temp*/,
std::mt19937 & rng) {
int n_logits = vocab.id_to_token.size();
std::vector<std::pair<double, gpt_vocab::id>> logits_id;
logits_id.reserve(n_logits);
for (int i = 0; i < n_logits; i++) {
logits_id.emplace_back(logits[i], i);
}
// find the top K tokens
std::partial_sort(
logits_id.begin(),
logits_id.begin() + top_k, logits_id.end(),
[](const std::pair<double, gpt_vocab::id> & a, const std::pair<double, gpt_vocab::id> & b) {
return a.first > b.first;
});
logits_id.resize(top_k);
// normalize
{
double sum = 0.0f;
for (int i = 0; i < (int)logits_id.size(); i++) {
sum += logits_id[i].first;
}
sum = 1.0/sum;
for (int i = 0; i < (int)logits_id.size(); i++) {
logits_id[i].first *= sum;
}
}
if (top_p < 1.0f) {
{
double cumsum = 0.0f;
for (int i = 0; i < top_k; i++) {
cumsum += logits_id[i].first;
if (cumsum >= top_p) {
logits_id.resize(i+1);
break;
}
}
}
// normalize again
{
double sum = 0.0f;
for (int i = 0; i < (int)logits_id.size(); i++) {
sum += logits_id[i].first;
}
sum = 1.0/sum;
for (int i = 0; i < (int)logits_id.size(); i++) {
logits_id[i].first *= sum;
}
}
}
//printf("\n");
//for (int i = 0; i < (int) logits_id.size(); i++) {
// printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), logits_id[i].first);
//}
//exit(0);
// sample from the obtained distribution
std::vector<double> probs;
probs.reserve(logits_id.size());
for (int i = 0; i < (int) logits_id.size(); i++) {
probs.push_back(logits_id[i].first);
}
std::discrete_distribution<> dist(probs.begin(), probs.end());
int idx = dist(rng);
return logits_id[idx].second;
}
// default hparams (GPT-2 117M)
struct gpt2_hparams {
int32_t n_vocab = 50257;
@ -165,7 +23,7 @@ struct gpt2_hparams {
int32_t n_embd = 768;
int32_t n_head = 12;
int32_t n_layer = 12;
int32_t f16 = 1;
int32_t ftype = 1;
};
struct gpt2_layer {
@ -187,7 +45,7 @@ struct gpt2_layer {
struct ggml_tensor * c_mlp_fc_w;
struct ggml_tensor * c_mlp_fc_b;
struct ggml_tensor * c_mlp_proj_w_trans; // transposed for efficiency
struct ggml_tensor * c_mlp_proj_w;
struct ggml_tensor * c_mlp_proj_b;
};
@ -198,8 +56,9 @@ struct gpt2_model {
struct ggml_tensor * ln_f_g;
struct ggml_tensor * ln_f_b;
struct ggml_tensor * wte; // position embedding
struct ggml_tensor * wpe; // token embedding
struct ggml_tensor * wte; // position embedding
struct ggml_tensor * wpe; // token embedding
struct ggml_tensor * lm_head; // language model head
std::vector<gpt2_layer> layers;
@ -241,14 +100,14 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
fin.read((char *) &hparams.n_embd, sizeof(hparams.n_embd));
fin.read((char *) &hparams.n_head, sizeof(hparams.n_head));
fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));
fin.read((char *) &hparams.f16, sizeof(hparams.f16));
fin.read((char *) &hparams.ftype, sizeof(hparams.ftype));
printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx);
printf("%s: n_embd = %d\n", __func__, hparams.n_embd);
printf("%s: n_head = %d\n", __func__, hparams.n_head);
printf("%s: n_layer = %d\n", __func__, hparams.n_layer);
printf("%s: f16 = %d\n", __func__, hparams.f16);
printf("%s: ftype = %d\n", __func__, hparams.ftype);
}
// load vocab
@ -268,16 +127,21 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
fin.read((char *) &len, sizeof(len));
word.resize(len);
fin.read((char *) &word[0], len);
fin.read((char *) word.data(), len);
vocab.token_to_id[word] = i;
vocab.id_to_token[i] = word;
}
}
// for the big tensors, we have the option to store the data in 16-bit floats
// for the big tensors, we have the option to store the data in 16-bit floats or quantized
// in order to save memory and also to speed up the computation
const ggml_type wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype));
if (wtype == GGML_TYPE_COUNT) {
fprintf(stderr, "%s: invalid model file '%s' (bad ftype value %d)\n",
__func__, fname.c_str(), model.hparams.ftype);
return false;
}
auto & ctx = model.ctx;
@ -291,32 +155,33 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
const int n_ctx = hparams.n_ctx;
const int n_vocab = hparams.n_vocab;
ctx_size += n_embd*ggml_type_size(GGML_TYPE_F32); // ln_f_g
ctx_size += n_embd*ggml_type_size(GGML_TYPE_F32); // ln_f_b
ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_g
ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_b
ctx_size += n_vocab*n_embd*ggml_type_size(wtype); // wte
ctx_size += n_ctx*n_embd*ggml_type_size(GGML_TYPE_F32); // wpe
ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype); // wte
ctx_size += n_ctx*n_embd*ggml_type_sizef(GGML_TYPE_F32); // wpe
ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype); // lm_head
ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_1_g
ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_1_b
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_g
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_b
ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_2_g
ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_2_b
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_g
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_b
ctx_size += n_layer*(3*n_embd*n_embd*ggml_type_size(wtype)); // c_attn_attn_w
ctx_size += n_layer*( 3*n_embd*ggml_type_size(GGML_TYPE_F32)); // c_attn_attn_b
ctx_size += n_layer*(3*n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_attn_w
ctx_size += n_layer*( 3*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_attn_b
ctx_size += n_layer*(n_embd*n_embd*ggml_type_size(wtype)); // c_attn_proj_w
ctx_size += n_layer*( n_embd*ggml_type_size(GGML_TYPE_F32)); // c_attn_proj_b
ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_proj_w
ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_proj_b
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_size(wtype)); // c_mlp_fc_w
ctx_size += n_layer*( 4*n_embd*ggml_type_size(GGML_TYPE_F32)); // c_mlp_fc_b
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_fc_w
ctx_size += n_layer*( 4*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_fc_b
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_size(wtype)); // c_mlp_proj_w
ctx_size += n_layer*( n_embd*ggml_type_size(GGML_TYPE_F32)); // c_mlp_proj_b
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_proj_w
ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_proj_b
ctx_size += n_ctx*n_layer*n_embd*ggml_type_size(GGML_TYPE_F32); // memory_k
ctx_size += n_ctx*n_layer*n_embd*ggml_type_size(GGML_TYPE_F32); // memory_v
ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_k
ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_v
ctx_size += (6 + 12*n_layer)*256; // object overhead
@ -325,9 +190,11 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
// create the ggml context
{
struct ggml_init_params params;
params.mem_size = ctx_size;
params.mem_buffer = nullptr;
struct ggml_init_params params = {
.mem_size = ctx_size,
.mem_buffer = NULL,
.no_alloc = false,
};
model.ctx = ggml_init(params);
if (!model.ctx) {
@ -350,36 +217,38 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
model.ln_f_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
model.ln_f_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
model.wte = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab);
model.wpe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ctx);
model.wte = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab);
model.wpe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ctx);
model.lm_head = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab);
// map by name
model.tensors["model/ln_f/g"] = model.ln_f_g;
model.tensors["model/ln_f/b"] = model.ln_f_b;
model.tensors["model/wte"] = model.wte;
model.tensors["model/wpe"] = model.wpe;
model.tensors["model/wte"] = model.wte;
model.tensors["model/wpe"] = model.wpe;
model.tensors["model/lm_head"] = model.lm_head;
for (int i = 0; i < n_layer; ++i) {
auto & layer = model.layers[i];
layer.ln_1_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.ln_1_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.ln_2_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.ln_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.ln_2_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.ln_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.c_attn_attn_w = ggml_new_tensor_2d(ctx, wtype, 3*n_embd, n_embd);
layer.c_attn_attn_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3*n_embd);
layer.c_attn_attn_w = ggml_new_tensor_2d(ctx, wtype, n_embd, 3*n_embd);
layer.c_attn_attn_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3*n_embd);
layer.c_attn_proj_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
layer.c_attn_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.c_attn_proj_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
layer.c_attn_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.c_mlp_fc_w = ggml_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd);
layer.c_mlp_fc_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_embd);
layer.c_mlp_fc_w = ggml_new_tensor_2d(ctx, wtype, n_embd, 4*n_embd);
layer.c_mlp_fc_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_embd);
layer.c_mlp_proj_w_trans = ggml_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd);
layer.c_mlp_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.c_mlp_proj_w = ggml_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd);
layer.c_mlp_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
// map by name
model.tensors["model/h" + std::to_string(i) + "/ln_1/g"] = layer.ln_1_g;
@ -397,7 +266,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/w"] = layer.c_mlp_fc_w;
model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/b"] = layer.c_mlp_fc_b;
model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/w"] = layer.c_mlp_proj_w_trans;
model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/w"] = layer.c_mlp_proj_w;
model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/b"] = layer.c_mlp_proj_b;
}
}
@ -425,14 +294,16 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
{
size_t total_size = 0;
bool has_lm_head = false;
while (true) {
int32_t n_dims;
int32_t length;
int32_t ftype;
int32_t ttype;
fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
fin.read(reinterpret_cast<char *>(&length), sizeof(length));
fin.read(reinterpret_cast<char *>(&ftype), sizeof(ftype));
fin.read(reinterpret_cast<char *>(&ttype), sizeof(ttype));
if (fin.eof()) {
break;
@ -448,7 +319,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
std::string name(length, 0);
fin.read(&name[0], length);
if (model.tensors.find(name) == model.tensors.end()) {
if (model.tensors.find(name.data()) == model.tensors.end()) {
fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
return false;
}
@ -461,13 +332,18 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
__func__, name.data(), tensor->ne[0], tensor->ne[1], ne[0], ne[1]);
__func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], ne[0], ne[1]);
return false;
}
const size_t bpe = (ftype == 0) ? sizeof(float) : sizeof(ggml_fp16_t);
// for debugging
if (0) {
printf("%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", name.data(), ne[0], ne[1], ggml_type_name(ggml_type(ttype)), ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor));
}
if (nelements*bpe != ggml_nbytes(tensor)) {
const size_t bpe = ggml_type_size(ggml_type(ttype));
if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
return false;
@ -475,7 +351,15 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
//printf("%24s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
// GPT-2 models share the WTE tensor as the LM head
if (name == "model/wte" && has_lm_head == false) {
memcpy(model.lm_head->data, tensor->data, ggml_nbytes(tensor));
}
if (name == "model/lm_head") {
has_lm_head = true;
}
total_size += ggml_nbytes(tensor);
}
@ -493,7 +377,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
// - n_threads: number of threads to use
// - n_past: the context size so far
// - embd_inp: the embeddings of the tokens in the context
// - embd_w: the predicted probabilities of the next token
// - embd_w: the predicted logits for the next token
//
bool gpt2_eval(
const gpt2_model & model,
@ -512,12 +396,12 @@ bool gpt2_eval(
const int n_head = hparams.n_head;
const int n_vocab = hparams.n_vocab;
static size_t buf_size = 5640ull*1024*1024;
static size_t buf_size = 512u*1024*1024;
static void * buf = malloc(buf_size);
if (mem_per_token > 0 && mem_per_token*N > buf_size) {
const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead
printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
//printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
// reallocate
buf_size = buf_size_new;
@ -528,13 +412,14 @@ bool gpt2_eval(
}
}
struct ggml_init_params params;
params.mem_size = buf_size;
params.mem_buffer = buf;
struct ggml_init_params params = {
/*.mem_size =*/ buf_size,
/*.mem_buffer =*/ buf,
/*.no_alloc =*/ false,
};
struct ggml_context * ctx0 = ggml_init(params);
struct ggml_cgraph gf = { };
struct ggml_cgraph gf = {};
gf.n_threads = n_threads;
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
@ -578,7 +463,7 @@ bool gpt2_eval(
// [2304, N]
{
cur = ggml_mul_mat(ctx0,
ggml_transpose(ctx0, model.layers[il].c_attn_attn_w),
model.layers[il].c_attn_attn_w,
cur);
cur = ggml_add(ctx0,
@ -654,11 +539,13 @@ bool gpt2_eval(
// V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
// [n_past + N, 64, 12]
struct ggml_tensor * V_trans =
ggml_permute(ctx0,
ggml_reshape_3d(ctx0,
ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
n_embd/n_head, n_head, n_past + N),
1, 2, 0, 3);
ggml_cpy(ctx0,
ggml_permute(ctx0,
ggml_reshape_3d(ctx0,
ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
n_embd/n_head, n_head, n_past + N),
1, 2, 0, 3),
ggml_new_tensor_3d(ctx0, model.memory_v->type, n_past + N, n_embd/n_head, n_head));
// KQV = transpose(V) * KQ_soft_max
// [64, N, 12]
@ -685,7 +572,7 @@ bool gpt2_eval(
// [768, N]
{
cur = ggml_mul_mat(ctx0,
ggml_transpose(ctx0, model.layers[il].c_attn_proj_w),
model.layers[il].c_attn_proj_w,
cur);
cur = ggml_add(ctx0,
@ -722,7 +609,7 @@ bool gpt2_eval(
// cur = fc_w*cur + fc_b
// [3072, N]
cur = ggml_mul_mat(ctx0,
ggml_transpose(ctx0, model.layers[il].c_mlp_fc_w),
model.layers[il].c_mlp_fc_w,
cur);
cur = ggml_add(ctx0,
@ -742,7 +629,7 @@ bool gpt2_eval(
// cur = proj_w*cur + proj_b
// [768, N]
cur = ggml_mul_mat(ctx0,
model.layers[il].c_mlp_proj_w_trans,
model.layers[il].c_mlp_proj_w,
cur);
cur = ggml_add(ctx0,
@ -769,12 +656,12 @@ bool gpt2_eval(
}
// inpL = WTE * inpL
// [ 768, 50257] - model.wte
// [ 768, 50257] - model.lm_head
// [ 768, N] - inpL
inpL = ggml_mul_mat(ctx0, model.wte, inpL);
inpL = ggml_mul_mat(ctx0, model.lm_head, inpL);
// logits -> probs
inpL = ggml_soft_max(ctx0, inpL);
//inpL = ggml_soft_max(ctx0, inpL);
// run the computation
ggml_build_forward_expand(&gf, inpL);
@ -788,7 +675,7 @@ bool gpt2_eval(
//embd_w.resize(n_vocab*N);
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
// return result for just the last token
// return result just for the last token
embd_w.resize(n_vocab);
memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);

View File

@ -2,18 +2,12 @@
// TODO: Change to C-style API and move to ./examples for easy reuse.
#include "common.h"
#include <vector>
#include <map>
#include <string>
struct gpt_vocab {
using id = int32_t;
using token = std::string;
std::map<token, id> token_to_id;
std::map<id, token> id_to_token;
};
struct gpt2_context;
struct gpt2_context * gpt2_init(const char * path_model);

24
examples/talk/speak Normal file
View File

@ -0,0 +1,24 @@
#!/bin/bash
# Usage:
# speak.sh <voice_id> <text-to-speak>
# espeak
# Mac OS: brew install espeak
# Linux: apt-get install espeak
#
#espeak -v en-us+m$1 -s 175 -p 50 -a 200 -g 5 -k 5 "$2"
# Mac OS "say" command
say "$2"
# Eleven Labs
# To use it, install the elevenlabs module from pip (pip install elevenlabs)
# It's possible to use the API for free with limited number of characters. To increase this limit register to https://beta.elevenlabs.io to get an api key and paste it after 'ELEVEN_API_KEY='
#Keep the line commented to use the free version without api key
#
#export ELEVEN_API_KEY=your_api_key
#wd=$(dirname $0)
#script=$wd/eleven-labs.py
#python3 $script $1 "$2"
#ffplay -autoexit -nodisp -loglevel quiet -hide_banner -i ./audio.mp3

1
examples/talk/speak.bat Normal file
View File

@ -0,0 +1 @@
@powershell -ExecutionPolicy Bypass -F examples\talk\speak.ps1 %1 %2

Some files were not shown because too many files have changed in this diff Show More