Compare commits

..

117 Commits

Author SHA1 Message Date
3c50be2217 whisper : remove comment 2023-09-10 13:27:06 +03:00
37de5dcf2b command : enable beam-search, add "no_timestamps", add "context", add p 2023-09-10 12:22:57 +03:00
7a2abb311d grammars : add assistant + update comments 2023-09-09 20:24:58 +03:00
54d168db67 command : grammar-related improvements
- option to read grammar from file
- add sample grammars for colors and chess moves
- fine-tune the performance further
2023-09-09 20:05:57 +03:00
b8f34d1ed7 whisper : fine-tuning grammar functionality 2023-09-06 17:05:05 +03:00
97ebb48b99 command : fix exception when recognizing the command 2023-09-06 15:05:17 +03:00
b0306cd5cf build : fix after master merge 2023-09-06 14:28:25 +03:00
afc84b35b0 Merge branch 'master' into HEAD 2023-09-06 13:33:34 +03:00
c3f319d7c2 ggml : sync latest llama.cpp (view_src + alloc improvements) (#1247)
* ggml : sync latest llama.cpp (view_src + alloc improvements)

* ggml : fix build
2023-09-05 20:57:27 +03:00
ba3c333611 make : improve cpuinfo handling on x86 hosts (#1238)
* make : simplify and correct x86 ISA extensions detection on the host

It got broken in commit c5f9acf4b7 for Haiku and Mac OS (Intel),
which report CPU features in upper case.

Now we're finding the names in case-insensitive manner and as words.
SSE3 detection has been corrected for Linux, which uses PNI for that
(Prescott New Instructions).

* make : use dmesg.boot in FreeBSD/DragonFlyBSD to detect x86 ISA extensions on the host

* make : enable x86 ISA extensions on the host both in CFLAGS and CXXFLAGS

* make : correct AVX x86 ISA extension detection on macOS (Intel) host

It got broken in commit c5f9acf4b7.  macOS calls it AVX1.0.
2023-09-05 14:58:47 +03:00
59a3d0cb57 ggml : sync (ggml-alloc, GPU, eps, etc.) (#1220)
* ggml : sync (ggml-alloc, GPU, eps, etc.)

* ggml : fix build

* wasm : fix build
2023-09-05 13:54:40 +03:00
6780c98e19 readme : update CMake build commands (#1231)
* Update README.md

* Update README.md: `vcpkg install opencl clblast`

* readme : update build commands

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2023-09-05 13:53:34 +03:00
2f52783a08 OSSRH_USERNAME -> JIRA_USER 2023-08-31 14:54:02 +10:00
7dec9d8cc4 build-root-directory: bindings/java 2023-08-31 12:04:16 +10:00
31476ccc0e whisper : add grammar-based sampling 2023-08-30 20:00:31 -04:00
fb0a24fba2 ci : enable java package publishing (#1228) 2023-08-31 09:57:43 +10:00
8e30bf3c02 ggml : fix compilation errors incurred by -Werror (#1227)
The -Werror warning option turns all warnings into errors. This PR makes
the compiler happy to build ggml.c and whisper.cpp with the stricter option.
2023-08-30 22:09:15 +03:00
99d3c105f5 whisper.android : fix cmake multiple libraries build (#1224)
* whisper.android : fix multiple libraries build

* fix flags for default target
2023-08-30 14:45:13 +03:00
18e9889418 coreml : wrap inference call in @autoreleasepool to fix memory leak (#1218) 2023-08-29 15:44:38 +03:00
8e46ba80d3 make : use cpuinfo in MSYS2 to enable x86 ISA extensions on the host (#1216) 2023-08-28 13:28:26 +03:00
b0d35995c4 make : add support for building on DragonFlyBSD/NetBSD/OpenBSD (#1212) 2023-08-27 21:38:46 +03:00
25466aa1c3 ggml : fix compiling when SSE3 is available but not SSSE3 (#1210)
It got broken in commit 3998465721.
2023-08-27 21:37:31 +03:00
601c2d2181 ggml : detect SSSE3 (#1211)
* ggml : add ggml_cpu_has_ssse3

* whisper : show SSSE3 in system info

* make : detect SSSE3 via cpuinfo
2023-08-27 21:36:41 +03:00
175ffa64ee examples : vim plugin and LSP server (#1144)
* Initial proof of concept Vim plugin

At present, this is likely only slightly better than feature parity with
the existing whisper.nvim

Known issues:
 Trailing whitespace
 Up to an existing length(5 seconds) of speech may be processed when
  listening is enabled
 CPU cycles are spent processing speech even when not listening.

Fixing these issues is likely dependent upon future efforts to create a
dedicated library instead of wrapping examples/stream

* Support $WHISPER_CPP_HOME environment variable

A minor misunderstanding of the whisper.nvim implementation resulted in
a plugin that was functional, but not a drop in replacement as it should
be now.

* Initial progress on LSP implementation

Libcall is nonviable because the library is immediately freed after a
call is made. Further investigation has shown Language Server Protocol as
a promising alternative that both simplifies the required logic on the
vimscript side and increases the ease with which plugins for other
editors could be made in the future. This is a very large undertaking
and my progress has slowed substantially.

Work is far from being in a usable state, but I wish to keep track of
major refactors for organizational purposes.

* Rewrite audio windowing of guided transcription

One of the defining goals of this venture is allowing consecutive
commands to be rattled off without the existing deadzones of the current
implementation.

* Add unguided_transcription. Cleanup.

The unguided transcription implantation heavily borrows from existing
example implementations and the guided_transcription logic.

A high level pass was done to check that method arguments are accurate
to what inputs are actually required.

A first attempt at cancellation support was added for record keeping,
but will be deleted in a future commit.

* Fix compilation.

Resolves a large number of compilation errors.
No testing has been done yet for execution errors.

Update Makefile and .gitignore

* Functional unguided_transcription

* Functional guided_transcription

Fix commandset_list being passed by value
Properly register the first token of a multitoken command

* Minor changes before time fix

I've apparently made an awfully major mistake in thinking that unix time
was in milliseconds and will be changing all timekeeping code to use
standardized methods.

In preparation for this is a number of minor bugfixes.
Output is manually flushed.
An echo method has been added.
registerCommandset now wraps the returned index

* Swap timekeeping to use std::chrono

* Add work in progress lsp backed whisper.vim plugin

Current progress blockers are
 Adding modality awareness to the command processing
  (specifically, motion prompting)
 Improving the VAD to be a little more responsive
  (testing start of activity)

* Reworked vim plugin command loop

* Fix change inside

Multiple bug fixes that, crucially, bring the plugin to the point where a
demonstration video is possible

Add better echo messaging so whisper_log isn't required
 Add loading complete message as indicator when listening has started
Insert/append are actually included in command sets
Some more heavy handed corrections to prevent a double exit when leaving
insert mode
As a somewhat hacky fix, the very first space is removed when inserting.
 This cleans up most use cases, but leaves me unsatisfied with the few
 cases it would be desired.

* Forcibly set commandset_index to 0 after subinsert

Also remove unnecessary ! to use builtin vim command

* Fix upper

A minor scope mistake was causing upper'd inputs to be eaten.
This was fixed and echoing was slightly improved for clarity.

* Fix formatting

Corrects indentation to 4 spaces as project standard
Slightly better error support for malformed json input

* Remove obsolete vim plugin

* Add json.hpp library

The same library that is used for the llama.cpp server

* Minor cleanups

add lsp to the make clean directive.
remove a redundant params definition.
reorder whisper.vim logging for subtranscriptions
Corrections to unlets (variables of argument scope appear immutable)

* Fix indentation. Fallback for subTranscription

Indentation has been changed to 4 spaces.

Unit testing has been set up, I'm opting not to include it in the
repository for now.
It however, has revealed a bug in the state logic where a
subtranscription can be initiated without having a saved command
When this occurs, append is added as a fallback

* Move audio polling logic to a subfunction

While work on the improved vad will continue, It's grown to be a little
out of scope. Instead, a future commit will perform multiple detection
passes at substretches of audio when a backlog of audio exists.

To facilitate this, and prevent code duplication, the vad code has been
moved into a subfunction shared by both the unguided and guided
transcription functions.

* Test for voice over subchunks if backlog > 1s

As the existing VAD implementation only checks for a falling edge at the
end of an audio chunk. It fails to detect voice in cases where the
recorded voice is only at the beginning of the audio.

To ameliorate this, when the timestamp would cause analysis of audio
over a second in length, it is split into 1 second length subchunks
which are individually tested.

Results are promising, but there seems to be a remaining bug with
unguided transcription likely related to saving context

* Limit the maximum length of audio input.

This existing VAD implementation only detects falling edges, which
means any gap in the users speaking is processed for transcription.
This simply establishes a constant maximum length depending on the type
of transcription. Uguided gets a generous 10 seconds and guided, 2.

While quick testing showed that commands are generally around a half a
second to a second, limiting commands to an even second resulted in
extreme degradation of quality. (Seemingly always the same output for a
given commandset)

* Unguided timestamp tracking, cleanup

Unguided transcriptions where not setup to allow for passing of
timestamp data forward, but have been corrected.

No_context is now always set to false. While conceptually desirable for
the quality of guided transcription, It was seemingly responsible for
prior command inputs ghosting in unguided transcription.

Save and Run are now tracked by command number instead of command text.
While command_text was provided for convenience, I wish to keep command
index authoritative. This gives greater consistency and potentially
allows for end users to rename or even translate the spoken versions of
these commands

* By default, maintain mode.

Previously, mode was reset to 0 unless otherwise set.
In addition to causing some edge cases, this was didn't mesh well with
the existing approach to visual mode.

With this change, initial tests indicate visual mode is functional.

* Add undo breaks before subtranscriptions

Subtranscriptions use undo as a hack to allow for partial responses to
be displayed. However, scripts don't cause an undo break mid execution
unless specifically instructed to. This meant that multiple
unguided transcriptions from a single session would cause a latter to
undo a former.

This is now fixed and undo should be reasonably usable as a command.

* Append instead of insert for new undo sequence

When entering and leavening insert mode with `i`, the cursor shifts one
column to the left. This is remedied by using append instead of insert
for setting these breaks in the undo sequence

`-` was also added to the pronunciation dictionary to be pronounced as
minus as it was causing a particularly high failure rate.

* Move undo sequence breaks to command execution

Previously, undo sequence breaks were triggered when there was a command
that caused a move to insert mode. This caused commands that changed
state (like delete or paste) to be bundled together with into the last
command that caused text to be entered.

* Fix repeat. Add space, carrot, dollar commands

 Repeat (.) wasn't being tracked properly just like undo and is being
 manually tracked now.

 While efforts have been made to properly handle spaces, it was
 particularly finicky to add a single space when one is needed. A
 special 'space' command has been added to insert a single space and move
 the cursor after it.

 Carrot and Dollar commands have been added for start of line and end of
 line respectively. These are both simple to implement, and just a
 matter of defining a pronunciation.

* Return error on duplicate in commandset

Not every command in the commandset tokenizes to a single token.
Because of this, it's possible for that two commands could resolve to
the same single token after subsequent tokens are discarded.

This commit adds a simple check for duplicates when a commandset is
registered and returns an error if so.

Additional code will be required later on the vim side to actually
process this error.

* Add support for user-defined commands

This adds a user definable dictionary from spoken keys to strings or
funcrefs. All keys are added to the commandlist and when spoken, trigger
the corresponding function.

Like "save" and "run", these user commands are only available when the
command buffer is empty.

* Add readme, update cmake

* Add area commandset. Refactor spoken_dict

Area commands (inside word, around sentence...) have been given a
commandset as considered earlier.

Verbose definitions for spoken_dict entries now use dicts instead of
lists. This shortens the definition for most keys that require it and
scales better with the addition of further commandsets

* Add mark, jump. Fix change under visual.

Mark (m) and jump (') have been added.

When a visual selection was executed upon a command that initiated a
subtranscription (change) the area of the visual selection is not
properly tracked which causes the attempt to stream in partial response
to fail. This is solved by disabling partial transcriptions from being
streamed when a subtranscription is started while in visual mode.

* Accommodate ignorecase. Fix change.

From testing on older different versions of vim, the test for
distinguishing an 'R' replace all from an 'r' replace could fail if
ignorecase was set. The comparison has been changed to explicitly
require case matching

Change detection has been moved to the execution section as it was missing the
change+motion case.

* Support registers. Fix README typo

There's no logic to prevent doubled register entry, but the functional
result is equivalent to if the same key order was typed into vim.

A minor typo in the readme. I've mismemorized the mnemonic for 't' as 'to'
instead of till., but 'to' can't be used as it's a homophone with '2'.
While there was no mistake in the actual logic, it was misleading to use
'to' in the readme.
2023-08-27 21:35:06 +03:00
cb5fb0a12d whisper : initial hipBLAS support (#1209) 2023-08-27 20:03:58 +03:00
b5bb5c85d4 whisper : allow whisper_full from mel spectrogram - no audio (#1214)
Co-authored-by: jbrough <jamie1612@gmail.com>
2023-08-27 20:02:57 +03:00
7e54df414e whisper : significantly improve the inference quality (#1148)
* Fix MSVC compile error C3688

Instead of simply using 'add_compile_options(/utf-8)' to address the MSVC compile error C3688, a better approach would be to handle it in a way that prevents passing '/utf-8' to NVCC.

* Significantly improve inference quality

In the function `log_mel_spectrogram_worker_thread`, there's an array out-of-bounds issue occurring during the calculation of complex number moduli. This issue is causing disruptions in the FFT spectrum, which, in turn, is reducing the quality of inference.

* Significantly improve inference quality

At last, I've pinpointed the actual source of the problem. Given that the frequency spectrum generated from real input data is symmetrical around the Nyquist frequency, there's a for-loop within the `log_mel_spectrogram_worker_thread` function that attempts to fold the frequency spectrum. Regrettably, a bug within this for-loop is causing a frame shift in the frequency spectrum. The previous attempt to remedy this, which involved using `fft_size + 1` when calculating the modulus, was merely a band-aid solution and did not address the underlying issue.

* Addressed a few minor issues

Fixed the issue of `fft_out` continuously expanding. Resolved the fallback caused by using 'break' instead of `fft_in[j] = 0`.

* Significantly improve inference quality 

Thanks for your patience everyone. It's finally sorted out. Now, the right side of the FFT spectrum is being flipped over to the left, and the amplitudes at corresponding positions on the left and right are added together (the spectrum on the left needs to be shifted by one position), then the average is calculated. FFT_OUT[0] is no longer discarded, making full use of the limited space to pack in more information.

* Add annotation and performance improvement

* Calculate FFT only when fft_in are not all zero

* Some minor performance improvement

* Fixed a bug impacting inference quality

* The first version after all the analysis is completed.

* Fix some bugs and add debug mode

* Fixed several bugs

* Temporarily disable speed-up mode and add debug mode.

* Add debug mode

* Disable speed-up mode and add debug mode

* Fix CI error (#1)

* Fix error

* Fix error

* Fixed several bugs including [BLANK_AUDIO] problem

* Remove Hard-coded hann window

* Some Final Fix (#2)

* Fix error

* Fix error

* Probably the last commit

* Probably the last commit

* whisper : minor coding style changes

* whisper : remove debug from public API

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2023-08-27 19:51:33 +03:00
20a80972f4 whisper.android : migrate from ndk-build to CMake (#1204) 2023-08-27 19:35:16 +03:00
7ef3f3837e main : log probs to text file (#1205)
* token/probability file generated with -ls

* code comment cleaning

* main : indentations

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2023-08-27 19:09:06 +03:00
aad2dad38a whisper : minor fixes (#1154) 2023-08-27 19:02:00 +03:00
66f2078878 build : fix OpenBLAS detection under Arch Linux (#1173) 2023-08-25 19:26:34 +03:00
8ce20f0f3d make : fix Linux machines supporting AVX1 not AVX2 (#1162)
e.g. ancient CPU E5-2670 (v1)

See issue #1126

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2023-08-25 15:52:22 +03:00
c84cf87261 whisper : add precalculated values of sin/cos for speeding up FFT (#1142)
* Add sin/cos precalculated values to speedup FFT

* Update whisper.cpp

Co-authored-by: bobqianic <129547291+bobqianic@users.noreply.github.com>

* Update whisper.cpp

Co-authored-by: bobqianic <129547291+bobqianic@users.noreply.github.com>

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: bobqianic <129547291+bobqianic@users.noreply.github.com>
2023-08-25 15:51:14 +03:00
c5f9acf4b7 make : simplify Makefile (#1147)
* Simplify Architecture specific in Makefile

* unified OS specific check
2023-08-25 15:20:44 +03:00
7decc85eb7 cmake : fix PowerPC build failures introduced in #1174 (#1196) 2023-08-25 15:19:48 +03:00
21e8c67a4f Fix AVX etc. under GCC/CMake (#1174) 2023-08-19 21:39:03 +03:00
a4bb2df36a quantize : fix load vocab crash when len is 128 (#1160)
* quantize : fix load vocab crash when len is 128

* ci : add quantize job
2023-08-06 11:04:42 +03:00
b948361956 examples : add tinydiarization support for streaming (#1137) 2023-08-03 11:24:07 +03:00
a792c4079c cmake : fix MSVC compile error C3688 (#1136)
Instead of simply using 'add_compile_options(/utf-8)' to address the MSVC compile error C3688, a better approach would be to handle it in a way that prevents passing '/utf-8' to NVCC.
2023-07-26 18:57:25 +03:00
7b374c9ac9 Revert "cmake : fix MSVC compile error C3688 on non-unicode Windows (#1110)"
This reverts commit fe5c1a7341.
2023-07-26 10:25:09 +03:00
a32c4aa482 whisper : fix visibility warning of struct whisper_full_params by declaring in advance (#1124) 2023-07-25 19:15:57 +03:00
a195bf899a cmake : enable OpenBLAS on Windows (#1128)
Fixed the issue of not being able to find OpenBLAS on the Windows platform. Even though the name of the previously released binary file was whisper-blas-bin-x64.zip, BLAS was actually not enabled. After enabling, the inference speed can increase by 3-4 times.
2023-07-25 19:15:08 +03:00
ded17dc1cf make : fix CLBlast build on MacOS (#1120) 2023-07-25 19:12:03 +03:00
a0bb409f51 make : check nvcc version and set flag (#1115) 2023-07-25 19:10:54 +03:00
a2684cd93a go : implement SetSplitOnWord (#1114)
* Go binding: Implement SetSplitOnWord

* Add comment for consistency
2023-07-25 19:10:12 +03:00
1450346214 make : tests can be called as "make tests base.en" (#1113) 2023-07-25 19:09:38 +03:00
fe5c1a7341 cmake : fix MSVC compile error C3688 on non-unicode Windows (#1110)
Co-authored-by: Gang Chen <cg@upiot.net>
2023-07-25 19:08:37 +03:00
1fa360fc6e readme : add OpenVINO support details (#1112) 2023-07-25 19:07:59 +03:00
41bf19f613 opencl : sync opencl compilation fix in ggml (#1111) 2023-07-25 19:07:08 +03:00
9ad35bd740 samples : add a larger (30min) sample (#1092)
Co-authored-by: Vadim Peretokin <vadim.peretokin@carasent.com>
2023-07-25 19:00:45 +03:00
fabf79fc67 whisper : expose API to let user control log output (#1060)
* expose api to let user control log output

Add
  whisper_set_log_callback()
that lets user set a callback for log messages.

Change all the
  fprintf(stderr, ...)
to call via the above.

* whisper : add <cstdarg>

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2023-07-25 18:58:25 +03:00
925915ae37 whisper : move progress calculation out of whisper.cpp (#1081)
Current `progress_step` was hardcoded into whisper.cpp, this resulted in
bindings having to access progress only at that step even if progress
callback was being called at every iteration.

With this change we get greater granularity progress reporting from
whisper.cpp and bindings/implementations can define their own progress step.
2023-07-25 18:53:34 +03:00
97f4a7fee0 examples : add Vim plugin (#1131)
* Initial proof of concept Vim plugin

At present, this is likely only slightly better than feature parity with
the existing whisper.nvim

Known issues:
 Trailing whitespace
 Up to an existing length(5 seconds) of speech may be processed when
  listening is enabled
 CPU cycles are spent processing speech even when not listening.

Fixing these issues is likely dependent upon future efforts to create a
dedicated library instead of wrapping examples/stream

* Support $WHISPER_CPP_HOME environment variable

A minor misunderstanding of the whisper.nvim implementation resulted in
a plugin that was functional, but not a drop in replacement as it should
be now.
2023-07-25 18:34:23 +03:00
3998465721 ci : more platforms coverage (#1101)
* add multi platform

* add image name

* fix

* fix /bin/sh path

* add missing \

* add all platforms for check

* remove platforms

* remove s390x

* - add arm v6
- format run cmd

* remove arm v6

* - bump checkout to v3
- use setup emsdk action
- add arch to all ubuntu jobs

* mymindstorm/setup-emsdk to v12

* add missing QEMU step

* add fail-fast: false for debug

* add freebsd

* remark all jobs except freebsd for test

* add sudo

* enable all tests again

* format

* check __AVX__ support before include immintrin.h

* try auto detect flag by cmake

* fix check for immintrin.h

* fix include check for immintrin.h

* Remove all platforms for sanitizer build except amd64

We have no clue why they failed.

---------

Co-authored-by: Alon Faraj <alon.faraj@mapcore.com>
2023-07-16 23:00:34 +03:00
4774d2feb0 whisper : minor OpenVINO refactoring (#1037)
Hopefully I didn't break something - haven't tested
2023-07-04 20:28:27 +03:00
6f0114f4a6 go : call SetDuration appropriately (#1077) 2023-07-04 16:13:25 +03:00
66616dbd4d go : fix context.Process call in examples (#1067) 2023-07-04 16:05:35 +03:00
62b81276e0 whisper : add OpenVINO support (#1037)
* openvino: use OpenVINO encoder inference

* openvino: add python script for OpenVINO model generation

* whisper: Fix 'unused' warnings when OpenVINO isn't enabled in build

* Apply suggestions from code review

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* whisper: Fix compilation error

* whisper: revert whisper_get_openvino_path_encoder & whisper_get_openvino_path_cache to non-const func signatures

* cmake: Add openvino-encoder as separate object target

* whisper : minor style fixes

* minor : indentation fixes

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2023-07-04 15:56:11 +03:00
176d7e4e7b readme : better wording (#1064) 2023-07-04 15:30:31 +03:00
70e6fcd78b readme : add tinydiarize instructions (#1058) 2023-07-04 09:51:22 +03:00
c8d0f5fe98 whisper : support speaker segmentation (local diarization) of mono audio via tinydiarize (#1058)
* add HuggingFace mirror to download  ggml model

* support tdrz via simple hack overriding solm tokens

* fix incorrect translate/transcribe token_ids that are not static const

* add apollo 13 sample for tdrz demo

* render [SPEAKER TURN] consistently in all terminal output using vocab.id_to_token

* extend whisper_segment with speaker_turn_next field and save in json output

* fix failing go build

* slipped in some python syntax whoops

* whisper : finalize tinydiarize support (add flag + fixes)

* whisper : tdrz support for word-level timestamps (respect max_len)

* java : try to fix tests after adding tdrz_enable flag

* main : remove TODO leftover

* java : fix params order list after adding "tdrz_enable"

* whisper : fix solm and add nosp token

* main : print tinydiarize help

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2023-07-04 09:45:00 +03:00
fdf58a6668 talk-llama : fix new rope interface 2023-07-03 19:24:01 +03:00
8ba42095c5 Revert "ggml : do not use _GNU_SOURCE gratuitously (#1027)"
This reverts commit 3f7a03ebe3.
2023-07-02 21:53:52 +03:00
d6509bf78d ggml : sync latest repo (mostly refactoring changes) 2023-07-02 21:46:09 +03:00
85ed71aaec talk-llama : fix build on macOS (#1062)
* talk-llama : use posix_madvise() instead of madvise() derived from BSD

sed -i 's,\<madvise\>,posix_&,g;s,\<MADV_,POSIX_&,g' examples/talk-llama/llama-util.h

* make : enable Darwin extensions for macOS builds

This is an attempt at fixing macOS build error coming from the fact that
RLIMIT_MEMLOCK define is not available there without Darwin extensions.
2023-06-28 22:34:50 +03:00
49c9472fa0 extra : update 'quantize-all.sh' to quantize all downloaded models (#1054)
Script will now do what it says: quantize everything except testing models in the 'models'  directory.
2023-06-28 22:07:02 +03:00
72deb41eb2 whisper : split_on_word no longer trims (#1046) 2023-06-25 23:51:01 +03:00
3f7a03ebe3 ggml : do not use _GNU_SOURCE gratuitously (#1027)
* Do not use _GNU_SOURCE gratuitously.

What is needed to build whisper.cpp and examples is availability of
stuff defined in The Open Group Base Specifications Issue 6
(https://pubs.opengroup.org/onlinepubs/009695399/) known also as
Single Unix Specification v3 (SUSv3) or POSIX.1-2001 + XSI extensions.

There is no need to penalize musl libc which simply follows standards.

Not having feature test macros in source code gives greater flexibility
to those wanting to reuse it in 3rd party app, as they can build it with
minimal FTM (_XOPEN_SOURCE=600) or other FTM depending on their needs.

It builds without issues in Alpine (musl libc), Ubuntu (glibc), MSYS2.

* examples : include SDL headers before other headers

This is an attempt at fixing macOS build error coming from SDL2 relying
on Darwin extension memset_pattern4/8/16 coming from Apple's string.h.
2023-06-25 16:34:30 +03:00
62642bb61c talk-llama : fix build after ggml sync (#1049)
sed -i 's,GGML_BACKEND_CUDA,GGML_BACKEND_GPU,g' examples/talk-llama/llama.cpp
2023-06-25 16:13:50 +03:00
f1c9df5806 metal : sync ggml-metal (ref #1047) 2023-06-25 15:40:39 +03:00
6c25fae1c4 opencl : sync latest ggml-opencl 2023-06-25 15:38:30 +03:00
44cb044e66 whisper : fix build with -Werror=undef (#1045) 2023-06-25 15:30:39 +03:00
6c68218e3c models : add ggml_to_pt script (#1042)
* adding ggml_to_pt

* typo sys too many args

* fixing swap errors dimensions

---------

Co-authored-by: simonMoisselin <simon.moisselin@gmail.com>
2023-06-25 15:29:54 +03:00
f11f33f1c0 models : cd statements are quoted to allow spaces in path (#1041) 2023-06-25 15:27:28 +03:00
8ac23c9f77 models : handle paths with spaces in download script (close #1038) 2023-06-25 15:23:23 +03:00
14baf2e7f3 main : add diarization support for all current output types (#1031)
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2023-06-25 15:07:57 +03:00
bc2dcf85fe readme : add java alternative binding (#1029)
Signed-off-by: Miguel Álvarez <miguelwork92@gmail.com>
2023-06-25 14:46:07 +03:00
1e45911f1a go : add support for whisper_full_lang_id() (#1010)
* * Add support for whisper_full_lang_id() to go bindings

* Expose token.id so we can test beg, eot etc

---------

Co-authored-by: Jay Binks <jay.binks@overthewire.com.au>
2023-06-25 14:45:33 +03:00
67564201ec go : fix "cb" -> "callNewSegment" 2023-06-25 14:34:10 +03:00
5feb0dffba ggml : sync latest ggml lib 2023-06-25 14:30:44 +03:00
7dfc11843c go : improve progress reporting and callback handling (#1024)
- Rename `cb` to `callNewSegment` in the `Process` function
- Add `callProgress` as a new parameter to the `Process` function
- Introduce `ProgressCallback` type for reporting progress during processing
- Update `Whisper_full` function to include `progressCallback` parameter
- Add `registerProgressCallback` function and `cbProgress` map for handling progress callbacks

Signed-off-by: appleboy <appleboy.tw@gmail.com>
2023-06-25 14:07:55 +03:00
6a7f3b8db2 make : update cuBLAS build both x86 and aarch64 (#1015)
make cuBLAS compilation compatible with x86 as well as aarch64.
2023-06-25 13:59:48 +03:00
207a12f5bc make : fix for CUDA native not working as an option on Ubuntu (#1012) 2023-06-25 13:57:18 +03:00
26b70395ff main : exit gracefully when invalid params are passed
* Refactor whisper_params_parse to return false on failure

* Updated help flag behavior
2023-06-25 13:52:29 +03:00
598f607e28 main : gracefully exit when invalid params are passed (#1002)
* Refactor whisper_params_parse to return false on failure

* Updated help flag behavior
2023-06-25 13:51:59 +03:00
3ec7bfffe0 py : make convert-pt-to-ggml.py backwards compatible with older vocab.json tokenizer files (#1001)
* patch checkpoint convert script to keep compatibility with older hf_transformers whisper tokenizer

* typo fix
2023-06-25 13:50:14 +03:00
a7f822ef59 readme : corrected syntax for markdown link (#995) 2023-06-25 13:46:44 +03:00
57543c169e updated java README 2023-06-06 10:27:26 +10:00
5b9e59bc07 speak scripts for Windows 2023-06-01 22:45:00 +10:00
3f7436e8a0 updated README for java 2023-06-01 16:55:48 +10:00
ce6f747064 whisper.android : support decode wav file has 2 channels (#972) 2023-05-31 10:13:14 +03:00
d7c936b44a Feature/java bindings2 (#944)
* Java needs to call `whisper_full_default_params_by_ref()`, returning struct by val does not seem to work.
* added convenience methods to WhisperFullParams
* Remove unused WhisperJavaParams
2023-05-29 09:38:58 +10:00
9b926844e3 models : fix README.md (#964)
Fixes typo on line 76 of models/README.md
2023-05-27 10:40:28 +03:00
5e2b3407ef examples : update elevenlabs scripts to use official python API (#837)
* Update elevenlabs example to use ufficial python API

* Update elevenlabs example to use official python API
2023-05-24 21:11:01 +03:00
4e16a8fb63 readme : highlight OpenBLAS support (#956)
* highlight openblas support

* Update README.md
2023-05-24 11:23:51 +03:00
77eab3fbfe talk-llama : sync latest llama.cpp (close #922, close #954) 2023-05-23 14:04:39 +03:00
041be06d58 cmake : build with any BLAS compatible library (#927)
* Build with any BLAS library

* ci: Removed explicit CUDA nvcc path
2023-05-20 21:23:45 +03:00
429b9785c0 ggml : update WASM SIMD 2023-05-20 20:00:06 +03:00
e410cfc3ce ggml : sync latest ggml repo
- new Q4 and Q8 quantization
- updated CUDA
2023-05-20 18:56:30 +03:00
bc89f285d8 bindings : add java bindings (#931)
* WIP - java bindings

* updated README

* failed attempt at JNI

* fullTranscribe() test passes

* tested on Ubuntu 20

* link to Java bindings
2023-05-20 18:25:02 +03:00
56a87ba45d whisper : fix hebrew language code (#935) 2023-05-20 18:17:54 +03:00
95b02d76b0 coreml : add support of large-v1 model (#926) 2023-05-15 18:36:06 +03:00
a5defbc1b9 release : v1.4.2 2023-05-14 19:06:45 +03:00
aaf0d41c7c ggml : add AVX dot products 2023-05-14 18:56:46 +03:00
0cb820e0f9 talk-llama : fix build + sync latest llama.cpp 2023-05-14 18:46:42 +03:00
16564f554f readme : improve Core ML model conversion guidance (#915) 2023-05-14 18:11:08 +03:00
fd01209d09 coreml : support quantized model files 2023-05-14 18:09:44 +03:00
e693074aa6 ggml : sync latest ggml
- New Q4 and Q5 formats
- Various improvements
2023-05-14 18:04:23 +03:00
d652cf12ec main : fix help for --no-timestamps arg (#908) 2023-05-14 17:54:57 +03:00
2b6a074305 extra : update ggml sync script 2023-05-14 10:01:52 +03:00
5300117471 whisper.objc : enable Core ML in example & fix segmentation fault (#910)
* coreml : update endcoder header import path

* coreml : force objc_arc in whisper-encoder.mm

* whisper.objc : create coreml/ group link

* whisper.objc : add coreml model link

* whisper.objc : update readme

* coreml : use -fobjc-arc for coreml/whisper-encoder.mm

* ci: create dummy .mlmodelc for pass ios build

* whisper.objc : update readme

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2023-05-14 09:47:02 +03:00
70af52a316 coreml : fix seg fault, double free (#919, #917, #899) 2023-05-14 09:42:19 +03:00
1d17cd5bb3 coreml : fix memory leak (#899) 2023-05-09 18:38:12 +03:00
bf2449dfae cmake : fix define used for COREML_ALLOW_FALLBACK (#893) 2023-05-08 21:08:09 +03:00
4e4d00c67a talk-llama : only copy used KV cache in get / set state (#890)
---------

Co-authored-by: ejones <evan.q.jones@gmail.com>
2023-05-08 20:59:21 +03:00
9931d66400 readme : add instructions on converting to GGML + "--no-config" to wget (#874) 2023-05-08 20:58:36 +03:00
1a548c048e cmake : fix options disabling AVX and AVX2 flags (#885) 2023-05-08 20:45:53 +03:00
122 changed files with 57622 additions and 6771 deletions

View File

@ -1,31 +1,41 @@
name: CI
on: [push, pull_request]
env:
ubuntu_image: "ubuntu:22.04"
jobs:
ubuntu-latest:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
arch: [linux/amd64, linux/arm64, linux/arm/v7, linux/ppc64le]
steps:
- name: Clone
uses: actions/checkout@v1
uses: actions/checkout@v3
- name: Dependencies
run: |
sudo apt-get update
sudo apt-get install build-essential
sudo apt-get install libsdl2-dev
- name: Set up QEMU
uses: docker/setup-qemu-action@v2
- name: Build
- name: Build ${{ matrix.arch }}
run: |
make
make stream
docker run --platform ${{ matrix.arch }} --rm \
-v ${{ github.workspace }}:/workspace \
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
apt update
apt install -y build-essential libsdl2-dev
make
make stream'
macOS-latest:
runs-on: macOS-latest
steps:
- name: Clone
uses: actions/checkout@v1
uses: actions/checkout@v3
- name: Dependencies
run: |
@ -37,82 +47,104 @@ jobs:
make
make stream
freeBSD-latest:
runs-on: macos-12
steps:
- name: Clone
uses: actions/checkout@v3
- name: Build
uses: cross-platform-actions/action@v0.15.0
with:
operating_system: freebsd
version: '13.2'
run: |
sudo pkg update
sudo pkg install -y gmake sdl2
gmake
gmake stream
ubuntu-latest-gcc:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
build: [Debug, Release]
arch: [linux/amd64, linux/arm64, linux/arm/v7, linux/ppc64le]
steps:
- name: Clone
uses: actions/checkout@v1
uses: actions/checkout@v3
- name: Dependencies
- name: Set up QEMU
uses: docker/setup-qemu-action@v2
- name: Build ${{ matrix.arch }}
run: |
sudo apt-get update
sudo apt-get install build-essential
sudo apt-get install cmake
sudo apt-get install libsdl2-dev
- name: Configure
run: cmake . -DWHISPER_SUPPORT_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }}
- name: Build
run: |
make
ctest -L gh --output-on-failure
docker run --platform ${{ matrix.arch }} --rm \
-v ${{ github.workspace }}:/workspace \
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
apt update
apt install -y build-essential cmake libsdl2-dev
cmake . -DWHISPER_SUPPORT_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }}
make
ctest -L gh --output-on-failure'
ubuntu-latest-clang:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
build: [Debug, Release]
arch: [linux/amd64, linux/arm64, linux/arm/v7, linux/ppc64le]
steps:
- name: Clone
uses: actions/checkout@v1
uses: actions/checkout@v3
- name: Dependencies
- name: Set up QEMU
uses: docker/setup-qemu-action@v2
- name: Build ${{ matrix.arch }}
run: |
sudo apt-get update
sudo apt-get install build-essential
sudo apt-get install cmake
sudo apt-get install libsdl2-dev
- name: Configure
run: cmake . -DWHISPER_SUPPORT_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_COMPILER=clang
- name: Build
run: |
make
ctest -L gh --output-on-failure
docker run --platform ${{ matrix.arch }} --rm \
-v ${{ github.workspace }}:/workspace \
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
apt update
apt install -y build-essential cmake libsdl2-dev
cmake . -DWHISPER_SUPPORT_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_COMPILER=clang
make
ctest -L gh --output-on-failure'
ubuntu-latest-gcc-sanitized:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
sanitizer: [ADDRESS, THREAD, UNDEFINED]
arch: [linux/amd64]
steps:
- name: Clone
uses: actions/checkout@v1
uses: actions/checkout@v3
- name: Dependencies
- name: Set up QEMU
uses: docker/setup-qemu-action@v2
- name: Build ${{ matrix.arch }}
run: |
sudo apt-get update
sudo apt-get install build-essential
sudo apt-get install cmake
- name: Configure
run: cmake . -DCMAKE_BUILD_TYPE=Debug -DWHISPER_SANITIZE_${{ matrix.sanitizer }}=ON
- name: Build
run: |
make
ctest -L gh --output-on-failure
docker run --platform ${{ matrix.arch }} --rm \
-v ${{ github.workspace }}:/workspace \
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
apt update
apt install -y build-essential cmake
cmake . -DCMAKE_BUILD_TYPE=Debug -DWHISPER_SANITIZE_${{ matrix.sanitizer }}=ON
make
ctest -L gh --output-on-failure'
windows:
runs-on: windows-latest
@ -125,14 +157,16 @@ jobs:
include:
- arch: Win32
s2arc: x86
jnaPath: win32-x86
- arch: x64
s2arc: x64
jnaPath: win32-x86-64
- sdl2: ON
s2ver: 2.26.0
steps:
- name: Clone
uses: actions/checkout@v1
uses: actions/checkout@v3
- name: Add msbuild to PATH
uses: microsoft/setup-msbuild@v1
@ -159,6 +193,12 @@ jobs:
if: matrix.sdl2 == 'ON'
run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }}
- name: Upload dll
uses: actions/upload-artifact@v3
with:
name: ${{ matrix.jnaPath }}_whisper.dll
path: build/bin/${{ matrix.build }}/whisper.dll
- name: Upload binaries
if: matrix.sdl2 == 'ON'
uses: actions/upload-artifact@v1
@ -187,7 +227,7 @@ jobs:
steps:
- name: Clone
uses: actions/checkout@v1
uses: actions/checkout@v3
- name: Add msbuild to PATH
uses: microsoft/setup-msbuild@v1
@ -253,7 +293,7 @@ jobs:
steps:
- name: Clone
uses: actions/checkout@v1
uses: actions/checkout@v3
- name: Add msbuild to PATH
uses: microsoft/setup-msbuild@v1
@ -300,24 +340,16 @@ jobs:
steps:
- name: Clone
uses: actions/checkout@v1
uses: actions/checkout@v3
- name: Dependencies
run: |
wget -q https://github.com/emscripten-core/emsdk/archive/master.tar.gz
tar -xvf master.tar.gz
emsdk-master/emsdk update
emsdk-master/emsdk install latest
emsdk-master/emsdk activate latest
- name: Setup emsdk
uses: mymindstorm/setup-emsdk@v12
- name: Configure
run: echo "tmp"
- name: Verify
run: emcc -v
- name: Build
run: |
pushd emsdk-master
source ./emsdk_env.sh
popd
emcmake cmake . -DCMAKE_BUILD_TYPE=${{ matrix.build }}
make
@ -330,10 +362,12 @@ jobs:
steps:
- name: Clone
uses: actions/checkout@v1
uses: actions/checkout@v3
- name: Configure
run: cp models/for-tests-ggml-base.en.bin models/ggml-base.en.bin
run: |
cp models/for-tests-ggml-base.en.bin models/ggml-base.en.bin
mkdir models/ggml-base.en-encoder.mlmodelc
- name: Build objc example
run: xcodebuild -project examples/whisper.objc/whisper.objc.xcodeproj -scheme whisper.objc -configuration ${{ matrix.build }} -sdk iphonesimulator build
@ -346,7 +380,7 @@ jobs:
steps:
- name: Clone
uses: actions/checkout@v1
uses: actions/checkout@v3
- name: Install Java
uses: actions/setup-java@v3
@ -361,3 +395,58 @@ jobs:
run: |
cd examples/whisper.android
./gradlew assembleRelease --no-daemon
java:
needs: [ 'windows' ]
runs-on: windows-latest
steps:
- uses: actions/checkout@v3
- name: Install Java
uses: actions/setup-java@v1
with:
java-version: 17
- name: Download Windows lib
uses: actions/download-artifact@v3
with:
name: win32-x86-64_whisper.dll
path: bindings/java/build/generated/resources/main/win32-x86-64
- name: Build
run: |
models\download-ggml-model.cmd tiny.en
cd bindings/java
chmod +x ./gradlew
./gradlew build
- name: Upload jar
uses: actions/upload-artifact@v3
with:
name: whispercpp.jar
path: bindings/java/build/libs/whispercpp-*.jar
- name: Publish package
if: ${{ github.ref == 'refs/heads/master' }}
uses: gradle/gradle-build-action@v2
with:
arguments: publish
build-root-directory: bindings/java
env:
MAVEN_USERNAME: ${{ secrets.JIRA_USER }}
MAVEN_PASSWORD: ${{ secrets.JIRA_PASS }}
# MAVEN_USERNAME: ${{ secrets.OSSRH_USERNAME }}
# MAVEN_PASSWORD: ${{ secrets.OSSRH_TOKEN }}
quantize:
runs-on: ubuntu-latest
steps:
- name: Clone
uses: actions/checkout@v3
- name: Test quantize
run: |
./models/download-ggml-model.sh tiny.en
make quantize
./quantize models/ggml-tiny.en.bin models/ggml-tiny.en-q4_0.bin q4_0

View File

@ -1,68 +0,0 @@
name: release-deb
on:
release:
types: [created]
jobs:
build:
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v2
- name: Configure
run: |
set -x -e
VERSION=$(echo $GITHUB_REF | cut --delimiter=/ -f 3)
ID="whisper-cpp-small_${VERSION}_amd64"
echo "PKG_VERSION=$VERSION" >> $GITHUB_ENV
echo "PKG_ID=$ID" >> $GITHUB_ENV
- name: Install deps
run: |
sudo apt install -y --no-install-recommends intel-mkl
- name: Build
run: |
cmake -S . -B build-mkl \
-DCMAKE_BUILD_TYPE=Release\
-DBUILD_SHARED_LIBS=0\
-DWHISPER_BLAS=1\
-DWHISPER_BLAS_VENDOR=Intel10_64lp
cd build-mkl
make
cd ..
- name: Create package tree
env:
GITHUB_REPO: ${{ github.repository }}
run: |
export ROOT=$PKG_ID/opt/project/whisper.cpp
mkdir -p $ROOT/bin
mkdir -p $ROOT/share
mkdir -p $PKG_ID/DEBIAN
cp build-mkl/bin/main $ROOT/bin/whisper
cp -r contrib/debian/control $PKG_ID/DEBIAN/
echo "Version: $PKG_VERSION" >> $PKG_ID/DEBIAN/control
echo "Vcs-Git: $GITHUB_REPO" >> $PKG_ID/DEBIAN/control
echo "Vcs-Git-Commit: $GITHUB_SHA" >> $PKG_ID/DEBIAN/control
models/download-ggml-model.sh small
build-mkl/bin/quantize models/ggml-small.bin \
$ROOT/share/ggml-small-q5_1.bin q5_1
- name: Create deb package
run: |
mkdir artifacts
dpkg-deb --build --root-owner-group $PKG_ID
- name: Upload Release Asset
uses: xresloader/upload-to-github-release@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
release_id: ${{ github.event.release.id }}
file: ${{ env.PKG_ID }}.deb

7
.gitignore vendored
View File

@ -5,19 +5,18 @@
.test/
.vs/
.vscode/
.idea/
.DS_Store
build/
build-em/
build-debug/
build-release/
build-rwdi/
build-static/
build-cublas/
build-no-accel/
build-sanitize-addr/
build-sanitize-thread/
cmake-build-debug/
/main
/stream
@ -26,6 +25,7 @@ cmake-build-debug/
/talk-llama
/bench
/quantize
/lsp
arm_neon.h
sync.sh
@ -43,3 +43,6 @@ extra/bench-gg.txt
models/*.mlmodel
models/*.mlmodelc
models/*.mlpackage
bindings/java/.gradle/
bindings/java/.idea/
.idea/

View File

@ -1,6 +1,6 @@
cmake_minimum_required (VERSION 3.0)
project(whisper.cpp VERSION 1.4.1)
project(whisper.cpp VERSION 1.4.2)
# Add path to modules
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
@ -54,6 +54,8 @@ option(WHISPER_NO_AVX2 "whisper: disable AVX2" OFF)
option(WHISPER_NO_FMA "whisper: disable FMA" OFF)
option(WHISPER_NO_F16C "whisper: disable F16c" OFF)
option(WHISPER_OPENVINO "whisper: support for OpenVINO" OFF)
if (APPLE)
option(WHISPER_NO_ACCELERATE "whisper: disable Accelerate framework" OFF)
option(WHISPER_COREML "whisper: enable Core ML framework" OFF)
@ -63,6 +65,7 @@ else()
option(WHISPER_BLAS_VENDOR "whisper: BLAS library vendor" Generic)
option(WHISPER_OPENBLAS "whisper: prefer OpenBLAS" OFF)
option(WHISPER_CUBLAS "whisper: support for cuBLAS" OFF)
option(WHISPER_HIPBLAS "whisper: support for hipBLAS" OFF)
option(WHISPER_CLBLAST "whisper: use CLBlast" OFF)
endif()
@ -123,7 +126,7 @@ if (APPLE)
endif()
if (WHISPER_COREML_ALLOW_FALLBACK)
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DWHISPER_USE_COREML_ALLOW_FALLBACK)
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DWHISPER_COREML_ALLOW_FALLBACK)
endif()
endif()
endif()
@ -134,25 +137,34 @@ if (WHISPER_OPENBLAS)
endif()
if (WHISPER_BLAS)
if (WHISPER_STATIC)
if (WIN32)
if(DEFINED ENV{OPENBLAS_PATH})
set(BLAS_LIBRARIES $ENV{OPENBLAS_PATH}/lib/libopenblas.dll.a)
message(STATUS "Libraries ${BLAS_LIBRARIES}")
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_USE_OPENBLAS)
include_directories($ENV{OPENBLAS_PATH}/include)
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ${BLAS_LIBRARIES})
else ()
message(WARNING "BLAS library was not found. Environment variable OPENBLAS_PATH not defined.")
endif ()
else ()
set(BLA_STATIC 1)
else()
set(BLA_STATIC 0)
set(BLA_VENDOR ${WHISPER_BLAS_VENDOR})
# set(BLA_PREFER_PKGCONFIG 1)
set(BLA_SIZEOF_INTEGER 8)
find_package(BLAS)
if(BLAS_FOUND)
message(STATUS "BLAS compatible library found")
message(STATUS "Libraries ${BLAS_LIBRARIES}")
find_path(BLAS_INCLUDE_DIRS cblas.h /usr/include/openblas /usr/local/include/openblas $ENV{BLAS_HOME}/include)
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_USE_OPENBLAS)
include_directories(${BLAS_INCLUDE_DIRS})
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ${BLAS_LIBRARIES})
else()
message(WARNING "BLAS library was not found")
endif()
endif ()
set(BLA_VENDOR ${WHISPER_BLAS_VENDOR})
set(BLA_SIZEOF_INTEGER 8)
find_package(BLAS)
if(BLAS_FOUND)
message(STATUS "BLAS compatible library found")
message(STATUS "Libraries ${BLAS_LIBRARIES}")
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_USE_OPENBLAS)
include_directories(${BLAS_INCLUDE_DIRS})
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ${BLAS_LIBRARIES})
else()
message(WARNING "BLAS library was not found")
endif()
endif ()
if (WHISPER_CUBLAS)
@ -162,7 +174,7 @@ if (WHISPER_CUBLAS)
if (CUDAToolkit_FOUND)
message(STATUS "cuBLAS found")
set(CMAKE_CUDA_COMPILER /usr/local/cuda/bin/nvcc)
enable_language(CUDA)
set(GGML_CUDA_SOURCES ggml-cuda.cu ggml-cuda.h)
@ -180,12 +192,43 @@ if (WHISPER_CUBLAS)
endif()
endif()
if (WHISPER_HIPBLAS)
list(APPEND CMAKE_PREFIX_PATH /opt/rocm)
if (NOT ${CMAKE_C_COMPILER_ID} MATCHES "Clang")
message(WARNING "Only LLVM is supported for HIP, hint: CC=/opt/rocm/llvm/bin/clang")
endif()
if (NOT ${CMAKE_CXX_COMPILER_ID} MATCHES "Clang")
message(WARNING "Only LLVM is supported for HIP, hint: CXX=/opt/rocm/llvm/bin/clang++")
endif()
find_package(hip)
find_package(hipblas)
find_package(rocblas)
if (${hipblas_FOUND} AND ${hip_FOUND})
message(STATUS "HIP and hipBLAS found")
add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUBLAS)
add_library(ggml-rocm OBJECT ggml-cuda.cu ggml-cuda.h)
set_property(TARGET ggml-rocm PROPERTY POSITION_INDEPENDENT_CODE ON)
set_source_files_properties(ggml-cuda.cu PROPERTIES LANGUAGE CXX)
target_link_libraries(ggml-rocm PRIVATE hip::device PUBLIC hip::host roc::rocblas roc::hipblas)
if (WHISPER_STATIC)
message(FATAL_ERROR "Static linking not supported for HIP/ROCm")
endif()
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ggml-rocm)
else()
message(WARNING "hipBLAS or HIP not found. Try setting CMAKE_PREFIX_PATH=/opt/rocm")
endif()
endif()
if (WHISPER_CLBLAST)
find_package(CLBlast)
if (CLBlast_FOUND)
message(STATUS "CLBlast found")
set(GGML_OPENCL_SOURCES ggml-opencl.c ggml-opencl.h)
set(GGML_OPENCL_SOURCES ggml-opencl.cpp ggml-opencl.h)
add_compile_definitions(GGML_USE_CLBLAST)
@ -195,6 +238,10 @@ if (WHISPER_CLBLAST)
endif()
endif()
if( WHISPER_OPENVINO )
find_package(OpenVINO REQUIRED COMPONENTS Runtime)
endif()
# compiler flags
if (NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES)
@ -234,12 +281,25 @@ message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}")
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm" OR ${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64")
message(STATUS "ARM detected")
elseif(${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le")
message(STATUS "PowerPC detected")
else()
message(STATUS "x86 detected")
if (MSVC)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX2")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /arch:AVX2")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /arch:AVX2")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /utf-8")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /utf-8")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /utf-8")
if(NOT WHISPER_NO_AVX2)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX2")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /arch:AVX2")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /arch:AVX2")
else()
if(NOT WHISPER_NO_AVX)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /arch:AVX")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /arch:AVX")
endif()
endif()
else()
if (EMSCRIPTEN)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -pthread")
@ -292,6 +352,24 @@ if (WHISPER_COREML)
)
endif()
if (WHISPER_OPENVINO)
set(TARGET whisper.openvino)
add_library(${TARGET} OBJECT
openvino/whisper-openvino-encoder.h
openvino/whisper-openvino-encoder.cpp
)
target_include_directories(${TARGET} PUBLIC
.
)
set_property(TARGET ${TARGET} PROPERTY POSITION_INDEPENDENT_CODE ON)
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DWHISPER_USE_OPENVINO)
target_link_libraries(${TARGET} PRIVATE openvino::runtime)
endif()
#
# whisper - this is the main library of the project
#
@ -317,6 +395,10 @@ if (WHISPER_COREML)
target_link_libraries(${TARGET} PRIVATE whisper.coreml)
endif()
if (WHISPER_OPENVINO)
target_link_libraries(${TARGET} PRIVATE whisper.openvino)
endif()
if (MSVC)
target_link_libraries(${TARGET} PRIVATE ${WHISPER_EXTRA_LIBS} ${CMAKE_THREAD_LIBS_INIT})

194
Makefile
View File

@ -12,6 +12,12 @@ ifndef UNAME_M
UNAME_M := $(shell uname -m)
endif
ifndef NVCC_VERSION
ifeq ($(call,$(shell which nvcc))$(.SHELLSTATUS),0)
NVCC_VERSION := $(shell nvcc --version | egrep -o "V[0-9]+.[0-9]+.[0-9]+" | cut -c2-)
endif
endif
CCV := $(shell $(CC) --version | head -n 1)
CXXV := $(shell $(CXX) --version | head -n 1)
@ -42,21 +48,16 @@ ifneq ($(wildcard /usr/include/musl/*),)
CXXFLAGS += -D_POSIX_SOURCE -D_GNU_SOURCE
endif
# RLIMIT_MEMLOCK came in BSD, is not specified in POSIX.1,
# and on macOS its availability depends on enabling Darwin extensions
ifeq ($(UNAME_S),Darwin)
CFLAGS += -D_DARWIN_C_SOURCE
CXXFLAGS += -D_DARWIN_C_SOURCE
endif
# OS specific
# TODO: support Windows
ifeq ($(UNAME_S),Linux)
CFLAGS += -pthread
CXXFLAGS += -pthread
endif
ifeq ($(UNAME_S),Darwin)
CFLAGS += -pthread
CXXFLAGS += -pthread
endif
ifeq ($(UNAME_S),FreeBSD)
CFLAGS += -pthread
CXXFLAGS += -pthread
endif
ifeq ($(UNAME_S),Haiku)
ifeq ($(filter $(UNAME_S),Linux Darwin DragonFly FreeBSD NetBSD OpenBSD Haiku),$(UNAME_S))
CFLAGS += -pthread
CXXFLAGS += -pthread
endif
@ -64,66 +65,56 @@ endif
# Architecture specific
# TODO: probably these flags need to be tweaked on some architectures
# feel free to update the Makefile for your architecture and send a pull request or issue
ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686))
ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686 amd64))
ifeq ($(UNAME_S),Darwin)
CFLAGS += -mf16c
AVX1_M := $(shell sysctl machdep.cpu.features)
ifneq (,$(findstring FMA,$(AVX1_M)))
CFLAGS += -mfma
endif
ifneq (,$(findstring AVX1.0,$(AVX1_M)))
CFLAGS += -mavx
endif
AVX2_M := $(shell sysctl machdep.cpu.leaf7_features)
ifneq (,$(findstring AVX2,$(AVX2_M)))
CFLAGS += -mavx2
endif
CPUINFO_CMD := sysctl machdep.cpu.features
else ifeq ($(UNAME_S),Linux)
AVX2_M := $(shell grep "avx2 " /proc/cpuinfo)
ifneq (,$(findstring avx2,$(AVX2_M)))
CFLAGS += -mavx2
endif
FMA_M := $(shell grep "fma " /proc/cpuinfo)
ifneq (,$(findstring fma,$(FMA_M)))
CFLAGS += -mfma
endif
F16C_M := $(shell grep "f16c " /proc/cpuinfo)
ifneq (,$(findstring f16c,$(F16C_M)))
CFLAGS += -mf16c
AVX1_M := $(shell grep "avx " /proc/cpuinfo)
ifneq (,$(findstring avx,$(AVX1_M)))
CFLAGS += -mavx
endif
endif
SSE3_M := $(shell grep "sse3 " /proc/cpuinfo)
ifneq (,$(findstring sse3,$(SSE3_M)))
CFLAGS += -msse3
endif
CPUINFO_CMD := cat /proc/cpuinfo
else ifneq (,$(filter MINGW32_NT% MINGW64_NT%,$(UNAME_S)))
CPUINFO_CMD := cat /proc/cpuinfo
else ifneq (,$(filter DragonFly FreeBSD,$(UNAME_S)))
CPUINFO_CMD := grep Features /var/run/dmesg.boot
else ifeq ($(UNAME_S),Haiku)
AVX2_M := $(shell sysinfo -cpu | grep "AVX2 ")
ifneq (,$(findstring avx2,$(AVX2_M)))
CFLAGS += -mavx2
endif
FMA_M := $(shell sysinfo -cpu | grep "FMA ")
ifneq (,$(findstring fma,$(FMA_M)))
CFLAGS += -mfma
endif
F16C_M := $(shell sysinfo -cpu | grep "F16C ")
ifneq (,$(findstring f16c,$(F16C_M)))
CFLAGS += -mf16c
AVX1_M := $(shell sysinfo -cpu | grep "AVX ")
ifneq (,$(findstring avx,$(AVX1_M)))
CFLAGS += -mavx
endif
endif
else
CFLAGS += -mfma -mf16c -mavx -mavx2
CPUINFO_CMD := sysinfo -cpu
endif
ifdef CPUINFO_CMD
AVX_M := $(shell $(CPUINFO_CMD) | grep -iwE 'AVX|AVX1.0')
ifneq (,$(AVX_M))
CFLAGS += -mavx
CXXFLAGS += -mavx
endif
AVX2_M := $(shell $(CPUINFO_CMD) | grep -iw 'AVX2')
ifneq (,$(AVX2_M))
CFLAGS += -mavx2
CXXFLAGS += -mavx2
endif
FMA_M := $(shell $(CPUINFO_CMD) | grep -iw 'FMA')
ifneq (,$(FMA_M))
CFLAGS += -mfma
CXXFLAGS += -mfma
endif
F16C_M := $(shell $(CPUINFO_CMD) | grep -iw 'F16C')
ifneq (,$(F16C_M))
CFLAGS += -mf16c
CXXFLAGS += -mf16c
endif
SSE3_M := $(shell $(CPUINFO_CMD) | grep -iwE 'PNI|SSE3')
ifneq (,$(SSE3_M))
CFLAGS += -msse3
CXXFLAGS += -msse3
endif
SSSE3_M := $(shell $(CPUINFO_CMD) | grep -iw 'SSSE3')
ifneq (,$(SSSE3_M))
CFLAGS += -mssse3
CXXFLAGS += -mssse3
endif
endif
endif
ifeq ($(UNAME_M),amd64)
CFLAGS += -mavx -mavx2 -mfma -mf16c
endif
ifneq ($(filter ppc64%,$(UNAME_M)),)
@ -155,29 +146,56 @@ endif
endif
ifdef WHISPER_OPENBLAS
CFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/openblas
CFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/openblas -I/usr/include/openblas
LDFLAGS += -lopenblas
endif
ifdef WHISPER_CUBLAS
CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
CXXFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib
ifeq ($(shell expr $(NVCC_VERSION) \>= 11.6), 1)
CUDA_ARCH_FLAG=native
else
CUDA_ARCH_FLAG=all
endif
CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include
CXXFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include
LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib
WHISPER_OBJ += ggml-cuda.o
NVCC = nvcc
NVCCFLAGS = --forward-unknown-to-host-compiler -arch=native
NVCCFLAGS = --forward-unknown-to-host-compiler -arch=$(CUDA_ARCH_FLAG)
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
$(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@
endif
ifdef WHISPER_HIPBLAS
ROCM_PATH ?= /opt/rocm
HIPCC ?= $(ROCM_PATH)/bin/hipcc
GPU_TARGETS ?= $(shell $(ROCM_PATH)/llvm/bin/amdgpu-arch)
CFLAGS += -DGGML_USE_HIPBLAS -DGGML_USE_CUBLAS
CXXFLAGS += -DGGML_USE_HIPBLAS -DGGML_USE_CUBLAS
LDFLAGS += -L$(ROCM_PATH)/lib -Wl,-rpath=$(ROCM_PATH)/lib
LDFLAGS += -lhipblas -lamdhip64 -lrocblas
HIPFLAGS += $(addprefix --offload-arch=,$(GPU_TARGETS))
WHISPER_OBJ += ggml-cuda.o
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
$(HIPCC) $(CXXFLAGS) $(HIPFLAGS) -x hip -c -o $@ $<
endif
ifdef WHISPER_CLBLAST
CFLAGS += -DGGML_USE_CLBLAST
LDFLAGS += -lclblast -lOpenCL
CXXFLAGS += -DGGML_USE_CLBLAST
LDFLAGS += -lclblast
ifeq ($(UNAME_S),Darwin)
LDFLAGS += -framework OpenCL
else
LDFLAGS += -lOpenCL
endif
WHISPER_OBJ += ggml-opencl.o
ggml-opencl.o: ggml-opencl.c ggml-opencl.h
$(CC) $(CFLAGS) -c $< -o $@
ggml-opencl.o: ggml-opencl.cpp ggml-opencl.h
$(CXX) $(CXXFLAGS) -c $< -o $@
endif
ifdef WHISPER_GPROF
@ -240,7 +258,7 @@ ifndef WHISPER_COREML
WHISPER_OBJ += whisper.o
else
whisper-encoder.o: coreml/whisper-encoder.mm coreml/whisper-encoder.h
$(CXX) -O3 -I . -c coreml/whisper-encoder.mm -o whisper-encoder.o
$(CXX) -O3 -I . -fobjc-arc -c coreml/whisper-encoder.mm -o whisper-encoder.o
whisper-encoder-impl.o: coreml/whisper-encoder-impl.m coreml/whisper-encoder-impl.h
$(CXX) -O3 -I . -fobjc-arc -c coreml/whisper-encoder-impl.m -o whisper-encoder-impl.o
@ -255,7 +273,7 @@ libwhisper.so: ggml.o $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) -shared -o libwhisper.so ggml.o $(WHISPER_OBJ) $(LDFLAGS)
clean:
rm -f *.o main stream command talk talk-llama bench quantize libwhisper.a libwhisper.so
rm -f *.o main stream command talk talk-llama bench quantize lsp libwhisper.a libwhisper.so
#
# Examples
@ -279,8 +297,11 @@ quantize: examples/quantize/quantize.cpp ggml.o $(WHISPER_OBJ) $(SRC_COMMON)
stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o stream $(CC_SDL) $(LDFLAGS)
command: examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o command $(CC_SDL) $(LDFLAGS)
command: examples/command/command.cpp examples/grammar-parser.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/command/command.cpp examples/grammar-parser.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o command $(CC_SDL) $(LDFLAGS)
lsp: examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o lsp $(CC_SDL) $(LDFLAGS)
talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o talk $(CC_SDL) $(LDFLAGS)
@ -301,12 +322,19 @@ samples:
@wget --quiet --show-progress -O samples/gb1.ogg https://upload.wikimedia.org/wikipedia/commons/1/1f/George_W_Bush_Columbia_FINAL.ogg
@wget --quiet --show-progress -O samples/hp0.ogg https://upload.wikimedia.org/wikipedia/en/d/d4/En.henryfphillips.ogg
@wget --quiet --show-progress -O samples/mm1.wav https://cdn.openai.com/whisper/draft-20220913a/micro-machines.wav
@wget --quiet --show-progress -O samples/a13.mp3 https://upload.wikimedia.org/wikipedia/commons/transcoded/6/6f/Apollo13-wehaveaproblem.ogg/Apollo13-wehaveaproblem.ogg.mp3
@wget --quiet --show-progress -O samples/diffusion2023-07-03.flac https://archive.org/download/diffusion2023-07-03/diffusion2023-07-03.flac
@echo "Converting to 16-bit WAV ..."
@ffmpeg -loglevel -0 -y -i samples/gb0.ogg -ar 16000 -ac 1 -c:a pcm_s16le samples/gb0.wav
@ffmpeg -loglevel -0 -y -i samples/gb1.ogg -ar 16000 -ac 1 -c:a pcm_s16le samples/gb1.wav
@ffmpeg -loglevel -0 -y -i samples/hp0.ogg -ar 16000 -ac 1 -c:a pcm_s16le samples/hp0.wav
@rm samples/*.ogg
@ffmpeg -loglevel -0 -y -i samples/mm1.wav -ar 16000 -ac 1 -c:a pcm_s16le samples/mm0.wav
@rm samples/mm1.wav
@ffmpeg -loglevel -0 -y -i samples/a13.mp3 -ar 16000 -ac 1 -c:a pcm_s16le -ss 00:00:00 -to 00:00:30 samples/a13.wav
@rm samples/a13.mp3
@ffmpeg -loglevel -0 -y -i samples/diffusion2023-07-03.flac -ar 16000 -ac 1 -c:a pcm_s16le samples/diffusion2023-07-03.wav
@rm samples/diffusion2023-07-03.flac
#
# Models
@ -348,4 +376,4 @@ tiny.en tiny base.en base small.en small medium.en medium large-v1 large: main
.PHONY: tests
tests:
bash ./tests/run-tests.sh
bash ./tests/run-tests.sh $(word 2, $(MAKECMDGOALS))

153
README.md
View File

@ -6,7 +6,7 @@
[![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](https://opensource.org/licenses/MIT)
[![npm](https://img.shields.io/npm/v/whisper.cpp.svg)](https://www.npmjs.com/package/whisper.cpp/)
Beta: [v1.4.1](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.4.1) / Stable: [v1.2.1](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.2.1) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126)
Beta: [v1.4.2](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.4.2) / Stable: [v1.2.1](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.2.1) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126)
High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model:
@ -21,6 +21,8 @@ High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisp
- Runs on the CPU
- [Partial GPU support for NVIDIA via cuBLAS](https://github.com/ggerganov/whisper.cpp#nvidia-gpu-support-via-cublas)
- [Partial OpenCL GPU support via CLBlast](https://github.com/ggerganov/whisper.cpp#opencl-gpu-support-via-clblast)
- [BLAS CPU support via OpenBLAS](https://github.com/ggerganov/whisper.cpp#blas-cpu-support-via-openblas)
- [OpenVINO Support](https://github.com/ggerganov/whisper.cpp#openvino-support)
- [C-style API](https://github.com/ggerganov/whisper.cpp/blob/master/whisper.h)
Supported platforms:
@ -28,6 +30,7 @@ Supported platforms:
- [x] Mac OS (Intel and Arm)
- [x] [iOS](examples/whisper.objc)
- [x] [Android](examples/whisper.android)
- [x] [Java](bindings/java/README.md)
- [x] Linux / [FreeBSD](https://github.com/ggerganov/whisper.cpp/issues/56#issuecomment-1350920264)
- [x] [WebAssembly](examples/whisper.wasm)
- [x] Windows ([MSVC](https://github.com/ggerganov/whisper.cpp/blob/master/.github/workflows/build.yml#L117-L144) and [MinGW](https://github.com/ggerganov/whisper.cpp/issues/168)]
@ -58,7 +61,7 @@ Or you can even run it straight in the browser: [talk.wasm](examples/talk.wasm)
- Various other examples are available in the [examples](examples) folder
The tensor operators are optimized heavily for Apple silicon CPUs. Depending on the computation size, Arm Neon SIMD
instrisics or CBLAS Accelerate framework routines are used. The latter are especially effective for bigger sizes since
intrinsics or CBLAS Accelerate framework routines are used. The latter are especially effective for bigger sizes since
the Accelerate framework utilizes the special-purpose AMX coprocessor available in modern Apple products.
## Quick start
@ -71,6 +74,8 @@ Then, download one of the Whisper models converted in [ggml format](models). For
bash ./models/download-ggml-model.sh base.en
```
If you wish to convert the Whisper models to ggml format yourself, instructions are in [models/README.md](models/README.md).
Now build the [main](examples/main) example and transcribe an audio file like this:
```bash
@ -111,6 +116,7 @@ options:
-lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail
-su, --speed-up [false ] speed up audio by x2 (reduced accuracy)
-tr, --translate [false ] translate from source language to english
-tdrz, --tinydiarize [false ] enable tinydiarize (requires a tdrz model)
-di, --diarize [false ] stereo audio diarization
-nf, --no-fallback [false ] do not use temperature fallback while decoding
-otxt, --output-txt [false ] output result in a text file
@ -259,6 +265,12 @@ speed-up - more than x3 faster compared with CPU-only execution. Here are the in
pip install coremltools
```
- To ensure `coremltools` operates correctly, please confirm that [Xcode](https://developer.apple.com/xcode/) is installed and execute `xcode-select --install` to install the command-line tools.
- Python 3.10 is recommended.
- [OPTIONAL] It is recommended to utilize a Python version management system, such as [Miniconda](https://docs.conda.io/en/latest/miniconda.html) for this step:
- To create an environment, use: `conda create -n py310-whisper python=3.10 -y`
- To activate the environment, use: `conda activate py310-whisper`
- Generate a Core ML model. For example, to generate a `base.en` model, use:
```bash
@ -275,8 +287,8 @@ speed-up - more than x3 faster compared with CPU-only execution. Here are the in
WHISPER_COREML=1 make -j
# using CMake
cd build
cmake -DWHISPER_COREML=1 ..
cmake -B build -DWHISPER_COREML=1
cmake --build build -j --config Release
```
- Run the examples as usual. For example:
@ -300,9 +312,88 @@ speed-up - more than x3 faster compared with CPU-only execution. Here are the in
For more information about the Core ML implementation please refer to PR [#566](https://github.com/ggerganov/whisper.cpp/pull/566).
## OpenVINO support
On platforms that support [OpenVINO](https://github.com/openvinotoolkit/openvino), the Encoder inference can be executed
on OpenVINO-supported devices including x86 CPUs and Intel GPUs (integrated & discrete).
This can result in significant speedup in encoder performance. Here are the instructions for generating the OpenVINO model and using it with `whisper.cpp`:
- First, setup python virtual env. and install python dependencies. Python 3.10 is recommended.
Windows:
```
cd models
python -m venv openvino_conv_env
openvino_conv_env\Scripts\activate
python -m pip install --upgrade pip
pip install -r openvino-conversion-requirements.txt
```
Linux and macOS:
```
cd models
python3 -m venv openvino_conv_env
source openvino_conv_env/bin/activate
python -m pip install --upgrade pip
pip install -r openvino-conversion-requirements.txt
```
- Generate an OpenVINO encoder model. For example, to generate a `base.en` model, use:
```
python convert-whisper-to-openvino.py --model base.en
```
This will produce ggml-base.en-encoder-openvino.xml/.bin IR model files. It's recommended to relocate these to the same folder as ggml models, as that
is the default location that the OpenVINO extension will search at runtime.
- Build `whisper.cpp` with OpenVINO support:
Download OpenVINO package from [release page](https://github.com/openvinotoolkit/openvino/releases). The recommended version to use is [2023.0.0](https://github.com/openvinotoolkit/openvino/releases/tag/2023.0.0).
After downloading & extracting package onto your development system, set up required environment by sourcing setupvars script. For example:
Linux:
```bash
source /path/to/l_openvino_toolkit_ubuntu22_2023.0.0.10926.b4452d56304_x86_64/setupvars.sh
```
Windows (cmd):
```
C:\Path\To\w_openvino_toolkit_windows_2023.0.0.10926.b4452d56304_x86_64\setupvars.bat
```
And then build the project using cmake:
```bash
cmake -B build -DWHISPER_OPENVINO=1
cmake --build build -j --config Release
```
- Run the examples as usual. For example:
```bash
./main -m models/ggml-base.en.bin -f samples/jfk.wav
...
whisper_ctx_init_openvino_encoder: loading OpenVINO model from 'models/ggml-base.en-encoder-openvino.xml'
whisper_ctx_init_openvino_encoder: first run on a device may take a while ...
whisper_openvino_init: path_model = models/ggml-base.en-encoder-openvino.xml, device = GPU, cache_dir = models/ggml-base.en-encoder-openvino-cache
whisper_ctx_init_openvino_encoder: OpenVINO model loaded
system_info: n_threads = 4 / 8 | AVX = 1 | AVX2 = 1 | AVX512 = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 0 | SSE3 = 1 | VSX = 0 | COREML = 0 | OPENVINO = 1 |
...
```
The first time run on an OpenVINO device is slow, since the OpenVINO framework will compile the IR (Intermediate Representation) model to a device-specific 'blob'. This device-specific blob will get
cached for the next run.
For more information about the Core ML implementation please refer to PR [#1037](https://github.com/ggerganov/whisper.cpp/pull/1037).
## NVIDIA GPU support via cuBLAS
With NVIDIA cards, the Encoder processing can be offloaded to the GPU to a large extend through cuBLAS.
With NVIDIA cards the Encoder processing can to a large extent be offloaded to the GPU through cuBLAS.
First, make sure you have installed `cuda`: https://developer.nvidia.com/cuda-downloads
Now build `whisper.cpp` with cuBLAS support:
@ -314,7 +405,7 @@ WHISPER_CUBLAS=1 make -j
## OpenCL GPU support via CLBlast
For cards and integrated GPUs that support OpenCL, the Encoder processing can be largely offloaded to the GPU through CLBlast. This is especially useful for users with AMD APU's or low end devices for up to ~2x speedup.
For cards and integrated GPUs that support OpenCL, the Encoder processing can be largely offloaded to the GPU through CLBlast. This is especially useful for users with AMD APUs or low end devices for up to ~2x speedup.
First, make sure you have installed `CLBlast` for your OS or Distribution: https://github.com/CNugteren/CLBlast
@ -327,16 +418,26 @@ make clean
WHISPER_CLBLAST=1 make -j
CMake:
cd whisper.cpp ; mkdir build ; cd build
cmake -DWHISPER_CLBLAST=ON ..
make clean
make -j
cp bin/* ../
cd whisper.cpp
cmake -B build -DWHISPER_CLBLAST=ON
cmake --build build -j --config Release
```
Run all the examples as usual.
## BLAS CPU support via OpenBLAS
Encoder processing can be accelerated on the CPU via OpenBLAS.
First, make sure you have installed `openblas`: https://www.openblas.net/
Now build `whisper.cpp` with OpenBLAS support:
```
make clean
WHISPER_OPENBLAS=1 make -j
```
## Limitations
- Inference only
@ -471,7 +572,7 @@ main: processing './samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 pr
[00:00:10.020 --> 00:00:11.000] country.
```
## Word-level timestamp
## Word-level timestamp (experimental)
The `--max-len` argument can be used to obtain word-level timestamps. Simply use `-ml 1`:
@ -512,6 +613,32 @@ main: processing './samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 pr
[00:00:10.510 --> 00:00:11.000] .
```
## Speaker segmentation via tinydiarize (experimental)
More information about this approach is available here: https://github.com/ggerganov/whisper.cpp/pull/1058
Sample usage:
```py
# download a tinydiarize compatible model
./models/download-ggml-model.sh small.en-tdrz
# run as usual, adding the "-tdrz" command-line argument
./main -f ./samples/a13.wav -m ./models/ggml-small.en-tdrz.bin -tdrz
...
main: processing './samples/a13.wav' (480000 samples, 30.0 sec), 4 threads, 1 processors, lang = en, task = transcribe, tdrz = 1, timestamps = 1 ...
...
[00:00:00.000 --> 00:00:03.800] Okay Houston, we've had a problem here. [SPEAKER_TURN]
[00:00:03.800 --> 00:00:06.200] This is Houston. Say again please. [SPEAKER_TURN]
[00:00:06.200 --> 00:00:08.260] Uh Houston we've had a problem.
[00:00:08.260 --> 00:00:11.320] We've had a main beam up on a volt. [SPEAKER_TURN]
[00:00:11.320 --> 00:00:13.820] Roger main beam interval. [SPEAKER_TURN]
[00:00:13.820 --> 00:00:15.100] Uh uh [SPEAKER_TURN]
[00:00:15.100 --> 00:00:18.020] So okay stand, by thirteen we're looking at it. [SPEAKER_TURN]
[00:00:18.020 --> 00:00:25.740] Okay uh right now uh Houston the uh voltage is uh is looking good um.
[00:00:27.620 --> 00:00:29.940] And we had a a pretty large bank or so.
```
## Karaoke-style movie generation (experimental)
The [main](examples/main) example provides support for output of karaoke-style movies, where the
@ -595,6 +722,8 @@ in [models](models).
- [X] Javascript: [bindings/javascript](bindings/javascript) | [#309](https://github.com/ggerganov/whisper.cpp/discussions/309)
- React Native (iOS / Android): [whisper.rn](https://github.com/mybigday/whisper.rn)
- [X] Go: [bindings/go](bindings/go) | [#312](https://github.com/ggerganov/whisper.cpp/discussions/312)
- [X] Java:
- [GiviMAD/whisper-jni](https://github.com/GiviMAD/whisper-jni)
- [X] Ruby: [bindings/ruby](bindings/ruby) | [#507](https://github.com/ggerganov/whisper.cpp/discussions/507)
- [X] Objective-C / Swift: [ggerganov/whisper.spm](https://github.com/ggerganov/whisper.spm) | [#313](https://github.com/ggerganov/whisper.cpp/discussions/313)
- [exPHAT/SwiftWhisper](https://github.com/exPHAT/SwiftWhisper)

View File

@ -32,7 +32,7 @@ mkdir:
modtidy:
@go mod tidy
clean:
clean:
@echo Clean
@rm -fr $(BUILD_DIR)
@go clean

View File

@ -31,7 +31,7 @@ func main() {
if err != nil {
panic(err)
}
if err := context.Process(samples, nil); err != nil {
if err := context.Process(samples, nil, nil); err != nil {
return err
}
@ -71,7 +71,7 @@ The examples are placed in the `build` directory. Once built, you can download a
And you can then test a model against samples with the following command:
```bash
./build/go-whisper -model models/ggml-tiny.en.bin samples/jfk.wav
./build/go-whisper -model models/ggml-tiny.en.bin samples/jfk.wav
```
## Using the bindings

View File

@ -67,7 +67,7 @@ func Process(model whisper.Model, path string, flags *Flags) error {
// Process the data
fmt.Fprintf(flags.Output(), " ...processing %q\n", path)
context.ResetTimings()
if err := context.Process(data, cb); err != nil {
if err := context.Process(data, cb, nil); err != nil {
return err
}

View File

@ -19,6 +19,10 @@ func (p *Params) SetTranslate(v bool) {
p.translate = toBool(v)
}
func (p *Params) SetSplitOnWord(v bool) {
p.split_on_word = toBool(v)
}
func (p *Params) SetNoContext(v bool) {
p.no_context = toBool(v)
}

View File

@ -81,6 +81,10 @@ func (context *context) SetSpeedup(v bool) {
context.params.SetSpeedup(v)
}
func (context *context) SetSplitOnWord(v bool) {
context.params.SetSplitOnWord(v)
}
// Set number of threads to use
func (context *context) SetThreads(v uint) {
context.params.SetThreads(int(v))
@ -93,7 +97,7 @@ func (context *context) SetOffset(v time.Duration) {
// Set duration of audio to process
func (context *context) SetDuration(v time.Duration) {
context.params.SetOffset(int(v.Milliseconds()))
context.params.SetDuration(int(v.Milliseconds()))
}
// Set timestamp token probability threshold (~0.01)
@ -152,12 +156,16 @@ func (context *context) WhisperLangAutoDetect(offset_ms int, n_threads int) ([]f
}
// Process new sample data and return any errors
func (context *context) Process(data []float32, cb SegmentCallback) error {
func (context *context) Process(
data []float32,
callNewSegment SegmentCallback,
callProgress ProgressCallback,
) error {
if context.model.ctx == nil {
return ErrInternalAppError
}
// If the callback is defined then we force on single_segment mode
if cb != nil {
if callNewSegment != nil {
context.params.SetSingleSegment(true)
}
@ -165,24 +173,28 @@ func (context *context) Process(data []float32, cb SegmentCallback) error {
processors := 0
if processors > 1 {
if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, nil, func(new int) {
if cb != nil {
if callNewSegment != nil {
num_segments := context.model.ctx.Whisper_full_n_segments()
s0 := num_segments - new
for i := s0; i < num_segments; i++ {
cb(toSegment(context.model.ctx, i))
callNewSegment(toSegment(context.model.ctx, i))
}
}
}); err != nil {
return err
}
} else if err := context.model.ctx.Whisper_full(context.params, data, nil, func(new int) {
if cb != nil {
if callNewSegment != nil {
num_segments := context.model.ctx.Whisper_full_n_segments()
s0 := num_segments - new
for i := s0; i < num_segments; i++ {
cb(toSegment(context.model.ctx, i))
callNewSegment(toSegment(context.model.ctx, i))
}
}
}, func(progress int) {
if callProgress != nil {
callProgress(progress)
}
}); err != nil {
return err
}

View File

@ -12,6 +12,10 @@ import (
// time. It is called during the Process function
type SegmentCallback func(Segment)
// ProgressCallback is the callback function for reporting progress during
// processing. It is called during the Process function
type ProgressCallback func(int)
// Model is the interface to a whisper model. Create a new model with the
// function whisper.New(string)
type Model interface {
@ -38,6 +42,7 @@ type Context interface {
SetDuration(time.Duration) // Set duration
SetThreads(uint) // Set number of threads to use
SetSpeedup(bool) // Set speedup flag
SetSplitOnWord(bool) // Set split on word flag
SetTokenThreshold(float32) // Set timestamp token probability threshold
SetTokenSumThreshold(float32) // Set timestamp token sum probability threshold
SetMaxSegmentLength(uint) // Set max segment length in characters
@ -47,7 +52,7 @@ type Context interface {
// Process mono audio data and return any errors.
// If defined, newly generated segments are passed to the
// callback function during processing.
Process([]float32, SegmentCallback) error
Process([]float32, SegmentCallback, ProgressCallback) error
// After process is called, return segments until the end of the stream
// is reached, when io.EOF is returned.

View File

@ -15,6 +15,7 @@ import (
#include <stdlib.h>
extern void callNewSegment(void* user_data, int new);
extern void callProgress(void* user_data, int progress);
extern bool callEncoderBegin(void* user_data);
// Text segment callback
@ -26,6 +27,15 @@ static void whisper_new_segment_cb(struct whisper_context* ctx, struct whisper_s
}
}
// Progress callback
// Called on every newly generated text segment
// Use the whisper_full_...() functions to obtain the text segments
static void whisper_progress_cb(struct whisper_context* ctx, struct whisper_state* state, int progress, void* user_data) {
if(user_data != NULL && ctx != NULL) {
callProgress(user_data, progress);
}
}
// Encoder begin callback
// If not NULL, called before the encoder starts
// If it returns false, the computation is aborted
@ -43,6 +53,8 @@ static struct whisper_full_params whisper_full_default_params_cb(struct whisper_
params.new_segment_callback_user_data = (void*)(ctx);
params.encoder_begin_callback = whisper_encoder_begin_cb;
params.encoder_begin_callback_user_data = (void*)(ctx);
params.progress_callback = whisper_progress_cb;
params.progress_callback_user_data = (void*)(ctx);
return params;
}
*/
@ -258,13 +270,13 @@ func (ctx *Context) Whisper_token_lang(lang_id int) Token {
}
// Task tokens
func Whisper_token_translate() Token {
return Token(C.whisper_token_translate())
func (ctx *Context) Whisper_token_translate() Token {
return Token(C.whisper_token_translate((*C.struct_whisper_context)(ctx)))
}
// Task tokens
func Whisper_token_transcribe() Token {
return Token(C.whisper_token_transcribe())
func (ctx *Context) Whisper_token_transcribe() Token {
return Token(C.whisper_token_transcribe((*C.struct_whisper_context)(ctx)))
}
// Performance information
@ -290,11 +302,19 @@ func (ctx *Context) Whisper_full_default_params(strategy SamplingStrategy) Param
// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
// Uses the specified decoding strategy to obtain the text.
func (ctx *Context) Whisper_full(params Params, samples []float32, encoderBeginCallback func() bool, newSegmentCallback func(int)) error {
func (ctx *Context) Whisper_full(
params Params,
samples []float32,
encoderBeginCallback func() bool,
newSegmentCallback func(int),
progressCallback func(int),
) error {
registerEncoderBeginCallback(ctx, encoderBeginCallback)
registerNewSegmentCallback(ctx, newSegmentCallback)
registerProgressCallback(ctx, progressCallback)
defer registerEncoderBeginCallback(ctx, nil)
defer registerNewSegmentCallback(ctx, nil)
defer registerProgressCallback(ctx, nil)
if C.whisper_full((*C.struct_whisper_context)(ctx), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples))) == 0 {
return nil
} else {
@ -318,6 +338,18 @@ func (ctx *Context) Whisper_full_parallel(params Params, samples []float32, proc
}
}
// Return the id of the autodetected language, returns -1 if not found
// Added to whisper.cpp in
// https://github.com/ggerganov/whisper.cpp/commit/a1c1583cc7cd8b75222857afc936f0638c5683d6
//
// Examples:
//
// "de" -> 2
// "german" -> 2
func (ctx *Context) Whisper_full_lang_id() int {
return int(C.whisper_full_lang_id((*C.struct_whisper_context)(ctx)))
}
// Number of generated text segments.
// A segment can be a few words, a sentence, or even a paragraph.
func (ctx *Context) Whisper_full_n_segments() int {
@ -370,6 +402,7 @@ func (ctx *Context) Whisper_full_get_token_p(segment int, token int) float32 {
var (
cbNewSegment = make(map[unsafe.Pointer]func(int))
cbProgress = make(map[unsafe.Pointer]func(int))
cbEncoderBegin = make(map[unsafe.Pointer]func() bool)
)
@ -381,6 +414,14 @@ func registerNewSegmentCallback(ctx *Context, fn func(int)) {
}
}
func registerProgressCallback(ctx *Context, fn func(int)) {
if fn == nil {
delete(cbProgress, unsafe.Pointer(ctx))
} else {
cbProgress[unsafe.Pointer(ctx)] = fn
}
}
func registerEncoderBeginCallback(ctx *Context, fn func() bool) {
if fn == nil {
delete(cbEncoderBegin, unsafe.Pointer(ctx))
@ -396,6 +437,13 @@ func callNewSegment(user_data unsafe.Pointer, new C.int) {
}
}
//export callProgress
func callProgress(user_data unsafe.Pointer, progress C.int) {
if fn, ok := cbProgress[user_data]; ok {
fn(int(progress))
}
}
//export callEncoderBegin
func callEncoderBegin(user_data unsafe.Pointer) C.bool {
if fn, ok := cbEncoderBegin[user_data]; ok {
@ -415,3 +463,7 @@ func (t TokenData) T0() int64 {
func (t TokenData) T1() int64 {
return int64(t.t1)
}
func (t TokenData) Id() Token {
return Token(t.id)
}

View File

@ -52,7 +52,7 @@ func Test_Whisper_001(t *testing.T) {
defer ctx.Whisper_free()
params := ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY)
data := buf.AsFloat32Buffer().Data
err = ctx.Whisper_full(params, data, nil, nil)
err = ctx.Whisper_full(params, data, nil, nil, nil)
assert.NoError(err)
// Print out tokens

124
bindings/java/.idea/uiDesigner.xml generated Normal file
View File

@ -0,0 +1,124 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="Palette2">
<group name="Swing">
<item class="com.intellij.uiDesigner.HSpacer" tooltip-text="Horizontal Spacer" icon="/com/intellij/uiDesigner/icons/hspacer.svg" removable="false" auto-create-binding="false" can-attach-label="false">
<default-constraints vsize-policy="1" hsize-policy="6" anchor="0" fill="1" />
</item>
<item class="com.intellij.uiDesigner.VSpacer" tooltip-text="Vertical Spacer" icon="/com/intellij/uiDesigner/icons/vspacer.svg" removable="false" auto-create-binding="false" can-attach-label="false">
<default-constraints vsize-policy="6" hsize-policy="1" anchor="0" fill="2" />
</item>
<item class="javax.swing.JPanel" icon="/com/intellij/uiDesigner/icons/panel.svg" removable="false" auto-create-binding="false" can-attach-label="false">
<default-constraints vsize-policy="3" hsize-policy="3" anchor="0" fill="3" />
</item>
<item class="javax.swing.JScrollPane" icon="/com/intellij/uiDesigner/icons/scrollPane.svg" removable="false" auto-create-binding="false" can-attach-label="true">
<default-constraints vsize-policy="7" hsize-policy="7" anchor="0" fill="3" />
</item>
<item class="javax.swing.JButton" icon="/com/intellij/uiDesigner/icons/button.svg" removable="false" auto-create-binding="true" can-attach-label="false">
<default-constraints vsize-policy="0" hsize-policy="3" anchor="0" fill="1" />
<initial-values>
<property name="text" value="Button" />
</initial-values>
</item>
<item class="javax.swing.JRadioButton" icon="/com/intellij/uiDesigner/icons/radioButton.svg" removable="false" auto-create-binding="true" can-attach-label="false">
<default-constraints vsize-policy="0" hsize-policy="3" anchor="8" fill="0" />
<initial-values>
<property name="text" value="RadioButton" />
</initial-values>
</item>
<item class="javax.swing.JCheckBox" icon="/com/intellij/uiDesigner/icons/checkBox.svg" removable="false" auto-create-binding="true" can-attach-label="false">
<default-constraints vsize-policy="0" hsize-policy="3" anchor="8" fill="0" />
<initial-values>
<property name="text" value="CheckBox" />
</initial-values>
</item>
<item class="javax.swing.JLabel" icon="/com/intellij/uiDesigner/icons/label.svg" removable="false" auto-create-binding="false" can-attach-label="false">
<default-constraints vsize-policy="0" hsize-policy="0" anchor="8" fill="0" />
<initial-values>
<property name="text" value="Label" />
</initial-values>
</item>
<item class="javax.swing.JTextField" icon="/com/intellij/uiDesigner/icons/textField.svg" removable="false" auto-create-binding="true" can-attach-label="true">
<default-constraints vsize-policy="0" hsize-policy="6" anchor="8" fill="1">
<preferred-size width="150" height="-1" />
</default-constraints>
</item>
<item class="javax.swing.JPasswordField" icon="/com/intellij/uiDesigner/icons/passwordField.svg" removable="false" auto-create-binding="true" can-attach-label="true">
<default-constraints vsize-policy="0" hsize-policy="6" anchor="8" fill="1">
<preferred-size width="150" height="-1" />
</default-constraints>
</item>
<item class="javax.swing.JFormattedTextField" icon="/com/intellij/uiDesigner/icons/formattedTextField.svg" removable="false" auto-create-binding="true" can-attach-label="true">
<default-constraints vsize-policy="0" hsize-policy="6" anchor="8" fill="1">
<preferred-size width="150" height="-1" />
</default-constraints>
</item>
<item class="javax.swing.JTextArea" icon="/com/intellij/uiDesigner/icons/textArea.svg" removable="false" auto-create-binding="true" can-attach-label="true">
<default-constraints vsize-policy="6" hsize-policy="6" anchor="0" fill="3">
<preferred-size width="150" height="50" />
</default-constraints>
</item>
<item class="javax.swing.JTextPane" icon="/com/intellij/uiDesigner/icons/textPane.svg" removable="false" auto-create-binding="true" can-attach-label="true">
<default-constraints vsize-policy="6" hsize-policy="6" anchor="0" fill="3">
<preferred-size width="150" height="50" />
</default-constraints>
</item>
<item class="javax.swing.JEditorPane" icon="/com/intellij/uiDesigner/icons/editorPane.svg" removable="false" auto-create-binding="true" can-attach-label="true">
<default-constraints vsize-policy="6" hsize-policy="6" anchor="0" fill="3">
<preferred-size width="150" height="50" />
</default-constraints>
</item>
<item class="javax.swing.JComboBox" icon="/com/intellij/uiDesigner/icons/comboBox.svg" removable="false" auto-create-binding="true" can-attach-label="true">
<default-constraints vsize-policy="0" hsize-policy="2" anchor="8" fill="1" />
</item>
<item class="javax.swing.JTable" icon="/com/intellij/uiDesigner/icons/table.svg" removable="false" auto-create-binding="true" can-attach-label="false">
<default-constraints vsize-policy="6" hsize-policy="6" anchor="0" fill="3">
<preferred-size width="150" height="50" />
</default-constraints>
</item>
<item class="javax.swing.JList" icon="/com/intellij/uiDesigner/icons/list.svg" removable="false" auto-create-binding="true" can-attach-label="false">
<default-constraints vsize-policy="6" hsize-policy="2" anchor="0" fill="3">
<preferred-size width="150" height="50" />
</default-constraints>
</item>
<item class="javax.swing.JTree" icon="/com/intellij/uiDesigner/icons/tree.svg" removable="false" auto-create-binding="true" can-attach-label="false">
<default-constraints vsize-policy="6" hsize-policy="6" anchor="0" fill="3">
<preferred-size width="150" height="50" />
</default-constraints>
</item>
<item class="javax.swing.JTabbedPane" icon="/com/intellij/uiDesigner/icons/tabbedPane.svg" removable="false" auto-create-binding="true" can-attach-label="false">
<default-constraints vsize-policy="3" hsize-policy="3" anchor="0" fill="3">
<preferred-size width="200" height="200" />
</default-constraints>
</item>
<item class="javax.swing.JSplitPane" icon="/com/intellij/uiDesigner/icons/splitPane.svg" removable="false" auto-create-binding="false" can-attach-label="false">
<default-constraints vsize-policy="3" hsize-policy="3" anchor="0" fill="3">
<preferred-size width="200" height="200" />
</default-constraints>
</item>
<item class="javax.swing.JSpinner" icon="/com/intellij/uiDesigner/icons/spinner.svg" removable="false" auto-create-binding="true" can-attach-label="true">
<default-constraints vsize-policy="0" hsize-policy="6" anchor="8" fill="1" />
</item>
<item class="javax.swing.JSlider" icon="/com/intellij/uiDesigner/icons/slider.svg" removable="false" auto-create-binding="true" can-attach-label="false">
<default-constraints vsize-policy="0" hsize-policy="6" anchor="8" fill="1" />
</item>
<item class="javax.swing.JSeparator" icon="/com/intellij/uiDesigner/icons/separator.svg" removable="false" auto-create-binding="false" can-attach-label="false">
<default-constraints vsize-policy="6" hsize-policy="6" anchor="0" fill="3" />
</item>
<item class="javax.swing.JProgressBar" icon="/com/intellij/uiDesigner/icons/progressbar.svg" removable="false" auto-create-binding="true" can-attach-label="false">
<default-constraints vsize-policy="0" hsize-policy="6" anchor="0" fill="1" />
</item>
<item class="javax.swing.JToolBar" icon="/com/intellij/uiDesigner/icons/toolbar.svg" removable="false" auto-create-binding="false" can-attach-label="false">
<default-constraints vsize-policy="0" hsize-policy="6" anchor="0" fill="1">
<preferred-size width="-1" height="20" />
</default-constraints>
</item>
<item class="javax.swing.JToolBar$Separator" icon="/com/intellij/uiDesigner/icons/toolbarSeparator.svg" removable="false" auto-create-binding="false" can-attach-label="false">
<default-constraints vsize-policy="0" hsize-policy="0" anchor="0" fill="1" />
</item>
<item class="javax.swing.JScrollBar" icon="/com/intellij/uiDesigner/icons/scrollbar.svg" removable="false" auto-create-binding="true" can-attach-label="false">
<default-constraints vsize-policy="6" hsize-policy="0" anchor="0" fill="2" />
</item>
</group>
</component>
</project>

71
bindings/java/README.md Normal file
View File

@ -0,0 +1,71 @@
# Java JNI bindings for Whisper
This package provides Java JNI bindings for whisper.cpp. They have been tested on:
* <strike>Darwin (OS X) 12.6 on x64_64</strike>
* Ubuntu on x86_64
* Windows on x86_64
The "low level" bindings are in `WhisperCppJnaLibrary`. The most simple usage is as follows:
JNA will attempt to load the `whispercpp` shared library from:
- jna.library.path
- jna.platform.library
- ~/Library/Frameworks
- /Library/Frameworks
- /System/Library/Frameworks
- classpath
```java
import io.github.ggerganov.whispercpp.WhisperCpp;
public class Example {
public static void main(String[] args) {
WhisperCpp whisper = new WhisperCpp();
// By default, models are loaded from ~/.cache/whisper/ and are usually named "ggml-${name}.bin"
// or you can provide the absolute path to the model file.
long context = whisper.initContext("base.en");
try {
var whisperParams = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
// custom configuration if required
whisperParams.temperature_inc = 0f;
var samples = readAudio(); // divide each value by 32767.0f
whisper.fullTranscribe(whisperParams, samples);
int segmentCount = whisper.getTextSegmentCount(context);
for (int i = 0; i < segmentCount; i++) {
String text = whisper.getTextSegment(context, i);
System.out.println(segment.getText());
}
} finally {
whisper.freeContext(context);
}
}
}
```
## Building & Testing
In order to build, you need to have the JDK 8 or higher installed. Run the tests with:
```bash
git clone https://github.com/ggerganov/whisper.cpp.git
cd whisper.cpp/bindings/java
./gradlew build
```
You need to have the `whisper` library in your [JNA library path](https://java-native-access.github.io/jna/4.2.1/com/sun/jna/NativeLibrary.html). On Windows the dll is included in the jar and you can update it:
```bash
copy /y ..\..\build\bin\Release\whisper.dll build\generated\resources\main\win32-x86-64\whisper.dll
```
## License
The license for the Go bindings is the same as the license for the rest of the whisper.cpp project, which is the MIT License. See the `LICENSE` file for more details.

112
bindings/java/build.gradle Normal file
View File

@ -0,0 +1,112 @@
plugins {
id 'java'
id 'java-library'
id 'maven-publish'
}
archivesBaseName = 'whispercpp'
group = 'io.github.ggerganov'
version = '1.4.0'
sourceCompatibility = 1.8
targetCompatibility = 1.8
sourceSets {
main {
resources {
srcDirs = ['src/main/resources', 'build/generated/resources/main']
}
}
test {
runtimeClasspath += files('build/generated/resources/main')
}
}
tasks.register('copyLibwhisperDynlib', Copy) {
from '../../build'
include 'libwhisper.dynlib'
into 'build/generated/resources/main/darwin'
}
tasks.register('copyLibwhisperSo', Copy) {
from '../../build'
include 'libwhisper.so'
into 'build/generated/resources/main/linux-x86-64'
}
tasks.register('copyWhisperDll', Copy) {
from '../../build/Release'
include 'whisper.dll'
into 'build/generated/resources/main/windows-x86-64'
}
tasks.register('copyLibs') {
dependsOn copyLibwhisperDynlib, copyLibwhisperSo, copyWhisperDll
}
test {
systemProperty 'jna.library.path', project.file('build/generated/resources/main').absolutePath
}
java {
withSourcesJar()
withJavadocJar()
}
jar {
exclude '**/whisper_java.exp', '**/whisper_java.lib'
}
javadoc {
options.addStringOption('Xdoclint:none', '-quiet')
}
tasks.withType(Test) {
useJUnitPlatform()
}
dependencies {
implementation "net.java.dev.jna:jna:5.13.0"
testImplementation "org.junit.jupiter:junit-jupiter:5.9.2"
testImplementation "org.assertj:assertj-core:3.24.2"
}
repositories {
mavenCentral()
}
publishing {
publications {
mavenJava(MavenPublication) {
artifactId = 'whispercpp'
from components.java
pom {
name = 'whispercpp'
description = "Java JNA bindings for OpenAI's Whisper model, implemented in C/C++"
url = 'https://github.com/ggerganov/whisper.cpp'
licenses {
license {
name = 'MIT licence'
url = 'https://raw.githubusercontent.com/ggerganov/whisper.cpp/master/LICENSE'
}
}
developers {
developer {
id = 'ggerganov'
name = 'Georgi Gerganov'
email = 'ggerganov@gmail.com'
}
developer {
id = 'nalbion'
name = 'Nicholas Albion'
email = 'nalbion@yahoo.com'
}
}
scm {
connection = 'scm:git:git://github.com/ggerganov/whisper.cpp.git'
url = 'https://github.com/ggerganov/whisper.cpp'
}
}
}
}
}

View File

@ -0,0 +1,6 @@
org.gradle.jvmargs=-Xms256m -Xmx1024m
system.include.dir=/usr/include
#system.local.include.dir=../../include
system.local.include.dir=./build/generated/sources/headers/java/main
jni.include.dir=/usr/lib/jvm/java-8-openjdk-amd64/include/
jni.lib.dir=/usr/lib/jvm/java-8-openjdk-amd64/lib/

Binary file not shown.

View File

@ -0,0 +1,6 @@
distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists
distributionUrl=https\://services.gradle.org/distributions/gradle-8.1-bin.zip
networkTimeout=10000
zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists

244
bindings/java/gradlew vendored Normal file
View File

@ -0,0 +1,244 @@
#!/bin/sh
#
# Copyright © 2015-2021 the original authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
##############################################################################
#
# Gradle start up script for POSIX generated by Gradle.
#
# Important for running:
#
# (1) You need a POSIX-compliant shell to run this script. If your /bin/sh is
# noncompliant, but you have some other compliant shell such as ksh or
# bash, then to run this script, type that shell name before the whole
# command line, like:
#
# ksh Gradle
#
# Busybox and similar reduced shells will NOT work, because this script
# requires all of these POSIX shell features:
# * functions;
# * expansions «$var», «${var}», «${var:-default}», «${var+SET}»,
# «${var#prefix}», «${var%suffix}», and «$( cmd )»;
# * compound commands having a testable exit status, especially «case»;
# * various built-in commands including «command», «set», and «ulimit».
#
# Important for patching:
#
# (2) This script targets any POSIX shell, so it avoids extensions provided
# by Bash, Ksh, etc; in particular arrays are avoided.
#
# The "traditional" practice of packing multiple parameters into a
# space-separated string is a well documented source of bugs and security
# problems, so this is (mostly) avoided, by progressively accumulating
# options in "$@", and eventually passing that to Java.
#
# Where the inherited environment variables (DEFAULT_JVM_OPTS, JAVA_OPTS,
# and GRADLE_OPTS) rely on word-splitting, this is performed explicitly;
# see the in-line comments for details.
#
# There are tweaks for specific operating systems such as AIX, CygWin,
# Darwin, MinGW, and NonStop.
#
# (3) This script is generated from the Groovy template
# https://github.com/gradle/gradle/blob/HEAD/subprojects/plugins/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt
# within the Gradle project.
#
# You can find Gradle at https://github.com/gradle/gradle/.
#
##############################################################################
# Attempt to set APP_HOME
# Resolve links: $0 may be a link
app_path=$0
# Need this for daisy-chained symlinks.
while
APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path
[ -h "$app_path" ]
do
ls=$( ls -ld "$app_path" )
link=${ls#*' -> '}
case $link in #(
/*) app_path=$link ;; #(
*) app_path=$APP_HOME$link ;;
esac
done
# This is normally unused
# shellcheck disable=SC2034
APP_BASE_NAME=${0##*/}
APP_HOME=$( cd "${APP_HOME:-./}" && pwd -P ) || exit
# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"'
# Use the maximum available, or set MAX_FD != -1 to use that value.
MAX_FD=maximum
warn () {
echo "$*"
} >&2
die () {
echo
echo "$*"
echo
exit 1
} >&2
# OS specific support (must be 'true' or 'false').
cygwin=false
msys=false
darwin=false
nonstop=false
case "$( uname )" in #(
CYGWIN* ) cygwin=true ;; #(
Darwin* ) darwin=true ;; #(
MSYS* | MINGW* ) msys=true ;; #(
NONSTOP* ) nonstop=true ;;
esac
CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
# Determine the Java command to use to start the JVM.
if [ -n "$JAVA_HOME" ] ; then
if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
# IBM's JDK on AIX uses strange locations for the executables
JAVACMD=$JAVA_HOME/jre/sh/java
else
JAVACMD=$JAVA_HOME/bin/java
fi
if [ ! -x "$JAVACMD" ] ; then
die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
Please set the JAVA_HOME variable in your environment to match the
location of your Java installation."
fi
else
JAVACMD=java
which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
Please set the JAVA_HOME variable in your environment to match the
location of your Java installation."
fi
# Increase the maximum file descriptors if we can.
if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then
case $MAX_FD in #(
max*)
# In POSIX sh, ulimit -H is undefined. That's why the result is checked to see if it worked.
# shellcheck disable=SC3045
MAX_FD=$( ulimit -H -n ) ||
warn "Could not query maximum file descriptor limit"
esac
case $MAX_FD in #(
'' | soft) :;; #(
*)
# In POSIX sh, ulimit -n is undefined. That's why the result is checked to see if it worked.
# shellcheck disable=SC3045
ulimit -n "$MAX_FD" ||
warn "Could not set maximum file descriptor limit to $MAX_FD"
esac
fi
# Collect all arguments for the java command, stacking in reverse order:
# * args from the command line
# * the main class name
# * -classpath
# * -D...appname settings
# * --module-path (only if needed)
# * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables.
# For Cygwin or MSYS, switch paths to Windows format before running java
if "$cygwin" || "$msys" ; then
APP_HOME=$( cygpath --path --mixed "$APP_HOME" )
CLASSPATH=$( cygpath --path --mixed "$CLASSPATH" )
JAVACMD=$( cygpath --unix "$JAVACMD" )
# Now convert the arguments - kludge to limit ourselves to /bin/sh
for arg do
if
case $arg in #(
-*) false ;; # don't mess with options #(
/?*) t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath
[ -e "$t" ] ;; #(
*) false ;;
esac
then
arg=$( cygpath --path --ignore --mixed "$arg" )
fi
# Roll the args list around exactly as many times as the number of
# args, so each arg winds up back in the position where it started, but
# possibly modified.
#
# NB: a `for` loop captures its iteration list before it begins, so
# changing the positional parameters here affects neither the number of
# iterations, nor the values presented in `arg`.
shift # remove old arg
set -- "$@" "$arg" # push replacement arg
done
fi
# Collect all arguments for the java command;
# * $DEFAULT_JVM_OPTS, $JAVA_OPTS, and $GRADLE_OPTS can contain fragments of
# shell script including quotes and variable substitutions, so put them in
# double quotes to make sure that they get re-expanded; and
# * put everything else in single quotes, so that it's not re-expanded.
set -- \
"-Dorg.gradle.appname=$APP_BASE_NAME" \
-classpath "$CLASSPATH" \
org.gradle.wrapper.GradleWrapperMain \
"$@"
# Stop when "xargs" is not available.
if ! command -v xargs >/dev/null 2>&1
then
die "xargs is not available"
fi
# Use "xargs" to parse quoted args.
#
# With -n1 it outputs one arg per line, with the quotes and backslashes removed.
#
# In Bash we could simply go:
#
# readarray ARGS < <( xargs -n1 <<<"$var" ) &&
# set -- "${ARGS[@]}" "$@"
#
# but POSIX shell has neither arrays nor command substitution, so instead we
# post-process each arg (as a line of input to sed) to backslash-escape any
# character that might be a shell metacharacter, then use eval to reverse
# that process (while maintaining the separation between arguments), and wrap
# the whole thing up as a single "set" statement.
#
# This will of course break if any of these variables contains a newline or
# an unmatched quote.
#
eval "set -- $(
printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" |
xargs -n1 |
sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' |
tr '\n' ' '
)" '"$@"'
exec "$JAVACMD" "$@"

92
bindings/java/gradlew.bat vendored Normal file
View File

@ -0,0 +1,92 @@
@rem
@rem Copyright 2015 the original author or authors.
@rem
@rem Licensed under the Apache License, Version 2.0 (the "License");
@rem you may not use this file except in compliance with the License.
@rem You may obtain a copy of the License at
@rem
@rem https://www.apache.org/licenses/LICENSE-2.0
@rem
@rem Unless required by applicable law or agreed to in writing, software
@rem distributed under the License is distributed on an "AS IS" BASIS,
@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@rem See the License for the specific language governing permissions and
@rem limitations under the License.
@rem
@if "%DEBUG%"=="" @echo off
@rem ##########################################################################
@rem
@rem Gradle startup script for Windows
@rem
@rem ##########################################################################
@rem Set local scope for the variables with windows NT shell
if "%OS%"=="Windows_NT" setlocal
set DIRNAME=%~dp0
if "%DIRNAME%"=="" set DIRNAME=.
@rem This is normally unused
set APP_BASE_NAME=%~n0
set APP_HOME=%DIRNAME%
@rem Resolve any "." and ".." in APP_HOME to make it shorter.
for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi
@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m"
@rem Find java.exe
if defined JAVA_HOME goto findJavaFromJavaHome
set JAVA_EXE=java.exe
%JAVA_EXE% -version >NUL 2>&1
if %ERRORLEVEL% equ 0 goto execute
echo.
echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
echo.
echo Please set the JAVA_HOME variable in your environment to match the
echo location of your Java installation.
goto fail
:findJavaFromJavaHome
set JAVA_HOME=%JAVA_HOME:"=%
set JAVA_EXE=%JAVA_HOME%/bin/java.exe
if exist "%JAVA_EXE%" goto execute
echo.
echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME%
echo.
echo Please set the JAVA_HOME variable in your environment to match the
echo location of your Java installation.
goto fail
:execute
@rem Setup the command line
set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
@rem Execute Gradle
"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %*
:end
@rem End local scope for the variables with windows NT shell
if %ERRORLEVEL% equ 0 goto mainEnd
:fail
rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
rem the _cmd.exe /c_ return code!
set EXIT_CODE=%ERRORLEVEL%
if %EXIT_CODE% equ 0 set EXIT_CODE=1
if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE%
exit /b %EXIT_CODE%
:mainEnd
if "%OS%"=="Windows_NT" endlocal
:omega

View File

@ -0,0 +1 @@
rootProject.name = "whispercpp"

View File

@ -0,0 +1,39 @@
package io.github.ggerganov.whispercpp;
import com.sun.jna.Structure;
import com.sun.jna.ptr.PointerByReference;
import io.github.ggerganov.whispercpp.ggml.GgmlType;
import io.github.ggerganov.whispercpp.WhisperModel;
import java.util.List;
public class WhisperContext extends Structure {
int t_load_us = 0;
int t_start_us = 0;
/** weight type (FP32 / FP16 / QX) */
GgmlType wtype = GgmlType.GGML_TYPE_F16;
/** intermediate type (FP32 or FP16) */
GgmlType itype = GgmlType.GGML_TYPE_F16;
// WhisperModel model;
public PointerByReference model;
// whisper_vocab vocab;
// whisper_state * state = nullptr;
public PointerByReference vocab;
public PointerByReference state;
/** populated by whisper_init_from_file() */
String path_model;
// public static class ByReference extends WhisperContext implements Structure.ByReference {
// }
//
// public static class ByValue extends WhisperContext implements Structure.ByValue {
// }
//
// @Override
// protected List<String> getFieldOrder() {
// return List.of("t_load_us", "t_start_us", "wtype", "itype", "model", "vocab", "state", "path_model");
// }
}

View File

@ -0,0 +1,151 @@
package io.github.ggerganov.whispercpp;
import com.sun.jna.Native;
import com.sun.jna.Pointer;
import io.github.ggerganov.whispercpp.params.WhisperFullParams;
import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
/**
* Before calling most methods, you must call `initContext(modelPath)` to initialise the `ctx` Pointer.
*/
public class WhisperCpp implements AutoCloseable {
private WhisperCppJnaLibrary lib = WhisperCppJnaLibrary.instance;
private Pointer ctx = null;
private Pointer greedyPointer = null;
private Pointer beamPointer = null;
public File modelDir() {
String modelDirPath = System.getenv("XDG_CACHE_HOME");
if (modelDirPath == null) {
modelDirPath = System.getProperty("user.home") + "/.cache";
}
return new File(modelDirPath, "whisper");
}
/**
* @param modelPath - absolute path, or just the name (eg: "base", "base-en" or "base.en")
*/
public void initContext(String modelPath) throws FileNotFoundException {
if (ctx != null) {
lib.whisper_free(ctx);
}
if (!modelPath.contains("/") && !modelPath.contains("\\")) {
if (!modelPath.endsWith(".bin")) {
modelPath = "ggml-" + modelPath.replace("-", ".") + ".bin";
}
modelPath = new File(modelDir(), modelPath).getAbsolutePath();
}
ctx = lib.whisper_init_from_file(modelPath);
if (ctx == null) {
throw new FileNotFoundException(modelPath);
}
}
/**
* Provides default params which can be used with `whisper_full()` etc.
* Because this function allocates memory for the params, the caller must call either:
* - call `whisper_free_params()`
* - `Native.free(Pointer.nativeValue(pointer));`
*
* @param strategy - GREEDY
*/
public WhisperFullParams getFullDefaultParams(WhisperSamplingStrategy strategy) {
Pointer pointer;
// whisper_full_default_params_by_ref allocates memory which we need to delete, so only create max 1 pointer for each strategy.
if (strategy == WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY) {
if (greedyPointer == null) {
greedyPointer = lib.whisper_full_default_params_by_ref(strategy.ordinal());
}
pointer = greedyPointer;
} else {
if (beamPointer == null) {
beamPointer = lib.whisper_full_default_params_by_ref(strategy.ordinal());
}
pointer = beamPointer;
}
WhisperFullParams params = new WhisperFullParams(pointer);
params.read();
return params;
}
@Override
public void close() {
freeContext();
freeParams();
System.out.println("Whisper closed");
}
private void freeContext() {
if (ctx != null) {
lib.whisper_free(ctx);
}
}
private void freeParams() {
if (greedyPointer != null) {
Native.free(Pointer.nativeValue(greedyPointer));
greedyPointer = null;
}
if (beamPointer != null) {
Native.free(Pointer.nativeValue(beamPointer));
beamPointer = null;
}
}
/**
* Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text.
* Not thread safe for same context
* Uses the specified decoding strategy to obtain the text.
*/
public String fullTranscribe(WhisperFullParams whisperParams, float[] audioData) throws IOException {
if (ctx == null) {
throw new IllegalStateException("Model not initialised");
}
if (lib.whisper_full(ctx, whisperParams, audioData, audioData.length) != 0) {
throw new IOException("Failed to process audio");
}
int nSegments = lib.whisper_full_n_segments(ctx);
StringBuilder str = new StringBuilder();
for (int i = 0; i < nSegments; i++) {
String text = lib.whisper_full_get_segment_text(ctx, i);
System.out.println("Segment:" + text);
str.append(text);
}
return str.toString().trim();
}
// public int getTextSegmentCount(Pointer ctx) {
// return lib.whisper_full_n_segments(ctx);
// }
// public String getTextSegment(Pointer ctx, int index) {
// return lib.whisper_full_get_segment_text(ctx, index);
// }
public String getSystemInfo() {
return lib.whisper_print_system_info();
}
public int benchMemcpy(int nthread) {
return lib.whisper_bench_memcpy(nthread);
}
public int benchGgmlMulMat(int nthread) {
return lib.whisper_bench_ggml_mul_mat(nthread);
}
}

View File

@ -0,0 +1,376 @@
package io.github.ggerganov.whispercpp;
import com.sun.jna.Library;
import com.sun.jna.Native;
import com.sun.jna.Pointer;
import io.github.ggerganov.whispercpp.model.WhisperModelLoader;
import io.github.ggerganov.whispercpp.model.WhisperTokenData;
import io.github.ggerganov.whispercpp.params.WhisperFullParams;
public interface WhisperCppJnaLibrary extends Library {
WhisperCppJnaLibrary instance = Native.load("whisper", WhisperCppJnaLibrary.class);
String whisper_print_system_info();
/**
* Allocate (almost) all memory needed for the model by loading from a file.
*
* @param path_model Path to the model file
* @return Whisper context on success, null on failure
*/
Pointer whisper_init_from_file(String path_model);
/**
* Allocate (almost) all memory needed for the model by loading from a buffer.
*
* @param buffer Model buffer
* @param buffer_size Size of the model buffer
* @return Whisper context on success, null on failure
*/
Pointer whisper_init_from_buffer(Pointer buffer, int buffer_size);
/**
* Allocate (almost) all memory needed for the model using a model loader.
*
* @param loader Model loader
* @return Whisper context on success, null on failure
*/
Pointer whisper_init(WhisperModelLoader loader);
/**
* Allocate (almost) all memory needed for the model by loading from a file without allocating the state.
*
* @param path_model Path to the model file
* @return Whisper context on success, null on failure
*/
Pointer whisper_init_from_file_no_state(String path_model);
/**
* Allocate (almost) all memory needed for the model by loading from a buffer without allocating the state.
*
* @param buffer Model buffer
* @param buffer_size Size of the model buffer
* @return Whisper context on success, null on failure
*/
Pointer whisper_init_from_buffer_no_state(Pointer buffer, int buffer_size);
// Pointer whisper_init_from_buffer_no_state(Pointer buffer, long buffer_size);
/**
* Allocate (almost) all memory needed for the model using a model loader without allocating the state.
*
* @param loader Model loader
* @return Whisper context on success, null on failure
*/
Pointer whisper_init_no_state(WhisperModelLoader loader);
/**
* Allocate memory for the Whisper state.
*
* @param ctx Whisper context
* @return Whisper state on success, null on failure
*/
Pointer whisper_init_state(Pointer ctx);
/**
* Free all allocated memory associated with the Whisper context.
*
* @param ctx Whisper context
*/
void whisper_free(Pointer ctx);
/**
* Free all allocated memory associated with the Whisper state.
*
* @param state Whisper state
*/
void whisper_free_state(Pointer state);
/**
* Convert RAW PCM audio to log mel spectrogram.
* The resulting spectrogram is stored inside the default state of the provided whisper context.
*
* @param ctx - Pointer to a WhisperContext
* @return 0 on success
*/
int whisper_pcm_to_mel(Pointer ctx, final float[] samples, int n_samples, int n_threads);
/**
* @param ctx Pointer to a WhisperContext
* @param state Pointer to WhisperState
* @param n_samples
* @param n_threads
* @return 0 on success
*/
int whisper_pcm_to_mel_with_state(Pointer ctx, Pointer state, final float[] samples, int n_samples, int n_threads);
/**
* This can be used to set a custom log mel spectrogram inside the default state of the provided whisper context.
* Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
* n_mel must be 80
* @return 0 on success
*/
int whisper_set_mel(Pointer ctx, final float[] data, int n_len, int n_mel);
int whisper_set_mel_with_state(Pointer ctx, Pointer state, final float[] data, int n_len, int n_mel);
/**
* Run the Whisper encoder on the log mel spectrogram stored inside the default state in the provided whisper context.
* Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
* Offset can be used to specify the offset of the first frame in the spectrogram.
* @return 0 on success
*/
int whisper_encode(Pointer ctx, int offset, int n_threads);
int whisper_encode_with_state(Pointer ctx, Pointer state, int offset, int n_threads);
/**
* Run the Whisper decoder to obtain the logits and probabilities for the next token.
* Make sure to call whisper_encode() first.
* tokens + n_tokens is the provided context for the decoder.
* n_past is the number of tokens to use from previous decoder calls.
* Returns 0 on success
* TODO: add support for multiple decoders
*/
int whisper_decode(Pointer ctx, Pointer tokens, int n_tokens, int n_past, int n_threads);
/**
* @param ctx
* @param state
* @param tokens Pointer to int tokens
* @param n_tokens
* @param n_past
* @param n_threads
* @return
*/
int whisper_decode_with_state(Pointer ctx, Pointer state, Pointer tokens, int n_tokens, int n_past, int n_threads);
/**
* Convert the provided text into tokens.
* The tokens pointer must be large enough to hold the resulting tokens.
* Returns the number of tokens on success, no more than n_max_tokens
* Returns -1 on failure
* TODO: not sure if correct
*/
int whisper_tokenize(Pointer ctx, String text, Pointer tokens, int n_max_tokens);
/** Largest language id (i.e. number of available languages - 1) */
int whisper_lang_max_id();
/**
* @return the id of the specified language, returns -1 if not found.
* Examples:
* "de" -> 2
* "german" -> 2
*/
int whisper_lang_id(String lang);
/** @return the short string of the specified language id (e.g. 2 -> "de"), returns nullptr if not found */
String whisper_lang_str(int id);
/**
* Use mel data at offset_ms to try and auto-detect the spoken language.
* Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first
* Returns the top language id or negative on failure
* If not null, fills the lang_probs array with the probabilities of all languages
* The array must be whisper_lang_max_id() + 1 in size
*
* ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69
*/
int whisper_lang_auto_detect(Pointer ctx, int offset_ms, int n_threads, float[] lang_probs);
int whisper_lang_auto_detect_with_state(Pointer ctx, Pointer state, int offset_ms, int n_threads, float[] lang_probs);
int whisper_n_len (Pointer ctx); // mel length
int whisper_n_len_from_state(Pointer state); // mel length
int whisper_n_vocab (Pointer ctx);
int whisper_n_text_ctx (Pointer ctx);
int whisper_n_audio_ctx (Pointer ctx);
int whisper_is_multilingual (Pointer ctx);
int whisper_model_n_vocab (Pointer ctx);
int whisper_model_n_audio_ctx (Pointer ctx);
int whisper_model_n_audio_state(Pointer ctx);
int whisper_model_n_audio_head (Pointer ctx);
int whisper_model_n_audio_layer(Pointer ctx);
int whisper_model_n_text_ctx (Pointer ctx);
int whisper_model_n_text_state (Pointer ctx);
int whisper_model_n_text_head (Pointer ctx);
int whisper_model_n_text_layer (Pointer ctx);
int whisper_model_n_mels (Pointer ctx);
int whisper_model_ftype (Pointer ctx);
int whisper_model_type (Pointer ctx);
/**
* Token logits obtained from the last call to whisper_decode().
* The logits for the last token are stored in the last row
* Rows: n_tokens
* Cols: n_vocab
*/
float[] whisper_get_logits (Pointer ctx);
float[] whisper_get_logits_from_state(Pointer state);
// Token Id -> String. Uses the vocabulary in the provided context
String whisper_token_to_str(Pointer ctx, int token);
String whisper_model_type_readable(Pointer ctx);
// Special tokens
int whisper_token_eot (Pointer ctx);
int whisper_token_sot (Pointer ctx);
int whisper_token_prev(Pointer ctx);
int whisper_token_solm(Pointer ctx);
int whisper_token_not (Pointer ctx);
int whisper_token_beg (Pointer ctx);
int whisper_token_lang(Pointer ctx, int lang_id);
// Task tokens
int whisper_token_translate (Pointer ctx);
int whisper_token_transcribe(Pointer ctx);
// Performance information from the default state.
void whisper_print_timings(Pointer ctx);
void whisper_reset_timings(Pointer ctx);
// Note: Even if `whisper_full_params is stripped back to just 4 ints, JNA throws "Invalid memory access"
// when `whisper_full_default_params()` tries to return a struct.
// WhisperFullParams whisper_full_default_params(int strategy);
/**
* Provides default params which can be used with `whisper_full()` etc.
* Because this function allocates memory for the params, the caller must call either:
* - call `whisper_free_params()`
* - `Native.free(Pointer.nativeValue(pointer));`
*
* @param strategy - WhisperSamplingStrategy.value
*/
Pointer whisper_full_default_params_by_ref(int strategy);
void whisper_free_params(Pointer params);
/**
* Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
* Not thread safe for same context
* Uses the specified decoding strategy to obtain the text.
*/
int whisper_full(Pointer ctx, WhisperFullParams params, final float[] samples, int n_samples);
int whisper_full_with_state(Pointer ctx, Pointer state, WhisperFullParams params, final float[] samples, int n_samples);
// Split the input audio in chunks and process each chunk separately using whisper_full_with_state()
// Result is stored in the default state of the context
// Not thread safe if executed in parallel on the same context.
// It seems this approach can offer some speedup in some cases.
// However, the transcription accuracy can be worse at the beginning and end of each chunk.
int whisper_full_parallel(Pointer ctx, WhisperFullParams params, final float[] samples, int n_samples, int n_processors);
/**
* Number of generated text segments.
* A segment can be a few words, a sentence, or even a paragraph.
* @param ctx Pointer to WhisperContext
*/
int whisper_full_n_segments (Pointer ctx);
/**
* @param state Pointer to WhisperState
*/
int whisper_full_n_segments_from_state(Pointer state);
/**
* Language id associated with the context's default state.
* @param ctx Pointer to WhisperContext
*/
int whisper_full_lang_id(Pointer ctx);
/** Language id associated with the provided state */
int whisper_full_lang_id_from_state(Pointer state);
/**
* Convert RAW PCM audio to log mel spectrogram but applies a Phase Vocoder to speed up the audio x2.
* The resulting spectrogram is stored inside the default state of the provided whisper context.
* @return 0 on success
*/
int whisper_pcm_to_mel_phase_vocoder(Pointer ctx, final float[] samples, int n_samples, int n_threads);
int whisper_pcm_to_mel_phase_vocoder_with_state(Pointer ctx, Pointer state, final float[] samples, int n_samples, int n_threads);
/** Get the start time of the specified segment. */
long whisper_full_get_segment_t0(Pointer ctx, int i_segment);
/** Get the start time of the specified segment from the state. */
long whisper_full_get_segment_t0_from_state(Pointer state, int i_segment);
/** Get the end time of the specified segment. */
long whisper_full_get_segment_t1(Pointer ctx, int i_segment);
/** Get the end time of the specified segment from the state. */
long whisper_full_get_segment_t1_from_state(Pointer state, int i_segment);
/** Get the text of the specified segment. */
String whisper_full_get_segment_text(Pointer ctx, int i_segment);
/** Get the text of the specified segment from the state. */
String whisper_full_get_segment_text_from_state(Pointer state, int i_segment);
/** Get the number of tokens in the specified segment. */
int whisper_full_n_tokens(Pointer ctx, int i_segment);
/** Get the number of tokens in the specified segment from the state. */
int whisper_full_n_tokens_from_state(Pointer state, int i_segment);
/** Get the token text of the specified token in the specified segment. */
String whisper_full_get_token_text(Pointer ctx, int i_segment, int i_token);
/** Get the token text of the specified token in the specified segment from the state. */
String whisper_full_get_token_text_from_state(Pointer ctx, Pointer state, int i_segment, int i_token);
/** Get the token ID of the specified token in the specified segment. */
int whisper_full_get_token_id(Pointer ctx, int i_segment, int i_token);
/** Get the token ID of the specified token in the specified segment from the state. */
int whisper_full_get_token_id_from_state(Pointer state, int i_segment, int i_token);
/** Get token data for the specified token in the specified segment. */
WhisperTokenData whisper_full_get_token_data(Pointer ctx, int i_segment, int i_token);
/** Get token data for the specified token in the specified segment from the state. */
WhisperTokenData whisper_full_get_token_data_from_state(Pointer state, int i_segment, int i_token);
/** Get the probability of the specified token in the specified segment. */
float whisper_full_get_token_p(Pointer ctx, int i_segment, int i_token);
/** Get the probability of the specified token in the specified segment from the state. */
float whisper_full_get_token_p_from_state(Pointer state, int i_segment, int i_token);
/**
* Benchmark function for memcpy.
*
* @param nThreads Number of threads to use for the benchmark.
* @return The result of the benchmark.
*/
int whisper_bench_memcpy(int nThreads);
/**
* Benchmark function for memcpy as a string.
*
* @param nThreads Number of threads to use for the benchmark.
* @return The result of the benchmark as a string.
*/
String whisper_bench_memcpy_str(int nThreads);
/**
* Benchmark function for ggml_mul_mat.
*
* @param nThreads Number of threads to use for the benchmark.
* @return The result of the benchmark.
*/
int whisper_bench_ggml_mul_mat(int nThreads);
/**
* Benchmark function for ggml_mul_mat as a string.
*
* @param nThreads Number of threads to use for the benchmark.
* @return The result of the benchmark as a string.
*/
String whisper_bench_ggml_mul_mat_str(int nThreads);
}

View File

@ -0,0 +1,24 @@
package io.github.ggerganov.whispercpp.callbacks;
import com.sun.jna.Callback;
import com.sun.jna.Pointer;
import io.github.ggerganov.whispercpp.WhisperContext;
import io.github.ggerganov.whispercpp.model.WhisperState;
/**
* Callback before the encoder starts.
* If not null, called before the encoder starts.
* If it returns false, the computation is aborted.
*/
public interface WhisperEncoderBeginCallback extends Callback {
/**
* Callback method before the encoder starts.
*
* @param ctx The whisper context.
* @param state The whisper state.
* @param user_data User data.
* @return True if the computation should proceed, false otherwise.
*/
boolean callback(Pointer ctx, Pointer state, Pointer user_data);
}

View File

@ -0,0 +1,25 @@
package io.github.ggerganov.whispercpp.callbacks;
import com.sun.jna.Callback;
import com.sun.jna.Pointer;
import io.github.ggerganov.whispercpp.model.WhisperTokenData;
/**
* Callback to filter logits.
* Can be used to modify the logits before sampling.
* If not null, called after applying temperature to logits.
*/
public interface WhisperLogitsFilterCallback extends Callback {
/**
* Callback method to filter logits.
*
* @param ctx The whisper context.
* @param state The whisper state.
* @param tokens The array of whisper_token_data.
* @param n_tokens The number of tokens.
* @param logits The array of logits.
* @param user_data User data.
*/
void callback(Pointer ctx, Pointer state, WhisperTokenData[] tokens, int n_tokens, float[] logits, Pointer user_data);
}

View File

@ -0,0 +1,24 @@
package io.github.ggerganov.whispercpp.callbacks;
import com.sun.jna.Callback;
import com.sun.jna.Pointer;
import io.github.ggerganov.whispercpp.WhisperContext;
import io.github.ggerganov.whispercpp.model.WhisperState;
/**
* Callback for the text segment.
* Called on every newly generated text segment.
* Use the whisper_full_...() functions to obtain the text segments.
*/
public interface WhisperNewSegmentCallback extends Callback {
/**
* Callback method for the text segment.
*
* @param ctx The whisper context.
* @param state The whisper state.
* @param n_new The number of newly generated text segments.
* @param user_data User data.
*/
void callback(Pointer ctx, Pointer state, int n_new, Pointer user_data);
}

View File

@ -0,0 +1,22 @@
package io.github.ggerganov.whispercpp.callbacks;
import com.sun.jna.Callback;
import com.sun.jna.Pointer;
import io.github.ggerganov.whispercpp.WhisperContext;
import io.github.ggerganov.whispercpp.model.WhisperState;
/**
* Callback for progress updates.
*/
public interface WhisperProgressCallback extends Callback {
/**
* Callback method for progress updates.
*
* @param ctx The whisper context.
* @param state The whisper state.
* @param progress The progress value.
* @param user_data User data.
*/
void callback(Pointer ctx, Pointer state, int progress, Pointer user_data);
}

View File

@ -0,0 +1,4 @@
package io.github.ggerganov.whispercpp.ggml;
public class GgmlTensor {
}

View File

@ -0,0 +1,18 @@
package io.github.ggerganov.whispercpp.ggml;
public enum GgmlType {
GGML_TYPE_F32,
GGML_TYPE_F16,
GGML_TYPE_Q4_0,
GGML_TYPE_Q4_1,
REMOVED_GGML_TYPE_Q4_2, // support has been removed
REMOVED_GGML_TYPE_Q4_3, // support has been removed
GGML_TYPE_Q5_0,
GGML_TYPE_Q5_1,
GGML_TYPE_Q8_0,
GGML_TYPE_Q8_1,
GGML_TYPE_I8,
GGML_TYPE_I16,
GGML_TYPE_I32,
GGML_TYPE_COUNT,
}

View File

@ -0,0 +1,10 @@
package io.github.ggerganov.whispercpp.model;
public enum EModel {
MODEL_UNKNOWN,
MODEL_TINY,
MODEL_BASE,
MODEL_SMALL,
MODEL_MEDIUM,
MODEL_LARGE,
}

View File

@ -0,0 +1,49 @@
package io.github.ggerganov.whispercpp;
import io.github.ggerganov.whispercpp.ggml.GgmlTensor;
import io.github.ggerganov.whispercpp.model.EModel;
public class WhisperModel {
// EModel type = EModel.MODEL_UNKNOWN;
//
// WhisperHParams hparams;
// WhisperFilters filters;
//
// // encoder.positional_embedding
// GgmlTensor e_pe;
//
// // encoder.conv1
// GgmlTensor e_conv_1_w;
// GgmlTensor e_conv_1_b;
//
// // encoder.conv2
// GgmlTensor e_conv_2_w;
// GgmlTensor e_conv_2_b;
//
// // encoder.ln_post
// GgmlTensor e_ln_w;
// GgmlTensor e_ln_b;
//
// // decoder.positional_embedding
// GgmlTensor d_pe;
//
// // decoder.token_embedding
// GgmlTensor d_te;
//
// // decoder.ln
// GgmlTensor d_ln_w;
// GgmlTensor d_ln_b;
//
// std::vector<whisper_layer_encoder> layers_encoder;
// std::vector<whisper_layer_decoder> layers_decoder;
//
// // context
// struct ggml_context * ctx;
//
// // the model memory buffer is read-only and can be shared between processors
// std::vector<uint8_t> * buf;
//
// // tensors
// int n_loaded;
// Map<String, GgmlTensor> tensors;
}

View File

@ -0,0 +1,62 @@
package io.github.ggerganov.whispercpp.model;
import com.sun.jna.Callback;
import com.sun.jna.Pointer;
import com.sun.jna.Structure;
public class WhisperModelLoader extends Structure {
public Pointer context;
public ReadFunction read;
public EOFFunction eof;
public CloseFunction close;
public static class ReadFunction implements Callback {
public Pointer invoke(Pointer ctx, Pointer output, int readSize) {
// TODO
return ctx;
}
}
public static class EOFFunction implements Callback {
public boolean invoke(Pointer ctx) {
// TODO
return false;
}
}
public static class CloseFunction implements Callback {
public void invoke(Pointer ctx) {
// TODO
}
}
// public WhisperModelLoader(Pointer p) {
// super(p);
// read = new ReadFunction();
// eof = new EOFFunction();
// close = new CloseFunction();
// read.setCallback(this);
// eof.setCallback(this);
// close.setCallback(this);
// read.write();
// eof.write();
// close.write();
// }
public WhisperModelLoader() {
super();
}
public interface ReadCallback extends Callback {
Pointer invoke(Pointer ctx, Pointer output, int readSize);
}
public interface EOFCallback extends Callback {
boolean invoke(Pointer ctx);
}
public interface CloseCallback extends Callback {
void invoke(Pointer ctx);
}
}

View File

@ -0,0 +1,4 @@
package io.github.ggerganov.whispercpp.model;
public class WhisperState {
}

View File

@ -0,0 +1,50 @@
package io.github.ggerganov.whispercpp.model;
import com.sun.jna.Structure;
import java.util.Arrays;
import java.util.List;
/**
* Structure representing token data.
*/
public class WhisperTokenData extends Structure {
/** Token ID. */
public int id;
/** Forced timestamp token ID. */
public int tid;
/** Probability of the token. */
public float p;
/** Log probability of the token. */
public float plog;
/** Probability of the timestamp token. */
public float pt;
/** Sum of probabilities of all timestamp tokens. */
public float ptsum;
/**
* Start time of the token (token-level timestamp data).
* Do not use if you haven't computed token-level timestamps.
*/
public long t0;
/**
* End time of the token (token-level timestamp data).
* Do not use if you haven't computed token-level timestamps.
*/
public long t1;
/** Voice length of the token. */
public float vlen;
@Override
protected List<String> getFieldOrder() {
return Arrays.asList("id", "tid", "p", "plog", "pt", "ptsum", "t0", "t1", "vlen");
}
}

View File

@ -0,0 +1,19 @@
package io.github.ggerganov.whispercpp.params;
import com.sun.jna.Structure;
import java.util.Arrays;
import java.util.List;
public class BeamSearchParams extends Structure {
/** ref: <a href="https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L265">...</a> */
public int beam_size;
/** ref: <a href="https://arxiv.org/pdf/2204.05424.pdf">...</a> */
public float patience;
@Override
protected List<String> getFieldOrder() {
return Arrays.asList("beam_size", "patience");
}
}

View File

@ -0,0 +1,30 @@
package io.github.ggerganov.whispercpp.params;
import com.sun.jna.IntegerType;
import java.util.function.BooleanSupplier;
public class CBool extends IntegerType implements BooleanSupplier {
public static final int SIZE = 1;
public static final CBool FALSE = new CBool(0);
public static final CBool TRUE = new CBool(1);
public CBool() {
this(0);
}
public CBool(long value) {
super(SIZE, value, true);
}
@Override
public boolean getAsBoolean() {
return intValue() == 1;
}
@Override
public String toString() {
return intValue() == 1 ? "true" : "false";
}
}

View File

@ -0,0 +1,16 @@
package io.github.ggerganov.whispercpp.params;
import com.sun.jna.Structure;
import java.util.Collections;
import java.util.List;
public class GreedyParams extends Structure {
/** <a href="https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L264">...</a> */
public int best_of;
@Override
protected List<String> getFieldOrder() {
return Collections.singletonList("best_of");
}
}

View File

@ -0,0 +1,10 @@
package io.github.ggerganov.whispercpp.params;
import java.util.List;
public class WhisperFilters {
int n_mel;
int n_fft;
List<Float> data;
}

View File

@ -0,0 +1,321 @@
package io.github.ggerganov.whispercpp.params;
import com.sun.jna.*;
import io.github.ggerganov.whispercpp.callbacks.WhisperEncoderBeginCallback;
import io.github.ggerganov.whispercpp.callbacks.WhisperLogitsFilterCallback;
import io.github.ggerganov.whispercpp.callbacks.WhisperNewSegmentCallback;
import io.github.ggerganov.whispercpp.callbacks.WhisperProgressCallback;
import java.util.Arrays;
import java.util.List;
/**
* Parameters for the whisper_full() function.
* If you change the order or add new parameters, make sure to update the default values in whisper.cpp:
* whisper_full_default_params()
*/
public class WhisperFullParams extends Structure {
public WhisperFullParams(Pointer p) {
super(p);
// super(p, ALIGN_MSVC);
// super(p, ALIGN_GNUC);
}
/** Sampling strategy for whisper_full() function. */
public int strategy;
/** Number of threads. (default = 4) */
public int n_threads;
/** Maximum tokens to use from past text as a prompt for the decoder. (default = 16384) */
public int n_max_text_ctx;
/** Start offset in milliseconds. (default = 0) */
public int offset_ms;
/** Audio duration to process in milliseconds. (default = 0) */
public int duration_ms;
/** Translate flag. (default = false) */
public CBool translate;
/** The compliment of translateMode() */
public void transcribeMode() {
translate = CBool.FALSE;
}
/** The compliment of transcribeMode() */
public void translateMode() {
translate = CBool.TRUE;
}
/** Flag to indicate whether to use past transcription (if any) as an initial prompt for the decoder. (default = true) */
public CBool no_context;
/** Flag to indicate whether to use past transcription (if any) as an initial prompt for the decoder. (default = true) */
public void enableContext(boolean enable) {
no_context = enable ? CBool.FALSE : CBool.TRUE;
}
/** Flag to force single segment output (useful for streaming). (default = false) */
public CBool single_segment;
/** Flag to force single segment output (useful for streaming). (default = false) */
public void singleSegment(boolean single) {
single_segment = single ? CBool.TRUE : CBool.FALSE;
}
/** Flag to print special tokens (e.g., &lt;SOT>, &lt;EOT>, &lt;BEG>, etc.). (default = false) */
public CBool print_special;
/** Flag to print special tokens (e.g., &lt;SOT>, &lt;EOT>, &lt;BEG>, etc.). (default = false) */
public void printSpecial(boolean enable) {
print_special = enable ? CBool.TRUE : CBool.FALSE;
}
/** Flag to print progress information. (default = true) */
public CBool print_progress;
/** Flag to print progress information. (default = true) */
public void printProgress(boolean enable) {
print_progress = enable ? CBool.TRUE : CBool.FALSE;
}
/** Flag to print results from within whisper.cpp (avoid it, use callback instead). (default = true) */
public CBool print_realtime;
/** Flag to print results from within whisper.cpp (avoid it, use callback instead). (default = true) */
public void printRealtime(boolean enable) {
print_realtime = enable ? CBool.TRUE : CBool.FALSE;
}
/** Flag to print timestamps for each text segment when printing realtime. (default = true) */
public CBool print_timestamps;
/** Flag to print timestamps for each text segment when printing realtime. (default = true) */
public void printTimestamps(boolean enable) {
print_timestamps = enable ? CBool.TRUE : CBool.FALSE;
}
/** [EXPERIMENTAL] Flag to enable token-level timestamps. (default = false) */
public CBool token_timestamps;
/** [EXPERIMENTAL] Flag to enable token-level timestamps. (default = false) */
public void tokenTimestamps(boolean enable) {
token_timestamps = enable ? CBool.TRUE : CBool.FALSE;
}
/** [EXPERIMENTAL] Timestamp token probability threshold (~0.01). (default = 0.01) */
public float thold_pt;
/** [EXPERIMENTAL] Timestamp token sum probability threshold (~0.01). */
public float thold_ptsum;
/** Maximum segment length in characters. (default = 0) */
public int max_len;
/** Flag to split on word rather than on token (when used with max_len). (default = false) */
public CBool split_on_word;
/** Flag to split on word rather than on token (when used with max_len). (default = false) */
public void splitOnWord(boolean enable) {
split_on_word = enable ? CBool.TRUE : CBool.FALSE;
}
/** Maximum tokens per segment (0, default = no limit) */
public int max_tokens;
/** Flag to speed up the audio by 2x using Phase Vocoder. (default = false) */
public CBool speed_up;
/** Flag to speed up the audio by 2x using Phase Vocoder. (default = false) */
public void speedUp(boolean enable) {
speed_up = enable ? CBool.TRUE : CBool.FALSE;
}
/** Overwrite the audio context size (0 = use default). */
public int audio_ctx;
/** Enable tinydiarize (default = false) */
public CBool tdrz_enable;
/** Enable tinydiarize (default = false) */
public void tdrzEnable(boolean enable) {
tdrz_enable = enable ? CBool.TRUE : CBool.FALSE;
}
/** Tokens to provide to the whisper decoder as an initial prompt.
* These are prepended to any existing text context from a previous call. */
public String initial_prompt;
/** Prompt tokens. (int*) */
public Pointer prompt_tokens;
public void setPromptTokens(int[] tokens) {
Memory mem = new Memory(tokens.length * 4L);
mem.write(0, tokens, 0, tokens.length);
prompt_tokens = mem;
}
/** Number of prompt tokens. */
public int prompt_n_tokens;
/** Language for auto-detection.
* For auto-detection, set to `null`, `""`, or "auto". */
public String language;
/** Flag to indicate whether to detect language automatically. */
public CBool detect_language;
/** Flag to indicate whether to detect language automatically. */
public void detectLanguage(boolean enable) {
detect_language = enable ? CBool.TRUE : CBool.FALSE;
}
// Common decoding parameters.
/** Flag to suppress blank tokens. */
public CBool suppress_blank;
public void suppressBlanks(boolean enable) {
suppress_blank = enable ? CBool.TRUE : CBool.FALSE;
}
/** Flag to suppress non-speech tokens. */
public CBool suppress_non_speech_tokens;
/** Flag to suppress non-speech tokens. */
public void suppressNonSpeechTokens(boolean enable) {
suppress_non_speech_tokens = enable ? CBool.TRUE : CBool.FALSE;
}
/** Initial decoding temperature. */
public float temperature;
/** Maximum initial timestamp. */
public float max_initial_ts;
/** Length penalty. */
public float length_penalty;
// Fallback parameters.
/** Temperature increment. */
public float temperature_inc;
/** Entropy threshold (similar to OpenAI's "compression_ratio_threshold"). */
public float entropy_thold;
/** Log probability threshold. */
public float logprob_thold;
/** No speech threshold. */
public float no_speech_thold;
/** Greedy decoding parameters. */
public GreedyParams greedy;
/**
* Beam search decoding parameters.
*/
public BeamSearchParams beam_search;
public void setBestOf(int bestOf) {
if (greedy == null) {
greedy = new GreedyParams();
}
greedy.best_of = bestOf;
}
public void setBeamSize(int beamSize) {
if (beam_search == null) {
beam_search = new BeamSearchParams();
}
beam_search.beam_size = beamSize;
}
public void setBeamSizeAndPatience(int beamSize, float patience) {
if (beam_search == null) {
beam_search = new BeamSearchParams();
}
beam_search.beam_size = beamSize;
beam_search.patience = patience;
}
/**
* Callback for every newly generated text segment.
* WhisperNewSegmentCallback
*/
public Pointer new_segment_callback;
/**
* User data for the new_segment_callback.
*/
public Pointer new_segment_callback_user_data;
/**
* Callback on each progress update.
* WhisperProgressCallback
*/
public Pointer progress_callback;
/**
* User data for the progress_callback.
*/
public Pointer progress_callback_user_data;
/**
* Callback each time before the encoder starts.
* WhisperEncoderBeginCallback
*/
public Pointer encoder_begin_callback;
/**
* User data for the encoder_begin_callback.
*/
public Pointer encoder_begin_callback_user_data;
/**
* Callback by each decoder to filter obtained logits.
* WhisperLogitsFilterCallback
*/
public Pointer logits_filter_callback;
/**
* User data for the logits_filter_callback.
*/
public Pointer logits_filter_callback_user_data;
public void setNewSegmentCallback(WhisperNewSegmentCallback callback) {
new_segment_callback = CallbackReference.getFunctionPointer(callback);
}
public void setProgressCallback(WhisperProgressCallback callback) {
progress_callback = CallbackReference.getFunctionPointer(callback);
}
public void setEncoderBeginCallbackeginCallbackCallback(WhisperEncoderBeginCallback callback) {
encoder_begin_callback = CallbackReference.getFunctionPointer(callback);
}
public void setLogitsFilterCallback(WhisperLogitsFilterCallback callback) {
logits_filter_callback = CallbackReference.getFunctionPointer(callback);
}
@Override
protected List<String> getFieldOrder() {
return Arrays.asList("strategy", "n_threads", "n_max_text_ctx", "offset_ms", "duration_ms", "translate",
"no_context", "single_segment",
"print_special", "print_progress", "print_realtime", "print_timestamps", "token_timestamps",
"thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "speed_up", "audio_ctx",
"tdrz_enable", "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language",
"suppress_blank", "suppress_non_speech_tokens", "temperature", "max_initial_ts", "length_penalty",
"temperature_inc", "entropy_thold", "logprob_thold", "no_speech_thold", "greedy", "beam_search",
"new_segment_callback", "new_segment_callback_user_data",
"progress_callback", "progress_callback_user_data",
"encoder_begin_callback", "encoder_begin_callback_user_data",
"logits_filter_callback", "logits_filter_callback_user_data");
}
}

View File

@ -0,0 +1,15 @@
package io.github.ggerganov.whispercpp.params;
public class WhisperHParams {
int n_vocab = 51864;
int n_audio_ctx = 1500;
int n_audio_state = 384;
int n_audio_head = 6;
int n_audio_layer = 4;
int n_text_ctx = 448;
int n_text_state = 384;
int n_text_head = 6;
int n_text_layer = 4;
int n_mels = 80;
int ftype = 1;
}

View File

@ -0,0 +1,10 @@
package io.github.ggerganov.whispercpp.params;
/** Available sampling strategies */
public enum WhisperSamplingStrategy {
/** similar to OpenAI's GreedyDecoder */
WHISPER_SAMPLING_GREEDY,
/** similar to OpenAI's BeamSearchDecoder */
WHISPER_SAMPLING_BEAM_SEARCH
}

View File

@ -0,0 +1,102 @@
package io.github.ggerganov.whispercpp;
import static org.junit.jupiter.api.Assertions.*;
import io.github.ggerganov.whispercpp.params.CBool;
import io.github.ggerganov.whispercpp.params.WhisperFullParams;
import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import javax.sound.sampled.AudioInputStream;
import javax.sound.sampled.AudioSystem;
import java.io.File;
import java.io.FileNotFoundException;
class WhisperCppTest {
private static WhisperCpp whisper = new WhisperCpp();
private static boolean modelInitialised = false;
@BeforeAll
static void init() throws FileNotFoundException {
// By default, models are loaded from ~/.cache/whisper/ and are usually named "ggml-${name}.bin"
// or you can provide the absolute path to the model file.
String modelName = "../../models/ggml-tiny.en.bin";
try {
whisper.initContext(modelName);
// whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
// whisper.getJavaDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
modelInitialised = true;
} catch (FileNotFoundException ex) {
System.out.println("Model " + modelName + " not found");
}
}
@Test
void testGetDefaultFullParams_BeamSearch() {
// When
WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
// Then
assertEquals(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH.ordinal(), params.strategy);
assertNotEquals(0, params.n_threads);
assertEquals(16384, params.n_max_text_ctx);
assertFalse(params.translate);
assertEquals(0.01f, params.thold_pt);
assertEquals(2, params.beam_search.beam_size);
assertEquals(-1.0f, params.beam_search.patience);
}
@Test
void testGetDefaultFullParams_Greedy() {
// When
WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
// Then
assertEquals(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY.ordinal(), params.strategy);
assertNotEquals(0, params.n_threads);
assertEquals(16384, params.n_max_text_ctx);
assertEquals(2, params.greedy.best_of);
}
@Test
void testFullTranscribe() throws Exception {
if (!modelInitialised) {
System.out.println("Model not initialised, skipping test");
return;
}
// Given
File file = new File(System.getProperty("user.dir"), "../../samples/jfk.wav");
AudioInputStream audioInputStream = AudioSystem.getAudioInputStream(file);
byte[] b = new byte[audioInputStream.available()];
float[] floats = new float[b.length / 2];
// WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
params.setProgressCallback((ctx, state, progress, user_data) -> System.out.println("progress: " + progress));
params.print_progress = CBool.FALSE;
// params.initial_prompt = "and so my fellow Americans um, like";
try {
audioInputStream.read(b);
for (int i = 0, j = 0; i < b.length; i += 2, j++) {
int intSample = (int) (b[i + 1]) << 8 | (int) (b[i]) & 0xFF;
floats[j] = intSample / 32767.0f;
}
// When
String result = whisper.fullTranscribe(params, floats);
// Then
System.err.println(result);
assertEquals("And so my fellow Americans ask not what your country can do for you " +
"ask what you can do for your country.",
result.replace(",", ""));
} finally {
audioInputStream.close();
}
}
}

View File

@ -0,0 +1,17 @@
package io.github.ggerganov.whispercpp;
import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.Test;
class WhisperJnaLibraryTest {
@Test
void testWhisperPrint_system_info() {
String systemInfo = WhisperCppJnaLibrary.instance.whisper_print_system_info();
// eg: "AVX = 1 | AVX2 = 1 | AVX512 = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0
// | WASM_SIMD = 0 | BLAS = 0 | SSE3 = 1 | VSX = 0 | COREML = 0 | "
System.out.println("System info: " + systemInfo);
assertTrue(systemInfo.length() > 10);
}
}

View File

@ -1 +1 @@
"use strict";var Module={};var ENVIRONMENT_IS_NODE=typeof process=="object"&&typeof process.versions=="object"&&typeof process.versions.node=="string";if(ENVIRONMENT_IS_NODE){var nodeWorkerThreads=require("worker_threads");var parentPort=nodeWorkerThreads.parentPort;parentPort.on("message",data=>onmessage({data:data}));var fs=require("fs");Object.assign(global,{self:global,require:require,Module:Module,location:{href:__filename},Worker:nodeWorkerThreads.Worker,importScripts:function(f){(0,eval)(fs.readFileSync(f,"utf8")+"//# sourceURL="+f)},postMessage:function(msg){parentPort.postMessage(msg)},performance:global.performance||{now:function(){return Date.now()}}})}var initializedJS=false;var pendingNotifiedProxyingQueues=[];function threadPrintErr(){var text=Array.prototype.slice.call(arguments).join(" ");if(ENVIRONMENT_IS_NODE){fs.writeSync(2,text+"\n");return}console.error(text)}function threadAlert(){var text=Array.prototype.slice.call(arguments).join(" ");postMessage({cmd:"alert",text:text,threadId:Module["_pthread_self"]()})}var err=threadPrintErr;self.alert=threadAlert;Module["instantiateWasm"]=(info,receiveInstance)=>{var instance=new WebAssembly.Instance(Module["wasmModule"],info);receiveInstance(instance);Module["wasmModule"]=null;return instance.exports};self.onunhandledrejection=e=>{throw e.reason??e};self.onmessage=e=>{try{if(e.data.cmd==="load"){Module["wasmModule"]=e.data.wasmModule;for(const handler of e.data.handlers){Module[handler]=function(){postMessage({cmd:"callHandler",handler:handler,args:[...arguments]})}}Module["wasmMemory"]=e.data.wasmMemory;Module["buffer"]=Module["wasmMemory"].buffer;Module["ENVIRONMENT_IS_PTHREAD"]=true;if(typeof e.data.urlOrBlob=="string"){importScripts(e.data.urlOrBlob)}else{var objectUrl=URL.createObjectURL(e.data.urlOrBlob);importScripts(objectUrl);URL.revokeObjectURL(objectUrl)}whisper_factory(Module).then(function(instance){Module=instance})}else if(e.data.cmd==="run"){Module["__performance_now_clock_drift"]=performance.now()-e.data.time;Module["__emscripten_thread_init"](e.data.pthread_ptr,0,0,1);Module["establishStackSpace"]();Module["PThread"].receiveObjectTransfer(e.data);Module["PThread"].threadInitTLS();if(!initializedJS){Module["__embind_initialize_bindings"]();pendingNotifiedProxyingQueues.forEach(queue=>{Module["executeNotifiedProxyingQueue"](queue)});pendingNotifiedProxyingQueues=[];initializedJS=true}try{Module["invokeEntryPoint"](e.data.start_routine,e.data.arg)}catch(ex){if(ex!="unwind"){if(ex instanceof Module["ExitStatus"]){if(Module["keepRuntimeAlive"]()){}else{Module["__emscripten_thread_exit"](ex.status)}}else{throw ex}}}}else if(e.data.cmd==="cancel"){if(Module["_pthread_self"]()){Module["__emscripten_thread_exit"](-1)}}else if(e.data.target==="setimmediate"){}else if(e.data.cmd==="processProxyingQueue"){if(initializedJS){Module["executeNotifiedProxyingQueue"](e.data.queue)}else{pendingNotifiedProxyingQueues.push(e.data.queue)}}else if(e.data.cmd){err("worker.js received unknown command "+e.data.cmd);err(e.data)}}catch(ex){if(Module["__emscripten_thread_crashed"]){Module["__emscripten_thread_crashed"]()}throw ex}};
"use strict";var Module={};var ENVIRONMENT_IS_NODE=typeof process=="object"&&typeof process.versions=="object"&&typeof process.versions.node=="string";if(ENVIRONMENT_IS_NODE){var nodeWorkerThreads=require("worker_threads");var parentPort=nodeWorkerThreads.parentPort;parentPort.on("message",data=>onmessage({data:data}));var fs=require("fs");Object.assign(global,{self:global,require:require,Module:Module,location:{href:__filename},Worker:nodeWorkerThreads.Worker,importScripts:f=>(0,eval)(fs.readFileSync(f,"utf8")+"//# sourceURL="+f),postMessage:msg=>parentPort.postMessage(msg),performance:global.performance||{now:Date.now}})}var initializedJS=false;function threadPrintErr(){var text=Array.prototype.slice.call(arguments).join(" ");if(ENVIRONMENT_IS_NODE){fs.writeSync(2,text+"\n");return}console.error(text)}function threadAlert(){var text=Array.prototype.slice.call(arguments).join(" ");postMessage({cmd:"alert",text:text,threadId:Module["_pthread_self"]()})}var err=threadPrintErr;self.alert=threadAlert;Module["instantiateWasm"]=(info,receiveInstance)=>{var module=Module["wasmModule"];Module["wasmModule"]=null;var instance=new WebAssembly.Instance(module,info);return receiveInstance(instance)};self.onunhandledrejection=e=>{throw e.reason||e};function handleMessage(e){try{if(e.data.cmd==="load"){let messageQueue=[];self.onmessage=e=>messageQueue.push(e);self.startWorker=instance=>{Module=instance;postMessage({"cmd":"loaded"});for(let msg of messageQueue){handleMessage(msg)}self.onmessage=handleMessage};Module["wasmModule"]=e.data.wasmModule;for(const handler of e.data.handlers){Module[handler]=(...args)=>{postMessage({cmd:"callHandler",handler:handler,args:args})}}Module["wasmMemory"]=e.data.wasmMemory;Module["buffer"]=Module["wasmMemory"].buffer;Module["ENVIRONMENT_IS_PTHREAD"]=true;if(typeof e.data.urlOrBlob=="string"){importScripts(e.data.urlOrBlob)}else{var objectUrl=URL.createObjectURL(e.data.urlOrBlob);importScripts(objectUrl);URL.revokeObjectURL(objectUrl)}whisper_factory(Module)}else if(e.data.cmd==="run"){Module["__emscripten_thread_init"](e.data.pthread_ptr,0,0,1);Module["__emscripten_thread_mailbox_await"](e.data.pthread_ptr);Module["establishStackSpace"]();Module["PThread"].receiveObjectTransfer(e.data);Module["PThread"].threadInitTLS();if(!initializedJS){Module["__embind_initialize_bindings"]();initializedJS=true}try{Module["invokeEntryPoint"](e.data.start_routine,e.data.arg)}catch(ex){if(ex!="unwind"){throw ex}}}else if(e.data.cmd==="cancel"){if(Module["_pthread_self"]()){Module["__emscripten_thread_exit"](-1)}}else if(e.data.target==="setimmediate"){}else if(e.data.cmd==="checkMailbox"){if(initializedJS){Module["checkMailbox"]()}}else if(e.data.cmd){err(`worker.js received unknown command ${e.data.cmd}`);err(e.data)}}catch(ex){if(Module["__emscripten_thread_crashed"]){Module["__emscripten_thread_crashed"]()}throw ex}}self.onmessage=handleMessage;

View File

@ -1,6 +1,6 @@
{
"name": "whisper.cpp",
"version": "1.4.1",
"version": "1.4.2",
"description": "Whisper speech recognition",
"main": "whisper.js",
"scripts": {

File diff suppressed because one or more lines are too long

View File

@ -1,5 +0,0 @@
Package: whisper-small-cpp
Architecture: amd64
Maintainer: Alexey Kharlamov <alexey@kharlamov.biz>
Description: Whisper Speech to Text Converter
Depends: libc6 (>= 2.2.1), intel-mkl

View File

@ -1,5 +1,9 @@
#import "coreml/whisper-encoder.h"
#import "coreml/whisper-encoder-impl.h"
#if !__has_feature(objc_arc)
#error This file must be compiled with automatic reference counting enabled (-fobjc-arc)
#endif
#import "whisper-encoder.h"
#import "whisper-encoder-impl.h"
#import <CoreML/CoreML.h>
@ -49,17 +53,11 @@ void whisper_coreml_encode(
error: nil
];
whisper_encoder_implOutput * outCoreML = [(__bridge id) ctx->data predictionFromLogmel_data:inMultiArray error:nil];
@autoreleasepool {
whisper_encoder_implOutput * outCoreML = [(__bridge id) ctx->data predictionFromLogmel_data:inMultiArray error:nil];
MLMultiArray * outMA = outCoreML.output;
//NSArray<NSNumber *> * shape = outMA.shape;
//NSArray<NSNumber *> * strides = outMA.strides;
//printf("shape: %ld %ld %ld %ld\n", [shape[0] longValue], [shape[1] longValue], [shape[2] longValue], [shape[3] longValue]);
//printf("strides: %ld %ld %ld %ld\n", [strides[0] longValue], [strides[1] longValue], [strides[2] longValue], [strides[3] longValue]);
memcpy(out, outMA.dataPointer, outMA.count * sizeof(float));
memcpy(out, outCoreML.output.dataPointer, outCoreML.output.count * sizeof(float));
}
}
#if __cplusplus

View File

@ -23,6 +23,7 @@ add_library(${TARGET} STATIC
common.cpp
common-ggml.h
common-ggml.cpp
grammar-parser.cpp
)
include(DefaultTargetOptions)
@ -69,4 +70,5 @@ else()
add_subdirectory(quantize)
add_subdirectory(talk)
add_subdirectory(talk-llama)
add_subdirectory(lsp)
endif()

View File

@ -9,6 +9,7 @@
#include "common.h"
#include "common-sdl.h"
#include "whisper.h"
#include "grammar-parser.h"
#include <sstream>
#include <cassert>
@ -21,6 +22,11 @@
#include <vector>
#include <map>
bool file_exists(const std::string & fname) {
std::ifstream f(fname.c_str());
return f.good();
}
// command-line parameters
struct whisper_params {
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
@ -30,8 +36,12 @@ struct whisper_params {
int32_t max_tokens = 32;
int32_t audio_ctx = 0;
float vad_thold = 0.6f;
float freq_thold = 100.0f;
float vad_thold = 0.6f;
float freq_thold = 100.0f;
float grammar_penalty = 100.0f;
grammar_parser::parse_state grammar_parsed;
bool speed_up = false;
bool translate = false;
@ -44,6 +54,8 @@ struct whisper_params {
std::string fname_out;
std::string commands;
std::string prompt;
std::string context;
std::string grammar;
};
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
@ -73,6 +85,9 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
else if (arg == "-cmd" || arg == "--commands") { params.commands = argv[++i]; }
else if (arg == "-p" || arg == "--prompt") { params.prompt = argv[++i]; }
else if (arg == "-ctx" || arg == "--context") { params.context = argv[++i]; }
else if ( arg == "--grammar") { params.grammar = argv[++i]; }
else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params);
@ -106,16 +121,30 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
fprintf(stderr, " -cmd FNAME, --commands FNAME [%-7s] text file with allowed commands\n", params.commands.c_str());
fprintf(stderr, " -p, --prompt [%-7s] the required activation prompt\n", params.prompt.c_str());
fprintf(stderr, " -ctx, --context [%-7s] sample text to help the transcription\n", params.context.c_str());
fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty);
fprintf(stderr, "\n");
}
std::string transcribe(whisper_context * ctx, const whisper_params & params, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) {
std::string transcribe(
whisper_context * ctx,
const whisper_params & params,
const std::vector<float> & pcmf32,
const std::string & grammar_rule,
float & logprob_min,
float & logprob_sum,
int & n_tokens,
int64_t & t_ms) {
const auto t_start = std::chrono::high_resolution_clock::now();
prob = 0.0f;
logprob_min = 0.0f;
logprob_sum = 0.0f;
n_tokens = 0;
t_ms = 0;
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
//whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH);
wparams.print_progress = false;
wparams.print_special = params.print_special;
@ -123,19 +152,37 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con
wparams.print_timestamps = !params.no_timestamps;
wparams.translate = params.translate;
wparams.no_context = true;
wparams.no_timestamps = params.no_timestamps;
wparams.single_segment = true;
wparams.max_tokens = params.max_tokens;
wparams.language = params.language.c_str();
wparams.n_threads = params.n_threads;
wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;
wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;
wparams.temperature = 0.4f;
wparams.temperature_inc = 1.0f;
wparams.greedy.best_of = 5;
wparams.beam_search.beam_size = 5;
wparams.initial_prompt = params.context.data();
const auto & grammar_parsed = params.grammar_parsed;
auto grammar_rules = grammar_parsed.c_rules();
if (!params.grammar_parsed.rules.empty() && !grammar_rule.empty()) {
wparams.grammar_rules = grammar_rules.data();
wparams.n_grammar_rules = grammar_rules.size();
wparams.i_start_rule = grammar_parsed.symbol_ids.at(grammar_rule);
wparams.grammar_penalty = params.grammar_penalty;
}
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
return "";
}
int prob_n = 0;
std::string result;
const int n_segments = whisper_full_n_segments(ctx);
@ -144,19 +191,17 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con
result += text;
const int n_tokens = whisper_full_n_tokens(ctx, i);
for (int j = 0; j < n_tokens; ++j) {
const int n = whisper_full_n_tokens(ctx, i);
for (int j = 0; j < n; ++j) {
const auto token = whisper_full_get_token_data(ctx, i, j);
prob += token.p;
++prob_n;
if(token.plog > 0.0f) exit(0);
logprob_min = std::min(logprob_min, token.plog);
logprob_sum += token.plog;
++n_tokens;
}
}
if (prob_n > 0) {
prob /= prob_n;
}
const auto t_end = std::chrono::high_resolution_clock::now();
t_ms = std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count();
@ -247,7 +292,7 @@ int process_command_list(struct whisper_context * ctx, audio_async &audio, const
fprintf(stderr, " ]\n");
}
std::string k_prompt = "select one from the available words: ";
std::string k_prompt = "select one from the available words: ";
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
if (i > 0) {
k_prompt += ", ";
@ -415,7 +460,9 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
bool is_running = true;
bool ask_prompt = true;
float prob = 0.0f;
float logprob_min = 0.0f;
float logprob_sum = 0.0f;
int n_tokens = 0;
std::vector<float> pcmf32_cur;
@ -453,7 +500,7 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
// detect the commands
audio.get(params.command_ms, pcmf32_cur);
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "", logprob_min, logprob_sum, n_tokens, t_ms));
const auto words = get_words(txt);
@ -489,18 +536,27 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
// general-purpose mode
// freely transcribe the voice into text
int process_general_transcription(struct whisper_context * ctx, audio_async &audio, const whisper_params &params) {
int process_general_transcription(struct whisper_context * ctx, audio_async & audio, const whisper_params & params) {
bool is_running = true;
bool have_prompt = false;
bool ask_prompt = true;
float prob0 = 0.0f;
float prob = 0.0f;
float logprob_min0 = 0.0f;
float logprob_min = 0.0f;
float logprob_sum0 = 0.0f;
float logprob_sum = 0.0f;
int n_tokens0 = 0;
int n_tokens = 0;
std::vector<float> pcmf32_cur;
std::vector<float> pcmf32_prompt;
const std::string k_prompt = "Ok Whisper, start listening for commands.";
std::string k_prompt = "Ok Whisper, start listening for commands.";
if (!params.prompt.empty()) {
k_prompt = params.prompt;
}
fprintf(stderr, "\n");
fprintf(stderr, "%s: general-purpose mode\n", __func__);
@ -533,9 +589,11 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud
// wait for activation phrase
audio.get(params.prompt_ms, pcmf32_cur);
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob0, t_ms));
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "prompt", logprob_min0, logprob_sum0, n_tokens0, t_ms));
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms);
const float p = 100.0f * std::exp(logprob_min0);
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms, p = %.2f%%)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms, p);
const float sim = similarity(txt, k_prompt);
@ -556,19 +614,30 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud
// we have heard the activation phrase, now detect the commands
audio.get(params.command_ms, pcmf32_cur);
//printf("len prompt: %.4f\n", pcmf32_prompt.size() / (float) WHISPER_SAMPLE_RATE);
//printf("len command: %.4f\n", pcmf32_cur.size() / (float) WHISPER_SAMPLE_RATE);
// prepend 3 second of silence
pcmf32_cur.insert(pcmf32_cur.begin(), 3.0f*WHISPER_SAMPLE_RATE, 0.0f);
// prepend the prompt audio
pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "root", logprob_min, logprob_sum, n_tokens, t_ms));
prob = 100.0f*(prob - prob0);
//const float p = 100.0f * std::exp((logprob - logprob0) / (n_tokens - n_tokens0));
const float p = 100.0f * std::exp(logprob_min);
//fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str());
// find the prompt in the text
float best_sim = 0.0f;
size_t best_len = 0;
for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
for (size_t n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
if (n >= txt.size()) {
break;
}
const auto prompt = txt.substr(0, n);
const float sim = similarity(prompt, k_prompt);
@ -581,9 +650,16 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud
}
}
const std::string command = ::trim(txt.substr(best_len));
fprintf(stdout, "%s: DEBUG: txt = '%s', prob = %.2f%%\n", __func__, txt.c_str(), p);
if (best_len == 0) {
fprintf(stdout, "%s: WARNING: command not recognized, try again\n", __func__);
} else {
// cut the prompt from the decoded text
const std::string command = ::trim(txt.substr(best_len));
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
}
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
fprintf(stdout, "\n");
}
@ -648,12 +724,36 @@ int main(int argc, char ** argv) {
int ret_val = 0;
if (!params.commands.empty()) {
ret_val = process_command_list(ctx, audio, params);
} else if (!params.prompt.empty()) {
ret_val = always_prompt_transcription(ctx, audio, params);
} else {
ret_val = process_general_transcription(ctx, audio, params);
if (!params.grammar.empty()) {
auto & grammar = params.grammar_parsed;
if (file_exists(params.grammar.c_str())) {
// read grammar from file
std::ifstream ifs(params.grammar.c_str());
const std::string txt = std::string((std::istreambuf_iterator<char>(ifs)), std::istreambuf_iterator<char>());
grammar = grammar_parser::parse(txt.c_str());
} else {
// read grammar from string
grammar = grammar_parser::parse(params.grammar.c_str());
}
// will be empty (default) if there are parse errors
if (grammar.rules.empty()) {
ret_val = 1;
} else {
fprintf(stderr, "%s: grammar:\n", __func__);
grammar_parser::print_grammar(stderr, grammar);
fprintf(stderr, "\n");
}
}
if (ret_val == 0) {
if (!params.commands.empty()) {
ret_val = process_command_list(ctx, audio, params);
} else if (!params.prompt.empty() && params.grammar_parsed.rules.empty()) {
ret_val = always_prompt_transcription(ctx, audio, params);
} else {
ret_val = process_general_transcription(ctx, audio, params);
}
}
audio.pause();

View File

@ -6,7 +6,6 @@
static const std::map<std::string, enum ggml_ftype> GGML_FTYPE_MAP = {
{"q4_0", GGML_FTYPE_MOSTLY_Q4_0},
{"q4_1", GGML_FTYPE_MOSTLY_Q4_1},
{"q4_2", GGML_FTYPE_MOSTLY_Q4_2},
{"q5_0", GGML_FTYPE_MOSTLY_Q5_0},
{"q5_1", GGML_FTYPE_MOSTLY_Q5_1},
{"q8_0", GGML_FTYPE_MOSTLY_Q8_0},
@ -46,7 +45,6 @@ bool ggml_common_quantize_0(
switch (ftype) {
case GGML_FTYPE_MOSTLY_Q4_0: qtype = GGML_TYPE_Q4_0; break;
case GGML_FTYPE_MOSTLY_Q4_1: qtype = GGML_TYPE_Q4_1; break;
case GGML_FTYPE_MOSTLY_Q4_2: qtype = GGML_TYPE_Q4_2; break;
case GGML_FTYPE_MOSTLY_Q5_0: qtype = GGML_TYPE_Q5_0; break;
case GGML_FTYPE_MOSTLY_Q5_1: qtype = GGML_TYPE_Q5_1; break;
case GGML_FTYPE_MOSTLY_Q8_0: qtype = GGML_TYPE_Q8_0; break;
@ -54,6 +52,11 @@ bool ggml_common_quantize_0(
case GGML_FTYPE_ALL_F32:
case GGML_FTYPE_MOSTLY_F16:
case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16:
case GGML_FTYPE_MOSTLY_Q2_K:
case GGML_FTYPE_MOSTLY_Q3_K:
case GGML_FTYPE_MOSTLY_Q4_K:
case GGML_FTYPE_MOSTLY_Q5_K:
case GGML_FTYPE_MOSTLY_Q6_K:
{
fprintf(stderr, "%s: invalid model type %d\n", __func__, ftype);
return false;
@ -171,10 +174,6 @@ bool ggml_common_quantize_0(
{
cur_size = ggml_quantize_q4_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
} break;
case GGML_TYPE_Q4_2:
{
cur_size = ggml_quantize_q4_2(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
} break;
case GGML_TYPE_Q5_0:
{
cur_size = ggml_quantize_q5_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
@ -193,6 +192,12 @@ bool ggml_common_quantize_0(
case GGML_TYPE_I16:
case GGML_TYPE_I32:
case GGML_TYPE_Q8_1:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_Q8_K:
case GGML_TYPE_COUNT:
{
fprintf(stderr, "%s: unsupported quantization type %d (%s)\n", __func__, ttype, ggml_type_name((ggml_type) ttype));

View File

@ -1,3 +1,5 @@
#define _USE_MATH_DEFINES // for M_PI
#include "common.h"
// third-party utilities
@ -6,39 +8,79 @@
#include "dr_wav.h"
#include <cmath>
#include <cstring>
#include <fstream>
#include <regex>
#include <locale>
#include <codecvt>
#include <sstream>
#ifndef M_PI
#define M_PI 3.14159265358979323846
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
// Function to check if the next argument exists
std::string get_next_arg(int& i, int argc, char** argv, const std::string& flag, gpt_params& params) {
if (i + 1 < argc && argv[i + 1][0] != '-') {
return argv[++i];
} else {
fprintf(stderr, "error: %s requires one argument.\n", flag.c_str());
gpt_print_usage(argc, argv, params);
exit(0);
}
}
bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
for (int i = 1; i < argc; i++) {
std::string arg = argv[i];
if (arg == "-s" || arg == "--seed") {
params.seed = std::stoi(argv[++i]);
params.seed = std::stoi(get_next_arg(i, argc, argv, arg, params));
} else if (arg == "-t" || arg == "--threads") {
params.n_threads = std::stoi(argv[++i]);
params.n_threads = std::stoi(get_next_arg(i, argc, argv, arg, params));
} else if (arg == "-ngl" || arg == "--gpu-layers" || arg == "--n-gpu-layers") {
params.n_gpu_layers = std::stoi(get_next_arg(i, argc, argv, arg, params));
} else if (arg == "-p" || arg == "--prompt") {
params.prompt = argv[++i];
params.prompt = get_next_arg(i, argc, argv, arg, params);
} else if (arg == "-n" || arg == "--n_predict") {
params.n_predict = std::stoi(argv[++i]);
params.n_predict = std::stoi(get_next_arg(i, argc, argv, arg, params));
} else if (arg == "--top_k") {
params.top_k = std::stoi(argv[++i]);
params.top_k = std::stoi(get_next_arg(i, argc, argv, arg, params));
} else if (arg == "--top_p") {
params.top_p = std::stof(argv[++i]);
params.top_p = std::stof(get_next_arg(i, argc, argv, arg, params));
} else if (arg == "--temp") {
params.temp = std::stof(argv[++i]);
params.temp = std::stof(get_next_arg(i, argc, argv, arg, params));
} else if (arg == "--repeat-last-n") {
params.repeat_last_n = std::stoi(get_next_arg(i, argc, argv, arg, params));
} else if (arg == "--repeat-penalty") {
params.repeat_penalty = std::stof(get_next_arg(i, argc, argv, arg, params));
} else if (arg == "-b" || arg == "--batch_size") {
params.n_batch = std::stoi(argv[++i]);
params.n_batch= std::stoi(get_next_arg(i, argc, argv, arg, params));
} else if (arg == "-m" || arg == "--model") {
params.model = argv[++i];
params.model = get_next_arg(i, argc, argv, arg, params);
} else if (arg == "-i" || arg == "--interactive") {
params.interactive = true;
} else if (arg == "-ip" || arg == "--interactive-port") {
params.interactive = true;
params.interactive_port = std::stoi(get_next_arg(i, argc, argv, arg, params));
} else if (arg == "-h" || arg == "--help") {
gpt_print_usage(argc, argv, params);
exit(0);
} else {
} else if (arg == "-f" || arg == "--file") {
get_next_arg(i, argc, argv, arg, params);
std::ifstream file(argv[i]);
if (!file) {
fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
break;
}
std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), back_inserter(params.prompt));
if (params.prompt.back() == '\n') {
params.prompt.pop_back();
}
} else if (arg == "-tt" || arg == "--token_test") {
params.token_test = get_next_arg(i, argc, argv, arg, params);
}
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
gpt_print_usage(argc, argv, params);
exit(0);
@ -55,12 +97,19 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, " -h, --help show this help message and exit\n");
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n");
fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
fprintf(stderr, " -ngl N, --gpu-layers N number of layers to offload to GPU on supported models (default: %d)\n", params.n_gpu_layers);
fprintf(stderr, " -p PROMPT, --prompt PROMPT\n");
fprintf(stderr, " prompt to start generation with (default: random)\n");
fprintf(stderr, " -f FNAME, --file FNAME\n");
fprintf(stderr, " load prompt from a file\n");
fprintf(stderr, " -tt TOKEN_TEST, --token_test TOKEN_TEST\n");
fprintf(stderr, " test tokenization\n");
fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d)\n", params.n_predict);
fprintf(stderr, " --top_k N top-k sampling (default: %d)\n", params.top_k);
fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", params.top_p);
fprintf(stderr, " --temp N temperature (default: %.1f)\n", params.temp);
fprintf(stderr, " --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled)\n", params.repeat_last_n);
fprintf(stderr, " --repeat-penalty N penalize repeat sequence of tokens (default: %.2f, 1.0 = disabled)\n", (double)params.repeat_penalty);
fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch);
fprintf(stderr, " -m FNAME, --model FNAME\n");
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
@ -101,6 +150,10 @@ std::string replace(const std::string & s, const std::string & from, const std::
return result;
}
void gpt_vocab::add_special_token(const std::string & token) {
special_tokens.push_back(token);
}
std::map<std::string, int32_t> json_parse(const std::string & fname) {
std::map<std::string, int32_t> result;
@ -192,54 +245,82 @@ std::map<std::string, int32_t> json_parse(const std::string & fname) {
return result;
}
std::string convert_to_utf8(const std::wstring & input) {
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
return converter.to_bytes(input);
}
std::wstring convert_to_wstring(const std::string & input) {
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
return converter.from_bytes(input);
}
void gpt_split_words(std::string str, std::vector<std::string>& words) {
const std::string pattern = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
const std::regex re(pattern);
std::smatch m;
while (std::regex_search(str, m, re)) {
for (auto x : m) {
words.push_back(x);
}
str = m.suffix();
}
}
std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text) {
std::vector<std::string> words;
// first split the text into words
{
std::string str = text;
std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
std::regex re(pat);
std::smatch m;
while (std::regex_search(str, m, re)) {
for (auto x : m) {
words.push_back(x);
// Generate the subpattern from the special_tokens vector if it's not empty
if (!vocab.special_tokens.empty()) {
const std::regex escape(R"([\[\\\^\$\.\|\?\*\+\(\)\{\}])");
std::string special_tokens_subpattern;
for (const auto & token : vocab.special_tokens) {
if (!special_tokens_subpattern.empty()) {
special_tokens_subpattern += "|";
}
special_tokens_subpattern += std::regex_replace(token, escape, R"(\$&)");
}
str = m.suffix();
std::regex re(special_tokens_subpattern);
std::smatch m;
// Split the text by special tokens.
while (std::regex_search(str, m, re)) {
// Split the substrings in-between special tokens into words.
gpt_split_words(m.prefix(), words);
// Add matched special tokens as words.
for (auto x : m) {
words.push_back(x);
}
str = m.suffix();
}
// Remaining text without special tokens will be handled below.
}
gpt_split_words(str, words);
}
// find the longest tokens that form the words:
// find the longest token that forms each word in words:
std::vector<gpt_vocab::id> tokens;
for (const auto & word : words) {
if (word.size() == 0) continue;
int i = 0;
int n = word.size();
while (i < n) {
int j = n;
while (j > i) {
auto it = vocab.token_to_id.find(word.substr(i, j-i));
if (it != vocab.token_to_id.end()) {
for (int i = 0; i < (int) word.size(); ){
for (int j = word.size() - 1; j >= i; j--){
auto cand = word.substr(i, j-i+1);
auto it = vocab.token_to_id.find(cand);
if (it != vocab.token_to_id.end()){ // word.substr(i, j-i+1) in vocab
tokens.push_back(it->second);
i = j;
i = j + 1;
break;
}
--j;
}
if (i == n) {
break;
}
if (j == i) {
auto sub = word.substr(i, 1);
if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) {
tokens.push_back(vocab.token_to_id.at(sub));
} else {
fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data());
else if (j == i){ // word.substr(i, 1) has no matching
fprintf(stderr, "%s: unknown token '%s'\n", __func__, word.substr(i, 1).data());
i++;
}
++i;
}
}
}
@ -247,6 +328,70 @@ std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::stri
return tokens;
}
std::vector<gpt_vocab::id> parse_tokens_from_string(const std::string& input, char delimiter) {
std::vector<gpt_vocab::id> output;
std::stringstream ss(input);
std::string token;
while (std::getline(ss, token, delimiter)) {
output.push_back(std::stoi(token));
}
return output;
}
std::map<std::string, std::vector<gpt_vocab::id>> extract_tests_from_file(const std::string & fpath_test){
if (fpath_test.empty()){
fprintf(stderr, "%s : No test file found.\n", __func__);
return std::map<std::string, std::vector<gpt_vocab::id>>();
}
std::map<std::string, std::vector<gpt_vocab::id>> tests;
auto fin = std::ifstream(fpath_test, std::ios_base::in);
const char * delimeter = " => ";
const char del_tok = ',';
std::string line;
while (std::getline(fin, line)) {
size_t delimiterPos = line.find(delimeter);
if (delimiterPos != std::string::npos) {
std::string text = line.substr(0, delimiterPos);
std::string s_tokens = line.substr(delimiterPos + std::strlen(delimeter));
tests[text] = parse_tokens_from_string(s_tokens, del_tok);
}
}
return tests;
}
void test_gpt_tokenizer(gpt_vocab & vocab, const std::string & fpath_test){
std::map<std::string, std::vector<gpt_vocab::id>> tests = extract_tests_from_file(fpath_test);
size_t n_fails = 0;
for (const auto & test : tests) {
std::vector<gpt_vocab::id> tokens = gpt_tokenize(vocab, test.first);
if (tokens != test.second){
n_fails++;
// print out failure cases
fprintf(stderr, "%s : failed test: '%s'\n", __func__, test.first.c_str());
fprintf(stderr, "%s : tokens in hf: ", __func__);
for (const auto & t : test.second) {
fprintf(stderr, "%s(%d), ", vocab.id_to_token[t].c_str(), t);
}
fprintf(stderr, "\n");
fprintf(stderr, "%s : tokens in ggml: ", __func__);
for (const auto & t : tokens) {
fprintf(stderr, "%s(%d), ", vocab.id_to_token[t].c_str(), t);
}
fprintf(stderr, "\n");
}
}
fprintf(stderr, "%s : %zu tests failed out of %zu tests.\n", __func__, n_fails, tests.size());
}
bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) {
printf("%s: loading vocab from '%s'\n", __func__, fname.c_str());
@ -346,6 +491,122 @@ gpt_vocab::id gpt_sample_top_k_top_p(
return logits_id[idx].second;
}
gpt_vocab::id gpt_sample_top_k_top_p_repeat(
const gpt_vocab & vocab,
const float * logits,
const int32_t * last_n_tokens_data,
size_t last_n_tokens_data_size,
int top_k,
double top_p,
double temp,
int repeat_last_n,
float repeat_penalty,
std::mt19937 & rng) {
int n_logits = vocab.id_to_token.size();
const auto * plogits = logits;
const auto last_n_tokens = std::vector<int32_t>(last_n_tokens_data, last_n_tokens_data + last_n_tokens_data_size);
if (temp <= 0) {
// select the token with the highest logit directly
float max_logit = plogits[0];
gpt_vocab::id max_id = 0;
for (int i = 1; i < n_logits; ++i) {
if (plogits[i] > max_logit) {
max_logit = plogits[i];
max_id = i;
}
}
return max_id;
}
std::vector<std::pair<double, gpt_vocab::id>> logits_id;
logits_id.reserve(n_logits);
{
const float scale = 1.0f/temp;
for (int i = 0; i < n_logits; ++i) {
// repetition penalty from ctrl paper (https://arxiv.org/abs/1909.05858)
// credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main
if (repeat_last_n > 0 && std::find(last_n_tokens.end()-repeat_last_n, last_n_tokens.end(), i) != last_n_tokens.end()) {
// if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
if (plogits[i] < 0.0f) {
logits_id.push_back(std::make_pair(plogits[i]*scale*repeat_penalty, i));
} else {
logits_id.push_back(std::make_pair(plogits[i]*scale/repeat_penalty, i));
}
} else {
logits_id.push_back(std::make_pair(plogits[i]*scale, i));
}
}
}
// find the top K tokens
std::partial_sort(
logits_id.begin(),
logits_id.begin() + top_k, logits_id.end(),
[](const std::pair<double, gpt_vocab::id> & a, const std::pair<double, gpt_vocab::id> & b) {
return a.first > b.first;
});
logits_id.resize(top_k);
double maxl = -INFINITY;
for (const auto & kv : logits_id) {
maxl = std::max(maxl, kv.first);
}
// compute probs for the top K tokens
std::vector<double> probs;
probs.reserve(logits_id.size());
double sum = 0.0;
for (const auto & kv : logits_id) {
double p = exp(kv.first - maxl);
probs.push_back(p);
sum += p;
}
// normalize the probs
for (auto & p : probs) {
p /= sum;
}
if (top_p < 1.0f) {
double cumsum = 0.0f;
for (int i = 0; i < top_k; i++) {
cumsum += probs[i];
if (cumsum >= top_p) {
top_k = i + 1;
probs.resize(top_k);
logits_id.resize(top_k);
break;
}
}
cumsum = 1.0/cumsum;
for (int i = 0; i < (int) probs.size(); i++) {
probs[i] *= cumsum;
}
}
// printf("\n");
// for (int i = 0; i < (int) probs.size(); i++) {
// for (int i = 0; i < 10; i++) {
// printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]);
// }
std::discrete_distribution<> dist(probs.begin(), probs.end());
int idx = dist(rng);
return logits_id[idx].second;
}
bool read_wav(const std::string & fname, std::vector<float>& pcmf32, std::vector<std::vector<float>>& pcmf32s, bool stereo) {
drwav wav;
std::vector<uint8_t> wav_data; // used for pipe input from stdin
@ -503,3 +764,46 @@ float similarity(const std::string & s0, const std::string & s1) {
return 1.0f - (dist / std::max(s0.size(), s1.size()));
}
bool sam_params_parse(int argc, char ** argv, sam_params & params) {
for (int i = 1; i < argc; i++) {
std::string arg = argv[i];
if (arg == "-s" || arg == "--seed") {
params.seed = std::stoi(argv[++i]);
} else if (arg == "-t" || arg == "--threads") {
params.n_threads = std::stoi(argv[++i]);
} else if (arg == "-m" || arg == "--model") {
params.model = argv[++i];
} else if (arg == "-i" || arg == "--inp") {
params.fname_inp = argv[++i];
} else if (arg == "-o" || arg == "--out") {
params.fname_out = argv[++i];
} else if (arg == "-h" || arg == "--help") {
sam_print_usage(argc, argv, params);
exit(0);
} else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
sam_print_usage(argc, argv, params);
exit(0);
}
}
return true;
}
void sam_print_usage(int argc, char ** argv, const sam_params & params) {
fprintf(stderr, "usage: %s [options]\n", argv[0]);
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help show this help message and exit\n");
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n");
fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
fprintf(stderr, " -m FNAME, --model FNAME\n");
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
fprintf(stderr, " -i FNAME, --inp FNAME\n");
fprintf(stderr, " input file (default: %s)\n", params.fname_inp.c_str());
fprintf(stderr, " -o FNAME, --out FNAME\n");
fprintf(stderr, " output file (default: %s)\n", params.fname_out.c_str());
fprintf(stderr, "\n");
}

View File

@ -11,23 +11,30 @@
#define COMMON_SAMPLE_RATE 16000
//
// CLI argument parsing
// GPT CLI argument parsing
//
struct gpt_params {
int32_t seed = -1; // RNG seed
int32_t seed = -1; // RNG seed
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
int32_t n_predict = 200; // new tokens to predict
int32_t n_batch = 8; // batch size for prompt processing
// sampling parameters
int32_t top_k = 40;
float top_p = 0.9f;
float temp = 0.9f;
int32_t top_k = 40;
float top_p = 0.9f;
float temp = 0.9f;
int32_t repeat_last_n = 64;
float repeat_penalty = 1.00f;
int32_t n_batch = 8; // batch size for prompt processing
std::string model = "models/gpt-2-117M/ggml-model.bin"; // model path
std::string prompt = "";
std::string token_test = "";
std::string model = "models/gpt-2-117M/ggml-model.bin"; // model path
std::string prompt;
bool interactive = false;
int32_t interactive_port = -1;
int32_t n_gpu_layers = 0;
};
bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
@ -53,11 +60,20 @@ struct gpt_vocab {
std::map<token, id> token_to_id;
std::map<id, token> id_to_token;
std::vector<std::string> special_tokens;
void add_special_token(const std::string & token);
};
// poor-man's JSON parsing
std::map<std::string, int32_t> json_parse(const std::string & fname);
std::string convert_to_utf8(const std::wstring & input);
std::wstring convert_to_wstring(const std::string & input);
void gpt_split_words(std::string str, std::vector<std::string>& words);
// split text into tokens
//
// ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53
@ -70,6 +86,15 @@ std::map<std::string, int32_t> json_parse(const std::string & fname);
//
std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text);
// test outputs of gpt_tokenize
//
// - compare with tokens generated by the huggingface tokenizer
// - test cases are chosen based on the model's main language (under 'prompt' directory)
// - if all sentences are tokenized identically, print 'All tests passed.'
// - otherwise, print sentence, huggingface tokens, ggml tokens
//
void test_gpt_tokenizer(gpt_vocab & vocab, const std::string & fpath_test);
// load the tokens from encoder.json
bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab);
@ -89,6 +114,18 @@ gpt_vocab::id gpt_sample_top_k_top_p(
double temp,
std::mt19937 & rng);
gpt_vocab::id gpt_sample_top_k_top_p_repeat(
const gpt_vocab & vocab,
const float * logits,
const int32_t * last_n_tokens_data,
size_t last_n_tokens_data_size,
int top_k,
double top_p,
double temp,
int repeat_last_n,
float repeat_penalty,
std::mt19937 & rng);
//
// Audio utils
//
@ -120,3 +157,20 @@ bool vad_simple(
// compute similarity between two strings using Levenshtein distance
float similarity(const std::string & s0, const std::string & s1);
//
// SAM argument parsing
//
struct sam_params {
int32_t seed = -1; // RNG seed
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
std::string model = "models/sam-vit-b/ggml-model-f16.bin"; // model path
std::string fname_inp = "img.jpg";
std::string fname_out = "img.out";
};
bool sam_params_parse(int argc, char ** argv, sam_params & params);
void sam_print_usage(int argc, char ** argv, const sam_params & params);

423
examples/grammar-parser.cpp Normal file
View File

@ -0,0 +1,423 @@
#include "grammar-parser.h"
#include <cstdint>
#include <cwchar>
#include <string>
#include <utility>
#include <stdexcept>
#include <exception>
namespace grammar_parser {
// NOTE: assumes valid utf8 (but checks for overrun)
// copied from whisper.cpp
std::pair<uint32_t, const char *> decode_utf8(const char * src) {
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
uint8_t first_byte = static_cast<uint8_t>(*src);
uint8_t highbits = first_byte >> 4;
int len = lookup[highbits];
uint8_t mask = (1 << (8 - len)) - 1;
uint32_t value = first_byte & mask;
const char * end = src + len; // may overrun!
const char * pos = src + 1;
for ( ; pos < end && *pos; pos++) {
value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
}
return std::make_pair(value, pos);
}
uint32_t get_symbol_id(parse_state & state, const char * src, size_t len) {
uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size());
auto result = state.symbol_ids.insert(std::make_pair(std::string(src, len), next_id));
return result.first->second;
}
uint32_t generate_symbol_id(parse_state & state, const std::string & base_name) {
uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size());
state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id;
return next_id;
}
void add_rule(
parse_state & state,
uint32_t rule_id,
const std::vector<whisper_grammar_element> & rule) {
if (state.rules.size() <= rule_id) {
state.rules.resize(rule_id + 1);
}
state.rules[rule_id] = rule;
}
bool is_word_char(char c) {
return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9');
}
std::pair<uint32_t, const char *> parse_hex(const char * src, int size) {
const char * pos = src;
const char * end = src + size;
uint32_t value = 0;
for ( ; pos < end && *pos; pos++) {
value <<= 4;
char c = *pos;
if ('a' <= c && c <= 'f') {
value += c - 'a' + 10;
} else if ('A' <= c && c <= 'F') {
value += c - 'A' + 10;
} else if ('0' <= c && c <= '9') {
value += c - '0';
} else {
break;
}
}
if (pos != end) {
throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src);
}
return std::make_pair(value, pos);
}
const char * parse_space(const char * src, bool newline_ok) {
const char * pos = src;
while (*pos == ' ' || *pos == '\t' || *pos == '#' ||
(newline_ok && (*pos == '\r' || *pos == '\n'))) {
if (*pos == '#') {
while (*pos && *pos != '\r' && *pos != '\n') {
pos++;
}
} else {
pos++;
}
}
return pos;
}
const char * parse_name(const char * src) {
const char * pos = src;
while (is_word_char(*pos)) {
pos++;
}
if (pos == src) {
throw std::runtime_error(std::string("expecting name at ") + src);
}
return pos;
}
std::pair<uint32_t, const char *> parse_char(const char * src) {
if (*src == '\\') {
switch (src[1]) {
case 'x': return parse_hex(src + 2, 2);
case 'u': return parse_hex(src + 2, 4);
case 'U': return parse_hex(src + 2, 8);
case 't': return std::make_pair('\t', src + 2);
case 'r': return std::make_pair('\r', src + 2);
case 'n': return std::make_pair('\n', src + 2);
case '\\':
case '"':
case '[':
case ']':
return std::make_pair(src[1], src + 2);
default:
throw std::runtime_error(std::string("unknown escape at ") + src);
}
} else if (*src) {
return decode_utf8(src);
}
throw std::runtime_error("unexpected end of input");
}
const char * parse_alternates(
parse_state & state,
const char * src,
const std::string & rule_name,
uint32_t rule_id,
bool is_nested);
const char * parse_sequence(
parse_state & state,
const char * src,
const std::string & rule_name,
std::vector<whisper_grammar_element> & out_elements,
bool is_nested) {
size_t last_sym_start = out_elements.size();
const char * pos = src;
while (*pos) {
if (*pos == '"') { // literal string
pos++;
last_sym_start = out_elements.size();
while (*pos != '"') {
auto char_pair = parse_char(pos);
pos = char_pair.second;
out_elements.push_back({WHISPER_GRETYPE_CHAR, char_pair.first});
}
pos = parse_space(pos + 1, is_nested);
} else if (*pos == '[') { // char range(s)
pos++;
enum whisper_gretype start_type = WHISPER_GRETYPE_CHAR;
if (*pos == '^') {
pos++;
start_type = WHISPER_GRETYPE_CHAR_NOT;
}
last_sym_start = out_elements.size();
while (*pos != ']') {
auto char_pair = parse_char(pos);
pos = char_pair.second;
enum whisper_gretype type = last_sym_start < out_elements.size()
? WHISPER_GRETYPE_CHAR_ALT
: start_type;
out_elements.push_back({type, char_pair.first});
if (pos[0] == '-' && pos[1] != ']') {
auto endchar_pair = parse_char(pos + 1);
pos = endchar_pair.second;
out_elements.push_back({WHISPER_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
}
}
pos = parse_space(pos + 1, is_nested);
} else if (is_word_char(*pos)) { // rule reference
const char * name_end = parse_name(pos);
uint32_t ref_rule_id = get_symbol_id(state, pos, name_end - pos);
pos = parse_space(name_end, is_nested);
last_sym_start = out_elements.size();
out_elements.push_back({WHISPER_GRETYPE_RULE_REF, ref_rule_id});
} else if (*pos == '(') { // grouping
// parse nested alternates into synthesized rule
pos = parse_space(pos + 1, true);
uint32_t sub_rule_id = generate_symbol_id(state, rule_name);
pos = parse_alternates(state, pos, rule_name, sub_rule_id, true);
last_sym_start = out_elements.size();
// output reference to synthesized rule
out_elements.push_back({WHISPER_GRETYPE_RULE_REF, sub_rule_id});
if (*pos != ')') {
throw std::runtime_error(std::string("expecting ')' at ") + pos);
}
pos = parse_space(pos + 1, is_nested);
} else if (*pos == '*' || *pos == '+' || *pos == '?') { // repetition operator
if (last_sym_start == out_elements.size()) {
throw std::runtime_error(std::string("expecting preceeding item to */+/? at ") + pos);
}
// apply transformation to previous symbol (last_sym_start to end) according to
// rewrite rules:
// S* --> S' ::= S S' |
// S+ --> S' ::= S S' | S
// S? --> S' ::= S |
uint32_t sub_rule_id = generate_symbol_id(state, rule_name);
std::vector<whisper_grammar_element> sub_rule;
// add preceding symbol to generated rule
sub_rule.insert(
sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end());
if (*pos == '*' || *pos == '+') {
// cause generated rule to recurse
sub_rule.push_back({WHISPER_GRETYPE_RULE_REF, sub_rule_id});
}
// mark start of alternate def
sub_rule.push_back({WHISPER_GRETYPE_ALT, 0});
if (*pos == '+') {
// add preceding symbol as alternate only for '+' (otherwise empty)
sub_rule.insert(
sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end());
}
sub_rule.push_back({WHISPER_GRETYPE_END, 0});
add_rule(state, sub_rule_id, sub_rule);
// in original rule, replace previous symbol with reference to generated rule
out_elements.resize(last_sym_start);
out_elements.push_back({WHISPER_GRETYPE_RULE_REF, sub_rule_id});
pos = parse_space(pos + 1, is_nested);
} else {
break;
}
}
return pos;
}
const char * parse_alternates(
parse_state & state,
const char * src,
const std::string & rule_name,
uint32_t rule_id,
bool is_nested) {
std::vector<whisper_grammar_element> rule;
const char * pos = parse_sequence(state, src, rule_name, rule, is_nested);
while (*pos == '|') {
rule.push_back({WHISPER_GRETYPE_ALT, 0});
pos = parse_space(pos + 1, true);
pos = parse_sequence(state, pos, rule_name, rule, is_nested);
}
rule.push_back({WHISPER_GRETYPE_END, 0});
add_rule(state, rule_id, rule);
return pos;
}
const char * parse_rule(parse_state & state, const char * src) {
const char * name_end = parse_name(src);
const char * pos = parse_space(name_end, false);
size_t name_len = name_end - src;
uint32_t rule_id = get_symbol_id(state, src, name_len);
const std::string name(src, name_len);
if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
throw std::runtime_error(std::string("expecting ::= at ") + pos);
}
pos = parse_space(pos + 3, true);
pos = parse_alternates(state, pos, name, rule_id, false);
if (*pos == '\r') {
pos += pos[1] == '\n' ? 2 : 1;
} else if (*pos == '\n') {
pos++;
} else if (*pos) {
throw std::runtime_error(std::string("expecting newline or end at ") + pos);
}
return parse_space(pos, true);
}
parse_state parse(const char * src) {
try {
parse_state state;
const char * pos = parse_space(src, true);
while (*pos) {
pos = parse_rule(state, pos);
}
return state;
} catch (const std::exception & err) {
fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what());
return parse_state();
}
}
void print_grammar_char(FILE * file, uint32_t c) {
if (0x20 <= c && c <= 0x7f) {
fprintf(file, "%c", static_cast<char>(c));
} else {
// cop out of encoding UTF-8
fprintf(file, "<U+%04X>", c);
}
}
bool is_char_element(whisper_grammar_element elem) {
switch (elem.type) {
case WHISPER_GRETYPE_CHAR: return true;
case WHISPER_GRETYPE_CHAR_NOT: return true;
case WHISPER_GRETYPE_CHAR_ALT: return true;
case WHISPER_GRETYPE_CHAR_RNG_UPPER: return true;
default: return false;
}
}
void print_rule_binary(FILE * file, const std::vector<whisper_grammar_element> & rule) {
for (auto elem : rule) {
switch (elem.type) {
case WHISPER_GRETYPE_END: fprintf(file, "END"); break;
case WHISPER_GRETYPE_ALT: fprintf(file, "ALT"); break;
case WHISPER_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break;
case WHISPER_GRETYPE_CHAR: fprintf(file, "CHAR"); break;
case WHISPER_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break;
case WHISPER_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
case WHISPER_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break;
}
switch (elem.type) {
case WHISPER_GRETYPE_END:
case WHISPER_GRETYPE_ALT:
case WHISPER_GRETYPE_RULE_REF:
fprintf(file, "(%u) ", elem.value);
break;
case WHISPER_GRETYPE_CHAR:
case WHISPER_GRETYPE_CHAR_NOT:
case WHISPER_GRETYPE_CHAR_RNG_UPPER:
case WHISPER_GRETYPE_CHAR_ALT:
fprintf(file, "(\"");
print_grammar_char(file, elem.value);
fprintf(file, "\") ");
break;
}
}
fprintf(file, "\n");
}
void print_rule(
FILE * file,
uint32_t rule_id,
const std::vector<whisper_grammar_element> & rule,
const std::map<uint32_t, std::string> & symbol_id_names) {
if (rule.empty() || rule.back().type != WHISPER_GRETYPE_END) {
throw std::runtime_error(
"malformed rule, does not end with WHISPER_GRETYPE_END: " + std::to_string(rule_id));
}
fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str());
for (size_t i = 0, end = rule.size() - 1; i < end; i++) {
whisper_grammar_element elem = rule[i];
switch (elem.type) {
case WHISPER_GRETYPE_END:
throw std::runtime_error(
"unexpected end of rule: " + std::to_string(rule_id) + "," +
std::to_string(i));
case WHISPER_GRETYPE_ALT:
fprintf(file, "| ");
break;
case WHISPER_GRETYPE_RULE_REF:
fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str());
break;
case WHISPER_GRETYPE_CHAR:
fprintf(file, "[");
print_grammar_char(file, elem.value);
break;
case WHISPER_GRETYPE_CHAR_NOT:
fprintf(file, "[^");
print_grammar_char(file, elem.value);
break;
case WHISPER_GRETYPE_CHAR_RNG_UPPER:
if (i == 0 || !is_char_element(rule[i - 1])) {
throw std::runtime_error(
"WHISPER_GRETYPE_CHAR_RNG_UPPER without preceding char: " +
std::to_string(rule_id) + "," + std::to_string(i));
}
fprintf(file, "-");
print_grammar_char(file, elem.value);
break;
case WHISPER_GRETYPE_CHAR_ALT:
if (i == 0 || !is_char_element(rule[i - 1])) {
throw std::runtime_error(
"WHISPER_GRETYPE_CHAR_ALT without preceding char: " +
std::to_string(rule_id) + "," + std::to_string(i));
}
print_grammar_char(file, elem.value);
break;
}
if (is_char_element(elem)) {
switch (rule[i + 1].type) {
case WHISPER_GRETYPE_CHAR_ALT:
case WHISPER_GRETYPE_CHAR_RNG_UPPER:
break;
default:
fprintf(file, "] ");
}
}
}
fprintf(file, "\n");
}
void print_grammar(FILE * file, const parse_state & state) {
try {
std::map<uint32_t, std::string> symbol_id_names;
for (auto kv : state.symbol_ids) {
symbol_id_names[kv.second] = kv.first;
}
for (size_t i = 0, end = state.rules.size(); i < end; i++) {
// fprintf(file, "%zu: ", i);
// print_rule_binary(file, state.rules[i]);
print_rule(file, uint32_t(i), state.rules[i], symbol_id_names);
// fprintf(file, "\n");
}
} catch (const std::exception & err) {
fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what());
}
}
std::vector<const whisper_grammar_element *> parse_state::c_rules() const{
std::vector<const whisper_grammar_element *> ret;
for (const auto & rule : rules) {
ret.push_back(rule.data());
}
return ret;
}
}

29
examples/grammar-parser.h Normal file
View File

@ -0,0 +1,29 @@
// Implements a parser for an extended Backus-Naur form (BNF), producing the
// binary context-free grammar format specified by whisper.h. Supports character
// ranges, grouping, and repetition operators. As an example, a grammar for
// arithmetic might look like:
//
// root ::= expr
// expr ::= term ([-+*/] term)*
// term ::= num | "(" space expr ")" space
// num ::= [0-9]+ space
// space ::= [ \t\n]*
#pragma once
#include "whisper.h"
#include <vector>
#include <map>
#include <cstdint>
#include <string>
namespace grammar_parser {
struct parse_state {
std::map<std::string, uint32_t> symbol_ids;
std::vector<std::vector<whisper_grammar_element>> rules;
std::vector<const whisper_grammar_element *> c_rules() const;
};
parse_state parse(const char * src);
void print_grammar(FILE * file, const parse_state & state);
}

View File

@ -0,0 +1,9 @@
if (WHISPER_SDL2)
# stream
set(TARGET lsp)
add_executable(${TARGET} lsp.cpp)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE common common-sdl whisper ${CMAKE_THREAD_LIBS_INIT})
endif ()

104
examples/lsp/README.md Normal file
View File

@ -0,0 +1,104 @@
# Language Server
This example consists of a simple language server to expose both unguided
and guided (command) transcriptions by sending json messages over stdout/stdin
as well as a rather robust vim plugin that makes use of the language server.
## Vim plugin quick start
Compile the language server with
```bash
make lsp
```
Install the plugin itself by copying or symlinking whisper.vim into ~/.vim/autoload/
In your vimrc, set the path of your whisper.cpp directory and optionally add some keybinds.
```vim
let g:whisper_dir = "~/whisper.cpp"
" Start listening for commands when Ctrl - g is pressed in normal mode
nnoremap <C-G> call whisper#requestCommands()<CR>
" Start unguided transcription when Ctrl - g is pressed in insert mode
inoremap <C-G> <Cmd>call whisper#doTranscription()<CR>
```
## Vim plugin usage
The vim plugin was designed to closely follow the mnemonics of vim
`s:spoken_dict` is used to translate keys to their spoken form.
Keys corresponding to a string use that spoken value normally and when a motion is expected, but use the key itself when a character is expected.
Keys corresponding to a dict, like `i`, can have manual difinitions given to each possible commandset.
0 is normal (insert), 1 is motion (inside), 2 is it's usage as a single key ([till] i), and 3 is it's usage in an area selection (s -> [around] sentence)
Some punctuation items, like `-` are explicitly given pronunciations to prevent them from being picked as punctuation instead of an actual command word.
Not all commands will tokenize to a single token and this can interfere with interpretation. "yank" as an example, takes multiple tokens and correspondingly, will give more accurate detection when only the first "ya" is used. While it could be changed to something else that is a single token (copy), value was placed on maintaining vim mnemonics.
Commands that would normally move the editor into insert mode (insert, append, open, change) will begin unguided transcription.
Unguided transcription will end when a speech segment ends in exit.
Presence of punctuation can be designated by whether or not you add a pause between the previous speech segment and exit.
Exiting only occurs if exit is the last word, so "Take the first exit on your right" would not cause transcription to end.
After a command is evaluated, the plugin will continue listening for the next command.
While in command mode, "Exit" will end listening.
A best effort approach is taken to keep track of audio that is recorded while a previous chunk is still processing and immediately interpret it afterwards, but the current voice detection still needs a fairly sizable gap to determine when a command has been spoken.
Log information is sent to a special `whisper_log` buffer and can be accessed with
```vim
:e whisper_log
```
## Vim plugin configuration
`g:whisper_dir`
A full path to the whisper.cpp repo. It can be expanded in the definition like so:
```vim
let g:whisper_dir = expand("~/whisper.cpp/")
```
(The WHISPER_CPP_HOME environment variable is also checked for users of the existing whisper.nvim script)
`g:whisper_lsp_path`
Can be used to manually set the path to the language server.
If not defined, it will be inferred from the above whisper_dir
`g:whisper_model_path`
A full path to the model to load. If not defined, it will default to ggml-base.en.bin
`g:whisper_user_commands`
A dictionary of spoken commands that correspond to either strings or funcrefs.
This can be used to create connections with other user plugins, for example
```vim
let g:whisper_user_commands = {"gen": "llama#doLlamaGen"}
```
will trigger the llama.cpp plugin to begin generation when "gen" is spoken
## Language server methods
`registerCommandset`
`params` is a list of strings that should be checked for with this commandset. The server prepends a space to these strings before tokenizing.
Responds with
`result.index` an integer index for the commandset registered, which should be included when initiating a guided transcription to select this commandset.
Will return an error if any of the commands in the commandset have duplicate tokenizations
`guided`
`params.commandset_index` An index returned by a corresponding commandset registration. If not set, the most recently registered commandset is used.
`params.timestamp` A positive unsigned integer which designates a point in time which audio should begin processing from. If left blank, the start point of audio processing will be the moment the message is recieved. This should be left blank unless you have a timestamp from a previous response.
Responds with
`result.command_index` The numerical index (starting from 0) of the detected command in the selected commandset
`result.command_text` A string containing the command as provided in the commandset
`result.timestamp` A positive unsigned integer that designates the point in time which audio stopped being processed at. Pass this timestamp back in a subsequent message to mask the latency of transcription.
`unguided`
`params.no_context` Sets the corresponding whisper `no_context` param. Defaults to true. Might provide more accurate results for consecutive unguided transcriptions if those after the first are set to false.
`params.prompt` If provided, sets the initial prompt used during transcription.
`params.timestamp` A positive unsigned integer which designates a point in time which audio should begin processing from. If left blank, the start point of audio processing will be the moment the message is recieved. This should be left blank unless you have a timestamp from a previous response.
Responds with
`result.transcription` A string containing the transcribed text. N.B. This will almost always start with a space due to how text is tokenized.
`result.timestamp` A positive unsigned integer that designates the point in time which audio stopped being processed at. Pass this timestamp back in a subsequent message to mask the latency of transcription.

24596
examples/lsp/json.hpp Normal file

File diff suppressed because it is too large Load Diff

458
examples/lsp/lsp.cpp Normal file
View File

@ -0,0 +1,458 @@
#include "common.h"
#include "common-sdl.h"
#include "whisper.h"
#include "json.hpp"
#include <iostream>
#include <cassert>
#include <cstdio>
#include <string>
#include <thread>
#include <vector>
#include <deque>
#include <set>
using json = nlohmann::json;
// command-line parameters
struct whisper_params {
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
int32_t prompt_ms = 5000;
int32_t command_ms = 8000;
int32_t capture_id = -1;
int32_t max_tokens = 32;
int32_t audio_ctx = 0;
float vad_thold = 0.6f;
float freq_thold = 100.0f;
bool speed_up = false;
bool translate = false;
bool print_special = false;
bool print_energy = false;
std::string language = "en";
std::string model = "models/ggml-base.en.bin";
};
struct command {
std::vector<whisper_token> tokens;
std::string plaintext;
};
struct commandset {
std::vector<struct command> commands;
std::vector<whisper_token> prompt_tokens;
// TODO: Store longest command?
// Multi-token commands should have probabilities of subsequent logits
// given that the prior logit is correct.
// In this case, all commands must be iterated.
// This however, is likely highly involved as different tokens
// almost certainly have different spoken lengths
// It would also have performance implications equivalent to a beam search
};
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
for (int i = 1; i < argc; i++) {
std::string arg = argv[i];
if (arg == "-h" || arg == "--help") {
whisper_print_usage(argc, argv, params);
exit(0);
}
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
else if (arg == "-pms" || arg == "--prompt-ms") { params.prompt_ms = std::stoi(argv[++i]); }
else if (arg == "-cms" || arg == "--command-ms") { params.command_ms = std::stoi(argv[++i]); }
else if (arg == "-c" || arg == "--capture") { params.capture_id = std::stoi(argv[++i]); }
else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); }
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params);
exit(0);
}
}
return true;
}
void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params) {
fprintf(stderr, "\n");
fprintf(stderr, "usage: %s [options]\n", argv[0]);
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
fprintf(stderr, " -pms N, --prompt-ms N [%-7d] prompt duration in milliseconds\n", params.prompt_ms);
fprintf(stderr, " -cms N, --command-ms N [%-7d] command duration in milliseconds\n", params.command_ms);
fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, "\n");
}
uint64_t wait_for_vad(audio_async & audio, json jparams, const whisper_params & params, uint64_t maxlength_ms, std::vector<float> & pcmf32) {
using namespace std::chrono;
uint64_t time_now = time_point_cast<milliseconds>(system_clock::now()).time_since_epoch().count();
uint64_t start_time = time_now;
if (jparams.contains("timestamp")) {
start_time = jparams.at("timestamp");
}
if(time_now - start_time < 500) {
//wait for a backlog of audio
std::this_thread::sleep_for(milliseconds(500 - (time_now - start_time)));
time_now = time_point_cast<milliseconds>(system_clock::now()).time_since_epoch().count();
} else if (time_now - start_time > 1000) {
audio.get(time_now-start_time, pcmf32);
size_t max_offset = pcmf32.size() - WHISPER_SAMPLE_RATE;
for(size_t offset=0;offset < max_offset;offset+=WHISPER_SAMPLE_RATE/10) {
std::vector<float> audio_chunk(&pcmf32[offset], &pcmf32[offset+WHISPER_SAMPLE_RATE]);
if(::vad_simple(audio_chunk, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
pcmf32.resize(offset+WHISPER_SAMPLE_RATE);
if (offset*1000/WHISPER_SAMPLE_RATE+1000 > maxlength_ms) {
//remove samples from the beginning
pcmf32.erase(pcmf32.begin(),pcmf32.end()-(maxlength_ms*WHISPER_SAMPLE_RATE/1000));
fprintf(stderr, "Shortened samples");
}
return start_time + offset*1000/WHISPER_SAMPLE_RATE+1000;
}
}
}
size_t window_duration = std::max((uint64_t)1000, time_now-start_time);
audio.get(window_duration, pcmf32);
while (!::vad_simple(pcmf32, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
std::this_thread::sleep_for(milliseconds(100));
time_now = time_point_cast<milliseconds>(system_clock::now()).time_since_epoch().count();
window_duration = std::max((uint64_t)1000,time_now-start_time);
audio.get(window_duration, pcmf32);
}
if (time_now - start_time > maxlength_ms) {
audio.get(maxlength_ms, pcmf32);
} else {
audio.get(time_now - start_time, pcmf32);
}
return time_now;
}
json unguided_transcription(struct whisper_context * ctx, audio_async &audio, json jparams, const whisper_params &params) {
std::vector<whisper_token> prompt_tokens;
std::vector<float> pcmf32;
uint64_t unprocessed_audio_timestamp = wait_for_vad(audio, jparams, params, 10000U, pcmf32);
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
if (jparams.contains("prompt")) {
// unlikely to see much use. Under normal circumstances, no_context would be set to false
std::string prompt = jparams.at("prompt");
prompt_tokens.resize(1024);
int n = whisper_tokenize(ctx, prompt.c_str(), prompt_tokens.data(), 1024);
prompt_tokens.resize(n);
wparams.prompt_tokens = prompt_tokens.data();
wparams.prompt_n_tokens = prompt_tokens.size();
}
wparams.print_progress = false;
wparams.print_special = params.print_special;
wparams.print_realtime = false;
wparams.print_timestamps = false;
wparams.translate = params.translate;
wparams.no_context = jparams.value("no_context", true);
wparams.single_segment = true;
wparams.max_tokens = params.max_tokens;
wparams.language = params.language.c_str();
wparams.n_threads = params.n_threads;
wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;
wparams.suppress_non_speech_tokens = true;
// run the transformer and a single decoding pass
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
fprintf(stderr, "%s: ERROR: whisper_full() failed\n", __func__);
throw json{
{"code", -32803},
{"message", "ERROR: whisper_full() failed"}
};
}
std::string result = whisper_full_get_segment_text(ctx,0);
return json {
{"transcription", result},
{"timestamp", unprocessed_audio_timestamp}
};
}
// command-list mode
// guide the transcription to match the most likely command from a provided list
json guided_transcription(struct whisper_context * ctx, audio_async &audio, const whisper_params &params, json jparams, std::vector<struct commandset> commandset_list) {
struct commandset cs = commandset_list[jparams.value("commandset_index", commandset_list.size()-1)];
std::vector<float> pcmf32;
uint64_t unprocessed_audio_timestamp = wait_for_vad(audio, jparams, params, 2000U, pcmf32);
fprintf(stderr, "%s: Speech detected! Processing ...\n", __func__);
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
wparams.print_progress = false;
wparams.print_special = params.print_special;
wparams.print_realtime = false;
wparams.print_timestamps = false;
wparams.translate = params.translate;
wparams.no_context = true;
wparams.single_segment = true;
wparams.max_tokens = 1;
wparams.language = params.language.c_str();
wparams.n_threads = params.n_threads;
wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;
// TODO: Do some time testing. Does an overly long prompt slow down processing?
// Set up command sets/precompute prompts
wparams.prompt_tokens = cs.prompt_tokens.data();
wparams.prompt_n_tokens = cs.prompt_tokens.size();
// TODO: properly expose as option
wparams.suppress_non_speech_tokens = true;
// run the transformer and a single decoding pass
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
fprintf(stderr, "%s: ERROR: whisper_full() failed\n", __func__);
throw json{
{"code", -32803},
{"message", "ERROR: whisper_full() failed"}//TODO: format string (sprintf?)
};
}
// estimate command probability
// NOTE: not optimal
{
const auto * logits = whisper_get_logits(ctx);
std::vector<float> probs(whisper_n_vocab(ctx), 0.0f);
// compute probs from logits via softmax
{
float max = -1e9;
for (int i = 0; i < (int) probs.size(); ++i) {
max = std::max(max, logits[i]);
}
float sum = 0.0f;
for (int i = 0; i < (int) probs.size(); ++i) {
probs[i] = expf(logits[i] - max);
sum += probs[i];
}
for (int i = 0; i < (int) probs.size(); ++i) {
probs[i] /= sum;
}
}
std::vector<std::pair<float, int>> probs_id;
// In my testing, the most verbose token is always the desired.
// TODO: Trim commandset struct once efficacy has been verified
for (int i = 0; i < (int) cs.commands.size(); ++i) {
probs_id.emplace_back(probs[cs.commands[i].tokens[0]], i);
}
// sort descending
{
using pair_type = decltype(probs_id)::value_type;
std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) {
return a.first > b.first;
});
}
int id = probs_id[0].second;
return json{
{"command_index", id},
{"command_text", cs.commands[id].plaintext},
{"timestamp", unprocessed_audio_timestamp},
};
}
}
json register_commandset(struct whisper_context * ctx, json jparams, std::vector<struct commandset> &commandset_list) {
// TODO: check for token collision
struct commandset cs;
std::string k_prompt = " select one from the available words: ";
std::set<whisper_token> token_set;
whisper_token tokens[32];
for (std::string s : jparams) {
std::vector<whisper_token> token_vec;
// The existing command implementation uses a nested for loop to tokenize single characters
// I fail to see the purpose of this when ' a' has a wholly different pronunciation than the start of ' apple'
const int n = whisper_tokenize(ctx, (" " + s).c_str(), tokens, 32);
if (n < 0) {
fprintf(stderr, "%s: error: failed to tokenize command '%s'\n", __func__, s.c_str());
return 3;
}
token_vec.push_back(tokens[0]);
if (!token_set.insert(tokens[0]).second) {
fprintf(stderr, "%s: warning: %s is a duplicate of an existing token\n", __func__, s.c_str());
throw json{
{"code",-31000},
{"message", "Duplicate token in token set: " + s}
};
}
if (n > 1) {// empty string if n=0? Should never occur
fprintf(stderr, "%s: error: command is more than a single token: %s\n", __func__, s.c_str());
}
struct command command = {token_vec, s};
cs.commands.push_back(command);
k_prompt += s;
}
k_prompt = k_prompt.substr(0,k_prompt.length()-2) + ". Selected word:";
cs.prompt_tokens.resize(1024);
int n = whisper_tokenize(ctx, k_prompt.c_str(), cs.prompt_tokens.data(), 1024);
cs.prompt_tokens.resize(n);
// prepare response
int index = commandset_list.size();
commandset_list.push_back(cs);
return json{{"index",index}};
}
json seek(struct whisper_context * ctx, audio_async &audio, json params) {
// whisper_state has the pertinent offsets, but there also seem to be a large
// number of scratch buffers that would prevent rewinding context in a manner similar to llama
// I'll give this a another pass once everything else is implemented,
// but for now, it's unsupported
throw json{
{"code", -32601},
{"message", "Seeking is not yet supported."}
};
}
json parse_job(const json &body, struct whisper_context * ctx, audio_async &audio, const whisper_params &params, std::vector<struct commandset> &commandset_list) {
// See: https://www.jsonrpc.org/specification
json id = body.at("id");
try {
std::string version = body.at("jsonrpc");
if (version != "2.0") {
// unsupported version
throw json{
{"code", -3260},
{"message", "invalid jsonrpc version"}
};
}
std::string method = body.at("method");
json jparams = json{{"dummy", "dummy"}};
if (body.contains("params"))
jparams = body.at("params");
json res;
// TODO: be consistent about argument order
fprintf(stderr, "Dispatching a job\n");
if (method == "unguided") { res = unguided_transcription(ctx, audio, jparams, params); }
else if (method == "guided") { res = guided_transcription(ctx, audio, params, jparams, commandset_list); }
else if (method == "seek") { res = seek(ctx, audio, jparams); }
else if (method == "registerCommandset") { res = register_commandset(ctx, jparams, commandset_list); }
else if (method == "echo") { res = jparams; }
return json{
{"jsonrpc", "2.0"},
{"result", res},
{"id", id}
};
} catch(json ex) {
return json {
{"jsonrpc", "2.0"},
{"error", ex},
{"id", id}
};
}
}
void process_loop(struct whisper_context * ctx, audio_async &audio, const whisper_params &params) {
std::deque<json> jobqueue;
std::vector<struct commandset> commandset_list;
while (true) {
// For eventual cancellation support, shouldn't block if job exists
if (std::cin.rdbuf()->in_avail() > 22 || jobqueue.size() == 0) {
int content_length;
if (scanf("Content-Length: %d", &content_length) != 1) {
fprintf(stderr, "Could not read input: %d", std::cin.peek());
return;
}
// scanf leaves the new lines intact
std::cin.ignore(2);
if (std::cin.peek() != 13) {
// Content-Type. jsonrpc necessitates utf8.
std::cin.ignore(200,10);
}
std::cin.ignore(2);
// A message is being sent and blocking is acceptable
std::string content(content_length,'\0');
std::cin.read(&content[0], content_length);
json job = json::parse(content);
// TODO: Some messages(cancellation) should skip queue here
if (job.is_array()) {
// response must also be batched. Will implement later
// for (subjob : job.begin())
// TODO: At the very least respond with an unsupported error.
} else {
jobqueue.push_back(job);
}
}
assert(jobqueue.size() > 0);
json job = jobqueue.front();
json resp = parse_job(job, ctx, audio, params, commandset_list);
if (resp != "unfinished") {
jobqueue.pop_front();
// send response
std::string data = resp.dump(-1, ' ', false, json::error_handler_t::replace);
fprintf(stdout, "Content-Length: %d\r\n\r\n%s\n", data.length()+1, data.c_str());
std::cout.flush();
}
}
}
int main(int argc, char ** argv) {
whisper_params params;
if (whisper_params_parse(argc, argv, params) == false) {
return 1;
}
if (whisper_lang_id(params.language.c_str()) == -1) {
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
whisper_print_usage(argc, argv, params);
exit(0);
}
// whisper init
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
// init audio
audio_async audio(30*1000);
if (!audio.init(params.capture_id, WHISPER_SAMPLE_RATE)) {
fprintf(stderr, "%s: audio.init() failed!\n", __func__);
return 1;
}
audio.resume();
// TODO: Investigate why this is required. An extra second of startup latency is not great
// wait for 1 second to avoid any buffered noise
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
audio.clear();
// TODO: consider some sort of indicator to designate loading has finished?
// Potentially better for the client to just start with a non-blocking message (register commands)
process_loop(ctx, audio, params);
audio.pause();
whisper_print_timings(ctx);
whisper_free(ctx);
return 0;
}

362
examples/lsp/whisper.vim Normal file
View File

@ -0,0 +1,362 @@
if !exists("g:whisper_dir")
let g:whisper_dir = expand($WHISPER_CPP_HOME)
if g:whisper_dir == ""
echoerr "Please provide a path to the whisper.cpp repo in either the $WHISPER_CPP_HOME environment variable, or g:whisper_dir"
endif
endif
if !exists("g:whisper_lsp_path")
let g:whisper_lsp_path = g:whisper_dir .. "lsp"
if !filereadable(g:whisper_lsp_path)
echoerr "Was not able to locate a lsp executable at: " .. g:whisper_lsp_path
throw "Executable not found"
endif
endif
if !exists("g:whisper_model_path")
" TODO: allow custom paths relative to the repo dir
let g:whisper_model_path = g:whisper_dir .. "models/ggml-base.en.bin"
if !filereadable(g:whisper_model_path)
echoerr "Could not find model at: " .. g:whisper_model_path
throw "Model not found"
endif
endif
let s:output_buffer = bufnr("whisper_log", v:true)
call setbufvar(s:output_buffer,"&buftype","nofile")
let s:lsp_command = [g:whisper_lsp_path,"-m",g:whisper_model_path]
" For faster execution. TODO: server load multiple models/run multiple servers?
" let s:lsp_command = [g:whisper_lsp_path, "-m", g:whisper_dir .. "models/ggml-tiny.en.bin", "-ac", "128"]
" requestCommands([params_dict])
func whisper#requestCommands(...)
let l:req = {"method": "guided", "params": {"commandset_index": 0}}
if a:0 > 0
call extend(l:req.params, a:1)
endif
let resp = ch_sendexpr(g:lsp_job, l:req, {"callback": function("s:commandCallback", [l:req.params, 0])})
endfunction
" doTranscription([params_dict])
func whisper#doTranscription(...)
let l:req = {"method": "unguided", "params": {}}
if a:0 > 0
call extend(l:req.params, a:1)
endif
let resp = ch_sendexpr(g:lsp_job, l:req, {"callback": function("s:transcriptionCallback", [function("s:insertText"),function("s:endTranscription")])})
endfunction
" For testing
func whisper#uppertest(cha)
echo tr(a:cha, s:c_lowerkeys, s:c_upperkeys)
endfunction
" (upper, exit, count, motion, command, insert/append, save run) "base"
" (upper, exit, count, motion, command, inside/around) "motion/visual"
" (upper, exit, count, motion, line, inside/around) "command already entered"
" (upper, exit, key, ) "from/till"
" upper and lower keys is used to translate between cases with tr
" Must be sunchronized
let s:c_lowerkeys = "1234567890-=qwertyuiop[]\\asdfghjkl;'zxcvbnm,./\""
let s:c_upperkeys = "!@#$%^&*()_+QWERTYUIOP{}|ASDFGHJKL:\"ZXCVBNM<>?'"
let s:c_count = split("1234567890\"",'\zs')
let s:c_command = split("ryuogpdxcv.iam", '\zs')
let s:c_motion = split("wetf'hjklnb$^)",'\zs')
" object words: Word, Sentence, Paragraph, [, (, <, Tag, {. ", '
let s:c_area = split("wsp])>t}\"'",'\zs')
"Special commands.
let s:c_special_always = ["exit", "upper"]
let s:c_special_normal = ["save", "run", "space"]
" If not in dict, key is spoken word,
" If key resolves to string, value is used for normal/motion, but key for chars
" If key resolves to dict, {0: "normal",1: "motion",2:"single char",3: "area"}
" Missing entries fall back as follows {0: "required", 1: 0, 2: "key", 3: 0}
let s:spoken_dict = {"w": "word", "e": "end", "r": "replace", "t": {0: "till", 3: "tag"}, "y": "yank", "u": "undo", "i": {0: "insert", 1: "inside"}, "o": "open", "p": {0: "paste", 3: "paragraph"}, "a": {0: "append", 1: "around"}, "s": {0: "substitute", 3: "sentence"}, "d": "delete", "f": "from", "g": "go", "h": "left", "j": "down", "k": "up", "l": "right", "c": "change", "v": "visual", "b": "back", "n": "next", "m": "mark", ".": {0: "repeat", 2: "period"}, "]": {0: "bracket", 2: "bracket"}, "'": {0: "jump", 2: "apostrophe", 3: "apostrophe"}, '"': {0: 'register', 2: "quotation", 3: "quotation"}, "-": {0: "minus", 2: "minus"}, "$": {0: "dollar", 2: "dollar"}, "^": {0: "carrot", 2: "carrot"}, ")": {0: "sentence", 2: "parenthesis", 3: "parenthesis"}, "}": {0: "paragraph", 2: "brace", 3: "brace"}, ">": {0: "indent", 2: "angle", 3: "angle"}}
" Give this another pass. This seems overly hacky even if it's functional
let s:sub_tran_msg = ""
func s:subTranProg(msg)
if s:sub_tran_msg != ""
let s:sub_tran_msg = s:sub_tran_msg .. a:msg
if mode() !=? 'v'
exe "normal" "u" .. s:sub_tran_msg
endif
else
if s:command_backlog == ""
" this should not occur
call s:logCallback(0, "Warning: Encountered sub transcription without prior command")
let s:command_backlog = "a"
endif
if a:msg[0] == ' '
let s:sub_tran_msg = s:command_backlog .. a:msg[1:-1]
else
let s:sub_tran_msg = s:command_backlog .. a:msg
endif
if mode() !=? 'v'
exe "normal" s:sub_tran_msg
endif
endif
call appendbufline(s:output_buffer, "$", s:sub_tran_msg .. ":" .. string(a:msg ))
endfunction
func s:subTranFinish(params, timestamp)
let s:repeat_command = s:sub_tran_msg
" Visual selection is lot if used with streaming, so streaming of partial
" transcriptions is disabled in visual mode
if mode() ==? 'v'
exe "normal" s:sub_tran_msg
endif
let s:sub_tran_msg = ""
let s:command_backlog = ""
exe "normal a\<C-G>u"
let l:params = a:params
let l:params.timestamp = a:timestamp
if exists("l:params.commandset_index")
unlet l:params.commandset_index
endif
call whisper#requestCommands(a:params)
endfunction
func s:logCallback(channel, msg)
call appendbufline(s:output_buffer,"$",a:msg)
endfunction
func s:transcriptionCallback(progressCallback, finishedCallback, channel, msg)
let l:tr = a:msg.result.transcription
let l:ex_ind = match(tolower(l:tr),"exit", len(l:tr)-6)
" The worst case I've observed so far is " Exit.", which is 6 characters
if l:ex_ind != -1
call a:progressCallback(strpart(l:tr,0,l:ex_ind-1))
call a:finishedCallback(a:msg.result.timestamp)
else
call a:progressCallback(l:tr)
let req = {"method": "unguided", "params": {"timestamp": a:msg.result.timestamp, "no_context": v:true}}
let resp = ch_sendexpr(g:lsp_job, req, {"callback": function("s:transcriptionCallback", [a:progressCallback, a:finishedCallback])})
endif
endfunc
func s:insertText(msg)
exe "normal a" .. a:msg
endfunction
func s:endTranscription(timestamp)
call appendbufline(s:output_buffer, "$", "Ending unguided transcription")
endfunction
" If a command does not include a whole actionable step, attempting to execute
" it discards the remainder of things. There is likely a simpler solution,
" but it can be made functional now by storing a backbuffer until actionable
let s:command_backlog = ""
let s:repeat_command = ""
let s:preceeding_upper = v:false
func s:commandCallback(params, commandset_index, channel, msg)
let l:command_index = a:msg.result.command_index
let l:do_execute = v:false
let l:next_mode = a:commandset_index
let l:command = s:commandset_list[a:commandset_index][l:command_index]
call s:logCallback(0, string(a:msg) .. " " .. a:commandset_index .. " " .. l:command)
if l:command_index == 0
"exit
"if s:command_backlog == ""
call s:logCallback(0,"Stopping command mode")
echo "No longer listening"
let s:command_backlog = ""
return
"else
" Legacy code to clear an existing buffer with exit.
" Was found to be rarely desired and is better introduced as a
" standalone command (clear?)
" call s:logCallback(0,"Clearing command_backlog" .. s:command_backlog)
" let s:command_backlog = ""
" let s:preceeding_upper = v:false
" endif
elseif l:command_index == 1
" upper
let s:preceeding_upper = !s:preceeding_upper
elseif l:command == "save"
" save and run can only happen in commandset 0,
exe "w"
elseif l:command == "run"
exe "make run"
elseif l:command == "space"
exe "normal i \<ESC>l"
elseif has_key(s:c_user, l:command)
let Userfunc = s:c_user[l:command]
if type(Userfunc) == v:t_string
let Userfunc = function(Userfunc)
endif
call Userfunc()
else
if s:preceeding_upper
" Upper should keep commandset
let s:preceeding_upper = v:false
let l:visual_command = tr(l:command, s:c_lowerkeys, s:c_upperkeys)
else
let l:visual_command = l:command
endif
echo s:command_backlog .. " - " .. l:visual_command
let s:command_backlog = s:command_backlog .. l:visual_command
if a:commandset_index == 2 || a:commandset_index == 3
" single key, either completes motion, replace, or register
" Should move to execute unless part of a register
" Change will be caught at execute
if s:command_backlog[-2:-2] !=# '"'
call s:logCallback(0,"not register")
let l:do_execute = v:true
end
let l:next_mode = 0
" commandset index only matters for a/i
elseif (l:command == "a" || l:command == "i") && a:commandset_index == 1
" inside/around. Is commandset 3
let l:next_mode = 3
elseif l:command ==# '"'
let l:next_mode = 2
elseif index(s:c_count, l:command) != -1
let l:next_mode = a:commandset_index
elseif index(s:c_motion, l:command) != -1
if l:command == 't' || l:command == 'f' || l:command == "'"
" prompt single key
let l:next_mode = 2
else
let l:do_execute = v:true
let l:next_mode = 0
endif
elseif index(s:c_command, l:command) != -1
if index(["y","g","d","c"], s:command_backlog[-1:-1]) != -1 && s:command_backlog[-1:-1] != s:command_backlog[-2:-2] && mode() !=? 'v'
" need motion or repeated command
" Potential for bad state here if disparaging command keys are
" entered (i.e. yd), but vim can handle checks for this at exe
" And checking for cases like y123d would complicate things
let l:next_mode = 1
elseif index(["i","a","c", "o", "s"], l:command) != -1 || s:command_backlog[-1:-1] ==# 'R'
"'Insert' mode, do general transcription
let l:req = {"method": "unguided", "params": a:params}
let l:req.params.timestamp = a:msg.result.timestamp
let l:req.params.no_context = v:true
let resp = ch_sendexpr(g:lsp_job, req, {"callback": function("s:transcriptionCallback", [function("s:subTranProg"), function("s:subTranFinish", [a:params])])})
return
elseif l:command == 'r' || l:command == 'm'
let l:next_mode = 2
elseif l:command == '.'
let l:next_mode = 0
let l:do_execute = v:true
let s:command_backlog = s:command_backlog[0:-2] .. s:repeat_command
else
if l:command ==? 'v'
let l:next_mode = 1
else
let l:next_mode = 0
endif
let l:do_execute = v:true
endif
else
throw "Invalid command state: " .. l:command .. " " .. a:commandset_index .. " " .. s:command_backlog
endif
endif
if l:do_execute
if mode() ==?'v' && l:next_mode == 0
let l:next_mode = 1
elseif match(s:command_backlog, 'c') != -1
let l:req = {"method": "unguided", "params": a:params}
let l:req.params.timestamp = a:msg.result.timestamp
let l:req.params.no_context = v:true
let resp = ch_sendexpr(g:lsp_job, req, {"callback": function("s:transcriptionCallback", [function("s:subTranProg"), function("s:subTranFinish", [a:params])])})
return
endif
exe "normal" s:command_backlog
if index(s:c_motion + ["u"],l:command) == -1
exe "normal a\<C-G>u"
let s:repeat_command = s:command_backlog
call s:logCallback(0, s:command_backlog)
endif
let s:command_backlog = ""
endif
let l:req = {"method": "guided", "params": a:params}
let l:req.params.timestamp = a:msg.result.timestamp
let l:req.params.commandset_index = l:next_mode
let resp = ch_sendexpr(g:lsp_job, l:req, {"callback": function("s:commandCallback",[a:params, l:next_mode])})
endfunction
func s:loadedCallback(channel, msg)
echo "Loading complete"
call s:logCallback(a:channel, a:msg)
endfunction
func s:registerCommandset(commandlist, is_final)
let req = {"method": "registerCommandset"}
let req.params = a:commandlist
call s:logCallback(0, join(a:commandlist))
call add(g:whisper_commandlist_spoken, a:commandlist)
if a:is_final
let resp = ch_sendexpr(g:lsp_job, req, {"callback": "s:loadedCallback"})
else
let resp = ch_sendexpr(g:lsp_job, req, {"callback": "s:logCallback"})
endif
endfunction
func s:registerAllCommands()
let l:normal = s:c_special_always + s:c_special_normal + s:c_count + s:c_command + s:c_motion + keys(s:c_user)
let l:visual = s:c_special_always + s:c_count + s:c_command + s:c_motion
" Currently the same as visual.
" let l:post_command = s:c_special_always + s:c_count + s:c_command + s:c_motion
let l:single_key = s:c_special_always + split(s:c_lowerkeys, '\zs')
let l:area = s:c_special_always + s:c_area
" Used only for compatibility with the testing script
let g:whisper_commandlist_spoken = []
let s:commandset_list = [l:normal, l:visual, l:single_key, l:area]
call s:registerCommandset(s:commandsetToSpoken(l:normal, 0), v:false)
call s:registerCommandset(s:commandsetToSpoken(l:visual, 1), v:false)
call s:registerCommandset(s:commandsetToSpoken(l:single_key, 2), v:false)
call s:registerCommandset(s:commandsetToSpoken(l:area, 3), v:true)
endfunction
func s:commandsetToSpoken(commandset, spoken_index)
let l:spoken_list = []
for l:command in a:commandset
if has_key(s:spoken_dict, l:command)
let l:spoken_value = s:spoken_dict[l:command]
if type(l:spoken_value) == v:t_dict
if has_key(l:spoken_value, a:spoken_index)
let l:spoken_value = l:spoken_value[a:spoken_index]
else
if a:spoken_index == 2
let l:spoken_value = l:command
else
let l:spoken_value = l:spoken_value[0]
endif
endif
else
if a:spoken_index == 2
let l:spoken_value = l:command
endif
endif
else
let l:spoken_value = l:command
endif
call add(l:spoken_list, l:spoken_value)
endfor
return l:spoken_list
endfunction
" TODO: Check lifetime. If the script is resourced, is the existing
" s:lsp_job dropped and therefore killed?
" This seems to not be the case and I've had to deal with zombie processes
" that survive exiting vim, even though said behavior conflicts with my
" understanding of the provided documentation
let s:lsp_opts = {"in_mode": "lsp", "out_mode": "lsp", "err_mode": "nl", "err_io": "buffer", "err_buf": s:output_buffer}
if !exists("g:lsp_job")
if exists("g:whisper_user_commands")
let s:c_user = g:whisper_user_commands
else
let s:c_user = {}
endif
let g:lsp_job = job_start(s:lsp_command, s:lsp_opts)
if job_status(g:lsp_job) == "fail"
echoerr "Failed to start whisper job"
endif
call s:registerAllCommands()
endif

View File

@ -10,6 +10,10 @@
#include <vector>
#include <cstring>
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
// Terminal color map. 10 colors grouped in ranges [0.0, 0.1, ..., 0.9]
// Lowest is red, middle is yellow, highest is green.
const std::vector<std::string> k_colors = {
@ -55,6 +59,7 @@ struct whisper_params {
int32_t offset_t_ms = 0;
int32_t offset_n = 0;
int32_t duration_ms = 0;
int32_t progress_step = 5;
int32_t max_context = -1;
int32_t max_len = 0;
int32_t best_of = 2;
@ -64,28 +69,36 @@ struct whisper_params {
float entropy_thold = 2.40f;
float logprob_thold = -1.00f;
bool speed_up = false;
bool translate = false;
bool detect_language= false;
bool diarize = false;
bool split_on_word = false;
bool no_fallback = false;
bool output_txt = false;
bool output_vtt = false;
bool output_srt = false;
bool output_wts = false;
bool output_csv = false;
bool output_jsn = false;
bool output_lrc = false;
bool print_special = false;
bool print_colors = false;
bool print_progress = false;
bool no_timestamps = false;
bool speed_up = false;
bool debug_mode = false;
bool translate = false;
bool detect_language = false;
bool diarize = false;
bool tinydiarize = false;
bool split_on_word = false;
bool no_fallback = false;
bool output_txt = false;
bool output_vtt = false;
bool output_srt = false;
bool output_wts = false;
bool output_csv = false;
bool output_jsn = false;
bool output_lrc = false;
bool print_special = false;
bool print_colors = false;
bool print_progress = false;
bool no_timestamps = false;
bool log_score = false;
std::string language = "en";
std::string language = "en";
std::string prompt;
std::string font_path = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
std::string model = "models/ggml-base.en.bin";
std::string model = "models/ggml-base.en.bin";
// [TDRZ] speaker turn string
std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line
std::string openvino_encode_device = "CPU";
std::vector<std::string> fname_inp = {};
std::vector<std::string> fname_out = {};
@ -111,41 +124,45 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
whisper_print_usage(argc, argv, params);
exit(0);
}
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(argv[++i]); }
else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(argv[++i]); }
else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(argv[++i]); }
else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); }
else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); }
else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); }
else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); }
else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); }
else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); }
else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); }
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; }
else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; }
else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; }
else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; }
else if (arg == "-olrc" || arg == "--output-lrc") { params.output_lrc = true; }
else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; }
else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; }
else if (arg == "-oj" || arg == "--output-json") { params.output_jsn = true; }
else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(argv[++i]); }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
else if (arg == "-dl" || arg == "--detect-language"){ params.detect_language= true; }
else if ( arg == "--prompt") { params.prompt = argv[++i]; }
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); }
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(argv[++i]); }
else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(argv[++i]); }
else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(argv[++i]); }
else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); }
else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); }
else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); }
else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); }
else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); }
else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); }
else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); }
// else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-debug"|| arg == "--debug-mode") { params.debug_mode = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; }
else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; }
else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; }
else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; }
else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; }
else if (arg == "-olrc" || arg == "--output-lrc") { params.output_lrc = true; }
else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; }
else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; }
else if (arg == "-oj" || arg == "--output-json") { params.output_jsn = true; }
else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(argv[++i]); }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
else if (arg == "-dl" || arg == "--detect-language") { params.detect_language = true; }
else if ( arg == "--prompt") { params.prompt = argv[++i]; }
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); }
else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; }
else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; }
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params);
@ -175,9 +192,11 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold);
fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold);
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
// fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false");
fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false");
fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false");
fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
@ -191,12 +210,14 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true");
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false");
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
fprintf(stderr, " -dl, --detect-language [%-7s] exit after automatically detecting language\n", params.detect_language ? "true" : "false");
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str());
fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false");
fprintf(stderr, "\n");
}
@ -204,8 +225,50 @@ struct whisper_print_user_data {
const whisper_params * params;
const std::vector<std::vector<float>> * pcmf32s;
int progress_prev;
};
std::string estimate_diarization_speaker(std::vector<std::vector<float>> pcmf32s, int64_t t0, int64_t t1, bool id_only = false) {
std::string speaker = "";
const int64_t n_samples = pcmf32s[0].size();
const int64_t is0 = timestamp_to_sample(t0, n_samples);
const int64_t is1 = timestamp_to_sample(t1, n_samples);
double energy0 = 0.0f;
double energy1 = 0.0f;
for (int64_t j = is0; j < is1; j++) {
energy0 += fabs(pcmf32s[0][j]);
energy1 += fabs(pcmf32s[1][j]);
}
if (energy0 > 1.1*energy1) {
speaker = "0";
} else if (energy1 > 1.1*energy0) {
speaker = "1";
} else {
speaker = "?";
}
//printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, speaker = %s\n", is0, is1, energy0, energy1, speaker.c_str());
if (!id_only) {
speaker.insert(0, "(speaker ");
speaker.append(")");
}
return speaker;
}
void whisper_print_progress_callback(struct whisper_context * ctx, struct whisper_state * /*state*/, int progress, void * user_data) {
int progress_step = ((whisper_print_user_data *) user_data)->params->progress_step;
int * progress_prev = &(((whisper_print_user_data *) user_data)->progress_prev);
if (progress >= *progress_prev + progress_step) {
*progress_prev += progress_step;
fprintf(stderr, "%s: progress = %3d%%\n", __func__, progress);
}
}
void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper_state * /*state*/, int n_new, void * user_data) {
const auto & params = *((whisper_print_user_data *) user_data)->params;
const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s;
@ -235,28 +298,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper
}
if (params.diarize && pcmf32s.size() == 2) {
const int64_t n_samples = pcmf32s[0].size();
const int64_t is0 = timestamp_to_sample(t0, n_samples);
const int64_t is1 = timestamp_to_sample(t1, n_samples);
double energy0 = 0.0f;
double energy1 = 0.0f;
for (int64_t j = is0; j < is1; j++) {
energy0 += fabs(pcmf32s[0][j]);
energy1 += fabs(pcmf32s[1][j]);
}
if (energy0 > 1.1*energy1) {
speaker = "(speaker 0)";
} else if (energy1 > 1.1*energy0) {
speaker = "(speaker 1)";
} else {
speaker = "(speaker ?)";
}
//printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str());
speaker = estimate_diarization_speaker(pcmf32s, t0, t1);
}
if (params.print_colors) {
@ -281,6 +323,12 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper
printf("%s%s", speaker.c_str(), text);
}
if (params.tinydiarize) {
if (whisper_full_get_segment_speaker_turn_next(ctx, i)) {
printf("%s", params.tdrz_speaker_turn.c_str());
}
}
// with timestamps or speakers: each segment on new line
if (!params.no_timestamps || params.diarize) {
printf("\n");
@ -290,7 +338,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper
}
}
bool output_txt(struct whisper_context * ctx, const char * fname) {
bool output_txt(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
std::ofstream fout(fname);
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
@ -302,13 +350,22 @@ bool output_txt(struct whisper_context * ctx, const char * fname) {
const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
fout << text << "\n";
std::string speaker = "";
if (params.diarize && pcmf32s.size() == 2)
{
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
speaker = estimate_diarization_speaker(pcmf32s, t0, t1);
}
fout << speaker << text << "\n";
}
return true;
}
bool output_vtt(struct whisper_context * ctx, const char * fname) {
bool output_vtt(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
std::ofstream fout(fname);
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
@ -324,15 +381,23 @@ bool output_vtt(struct whisper_context * ctx, const char * fname) {
const char * text = whisper_full_get_segment_text(ctx, i);
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
std::string speaker = "";
if (params.diarize && pcmf32s.size() == 2)
{
speaker = estimate_diarization_speaker(pcmf32s, t0, t1, true);
speaker.insert(0, "<v Speaker");
speaker.append(">");
}
fout << to_timestamp(t0) << " --> " << to_timestamp(t1) << "\n";
fout << text << "\n\n";
fout << speaker << text << "\n\n";
}
return true;
}
bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_params & params) {
bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
std::ofstream fout(fname);
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
@ -346,10 +411,16 @@ bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_
const char * text = whisper_full_get_segment_text(ctx, i);
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
std::string speaker = "";
if (params.diarize && pcmf32s.size() == 2)
{
speaker = estimate_diarization_speaker(pcmf32s, t0, t1);
}
fout << i + 1 + params.offset_n << "\n";
fout << to_timestamp(t0, true) << " --> " << to_timestamp(t1, true) << "\n";
fout << text << "\n\n";
fout << speaker << text << "\n\n";
}
return true;
@ -386,7 +457,7 @@ char *escape_double_quotes_and_backslashes(const char *str) {
return escaped;
}
bool output_csv(struct whisper_context * ctx, const char * fname) {
bool output_csv(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
std::ofstream fout(fname);
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
@ -396,7 +467,13 @@ bool output_csv(struct whisper_context * ctx, const char * fname) {
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
const int n_segments = whisper_full_n_segments(ctx);
fout << "start,end,text\n";
fout << "start,end,";
if (params.diarize && pcmf32s.size() == 2)
{
fout << "speaker,";
}
fout << "text\n";
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
@ -404,13 +481,37 @@ bool output_csv(struct whisper_context * ctx, const char * fname) {
char * text_escaped = escape_double_quotes_and_backslashes(text);
//need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds.
fout << 10 * t0 << "," << 10 * t1 << ",\"" << text_escaped << "\"\n";
fout << 10 * t0 << "," << 10 * t1 << ",";
if (params.diarize && pcmf32s.size() == 2)
{
fout << estimate_diarization_speaker(pcmf32s, t0, t1, true) << ",";
}
fout << "\"" << text_escaped << "\"\n";
}
return true;
}
bool output_json(struct whisper_context * ctx, const char * fname, const whisper_params & params) {
bool output_score(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
std::ofstream fout(fname);
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
const int n_segments = whisper_full_n_segments(ctx);
// fprintf(stderr,"segments: %d\n",n_segments);
for (int i = 0; i < n_segments; ++i) {
const int n_tokens = whisper_full_n_tokens(ctx, i);
// fprintf(stderr,"tokens: %d\n",n_tokens);
for (int j = 0; j < n_tokens; j++) {
auto token = whisper_full_get_token_text(ctx, i, j);
auto probability = whisper_full_get_token_p(ctx, i, j);
fout << token << '\t' << probability << std::endl;
// fprintf(stderr,"token: %s %f\n",token,probability);
}
}
return true;
}
bool output_json(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
std::ofstream fout(fname);
int indent = 0;
@ -424,13 +525,13 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper
indent++;
};
auto end_arr = [&](bool end = false) {
auto end_arr = [&](bool end) {
indent--;
doindent();
fout << (end ? "]\n" : "},\n");
};
auto start_obj = [&](const char *name = nullptr) {
auto start_obj = [&](const char *name) {
doindent();
if (name) {
fout << "\"" << name << "\": {\n";
@ -440,7 +541,7 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper
indent++;
};
auto end_obj = [&](bool end = false) {
auto end_obj = [&](bool end) {
indent--;
doindent();
fout << (end ? "}\n" : "},\n");
@ -451,24 +552,24 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper
fout << "\"" << name << "\": ";
};
auto value_s = [&](const char *name, const char *val, bool end = false) {
auto value_s = [&](const char *name, const char *val, bool end) {
start_value(name);
char * val_escaped = escape_double_quotes_and_backslashes(val);
fout << "\"" << val_escaped << (end ? "\"\n" : "\",\n");
free(val_escaped);
};
auto end_value = [&](bool end = false) {
auto end_value = [&](bool end) {
fout << (end ? "\n" : ",\n");
};
auto value_i = [&](const char *name, const int64_t val, bool end = false) {
auto value_i = [&](const char *name, const int64_t val, bool end) {
start_value(name);
fout << val;
end_value(end);
};
auto value_b = [&](const char *name, const bool val, bool end = false) {
auto value_b = [&](const char *name, const bool val, bool end) {
start_value(name);
fout << (val ? "true" : "false");
end_value(end);
@ -480,53 +581,62 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper
}
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
start_obj();
value_s("systeminfo", whisper_print_system_info());
start_obj(nullptr);
value_s("systeminfo", whisper_print_system_info(), false);
start_obj("model");
value_s("type", whisper_model_type_readable(ctx));
value_b("multilingual", whisper_is_multilingual(ctx));
value_i("vocab", whisper_model_n_vocab(ctx));
value_s("type", whisper_model_type_readable(ctx), false);
value_b("multilingual", whisper_is_multilingual(ctx), false);
value_i("vocab", whisper_model_n_vocab(ctx), false);
start_obj("audio");
value_i("ctx", whisper_model_n_audio_ctx(ctx));
value_i("state", whisper_model_n_audio_state(ctx));
value_i("head", whisper_model_n_audio_head(ctx));
value_i("ctx", whisper_model_n_audio_ctx(ctx), false);
value_i("state", whisper_model_n_audio_state(ctx), false);
value_i("head", whisper_model_n_audio_head(ctx), false);
value_i("layer", whisper_model_n_audio_layer(ctx), true);
end_obj();
end_obj(false);
start_obj("text");
value_i("ctx", whisper_model_n_text_ctx(ctx));
value_i("state", whisper_model_n_text_state(ctx));
value_i("head", whisper_model_n_text_head(ctx));
value_i("ctx", whisper_model_n_text_ctx(ctx), false);
value_i("state", whisper_model_n_text_state(ctx), false);
value_i("head", whisper_model_n_text_head(ctx), false);
value_i("layer", whisper_model_n_text_layer(ctx), true);
end_obj();
value_i("mels", whisper_model_n_mels(ctx));
end_obj(false);
value_i("mels", whisper_model_n_mels(ctx), false);
value_i("ftype", whisper_model_ftype(ctx), true);
end_obj();
end_obj(false);
start_obj("params");
value_s("model", params.model.c_str());
value_s("language", params.language.c_str());
value_s("model", params.model.c_str(), false);
value_s("language", params.language.c_str(), false);
value_b("translate", params.translate, true);
end_obj();
end_obj(false);
start_obj("result");
value_s("language", whisper_lang_str(whisper_full_lang_id(ctx)), true);
end_obj();
end_obj(false);
start_arr("transcription");
const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
start_obj();
start_obj(nullptr);
start_obj("timestamps");
value_s("from", to_timestamp(t0, true).c_str());
value_s("from", to_timestamp(t0, true).c_str(), false);
value_s("to", to_timestamp(t1, true).c_str(), true);
end_obj();
end_obj(false);
start_obj("offsets");
value_i("from", t0 * 10);
value_i("from", t0 * 10, false);
value_i("to", t1 * 10, true);
end_obj();
value_s("text", text, true);
end_obj(false);
value_s("text", text, !params.diarize && !params.tinydiarize);
if (params.diarize && pcmf32s.size() == 2) {
value_s("speaker", estimate_diarization_speaker(pcmf32s, t0, t1, true).c_str(), true);
}
if (params.tinydiarize) {
value_b("speaker_turn_next", whisper_full_get_segment_speaker_turn_next(ctx, i), true);
}
end_obj(i == (n_segments - 1));
}
@ -538,7 +648,7 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper
// karaoke video generation
// outputs a bash script that uses ffmpeg to generate a video with the subtitles
// TODO: font parameter adjustments
bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, float t_sec) {
bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, float t_sec, std::vector<std::vector<float>> pcmf32s) {
std::ofstream fout(fname);
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
@ -575,6 +685,11 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
fout << "drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='':enable='between(t," << t0/100.0 << "," << t0/100.0 << ")'";
bool is_first = true;
std::string speaker = "";
if (params.diarize && pcmf32s.size() == 2) {
speaker = estimate_diarization_speaker(pcmf32s, t0, t1);
}
for (int j = 0; j < n; ++j) {
const auto & token = tokens[j];
@ -583,13 +698,19 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
continue;
}
std::string txt_bg;
std::string txt_fg; // highlight token
std::string txt_ul; // underline
std::string txt_bg = "";
std::string txt_fg = ""; // highlight token
std::string txt_ul = ""; // underline
txt_bg = "> ";
txt_fg = "> ";
txt_ul = "\\ \\ ";
if (params.diarize && pcmf32s.size() == 2) {
txt_bg = speaker;
txt_fg = speaker;
txt_ul = "\\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ ";
}
txt_bg.append("> ");
txt_fg.append("> ");
txt_ul.append("\\ \\ ");
{
for (int k = 0; k < n; ++k) {
@ -652,8 +773,7 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
return true;
}
bool output_lrc(struct whisper_context * ctx, const char * fname) {
bool output_lrc(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
std::ofstream fout(fname);
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
@ -678,8 +798,16 @@ bool output_lrc(struct whisper_context * ctx, const char * fname) {
char buf[16];
snprintf(buf, sizeof(buf), "%02d:%02d.%02d", (int) min, (int) sec, (int) ( msec / 10));
std::string timestamp_lrc = std::string(buf);
std::string speaker = "";
fout << '[' << timestamp_lrc << ']' << text << "\n";
if (params.diarize && pcmf32s.size() == 2)
{
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
speaker = estimate_diarization_speaker(pcmf32s, t0, t1);
}
fout << '[' << timestamp_lrc << ']' << speaker << text << "\n";
}
return true;
@ -689,6 +817,7 @@ int main(int argc, char ** argv) {
whisper_params params;
if (whisper_params_parse(argc, argv, params) == false) {
whisper_print_usage(argc, argv, params);
return 1;
}
@ -704,6 +833,12 @@ int main(int argc, char ** argv) {
exit(0);
}
if (params.diarize && params.tinydiarize) {
fprintf(stderr, "error: cannot use both --diarize and --tinydiarize\n");
whisper_print_usage(argc, argv, params);
exit(0);
}
// whisper init
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
@ -713,6 +848,9 @@ int main(int argc, char ** argv) {
return 3;
}
// initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured
whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr);
for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
const auto fname_inp = params.fname_inp[f];
const auto fname_out = f < (int) params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
@ -745,11 +883,12 @@ int main(int argc, char ** argv) {
if (params.detect_language) {
params.language = "auto";
}
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, timestamps = %d ...\n",
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, %stimestamps = %d ...\n",
__func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
params.n_threads, params.n_processors,
params.language.c_str(),
params.translate ? "translate" : "transcribe",
params.tinydiarize ? "tdrz = 1, " : "",
params.no_timestamps ? 0 : 1);
fprintf(stderr, "\n");
@ -779,6 +918,9 @@ int main(int argc, char ** argv) {
wparams.split_on_word = params.split_on_word;
wparams.speed_up = params.speed_up;
wparams.debug_mode = params.debug_mode;
wparams.tdrz_enable = params.tinydiarize; // [TDRZ]
wparams.initial_prompt = params.prompt.c_str();
@ -789,7 +931,7 @@ int main(int argc, char ** argv) {
wparams.entropy_thold = params.entropy_thold;
wparams.logprob_thold = params.logprob_thold;
whisper_print_user_data user_data = { &params, &pcmf32s };
whisper_print_user_data user_data = { &params, &pcmf32s, 0 };
// this callback is called on each new segment
if (!wparams.print_realtime) {
@ -797,6 +939,11 @@ int main(int argc, char ** argv) {
wparams.new_segment_callback_user_data = &user_data;
}
if (wparams.print_progress) {
wparams.progress_callback = whisper_print_progress_callback;
wparams.progress_callback_user_data = &user_data;
}
// example for abort mechanism
// in this example, we do not abort the processing, but we could if the flag is set to true
// the callback is called before every encoder run - if it returns false, the processing is aborted
@ -823,43 +970,49 @@ int main(int argc, char ** argv) {
// output to text file
if (params.output_txt) {
const auto fname_txt = fname_out + ".txt";
output_txt(ctx, fname_txt.c_str());
output_txt(ctx, fname_txt.c_str(), params, pcmf32s);
}
// output to VTT file
if (params.output_vtt) {
const auto fname_vtt = fname_out + ".vtt";
output_vtt(ctx, fname_vtt.c_str());
output_vtt(ctx, fname_vtt.c_str(), params, pcmf32s);
}
// output to SRT file
if (params.output_srt) {
const auto fname_srt = fname_out + ".srt";
output_srt(ctx, fname_srt.c_str(), params);
output_srt(ctx, fname_srt.c_str(), params, pcmf32s);
}
// output to WTS file
if (params.output_wts) {
const auto fname_wts = fname_out + ".wts";
output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE);
output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE, pcmf32s);
}
// output to CSV file
if (params.output_csv) {
const auto fname_csv = fname_out + ".csv";
output_csv(ctx, fname_csv.c_str());
output_csv(ctx, fname_csv.c_str(), params, pcmf32s);
}
// output to JSON file
if (params.output_jsn) {
const auto fname_jsn = fname_out + ".json";
output_json(ctx, fname_jsn.c_str(), params);
output_json(ctx, fname_jsn.c_str(), params, pcmf32s);
}
// output to LRC file
if (params.output_lrc) {
const auto fname_lrc = fname_out + ".lrc";
output_lrc(ctx, fname_lrc.c_str());
output_lrc(ctx, fname_lrc.c_str(), params, pcmf32s);
}
// output to score file
if (params.log_score) {
const auto fname_score = fname_out + ".score.txt";
output_score(ctx, fname_score.c_str(), params, pcmf32s);
}
}
}

View File

@ -25,7 +25,7 @@ struct whisper_hparams {
int32_t n_text_head = 6;
int32_t n_text_layer = 4;
int32_t n_mels = 80;
int32_t f16 = 1;
int32_t ftype = 1;
};
struct whisper_filters {
@ -57,7 +57,7 @@ bool whisper_model_quantize(const std::string & fname_inp, const std::string & f
{
uint32_t magic;
finp.read((char *) &magic, sizeof(magic));
if (magic != 0x67676d6c) {
if (magic != GGML_FILE_MAGIC) {
fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname_inp.c_str());
return false;
}
@ -79,7 +79,10 @@ bool whisper_model_quantize(const std::string & fname_inp, const std::string & f
finp.read((char *) &hparams.n_text_head, sizeof(hparams.n_text_head));
finp.read((char *) &hparams.n_text_layer, sizeof(hparams.n_text_layer));
finp.read((char *) &hparams.n_mels, sizeof(hparams.n_mels));
finp.read((char *) &hparams.f16, sizeof(hparams.f16));
finp.read((char *) &hparams.ftype, sizeof(hparams.ftype));
const int32_t qntvr_src = hparams.ftype / GGML_QNT_VERSION_FACTOR;
const int32_t ftype_dst = GGML_QNT_VERSION * GGML_QNT_VERSION_FACTOR + ftype;
fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab);
fprintf(stderr, "%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx);
@ -91,19 +94,22 @@ bool whisper_model_quantize(const std::string & fname_inp, const std::string & f
fprintf(stderr, "%s: n_text_head = %d\n", __func__, hparams.n_text_head);
fprintf(stderr, "%s: n_text_layer = %d\n", __func__, hparams.n_text_layer);
fprintf(stderr, "%s: n_mels = %d\n", __func__, hparams.n_mels);
fprintf(stderr, "%s: f16 = %d\n", __func__, hparams.f16);
fprintf(stderr, "%s: ftype (src) = %d\n", __func__, hparams.ftype);
fprintf(stderr, "%s: qntvr (src) = %d\n", __func__, qntvr_src);
fprintf(stderr, "%s: ftype (dst) = %d\n", __func__, ftype_dst);
fprintf(stderr, "%s: qntvr (dst) = %d\n", __func__, GGML_QNT_VERSION);
fout.write((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
fout.write((char *) &hparams.n_audio_ctx, sizeof(hparams.n_audio_ctx));
fout.write((char *) &hparams.n_audio_state, sizeof(hparams.n_audio_state));
fout.write((char *) &hparams.n_audio_head, sizeof(hparams.n_audio_head));
fout.write((char *) &hparams.n_audio_layer, sizeof(hparams.n_audio_layer));
fout.write((char *) &hparams.n_text_ctx, sizeof(hparams.n_text_ctx));
fout.write((char *) &hparams.n_text_state, sizeof(hparams.n_text_state));
fout.write((char *) &hparams.n_text_head, sizeof(hparams.n_text_head));
fout.write((char *) &hparams.n_text_layer, sizeof(hparams.n_text_layer));
fout.write((char *) &hparams.n_mels, sizeof(hparams.n_mels));
fout.write((char *) &ftype, sizeof(hparams.f16));
fout.write((const char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
fout.write((const char *) &hparams.n_audio_ctx, sizeof(hparams.n_audio_ctx));
fout.write((const char *) &hparams.n_audio_state, sizeof(hparams.n_audio_state));
fout.write((const char *) &hparams.n_audio_head, sizeof(hparams.n_audio_head));
fout.write((const char *) &hparams.n_audio_layer, sizeof(hparams.n_audio_layer));
fout.write((const char *) &hparams.n_text_ctx, sizeof(hparams.n_text_ctx));
fout.write((const char *) &hparams.n_text_state, sizeof(hparams.n_text_state));
fout.write((const char *) &hparams.n_text_head, sizeof(hparams.n_text_head));
fout.write((const char *) &hparams.n_text_layer, sizeof(hparams.n_text_layer));
fout.write((const char *) &hparams.n_mels, sizeof(hparams.n_mels));
fout.write((const char *) &ftype_dst, sizeof(hparams.ftype));
}
// load mel filters
@ -132,15 +138,17 @@ bool whisper_model_quantize(const std::string & fname_inp, const std::string & f
// return false;
//}
std::string word;
char word[129];
for (int i = 0; i < n_vocab; i++) {
uint32_t len;
finp.read ((char *) &len, sizeof(len));
fout.write((char *) &len, sizeof(len));
word.resize(len);
finp.read ((char *) word.data(), len);
fout.write((char *) word.data(), len);
word[len] = '\0';
finp.read ((char *) word, len);
fout.write((char *) word, len);
vocab.token_to_id[word] = i;
vocab.id_to_token[i] = word;

View File

@ -47,6 +47,7 @@ struct whisper_params {
bool print_special = false;
bool no_context = true;
bool no_timestamps = false;
bool tinydiarize = false;
std::string language = "en";
std::string model = "models/ggml-base.en.bin";
@ -80,6 +81,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; }
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params);
@ -113,6 +116,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false");
fprintf(stderr, "\n");
}
@ -299,6 +303,8 @@ int main(int argc, char ** argv) {
wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;
wparams.tdrz_enable = params.tinydiarize; // [TDRZ]
// disable temperature fallback
//wparams.temperature_inc = -1.0f;
wparams.temperature_inc = params.no_fallback ? 0.0f : wparams.temperature_inc;
@ -344,10 +350,19 @@ int main(int argc, char ** argv) {
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
printf ("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
std::string output = "[" + to_timestamp(t0) + " --> " + to_timestamp(t1) + "] " + text;
if (whisper_full_get_segment_speaker_turn_next(ctx, i)) {
output += " [SPEAKER_TURN]";
}
output += "\n";
printf("%s", output.c_str());
fflush(stdout);
if (params.fname_out.length() > 0) {
fout << "[" << to_timestamp(t0) << " --> " << to_timestamp(t1) << "] " << text << std::endl;
fout << output;
}
}
}

View File

@ -42,8 +42,8 @@ Example usage:
## TTS
For best experience, this example needs a TTS tool to convert the generated text responses to voice.
You can use any TTS engine that you would like - simply edit the [speak.sh](speak.sh) script to your needs.
By default, it is configured to use MacOS's `say`, but you can use whatever you wish.
You can use any TTS engine that you would like - simply edit the [speak](speak) script to your needs.
By default, it is configured to use MacOS's `say` or Windows SpeechSynthesizer, but you can use whatever you wish.
## Discussion

View File

@ -1,23 +1,20 @@
import sys
import importlib.util
api_key = "" #Write your https://beta.elevenlabs.io api key here
if not api_key:
print("To use elevenlabs you have to register to https://beta.elevenlabs.io and add your elevenlabs api key to examples/talk-llama/eleven-labs.py")
sys.exit()
if importlib.util.find_spec("elevenlabs") is None:
print("elevenlabs library is not installed, you can install it to your enviroment using 'pip install elevenlabs'")
sys.exit()
from elevenlabs import ElevenLabs
eleven = ElevenLabs(api_key)
from elevenlabs import generate, play, save
# Get a Voice object, by name or UUID
voice = eleven.voices["Arnold"] #Possible Voices: Adam Antoni Arnold Bella Domi Elli Josh
voice = "Arnold" #Possible Voices: Adam Antoni Arnold Bella Domi Elli Josh
# Generate the TTS
audio = voice.generate(str(sys.argv[2:]))
audio = generate(
text=str(sys.argv[2:]),
voice=voice
)
# Save the TTS to a file
audio.save("audio")
save(audio, "audio.mp3")

View File

@ -14,6 +14,7 @@
#include <string>
#include <vector>
#include <stdexcept>
#ifdef __has_include
#if __has_include(<unistd.h>)
@ -74,7 +75,7 @@ struct llama_file {
llama_file(const char * fname, const char * mode) {
fp = std::fopen(fname, mode);
if (fp == NULL) {
throw format("failed to open %s: %s", fname, std::strerror(errno));
throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno)));
}
seek(0, SEEK_END);
size = tell();
@ -100,17 +101,17 @@ struct llama_file {
LLAMA_ASSERT(ret == 0); // same
}
void read_raw(void * ptr, size_t size) {
if (size == 0) {
void read_raw(void * ptr, size_t len) const {
if (len == 0) {
return;
}
errno = 0;
std::size_t ret = std::fread(ptr, size, 1, fp);
std::size_t ret = std::fread(ptr, len, 1, fp);
if (ferror(fp)) {
throw format("read error: %s", strerror(errno));
throw std::runtime_error(format("read error: %s", strerror(errno)));
}
if (ret != 1) {
throw std::string("unexpectedly reached end of file");
throw std::runtime_error(std::string("unexpectedly reached end of file"));
}
}
@ -126,14 +127,14 @@ struct llama_file {
return std::string(chars.data(), len);
}
void write_raw(const void * ptr, size_t size) {
if (size == 0) {
void write_raw(const void * ptr, size_t len) const {
if (len == 0) {
return;
}
errno = 0;
size_t ret = std::fwrite(ptr, size, 1, fp);
size_t ret = std::fwrite(ptr, len, 1, fp);
if (ret != 1) {
throw format("write error: %s", strerror(errno));
throw std::runtime_error(format("write error: %s", strerror(errno)));
}
}
@ -171,7 +172,7 @@ struct llama_mmap {
#ifdef _POSIX_MAPPED_FILES
static constexpr bool SUPPORTED = true;
llama_mmap(struct llama_file * file, bool prefetch = true) {
llama_mmap(struct llama_file * file, size_t prefetch = (size_t) -1 /* -1 = max value */) {
size = file->size;
int fd = fileno(file->fp);
int flags = MAP_SHARED;
@ -180,13 +181,13 @@ struct llama_mmap {
#endif
addr = mmap(NULL, file->size, PROT_READ, flags, fd, 0);
if (addr == MAP_FAILED) {
throw format("mmap failed: %s", strerror(errno));
throw std::runtime_error(format("mmap failed: %s", strerror(errno)));
}
if (prefetch) {
if (prefetch > 0) {
// Advise the kernel to preload the mapped memory
if (madvise(addr, file->size, MADV_WILLNEED)) {
fprintf(stderr, "warning: madvise(.., MADV_WILLNEED) failed: %s\n",
if (posix_madvise(addr, std::min(file->size, prefetch), POSIX_MADV_WILLNEED)) {
fprintf(stderr, "warning: posix_madvise(.., POSIX_MADV_WILLNEED) failed: %s\n",
strerror(errno));
}
}
@ -207,7 +208,7 @@ struct llama_mmap {
DWORD error = GetLastError();
if (hMapping == NULL) {
throw format("CreateFileMappingA failed: %s", llama_format_win_err(error).c_str());
throw std::runtime_error(format("CreateFileMappingA failed: %s", llama_format_win_err(error).c_str()));
}
addr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0);
@ -215,7 +216,7 @@ struct llama_mmap {
CloseHandle(hMapping);
if (addr == NULL) {
throw format("MapViewOfFile failed: %s", llama_format_win_err(error).c_str());
throw std::runtime_error(format("MapViewOfFile failed: %s", llama_format_win_err(error).c_str()));
}
#if _WIN32_WINNT >= _WIN32_WINNT_WIN8
@ -243,8 +244,9 @@ struct llama_mmap {
#else
static constexpr bool SUPPORTED = false;
llama_mmap(struct llama_file *) {
throw std::string("mmap not supported");
llama_mmap(struct llama_file *, bool prefetch = true) {
(void)prefetch;
throw std::runtime_error(std::string("mmap not supported"));
}
#endif
};
@ -265,9 +267,9 @@ struct llama_mlock {
}
}
void init(void * addr) {
LLAMA_ASSERT(this->addr == NULL && this->size == 0);
this->addr = addr;
void init(void * ptr) {
LLAMA_ASSERT(addr == NULL && size == 0);
addr = ptr;
}
void grow_to(size_t target_size) {
@ -338,14 +340,14 @@ struct llama_mlock {
return (size_t) si.dwPageSize;
}
bool raw_lock(void * addr, size_t size) {
bool raw_lock(void * ptr, size_t len) {
for (int tries = 1; ; tries++) {
if (VirtualLock(addr, size)) {
if (VirtualLock(ptr, len)) {
return true;
}
if (tries == 2) {
fprintf(stderr, "warning: failed to VirtualLock %zu-byte buffer (after previously locking %zu bytes): %s\n",
size, this->size, llama_format_win_err(GetLastError()).c_str());
len, size, llama_format_win_err(GetLastError()).c_str());
return false;
}
@ -361,7 +363,7 @@ struct llama_mlock {
// is equal to the number of pages in its minimum working set minus
// a small overhead."
// Hopefully a megabyte is enough overhead:
size_t increment = size + 1048576;
size_t increment = len + 1048576;
// The minimum must be <= the maximum, so we need to increase both:
min_ws_size += increment;
max_ws_size += increment;
@ -373,8 +375,8 @@ struct llama_mlock {
}
}
void raw_unlock(void * addr, size_t size) {
if (!VirtualUnlock(addr, size)) {
void raw_unlock(void * ptr, size_t len) {
if (!VirtualUnlock(ptr, len)) {
fprintf(stderr, "warning: failed to VirtualUnlock buffer: %s\n",
llama_format_win_err(GetLastError()).c_str());
}
@ -382,11 +384,16 @@ struct llama_mlock {
#else
static constexpr bool SUPPORTED = false;
void raw_lock(const void * addr, size_t size) {
fprintf(stderr, "warning: mlock not supported on this system\n");
size_t lock_granularity() {
return (size_t) 65536;
}
void raw_unlock(const void * addr, size_t size) {}
bool raw_lock(const void * addr, size_t len) {
fprintf(stderr, "warning: mlock not supported on this system\n");
return false;
}
void raw_unlock(const void * addr, size_t len) {}
#endif
};
@ -395,36 +402,70 @@ struct llama_buffer {
uint8_t * addr = NULL;
size_t size = 0;
void resize(size_t size) {
llama_buffer() = default;
void resize(size_t len) {
delete[] addr;
addr = new uint8_t[size];
this->size = size;
addr = new uint8_t[len];
size = len;
}
~llama_buffer() {
delete[] addr;
}
// disable copy and move
llama_buffer(const llama_buffer&) = delete;
llama_buffer(llama_buffer&&) = delete;
llama_buffer& operator=(const llama_buffer&) = delete;
llama_buffer& operator=(llama_buffer&&) = delete;
};
#ifdef GGML_USE_CUBLAS
#include "ggml-cuda.h"
struct llama_ctx_buffer {
uint8_t * addr = NULL;
bool is_cuda;
size_t size = 0;
llama_ctx_buffer() = default;
void resize(size_t size) {
if (addr) {
ggml_cuda_host_free(addr);
}
free();
addr = (uint8_t *) ggml_cuda_host_malloc(size);
if (addr) {
is_cuda = true;
}
else {
// fall back to pageable memory
addr = new uint8_t[size];
is_cuda = false;
}
this->size = size;
}
~llama_ctx_buffer() {
void free() {
if (addr) {
ggml_cuda_host_free(addr);
if (is_cuda) {
ggml_cuda_host_free(addr);
}
else {
delete[] addr;
}
}
addr = NULL;
}
~llama_ctx_buffer() {
free();
}
// disable copy and move
llama_ctx_buffer(const llama_ctx_buffer&) = delete;
llama_ctx_buffer(llama_ctx_buffer&&) = delete;
llama_ctx_buffer& operator=(const llama_ctx_buffer&) = delete;
llama_ctx_buffer& operator=(llama_ctx_buffer&&) = delete;
};
#else
typedef llama_buffer llama_ctx_buffer;

File diff suppressed because it is too large Load Diff

View File

@ -19,11 +19,17 @@
# define LLAMA_API
#endif
#define LLAMA_FILE_VERSION 1
#define LLAMA_FILE_MAGIC 'ggjt'
#define LLAMA_FILE_MAGIC_UNVERSIONED 'ggml'
#define LLAMA_SESSION_MAGIC 'ggsn'
#define LLAMA_SESSION_VERSION 0
#define LLAMA_FILE_MAGIC_GGJT 0x67676a74u // 'ggjt'
#define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
#define LLAMA_FILE_MAGIC_GGMF 0x67676d66u // 'ggmf'
#define LLAMA_FILE_MAGIC_GGML 0x67676d6cu // 'ggml'
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
#define LLAMA_FILE_VERSION 3
#define LLAMA_FILE_MAGIC LLAMA_FILE_MAGIC_GGJT
#define LLAMA_FILE_MAGIC_UNVERSIONED LLAMA_FILE_MAGIC_GGML
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
#define LLAMA_SESSION_VERSION 1
#ifdef __cplusplus
extern "C" {
@ -40,9 +46,9 @@ extern "C" {
typedef int llama_token;
typedef struct llama_token_data {
llama_token id; // token id
float logit; // log-odds of the token
float p; // probability of the token
llama_token id; // token id
float logit; // log-odds of the token
float p; // probability of the token
} llama_token_data;
typedef struct llama_token_data_array {
@ -54,9 +60,9 @@ extern "C" {
typedef void (*llama_progress_callback)(float progress, void *ctx);
struct llama_context_params {
int n_ctx; // text context
int n_parts; // -1 for default
int seed; // RNG seed, 0 for random
int n_ctx; // text context
int n_gpu_layers; // number of layers to store in VRAM
int seed; // RNG seed, -1 for random
bool f16_kv; // use fp16 for KV cache
bool logits_all; // the llama_eval() call computes all logits, not just the last one
@ -73,16 +79,16 @@ extern "C" {
// model file types
enum llama_ftype {
LLAMA_FTYPE_ALL_F32 = 0,
LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors
LLAMA_FTYPE_ALL_F32 = 0,
LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16
LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // except 1d tensors
// LLAMA_FTYPE_MOSTLY_Q4_3 (6) support has been removed
LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors
// LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed
// LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // support has been removed
LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors
};
LLAMA_API struct llama_context_params llama_context_default_params();
@ -90,6 +96,13 @@ extern "C" {
LLAMA_API bool llama_mmap_supported();
LLAMA_API bool llama_mlock_supported();
// TODO: not great API - very likely to change
// Initialize the llama + ggml backend
// Call once at the start of the program
LLAMA_API void llama_init_backend();
LLAMA_API int64_t llama_time_us();
// Various functions for loading a ggml llama model.
// Allocate (almost) all memory needed for the model.
// Return NULL on failure
@ -122,26 +135,28 @@ extern "C" {
int n_threads);
// Returns the number of tokens in the KV cache
LLAMA_API int llama_get_kv_cache_token_count(struct llama_context * ctx);
LLAMA_API int llama_get_kv_cache_token_count(const struct llama_context * ctx);
// Sets the current rng seed.
LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, int seed);
// Returns the size in bytes of the state (rng, logits, embedding and kv_cache)
LLAMA_API size_t llama_get_state_size(struct llama_context * ctx);
// Returns the maximum size in bytes of the state (rng, logits, embedding
// and kv_cache) - will often be smaller after compacting tokens
LLAMA_API size_t llama_get_state_size(const struct llama_context * ctx);
// Copies the state to the specified destination address.
// Destination needs to have allocated enough memory.
// Returns the number of bytes copied
LLAMA_API size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest);
LLAMA_API size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst);
// Set the state reading from the specified address
// Returns the number of bytes read
LLAMA_API size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src);
LLAMA_API size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src);
// Save/load session file
LLAMA_API bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out);
LLAMA_API bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count);
// Run the llama inference to obtain the logits and probabilities for the next token.
// tokens + n_tokens is the provided batch of new tokens to process
// n_past is the number of tokens to use from previous eval calls
@ -165,9 +180,9 @@ extern "C" {
int n_max_tokens,
bool add_bos);
LLAMA_API int llama_n_vocab(struct llama_context * ctx);
LLAMA_API int llama_n_ctx (struct llama_context * ctx);
LLAMA_API int llama_n_embd (struct llama_context * ctx);
LLAMA_API int llama_n_vocab(const struct llama_context * ctx);
LLAMA_API int llama_n_ctx (const struct llama_context * ctx);
LLAMA_API int llama_n_embd (const struct llama_context * ctx);
// Token logits obtained from the last call to llama_eval()
// The logits for the last token are stored in the last row
@ -181,7 +196,7 @@ extern "C" {
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
// Token Id -> String. Uses the vocabulary in the provided context
LLAMA_API const char * llama_token_to_str(struct llama_context * ctx, llama_token token);
LLAMA_API const char * llama_token_to_str(const struct llama_context * ctx, llama_token token);
// Special tokens
LLAMA_API llama_token llama_token_bos();
@ -191,25 +206,25 @@ extern "C" {
// Sampling functions
/// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
LLAMA_API void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, llama_token * last_tokens, size_t last_tokens_size, float penalty);
LLAMA_API void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float penalty);
/// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, llama_token * last_tokens, size_t last_tokens_size, float alpha_frequency, float alpha_presence);
LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float alpha_frequency, float alpha_presence);
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates);
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
LLAMA_API void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int k, size_t min_keep = 1);
LLAMA_API void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int k, size_t min_keep);
/// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
LLAMA_API void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep = 1);
LLAMA_API void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep);
/// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
LLAMA_API void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep = 1);
LLAMA_API void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep);
/// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep = 1);
LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep);
LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp);
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.

View File

@ -13,8 +13,11 @@
say "$2"
# Eleven Labs
# To use it, install the elevenlabs module from pip (pip install elevenlabs), register to https://beta.elevenlabs.io to get an api key and paste it in /examples/talk-llama/eleven-labs.py
# To use it, install the elevenlabs module from pip (pip install elevenlabs)
# It's possible to use the API for free with limited number of characters. To increase this limit register to https://beta.elevenlabs.io to get an api key and paste it after 'ELEVEN_API_KEY='
#Keep the line commented to use the free version whitout api key
#
#export ELEVEN_API_KEY=your_api_key
#wd=$(dirname $0)
#script=$wd/eleven-labs.py
#python3 $script $1 "$2" >/dev/null 2>&1

View File

@ -0,0 +1 @@
@powershell -ExecutionPolicy Bypass -F examples\talk\speak.ps1 %1 %2

View File

@ -0,0 +1,12 @@
# Set-ExecutionPolicy -ExecutionPolicy Bypass -Scope CurrentUser
param(
# voice options are David or Zira
[Parameter(Mandatory=$true)][string]$voice,
[Parameter(Mandatory=$true)][string]$text
)
Add-Type -AssemblyName System.Speech;
$speak = New-Object System.Speech.Synthesis.SpeechSynthesizer;
$speak.SelectVoice("Microsoft $voice Desktop");
$speak.Rate="0";
$speak.Speak($text);

View File

@ -33,8 +33,6 @@ struct whisper_params {
int32_t max_tokens = 32;
int32_t audio_ctx = 0;
int32_t n_parts_llama = -1;
float vad_thold = 0.6f;
float freq_thold = 100.0f;
@ -49,7 +47,7 @@ struct whisper_params {
std::string language = "en";
std::string model_wsp = "models/ggml-base.en.bin";
std::string model_llama = "models/ggml-llama-7B.bin";
std::string speak = "./examples/talk-llama/speak.sh";
std::string speak = "./examples/talk-llama/speak";
std::string prompt = "";
std::string fname_out;
std::string path_session = ""; // path to file for saving/loading model eval state
@ -72,7 +70,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
else if (arg == "--n-parts-llama") { params.n_parts_llama = std::stoi(argv[++i]); }
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
@ -123,7 +120,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
fprintf(stderr, " -mw FILE, --model-whisper [%-7s] whisper model file\n", params.model_wsp.c_str());
fprintf(stderr, " -ml FILE, --model-llama [%-7s] llama model file\n", params.model_llama.c_str());
fprintf(stderr, " --n-parts-llama N [%-7d] num parts in llama model file\n", params.n_parts_llama);
fprintf(stderr, " -s FILE, --speak TEXT [%-7s] command for TTS\n", params.speak.c_str());
fprintf(stderr, " --prompt-file FNAME [%-7s] file with custom prompt to start dialog\n", "");
fprintf(stderr, " --session FNAME file to cache model state in (may be large!) (default: none)\n");
@ -239,13 +235,14 @@ int main(int argc, char ** argv) {
// llama init
llama_init_backend();
auto lparams = llama_context_default_params();
// tune these to your liking
lparams.n_ctx = 2048;
lparams.seed = 1;
lparams.f16_kv = true;
lparams.n_parts = params.n_parts_llama;
struct llama_context * ctx_llama = llama_init_from_file(params.model_llama.c_str(), lparams);
@ -560,7 +557,7 @@ int main(int argc, char ** argv) {
embd_inp.insert(embd_inp.end(), embd.begin(), embd.end());
n_past += embd.size();
embd.clear();
if (done) break;
@ -577,7 +574,7 @@ int main(int argc, char ** argv) {
if (!path_session.empty() && need_to_save_session) {
need_to_save_session = false;
llama_save_session_file(ctx_llama, path_session.c_str(), session_tokens.data(), session_tokens.size());
}
}
llama_token id = 0;
@ -609,8 +606,8 @@ int main(int argc, char ** argv) {
id = llama_sample_token_greedy(ctx_llama, &candidates_p);
} else {
// Temperature sampling
llama_sample_top_k(ctx_llama, &candidates_p, top_k);
llama_sample_top_p(ctx_llama, &candidates_p, top_p);
llama_sample_top_k(ctx_llama, &candidates_p, top_k, 1);
llama_sample_top_p(ctx_llama, &candidates_p, top_p, 1);
llama_sample_temperature(ctx_llama, &candidates_p, temp);
id = llama_sample_token(ctx_llama, &candidates_p);
}

View File

@ -191,9 +191,9 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
// create the ggml context
{
struct ggml_init_params params = {
.mem_size = ctx_size,
.mem_buffer = NULL,
.no_alloc = false,
/*.mem_size =*/ ctx_size,
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ false,
};
model.ctx = ggml_init(params);
@ -420,7 +420,6 @@ bool gpt2_eval(
struct ggml_context * ctx0 = ggml_init(params);
struct ggml_cgraph gf = {};
gf.n_threads = n_threads;
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
@ -442,7 +441,7 @@ bool gpt2_eval(
// norm
{
// [ 768, N]
cur = ggml_norm(ctx0, inpL);
cur = ggml_norm(ctx0, inpL, 1e-5f);
// cur = ln_1_g*cur + ln_1_b
// [ 768, N]
@ -589,7 +588,7 @@ bool gpt2_eval(
{
// norm
{
cur = ggml_norm(ctx0, inpFF);
cur = ggml_norm(ctx0, inpFF, 1e-5f);
// cur = ln_2_g*cur + ln_2_b
// [ 768, N]
@ -644,7 +643,7 @@ bool gpt2_eval(
// norm
{
// [ 768, N]
inpL = ggml_norm(ctx0, inpL);
inpL = ggml_norm(ctx0, inpL, 1e-5f);
// inpL = ln_f_g*inpL + ln_f_b
// [ 768, N]
@ -664,8 +663,8 @@ bool gpt2_eval(
//inpL = ggml_soft_max(ctx0, inpL);
// run the computation
ggml_build_forward_expand(&gf, inpL);
ggml_graph_compute (ctx0, &gf);
ggml_build_forward_expand (&gf, inpL);
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
//if (n_past%100 == 0) {
// ggml_graph_print (&gf);

View File

@ -37,5 +37,5 @@ wget --quiet --show-progress -O models/ggml-gpt-2-117M.bin https://huggingface.c
## TTS
For best experience, this example needs a TTS tool to convert the generated text responses to voice.
You can use any TTS engine that you would like - simply edit the [speak.sh](speak.sh) script to your needs.
By default, it is configured to use `espeak`, but you can use whatever you wish.
You can use any TTS engine that you would like - simply edit the [speak](speak) script to your needs.
By default, it is configured to use MacOS's `say` or `espeak` or Windows SpeechSynthesizer, but you can use whatever you wish.

View File

@ -1,23 +1,20 @@
import sys
import importlib.util
api_key = "" #Write your https://beta.elevenlabs.io api key here
if not api_key:
print("To use elevenlabs you have to register to https://beta.elevenlabs.io and add your elevenlabs api key to examples/talk/eleven-labs.py")
sys.exit()
if importlib.util.find_spec("elevenlabs") is None:
print("elevenlabs library is not installed, you can install it to your enviroment using 'pip install elevenlabs'")
sys.exit()
from elevenlabs import ElevenLabs
eleven = ElevenLabs(api_key)
from elevenlabs import generate, play, save
# Get a Voice object, by name or UUID
voice = eleven.voices["Arnold"] #Possible Voices: Adam Antoni Arnold Bella Domi Elli Josh
voice = "Arnold" #Possible Voices: Adam Antoni Arnold Bella Domi Elli Josh
# Generate the TTS
audio = voice.generate(str(sys.argv[2:]))
audio = generate(
text=str(sys.argv[2:]),
voice=voice
)
# Save the TTS to a file
audio.save("audio")
save(audio, "audio.mp3")

View File

@ -379,6 +379,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
// - embd_inp: the embeddings of the tokens in the context
// - embd_w: the predicted logits for the next token
//
// TODO: sync latest version from ggml repo
bool gpt2_eval(
const gpt2_model & model,
const int n_threads,
@ -420,7 +421,6 @@ bool gpt2_eval(
struct ggml_context * ctx0 = ggml_init(params);
struct ggml_cgraph gf = {};
gf.n_threads = n_threads;
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
@ -442,7 +442,7 @@ bool gpt2_eval(
// norm
{
// [ 768, N]
cur = ggml_norm(ctx0, inpL);
cur = ggml_norm(ctx0, inpL, 1e-5f);
// cur = ln_1_g*cur + ln_1_b
// [ 768, N]
@ -589,7 +589,7 @@ bool gpt2_eval(
{
// norm
{
cur = ggml_norm(ctx0, inpFF);
cur = ggml_norm(ctx0, inpFF, 1e-5f);
// cur = ln_2_g*cur + ln_2_b
// [ 768, N]
@ -644,7 +644,7 @@ bool gpt2_eval(
// norm
{
// [ 768, N]
inpL = ggml_norm(ctx0, inpL);
inpL = ggml_norm(ctx0, inpL, 1e-5f);
// inpL = ln_f_g*inpL + ln_f_b
// [ 768, N]
@ -664,8 +664,8 @@ bool gpt2_eval(
//inpL = ggml_soft_max(ctx0, inpL);
// run the computation
ggml_build_forward_expand(&gf, inpL);
ggml_graph_compute (ctx0, &gf);
ggml_build_forward_expand (&gf, inpL);
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
//if (n_past%100 == 0) {
// ggml_graph_print (&gf);

5
examples/talk/speak.sh → examples/talk/speak Executable file → Normal file
View File

@ -13,8 +13,11 @@
say "$2"
# Eleven Labs
# To use it, install the elevenlabs module from pip (pip install elevenlabs), register to https://beta.elevenlabs.io to get an api key and paste it in /examples/talk/eleven-labs.py
# To use it, install the elevenlabs module from pip (pip install elevenlabs)
# It's possible to use the API for free with limited number of characters. To increase this limit register to https://beta.elevenlabs.io to get an api key and paste it after 'ELEVEN_API_KEY='
#Keep the line commented to use the free version without api key
#
#export ELEVEN_API_KEY=your_api_key
#wd=$(dirname $0)
#script=$wd/eleven-labs.py
#python3 $script $1 "$2"

1
examples/talk/speak.bat Normal file
View File

@ -0,0 +1 @@
@powershell -ExecutionPolicy Bypass -F examples\talk\speak.ps1 %1 %2

12
examples/talk/speak.ps1 Normal file
View File

@ -0,0 +1,12 @@
# Set-ExecutionPolicy -ExecutionPolicy Bypass -Scope CurrentUser
param(
# voice options are David or Zira
[Parameter(Mandatory=$true)][string]$voice,
[Parameter(Mandatory=$true)][string]$text
)
Add-Type -AssemblyName System.Speech;
$speak = New-Object System.Speech.Synthesis.SpeechSynthesizer;
$speak.SelectVoice("Microsoft $voice Desktop");
$speak.Rate="0";
$speak.Speak($text);

View File

@ -36,7 +36,7 @@ struct whisper_params {
std::string language = "en";
std::string model_wsp = "models/ggml-base.en.bin";
std::string model_gpt = "models/ggml-gpt-2-117M.bin";
std::string speak = "./examples/talk/speak.sh";
std::string speak = "./examples/talk/speak";
std::string fname_out;
};

View File

@ -18,6 +18,9 @@ android {
vectorDrawables {
useSupportLibrary true
}
ndk {
abiFilters 'arm64-v8a', 'armeabi-v7a', 'x86', 'x86_64'
}
}
buildTypes {
@ -42,8 +45,8 @@ android {
}
ndkVersion "25.1.8937393"
externalNativeBuild {
ndkBuild {
path 'src/main/jni/whisper/Android.mk'
cmake {
path = file("src/main/jni/whisper/CMakeLists.txt")
}
}
packagingOptions {

View File

@ -10,12 +10,16 @@ fun decodeWaveFile(file: File): FloatArray {
file.inputStream().use { it.copyTo(baos) }
val buffer = ByteBuffer.wrap(baos.toByteArray())
buffer.order(ByteOrder.LITTLE_ENDIAN)
val channel = buffer.getShort(22).toInt()
buffer.position(44)
val shortBuffer = buffer.asShortBuffer()
val shortArray = ShortArray(shortBuffer.limit())
shortBuffer.get(shortArray)
return FloatArray(shortArray.size) { index ->
(shortArray[index] / 32767.0f).coerceIn(-1f..1f)
return FloatArray(shortArray.size / channel) { index ->
when (channel) {
1 -> (shortArray[index] / 32767.0f).coerceIn(-1f..1f)
else -> ((shortArray[2*index] + shortArray[2*index + 1])/ 32767.0f / 2.0f).coerceIn(-1f..1f)
}
}
}
@ -73,4 +77,4 @@ private fun headerBytes(totalLength: Int): ByteArray {
it.get(bytes)
return bytes
}
}
}

View File

@ -1,26 +0,0 @@
LOCAL_PATH := $(call my-dir)
include $(CLEAR_VARS)
LOCAL_MODULE := libwhisper
include $(LOCAL_PATH)/Whisper.mk
include $(BUILD_SHARED_LIBRARY)
ifeq ($(TARGET_ARCH_ABI),armeabi-v7a)
include $(CLEAR_VARS)
LOCAL_MODULE := libwhisper_vfpv4
include $(LOCAL_PATH)/Whisper.mk
# Allow building NEON FMA code.
# https://android.googlesource.com/platform/ndk/+/master/sources/android/cpufeatures/cpu-features.h
LOCAL_CFLAGS += -mfpu=neon-vfpv4
include $(BUILD_SHARED_LIBRARY)
endif
ifeq ($(TARGET_ARCH_ABI),arm64-v8a)
include $(CLEAR_VARS)
LOCAL_MODULE := libwhisper_v8fp16_va
include $(LOCAL_PATH)/Whisper.mk
# Allow building NEON FMA code.
# https://android.googlesource.com/platform/ndk/+/master/sources/android/cpufeatures/cpu-features.h
LOCAL_CFLAGS += -march=armv8.2-a+fp16
include $(BUILD_SHARED_LIBRARY)
endif

View File

@ -1 +0,0 @@
APP_STL := c++_static

View File

@ -0,0 +1,53 @@
cmake_minimum_required(VERSION 3.10)
project(whisper.cpp)
set(CMAKE_CXX_STANDARD 11)
set(WHISPER_LIB_DIR ${CMAKE_SOURCE_DIR}/../../../../../../../)
set(
SOURCE_FILES
${WHISPER_LIB_DIR}/ggml.c
${WHISPER_LIB_DIR}/whisper.cpp
${CMAKE_SOURCE_DIR}/jni.c
)
find_library(LOG_LIB log)
function(build_library target_name)
add_library(
${target_name}
SHARED
${SOURCE_FILES}
)
target_link_libraries(${target_name} ${LOG_LIB} android)
if (${target_name} STREQUAL "whisper_v8fp16_va")
target_compile_options(${target_name} PRIVATE -march=armv8.2-a+fp16)
elseif (${target_name} STREQUAL "whisper_vfpv4")
target_compile_options(${target_name} PRIVATE -mfpu=neon-vfpv4)
endif ()
if (NOT ${CMAKE_BUILD_TYPE} STREQUAL "Debug")
target_compile_options(${target_name} PRIVATE -O3)
target_compile_options(${target_name} PRIVATE -fvisibility=hidden -fvisibility-inlines-hidden)
target_compile_options(${target_name} PRIVATE -ffunction-sections -fdata-sections)
target_link_options(${target_name} PRIVATE -Wl,--gc-sections)
target_link_options(${target_name} PRIVATE -Wl,--exclude-libs,ALL)
target_link_options(${target_name} PRIVATE -flto)
endif ()
endfunction()
build_library("whisper") # Default target
if (${ANDROID_ABI} STREQUAL "arm64-v8a")
build_library("whisper_v8fp16_va")
elseif (${ANDROID_ABI} STREQUAL "armeabi-v7a")
build_library("whisper_vfpv4")
endif ()
include_directories(${WHISPER_LIB_DIR})

View File

@ -1,18 +0,0 @@
WHISPER_LIB_DIR := $(LOCAL_PATH)/../../../../../../../
LOCAL_LDLIBS := -landroid -llog
# Make the final output library smaller by only keeping the symbols referenced from the app.
ifneq ($(APP_OPTIM),debug)
LOCAL_CFLAGS += -O3
LOCAL_CFLAGS += -fvisibility=hidden -fvisibility-inlines-hidden
LOCAL_CFLAGS += -ffunction-sections -fdata-sections
LOCAL_LDFLAGS += -Wl,--gc-sections
LOCAL_LDFLAGS += -Wl,--exclude-libs,ALL
LOCAL_LDFLAGS += -flto
endif
LOCAL_CFLAGS += -DSTDC_HEADERS -std=c11 -I $(WHISPER_LIB_DIR)
LOCAL_CPPFLAGS += -std=c++11
LOCAL_SRC_FILES := $(WHISPER_LIB_DIR)/ggml.c \
$(WHISPER_LIB_DIR)/whisper.cpp \
$(LOCAL_PATH)/jni.c

View File

@ -32,7 +32,7 @@ model="base.en"
# export the path to the whisper.cpp repo in the WHISPER_CPP_HOME env variable
# https://github.com/ggerganov/whisper.cpp
cd ${WHISPER_CPP_HOME}
cd "${WHISPER_CPP_HOME}"
if [ ! -f ./stream ] ; then
echo "whisper.nvim: the 'stream' executable was not found! WHISPER_CPP_HOME=${WHISPER_CPP_HOME}" > /tmp/whisper.nvim

View File

@ -14,15 +14,24 @@ https://user-images.githubusercontent.com/1991296/204126266-ce4177c6-6eca-4bd9-b
```java
git clone https://github.com/ggerganov/whisper.cpp
open whisper.cpp/examples/whisper.objc/whisper.objc.xcodeproj/
// If you don't want to convert a Core ML model, you can skip this step by create dummy model
mkdir models/ggml-base.en-encoder.mlmodelc
```
Make sure to build the project in `Release`:
<img width="947" alt="image" src="https://user-images.githubusercontent.com/1991296/197382607-9e1e6d1b-79fa-496f-9d16-b71dc1535701.png">
Also, don't forget to add the `-DGGML_USE_ACCELERATE` compiler flag in Build Phases.
Also, don't forget to add the `-DGGML_USE_ACCELERATE` compiler flag for `ggml.c` in Build Phases.
This can significantly improve the performance of the transcription:
<img width="1072" alt="image" src="https://user-images.githubusercontent.com/1991296/208511239-8d7cdbd1-aa48-41b5-becd-ca288d53cc07.png">
If you want to enable Core ML support, you can add the `-DWHISPER_USE_COREML -DWHISPER_COREML_ALLOW_FALLBACK` compiler flag for `whisper.cpp` in Build Phases:
<img width="1072" alt="image" src="https://github.com/ggerganov/whisper.cpp/assets/3001525/103e8f57-6eb6-490d-a60c-f6cf6c319324">
Then follow the [`Core ML support` section of readme](../../README.md#core-ml-support) for convert the model.
In this project, it also added `-O3 -DNDEBUG` to `Other C Flags`, but adding flags to app proj is not ideal in real world (applies to all C/C++ files), consider splitting xcodeproj in workspace in your own project.

View File

@ -14,9 +14,13 @@
18627C8629052BE000BD2A04 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 18627C8529052BE000BD2A04 /* Assets.xcassets */; };
18627C8929052BE000BD2A04 /* LaunchScreen.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 18627C8729052BE000BD2A04 /* LaunchScreen.storyboard */; };
18627C8C29052BE000BD2A04 /* main.m in Sources */ = {isa = PBXBuildFile; fileRef = 18627C8B29052BE000BD2A04 /* main.m */; };
18627C9429052C4900BD2A04 /* whisper.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 18627C9329052C4900BD2A04 /* whisper.cpp */; };
18627C9429052C4900BD2A04 /* whisper.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 18627C9329052C4900BD2A04 /* whisper.cpp */; settings = {COMPILER_FLAGS = "-DWHISPER_USE_COREML -DWHISPER_COREML_ALLOW_FALLBACK"; }; };
18627C9629052C5800BD2A04 /* ggml.c in Sources */ = {isa = PBXBuildFile; fileRef = 18627C9529052C5800BD2A04 /* ggml.c */; settings = {COMPILER_FLAGS = "-DGGML_USE_ACCELERATE"; }; };
18627C9B29052CFF00BD2A04 /* ggml-base.en.bin in Resources */ = {isa = PBXBuildFile; fileRef = 18627C9A29052CFF00BD2A04 /* ggml-base.en.bin */; };
7FE3424B2A0C3FA20015A058 /* whisper-encoder-impl.m in Sources */ = {isa = PBXBuildFile; fileRef = 7FE342452A0C3FA20015A058 /* whisper-encoder-impl.m */; };
7FE3424C2A0C3FA20015A058 /* whisper-encoder.mm in Sources */ = {isa = PBXBuildFile; fileRef = 7FE342472A0C3FA20015A058 /* whisper-encoder.mm */; };
7FE3424D2A0C3FA20015A058 /* whisper-decoder-impl.m in Sources */ = {isa = PBXBuildFile; fileRef = 7FE3424A2A0C3FA20015A058 /* whisper-decoder-impl.m */; };
7FE3424F2A0C418A0015A058 /* ggml-base.en-encoder.mlmodelc in Resources */ = {isa = PBXBuildFile; fileRef = 7FE3424E2A0C418A0015A058 /* ggml-base.en-encoder.mlmodelc */; };
/* End PBXBuildFile section */
/* Begin PBXFileReference section */
@ -37,6 +41,13 @@
18627C9529052C5800BD2A04 /* ggml.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; name = ggml.c; path = ../../../ggml.c; sourceTree = "<group>"; };
18627C9729052C6600BD2A04 /* ggml.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = ggml.h; path = ../../../ggml.h; sourceTree = "<group>"; };
18627C9A29052CFF00BD2A04 /* ggml-base.en.bin */ = {isa = PBXFileReference; lastKnownFileType = archive.macbinary; name = "ggml-base.en.bin"; path = "../../../models/ggml-base.en.bin"; sourceTree = "<group>"; };
7FE342452A0C3FA20015A058 /* whisper-encoder-impl.m */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.objc; path = "whisper-encoder-impl.m"; sourceTree = "<group>"; };
7FE342462A0C3FA20015A058 /* whisper-encoder.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = "whisper-encoder.h"; sourceTree = "<group>"; };
7FE342472A0C3FA20015A058 /* whisper-encoder.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = "whisper-encoder.mm"; sourceTree = "<group>"; };
7FE342482A0C3FA20015A058 /* whisper-decoder-impl.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = "whisper-decoder-impl.h"; sourceTree = "<group>"; };
7FE342492A0C3FA20015A058 /* whisper-encoder-impl.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = "whisper-encoder-impl.h"; sourceTree = "<group>"; };
7FE3424A2A0C3FA20015A058 /* whisper-decoder-impl.m */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.objc; path = "whisper-decoder-impl.m"; sourceTree = "<group>"; };
7FE3424E2A0C418A0015A058 /* ggml-base.en-encoder.mlmodelc */ = {isa = PBXFileReference; lastKnownFileType = wrapper; name = "ggml-base.en-encoder.mlmodelc"; path = "../../../models/ggml-base.en-encoder.mlmodelc"; sourceTree = "<group>"; };
/* End PBXFileReference section */
/* Begin PBXFrameworksBuildPhase section */
@ -69,6 +80,8 @@
18627C7829052BDF00BD2A04 /* whisper.objc */ = {
isa = PBXGroup;
children = (
7FE3424E2A0C418A0015A058 /* ggml-base.en-encoder.mlmodelc */,
7FE342442A0C3FA20015A058 /* coreml */,
18627C9A29052CFF00BD2A04 /* ggml-base.en.bin */,
18627C9729052C6600BD2A04 /* ggml.h */,
18627C9529052C5800BD2A04 /* ggml.c */,
@ -89,6 +102,20 @@
path = whisper.objc;
sourceTree = "<group>";
};
7FE342442A0C3FA20015A058 /* coreml */ = {
isa = PBXGroup;
children = (
7FE342452A0C3FA20015A058 /* whisper-encoder-impl.m */,
7FE342462A0C3FA20015A058 /* whisper-encoder.h */,
7FE342472A0C3FA20015A058 /* whisper-encoder.mm */,
7FE342482A0C3FA20015A058 /* whisper-decoder-impl.h */,
7FE342492A0C3FA20015A058 /* whisper-encoder-impl.h */,
7FE3424A2A0C3FA20015A058 /* whisper-decoder-impl.m */,
);
name = coreml;
path = ../../../coreml;
sourceTree = "<group>";
};
/* End PBXGroup section */
/* Begin PBXNativeTarget section */
@ -147,6 +174,7 @@
buildActionMask = 2147483647;
files = (
18627C8929052BE000BD2A04 /* LaunchScreen.storyboard in Resources */,
7FE3424F2A0C418A0015A058 /* ggml-base.en-encoder.mlmodelc in Resources */,
18627C8629052BE000BD2A04 /* Assets.xcassets in Resources */,
18627C8429052BDF00BD2A04 /* Main.storyboard in Resources */,
18627C9B29052CFF00BD2A04 /* ggml-base.en.bin in Resources */,
@ -161,11 +189,14 @@
buildActionMask = 2147483647;
files = (
18627C8129052BDF00BD2A04 /* ViewController.m in Sources */,
7FE3424C2A0C3FA20015A058 /* whisper-encoder.mm in Sources */,
18627C9429052C4900BD2A04 /* whisper.cpp in Sources */,
18627C9629052C5800BD2A04 /* ggml.c in Sources */,
18627C7B29052BDF00BD2A04 /* AppDelegate.m in Sources */,
7FE3424D2A0C3FA20015A058 /* whisper-decoder-impl.m in Sources */,
18627C8C29052BE000BD2A04 /* main.m in Sources */,
18627C7E29052BDF00BD2A04 /* SceneDelegate.m in Sources */,
7FE3424B2A0C3FA20015A058 /* whisper-encoder-impl.m in Sources */,
);
runOnlyForDeploymentPostprocessing = 0;
};

View File

@ -10,36 +10,48 @@ fi
qtype0="q5_0"
qtype1="q5_1"
upload="$1"
declare -a filedex
cd `dirname $0`
cd ../
./quantize ./models/ggml-tiny.en.bin ./models/ggml-tiny.en-${qtype1}.bin ${qtype1}
./quantize ./models/ggml-tiny.bin ./models/ggml-tiny-${qtype1}.bin ${qtype1}
# Let's loop across all the objects in the 'models' dir:
for i in ./models/*; do
# Check to see if it's a file or directory
if [ -d "$i" ]; then
# It's a directory! We should make sure it's not empty first:
if [ "$(ls -A $i)" ]; then
# Passed! Let's go searching for bin files (shouldn't need to go more than a layer deep here)
for f in "$i"/*.bin; do
# [Neuron Activation]
newfile=`echo "${f##*/}" | cut -d _ -f 1`;
if [ "$newfile" != "q5" ]; then
./quantize "${f}" "${i:-4}/${i:9:${#i}-4}-${qtype1}.bin" ${qtype1};
./quantize "${f}" "${i:-4}/${i:9:${#i}-4}-${qtype0}.bin" ${qtype0};
filedex+=( "${i:-4}/${i:9:${#i}-4}-${qtype1}.bin" "${i:-4}/${i:9:${#i}-4}-${qtype0}.bin" )
fi
done
fi
else
# It's a file! Let's make sure it's the right type:
if [ "${i##*.}" == "bin" ]; then
# And we probably want to skip the testing files
if [ "${i:9:8}" != "for-test" ]; then
# [Neuron Activation]
./quantize "${i}" "${i:-4}-${qtype1}.bin" ${qtype1};
./quantize "${i}" "${i:-4}-${qtype0}.bin" ${qtype0};
filedex+=( "${i:-4}-${qtype1}.bin" "${i:-4}-${qtype0}.bin" )
fi
fi
fi
done
./quantize ./models/ggml-base.en.bin ./models/ggml-base.en-${qtype1}.bin ${qtype1}
./quantize ./models/ggml-base.bin ./models/ggml-base-${qtype1}.bin ${qtype1}
./quantize ./models/ggml-small.en.bin ./models/ggml-small.en-${qtype1}.bin ${qtype1}
./quantize ./models/ggml-small.bin ./models/ggml-small-${qtype1}.bin ${qtype1}
./quantize ./models/ggml-medium.en.bin ./models/ggml-medium.en-${qtype0}.bin ${qtype0}
./quantize ./models/ggml-medium.bin ./models/ggml-medium-${qtype0}.bin ${qtype0}
./quantize ./models/ggml-large.bin ./models/ggml-large-${qtype0}.bin ${qtype0}
if [ "$upload" == "1" ]; then
scp ./models/ggml-tiny.en-${qtype1}.bin root@linode0:/mnt/Data/ggml/ggml-model-whisper-tiny.en-${qtype1}.bin
scp ./models/ggml-tiny-${qtype1}.bin root@linode0:/mnt/Data/ggml/ggml-model-whisper-tiny-${qtype1}.bin
scp ./models/ggml-base.en-${qtype1}.bin root@linode0:/mnt/Data/ggml/ggml-model-whisper-base.en-${qtype1}.bin
scp ./models/ggml-base-${qtype1}.bin root@linode0:/mnt/Data/ggml/ggml-model-whisper-base-${qtype1}.bin
scp ./models/ggml-small.en-${qtype1}.bin root@linode0:/mnt/Data/ggml/ggml-model-whisper-small.en-${qtype1}.bin
scp ./models/ggml-small-${qtype1}.bin root@linode0:/mnt/Data/ggml/ggml-model-whisper-small-${qtype1}.bin
scp ./models/ggml-medium.en-${qtype0}.bin root@linode0:/mnt/Data/ggml/ggml-model-whisper-medium.en-${qtype0}.bin
scp ./models/ggml-medium-${qtype0}.bin root@linode0:/mnt/Data/ggml/ggml-model-whisper-medium-${qtype0}.bin
scp ./models/ggml-large-${qtype0}.bin root@linode0:/mnt/Data/ggml/ggml-model-whisper-large-${qtype0}.bin
for i in ${!filedex[@]}; do
if [ "${filedex[$i]:9:8}" != "for-test" ]; then
scp ${filedex[$i]} root@linode0:/mnt/Data/ggml/ggml-model-${filedex[$i]:9}
fi
done
fi

View File

@ -4,9 +4,18 @@ cp -rpv ../ggml/src/ggml.c ./ggml.c
cp -rpv ../ggml/src/ggml-cuda.h ./ggml-cuda.h
cp -rpv ../ggml/src/ggml-cuda.cu ./ggml-cuda.cu
cp -rpv ../ggml/src/ggml-opencl.h ./ggml-opencl.h
cp -rpv ../ggml/src/ggml-opencl.c ./ggml-opencl.c
cp -rpv ../ggml/src/ggml-opencl.cpp ./ggml-opencl.cpp
cp -rpv ../ggml/src/ggml-metal.h ./ggml-metal.h
cp -rpv ../ggml/src/ggml-metal.m ./ggml-metal.m
cp -rpv ../ggml/src/ggml-metal.metal ./ggml-metal.metal
cp -rpv ../ggml/include/ggml/ggml.h ./ggml.h
cp -rpv ../ggml/examples/common.h ./examples/common.h
cp -rpv ../ggml/examples/common.cpp ./examples/common.cpp
cp -rpv ../ggml/examples/common-ggml.h ./examples/common-ggml.h
cp -rpv ../ggml/examples/common-ggml.cpp ./examples/common-ggml.cpp
cp -rpv ../ggml/examples/whisper/whisper.h ./whisper.h
cp -rpv ../ggml/examples/whisper/whisper.cpp ./whisper.cpp
cp -rpv ../ggml/examples/whisper/main.cpp ./examples/main/main.cpp
cp -rpv ../ggml/examples/whisper/quantize.cpp ./examples/quantize/quantize.cpp

594
ggml-alloc.c Normal file
View File

@ -0,0 +1,594 @@
#include "ggml-alloc.h"
#include "ggml.h"
#include <assert.h>
#include <stdarg.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#define UNUSED(x) (void)(x)
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define GGML_MAX_CONCUR (2*GGML_MAX_NODES)
//#define GGML_ALLOCATOR_DEBUG
//#define AT_PRINTF printf
#define AT_PRINTF(...) ((void)0)
struct hash_node {
struct ggml_tensor * t;
int n_children;
int n_views;
};
static size_t hash(void * p) {
return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
}
static struct hash_node * hash_get(struct hash_node hash_table[], struct ggml_tensor * t) {
size_t h = hash(t);
// linear probing
size_t i = h;
while (hash_table[i].t != NULL) {
if (hash_table[i].t == t) {
return &hash_table[i];
}
i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
if (i == h) {
// hash table is full
GGML_ASSERT(false);
}
}
hash_table[i].t = t;
return &hash_table[i];
}
// TODO: GGML_PAD ?
static size_t aligned_offset(const void * buffer, size_t offset, size_t alignment) {
assert(alignment && !(alignment & (alignment - 1))); // power of 2
size_t align = (alignment - (((uintptr_t)buffer + offset) % alignment)) % alignment;
return offset + align;
}
struct free_block {
void * addr;
size_t size;
};
#define MAX_FREE_BLOCKS 128
struct ggml_allocr {
void * data;
size_t size;
size_t alignment;
int n_free_blocks;
struct free_block free_blocks[MAX_FREE_BLOCKS];
struct hash_node hash_table[GGML_GRAPH_HASHTABLE_SIZE];
size_t max_size;
bool measure;
int parse_seq[GGML_MAX_CONCUR];
int parse_seq_len;
#ifdef GGML_ALLOCATOR_DEBUG
struct ggml_tensor * allocated_tensors[1024];
#endif
};
#ifdef GGML_ALLOCATOR_DEBUG
static void add_allocated_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
for (int i = 0; i < 1024; i++) {
if (alloc->allocated_tensors[i] == NULL) {
alloc->allocated_tensors[i] = tensor;
return;
}
}
GGML_ASSERT(!"out of allocated_tensors");
}
static void remove_allocated_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
for (int i = 0; i < 1024; i++) {
if (alloc->allocated_tensors[i] == tensor ||
(alloc->allocated_tensors[i] != NULL && alloc->allocated_tensors[i]->data == tensor->data)) {
alloc->allocated_tensors[i] = NULL;
return;
}
}
printf("tried to free tensor %s not found\n", tensor->name);
GGML_ASSERT(!"tensor not found");
}
#endif
static size_t ggml_allocator_get_alloc_size(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
return ggml_nbytes(tensor);
UNUSED(alloc);
}
void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
size_t size = ggml_allocator_get_alloc_size(alloc, tensor);
size = aligned_offset(NULL, size, alloc->alignment);
AT_PRINTF("%s: allocating %s (%zu bytes) - ", __func__, tensor->name, size);
size_t max_avail = 0;
// find the best fitting free block besides the last block
int best_fit_block = -1;
size_t best_fit_size = SIZE_MAX;
for (int i = 0; i < alloc->n_free_blocks - 1; i++) {
struct free_block * block = &alloc->free_blocks[i];
max_avail = MAX(max_avail, block->size);
if (block->size >= size && block->size <= best_fit_size) {
best_fit_block = i;
best_fit_size = block->size;
}
}
AT_PRINTF("block %d\n", best_fit_block);
if (best_fit_block == -1) {
// the last block is our last resort
struct free_block * block = &alloc->free_blocks[alloc->n_free_blocks - 1];
if (block->size >= size) {
best_fit_block = alloc->n_free_blocks - 1;
max_avail = MAX(max_avail, block->size);
} else {
fprintf(stderr, "%s: not enough space in the buffer (needed %zu, largest block available %zu)\n",
__func__, size, max_avail);
GGML_ASSERT(!"not enough space in the buffer");
return;
}
}
struct free_block * block = &alloc->free_blocks[best_fit_block];
void * addr = block->addr;
block->addr = (char*)block->addr + size;
block->size -= size;
if (block->size == 0) {
// remove block if empty
alloc->n_free_blocks--;
for (int j = best_fit_block; j < alloc->n_free_blocks; j++) {
alloc->free_blocks[j] = alloc->free_blocks[j+1];
}
}
tensor->data = addr;
#ifdef GGML_ALLOCATOR_DEBUG
add_allocated_tensor(alloc, tensor);
size_t cur_max = (char*)addr - (char*)alloc->data + size;
if (cur_max > alloc->max_size) {
printf("max_size = %.2f MB: tensors: ", cur_max / 1024.0 / 1024.0);
for (int i = 0; i < 1024; i++) {
if (alloc->allocated_tensors[i]) {
printf("%s (%.2f MB) ", alloc->allocated_tensors[i]->name, ggml_nbytes(alloc->allocated_tensors[i]) / 1024.0 / 1024.0);
}
}
printf("\n");
}
#endif
alloc->max_size = MAX(alloc->max_size, (char*)addr - (char*)alloc->data + size);
}
// this is a very naive implementation, but for our case the number of free blocks should be very small
static void ggml_allocator_free_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
void * ptr = tensor->data;
if (ptr < alloc->data || (char*)ptr >= (char*)alloc->data + alloc->max_size) {
// the tensor was not allocated in this buffer
// this can happen because the graph allocator will try to free weights and other tensors from different buffers
// the easiest way to deal with this is just to ignore it
return;
}
size_t size = ggml_allocator_get_alloc_size(alloc, tensor);
size = aligned_offset(NULL, size, alloc->alignment);
AT_PRINTF("%s: freeing %s (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, size, alloc->n_free_blocks);
#ifdef GGML_ALLOCATOR_DEBUG
remove_allocated_tensor(alloc, tensor);
#endif
// see if we can merge with an existing block
for (int i = 0; i < alloc->n_free_blocks; i++) {
struct free_block * block = &alloc->free_blocks[i];
// check if ptr is at the end of the block
if ((char*)block->addr + block->size == ptr) {
block->size += size;
// check if we can merge with the next block
if (i < alloc->n_free_blocks - 1 && (char*)block->addr + block->size == alloc->free_blocks[i+1].addr) {
block->size += alloc->free_blocks[i+1].size;
alloc->n_free_blocks--;
for (int j = i+1; j < alloc->n_free_blocks; j++) {
alloc->free_blocks[j] = alloc->free_blocks[j+1];
}
}
return;
}
// check if ptr is at the beginning of the block
if ((char*)ptr + size == block->addr) {
block->addr = ptr;
block->size += size;
// check if we can merge with the previous block
if (i > 0 && (char*)alloc->free_blocks[i-1].addr + alloc->free_blocks[i-1].size == block->addr) {
alloc->free_blocks[i-1].size += block->size;
alloc->n_free_blocks--;
for (int j = i; j < alloc->n_free_blocks; j++) {
alloc->free_blocks[j] = alloc->free_blocks[j+1];
}
}
return;
}
}
// otherwise, add a new block
GGML_ASSERT(alloc->n_free_blocks < MAX_FREE_BLOCKS && "out of free blocks");
// insert the new block in the correct position to keep the array sorted by address (to make merging blocks faster)
int insert_pos = 0;
while (insert_pos < alloc->n_free_blocks && alloc->free_blocks[insert_pos].addr < ptr) {
insert_pos++;
}
// shift all blocks from insert_pos onward to make room for the new block
for (int i = alloc->n_free_blocks; i > insert_pos; i--) {
alloc->free_blocks[i] = alloc->free_blocks[i-1];
}
// insert the new block
alloc->free_blocks[insert_pos].addr = ptr;
alloc->free_blocks[insert_pos].size = size;
alloc->n_free_blocks++;
}
void ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, const int * list, int n) {
for (int i = 0; i < n; i++) {
alloc->parse_seq[i] = list[i];
}
alloc->parse_seq_len = n;
}
void ggml_allocr_reset(struct ggml_allocr * alloc) {
alloc->n_free_blocks = 1;
size_t align_offset = aligned_offset(alloc->data, 0, alloc->alignment);
alloc->free_blocks[0].addr = (char *)alloc->data + align_offset;
alloc->free_blocks[0].size = alloc->size - align_offset;
}
struct ggml_allocr * ggml_allocr_new(void * data, size_t size, size_t alignment) {
struct ggml_allocr * alloc = (struct ggml_allocr *)malloc(sizeof(struct ggml_allocr) /* + n_free_blocks * sizeof(struct free_block) */);
*alloc = (struct ggml_allocr){
/*.data = */ data,
/*.size = */ size,
/*.alignment = */ alignment,
/*.n_free_blocks = */ 0,
/*.free_blocks = */ {{0}},
/*.hash_table = */ {{0}},
/*.max_size = */ 0,
/*.measure = */ false,
/*.parse_seq = */ {0},
/*.parse_seq_len = */ 0,
#ifdef GGML_ALLOCATOR_DEBUG
/*.allocated_tensors = */ {0},
#endif
};
ggml_allocr_reset(alloc);
return alloc;
}
// address and size of the buffer when measuring
// it needs to be large enough to fit all the tensors, but it cannot overlap with other existing buffers
static void * const MEASURE_BASE_ADDR = (void *) 0x1000;
static const size_t MEASURE_MAX_SIZE = 1ULL<<40; // 1 TB
struct ggml_allocr * ggml_allocr_new_measure(size_t alignment) {
struct ggml_allocr * alloc = (struct ggml_allocr *)malloc(sizeof(struct ggml_allocr) /* + n_free_blocks * sizeof(struct free_block) */);
*alloc = (struct ggml_allocr){
/*.data = */ MEASURE_BASE_ADDR,
/*.size = */ MEASURE_MAX_SIZE,
/*.alignment = */ alignment,
/*.n_free_blocks = */ 0,
/*.free_blocks = */ {{0}},
/*.hash_table = */ {{0}},
/*.max_size = */ 0,
/*.measure = */ true,
/*.parse_seq = */ {0},
/*.parse_seq_len = */ 0,
#ifdef GGML_ALLOCATOR_DEBUG
/*.allocated_tensors = */ {0},
#endif
};
ggml_allocr_reset(alloc);
return alloc;
}
void ggml_allocr_free(struct ggml_allocr * alloc) {
free(alloc);
}
bool ggml_allocr_is_measure(struct ggml_allocr * alloc) {
return alloc->measure;
}
//////////// compute graph allocator
static bool ggml_is_view(struct ggml_tensor * t) {
return t->op == GGML_OP_RESHAPE || t->op == GGML_OP_VIEW || t->op == GGML_OP_TRANSPOSE ||
t->op == GGML_OP_PERMUTE || t->op == GGML_OP_CPY;
}
static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
if (a->type != b->type) {
return false;
}
for (int i = 0; i < GGML_MAX_DIMS; i++) {
if (a->ne[i] != b->ne[i]) {
return false;
}
if (a->nb[i] != b->nb[i]) {
return false;
}
}
return true;
}
static struct ggml_tensor * get_view_parent(struct ggml_tensor * t) {
switch (t->op) {
case GGML_OP_PERMUTE:
case GGML_OP_RESHAPE:
case GGML_OP_TRANSPOSE:
case GGML_OP_VIEW:
return t->src[0];
case GGML_OP_CPY:
return t->src[1];
default:
return NULL;
}
}
static struct ggml_tensor * get_view_source(struct ggml_tensor * t) {
struct ggml_tensor * parent = t;
do {
parent = get_view_parent(parent);
} while (ggml_is_view(parent));
return parent;
}
static bool ggml_op_can_inplace(enum ggml_op op) {
switch (op) {
case GGML_OP_SCALE:
case GGML_OP_DIAG_MASK_ZERO:
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_ADD:
case GGML_OP_ADD1:
case GGML_OP_ACC:
case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_SQR:
case GGML_OP_SQRT:
case GGML_OP_LOG:
case GGML_OP_UNARY:
case GGML_OP_ROPE:
case GGML_OP_RMS_NORM:
case GGML_OP_SET:
case GGML_OP_SOFT_MAX:
case GGML_OP_CONT:
case GGML_OP_ADD_REL_POS:
return true;
default:
return false;
}
}
static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node) {
struct hash_node * ht = alloc->hash_table;
if (node->data == NULL) {
if (ggml_is_view(node)) {
size_t offset;
switch(node->op) {
case GGML_OP_VIEW:
memcpy(&offset, node->op_params, sizeof(size_t));
node->data = (char *) node->src[0]->data + offset;
break;
case GGML_OP_PERMUTE:
case GGML_OP_RESHAPE:
case GGML_OP_TRANSPOSE:
node->data = node->src[0]->data;
break;
case GGML_OP_CPY:
node->data = node->src[1]->data;
break;
default:
GGML_ASSERT(!"unknown view op");
break;
}
} else {
// see if we can reuse a parent's buffer (inplace)
if (ggml_op_can_inplace(node->op)) {
for (int i = 0; i < GGML_MAX_SRC; i++) {
struct ggml_tensor * parent = node->src[i];
if (parent == NULL) {
break;
}
// if the node's data is external, then we cannot re-use it
if ((char *) parent->data < (char *) alloc->data ||
(char *) parent->data >= ((char *) alloc->data + alloc->size)) {
AT_PRINTF("not reusing parent %s for %s as %p is external\n", parent->name, node->name, parent->data);
continue;
}
struct hash_node * p_hn = hash_get(ht, parent);
if (parent->data != NULL && p_hn->n_children == 1 && p_hn->n_views == 0 && ggml_are_same_layout(node, parent)) {
if (ggml_is_view(parent)) {
struct ggml_tensor * view_src = get_view_source(parent);
struct hash_node * view_src_hn = hash_get(ht, view_src);
if (view_src_hn->n_views == 1 && view_src_hn->n_children == 0 && view_src->data == parent->data) {
// TODO: the offset of the view parent must be kept to ensure that the op doesn't overwrite
// the parent's data that it will need later (same layout requirement). the problem is that then
// we cannot free the tensor because the original address of the allocation is lost.
// adding a view_src pointer to the tensor would solve this and simplify the code dealing with views
// for now, we only reuse the parent's data if the offset is zero (view_src->data == parent->data)
AT_PRINTF("reusing view parent %s (%s) for %s\n", parent->name, view_src->name, node->name);
node->data = parent->data;
return;
}
}
else {
AT_PRINTF("reusing parent %s for %s\n", parent->name, node->name);
node->data = parent->data;
return;
}
}
}
}
ggml_allocr_alloc(alloc, node);
}
}
}
static size_t ggml_allocator_alloc_graph_tensors_n(
struct ggml_allocr * alloc,
struct ggml_cgraph ** graphs, int n_graphs,
struct ggml_tensor *** inputs, struct ggml_tensor *** outputs) {
// reset hash table
struct hash_node * ht = alloc->hash_table;
memset(ht, 0, sizeof(struct hash_node) * GGML_GRAPH_HASHTABLE_SIZE);
// count number of children and views
for (int g = 0; g < n_graphs; g++) {
struct ggml_cgraph * gf = graphs[g];
for (int i = 0; i < gf->n_nodes; i++) {
struct ggml_tensor * node = gf->nodes[i];
if (ggml_is_view(node)) {
struct ggml_tensor * view_src = get_view_source(node);
hash_get(ht, view_src)->n_views += 1;
}
for (int j = 0; j < GGML_MAX_SRC; j++) {
struct ggml_tensor * parent = node->src[j];
if (parent == NULL) {
break;
}
hash_get(ht, parent)->n_children += 1;
}
}
}
// allocate tensors
for (int g = 0; g < n_graphs; g++) {
struct ggml_cgraph * gf = graphs[g];
AT_PRINTF("####### graph %d/%d\n", g, n_graphs);
// graph inputs are allocated first to ensure that they are not overwritten by each other
if (inputs != NULL && inputs[g] != NULL) {
for (int i = 0; inputs[g][i] != NULL; i++) {
struct ggml_tensor * input = inputs[g][i];
AT_PRINTF("input: %s\n", input->name);
allocate_node(alloc, input);
}
}
// if we have parse_seq then we allocate nodes following the list, and we only free nodes at barriers
int last_barrier_pos = 0;
int n_nodes = alloc->parse_seq_len ? alloc->parse_seq_len : gf->n_nodes;
for (int ind = 0; ind < n_nodes; ind++) {
// allocate a node if there is no parse_seq or this is not a barrier
if ((alloc->parse_seq_len==0) || alloc->parse_seq[ind] != -1) {
int i = alloc->parse_seq_len ? alloc->parse_seq[ind] : ind;
struct ggml_tensor * node = gf->nodes[i];
// allocate parents (leafs)
for (int j = 0; j < GGML_MAX_SRC; j++) {
struct ggml_tensor * parent = node->src[j];
if (parent == NULL) {
break;
}
allocate_node(alloc, parent);
}
// allocate node
allocate_node(alloc, node);
AT_PRINTF("exec: %s (%s) <= ", ggml_op_name(node->op), node->name);
for (int j = 0; j < GGML_MAX_SRC; j++) {
struct ggml_tensor * parent = node->src[j];
if (parent == NULL) {
break;
}
AT_PRINTF("%s", parent->name);
if (j < GGML_MAX_SRC - 1 && node->src[j + 1] != NULL) {
AT_PRINTF(", ");
}
}
AT_PRINTF("\n");
}
// update parents
// update immediately if there is no parse_seq
// update only at barriers if there is parse_seq
if ((alloc->parse_seq_len==0) || alloc->parse_seq[ind] == -1) {
int update_start = alloc->parse_seq_len ? last_barrier_pos : ind;
int update_end = alloc->parse_seq_len ? ind : ind + 1;
for (int i = update_start; i < update_end; i++) {
int node_i = alloc->parse_seq_len ? alloc->parse_seq[i] : i;
struct ggml_tensor * node = gf->nodes[node_i];
for (int j = 0; j < GGML_MAX_SRC; j++) {
struct ggml_tensor * parent = node->src[j];
if (parent == NULL) {
break;
}
struct hash_node * p_hn = hash_get(ht, parent);
p_hn->n_children -= 1;
//AT_PRINTF("parent %s: %d children, %d views\n", parent->name, parent->n_children, parent->n_views);
if (p_hn->n_children == 0 && p_hn->n_views == 0) {
if (ggml_is_view(parent)) {
struct ggml_tensor * view_src = get_view_source(parent);
struct hash_node * view_src_hn = hash_get(ht, view_src);
view_src_hn->n_views -= 1;
AT_PRINTF("view_src %s: %d children, %d views\n", view_src->name, view_src_hn->n_children, view_src_hn->n_views);
if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0 && view_src->data != node->data) {
ggml_allocator_free_tensor(alloc, view_src);
}
}
else {
if (parent->data != node->data) {
ggml_allocator_free_tensor(alloc, parent);
}
}
}
}
}
AT_PRINTF("\n");
if (alloc->parse_seq_len) {
last_barrier_pos = ind + 1;
}
}
}
// free graph outputs here that wouldn't be freed otherwise because they have no children
if (outputs != NULL && outputs[g] != NULL) {
for (int i = 0; outputs[g][i] != NULL; i++) {
struct ggml_tensor * output = outputs[g][i];
AT_PRINTF("output: %s\n", output->name);
ggml_allocator_free_tensor(alloc, output);
}
}
}
return alloc->max_size;
}
size_t ggml_allocr_alloc_graph(struct ggml_allocr * alloc, struct ggml_cgraph * graph) {
return ggml_allocator_alloc_graph_tensors_n(alloc, &graph, 1, NULL, NULL);
}

26
ggml-alloc.h Normal file
View File

@ -0,0 +1,26 @@
#pragma once
#include "ggml.h"
#ifdef __cplusplus
extern "C" {
#endif
GGML_API struct ggml_allocr * ggml_allocr_new(void * data, size_t size, size_t alignment);
GGML_API struct ggml_allocr * ggml_allocr_new_measure(size_t alignment);
// tell the allocator to parse nodes following the order described in the list
// you should call this if your graph are optimized to execute out-of-order
GGML_API void ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, const int * list, int n);
GGML_API void ggml_allocr_free(struct ggml_allocr * alloc);
GGML_API bool ggml_allocr_is_measure(struct ggml_allocr * alloc);
GGML_API void ggml_allocr_reset(struct ggml_allocr * alloc);
GGML_API void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor);
GGML_API size_t ggml_allocr_alloc_graph(struct ggml_allocr * alloc, struct ggml_cgraph * graph);
#ifdef __cplusplus
}
#endif

File diff suppressed because it is too large Load Diff

View File

@ -1,18 +1,45 @@
#pragma once
#include "ggml.h"
#ifdef GGML_USE_HIPBLAS
#define GGML_CUDA_NAME "ROCm"
#define GGML_CUBLAS_NAME "hipBLAS"
#else
#define GGML_CUDA_NAME "CUDA"
#define GGML_CUBLAS_NAME "cuBLAS"
#endif
#ifdef __cplusplus
extern "C" {
#endif
void ggml_init_cublas(void);
#define GGML_CUDA_MAX_DEVICES 16
bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
void ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, void * wdata, size_t wsize);
GGML_API void ggml_init_cublas(void);
GGML_API void * ggml_cuda_host_malloc(size_t size);
GGML_API void ggml_cuda_host_free(void * ptr);
// TODO: export these with GGML_API
void * ggml_cuda_host_malloc(size_t size);
void ggml_cuda_host_free(void * ptr);
GGML_API bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
GGML_API void ggml_cuda_set_tensor_split(const float * tensor_split);
GGML_API void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor);
GGML_API void ggml_cuda_free_data(struct ggml_tensor * tensor);
GGML_API void ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
GGML_API void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
GGML_API void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor);
GGML_API void ggml_cuda_assign_buffers_no_alloc(struct ggml_tensor * tensor);
GGML_API void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset);
GGML_API void ggml_cuda_set_main_device(int main_device);
GGML_API void ggml_cuda_set_mul_mat_q(bool mul_mat_q);
GGML_API void ggml_cuda_set_scratch_size(size_t scratch_size);
GGML_API void ggml_cuda_free_scratch(void);
GGML_API bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor);
GGML_API int ggml_cuda_get_device_count(void);
GGML_API void ggml_cuda_get_device_description(int device, char * description, size_t description_size);
#ifdef __cplusplus
}

85
ggml-metal.h Normal file
View File

@ -0,0 +1,85 @@
// An interface allowing to compute ggml_cgraph with Metal
//
// This is a fully functional interface that extends ggml with GPU support for Apple devices.
// A similar interface can be created for other GPU backends (e.g. Vulkan, CUDA, OpenCL, etc.)
//
// How it works?
//
// As long as your program can create and evaluate a ggml_cgraph on the CPU, you can use this
// interface to evaluate the same graph on the GPU. Instead of using ggml_graph_compute(), you
// use ggml_metal_graph_compute() (or ggml_vulkan_graph_compute(), etc.)
//
// You only need to make sure that all memory buffers that you used during the graph creation
// are mapped to the device memory with the ggml_metal_add_buffer() function. This mapping is
// used during the graph evaluation to determine the arguments of the compute kernels.
//
// Synchronization between device and host memory (for example for input and output tensors)
// is done with the ggml_metal_set_tensor() and ggml_metal_get_tensor() functions.
//
#pragma once
#include <stddef.h>
#include <stdbool.h>
// max memory buffers that can be mapped to the device
#define GGML_METAL_MAX_BUFFERS 16
#define GGML_METAL_MAX_COMMAND_BUFFERS 32
struct ggml_tensor;
struct ggml_cgraph;
#ifdef __cplusplus
extern "C" {
#endif
struct ggml_metal_context;
// number of command buffers to use
struct ggml_metal_context * ggml_metal_init(int n_cb);
void ggml_metal_free(struct ggml_metal_context * ctx);
void * ggml_metal_host_malloc(size_t n);
void ggml_metal_host_free (void * data);
// set the number of command buffers to use
void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb);
// creates a mapping between a host memory buffer and a device memory buffer
// - make sure to map all buffers used in the graph before calling ggml_metal_graph_compute
// - the mapping is used during computation to determine the arguments of the compute kernels
// - you don't need to keep the host memory buffer allocated as it is never accessed by Metal
// - max_size specifies the maximum size of a tensor and is used to create shared views such
// that it is guaranteed that the tensor will fit in at least one of the views
//
bool ggml_metal_add_buffer(
struct ggml_metal_context * ctx,
const char * name,
void * data,
size_t size,
size_t max_size);
// set data from host memory into the device
void ggml_metal_set_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t);
// get data from the device into host memory
void ggml_metal_get_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t);
// try to find operations that can be run concurrently in the graph
// you should run it again if the topology of your graph changes
void ggml_metal_graph_find_concurrency(struct ggml_metal_context * ctx, struct ggml_cgraph * gf, bool check_mem);
// if the graph has been optimized for concurrently dispatch, return length of the concur_list if optimized
int ggml_metal_if_optimized(struct ggml_metal_context * ctx);
// output the concur_list for ggml_alloc
int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx);
// same as ggml_graph_compute but uses Metal
// creates gf->n_threads command buffers in parallel
void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
#ifdef __cplusplus
}
#endif

Some files were not shown because too many files have changed in this diff Show More