Compare commits

..

79 Commits

Author SHA1 Message Date
00ea21668b whisper : account speed_up flag for short audio (close #405) 2023-01-15 12:42:15 +02:00
0b85e8c401 Update README.md 2023-01-15 11:36:20 +02:00
fafd78945d bench.wasm : print system info 2023-01-15 11:34:03 +02:00
8de452c18b Improve decoding (#291)
* whisper : prepare infra for new decoding strategies

* whisper : apply logit filters and compute logprobs

* whisper : add whisper_get_logits()

* whisper : separate self and cross attention memory

Initial step needed for supporting parallel decoders

* whisper : move probs_id buffer to whisper_context

* whisper : refactor kv cache into separate struct

* whisper : move self-attention kv cache to whisper_decoder

* whisper : wip decoding parameters + strategies

* whisper : wip decoding parameters + strategies (part 2)

* whisper : wip decoding parameters + strategies (part 3)

* whisper : wip decoding parameters + strategies (part 4)

* whisper : fix prompt_past update to not include prompt_init

* whisper : temperature + best_of support

* whisper : support for compression_ration_threshold

We actually use entropy, but it is similar

* command : fix example to use logits instead of obsolete probs

* whisper : handle empty sequence ranking

* whisper : add WHISPER_DEBUG + diagnostic prints + new main args

* whisper : minor fixes

* whisper : add beam-search support

* whisper : bug fix when there no previous context

* whisper : add comments

* stream : disable temperature fallback

For real-time processing, we always want a single decoder running at T=0

* whisper.swiftui : update example - fix paths + add empty folders
2023-01-15 11:29:57 +02:00
a6dbd9188b stream : fix a bug that inserted a lot of empty audio at the start
The quality was terrible due to this
2023-01-14 19:20:47 +02:00
4ef3398e8f ggml : remove obsolete zeroing + comment fixes (#390) 2023-01-08 20:21:03 +02:00
5e9f33596f readme : clarify main and stream usage (#391)
Give an example of ./main that uses a sample file that's already there, and make the stream example clarify you need `make stream`
2023-01-08 20:18:41 +02:00
8d7b29cedd ggml : correct behaviour of ggml_vec_sum_f32 (#390) 2023-01-08 20:06:09 +02:00
08dc705a69 whisper : fix sample_to_timestamp calculation with 64 bit precision to avoid overflow (#388)
* Do calculation with 64 bit precision to avoid overflow

* Update whisper.cpp

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2023-01-08 15:08:45 +02:00
1512545149 whisper : add loader class to allow loading from buffer and others (#353)
* whisper : add loader to allow loading from other than file

* whisper : rename whisper_init to whisper_init_from_file

* whisper : add whisper_init_from_buffer

* android : Delete local.properties

* android : load models directly from assets

* whisper : adding <stddef.h> needed for size_t + code style

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2023-01-08 13:03:33 +02:00
52a3e0c92a ggml : improve vec_dot_f16 unrolling in flash_attn_f16 2023-01-08 11:41:18 +02:00
d1ea1220ff command : clean-up / refactoring / formatting (#383) 2023-01-07 21:43:24 +02:00
9c4a1522f6 command : always-prompt mode (#383) 2023-01-07 21:41:11 +02:00
f078a6f20e go : adding features to the go-whisper example, go ci, etc (#384)
* Updated bindings so they can be used in third pary packages.

* Updated makefiles to set FMA flag on optionally, for xeon E5 on Darwin

* Added test script

* Changes for examples

* Reverted

* Made the NewContext method private
2023-01-07 21:21:43 +02:00
f30b5d322c ggml : fix bug in new soft max computation 2023-01-07 21:00:07 +02:00
44efbf7ff1 cmake : add -Wno-unused-function + update whisper.js 2023-01-07 20:18:34 +02:00
d347a59a5f ggml : when using BLAS start only 1 CPU thread 2023-01-07 19:48:56 +02:00
6394c906af ggml : fix running tasks with variable number of threads 2023-01-07 19:20:18 +02:00
74ffa14e1d ggml : unroll ggml_vec_dot_f16 in ggml_compute_forward_flash_attn_f16 2023-01-07 19:19:40 +02:00
65fdcbbbbb whisper : revert accidental MB change 2023-01-07 16:18:21 +02:00
d61d55cd4b ggml : speed-up soft max via Accelerate + unroll 2023-01-07 16:16:42 +02:00
d51fc3ee0a ggml : use vDSP_sve and vDSP_maxv from Accelerate 2023-01-07 16:10:16 +02:00
f82a7dd019 ggml : make gcc happy (minor) 2023-01-07 09:34:39 +02:00
87dd4a3081 talk.wasm : bump memory usage + update whisper.js 2023-01-06 21:13:44 +02:00
41e05c6b1b cmake : support AVX2 in Windows better (#381) 2023-01-06 19:36:33 +02:00
fa379cb22a Revert "tmp"
This reverts commit 1652965529.
2023-01-06 19:33:09 +02:00
322f4e6c4e go : bindings updated so they can be used in third party packages. (#379)
* Updated bindings so they can be used in third pary packages.

* Updated makefiles to set FMA flag on optionally, for xeon E5 on Darwin
2023-01-06 19:32:28 +02:00
1652965529 tmp 2023-01-06 19:32:12 +02:00
6042c7a3be cmake : change min required version to 3.0 (#351)
We increase the min version only when want to use particular
functionality that is available in the newer version
2023-01-06 19:25:28 +02:00
6b351bb669 command : add "guided-mode" video demo in the README.md 2023-01-06 18:59:26 +02:00
a62170c656 ggml : add SSE3 and fp16 conversion lookup table (#368)
* Improves WASM performance:
  On MacBook M1 Pro, I observe 25% faster using Firefox and 35% faster using Chrome

* Add support for SSE3 SIMD

* Add SSE3 to system information

* Add Imath support for fp16-fp32 conversions

* Add Imath to system information

* Wrap Imath calls to avoid static function warnings

* Drop Imath; Add lookup table for f16 -> f32 conversions

* Remove TODO comments

* Update SSE3 to new macro arguments

* Correct updated macro definitions

* Prefer static inline where possible

* ggml : static inlines + add public f16 <-> f32 conversions

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2023-01-06 18:45:59 +02:00
1944e7c33e whisper : document POWER VSX support 2023-01-05 23:53:00 +02:00
49a8dd6732 ggml : reorganize POWER9 ppc64le SIMD code 2023-01-05 23:53:00 +02:00
8c7f642286 ggml : change f16 load and store macro arguments 2023-01-05 23:53:00 +02:00
ad2a4ffa03 whisper : do not use F16 tensors when in F32 mode (#369) 2023-01-05 22:56:25 +02:00
b3c865083e ci : add emscripten build 2023-01-05 22:10:20 +02:00
a0d4f8e65c main : make whisper_print_segment_callback() more readable (close #371) 2023-01-05 21:45:05 +02:00
4a214d2f07 cmake : add CMAKE_RUNTIME_OUTPUT_DIRECTORY
Currently needed by the wasm examples
2023-01-05 21:40:59 +02:00
0a0cfa7985 ggml : add void to argument-less functions 2023-01-05 21:40:38 +02:00
196d738974 minor : close #370 + Makefile build info print change 2023-01-05 21:35:45 +02:00
84c6b42e65 cmake : update to 3.19 (#351)
- update from 3.0 (from 2014) to 3.19 (from 2020)
- move some global setting onto the targets (through a cmake include)
2023-01-05 21:22:48 +02:00
dd6d582977 whisper : use ranged-based for loops for readability 2023-01-05 21:20:44 +02:00
d51c5eb906 ggml : define MIN / MAX only if not defined (minor) 2023-01-05 21:16:52 +02:00
0be6a1afd9 make : print build information 2023-01-02 13:35:26 +02:00
a466c3404d stream : fix data race on bool + avoid division-by-zero 2023-01-02 10:20:50 +02:00
d629c034a4 models : fix HF model URL (close #356) 2023-01-02 09:54:43 +02:00
f00509d57c command : refactor to split command list & general transcription modes (#331)
This makes it easier to understand if you're looking for only one of the capabilities.
2022-12-31 14:08:57 +02:00
424c410c42 ggml : improve f16 acceleration for POWER9 ppc64le 2022-12-31 10:02:19 +02:00
d97e6005e9 whisper : add whisper_n_audio_ctx and check for invalid audio_ctx
closes #344
2022-12-31 09:57:19 +02:00
3467230a77 models : fix typo in convert-h5-to-ggml.py
signficant -> significant
2022-12-31 09:49:01 +02:00
a091581eb3 cmake : add runtime destination install (#345)
needed for mingw32 build to successfully install the dlls in the correct location
2022-12-31 09:48:00 +02:00
68daf6e487 whisper : avoid some memory allocations 2022-12-30 13:43:48 +02:00
a593b932e4 main : add -ocsv, aka --output-csv to output a CSV file
Adds -ocsv, aka --output-csv feature to examples/main, which outputs a CSV file containing lines formatted as follows <startTime-in-integer-milliseconds>, <endTime-in-integer-milliseconds>, "<transcript-line-including-commas>".
2022-12-29 14:04:00 +02:00
9a8ad3db69 make : add i686 arch (close #329) 2022-12-29 13:58:55 +02:00
4e0b2069e7 ggml : barrier refactor + static functions 2022-12-28 19:00:53 +02:00
ac521a566e ggml : simplify the SIMD code (#324)
* ggml : simplify the SIMD code

* ggml : generic reduce for all register sizes + comments
2022-12-24 10:22:28 +02:00
331c0bbddc examples : fix memory leak on failure to load gpt2 model (#323) 2022-12-23 20:19:07 +02:00
dc90efd504 examples : small code cleanups (#322)
- remove unnecessary initialization of string to ""
- use empty() instead of checking size()
- use emplace_back instead of push_back
- use nullptr instead of NULL
- remove unnecessary call to .data() on string
- use character overload of find_first_of() instead of passing a string
2022-12-23 20:18:51 +02:00
7282e2109e ggml : use vaddvq_f32 for slightly more efficient reduce 2022-12-23 13:48:19 +02:00
466ceebb78 ggml : add f16 acceleration for POWER9 ppc64le 2022-12-23 13:23:58 +02:00
77226aa89d models : fix support for spaces in path (close #315) 2022-12-23 11:11:38 +02:00
543bd5627e whisper : use emplace_back in place of push_back (#319)
This avoids potential construction of temporaries.
2022-12-23 11:07:19 +02:00
62fee9a9cc whisper : fix mem leak on failure to load model (#318) 2022-12-23 11:06:17 +02:00
493d94130d ggml : make consts static (#317)
These shouldn't be able to be referenced outside the compilation unit.
2022-12-23 11:05:27 +02:00
1480a5f1af Update README.md
Add SwiftUI example links
2022-12-23 11:02:46 +02:00
0f4227d9ee examples : add whisper.swiftui demo app (#308)
* Add SwiftUI demo project.

* Add -DGGML_USE_ACCELERATE
2022-12-23 10:56:18 +02:00
4c1fe0c813 Update README.md
Add bindings links / discussions
2022-12-22 18:22:58 +02:00
fa463313ad minor : small code cleanups (#302)
* Small code cleanups

- fix indentation
- remove extra semicolons
- remove extra break after returns in case statements
- remove unnecessary call to .data() on string
- use empty() instead of checking size()
- no need to check for nullptr before free
- remove unnecessary initialization of string to ""

* minor : switch case always break

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2022-12-22 17:06:19 +02:00
501a6b455c minor : flag "ARM FMA" -> "ARM_FMA" 2022-12-22 16:47:54 +02:00
91fc08c641 Build a vfpv4 library for armeabi-v7a and do runtime detection to select the right library 2022-12-22 16:47:54 +02:00
e1432dd91a Check for both __ARM_NEON and __ARM_FEATURE_FMA so that the project can be compiled for armv7a.
Android armeabi-v7a's NEON support doesn't support FMA unless configured with `-mfpu=neon-fp-armv8`, which would need runtime checks.
* Also removed ABI filter from Android project.
2022-12-22 16:47:54 +02:00
22193cbfe8 Bump NDK version 2022-12-22 16:47:54 +02:00
42c6730732 whisper : use nullptr (C++11) instead of NULL macro (#299) 2022-12-22 16:35:18 +02:00
76b6211f9b cmake : add headers to target (#298)
This will show the header files in IDEs.
2022-12-22 16:34:47 +02:00
86a277f78d go : run go mod tidy before building examples + fix permissions (#296)
* run `go mod tidy` before building examples

Running `make examples` after cloning the repository gives the following
error:

```
...
[100%] Built target whisper
gmake[3]: Leaving directory '/tmp/exp/whisper.cpp/bindings/go/build'
gmake[2]: Leaving directory '/tmp/exp/whisper.cpp/bindings/go/build'
gmake[1]: Leaving directory '/tmp/exp/whisper.cpp/bindings/go/build'
Build example go-model-download
Build example go-whisper
examples/go-whisper/process.go:11:2: missing go.sum entry for module providing package github.com/go-audio/wav (imported by github.com/ggerganov/whisper.cpp/bindings/go/examples/go-whisper); to add:
        go get github.com/ggerganov/whisper.cpp/bindings/go/examples/go-whisper
make: *** [Makefile:26: examples/go-whisper] Error 1
```

* remove executable bit from various files
2022-12-22 16:34:20 +02:00
231bebca7d bindings : initial import of golang bindings (#287)
* Initial import of golang bindings

* Updated makefile rules

* Updated bindings

* Makefile update to add in more tests
2022-12-20 08:54:33 +02:00
90564f85f9 Update README.md 2022-12-19 22:09:21 +02:00
99da1e5cc8 cmake : enable and fix -Wall -Wextra -Wpedantic C++ warnings 2022-12-19 20:45:08 +02:00
8e3f129b4d minor : resolves some of warnings when compiling with clang/clang++ (#294)
* Resolves some of warnings when compiling with clang/clang++

Mostly nit stuff that clang catches when compiling with -Wall -Wextra
-pedantic.

- Fix comparison between sign/unsigned integers.
- Passes a constant reference (const&) instead of copying each time.

* minor : normalize coding style

* minor : fix warning

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2022-12-19 20:19:01 +02:00
93 changed files with 6701 additions and 2301 deletions

17
.github/workflows/bindings.yml vendored Normal file
View File

@ -0,0 +1,17 @@
name: Bindings Tests
on:
push:
paths:
- bindings/go/**
jobs:
ubuntu-latest:
runs-on: ubuntu-latest
steps:
- uses: actions/setup-go@v3
with:
go-version: '^1.19'
- uses: actions/checkout@v1
- run: |
cd bindings/go
make test

View File

@ -235,3 +235,33 @@ jobs:
with:
name: whisper-blas-bin-${{ matrix.arch }}
path: build/bin/${{ matrix.build }}
emscripten:
runs-on: ubuntu-latest
strategy:
matrix:
build: [Release]
steps:
- name: Clone
uses: actions/checkout@v1
- name: Dependencies
run: |
wget -q https://github.com/emscripten-core/emsdk/archive/master.tar.gz
tar -xvf master.tar.gz
emsdk-master/emsdk update
emsdk-master/emsdk install latest
emsdk-master/emsdk activate latest
- name: Configure
run: echo "tmp"
- name: Build
run: |
pushd emsdk-master
source ./emsdk_env.sh
popd
emcmake cmake . -DCMAKE_BUILD_TYPE=${{ matrix.build }}
make

2
.gitignore vendored
View File

@ -8,6 +8,7 @@ build/
build-em/
build-debug/
build-release/
build-static/
build-sanitize-addr/
build-sanitize-thread/
@ -18,6 +19,7 @@ build-sanitize-thread/
/bench
sync.sh
libwhisper.a
libwhisper.so
compile_commands.json

View File

@ -2,14 +2,15 @@ cmake_minimum_required (VERSION 3.0)
project(whisper.cpp VERSION 1.0.4)
set(CMAKE_EXPORT_COMPILE_COMMANDS "on")
# Add path to modules
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
set(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_PREFIX}/lib")
if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
set(WHISPER_STANDALONE ON)
include(cmake/GitVars.cmake)
include(cmake/BuildTypes.cmake)
include(GitVars)
include(BuildTypes)
# configure project version
if (EXISTS "${CMAKE_SOURCE_DIR}/bindings/ios/Makefile-tmpl")
@ -52,6 +53,7 @@ if (APPLE)
option(WHISPER_NO_ACCELERATE "whisper: disable Accelerate framework" OFF)
option(WHISPER_NO_AVX "whisper: disable AVX" OFF)
option(WHISPER_NO_AVX2 "whisper: disable AVX2" OFF)
option(WHISPER_NO_FMA "whisper: disable FMA" OFF)
else()
option(WHISPER_SUPPORT_OPENBLAS "whisper: support for OpenBLAS" OFF)
endif()
@ -82,9 +84,6 @@ endif()
# dependencies
set(CMAKE_C_STANDARD 11)
set(CMAKE_CXX_STANDARD 11)
find_package(Threads REQUIRED)
# on APPLE - include Accelerate framework
@ -131,6 +130,13 @@ if (WHISPER_ALL_WARNINGS)
-Wcast-qual \
-Wstrict-prototypes \
-Wpointer-arith \
-Wno-unused-function \
")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} \
-Wall \
-Wextra \
-Wpedantic \
-Wcast-qual \
")
else()
# todo : msvc
@ -151,6 +157,7 @@ else()
if (MSVC)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX2")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /arch:AVX2")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /arch:AVX2")
else()
if (EMSCRIPTEN)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -pthread")
@ -162,7 +169,10 @@ else()
if(NOT WHISPER_NO_AVX2)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx2")
endif()
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mfma -mf16c")
if(NOT WHISPER_NO_FMA)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mfma")
endif()
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mf16c")
endif()
endif()
endif()
@ -178,10 +188,14 @@ endif()
set(TARGET whisper)
add_library(${TARGET}
ggml.h
ggml.c
whisper.h
whisper.cpp
)
include(DefaultTargetOptions)
target_include_directories(${TARGET} PUBLIC
.
)
@ -215,6 +229,7 @@ target_compile_definitions(${TARGET} PUBLIC
install(TARGETS ${TARGET}
LIBRARY DESTINATION lib
ARCHIVE DESTINATION lib/static
RUNTIME DESTINATION bin
)
#

View File

@ -10,6 +10,9 @@ ifndef UNAME_M
UNAME_M := $(shell uname -m)
endif
CCV := $(shell $(CC) --version | head -n 1)
CXXV := $(shell $(CXX) --version | head -n 1)
# Mac OS + Arm can report x86_64
# ref: https://github.com/ggerganov/whisper.cpp/issues/66#issuecomment-1282546789
ifeq ($(UNAME_S),Darwin)
@ -53,10 +56,13 @@ endif
# Architecture specific
# TODO: probably these flags need to be tweaked on some architectures
# feel free to update the Makefile for your architecture and send a pull request or issue
ifeq ($(UNAME_M),x86_64)
ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686))
ifeq ($(UNAME_S),Darwin)
CFLAGS += -mfma -mf16c
CFLAGS += -mf16c
AVX1_M := $(shell sysctl machdep.cpu.features)
ifneq (,$(findstring FMA,$(AVX1_M)))
CFLAGS += -mfma
endif
ifneq (,$(findstring AVX1.0,$(AVX1_M)))
CFLAGS += -mavx
endif
@ -81,6 +87,10 @@ ifeq ($(UNAME_M),x86_64)
ifneq (,$(findstring f16c,$(F16C_M)))
CFLAGS += -mf16c
endif
SSE3_M := $(shell grep "sse3 " /proc/cpuinfo)
ifneq (,$(findstring sse3,$(SSE3_M)))
CFLAGS += -msse3
endif
else ifeq ($(UNAME_S),Haiku)
AVX1_M := $(shell sysinfo -cpu | grep "AVX ")
ifneq (,$(findstring avx,$(AVX1_M)))
@ -105,6 +115,12 @@ endif
ifeq ($(UNAME_M),amd64)
CFLAGS += -mavx -mavx2 -mfma -mf16c
endif
ifeq ($(UNAME_M),ppc64le)
POWER9_M := $(shell grep "POWER9" /proc/cpuinfo)
ifneq (,$(findstring POWER9,$(POWER9_M)))
CFLAGS += -mpower9-vector
endif
endif
ifndef WHISPER_NO_ACCELERATE
# Mac M1 - include Accelerate framework
ifeq ($(UNAME_S),Darwin)
@ -135,6 +151,21 @@ ifneq ($(filter armv8%,$(UNAME_M)),)
CFLAGS += -mfp16-format=ieee -mno-unaligned-access
endif
#
# Print build information
#
$(info I whisper.cpp build info: )
$(info I UNAME_S: $(UNAME_S))
$(info I UNAME_P: $(UNAME_P))
$(info I UNAME_M: $(UNAME_M))
$(info I CFLAGS: $(CFLAGS))
$(info I CXXFLAGS: $(CXXFLAGS))
$(info I LDFLAGS: $(LDFLAGS))
$(info I CC: $(CCV))
$(info I CXX: $(CXXV))
$(info )
default: main
#

View File

@ -11,6 +11,7 @@ High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisp
- Plain C/C++ implementation without dependencies
- Apple silicon first-class citizen - optimized via Arm Neon and Accelerate framework
- AVX intrinsics support for x86 architectures
- VSX intrinsics support for POWER architectures
- Mixed F16 / F32 precision
- Low memory usage (Flash Attention + Flash Forward)
- Zero memory allocations at runtime
@ -70,7 +71,7 @@ Now build the [main](examples/main) example and transcribe an audio file like th
make
# transcribe an audio file
./main -f input.wav
./main -f samples/jfk.wav
```
---
@ -88,27 +89,36 @@ c++ -I. -I./examples -O3 -std=c++11 -pthread examples/main/main.cpp whisper.o gg
usage: ./main [options] file0.wav file1.wav ...
options:
-h, --help [default] show this help message and exit
-t N, --threads N [4 ] number of threads to use during computation
-p N, --processors N [1 ] number of processors to use during computation
-ot N, --offset-t N [0 ] time offset in milliseconds
-on N, --offset-n N [0 ] segment index offset
-d N, --duration N [0 ] duration of audio to process in milliseconds
-mc N, --max-context N [-1 ] maximum number of text context tokens to store
-ml N, --max-len N [0 ] maximum segment length in characters
-wt N, --word-thold N [0.01 ] word timestamp probability threshold
-su, --speed-up [false ] speed up audio by x2 (reduced accuracy)
-tr, --translate [false ] translate from source language to english
-otxt, --output-txt [false ] output result in a text file
-ovtt, --output-vtt [false ] output result in a vtt file
-osrt, --output-srt [false ] output result in a srt file
-owts, --output-words [false ] output script for generating karaoke video
-ps, --print-special [false ] print special tokens
-pc, --print-colors [false ] print colors
-nt, --no-timestamps [true ] do not print timestamps
-l LANG, --language LANG [en ] spoken language
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path
-f FNAME, --file FNAME [ ] input WAV file path
-h, --help [default] show this help message and exit
-t N, --threads N [4 ] number of threads to use during computation
-p N, --processors N [1 ] number of processors to use during computation
-ot N, --offset-t N [0 ] time offset in milliseconds
-on N, --offset-n N [0 ] segment index offset
-d N, --duration N [0 ] duration of audio to process in milliseconds
-mc N, --max-context N [-1 ] maximum number of text context tokens to store
-ml N, --max-len N [0 ] maximum segment length in characters
-bo N, --best-of N [5 ] number of best candidates to keep
-bs N, --beam-size N [-1 ] beam size for beam search
-wt N, --word-thold N [0.01 ] word timestamp probability threshold
-et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail
-lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail
-su, --speed-up [false ] speed up audio by x2 (reduced accuracy)
-tr, --translate [false ] translate from source language to english
-di, --diarize [false ] stereo audio diarization
-otxt, --output-txt [false ] output result in a text file
-ovtt, --output-vtt [false ] output result in a vtt file
-osrt, --output-srt [false ] output result in a srt file
-owts, --output-words [false ] output script for generating karaoke video
-ocsv, --output-csv [false ] output result in a CSV file
-ps, --print-special [false ] print special tokens
-pc, --print-colors [false ] print colors
-pp, --print-progress [false ] print progress
-nt, --no-timestamps [true ] do not print timestamps
-l LANG, --language LANG [en ] spoken language ('auto' for auto-detect)
--prompt PROMPT [ ] initial prompt
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path
-f FNAME, --file FNAME [ ] input WAV file path
bash ./models/download-ggml-model.sh base.en
Downloading ggml model base.en ...
@ -211,17 +221,7 @@ make large
## Limitations
- Inference only
- No GPU support
- Very basic greedy sampling scheme - always pick up the token with highest probability.
This should be similar to the [GreedyDecoder](https://github.com/openai/whisper/blob/main/whisper/decoding.py#L249-L274)
from the original python implementation, so in order to make a fair comparison between the 2 implementations, make sure
to run the python code with the following parameters:
```
whisper --best_of None --beam_size None ...
```
In the future, `whisper.cpp` will support more sampling strategies.
- No GPU support (yet)
## Another example
@ -306,6 +306,7 @@ The [stream](examples/stream) tool samples the audio every half a second and run
More info is available in [issue #10](https://github.com/ggerganov/whisper.cpp/issues/10).
```java
make stream
./stream -m ./models/ggml-base.en.bin -t 8 --step 500 --length 5000
```
@ -447,12 +448,13 @@ or manually from here:
For more details, see the conversion script [models/convert-pt-to-ggml.py](models/convert-pt-to-ggml.py) or the README
in [models](models).
## Bindings
## [Bindings](https://github.com/ggerganov/whisper.cpp/discussions/categories/bindings)
- [X] Rust: [tazz4843/whisper-rs](https://github.com/tazz4843/whisper-rs)
- [X] Objective-C / Swift: [ggerganov/whisper.spm](https://github.com/ggerganov/whisper.spm)
- [X] Javascript: [bindings/javascript](bindings/javascript)
- [ ] Python: soon
- [X] Rust: [tazz4843/whisper-rs](https://github.com/tazz4843/whisper-rs) | [#310](https://github.com/ggerganov/whisper.cpp/discussions/310)
- [X] Javascript: [bindings/javascript](bindings/javascript) | [#309](https://github.com/ggerganov/whisper.cpp/discussions/309)
- [X] Go: [bindings/go](bindings/go) | [#312](https://github.com/ggerganov/whisper.cpp/discussions/312)
- [X] Objective-C / Swift: [ggerganov/whisper.spm](https://github.com/ggerganov/whisper.spm) | [#313](https://github.com/ggerganov/whisper.cpp/discussions/313)
- [ ] Python: soon | [WIP](https://github.com/ggerganov/whisper.cpp/issues/9)
## Examples
@ -467,6 +469,7 @@ Some of the examples are even ported to run in the browser using WebAssembly. Ch
| [command](examples/command) | [command.wasm](examples/command.wasm) | Basic voice assistant example for receiving voice commands from the mic |
| [talk](examples/talk) | [talk.wasm](examples/talk.wasm) | Talk with a GPT-2 bot |
| [whisper.objc](examples/whisper.objc) | | iOS mobile application using whisper.cpp |
| [whisper.swiftui](examples/whisper.swiftui) | | SwiftUI iOS / macOS application using whisper.cpp |
| [whisper.android](examples/whisper.android) | | Android mobile application using whisper.cpp |
| [whisper.nvim](examples/whisper.nvim) | | Speech-to-text plugin for Neovim |
| [generate-karaoke.sh](examples/generate-karaoke.sh) | | Helper script to easily [generate a karaoke video](https://youtu.be/uj7hVta4blM) of raw audio capture |

2
bindings/go/.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
build
models

21
bindings/go/LICENSE Normal file
View File

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2022 David Thorpe
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

38
bindings/go/Makefile Normal file
View File

@ -0,0 +1,38 @@
BUILD_DIR := build
MODELS_DIR := models
EXAMPLES_DIR := $(wildcard examples/*)
INCLUDE_PATH := $(abspath ../..)
LIBRARY_PATH := $(abspath ../..)
all: clean whisper examples
whisper: mkdir
@echo Build whisper
@${MAKE} -C ../.. libwhisper.a
test: model-small whisper modtidy
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go test -v .
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go test -v ./pkg/whisper/...
examples: $(EXAMPLES_DIR)
model-small: mkdir examples/go-model-download
@${BUILD_DIR}/go-model-download -out models ggml-small.en.bin
$(EXAMPLES_DIR): mkdir whisper modtidy
@echo Build example $(notdir $@)
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go build ${BUILD_FLAGS} -o ${BUILD_DIR}/$(notdir $@) ./$@
mkdir:
@echo Mkdir ${BUILD_DIR}
@install -d ${BUILD_DIR}
@echo Mkdir ${MODELS_DIR}
@install -d ${MODELS_DIR}
modtidy:
@go mod tidy
clean:
@echo Clean
@rm -fr $(BUILD_DIR)
@go clean

100
bindings/go/README.md Normal file
View File

@ -0,0 +1,100 @@
# Go bindings for Whisper
This package provides Go bindings for whisper.cpp. They have been tested on:
* Darwin (OS X) 12.6 on x64_64
* Debian Linux on arm64
* Fedora Linux on x86_64
The "low level" bindings are in the `bindings/go` directory and there is a more
Go-style package in the `bindings/go/pkg/whisper` directory. The most simple usage
is as follows:
```go
import (
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
)
func main() {
var modelpath string // Path to the model
var samples []float32 // Samples to process
// Load the model
model, err := whisper.New(modelpath)
if err != nil {
panic(err)
}
defer model.Close()
// Process samples
context, err := model.NewContext()
if err != nil {
panic(err)
}
if err := context.Process(samples, nil); err != nil {
return err
}
// Print out the results
for {
segment, err := context.NextSegment()
if err != nil {
break
}
fmt.Printf("[%6s->%6s] %s\n", segment.Start, segment.End, segment.Text)
}
}
```
## Building & Testing
In order to build, you need to have the Go compiler installed. You can get it from [here](https://golang.org/dl/). Run the tests with:
```bash
git clone https://github.com/ggerganov/whisper.cpp.git
cd whisper.cpp/bindings/go
make test
```
This will compile a static `libwhisper.a` in a `build` folder, download a model file, then run the tests. To build the examples:
```bash
make examples
```
The examples are placed in the `build` directory. Once built, you can download all the models with the following command:
```bash
./build/go-model-download -out models
```
And you can then test a model against samples with the following command:
```bash
./build/go-whisper -model models/ggml-tiny.en.bin samples/jfk.wav
```
## Using the bindings
To use the bindings in your own software,
1. Import `github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper` (or `github.com/ggerganov/whisper.cpp/bindings/go` into your package;
2. Compile `libwhisper.a` (you can use `make whisper` in the `bindings/go` directory);
3. Link your go binary against whisper by setting the environment variables `C_INCLUDE_PATH` and `LIBRARY_PATH`
to point to the `whisper.h` file directory and `libwhisper.a` file directory respectively.
Look at the `Makefile` in the `bindings/go` directory for an example.
The API Documentation:
* https://pkg.go.dev/github.com/ggerganov/whisper.cpp/bindings/go
* https://pkg.go.dev/github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper
Getting help:
* Follow the discussion for the go bindings [here](https://github.com/ggerganov/whisper.cpp/discussions/312)
## License
The license for the Go bindings is the same as the license for the rest of the whisper.cpp project, which is the MIT License. See the `LICENSE` file for more details.

5
bindings/go/doc.go Normal file
View File

@ -0,0 +1,5 @@
/*
github.com/ggerganov/whisper.cpp/bindings/go
provides a speech-to-text service bindings for the Go programming language.
*/
package whisper

View File

@ -0,0 +1,30 @@
package main
import (
"context"
"os"
"os/signal"
)
// ContextForSignal returns a context object which is cancelled when a signal
// is received. It returns nil if no signal parameter is provided
func ContextForSignal(signals ...os.Signal) context.Context {
if len(signals) == 0 {
return nil
}
ch := make(chan os.Signal)
ctx, cancel := context.WithCancel(context.Background())
// Send message on channel when signal received
signal.Notify(ch, signals...)
// When any signal received, call cancel
go func() {
<-ch
cancel()
}()
// Return success
return ctx
}

View File

@ -0,0 +1,208 @@
package main
import (
"context"
"flag"
"fmt"
"io"
"net/http"
"net/url"
"os"
"path/filepath"
"syscall"
"time"
)
///////////////////////////////////////////////////////////////////////////////
// CONSTANTS
const (
srcUrl = "https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main" // The location of the models
srcExt = ".bin" // Filename extension
bufSize = 1024 * 64 // Size of the buffer used for downloading the model
)
var (
// The models which will be downloaded, if no model is specified as an argument
modelNames = []string{"ggml-tiny.en", "ggml-tiny", "ggml-base.en", "ggml-base", "ggml-small.en", "ggml-small", "ggml-medium.en", "ggml-medium", "ggml-large-v1", "ggml-large"}
)
var (
// The output folder. When not set, use current working directory.
flagOut = flag.String("out", "", "Output folder")
// HTTP timeout parameter - will timeout if takes longer than this to download a model
flagTimeout = flag.Duration("timeout", 30*time.Minute, "HTTP timeout")
// Quiet parameter - will not print progress if set
flagQuiet = flag.Bool("quiet", false, "Quiet mode")
)
///////////////////////////////////////////////////////////////////////////////
// MAIN
func main() {
flag.Usage = func() {
name := filepath.Base(flag.CommandLine.Name())
fmt.Fprintf(flag.CommandLine.Output(), "Usage: %s [options] <model>\n\n", name)
flag.PrintDefaults()
}
flag.Parse()
// Get output path
out, err := GetOut()
if err != nil {
fmt.Fprintln(os.Stderr, "Error:", err)
os.Exit(-1)
}
// Create context which quits on SIGINT or SIGQUIT
ctx := ContextForSignal(os.Interrupt, syscall.SIGQUIT)
// Progress filehandle
progress := os.Stdout
if *flagQuiet {
progress, err = os.Open(os.DevNull)
if err != nil {
fmt.Fprintln(os.Stderr, "Error:", err)
os.Exit(-1)
}
defer progress.Close()
}
// Download models - exit on error or interrupt
for _, model := range GetModels() {
url, err := URLForModel(model)
if err != nil {
fmt.Fprintln(os.Stderr, "Error:", err)
continue
} else if path, err := Download(ctx, progress, url, out); err == nil || err == io.EOF {
continue
} else if err == context.Canceled {
os.Remove(path)
fmt.Fprintln(progress, "\nInterrupted")
break
} else if err == context.DeadlineExceeded {
os.Remove(path)
fmt.Fprintln(progress, "Timeout downloading model")
continue
} else {
os.Remove(path)
fmt.Fprintln(os.Stderr, "Error:", err)
break
}
}
}
///////////////////////////////////////////////////////////////////////////////
// PUBLIC METHODS
// GetOut returns the path to the output directory
func GetOut() (string, error) {
if *flagOut == "" {
return os.Getwd()
}
if info, err := os.Stat(*flagOut); err != nil {
return "", err
} else if !info.IsDir() {
return "", fmt.Errorf("not a directory: %s", info.Name())
} else {
return *flagOut, nil
}
}
// GetModels returns the list of models to download
func GetModels() []string {
if flag.NArg() == 0 {
return modelNames
} else {
return flag.Args()
}
}
// URLForModel returns the URL for the given model on huggingface.co
func URLForModel(model string) (string, error) {
if filepath.Ext(model) != srcExt {
model += srcExt
}
url, err := url.Parse(srcUrl)
if err != nil {
return "", err
} else {
url.Path = filepath.Join(url.Path, model)
}
return url.String(), nil
}
// Download downloads the model from the given URL to the given output directory
func Download(ctx context.Context, p io.Writer, model, out string) (string, error) {
// Create HTTP client
client := http.Client{
Timeout: *flagTimeout,
}
// Initiate the download
req, err := http.NewRequest("GET", model, nil)
if err != nil {
return "", err
}
resp, err := client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("%s: %s", model, resp.Status)
}
// If output file exists and is the same size as the model, skip
path := filepath.Join(out, filepath.Base(model))
if info, err := os.Stat(path); err == nil && info.Size() == resp.ContentLength {
fmt.Fprintln(p, "Skipping", model, "as it already exists")
return "", nil
}
// Create file
w, err := os.Create(path)
if err != nil {
return "", err
}
defer w.Close()
// Report
fmt.Fprintln(p, "Downloading", model, "to", out)
// Progressively download the model
data := make([]byte, bufSize)
count, pct := int64(0), int64(0)
ticker := time.NewTicker(5 * time.Second)
for {
select {
case <-ctx.Done():
// Cancelled, return error
return path, ctx.Err()
case <-ticker.C:
pct = DownloadReport(p, pct, count, resp.ContentLength)
default:
// Read body
n, err := resp.Body.Read(data)
if err != nil {
DownloadReport(p, pct, count, resp.ContentLength)
return path, err
} else if m, err := w.Write(data[:n]); err != nil {
return path, err
} else {
count += int64(m)
}
}
}
}
// Report periodically reports the download progress when percentage changes
func DownloadReport(w io.Writer, pct, count, total int64) int64 {
pct_ := count * 100 / total
if pct_ > pct {
fmt.Fprintf(w, " ...%d MB written (%d%%)\n", count/1e6, pct_)
}
return pct_
}

View File

@ -0,0 +1,22 @@
package main
import "fmt"
///////////////////////////////////////////////////////////////////////////////
// CONSTANTS
const (
Reset = "\033[0m"
RGBPrefix = "\033[38;5;" // followed by RGB values in decimal format separated by colons
RGBSuffix = "m"
)
///////////////////////////////////////////////////////////////////////////////
// PUBLIC METHODS
// Colorize text with RGB values, from 0 to 23
func Colorize(text string, v int) string {
// https://en.wikipedia.org/wiki/ANSI_escape_code#8-bit
// Grayscale colors are in the range 232-255
return RGBPrefix + fmt.Sprint(v%24+232) + RGBSuffix + text + Reset
}

View File

@ -0,0 +1,156 @@
package main
import (
"flag"
"fmt"
"strings"
"time"
// Packages
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
)
///////////////////////////////////////////////////////////////////////////////
// TYPES
type Flags struct {
*flag.FlagSet
}
///////////////////////////////////////////////////////////////////////////////
// LIFECYCLE
func NewFlags(name string, args []string) (*Flags, error) {
flags := &Flags{
FlagSet: flag.NewFlagSet(name, flag.ContinueOnError),
}
// Register the command line arguments
registerFlags(flags)
// Parse command line
if err := flags.Parse(args); err != nil {
return nil, err
}
// Return success
return flags, nil
}
///////////////////////////////////////////////////////////////////////////////
// PUBLIC METHODS
func (flags *Flags) GetModel() string {
return flags.Lookup("model").Value.String()
}
func (flags *Flags) GetLanguage() string {
return flags.Lookup("language").Value.String()
}
func (flags *Flags) IsTranslate() bool {
return flags.Lookup("translate").Value.(flag.Getter).Get().(bool)
}
func (flags *Flags) GetOffset() time.Duration {
return flags.Lookup("offset").Value.(flag.Getter).Get().(time.Duration)
}
func (flags *Flags) GetDuration() time.Duration {
return flags.Lookup("duration").Value.(flag.Getter).Get().(time.Duration)
}
func (flags *Flags) GetThreads() uint {
return flags.Lookup("threads").Value.(flag.Getter).Get().(uint)
}
func (flags *Flags) GetOut() string {
return strings.ToLower(flags.Lookup("out").Value.String())
}
func (flags *Flags) IsSpeedup() bool {
return flags.Lookup("speedup").Value.String() == "true"
}
func (flags *Flags) IsTokens() bool {
return flags.Lookup("tokens").Value.String() == "true"
}
func (flags *Flags) IsColorize() bool {
return flags.Lookup("colorize").Value.String() == "true"
}
func (flags *Flags) GetMaxLen() uint {
return flags.Lookup("max-len").Value.(flag.Getter).Get().(uint)
}
func (flags *Flags) GetMaxTokens() uint {
return flags.Lookup("max-tokens").Value.(flag.Getter).Get().(uint)
}
func (flags *Flags) GetWordThreshold() float32 {
return float32(flags.Lookup("word-thold").Value.(flag.Getter).Get().(float64))
}
func (flags *Flags) SetParams(context whisper.Context) error {
if lang := flags.GetLanguage(); lang != "" && lang != "auto" {
fmt.Fprintf(flags.Output(), "Setting language to %q\n", lang)
if err := context.SetLanguage(lang); err != nil {
return err
}
}
if flags.IsTranslate() && context.IsMultilingual() {
fmt.Fprintf(flags.Output(), "Setting translate to true\n")
context.SetTranslate(true)
}
if offset := flags.GetOffset(); offset != 0 {
fmt.Fprintf(flags.Output(), "Setting offset to %v\n", offset)
context.SetOffset(offset)
}
if duration := flags.GetDuration(); duration != 0 {
fmt.Fprintf(flags.Output(), "Setting duration to %v\n", duration)
context.SetDuration(duration)
}
if flags.IsSpeedup() {
fmt.Fprintf(flags.Output(), "Setting speedup to true\n")
context.SetSpeedup(true)
}
if threads := flags.GetThreads(); threads != 0 {
fmt.Fprintf(flags.Output(), "Setting threads to %d\n", threads)
context.SetThreads(threads)
}
if max_len := flags.GetMaxLen(); max_len != 0 {
fmt.Fprintf(flags.Output(), "Setting max_segment_length to %d\n", max_len)
context.SetMaxSegmentLength(max_len)
}
if max_tokens := flags.GetMaxTokens(); max_tokens != 0 {
fmt.Fprintf(flags.Output(), "Setting max_tokens to %d\n", max_tokens)
context.SetMaxTokensPerSegment(max_tokens)
}
if word_threshold := flags.GetWordThreshold(); word_threshold != 0 {
fmt.Fprintf(flags.Output(), "Setting word_threshold to %f\n", word_threshold)
context.SetTokenThreshold(word_threshold)
}
// Return success
return nil
}
///////////////////////////////////////////////////////////////////////////////
// PRIVATE METHODS
func registerFlags(flag *Flags) {
flag.String("model", "", "Path to the model file")
flag.String("language", "", "Spoken language")
flag.Bool("translate", false, "Translate from source language to english")
flag.Duration("offset", 0, "Time offset")
flag.Duration("duration", 0, "Duration of audio to process")
flag.Uint("threads", 0, "Number of threads to use")
flag.Bool("speedup", false, "Enable speedup")
flag.Uint("max-len", 0, "Maximum segment length in characters")
flag.Uint("max-tokens", 0, "Maximum tokens per segment")
flag.Float64("word-thold", 0, "Maximum segment score")
flag.Bool("tokens", false, "Display tokens")
flag.Bool("colorize", false, "Colorize tokens")
flag.String("out", "", "Output format (srt, none or leave as empty string)")
}

View File

@ -0,0 +1,43 @@
package main
import (
"flag"
"fmt"
"os"
"path/filepath"
// Packages
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
)
func main() {
flags, err := NewFlags(filepath.Base(os.Args[0]), os.Args[1:])
if err == flag.ErrHelp {
os.Exit(0)
} else if err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
} else if flags.GetModel() == "" {
fmt.Fprintln(os.Stderr, "Use -model flag to specify which model file to use")
os.Exit(1)
} else if flags.NArg() == 0 {
fmt.Fprintln(os.Stderr, "No input files specified")
os.Exit(1)
}
// Load model
model, err := whisper.New(flags.GetModel())
if err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
defer model.Close()
// Process files
for _, filename := range flags.Args() {
if err := Process(model, filename, flags); err != nil {
fmt.Fprintln(os.Stderr, err)
continue
}
}
}

View File

@ -0,0 +1,127 @@
package main
import (
"fmt"
"io"
"os"
"time"
// Package imports
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
wav "github.com/go-audio/wav"
)
func Process(model whisper.Model, path string, flags *Flags) error {
var data []float32
// Create processing context
context, err := model.NewContext()
if err != nil {
return err
}
// Set the parameters
if err := flags.SetParams(context); err != nil {
return err
}
// Open the file
fmt.Fprintf(flags.Output(), "Loading %q\n", path)
fh, err := os.Open(path)
if err != nil {
return err
}
defer fh.Close()
// Decode the WAV file - load the full buffer
dec := wav.NewDecoder(fh)
if buf, err := dec.FullPCMBuffer(); err != nil {
return err
} else if dec.SampleRate != whisper.SampleRate {
return fmt.Errorf("unsupported sample rate: %d", dec.SampleRate)
} else if dec.NumChans != 1 {
return fmt.Errorf("unsupported number of channels: %d", dec.NumChans)
} else {
data = buf.AsFloat32Buffer().Data
}
// Segment callback when -tokens is specified
var cb whisper.SegmentCallback
if flags.IsTokens() {
cb = func(segment whisper.Segment) {
fmt.Fprintf(flags.Output(), "%02d [%6s->%6s] ", segment.Num, segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond))
for _, token := range segment.Tokens {
if flags.IsColorize() && context.IsText(token) {
fmt.Fprint(flags.Output(), Colorize(token.Text, int(token.P*24.0)), " ")
} else {
fmt.Fprint(flags.Output(), token.Text, " ")
}
}
fmt.Fprintln(flags.Output(), "")
fmt.Fprintln(flags.Output(), "")
}
}
// Process the data
fmt.Fprintf(flags.Output(), " ...processing %q\n", path)
if err := context.Process(data, cb); err != nil {
return err
}
// Print out the results
switch {
case flags.GetOut() == "srt":
return OutputSRT(os.Stdout, context)
case flags.GetOut() == "none":
return nil
default:
return Output(os.Stdout, context, flags.IsColorize())
}
}
// Output text as SRT file
func OutputSRT(w io.Writer, context whisper.Context) error {
n := 1
for {
segment, err := context.NextSegment()
if err == io.EOF {
return nil
} else if err != nil {
return err
}
fmt.Fprintln(w, n)
fmt.Fprintln(w, srtTimestamp(segment.Start), " --> ", srtTimestamp(segment.End))
fmt.Fprintln(w, segment.Text)
fmt.Fprintln(w, "")
n++
}
}
// Output text to terminal
func Output(w io.Writer, context whisper.Context, colorize bool) error {
for {
segment, err := context.NextSegment()
if err == io.EOF {
return nil
} else if err != nil {
return err
}
fmt.Fprintf(w, "[%6s->%6s]", segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond))
if colorize {
for _, token := range segment.Tokens {
if !context.IsText(token) {
continue
}
fmt.Fprint(w, " ", Colorize(token.Text, int(token.P*24.0)))
}
fmt.Fprint(w, "\n")
} else {
fmt.Fprintln(w, " ", segment.Text)
}
}
}
// Return srtTimestamp
func srtTimestamp(t time.Duration) string {
return fmt.Sprintf("%02d:%02d:%02d,%03d", t/time.Hour, (t%time.Hour)/time.Minute, (t%time.Minute)/time.Second, (t%time.Second)/time.Millisecond)
}

16
bindings/go/go.mod Normal file
View File

@ -0,0 +1,16 @@
module github.com/ggerganov/whisper.cpp/bindings/go
go 1.19
require (
github.com/go-audio/wav v1.1.0
github.com/stretchr/testify v1.8.1
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/go-audio/audio v1.0.0 // indirect
github.com/go-audio/riff v1.0.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

23
bindings/go/go.sum Normal file
View File

@ -0,0 +1,23 @@
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-audio/audio v1.0.0 h1:zS9vebldgbQqktK4H0lUqWrG8P0NxCJVqcj7ZpNnwd4=
github.com/go-audio/audio v1.0.0/go.mod h1:6uAu0+H2lHkwdGsAY+j2wHPNPpPoeg5AaEFh9FlA+Zs=
github.com/go-audio/riff v1.0.0 h1:d8iCGbDvox9BfLagY94fBynxSPHO80LmZCaOsmKxokA=
github.com/go-audio/riff v1.0.0/go.mod h1:l3cQwc85y79NQFCRB7TiPoNiaijp6q8Z0Uv38rVG498=
github.com/go-audio/wav v1.1.0 h1:jQgLtbqBzY7G+BM8fXF7AHUk1uHUviWS4X39d5rsL2g=
github.com/go-audio/wav v1.1.0/go.mod h1:mpe9qfwbScEbkd8uybLuIpTgHyrISw/OTuvjUW2iGtE=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

156
bindings/go/params.go Normal file
View File

@ -0,0 +1,156 @@
package whisper
import (
"fmt"
)
///////////////////////////////////////////////////////////////////////////////
// CGO
/*
#include <whisper.h>
*/
import "C"
///////////////////////////////////////////////////////////////////////////////
// PUBLIC METHODS
func (p *Params) SetTranslate(v bool) {
p.translate = toBool(v)
}
func (p *Params) SetNoContext(v bool) {
p.no_context = toBool(v)
}
func (p *Params) SetSingleSegment(v bool) {
p.single_segment = toBool(v)
}
func (p *Params) SetPrintSpecial(v bool) {
p.print_special = toBool(v)
}
func (p *Params) SetPrintProgress(v bool) {
p.print_progress = toBool(v)
}
func (p *Params) SetPrintRealtime(v bool) {
p.print_realtime = toBool(v)
}
func (p *Params) SetPrintTimestamps(v bool) {
p.print_timestamps = toBool(v)
}
func (p *Params) SetSpeedup(v bool) {
p.speed_up = toBool(v)
}
// Set language id
func (p *Params) SetLanguage(lang int) error {
str := C.whisper_lang_str(C.int(lang))
if str == nil {
return ErrInvalidLanguage
} else {
p.language = str
}
return nil
}
// Get language id
func (p *Params) Language() int {
if p.language == nil {
return -1
}
return int(C.whisper_lang_id(p.language))
}
// Set number of threads to use
func (p *Params) SetThreads(threads int) {
p.n_threads = C.int(threads)
}
// Set start offset in ms
func (p *Params) SetOffset(offset_ms int) {
p.offset_ms = C.int(offset_ms)
}
// Set audio duration to process in ms
func (p *Params) SetDuration(duration_ms int) {
p.duration_ms = C.int(duration_ms)
}
// Set timestamp token probability threshold (~0.01)
func (p *Params) SetTokenThreshold(t float32) {
p.thold_pt = C.float(t)
}
// Set timestamp token sum probability threshold (~0.01)
func (p *Params) SetTokenSumThreshold(t float32) {
p.thold_ptsum = C.float(t)
}
// Set max segment length in characters
func (p *Params) SetMaxSegmentLength(n int) {
p.max_len = C.int(n)
}
// Set max tokens per segment (0 = no limit)
func (p *Params) SetMaxTokensPerSegment(n int) {
p.max_tokens = C.int(n)
}
///////////////////////////////////////////////////////////////////////////////
// PRIVATE METHODS
func toBool(v bool) C.bool {
if v {
return C.bool(true)
}
return C.bool(false)
}
///////////////////////////////////////////////////////////////////////////////
// STRINGIFY
func (p *Params) String() string {
str := "<whisper.params"
str += fmt.Sprintf(" strategy=%v", p.strategy)
str += fmt.Sprintf(" n_threads=%d", p.n_threads)
if p.language != nil {
str += fmt.Sprintf(" language=%s", C.GoString(p.language))
}
str += fmt.Sprintf(" n_max_text_ctx=%d", p.n_max_text_ctx)
str += fmt.Sprintf(" offset_ms=%d", p.offset_ms)
str += fmt.Sprintf(" duration_ms=%d", p.duration_ms)
if p.translate {
str += " translate"
}
if p.no_context {
str += " no_context"
}
if p.single_segment {
str += " single_segment"
}
if p.print_special {
str += " print_special"
}
if p.print_progress {
str += " print_progress"
}
if p.print_realtime {
str += " print_realtime"
}
if p.print_timestamps {
str += " print_timestamps"
}
if p.token_timestamps {
str += " token_timestamps"
}
if p.speed_up {
str += " speed_up"
}
return str + ">"
}

View File

@ -0,0 +1,28 @@
package whisper
import (
"errors"
// Bindings
whisper "github.com/ggerganov/whisper.cpp/bindings/go"
)
///////////////////////////////////////////////////////////////////////////////
// ERRORS
var (
ErrUnableToLoadModel = errors.New("unable to load model")
ErrInternalAppError = errors.New("internal application error")
ErrProcessingFailed = errors.New("processing failed")
ErrUnsupportedLanguage = errors.New("unsupported language")
ErrModelNotMultilingual = errors.New("model is not multilingual")
)
///////////////////////////////////////////////////////////////////////////////
// CONSTANTS
// SampleRate is the sample rate of the audio data.
const SampleRate = whisper.SampleRate
// SampleBits is the number of bytes per sample.
const SampleBits = whisper.SampleBits

View File

@ -0,0 +1,251 @@
package whisper
import (
"io"
"strings"
"time"
// Bindings
whisper "github.com/ggerganov/whisper.cpp/bindings/go"
)
///////////////////////////////////////////////////////////////////////////////
// TYPES
type context struct {
n int
model *model
params whisper.Params
}
// Make sure context adheres to the interface
var _ Context = (*context)(nil)
///////////////////////////////////////////////////////////////////////////////
// LIFECYCLE
func newContext(model *model, params whisper.Params) (Context, error) {
context := new(context)
context.model = model
context.params = params
// Return success
return context, nil
}
///////////////////////////////////////////////////////////////////////////////
// PUBLIC METHODS
// Set the language to use for speech recognition.
func (context *context) SetLanguage(lang string) error {
if context.model.ctx == nil {
return ErrInternalAppError
}
if !context.model.IsMultilingual() {
return ErrModelNotMultilingual
}
if id := context.model.ctx.Whisper_lang_id(lang); id < 0 {
return ErrUnsupportedLanguage
} else if err := context.params.SetLanguage(id); err != nil {
return err
}
// Return success
return nil
}
func (context *context) IsMultilingual() bool {
return context.model.IsMultilingual()
}
// Get language
func (context *context) Language() string {
return whisper.Whisper_lang_str(context.params.Language())
}
// Set translate flag
func (context *context) SetTranslate(v bool) {
context.params.SetTranslate(v)
}
// Set speedup flag
func (context *context) SetSpeedup(v bool) {
context.params.SetSpeedup(v)
}
// Set number of threads to use
func (context *context) SetThreads(v uint) {
context.params.SetThreads(int(v))
}
// Set time offset
func (context *context) SetOffset(v time.Duration) {
context.params.SetOffset(int(v.Milliseconds()))
}
// Set duration of audio to process
func (context *context) SetDuration(v time.Duration) {
context.params.SetOffset(int(v.Milliseconds()))
}
// Set timestamp token probability threshold (~0.01)
func (context *context) SetTokenThreshold(t float32) {
context.params.SetTokenThreshold(t)
}
// Set timestamp token sum probability threshold (~0.01)
func (context *context) SetTokenSumThreshold(t float32) {
context.params.SetTokenSumThreshold(t)
}
// Set max segment length in characters
func (context *context) SetMaxSegmentLength(n uint) {
context.params.SetMaxSegmentLength(int(n))
}
// Set max tokens per segment (0 = no limit)
func (context *context) SetMaxTokensPerSegment(n uint) {
context.params.SetMaxTokensPerSegment(int(n))
}
// Process new sample data and return any errors
func (context *context) Process(data []float32, cb SegmentCallback) error {
if context.model.ctx == nil {
return ErrInternalAppError
}
// If the callback is defined then we force on single_segment mode
if cb != nil {
context.params.SetSingleSegment(true)
}
// We don't do parallel processing at the moment
processors := 0
if processors > 1 {
if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, nil, func(new int) {
if cb != nil {
num_segments := context.model.ctx.Whisper_full_n_segments()
s0 := num_segments - new
for i := s0; i < num_segments; i++ {
cb(toSegment(context.model.ctx, i))
}
}
}); err != nil {
return err
}
} else if err := context.model.ctx.Whisper_full(context.params, data, nil, func(new int) {
if cb != nil {
num_segments := context.model.ctx.Whisper_full_n_segments()
s0 := num_segments - new
for i := s0; i < num_segments; i++ {
cb(toSegment(context.model.ctx, i))
}
}
}); err != nil {
return err
}
// Return success
return nil
}
// Return the next segment of tokens
func (context *context) NextSegment() (Segment, error) {
if context.model.ctx == nil {
return Segment{}, ErrInternalAppError
}
if context.n >= context.model.ctx.Whisper_full_n_segments() {
return Segment{}, io.EOF
}
// Populate result
result := toSegment(context.model.ctx, context.n)
// Increment the cursor
context.n++
// Return success
return result, nil
}
// Test for text tokens
func (context *context) IsText(t Token) bool {
switch {
case context.IsBEG(t):
return false
case context.IsSOT(t):
return false
case whisper.Token(t.Id) >= context.model.ctx.Whisper_token_eot():
return false
case context.IsPREV(t):
return false
case context.IsSOLM(t):
return false
case context.IsNOT(t):
return false
default:
return true
}
}
// Test for "begin" token
func (context *context) IsBEG(t Token) bool {
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_beg()
}
// Test for "start of transcription" token
func (context *context) IsSOT(t Token) bool {
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_sot()
}
// Test for "end of transcription" token
func (context *context) IsEOT(t Token) bool {
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_eot()
}
// Test for "start of prev" token
func (context *context) IsPREV(t Token) bool {
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_prev()
}
// Test for "start of lm" token
func (context *context) IsSOLM(t Token) bool {
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_solm()
}
// Test for "No timestamps" token
func (context *context) IsNOT(t Token) bool {
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_not()
}
// Test for token associated with a specific language
func (context *context) IsLANG(t Token, lang string) bool {
if id := context.model.ctx.Whisper_lang_id(lang); id >= 0 {
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_lang(id)
} else {
return false
}
}
///////////////////////////////////////////////////////////////////////////////
// PRIVATE METHODS
func toSegment(ctx *whisper.Context, n int) Segment {
return Segment{
Num: n,
Text: strings.TrimSpace(ctx.Whisper_full_get_segment_text(n)),
Start: time.Duration(ctx.Whisper_full_get_segment_t0(n)) * time.Millisecond * 10,
End: time.Duration(ctx.Whisper_full_get_segment_t1(n)) * time.Millisecond * 10,
Tokens: toTokens(ctx, n),
}
}
func toTokens(ctx *whisper.Context, n int) []Token {
result := make([]Token, ctx.Whisper_full_n_tokens(n))
for i := 0; i < len(result); i++ {
result[i] = Token{
Id: int(ctx.Whisper_full_get_token_id(n, i)),
Text: strings.TrimSpace(ctx.Whisper_full_get_token_text(n, i)),
P: ctx.Whisper_full_get_token_p(n, i),
}
}
return result
}

View File

@ -0,0 +1,55 @@
package whisper_test
import (
"os"
"testing"
// Packages
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
assert "github.com/stretchr/testify/assert"
)
const (
ModelPath = "../../models/ggml-tiny.bin"
SamplePath = "../../samples/jfk.wav"
)
func Test_Whisper_000(t *testing.T) {
assert := assert.New(t)
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
t.Skip("Skipping test, model not found:", ModelPath)
}
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
t.Skip("Skipping test, sample not found:", SamplePath)
}
// Load model
model, err := whisper.New(ModelPath)
assert.NoError(err)
assert.NotNil(model)
assert.NoError(model.Close())
t.Log("languages=", model.Languages())
}
func Test_Whisper_001(t *testing.T) {
assert := assert.New(t)
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
t.Skip("Skipping test, model not found:", ModelPath)
}
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
t.Skip("Skipping test, sample not found:", SamplePath)
}
// Load model
model, err := whisper.New(ModelPath)
assert.NoError(err)
assert.NotNil(model)
defer model.Close()
// Get context for decoding
ctx, err := model.NewContext()
assert.NoError(err)
assert.NotNil(ctx)
}

View File

@ -0,0 +1,4 @@
/*
This is the higher-level speech-to-text whisper.cpp API for go
*/
package whisper

View File

@ -0,0 +1,85 @@
package whisper
import (
"io"
"time"
)
///////////////////////////////////////////////////////////////////////////////
// TYPES
// SegmentCallback is the callback function for processing segments in real
// time. It is called during the Process function
type SegmentCallback func(Segment)
// Model is the interface to a whisper model. Create a new model with the
// function whisper.New(string)
type Model interface {
io.Closer
// Return a new speech-to-text context.
NewContext() (Context, error)
// Return true if the model is multilingual.
IsMultilingual() bool
// Return all languages supported.
Languages() []string
}
// Context is the speach recognition context.
type Context interface {
SetLanguage(string) error // Set the language to use for speech recognition.
SetTranslate(bool) // Set translate flag
IsMultilingual() bool // Return true if the model is multilingual.
Language() string // Get language
SetOffset(time.Duration) // Set offset
SetDuration(time.Duration) // Set duration
SetThreads(uint) // Set number of threads to use
SetSpeedup(bool) // Set speedup flag
SetTokenThreshold(float32) // Set timestamp token probability threshold
SetTokenSumThreshold(float32) // Set timestamp token sum probability threshold
SetMaxSegmentLength(uint) // Set max segment length in characters
SetMaxTokensPerSegment(uint) // Set max tokens per segment (0 = no limit)
// Process mono audio data and return any errors.
// If defined, newly generated segments are passed to the
// callback function during processing.
Process([]float32, SegmentCallback) error
// After process is called, return segments until the end of the stream
// is reached, when io.EOF is returned.
NextSegment() (Segment, error)
IsBEG(Token) bool // Test for "begin" token
IsSOT(Token) bool // Test for "start of transcription" token
IsEOT(Token) bool // Test for "end of transcription" token
IsPREV(Token) bool // Test for "start of prev" token
IsSOLM(Token) bool // Test for "start of lm" token
IsNOT(Token) bool // Test for "No timestamps" token
IsLANG(Token, string) bool // Test for token associated with a specific language
IsText(Token) bool // Test for text token
}
// Segment is the text result of a speech recognition.
type Segment struct {
// Segment Number
Num int
// Time beginning and end timestamps for the segment.
Start, End time.Duration
// The text of the segment.
Text string
// The tokens of the segment.
Tokens []Token
}
// Token is a text or special token
type Token struct {
Id int
Text string
P float32
}

View File

@ -0,0 +1,100 @@
package whisper
import (
"fmt"
"os"
"runtime"
// Bindings
whisper "github.com/ggerganov/whisper.cpp/bindings/go"
)
///////////////////////////////////////////////////////////////////////////////
// TYPES
type model struct {
path string
ctx *whisper.Context
}
// Make sure model adheres to the interface
var _ Model = (*model)(nil)
///////////////////////////////////////////////////////////////////////////////
// LIFECYCLE
func New(path string) (Model, error) {
model := new(model)
if _, err := os.Stat(path); err != nil {
return nil, err
} else if ctx := whisper.Whisper_init(path); ctx == nil {
return nil, ErrUnableToLoadModel
} else {
model.ctx = ctx
model.path = path
}
// Return success
return model, nil
}
func (model *model) Close() error {
if model.ctx != nil {
model.ctx.Whisper_free()
}
// Release resources
model.ctx = nil
// Return success
return nil
}
///////////////////////////////////////////////////////////////////////////////
// STRINGIFY
func (model *model) String() string {
str := "<whisper.model"
if model.ctx != nil {
str += fmt.Sprintf(" model=%q", model.path)
}
return str + ">"
}
///////////////////////////////////////////////////////////////////////////////
// PUBLIC METHODS
// Return true if model is multilingual (language and translation options are supported)
func (model *model) IsMultilingual() bool {
return model.ctx.Whisper_is_multilingual() != 0
}
// Return all recognized languages. Initially it is set to auto-detect
func (model *model) Languages() []string {
result := make([]string, 0, whisper.Whisper_lang_max_id())
for i := 0; i < whisper.Whisper_lang_max_id(); i++ {
str := whisper.Whisper_lang_str(i)
if model.ctx.Whisper_lang_id(str) >= 0 {
result = append(result, str)
}
}
return result
}
func (model *model) NewContext() (Context, error) {
if model.ctx == nil {
return nil, ErrInternalAppError
}
// Create new context
params := model.ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY)
params.SetTranslate(false)
params.SetPrintSpecial(false)
params.SetPrintProgress(false)
params.SetPrintRealtime(false)
params.SetPrintTimestamps(false)
params.SetThreads(runtime.NumCPU())
// Return new context
return newContext(model, params)
}

BIN
bindings/go/samples/jfk.wav Normal file

Binary file not shown.

419
bindings/go/whisper.go Normal file
View File

@ -0,0 +1,419 @@
package whisper
import (
"errors"
"unsafe"
)
///////////////////////////////////////////////////////////////////////////////
// CGO
/*
#cgo LDFLAGS: -lwhisper -lm -lstdc++
#cgo darwin LDFLAGS: -framework Accelerate
#include <whisper.h>
#include <stdlib.h>
extern void callNewSegment(void* user_data, int new);
extern bool callEncoderBegin(void* user_data);
// Text segment callback
// Called on every newly generated text segment
// Use the whisper_full_...() functions to obtain the text segments
static void whisper_new_segment_cb(struct whisper_context* ctx, int n_new, void* user_data) {
if(user_data != NULL && ctx != NULL) {
callNewSegment(user_data, n_new);
}
}
// Encoder begin callback
// If not NULL, called before the encoder starts
// If it returns false, the computation is aborted
static bool whisper_encoder_begin_cb(struct whisper_context* ctx, void* user_data) {
if(user_data != NULL && ctx != NULL) {
return callEncoderBegin(user_data);
}
return false;
}
// Get default parameters and set callbacks
static struct whisper_full_params whisper_full_default_params_cb(struct whisper_context* ctx, enum whisper_sampling_strategy strategy) {
struct whisper_full_params params = whisper_full_default_params(strategy);
params.new_segment_callback = whisper_new_segment_cb;
params.new_segment_callback_user_data = (void*)(ctx);
params.encoder_begin_callback = whisper_encoder_begin_cb;
params.encoder_begin_callback_user_data = (void*)(ctx);
return params;
}
*/
import "C"
///////////////////////////////////////////////////////////////////////////////
// TYPES
type (
Context C.struct_whisper_context
Token C.whisper_token
TokenData C.struct_whisper_token_data
SamplingStrategy C.enum_whisper_sampling_strategy
Params C.struct_whisper_full_params
)
///////////////////////////////////////////////////////////////////////////////
// GLOBALS
const (
SAMPLING_GREEDY SamplingStrategy = C.WHISPER_SAMPLING_GREEDY
SAMPLING_BEAM_SEARCH SamplingStrategy = C.WHISPER_SAMPLING_BEAM_SEARCH
)
const (
SampleRate = C.WHISPER_SAMPLE_RATE // Expected sample rate, samples per second
SampleBits = uint16(unsafe.Sizeof(C.float(0))) * 8 // Sample size in bits
NumFFT = C.WHISPER_N_FFT
NumMEL = C.WHISPER_N_MEL
HopLength = C.WHISPER_HOP_LENGTH
ChunkSize = C.WHISPER_CHUNK_SIZE
)
var (
ErrTokenizerFailed = errors.New("whisper_tokenize failed")
ErrAutoDetectFailed = errors.New("whisper_lang_auto_detect failed")
ErrConversionFailed = errors.New("whisper_convert failed")
ErrInvalidLanguage = errors.New("invalid language")
)
///////////////////////////////////////////////////////////////////////////////
// PUBLIC METHODS
// Allocates all memory needed for the model and loads the model from the given file.
// Returns NULL on failure.
func Whisper_init(path string) *Context {
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
if ctx := C.whisper_init_from_file(cPath); ctx != nil {
return (*Context)(ctx)
} else {
return nil
}
}
// Frees all memory allocated by the model.
func (ctx *Context) Whisper_free() {
C.whisper_free((*C.struct_whisper_context)(ctx))
}
// Convert RAW PCM audio to log mel spectrogram.
// The resulting spectrogram is stored inside the provided whisper context.
func (ctx *Context) Whisper_pcm_to_mel(data []float32, threads int) error {
if C.whisper_pcm_to_mel((*C.struct_whisper_context)(ctx), (*C.float)(&data[0]), C.int(len(data)), C.int(threads)) == 0 {
return nil
} else {
return ErrConversionFailed
}
}
// This can be used to set a custom log mel spectrogram inside the provided whisper context.
// Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
// n_mel must be 80
func (ctx *Context) Whisper_set_mel(data []float32, n_mel int) error {
if C.whisper_set_mel((*C.struct_whisper_context)(ctx), (*C.float)(&data[0]), C.int(len(data)), C.int(n_mel)) == 0 {
return nil
} else {
return ErrConversionFailed
}
}
// Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper context.
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
// offset can be used to specify the offset of the first frame in the spectrogram.
func (ctx *Context) Whisper_encode(offset, threads int) error {
if C.whisper_encode((*C.struct_whisper_context)(ctx), C.int(offset), C.int(threads)) == 0 {
return nil
} else {
return ErrConversionFailed
}
}
// Run the Whisper decoder to obtain the logits and probabilities for the next token.
// Make sure to call whisper_encode() first.
// tokens + n_tokens is the provided context for the decoder.
// n_past is the number of tokens to use from previous decoder calls.
func (ctx *Context) Whisper_decode(tokens []Token, past, threads int) error {
if C.whisper_decode((*C.struct_whisper_context)(ctx), (*C.whisper_token)(&tokens[0]), C.int(len(tokens)), C.int(past), C.int(threads)) == 0 {
return nil
} else {
return ErrConversionFailed
}
}
// whisper_sample_best() returns the token with the highest probability
func (ctx *Context) Whisper_sample_best() TokenData {
return TokenData(C.whisper_sample_best((*C.struct_whisper_context)(ctx)))
}
// whisper_sample_timestamp() returns the most probable timestamp token
func (ctx *Context) Whisper_sample_timestamp(is_initial bool) TokenData {
return TokenData(C.whisper_sample_timestamp((*C.struct_whisper_context)(ctx), C.bool(is_initial)))
}
// Convert the provided text into tokens. The tokens pointer must be large enough to hold the resulting tokens.
// Returns the number of tokens on success
func (ctx *Context) Whisper_tokenize(text string, tokens []Token) (int, error) {
cText := C.CString(text)
defer C.free(unsafe.Pointer(cText))
if n := C.whisper_tokenize((*C.struct_whisper_context)(ctx), cText, (*C.whisper_token)(&tokens[0]), C.int(len(tokens))); n >= 0 {
return int(n), nil
} else {
return 0, ErrTokenizerFailed
}
}
// Return the id of the specified language, returns -1 if not found
// Examples:
//
// "de" -> 2
// "german" -> 2
func (ctx *Context) Whisper_lang_id(lang string) int {
return int(C.whisper_lang_id(C.CString(lang)))
}
// Largest language id (i.e. number of available languages - 1)
func Whisper_lang_max_id() int {
return int(C.whisper_lang_max_id())
}
// Return the short string of the specified language id (e.g. 2 -> "de"),
// returns empty string if not found
func Whisper_lang_str(id int) string {
return C.GoString(C.whisper_lang_str(C.int(id)))
}
// Use mel data at offset_ms to try and auto-detect the spoken language
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
// Returns the probabilities of all languages.
// ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69
func (ctx *Context) Whisper_lang_auto_detect(offset_ms, n_threads int) ([]float32, error) {
probs := make([]float32, Whisper_lang_max_id()+1)
if n := int(C.whisper_lang_auto_detect((*C.struct_whisper_context)(ctx), C.int(offset_ms), C.int(n_threads), (*C.float)(&probs[0]))); n < 0 {
return nil, ErrAutoDetectFailed
} else {
return probs, nil
}
}
func (ctx *Context) Whisper_n_len() int {
return int(C.whisper_n_len((*C.struct_whisper_context)(ctx)))
}
func (ctx *Context) Whisper_n_vocab() int {
return int(C.whisper_n_vocab((*C.struct_whisper_context)(ctx)))
}
func (ctx *Context) Whisper_n_text_ctx() int {
return int(C.whisper_n_text_ctx((*C.struct_whisper_context)(ctx)))
}
func (ctx *Context) Whisper_n_audio_ctx() int {
return int(C.whisper_n_audio_ctx((*C.struct_whisper_context)(ctx)))
}
func (ctx *Context) Whisper_is_multilingual() int {
return int(C.whisper_is_multilingual((*C.struct_whisper_context)(ctx)))
}
// The probabilities for the next token
//func (ctx *Whisper_context) Whisper_get_probs() []float32 {
// return (*[1 << 30]float32)(unsafe.Pointer(C.whisper_get_probs((*C.struct_whisper_context)(ctx))))[:ctx.Whisper_n_vocab()]
//}
// Token Id -> String. Uses the vocabulary in the provided context
func (ctx *Context) Whisper_token_to_str(token Token) string {
return C.GoString(C.whisper_token_to_str((*C.struct_whisper_context)(ctx), C.whisper_token(token)))
}
// Special tokens
func (ctx *Context) Whisper_token_eot() Token {
return Token(C.whisper_token_eot((*C.struct_whisper_context)(ctx)))
}
// Special tokens
func (ctx *Context) Whisper_token_sot() Token {
return Token(C.whisper_token_sot((*C.struct_whisper_context)(ctx)))
}
// Special tokens
func (ctx *Context) Whisper_token_prev() Token {
return Token(C.whisper_token_prev((*C.struct_whisper_context)(ctx)))
}
// Special tokens
func (ctx *Context) Whisper_token_solm() Token {
return Token(C.whisper_token_solm((*C.struct_whisper_context)(ctx)))
}
// Special tokens
func (ctx *Context) Whisper_token_not() Token {
return Token(C.whisper_token_not((*C.struct_whisper_context)(ctx)))
}
// Special tokens
func (ctx *Context) Whisper_token_beg() Token {
return Token(C.whisper_token_beg((*C.struct_whisper_context)(ctx)))
}
// Special tokens
func (ctx *Context) Whisper_token_lang(lang_id int) Token {
return Token(C.whisper_token_lang((*C.struct_whisper_context)(ctx), C.int(lang_id)))
}
// Task tokens
func Whisper_token_translate() Token {
return Token(C.whisper_token_translate())
}
// Task tokens
func Whisper_token_transcribe() Token {
return Token(C.whisper_token_transcribe())
}
// Performance information
func (ctx *Context) Whisper_print_timings() {
C.whisper_print_timings((*C.struct_whisper_context)(ctx))
}
// Performance information
func (ctx *Context) Whisper_reset_timings() {
C.whisper_reset_timings((*C.struct_whisper_context)(ctx))
}
// Print system information
func Whisper_print_system_info() string {
return C.GoString(C.whisper_print_system_info())
}
// Return default parameters for a strategy
func (ctx *Context) Whisper_full_default_params(strategy SamplingStrategy) Params {
// Get default parameters
return Params(C.whisper_full_default_params_cb((*C.struct_whisper_context)(ctx), C.enum_whisper_sampling_strategy(strategy)))
}
// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
// Uses the specified decoding strategy to obtain the text.
func (ctx *Context) Whisper_full(params Params, samples []float32, encoderBeginCallback func() bool, newSegmentCallback func(int)) error {
registerEncoderBeginCallback(ctx, encoderBeginCallback)
registerNewSegmentCallback(ctx, newSegmentCallback)
defer registerEncoderBeginCallback(ctx, nil)
defer registerNewSegmentCallback(ctx, nil)
if C.whisper_full((*C.struct_whisper_context)(ctx), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples))) == 0 {
return nil
} else {
return ErrConversionFailed
}
}
// Split the input audio in chunks and process each chunk separately using whisper_full()
// It seems this approach can offer some speedup in some cases.
// However, the transcription accuracy can be worse at the beginning and end of each chunk.
func (ctx *Context) Whisper_full_parallel(params Params, samples []float32, processors int, encoderBeginCallback func() bool, newSegmentCallback func(int)) error {
registerEncoderBeginCallback(ctx, encoderBeginCallback)
registerNewSegmentCallback(ctx, newSegmentCallback)
defer registerEncoderBeginCallback(ctx, nil)
defer registerNewSegmentCallback(ctx, nil)
if C.whisper_full_parallel((*C.struct_whisper_context)(ctx), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples)), C.int(processors)) == 0 {
return nil
} else {
return ErrConversionFailed
}
}
// Number of generated text segments.
// A segment can be a few words, a sentence, or even a paragraph.
func (ctx *Context) Whisper_full_n_segments() int {
return int(C.whisper_full_n_segments((*C.struct_whisper_context)(ctx)))
}
// Get the start and end time of the specified segment.
func (ctx *Context) Whisper_full_get_segment_t0(segment int) int64 {
return int64(C.whisper_full_get_segment_t0((*C.struct_whisper_context)(ctx), C.int(segment)))
}
// Get the start and end time of the specified segment.
func (ctx *Context) Whisper_full_get_segment_t1(segment int) int64 {
return int64(C.whisper_full_get_segment_t1((*C.struct_whisper_context)(ctx), C.int(segment)))
}
// Get the text of the specified segment.
func (ctx *Context) Whisper_full_get_segment_text(segment int) string {
return C.GoString(C.whisper_full_get_segment_text((*C.struct_whisper_context)(ctx), C.int(segment)))
}
// Get number of tokens in the specified segment.
func (ctx *Context) Whisper_full_n_tokens(segment int) int {
return int(C.whisper_full_n_tokens((*C.struct_whisper_context)(ctx), C.int(segment)))
}
// Get the token text of the specified token index in the specified segment.
func (ctx *Context) Whisper_full_get_token_text(segment int, token int) string {
return C.GoString(C.whisper_full_get_token_text((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
}
// Get the token of the specified token index in the specified segment.
func (ctx *Context) Whisper_full_get_token_id(segment int, token int) Token {
return Token(C.whisper_full_get_token_id((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
}
// Get token data for the specified token in the specified segment.
// This contains probabilities, timestamps, etc.
func (ctx *Context) whisper_full_get_token_data(segment int, token int) TokenData {
return TokenData(C.whisper_full_get_token_data((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
}
// Get the probability of the specified token in the specified segment.
func (ctx *Context) Whisper_full_get_token_p(segment int, token int) float32 {
return float32(C.whisper_full_get_token_p((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
}
///////////////////////////////////////////////////////////////////////////////
// CALLBACKS
var (
cbNewSegment = make(map[unsafe.Pointer]func(int))
cbEncoderBegin = make(map[unsafe.Pointer]func() bool)
)
func registerNewSegmentCallback(ctx *Context, fn func(int)) {
if fn == nil {
delete(cbNewSegment, unsafe.Pointer(ctx))
} else {
cbNewSegment[unsafe.Pointer(ctx)] = fn
}
}
func registerEncoderBeginCallback(ctx *Context, fn func() bool) {
if fn == nil {
delete(cbEncoderBegin, unsafe.Pointer(ctx))
} else {
cbEncoderBegin[unsafe.Pointer(ctx)] = fn
}
}
//export callNewSegment
func callNewSegment(user_data unsafe.Pointer, new C.int) {
if fn, ok := cbNewSegment[user_data]; ok {
fn(int(new))
}
}
//export callEncoderBegin
func callEncoderBegin(user_data unsafe.Pointer) C.bool {
if fn, ok := cbEncoderBegin[user_data]; ok {
if fn() {
return C.bool(true)
} else {
return C.bool(false)
}
}
return true
}

113
bindings/go/whisper_test.go Normal file
View File

@ -0,0 +1,113 @@
package whisper_test
import (
"os"
"runtime"
"testing"
"time"
// Packages
whisper "github.com/ggerganov/whisper.cpp/bindings/go"
wav "github.com/go-audio/wav"
assert "github.com/stretchr/testify/assert"
)
const (
ModelPath = "models/ggml-small.en.bin"
SamplePath = "samples/jfk.wav"
)
func Test_Whisper_000(t *testing.T) {
assert := assert.New(t)
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
t.Skip("Skipping test, model not found:", ModelPath)
}
ctx := whisper.Whisper_init(ModelPath)
assert.NotNil(ctx)
ctx.Whisper_free()
}
func Test_Whisper_001(t *testing.T) {
assert := assert.New(t)
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
t.Skip("Skipping test, model not found:", ModelPath)
}
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
t.Skip("Skipping test, sample not found:", SamplePath)
}
// Open samples
fh, err := os.Open(SamplePath)
assert.NoError(err)
defer fh.Close()
// Read samples
d := wav.NewDecoder(fh)
buf, err := d.FullPCMBuffer()
assert.NoError(err)
// Run whisper
ctx := whisper.Whisper_init(ModelPath)
assert.NotNil(ctx)
defer ctx.Whisper_free()
params := ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY)
data := buf.AsFloat32Buffer().Data
err = ctx.Whisper_full(params, data, nil, nil)
assert.NoError(err)
// Print out tokens
num_segments := ctx.Whisper_full_n_segments()
assert.GreaterOrEqual(num_segments, 1)
for i := 0; i < num_segments; i++ {
str := ctx.Whisper_full_get_segment_text(i)
assert.NotEmpty(str)
t0 := time.Duration(ctx.Whisper_full_get_segment_t0(i)) * time.Millisecond
t1 := time.Duration(ctx.Whisper_full_get_segment_t1(i)) * time.Millisecond
t.Logf("[%6s->%-6s] %q", t0, t1, str)
}
}
func Test_Whisper_002(t *testing.T) {
assert := assert.New(t)
for i := 0; i < whisper.Whisper_lang_max_id(); i++ {
str := whisper.Whisper_lang_str(i)
assert.NotEmpty(str)
t.Log(str)
}
}
func Test_Whisper_003(t *testing.T) {
threads := runtime.NumCPU()
assert := assert.New(t)
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
t.Skip("Skipping test, model not found:", ModelPath)
}
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
t.Skip("Skipping test, sample not found:", SamplePath)
}
// Open samples
fh, err := os.Open(SamplePath)
assert.NoError(err)
defer fh.Close()
// Read samples
d := wav.NewDecoder(fh)
buf, err := d.FullPCMBuffer()
assert.NoError(err)
// Make the model
ctx := whisper.Whisper_init(ModelPath)
assert.NotNil(ctx)
defer ctx.Whisper_free()
// Get MEL
assert.NoError(ctx.Whisper_pcm_to_mel(buf.AsFloat32Buffer().Data, threads))
// Get Languages
languages, err := ctx.Whisper_lang_auto_detect(0, threads)
assert.NoError(err)
for i, p := range languages {
t.Logf("%s: %f", whisper.Whisper_lang_str(i), p)
}
}

View File

@ -20,7 +20,7 @@ struct whisper_context * g_context;
EMSCRIPTEN_BINDINGS(whisper) {
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
if (g_context == nullptr) {
g_context = whisper_init(path_model.c_str());
g_context = whisper_init_from_file(path_model.c_str());
if (g_context != nullptr) {
return true;
} else {

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,17 @@
# Set the default compile features and properties for a target.
if (NOT TARGET)
message(FATAL_ERROR "TARGET not set before including DefaultTargetOptions")
endif()
target_compile_features(${TARGET}
PRIVATE
cxx_std_11
)
set_target_properties(${TARGET}
PROPERTIES
EXPORT_COMPILE_COMMANDS ON
RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin"
INSTALL_RPATH "${CMAKE_INSTALL_PREFIX}/lib"
)

View File

@ -8,6 +8,8 @@ add_executable(${TARGET}
emscripten.cpp
)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE
whisper
)

View File

@ -28,6 +28,11 @@ void bench_main(size_t index) {
return;
}
{
fprintf(stderr, "\n");
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", n_threads, std::thread::hardware_concurrency(), whisper_print_system_info());
}
if (int ret = whisper_encode(ctx, 0, n_threads) != 0) {
fprintf(stderr, "error: failed to encode model: %d\n", ret);
return;
@ -52,7 +57,7 @@ EMSCRIPTEN_BINDINGS(bench) {
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
for (size_t i = 0; i < g_contexts.size(); ++i) {
if (g_contexts[i] == nullptr) {
g_contexts[i] = whisper_init(path_model.c_str());
g_contexts[i] = whisper_init_from_file(path_model.c_str());
if (g_contexts[i] != nullptr) {
if (g_worker.joinable()) {
g_worker.join();

View File

@ -1,3 +1,6 @@
set(TARGET bench)
add_executable(${TARGET} bench.cpp)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE whisper ${CMAKE_THREAD_LIBS_INIT})

View File

@ -33,7 +33,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
return true;
}
void whisper_print_usage(int argc, char ** argv, const whisper_params & params) {
void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params) {
fprintf(stderr, "\n");
fprintf(stderr, "usage: %s [options]\n", argv[0]);
fprintf(stderr, "\n");
@ -53,7 +53,7 @@ int main(int argc, char ** argv) {
// whisper init
struct whisper_context * ctx = whisper_init(params.model.c_str());
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
{
fprintf(stderr, "\n");

View File

@ -8,6 +8,8 @@ add_executable(${TARGET}
emscripten.cpp
)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE
whisper
)

View File

@ -324,7 +324,7 @@ EMSCRIPTEN_BINDINGS(command) {
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
for (size_t i = 0; i < g_contexts.size(); ++i) {
if (g_contexts[i] == nullptr) {
g_contexts[i] = whisper_init(path_model.c_str());
g_contexts[i] = whisper_init_from_file(path_model.c_str());
if (g_contexts[i] != nullptr) {
g_running = true;
if (g_worker.joinable()) {

View File

@ -2,6 +2,9 @@ if (WHISPER_SUPPORT_SDL2)
# command
set(TARGET command)
add_executable(${TARGET} command.cpp)
include(DefaultTargetOptions)
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
target_link_libraries(${TARGET} PRIVATE whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
endif ()

View File

@ -9,7 +9,19 @@ More info is available in [issue #171](https://github.com/ggerganov/whisper.cpp/
# On Raspberry Pi, use tiny or base models + "-ac 768" for better performance
./command -m ./models/ggml-tiny.en.bin -ac 768 -t 3 -c 0
```
https://user-images.githubusercontent.com/1991296/204038393-2f846eae-c255-4099-a76d-5735c25c49da.mp4
Web version: [examples/command.wasm](/examples/command.wasm)
## Guided mode
"Guided mode" allows you to specify a list of commands (i.e. strings) and the transcription will be guided to classify your command into one from the list. This can be useful in situations where a device is listening only for a small subset of commands.
Initial tests show that this approach might be extremely efficient in terms of performance, since it integrates very well with the "partial Encoder" idea from #137.
```bash
# Run in guided mode, the list of allowed commands is in commands.txt
./command -m ./models/ggml-base.en.bin -cmd ./examples/command/commands.txt
@ -17,9 +29,8 @@ More info is available in [issue #171](https://github.com/ggerganov/whisper.cpp/
./command -m ./models/ggml-tiny.en.bin -cmd ./examples/command/commands.txt -ac 128 -t 3 -c 0
```
https://user-images.githubusercontent.com/1991296/204038393-2f846eae-c255-4099-a76d-5735c25c49da.mp4
https://user-images.githubusercontent.com/1991296/207435352-8fc4ed3f-bde5-4555-9b8b-aeeb76bee969.mp4
Web version: [examples/command.wasm](/examples/command.wasm)
## Building

View File

@ -11,6 +11,7 @@
#include <SDL.h>
#include <SDL_audio.h>
#include <sstream>
#include <cassert>
#include <cstdio>
#include <fstream>
@ -25,7 +26,7 @@
struct whisper_params {
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
int32_t prompt_ms = 5000;
int32_t command_ms = 4000;
int32_t command_ms = 8000;
int32_t capture_id = -1;
int32_t max_tokens = 32;
int32_t audio_ctx = 0;
@ -41,8 +42,9 @@ struct whisper_params {
std::string language = "en";
std::string model = "models/ggml-base.en.bin";
std::string fname_out = "";
std::string commands = "";
std::string fname_out;
std::string commands;
std::string prompt;
};
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
@ -71,6 +73,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
else if (arg == "-cmd" || arg == "--commands") { params.commands = argv[++i]; }
else if (arg == "-p" || arg == "--prompt") { params.prompt = argv[++i]; }
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params);
@ -81,7 +84,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
return true;
}
void whisper_print_usage(int argc, char ** argv, const whisper_params & params) {
void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params) {
fprintf(stderr, "\n");
fprintf(stderr, "usage: %s [options]\n", argv[0]);
fprintf(stderr, "\n");
@ -103,6 +106,7 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
fprintf(stderr, " -cmd FNAME, --commands FNAME [%-7s] text file with allowed commands\n", params.commands.c_str());
fprintf(stderr, " -p, --prompt [%-7s] the required activation prompt\n", params.prompt.c_str());
fprintf(stderr, "\n");
}
@ -387,7 +391,7 @@ bool vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float
float energy_all = 0.0f;
float energy_last = 0.0f;
for (size_t i = 0; i < n_samples; i++) {
for (int i = 0; i < n_samples; i++) {
energy_all += fabsf(pcmf32[i]);
if (i >= n_samples - n_samples_last) {
@ -510,6 +514,433 @@ std::vector<std::string> read_allowed_commands(const std::string & fname) {
return allowed_commands;
}
std::vector<std::string> get_words(const std::string &txt) {
std::vector<std::string> words;
std::istringstream iss(txt);
std::string word;
while (iss >> word) {
words.push_back(word);
}
return words;
}
// returns true if no exit event was received
bool process_sdl_events() {
SDL_Event event;
while (SDL_PollEvent(&event)) {
switch (event.type) {
case SDL_QUIT:
{
return false;
} break;
default:
break;
}
}
return true;
}
// command-list mode
// guide the transcription to match the most likely command from a provided list
int process_command_list(struct whisper_context * ctx, audio_async &audio, const whisper_params &params) {
fprintf(stderr, "\n");
fprintf(stderr, "%s: guided mode\n", __func__);
std::vector<std::string> allowed_commands = read_allowed_commands(params.commands);
if (allowed_commands.empty()) {
fprintf(stderr, "%s: error: failed to read allowed commands from '%s'\n", __func__, params.commands.c_str());
return 2;
}
int max_len = 0;
std::vector<std::vector<whisper_token>> allowed_tokens;
for (const auto & cmd : allowed_commands) {
whisper_token tokens[1024];
allowed_tokens.emplace_back();
for (int l = 0; l < (int) cmd.size(); ++l) {
// NOTE: very important to add the whitespace !
// the reason is that the first decoded token starts with a whitespace too!
std::string ss = std::string(" ") + cmd.substr(0, l + 1);
const int n = whisper_tokenize(ctx, ss.c_str(), tokens, 1024);
if (n < 0) {
fprintf(stderr, "%s: error: failed to tokenize command '%s'\n", __func__, cmd.c_str());
return 3;
}
if (n == 1) {
allowed_tokens.back().push_back(tokens[0]);
}
}
max_len = std::max(max_len, (int) cmd.size());
}
fprintf(stderr, "%s: allowed commands [ tokens ]:\n", __func__);
fprintf(stderr, "\n");
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
fprintf(stderr, " - \033[1m%-*s\033[0m = [", max_len, allowed_commands[i].c_str());
for (const auto & token : allowed_tokens[i]) {
fprintf(stderr, " %5d", token);
}
fprintf(stderr, " ]\n");
}
std::string k_prompt = "select one from the available words: ";
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
if (i > 0) {
k_prompt += ", ";
}
k_prompt += allowed_commands[i];
}
k_prompt += ". selected word: ";
// tokenize prompt
std::vector<whisper_token> k_tokens;
{
k_tokens.resize(1024);
const int n = whisper_tokenize(ctx, k_prompt.c_str(), k_tokens.data(), 1024);
if (n < 0) {
fprintf(stderr, "%s: error: failed to tokenize prompt '%s'\n", __func__, k_prompt.c_str());
return 4;
}
k_tokens.resize(n);
}
fprintf(stderr, "\n");
fprintf(stderr, "%s: prompt: '%s'\n", __func__, k_prompt.c_str());
fprintf(stderr, "%s: tokens: [", __func__);
for (const auto & token : k_tokens) {
fprintf(stderr, " %d", token);
}
fprintf(stderr, " ]\n");
fprintf(stderr, "\n");
fprintf(stderr, "%s: listening for a command ...\n", __func__);
fprintf(stderr, "\n");
bool is_running = true;
std::vector<float> pcmf32_cur;
std::vector<float> pcmf32_prompt;
// main loop
while (is_running) {
// handle Ctrl + C
is_running = process_sdl_events();
// delay
std::this_thread::sleep_for(std::chrono::milliseconds(100));
audio.get(2000, pcmf32_cur);
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
const auto t_start = std::chrono::high_resolution_clock::now();
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
wparams.print_progress = false;
wparams.print_special = params.print_special;
wparams.print_realtime = false;
wparams.print_timestamps = !params.no_timestamps;
wparams.translate = params.translate;
wparams.no_context = true;
wparams.single_segment = true;
wparams.max_tokens = 1;
wparams.language = params.language.c_str();
wparams.n_threads = params.n_threads;
wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;
wparams.prompt_tokens = k_tokens.data();
wparams.prompt_n_tokens = k_tokens.size();
// run the transformer and a single decoding pass
if (whisper_full(ctx, wparams, pcmf32_cur.data(), pcmf32_cur.size()) != 0) {
fprintf(stderr, "%s: ERROR: whisper_full() failed\n", __func__);
break;
}
// estimate command probability
// NOTE: not optimal
{
const auto * logits = whisper_get_logits(ctx);
std::vector<float> probs(whisper_n_vocab(ctx), 0.0f);
// compute probs from logits via softmax
{
float max = -1e9;
for (int i = 0; i < (int) probs.size(); ++i) {
max = std::max(max, logits[i]);
}
float sum = 0.0f;
for (int i = 0; i < (int) probs.size(); ++i) {
probs[i] = expf(logits[i] - max);
sum += probs[i];
}
for (int i = 0; i < (int) probs.size(); ++i) {
probs[i] /= sum;
}
}
std::vector<std::pair<float, int>> probs_id;
double psum = 0.0;
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
probs_id.emplace_back(probs[allowed_tokens[i][0]], i);
for (int j = 1; j < (int) allowed_tokens[i].size(); ++j) {
probs_id.back().first += probs[allowed_tokens[i][j]];
}
probs_id.back().first /= allowed_tokens[i].size();
psum += probs_id.back().first;
}
// normalize
for (auto & p : probs_id) {
p.first /= psum;
}
// sort descending
{
using pair_type = decltype(probs_id)::value_type;
std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) {
return a.first > b.first;
});
}
// print the commands and the respective probabilities
{
fprintf(stdout, "\n");
for (const auto & cmd : probs_id) {
fprintf(stdout, "%s: %s%-*s%s = %f | ", __func__, "\033[1m", max_len, allowed_commands[cmd.second].c_str(), "\033[0m", cmd.first);
for (int token : allowed_tokens[cmd.second]) {
fprintf(stdout, "'%4s' %f ", whisper_token_to_str(ctx, token), probs[token]);
}
fprintf(stdout, "\n");
}
}
// best command
{
const auto t_end = std::chrono::high_resolution_clock::now();
const float prob = probs_id[0].first;
const int index = probs_id[0].second;
fprintf(stdout, "\n");
fprintf(stdout, "%s: detected command: %s%s%s | p = %f | t = %d ms\n", __func__,
"\033[1m", allowed_commands[index].c_str(), "\033[0m", prob,
(int) std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count());
fprintf(stdout, "\n");
}
}
audio.clear();
}
}
return 0;
}
// always-prompt mode
// transcribe the voice into text after valid prompt
int always_prompt_transcription(struct whisper_context * ctx, audio_async & audio, const whisper_params & params) {
bool is_running = true;
bool ask_prompt = true;
float prob = 0.0f;
std::vector<float> pcmf32_cur;
const std::string k_prompt = params.prompt;
const int k_prompt_length = get_words(k_prompt).size();
fprintf(stderr, "\n");
fprintf(stderr, "%s: always-prompt mode\n", __func__);
// main loop
while (is_running) {
// handle Ctrl + C
is_running = process_sdl_events();
// delay
std::this_thread::sleep_for(std::chrono::milliseconds(100));
if (ask_prompt) {
fprintf(stdout, "\n");
fprintf(stdout, "%s: The prompt is: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m");
fprintf(stdout, "\n");
ask_prompt = false;
}
{
audio.get(2000, pcmf32_cur);
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
int64_t t_ms = 0;
// detect the commands
audio.get(params.command_ms, pcmf32_cur);
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
const auto words = get_words(txt);
std::string prompt;
std::string command;
for (int i = 0; i < (int) words.size(); ++i) {
if (i < k_prompt_length) {
prompt += words[i] + " ";
} else {
command += words[i] + " ";
}
}
const float sim = similarity(prompt, k_prompt);
//debug
//fprintf(stdout, "command size: %i\n", command_length);
if ((sim > 0.7f) && (command.size() > 0)) {
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
}
fprintf(stdout, "\n");
audio.clear();
}
}
}
return 0;
}
// general-purpose mode
// freely transcribe the voice into text
int process_general_transcription(struct whisper_context * ctx, audio_async &audio, const whisper_params &params) {
bool is_running = true;
bool have_prompt = false;
bool ask_prompt = true;
float prob0 = 0.0f;
float prob = 0.0f;
std::vector<float> pcmf32_cur;
std::vector<float> pcmf32_prompt;
const std::string k_prompt = "Ok Whisper, start listening for commands.";
fprintf(stderr, "\n");
fprintf(stderr, "%s: general-purpose mode\n", __func__);
// main loop
while (is_running) {
// handle Ctrl + C
is_running = process_sdl_events();
// delay
std::this_thread::sleep_for(std::chrono::milliseconds(100));
if (ask_prompt) {
fprintf(stdout, "\n");
fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m");
fprintf(stdout, "\n");
ask_prompt = false;
}
{
audio.get(2000, pcmf32_cur);
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
int64_t t_ms = 0;
if (!have_prompt) {
// wait for activation phrase
audio.get(params.prompt_ms, pcmf32_cur);
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob0, t_ms));
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms);
const float sim = similarity(txt, k_prompt);
if (txt.length() < 0.8*k_prompt.length() || txt.length() > 1.2*k_prompt.length() || sim < 0.8f) {
fprintf(stdout, "%s: WARNING: prompt not recognized, try again\n", __func__);
ask_prompt = true;
} else {
fprintf(stdout, "\n");
fprintf(stdout, "%s: The prompt has been recognized!\n", __func__);
fprintf(stdout, "%s: Waiting for voice commands ...\n", __func__);
fprintf(stdout, "\n");
// save the audio for the prompt
pcmf32_prompt = pcmf32_cur;
have_prompt = true;
}
} else {
// we have heard the activation phrase, now detect the commands
audio.get(params.command_ms, pcmf32_cur);
// prepend the prompt audio
pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
prob = 100.0f*(prob - prob0);
//fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str());
// find the prompt in the text
float best_sim = 0.0f;
size_t best_len = 0;
for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
const auto prompt = txt.substr(0, n);
const float sim = similarity(prompt, k_prompt);
//fprintf(stderr, "%s: prompt = '%s', sim = %f\n", __func__, prompt.c_str(), sim);
if (sim > best_sim) {
best_sim = sim;
best_len = n;
}
}
const std::string command = ::trim(txt.substr(best_len));
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
fprintf(stdout, "\n");
}
audio.clear();
}
}
}
return 0;
}
int main(int argc, char ** argv) {
whisper_params params;
@ -525,7 +956,7 @@ int main(int argc, char ** argv) {
// whisper init
struct whisper_context * ctx = whisper_init(params.model.c_str());
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
// print some info about the processing
{
@ -561,300 +992,14 @@ int main(int argc, char ** argv) {
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
audio.clear();
int max_len = 0;
bool is_running = true;
bool have_prompt = false;
bool ask_prompt = true;
float prob0 = 0.0f;
float prob = 0.0f;
std::vector<float> pcmf32_cur;
std::vector<float> pcmf32_prompt;
std::vector<std::string> allowed_commands;
std::vector<std::vector<whisper_token>> allowed_tokens;
std::string k_prompt = "";
std::vector<whisper_token> k_tokens;
if (params.commands != "") {
fprintf(stderr, "\n");
fprintf(stderr, "%s: guided mode\n", __func__);
allowed_commands = read_allowed_commands(params.commands);
if (allowed_commands.empty()) {
fprintf(stderr, "%s: error: failed to read allowed commands from '%s'\n", __func__, params.commands.c_str());
return 2;
}
for (const auto & cmd : allowed_commands) {
whisper_token tokens[1024];
allowed_tokens.emplace_back();
for (int l = 0; l < cmd.size(); ++l) {
// NOTE: very important to add the whitespace !
// the reason is that the first decoded token starts with a whitespace too!
std::string ss = std::string(" ") + cmd.substr(0, l + 1);
const int n = whisper_tokenize(ctx, ss.c_str(), tokens, 1024);
if (n < 0) {
fprintf(stderr, "%s: error: failed to tokenize command '%s'\n", __func__, cmd.c_str());
return 3;
}
if (n == 1) {
allowed_tokens.back().push_back(tokens[0]);
}
}
max_len = std::max(max_len, (int) cmd.size());
}
fprintf(stderr, "%s: allowed commands [ tokens ]:\n", __func__);
fprintf(stderr, "\n");
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
fprintf(stderr, " - \033[1m%-*s\033[0m = [", max_len, allowed_commands[i].c_str());
for (const auto & token : allowed_tokens[i]) {
fprintf(stderr, " %5d", token);
}
fprintf(stderr, " ]\n");
}
k_prompt = "select one from the available words: ";
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
if (i > 0) {
k_prompt += ", ";
}
k_prompt += allowed_commands[i];
}
k_prompt += ". selected word: ";
// tokenize prompt
{
k_tokens.resize(1024);
const int n = whisper_tokenize(ctx, k_prompt.c_str(), k_tokens.data(), 1024);
if (n < 0) {
fprintf(stderr, "%s: error: failed to tokenize prompt '%s'\n", __func__, k_prompt.c_str());
return 4;
}
k_tokens.resize(n);
}
fprintf(stderr, "\n");
fprintf(stderr, "%s: prompt: '%s'\n", __func__, k_prompt.c_str());
fprintf(stderr, "%s: tokens: [", __func__);
for (const auto & token : k_tokens) {
fprintf(stderr, " %d", token);
}
fprintf(stderr, " ]\n");
fprintf(stderr, "\n");
fprintf(stderr, "%s: listening for a command ...\n", __func__);
fprintf(stderr, "\n");
int ret_val = 0;
if (!params.commands.empty()) {
ret_val = process_command_list(ctx, audio, params);
} else if (!params.prompt.empty()) {
ret_val = always_prompt_transcription(ctx, audio, params);
} else {
fprintf(stderr, "\n");
fprintf(stderr, "%s: general-purpose mode\n", __func__);
k_prompt = "Ok Whisper, start listening for commands.";
}
// main loop
while (is_running) {
// handle Ctrl + C
{
SDL_Event event;
while (SDL_PollEvent(&event)) {
switch (event.type) {
case SDL_QUIT:
{
is_running = false;
} break;
default:
break;
}
}
if (!is_running) {
break;
}
}
// delay
std::this_thread::sleep_for(std::chrono::milliseconds(100));
if (allowed_commands.empty()) {
// general-purpose mode
// freely transcribe the voice into text
if (ask_prompt) {
fprintf(stdout, "\n");
fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m");
fprintf(stdout, "\n");
ask_prompt = false;
}
{
int64_t t_ms = 0;
audio.get(2000, pcmf32_cur);
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
if (!have_prompt) {
// wait for activation phrase
audio.get(params.prompt_ms, pcmf32_cur);
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob0, t_ms));
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms);
const float sim = similarity(txt, k_prompt);
if (txt.length() < 0.8*k_prompt.length() || txt.length() > 1.2*k_prompt.length() || sim < 0.8f) {
fprintf(stdout, "%s: WARNING: prompt not recognized, try again\n", __func__);
ask_prompt = true;
} else {
fprintf(stdout, "\n");
fprintf(stdout, "%s: The prompt has been recognized!\n", __func__);
fprintf(stdout, "%s: Waiting for voice commands ...\n", __func__);
fprintf(stdout, "\n");
// save the audio for the prompt
pcmf32_prompt = pcmf32_cur;
have_prompt = true;
}
} else {
// we have heard the activation phrase, now detect the commands
audio.get(params.command_ms, pcmf32_cur);
// prepend the prompt audio
pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
prob = 100.0f*(prob - prob0);
//fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str());
// find the prompt in the text
float best_sim = 0.0f;
size_t best_len = 0;
for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
const auto prompt = txt.substr(0, n);
const float sim = similarity(prompt, k_prompt);
//fprintf(stderr, "%s: prompt = '%s', sim = %f\n", __func__, prompt.c_str(), sim);
if (sim > best_sim) {
best_sim = sim;
best_len = n;
}
}
const std::string command = ::trim(txt.substr(best_len));
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
fprintf(stdout, "\n");
}
audio.clear();
}
}
} else {
// command-list mode
// guide the transcription to match the most likely command from a provided list
audio.get(2000, pcmf32_cur);
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
const auto t_start = std::chrono::high_resolution_clock::now();
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
wparams.print_progress = false;
wparams.print_special = params.print_special;
wparams.print_realtime = false;
wparams.print_timestamps = !params.no_timestamps;
wparams.translate = params.translate;
wparams.no_context = true;
wparams.single_segment = true;
wparams.max_tokens = 1;
wparams.language = params.language.c_str();
wparams.n_threads = params.n_threads;
wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;
wparams.prompt_tokens = k_tokens.data();
wparams.prompt_n_tokens = k_tokens.size();
// run the transformer and a single decoding pass
if (whisper_full(ctx, wparams, pcmf32_cur.data(), pcmf32_cur.size()) != 0) {
fprintf(stderr, "%s: ERROR: whisper_full() failed\n", __func__);
break;
}
const auto * probs = whisper_get_probs(ctx);
std::vector<std::pair<float, int>> probs_id;
double psum = 0.0;
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
probs_id.push_back(std::make_pair(probs[allowed_tokens[i][0]], i));
for (int j = 1; j < (int) allowed_tokens[i].size(); ++j) {
probs_id.back().first += probs[allowed_tokens[i][j]];
}
probs_id.back().first /= allowed_tokens[i].size();
psum += probs_id.back().first;
}
// normalize
for (auto & p : probs_id) {
p.first /= psum;
}
// sort descending
{
using pair_type = decltype(probs_id)::value_type;
std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) {
return a.first > b.first;
});
}
// print the commands and the respective probabilities
{
fprintf(stdout, "\n");
for (const auto & cmd : probs_id) {
fprintf(stdout, "%s: %s%-*s%s = %f | ", __func__, "\033[1m", max_len, allowed_commands[cmd.second].c_str(), "\033[0m", cmd.first);
for (int i = 0; i < (int) allowed_tokens[cmd.second].size(); ++i) {
fprintf(stdout, "'%4s' %f ", whisper_token_to_str(ctx, allowed_tokens[cmd.second][i]), probs[allowed_tokens[cmd.second][i]]);
}
fprintf(stdout, "\n");
}
}
// best command
{
fprintf(stdout, "\n");
fprintf(stdout, "%s: detected command: %s%s%s | p = %f | t = %d ms\n", __func__,
"\033[1m", allowed_commands[probs_id[0].second].c_str(), "\033[0m", probs_id[0].first,
(int) std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::high_resolution_clock::now() - t_start).count());
fprintf(stdout, "\n");
}
const auto t_end = std::chrono::high_resolution_clock::now();
audio.clear();
}
}
ret_val = process_general_transcription(ctx, audio, params);
}
audio.pause();
@ -862,5 +1007,5 @@ int main(int argc, char ** argv) {
whisper_print_timings(ctx);
whisper_free(ctx);
return 0;
return ret_val;
}

View File

@ -1,3 +1,6 @@
set(TARGET main)
add_executable(${TARGET} main.cpp)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE whisper ${CMAKE_THREAD_LIBS_INIT})

View File

@ -59,8 +59,12 @@ struct whisper_params {
int32_t duration_ms = 0;
int32_t max_context = -1;
int32_t max_len = 0;
int32_t best_of = 5;
int32_t beam_size = -1;
float word_thold = 0.01f;
float word_thold = 0.01f;
float entropy_thold = 2.4f;
float logprob_thold = -1.0f;
bool speed_up = false;
bool translate = false;
@ -69,13 +73,14 @@ struct whisper_params {
bool output_vtt = false;
bool output_srt = false;
bool output_wts = false;
bool output_csv = false;
bool print_special = false;
bool print_colors = false;
bool print_progress = false;
bool no_timestamps = false;
std::string language = "en";
std::string prompt = "";
std::string prompt;
std::string model = "models/ggml-base.en.bin";
std::vector<std::string> fname_inp = {};
@ -103,7 +108,11 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); }
else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); }
else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); }
else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); }
else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); }
else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); }
else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); }
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
@ -111,6 +120,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; }
else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
@ -118,7 +128,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
else if ( arg == "--prompt") { params.prompt = argv[++i]; }
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else if (arg == "-f" || arg == "--file") { params.fname_inp.push_back(argv[++i]); }
else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); }
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params);
@ -129,35 +139,40 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
return true;
}
void whisper_print_usage(int argc, char ** argv, const whisper_params & params) {
void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params) {
fprintf(stderr, "\n");
fprintf(stderr, "usage: %s [options] file0.wav file1.wav ...\n", argv[0]);
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors);
fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms);
fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n);
fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms);
fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len);
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false");
fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false");
fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true");
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors);
fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms);
fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n);
fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms);
fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len);
fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of);
fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size);
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold);
fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold);
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false");
fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false");
fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false");
fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true");
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
fprintf(stderr, "\n");
}
@ -173,90 +188,81 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi
const int n_segments = whisper_full_n_segments(ctx);
std::string speaker = "";
int64_t t0;
int64_t t1;
// print the last n_new segments
const int s0 = n_segments - n_new;
if (s0 == 0) {
printf("\n");
}
for (int i = s0; i < n_segments; i++) {
if (params.no_timestamps) {
if (params.print_colors) {
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
if (params.print_special == false) {
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
if (id >= whisper_token_eot(ctx)) {
continue;
}
}
const char * text = whisper_full_get_token_text(ctx, i, j);
const float p = whisper_full_get_token_p (ctx, i, j);
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
}
} else {
const char * text = whisper_full_get_segment_text(ctx, i);
printf("%s", text);
}
fflush(stdout);
} else {
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
std::string speaker = "";
if (params.diarize && pcmf32s.size() == 2) {
const int64_t n_samples = pcmf32s[0].size();
const int64_t is0 = timestamp_to_sample(t0, n_samples);
const int64_t is1 = timestamp_to_sample(t1, n_samples);
double energy0 = 0.0f;
double energy1 = 0.0f;
for (int64_t j = is0; j < is1; j++) {
energy0 += fabs(pcmf32s[0][j]);
energy1 += fabs(pcmf32s[1][j]);
}
if (energy0 > 1.1*energy1) {
speaker = "(speaker 0)";
} else if (energy1 > 1.1*energy0) {
speaker = "(speaker 1)";
} else {
speaker = "(speaker ?)";
}
//printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str());
}
if (params.print_colors) {
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
if (params.print_special == false) {
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
if (id >= whisper_token_eot(ctx)) {
continue;
}
}
const char * text = whisper_full_get_token_text(ctx, i, j);
const float p = whisper_full_get_token_p (ctx, i, j);
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
}
printf("\n");
} else {
const char * text = whisper_full_get_segment_text(ctx, i);
printf("[%s --> %s] %s%s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), speaker.c_str(), text);
}
if (!params.no_timestamps || params.diarize) {
t0 = whisper_full_get_segment_t0(ctx, i);
t1 = whisper_full_get_segment_t1(ctx, i);
}
if (!params.no_timestamps) {
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
}
if (params.diarize && pcmf32s.size() == 2) {
const int64_t n_samples = pcmf32s[0].size();
const int64_t is0 = timestamp_to_sample(t0, n_samples);
const int64_t is1 = timestamp_to_sample(t1, n_samples);
double energy0 = 0.0f;
double energy1 = 0.0f;
for (int64_t j = is0; j < is1; j++) {
energy0 += fabs(pcmf32s[0][j]);
energy1 += fabs(pcmf32s[1][j]);
}
if (energy0 > 1.1*energy1) {
speaker = "(speaker 0)";
} else if (energy1 > 1.1*energy0) {
speaker = "(speaker 1)";
} else {
speaker = "(speaker ?)";
}
//printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str());
}
if (params.print_colors) {
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
if (params.print_special == false) {
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
if (id >= whisper_token_eot(ctx)) {
continue;
}
}
const char * text = whisper_full_get_token_text(ctx, i, j);
const float p = whisper_full_get_token_p (ctx, i, j);
const int col = std::max(0, std::min((int) k_colors.size() - 1, (int) (std::pow(p, 3)*float(k_colors.size()))));
printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
}
} else {
const char * text = whisper_full_get_segment_text(ctx, i);
printf("%s%s", speaker.c_str(), text);
}
// with timestamps or speakers: each segment on new line
if (!params.no_timestamps || params.diarize) {
printf("\n");
}
fflush(stdout);
}
}
@ -325,10 +331,35 @@ bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_
return true;
}
bool output_csv(struct whisper_context * ctx, const char * fname) {
std::ofstream fout(fname);
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
return false;
}
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
if (text[0] == ' ') {
text = text + sizeof(char); //whisper_full_get_segment_text() returns a string with leading space, point to the next character.
}
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
//need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds.
fout << 10 * t0 << ", " << 10 * t1 << ", \"" << text << "\"\n";
}
return true;
}
// karaoke video generation
// outputs a bash script that uses ffmpeg to generate a video with the subtitles
// TODO: font parameter adjustments
bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, float t_sec) {
bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & /*params*/, float t_sec) {
std::ofstream fout(fname);
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
@ -377,7 +408,6 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
txt_ul = "\\ \\ ";
{
int ncnt = 0;
for (int k = 0; k < n; ++k) {
const auto & token2 = tokens[k];
@ -401,8 +431,6 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
txt_ul += "\\ ";
}
}
ncnt += txt.size();
}
::replace_all(txt_bg, "'", "\u2019");
@ -461,7 +489,7 @@ int main(int argc, char ** argv) {
// whisper init
struct whisper_context * ctx = whisper_init(params.model.c_str());
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
if (ctx == nullptr) {
fprintf(stderr, "error: failed to initialize whisper context\n");
@ -471,7 +499,7 @@ int main(int argc, char ** argv) {
// initial prompt
std::vector<whisper_token> prompt_tokens;
if (params.prompt.size() > 0) {
if (!params.prompt.empty()) {
prompt_tokens.resize(1024);
prompt_tokens.resize(whisper_tokenize(ctx, params.prompt.c_str(), prompt_tokens.data(), prompt_tokens.size()));
@ -508,14 +536,14 @@ int main(int argc, char ** argv) {
}
}
if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), NULL) == false) {
if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) {
fprintf(stderr, "error: failed to open WAV file from stdin\n");
return 4;
}
fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
}
else if (drwav_init_file(&wav, fname_inp.c_str(), NULL) == false) {
else if (drwav_init_file(&wav, fname_inp.c_str(), nullptr) == false) {
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
return 5;
}
@ -531,7 +559,7 @@ int main(int argc, char ** argv) {
}
if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
fprintf(stderr, "%s: WAV file '%s' must be 16 kHz\n", argv[0], fname_inp.c_str());
fprintf(stderr, "%s: WAV file '%s' must be %i kHz\n", argv[0], fname_inp.c_str(), WHISPER_SAMPLE_RATE/1000);
return 8;
}
@ -550,11 +578,11 @@ int main(int argc, char ** argv) {
// convert to mono, float
pcmf32.resize(n);
if (wav.channels == 1) {
for (int i = 0; i < n; i++) {
for (uint64_t i = 0; i < n; i++) {
pcmf32[i] = float(pcm16[i])/32768.0f;
}
} else {
for (int i = 0; i < n; i++) {
for (uint64_t i = 0; i < n; i++) {
pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
}
}
@ -565,7 +593,7 @@ int main(int argc, char ** argv) {
pcmf32s[0].resize(n);
pcmf32s[1].resize(n);
for (int i = 0; i < n; i++) {
for (uint64_t i = 0; i < n; i++) {
pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
}
@ -603,6 +631,8 @@ int main(int argc, char ** argv) {
{
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
wparams.strategy = params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY;
wparams.print_realtime = false;
wparams.print_progress = params.print_progress;
wparams.print_timestamps = !params.no_timestamps;
@ -616,12 +646,18 @@ int main(int argc, char ** argv) {
wparams.token_timestamps = params.output_wts || params.max_len > 0;
wparams.thold_pt = params.word_thold;
wparams.entropy_thold = params.entropy_thold;
wparams.logprob_thold = params.logprob_thold;
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
wparams.speed_up = params.speed_up;
wparams.prompt_tokens = prompt_tokens.size() == 0 ? nullptr : prompt_tokens.data();
wparams.prompt_n_tokens = prompt_tokens.size() == 0 ? 0 : prompt_tokens.size();
wparams.greedy.best_of = params.best_of;
wparams.beam_search.beam_size = params.beam_size;
wparams.temperature_inc = -1;
wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();
whisper_print_user_data user_data = { &params, &pcmf32s };
@ -637,7 +673,7 @@ int main(int argc, char ** argv) {
{
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
wparams.encoder_begin_callback = [](struct whisper_context * ctx, void * user_data) {
wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, void * user_data) {
bool is_aborted = *(bool*)user_data;
return !is_aborted;
};
@ -677,6 +713,13 @@ int main(int argc, char ** argv) {
const auto fname_wts = fname_inp + ".wts";
output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE);
}
// output to CSV file
if (params.output_csv) {
const auto fname_csv = fname_inp + ".csv";
output_csv(ctx, fname_csv.c_str());
}
}
}

View File

@ -8,6 +8,8 @@ add_executable(${TARGET}
emscripten.cpp
)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE
whisper
)

View File

@ -49,6 +49,9 @@ void stream_main(size_t index) {
wparams.max_tokens = 32;
wparams.audio_ctx = 768; // partial encoder context for better performance
// disable temperature fallback
wparams.temperature_inc = -1.0f;
wparams.language = "en";
printf("stream: using %d threads\n", wparams.n_threads);
@ -129,7 +132,7 @@ EMSCRIPTEN_BINDINGS(stream) {
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
for (size_t i = 0; i < g_contexts.size(); ++i) {
if (g_contexts[i] == nullptr) {
g_contexts[i] = whisper_init(path_model.c_str());
g_contexts[i] = whisper_init_from_file(path_model.c_str());
if (g_contexts[i] != nullptr) {
g_running = true;
if (g_worker.joinable()) {

View File

@ -2,6 +2,9 @@ if (WHISPER_SUPPORT_SDL2)
# stream
set(TARGET stream)
add_executable(${TARGET} stream.cpp)
include(DefaultTargetOptions)
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
target_link_libraries(${TARGET} PRIVATE whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
endif ()

View File

@ -8,6 +8,7 @@
#include <SDL.h>
#include <SDL_audio.h>
#include <atomic>
#include <cassert>
#include <cstdio>
#include <string>
@ -51,7 +52,7 @@ struct whisper_params {
std::string language = "en";
std::string model = "models/ggml-base.en.bin";
std::string fname_out = "";
std::string fname_out;
};
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
@ -90,7 +91,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
return true;
}
void whisper_print_usage(int argc, char ** argv, const whisper_params & params) {
void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params) {
fprintf(stderr, "\n");
fprintf(stderr, "usage: %s [options]\n", argv[0]);
fprintf(stderr, "\n");
@ -144,8 +145,8 @@ private:
int m_len_ms = 0;
int m_sample_rate = 0;
bool m_running = false;
std::mutex m_mutex;
std::atomic_bool m_running;
std::mutex m_mutex;
std::vector<float> m_audio;
std::vector<float> m_audio_new;
@ -155,6 +156,8 @@ private:
audio_async::audio_async(int len_ms) {
m_len_ms = len_ms;
m_running = false;
}
audio_async::~audio_async() {
@ -391,7 +394,7 @@ bool vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float
float energy_all = 0.0f;
float energy_last = 0.0f;
for (size_t i = 0; i < n_samples; i++) {
for (int i = 0; i < n_samples; i++) {
energy_all += fabsf(pcmf32[i]);
if (i >= n_samples - n_samples_last) {
@ -427,10 +430,10 @@ int main(int argc, char ** argv) {
const int n_samples_keep = (params.keep_ms *1e-3)*WHISPER_SAMPLE_RATE;
const int n_samples_30s = (30000 *1e-3)*WHISPER_SAMPLE_RATE;
const int n_new_line = params.length_ms / params.step_ms - 1; // number of steps to print new line
const bool use_vad = n_samples_step <= 0; // sliding window mode uses VAD
const int n_new_line = !use_vad ? params.length_ms / params.step_ms - 1 : 1; // number of steps to print new line
params.no_timestamps = !use_vad;
params.no_context = use_vad;
params.max_tokens = 0;
@ -453,10 +456,10 @@ int main(int argc, char ** argv) {
exit(0);
}
struct whisper_context * ctx = whisper_init(params.model.c_str());
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
std::vector<float> pcmf32 (n_samples_30s, 0.0f);
std::vector<float> pcmf32_old(n_samples_30s, 0.0f);
std::vector<float> pcmf32_old;
std::vector<float> pcmf32_new(n_samples_30s, 0.0f);
std::vector<whisper_token> prompt_tokens;
@ -612,6 +615,9 @@ int main(int argc, char ** argv) {
wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;
// disable temperature fallback
wparams.temperature_inc = -1.0f;
wparams.prompt_tokens = params.no_context ? nullptr : prompt_tokens.data();
wparams.prompt_n_tokens = params.no_context ? 0 : prompt_tokens.size();

View File

@ -9,6 +9,8 @@ add_executable(${TARGET}
gpt-2.cpp
)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE
whisper
)
@ -31,8 +33,8 @@ set_target_properties(${TARGET} PROPERTIES LINK_FLAGS " \
--bind \
-s USE_PTHREADS=1 \
-s PTHREAD_POOL_SIZE=8 \
-s INITIAL_MEMORY=1600MB \
-s TOTAL_MEMORY=1600MB \
-s INITIAL_MEMORY=1800MB \
-s TOTAL_MEMORY=1800MB \
-s FORCE_FILESYSTEM=1 \
-s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap']\" \
${EXTRA_FLAGS} \

View File

@ -36,7 +36,7 @@ In order to run this demo efficiently, you need to have the following:
- Latest Chrome or Firefox browser (Safari is not supported)
- Run this on a desktop or laptop with modern CPU (a mobile phone will likely not be good enough)
- Speak phrases that are no longer than 10 seconds - this is the audio context of the AI
- The web-page uses about 1.6GB of RAM
- The web-page uses about 1.8GB of RAM
Notice that this demo is using the smallest GPT-2 model, so the generated text responses are not always very good.
Also, the prompting strategy can likely be improved to achieve better results.

View File

@ -271,7 +271,7 @@ EMSCRIPTEN_BINDINGS(talk) {
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
for (size_t i = 0; i < g_contexts.size(); ++i) {
if (g_contexts[i] == nullptr) {
g_contexts[i] = whisper_init(path_model.c_str());
g_contexts[i] = whisper_init_from_file(path_model.c_str());
if (g_contexts[i] != nullptr) {
g_running = true;
if (g_worker.joinable()) {

View File

@ -8,6 +8,9 @@ if (WHISPER_SUPPORT_SDL2)
# TODO: this is temporary
# need to export ggml symbols for MSVC, but too lazy ..
add_executable(${TARGET} talk.cpp gpt-2.cpp ../../ggml.c ../../whisper.cpp)
include(DefaultTargetOptions)
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS} ../../)
target_link_libraries(${TARGET} PRIVATE ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
endif ()

View File

@ -40,7 +40,7 @@ std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::stri
// find the longest tokens that form the words:
std::vector<gpt_vocab::id> tokens;
for (const auto & word : words) {
if (word.size() == 0) continue;
if (word.empty()) continue;
int i = 0;
int n = word.size();
@ -78,7 +78,7 @@ gpt_vocab::id gpt_sample_top_k_top_p(
const float * logits,
int top_k,
double top_p,
double temp,
double /*temp*/,
std::mt19937 & rng) {
int n_logits = vocab.id_to_token.size();
@ -86,7 +86,7 @@ gpt_vocab::id gpt_sample_top_k_top_p(
logits_id.reserve(n_logits);
for (int i = 0; i < n_logits; i++) {
logits_id.push_back(std::make_pair(logits[i], i));
logits_id.emplace_back(logits[i], i);
}
// find the top K tokens
@ -268,7 +268,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
fin.read((char *) &len, sizeof(len));
word.resize(len);
fin.read((char *) word.data(), len);
fin.read((char *) &word[0], len);
vocab.token_to_id[word] = i;
vocab.id_to_token[i] = word;
@ -327,7 +327,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
{
struct ggml_init_params params;
params.mem_size = ctx_size;
params.mem_buffer = NULL;
params.mem_buffer = nullptr;
model.ctx = ggml_init(params);
if (!model.ctx) {
@ -448,7 +448,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
std::string name(length, 0);
fin.read(&name[0], length);
if (model.tensors.find(name.data()) == model.tensors.end()) {
if (model.tensors.find(name) == model.tensors.end()) {
fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
return false;
}
@ -833,7 +833,7 @@ Me too.
struct gpt2_context * gpt2_init(const char * path_model) {
gpt2_context * ctx = new gpt2_context;
ctx->rng = std::mt19937(time(NULL));
ctx->rng = std::mt19937(time(nullptr));
// load the model
{
@ -841,6 +841,7 @@ struct gpt2_context * gpt2_init(const char * path_model) {
if (!gpt2_model_load(path_model, ctx->model, ctx->vocab)) {
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, path_model);
delete ctx;
return nullptr;
}
@ -884,9 +885,9 @@ std::string gpt2_gen_text(gpt2_context * ctx, const char * text, int max_tokens)
std::string result;
for (int i = embd.size(); i < embd_inp.size() + n_predict; i++) {
for (int i = embd.size(); i < (int) embd_inp.size() + n_predict; i++) {
// predict
if (embd.size() > 0) {
if (!embd.empty()) {
if (!gpt2_eval(ctx->model, ctx->n_threads, n_past, embd, embd_w, mem_per_token)) {
printf("gpt-2: failed to generate text\n");
return "";

View File

@ -39,7 +39,7 @@ struct whisper_params {
std::string model_wsp = "models/ggml-base.en.bin";
std::string model_gpt = "models/ggml-gpt-2-117M.bin";
std::string speak = "./examples/talk/speak.sh";
std::string fname_out = "";
std::string fname_out;
};
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
@ -79,7 +79,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
return true;
}
void whisper_print_usage(int argc, char ** argv, const whisper_params & params) {
void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params) {
fprintf(stderr, "\n");
fprintf(stderr, "usage: %s [options]\n", argv[0]);
fprintf(stderr, "\n");
@ -397,7 +397,7 @@ bool vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float
float energy_all = 0.0f;
float energy_last = 0.0f;
for (size_t i = 0; i < n_samples; i++) {
for (int i = 0; i < n_samples; i++) {
energy_all += fabsf(pcmf32[i]);
if (i >= n_samples - n_samples_last) {
@ -498,7 +498,7 @@ int main(int argc, char ** argv) {
// whisper init
struct whisper_context * ctx_wsp = whisper_init(params.model_wsp.c_str());
struct whisper_context * ctx_wsp = whisper_init_from_file(params.model_wsp.c_str());
// gpt init
@ -541,7 +541,6 @@ int main(int argc, char ** argv) {
bool force_speak = false;
float prob0 = 0.0f;
float prob = 0.0f;
std::vector<float> pcmf32_cur;
std::vector<float> pcmf32_prompt;
@ -589,7 +588,7 @@ int main(int argc, char ** argv) {
audio.get(params.voice_ms, pcmf32_cur);
std::string text_heard = "";
std::string text_heard;
if (!force_speak) {
text_heard = ::trim(::transcribe(ctx_wsp, params, pcmf32_cur, prob0, t_ms));
@ -611,7 +610,7 @@ int main(int argc, char ** argv) {
text_heard = std::regex_replace(text_heard, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
// take first line
text_heard = text_heard.substr(0, text_heard.find_first_of("\n"));
text_heard = text_heard.substr(0, text_heard.find_first_of('\n'));
// remove leading and trailing whitespace
text_heard = std::regex_replace(text_heard, std::regex("^\\s+"), "");
@ -641,18 +640,18 @@ int main(int argc, char ** argv) {
text_to_speak = gpt2_gen_text(ctx_gpt, prompt.c_str(), params.max_tokens);
text_to_speak = std::regex_replace(text_to_speak, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
text_to_speak = text_to_speak.substr(0, text_to_speak.find_first_of("\n"));
text_to_speak = text_to_speak.substr(0, text_to_speak.find_first_of('\n'));
// remove first 2 lines of base prompt
if (n_iter > 4) {
{
const size_t pos = prompt_base.find_first_of("\n");
const size_t pos = prompt_base.find_first_of('\n');
if (pos != std::string::npos) {
prompt_base = prompt_base.substr(pos + 1);
}
}
{
const size_t pos = prompt_base.find_first_of("\n");
const size_t pos = prompt_base.find_first_of('\n');
if (pos != std::string::npos) {
prompt_base = prompt_base.substr(pos + 1);
}

View File

@ -1,5 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="GradleMigrationSettings" migrationVersion="1" />
<component name="GradleSettings">
<option name="linkedExternalProjectsSettings">
<GradleProjectSettings>

View File

@ -14,10 +14,6 @@ android {
versionCode 1
versionName "1.0"
ndk {
abiFilters 'arm64-v8a', 'x86_64'
}
testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
vectorDrawables {
useSupportLibrary true
@ -44,7 +40,7 @@ android {
composeOptions {
kotlinCompilerExtensionVersion '1.3.1'
}
ndkVersion "25.0.8528842"
ndkVersion "25.1.8937393"
externalNativeBuild {
ndkBuild {
path 'src/main/jni/whisper/Android.mk'

View File

@ -64,16 +64,22 @@ class MainScreenViewModel(private val application: Application) : ViewModel() {
private suspend fun copyAssets() = withContext(Dispatchers.IO) {
modelsPath.mkdirs()
samplesPath.mkdirs()
application.copyData("models", modelsPath, ::printMessage)
//application.copyData("models", modelsPath, ::printMessage)
application.copyData("samples", samplesPath, ::printMessage)
printMessage("All data copied to working directory.\n")
}
private suspend fun loadBaseModel() = withContext(Dispatchers.IO) {
printMessage("Loading model...\n")
val firstModel = modelsPath.listFiles()!!.first()
whisperContext = WhisperContext.createContext(firstModel.absolutePath)
printMessage("Loaded model ${firstModel.name}.\n")
val models = application.assets.list("models/")
if (models != null) {
val inputstream = application.assets.open("models/" + models[0])
whisperContext = WhisperContext.createContextFromInputStream(inputstream)
printMessage("Loaded model ${models[0]}.\n")
}
//val firstModel = modelsPath.listFiles()!!.first()
//whisperContext = WhisperContext.createContextFromFile(firstModel.absolutePath)
}
fun transcribeSample() = viewModelScope.launch {

View File

@ -1,8 +1,14 @@
package com.whispercppdemo.whisper
import android.os.Build
import android.util.Log
import kotlinx.coroutines.*
import java.io.File
import java.io.InputStream
import java.util.concurrent.Executors
private const val LOG_TAG = "LibWhisper"
class WhisperContext private constructor(private var ptr: Long) {
// Meet Whisper C++ constraint: Don't access from more than one thread at a time.
private val scope: CoroutineScope = CoroutineScope(
@ -34,23 +40,53 @@ class WhisperContext private constructor(private var ptr: Long) {
}
companion object {
fun createContext(filePath: String): WhisperContext {
fun createContextFromFile(filePath: String): WhisperContext {
val ptr = WhisperLib.initContext(filePath)
if (ptr == 0L) {
throw java.lang.RuntimeException("Couldn't create context with path $filePath")
}
return WhisperContext(ptr)
}
fun createContextFromInputStream(stream: InputStream): WhisperContext {
val ptr = WhisperLib.initContextFromInputStream(stream)
if (ptr == 0L) {
throw java.lang.RuntimeException("Couldn't create context from input stream")
}
return WhisperContext(ptr)
}
}
}
private class WhisperLib {
companion object {
init {
System.loadLibrary("whisper")
Log.d(LOG_TAG, "Primary ABI: ${Build.SUPPORTED_ABIS[0]}")
var loadVfpv4 = false
if (isArmEabiV7a()) {
// armeabi-v7a needs runtime detection support
val cpuInfo = cpuInfo()
cpuInfo?.let {
Log.d(LOG_TAG, "CPU info: $cpuInfo")
if (cpuInfo.contains("vfpv4")) {
Log.d(LOG_TAG, "CPU supports vfpv4")
loadVfpv4 = true
}
}
}
if (loadVfpv4) {
Log.d(LOG_TAG, "Loading libwhisper_vfpv4.so")
System.loadLibrary("whisper_vfpv4")
} else {
Log.d(LOG_TAG, "Loading libwhisper.so")
System.loadLibrary("whisper")
}
}
// JNI methods
external fun initContextFromInputStream(inputStream: InputStream): Long
external fun initContext(modelPath: String): Long
external fun freeContext(contextPtr: Long)
external fun fullTranscribe(contextPtr: Long, audioData: FloatArray)
@ -59,3 +95,17 @@ private class WhisperLib {
}
}
private fun isArmEabiV7a(): Boolean {
return Build.SUPPORTED_ABIS[0].equals("armeabi-v7a")
}
private fun cpuInfo(): String? {
return try {
File("/proc/cpuinfo").inputStream().bufferedReader().use {
it.readText()
}
} catch (e: Exception) {
Log.w(LOG_TAG, "Couldn't read /proc/cpuinfo", e)
null
}
}

View File

@ -1,22 +1,15 @@
LOCAL_PATH := $(call my-dir)
include $(CLEAR_VARS)
WHISPER_LIB_DIR := $(LOCAL_PATH)/../../../../../../../
LOCAL_LDLIBS := -llog
LOCAL_MODULE := libwhisper
include $(LOCAL_PATH)/Whisper.mk
include $(BUILD_SHARED_LIBRARY)
# Make the final output library smaller by only keeping the symbols referenced from the app.
ifneq ($(APP_OPTIM),debug)
LOCAL_CFLAGS += -fvisibility=hidden -fvisibility-inlines-hidden
LOCAL_CFLAGS += -ffunction-sections -fdata-sections
LOCAL_LDFLAGS += -Wl,--gc-sections
LOCAL_LDFLAGS += -Wl,--exclude-libs,ALL
LOCAL_LDFLAGS += -flto
endif
LOCAL_CFLAGS += -DSTDC_HEADERS -std=c11 -I $(WHISPER_LIB_DIR)
LOCAL_CPPFLAGS += -std=c++11
LOCAL_SRC_FILES := $(WHISPER_LIB_DIR)/ggml.c \
$(WHISPER_LIB_DIR)/whisper.cpp \
$(LOCAL_PATH)/jni.c
include $(BUILD_SHARED_LIBRARY)
ifeq ($(TARGET_ARCH_ABI),armeabi-v7a)
include $(CLEAR_VARS)
LOCAL_MODULE := libwhisper_vfpv4
include $(LOCAL_PATH)/Whisper.mk
# Allow building NEON FMA code.
# https://android.googlesource.com/platform/ndk/+/master/sources/android/cpufeatures/cpu-features.h
LOCAL_CFLAGS += -mfpu=neon-vfpv4
include $(BUILD_SHARED_LIBRARY)
endif

View File

@ -0,0 +1,18 @@
WHISPER_LIB_DIR := $(LOCAL_PATH)/../../../../../../../
LOCAL_LDLIBS := -llog
# Make the final output library smaller by only keeping the symbols referenced from the app.
ifneq ($(APP_OPTIM),debug)
LOCAL_CFLAGS += -O3
LOCAL_CFLAGS += -fvisibility=hidden -fvisibility-inlines-hidden
LOCAL_CFLAGS += -ffunction-sections -fdata-sections
LOCAL_LDFLAGS += -Wl,--gc-sections
LOCAL_LDFLAGS += -Wl,--exclude-libs,ALL
LOCAL_LDFLAGS += -flto
endif
LOCAL_CFLAGS += -DSTDC_HEADERS -std=c11 -I $(WHISPER_LIB_DIR)
LOCAL_CPPFLAGS += -std=c++11
LOCAL_SRC_FILES := $(WHISPER_LIB_DIR)/ggml.c \
$(WHISPER_LIB_DIR)/whisper.cpp \
$(LOCAL_PATH)/jni.c

View File

@ -2,6 +2,7 @@
#include <android/log.h>
#include <stdlib.h>
#include <sys/sysinfo.h>
#include <string.h>
#include "whisper.h"
#define UNUSED(x) (void)(x)
@ -17,13 +18,86 @@ static inline int max(int a, int b) {
return (a > b) ? a : b;
}
struct input_stream_context {
size_t offset;
JNIEnv * env;
jobject thiz;
jobject input_stream;
jmethodID mid_available;
jmethodID mid_read;
};
size_t inputStreamRead(void * ctx, void * output, size_t read_size) {
struct input_stream_context* is = (struct input_stream_context*)ctx;
jint avail_size = (*is->env)->CallIntMethod(is->env, is->input_stream, is->mid_available);
jint size_to_copy = read_size < avail_size ? (jint)read_size : avail_size;
jbyteArray byte_array = (*is->env)->NewByteArray(is->env, size_to_copy);
jint n_read = (*is->env)->CallIntMethod(is->env, is->input_stream, is->mid_read, byte_array, 0, size_to_copy);
if (size_to_copy != read_size || size_to_copy != n_read) {
LOGI("Insufficient Read: Req=%zu, ToCopy=%d, Available=%d", read_size, size_to_copy, n_read);
}
jbyte* byte_array_elements = (*is->env)->GetByteArrayElements(is->env, byte_array, NULL);
memcpy(output, byte_array_elements, size_to_copy);
(*is->env)->ReleaseByteArrayElements(is->env, byte_array, byte_array_elements, JNI_ABORT);
(*is->env)->DeleteLocalRef(is->env, byte_array);
is->offset += size_to_copy;
return size_to_copy;
}
bool inputStreamEof(void * ctx) {
struct input_stream_context* is = (struct input_stream_context*)ctx;
jint result = (*is->env)->CallIntMethod(is->env, is->input_stream, is->mid_available);
return result <= 0;
}
void inputStreamClose(void * ctx) {
}
JNIEXPORT jlong JNICALL
Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_initContextFromInputStream(
JNIEnv *env, jobject thiz, jobject input_stream) {
UNUSED(thiz);
struct whisper_context *context = NULL;
struct whisper_model_loader loader = {};
struct input_stream_context inp_ctx = {};
inp_ctx.offset = 0;
inp_ctx.env = env;
inp_ctx.thiz = thiz;
inp_ctx.input_stream = input_stream;
jclass cls = (*env)->GetObjectClass(env, input_stream);
inp_ctx.mid_available = (*env)->GetMethodID(env, cls, "available", "()I");
inp_ctx.mid_read = (*env)->GetMethodID(env, cls, "read", "([BII)I");
loader.context = &inp_ctx;
loader.read = inputStreamRead;
loader.eof = inputStreamEof;
loader.close = inputStreamClose;
loader.eof(loader.context);
context = whisper_init(&loader);
return (jlong) context;
}
JNIEXPORT jlong JNICALL
Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_initContext(
JNIEnv *env, jobject thiz, jstring model_path_str) {
UNUSED(thiz);
struct whisper_context *context = NULL;
const char *model_path_chars = (*env)->GetStringUTFChars(env, model_path_str, NULL);
context = whisper_init(model_path_chars);
context = whisper_init_from_file(model_path_chars);
(*env)->ReleaseStringUTFChars(env, model_path_str, model_path_chars);
return (jlong) context;
}

View File

@ -1,10 +0,0 @@
## This file is automatically generated by Android Studio.
# Do not modify this file -- YOUR CHANGES WILL BE ERASED!
#
# This file should *NOT* be checked into Version Control Systems,
# as it contains information specific to your local configuration.
#
# Location of the SDK. This is only used by Gradle.
# For customization when using a Version Control System, please read the
# header note.
sdk.dir=/Users/kevin/Library/Android/sdk

View File

@ -19,3 +19,8 @@ open whisper.cpp/examples/whisper.objc/whisper.objc.xcodeproj/
Make sure to build the project in `Release`:
<img width="947" alt="image" src="https://user-images.githubusercontent.com/1991296/197382607-9e1e6d1b-79fa-496f-9d16-b71dc1535701.png">
Also, don't forget to add the `-DGGML_USE_ACCELERATE` compiler flag in Build Phases.
This can significantly improve the performance of the transcription:
<img width="1072" alt="image" src="https://user-images.githubusercontent.com/1991296/208511239-8d7cdbd1-aa48-41b5-becd-ca288d53cc07.png">

View File

@ -61,7 +61,7 @@ void AudioInputCallback(void * inUserData,
NSLog(@"Loading model from %@", modelPath);
// create ggml context
stateInp.ctx = whisper_init([modelPath UTF8String]);
stateInp.ctx = whisper_init_from_file([modelPath UTF8String]);
// check if the model was loaded successfully
if (stateInp.ctx == NULL) {

View File

@ -0,0 +1,12 @@
A sample SwiftUI app using [whisper.cpp](https://github.com/ggerganov/whisper.cpp/) to do voice-to-text transcriptions.
See also: [whisper.objc](https://github.com/ggerganov/whisper.cpp/tree/master/examples/whisper.objc).
To use:
1. Select a model from the [whisper.cpp repository](https://github.com/ggerganov/whisper.cpp/tree/master/models).[^1]
2. Add the model to "whisper.swiftui.demo/Resources/models" via Xcode.
3. Select a sample audio file (for example, [jfk.wav](https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav)).
4. Add the model to "whisper.swiftui.demo/Resources/samples" via Xcode.
5. Select the "release" build configuration under "Run", then deploy and run to your device.
[^1]: I recommend the tiny, base or small models for running on an iOS device.

View File

@ -0,0 +1,70 @@
import Foundation
enum WhisperError: Error {
case couldNotInitializeContext
}
// Meet Whisper C++ constraint: Don't access from more than one thread at a time.
actor WhisperContext {
private var context: OpaquePointer
init(context: OpaquePointer) {
self.context = context
}
deinit {
whisper_free(context)
}
func fullTranscribe(samples: [Float]) {
// Leave 2 processors free (i.e. the high-efficiency cores).
let maxThreads = max(1, min(8, cpuCount() - 2))
print("Selecting \(maxThreads) threads")
var params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY)
"en".withCString { en in
// Adapted from whisper.objc
params.print_realtime = true
params.print_progress = false
params.print_timestamps = true
params.print_special = false
params.translate = false
params.language = en
params.n_threads = Int32(maxThreads)
params.offset_ms = 0
params.no_context = true
params.single_segment = false
whisper_reset_timings(context)
print("About to run whisper_full")
samples.withUnsafeBufferPointer { samples in
if (whisper_full(context, params, samples.baseAddress, Int32(samples.count)) != 0) {
print("Failed to run the model")
} else {
whisper_print_timings(context)
}
}
}
}
func getTranscription() -> String {
var transcription = ""
for i in 0..<whisper_full_n_segments(context) {
transcription += String.init(cString: whisper_full_get_segment_text(context, i))
}
return transcription
}
static func createContext(path: String) throws -> WhisperContext {
let context = whisper_init_from_file(path)
if let context {
return WhisperContext(context: context)
} else {
print("Couldn't load model at \(path)")
throw WhisperError.couldNotInitializeContext
}
}
}
fileprivate func cpuCount() -> Int {
ProcessInfo.processInfo.processorCount
}

View File

@ -0,0 +1,4 @@
//
// Use this file to import your target's public headers that you would like to expose to Swift.
//
#import "whisper.h"

View File

@ -0,0 +1,162 @@
import Foundation
import SwiftUI
import AVFoundation
@MainActor
class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate {
@Published var isModelLoaded = false
@Published var messageLog = ""
@Published var canTranscribe = false
@Published var isRecording = false
private var whisperContext: WhisperContext?
private let recorder = Recorder()
private var recordedFile: URL? = nil
private var audioPlayer: AVAudioPlayer?
private var modelUrl: URL? {
Bundle.main.url(forResource: "ggml-tiny.en", withExtension: "bin", subdirectory: "models")
}
private var sampleUrl: URL? {
Bundle.main.url(forResource: "jfk", withExtension: "wav", subdirectory: "samples")
}
private enum LoadError: Error {
case couldNotLocateModel
}
override init() {
super.init()
do {
try loadModel()
canTranscribe = true
} catch {
print(error.localizedDescription)
messageLog += "\(error.localizedDescription)\n"
}
}
private func loadModel() throws {
messageLog += "Loading model...\n"
if let modelUrl {
whisperContext = try WhisperContext.createContext(path: modelUrl.path())
messageLog += "Loaded model \(modelUrl.lastPathComponent)\n"
} else {
messageLog += "Could not locate model\n"
}
}
func transcribeSample() async {
if let sampleUrl {
await transcribeAudio(sampleUrl)
} else {
messageLog += "Could not locate sample\n"
}
}
private func transcribeAudio(_ url: URL) async {
if (!canTranscribe) {
return
}
guard let whisperContext else {
return
}
do {
canTranscribe = false
messageLog += "Reading wave samples...\n"
let data = try readAudioSamples(url)
messageLog += "Transcribing data...\n"
await whisperContext.fullTranscribe(samples: data)
let text = await whisperContext.getTranscription()
messageLog += "Done: \(text)\n"
} catch {
print(error.localizedDescription)
messageLog += "\(error.localizedDescription)\n"
}
canTranscribe = true
}
private func readAudioSamples(_ url: URL) throws -> [Float] {
stopPlayback()
try startPlayback(url)
return try decodeWaveFile(url)
}
func toggleRecord() async {
if isRecording {
await recorder.stopRecording()
isRecording = false
if let recordedFile {
await transcribeAudio(recordedFile)
}
} else {
requestRecordPermission { granted in
if granted {
Task {
do {
self.stopPlayback()
let file = try FileManager.default.url(for: .documentDirectory, in: .userDomainMask, appropriateFor: nil, create: true)
.appending(path: "output.wav")
try await self.recorder.startRecording(toOutputFile: file, delegate: self)
self.isRecording = true
self.recordedFile = file
} catch {
print(error.localizedDescription)
self.messageLog += "\(error.localizedDescription)\n"
self.isRecording = false
}
}
}
}
}
}
private func requestRecordPermission(response: @escaping (Bool) -> Void) {
#if os(macOS)
response(true)
#else
AVAudioSession.sharedInstance().requestRecordPermission { granted in
response(granted)
}
#endif
}
private func startPlayback(_ url: URL) throws {
audioPlayer = try AVAudioPlayer(contentsOf: url)
audioPlayer?.play()
}
private func stopPlayback() {
audioPlayer?.stop()
audioPlayer = nil
}
// MARK: AVAudioRecorderDelegate
nonisolated func audioRecorderEncodeErrorDidOccur(_ recorder: AVAudioRecorder, error: Error?) {
if let error {
Task {
await handleRecError(error)
}
}
}
private func handleRecError(_ error: Error) {
print(error.localizedDescription)
messageLog += "\(error.localizedDescription)\n"
isRecording = false
}
nonisolated func audioRecorderDidFinishRecording(_ recorder: AVAudioRecorder, successfully flag: Bool) {
Task {
await onDidFinishRecording()
}
}
private func onDidFinishRecording() {
isRecording = false
}
}

View File

@ -0,0 +1,11 @@
{
"colors" : [
{
"idiom" : "universal"
}
],
"info" : {
"author" : "xcode",
"version" : 1
}
}

View File

@ -0,0 +1,63 @@
{
"images" : [
{
"idiom" : "universal",
"platform" : "ios",
"size" : "1024x1024"
},
{
"idiom" : "mac",
"scale" : "1x",
"size" : "16x16"
},
{
"idiom" : "mac",
"scale" : "2x",
"size" : "16x16"
},
{
"idiom" : "mac",
"scale" : "1x",
"size" : "32x32"
},
{
"idiom" : "mac",
"scale" : "2x",
"size" : "32x32"
},
{
"idiom" : "mac",
"scale" : "1x",
"size" : "128x128"
},
{
"idiom" : "mac",
"scale" : "2x",
"size" : "128x128"
},
{
"idiom" : "mac",
"scale" : "1x",
"size" : "256x256"
},
{
"idiom" : "mac",
"scale" : "2x",
"size" : "256x256"
},
{
"idiom" : "mac",
"scale" : "1x",
"size" : "512x512"
},
{
"idiom" : "mac",
"scale" : "2x",
"size" : "512x512"
}
],
"info" : {
"author" : "xcode",
"version" : 1
}
}

View File

@ -0,0 +1,6 @@
{
"info" : {
"author" : "xcode",
"version" : 1
}
}

View File

@ -0,0 +1,6 @@
{
"info" : {
"author" : "xcode",
"version" : 1
}
}

View File

@ -0,0 +1,12 @@
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>com.apple.security.app-sandbox</key>
<true/>
<key>com.apple.security.device.audio-input</key>
<true/>
<key>com.apple.security.files.user-selected.read-only</key>
<true/>
</dict>
</plist>

View File

@ -0,0 +1,43 @@
import SwiftUI
import AVFoundation
struct ContentView: View {
@StateObject var whisperState = WhisperState()
var body: some View {
NavigationStack {
VStack {
HStack {
Button("Transcribe", action: {
Task {
await whisperState.transcribeSample()
}
})
.buttonStyle(.bordered)
.disabled(!whisperState.canTranscribe)
Button(whisperState.isRecording ? "Stop recording" : "Start recording", action: {
Task {
await whisperState.toggleRecord()
}
})
.buttonStyle(.bordered)
.disabled(!whisperState.canTranscribe)
}
ScrollView {
Text(verbatim: whisperState.messageLog)
.frame(maxWidth: .infinity, alignment: .leading)
}
}
.navigationTitle("Whisper SwiftUI Demo")
.padding()
}
}
}
struct ContentView_Previews: PreviewProvider {
static var previews: some View {
ContentView()
}
}

View File

@ -0,0 +1,35 @@
import Foundation
import AVFoundation
actor Recorder {
private var recorder: AVAudioRecorder?
enum RecorderError: Error {
case couldNotStartRecording
}
func startRecording(toOutputFile url: URL, delegate: AVAudioRecorderDelegate?) throws {
let recordSettings: [String : Any] = [
AVFormatIDKey: Int(kAudioFormatLinearPCM),
AVSampleRateKey: 16000.0,
AVNumberOfChannelsKey: 1,
AVEncoderAudioQualityKey: AVAudioQuality.high.rawValue
]
#if !os(macOS)
let session = AVAudioSession.sharedInstance()
try session.setCategory(.playAndRecord, mode: .default)
#endif
let recorder = try AVAudioRecorder(url: url, settings: recordSettings)
recorder.delegate = delegate
if recorder.record() == false {
print("Could not start recording")
throw RecorderError.couldNotStartRecording
}
self.recorder = recorder
}
func stopRecording() {
recorder?.stop()
recorder = nil
}
}

View File

@ -0,0 +1,12 @@
import Foundation
func decodeWaveFile(_ url: URL) throws -> [Float] {
let data = try Data(contentsOf: url)
let floats = stride(from: 44, to: data.count, by: 2).map {
return data[$0..<$0 + 2].withUnsafeBytes {
let short = Int16(littleEndian: $0.load(as: Int16.self))
return max(-1.0, min(Float(short) / 32767.0, 1.0))
}
}
return floats
}

View File

@ -0,0 +1,10 @@
import SwiftUI
@main
struct WhisperCppDemoApp: App {
var body: some Scene {
WindowGroup {
ContentView()
}
}
}

View File

@ -0,0 +1 @@
xcuserdata/

View File

@ -0,0 +1,468 @@
// !$*UTF8*$!
{
archiveVersion = 1;
classes = {
};
objectVersion = 56;
objects = {
/* Begin PBXBuildFile section */
0A8E49002954B3F100704C1B /* README.md in Resources */ = {isa = PBXBuildFile; fileRef = 0A8E48FF2954B3F100704C1B /* README.md */; };
0AA751482953AC2E001EE061 /* samples in Resources */ = {isa = PBXBuildFile; fileRef = 0AA751462953AC2E001EE061 /* samples */; };
0AA751492953AC2E001EE061 /* models in Resources */ = {isa = PBXBuildFile; fileRef = 0AA751472953AC2E001EE061 /* models */; };
0AA7514C2953B569001EE061 /* RiffWaveUtils.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0AA7514B2953B569001EE061 /* RiffWaveUtils.swift */; };
0AA7514E2953D958001EE061 /* Recorder.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0AA7514D2953D958001EE061 /* Recorder.swift */; };
0AAC5D9B29539CCF003032C3 /* WhisperCppDemoApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5D9A29539CCF003032C3 /* WhisperCppDemoApp.swift */; };
0AAC5D9D29539CCF003032C3 /* ContentView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5D9C29539CCF003032C3 /* ContentView.swift */; };
0AAC5D9F29539CD0003032C3 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 0AAC5D9E29539CD0003032C3 /* Assets.xcassets */; };
0AAC5DA329539CD0003032C3 /* Preview Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 0AAC5DA229539CD0003032C3 /* Preview Assets.xcassets */; };
0AAC5DCB29539EB1003032C3 /* whisper.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5DC729539EB0003032C3 /* whisper.cpp */; settings = {COMPILER_FLAGS = "-Wno-shorten-64-to-32"; }; };
0AAC5DCC29539EB1003032C3 /* ggml.c in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5DC929539EB0003032C3 /* ggml.c */; settings = {COMPILER_FLAGS = "-DGGML_USE_ACCELERATE -Wno-shorten-64-to-32"; }; };
0AAC5DCE2953A05C003032C3 /* WhisperState.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5DCD2953A05C003032C3 /* WhisperState.swift */; };
0AAC5DD12953A394003032C3 /* LibWhisper.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5DD02953A394003032C3 /* LibWhisper.swift */; };
/* End PBXBuildFile section */
/* Begin PBXFileReference section */
0A8E48FF2954B3F100704C1B /* README.md */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = "<group>"; };
0AA751462953AC2E001EE061 /* samples */ = {isa = PBXFileReference; lastKnownFileType = folder; path = samples; sourceTree = "<group>"; };
0AA751472953AC2E001EE061 /* models */ = {isa = PBXFileReference; lastKnownFileType = folder; path = models; sourceTree = "<group>"; };
0AA7514B2953B569001EE061 /* RiffWaveUtils.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = RiffWaveUtils.swift; sourceTree = "<group>"; };
0AA7514D2953D958001EE061 /* Recorder.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Recorder.swift; sourceTree = "<group>"; };
0AAC5D9729539CCF003032C3 /* whisper.swiftui.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = whisper.swiftui.app; sourceTree = BUILT_PRODUCTS_DIR; };
0AAC5D9A29539CCF003032C3 /* WhisperCppDemoApp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = WhisperCppDemoApp.swift; sourceTree = "<group>"; };
0AAC5D9C29539CCF003032C3 /* ContentView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ContentView.swift; sourceTree = "<group>"; };
0AAC5D9E29539CD0003032C3 /* Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = Assets.xcassets; sourceTree = "<group>"; };
0AAC5DA029539CD0003032C3 /* WhisperCppDemo.entitlements */ = {isa = PBXFileReference; lastKnownFileType = text.plist.entitlements; path = WhisperCppDemo.entitlements; sourceTree = "<group>"; };
0AAC5DA229539CD0003032C3 /* Preview Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = "Preview Assets.xcassets"; sourceTree = "<group>"; };
0AAC5DC629539EAF003032C3 /* WhisperCppDemo-Bridging-Header.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = "WhisperCppDemo-Bridging-Header.h"; sourceTree = "<group>"; };
0AAC5DC729539EB0003032C3 /* whisper.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = whisper.cpp; sourceTree = "<group>"; };
0AAC5DC829539EB0003032C3 /* whisper.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = whisper.h; sourceTree = "<group>"; };
0AAC5DC929539EB0003032C3 /* ggml.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; path = ggml.c; sourceTree = "<group>"; };
0AAC5DCA29539EB0003032C3 /* ggml.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = ggml.h; sourceTree = "<group>"; };
0AAC5DCD2953A05C003032C3 /* WhisperState.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = WhisperState.swift; sourceTree = "<group>"; };
0AAC5DD02953A394003032C3 /* LibWhisper.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LibWhisper.swift; sourceTree = "<group>"; };
/* End PBXFileReference section */
/* Begin PBXFrameworksBuildPhase section */
0AAC5D9429539CCF003032C3 /* Frameworks */ = {
isa = PBXFrameworksBuildPhase;
buildActionMask = 2147483647;
files = (
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXFrameworksBuildPhase section */
/* Begin PBXGroup section */
0AA7513F2953AB32001EE061 /* Models */ = {
isa = PBXGroup;
children = (
0AAC5DCD2953A05C003032C3 /* WhisperState.swift */,
);
path = Models;
sourceTree = "<group>";
};
0AA751402953ABA6001EE061 /* Resources */ = {
isa = PBXGroup;
children = (
0AA751472953AC2E001EE061 /* models */,
0AA751462953AC2E001EE061 /* samples */,
);
path = Resources;
sourceTree = "<group>";
};
0AA7514A2953B561001EE061 /* Utils */ = {
isa = PBXGroup;
children = (
0AA7514B2953B569001EE061 /* RiffWaveUtils.swift */,
0AA7514D2953D958001EE061 /* Recorder.swift */,
);
path = Utils;
sourceTree = "<group>";
};
0AAC5D8E29539CCF003032C3 = {
isa = PBXGroup;
children = (
0A8E48FF2954B3F100704C1B /* README.md */,
0AAC5DC529539E89003032C3 /* whisper.cpp */,
0AAC5DCF2953A36C003032C3 /* whisper.cpp.swift */,
0AAC5D9929539CCF003032C3 /* whisper.swiftui.demo */,
0AAC5D9829539CCF003032C3 /* Products */,
);
sourceTree = "<group>";
};
0AAC5D9829539CCF003032C3 /* Products */ = {
isa = PBXGroup;
children = (
0AAC5D9729539CCF003032C3 /* whisper.swiftui.app */,
);
name = Products;
sourceTree = "<group>";
};
0AAC5D9929539CCF003032C3 /* whisper.swiftui.demo */ = {
isa = PBXGroup;
children = (
0AA7514A2953B561001EE061 /* Utils */,
0AA751402953ABA6001EE061 /* Resources */,
0AA7513F2953AB32001EE061 /* Models */,
0AAC5DD32953A9ED003032C3 /* Supporting files */,
0AAC5DD22953A9E3003032C3 /* UI */,
0AAC5D9A29539CCF003032C3 /* WhisperCppDemoApp.swift */,
);
path = whisper.swiftui.demo;
sourceTree = "<group>";
};
0AAC5DA129539CD0003032C3 /* Preview Content */ = {
isa = PBXGroup;
children = (
0AAC5DA229539CD0003032C3 /* Preview Assets.xcassets */,
);
name = "Preview Content";
path = "../Preview Content";
sourceTree = "<group>";
};
0AAC5DC529539E89003032C3 /* whisper.cpp */ = {
isa = PBXGroup;
children = (
0AAC5DC929539EB0003032C3 /* ggml.c */,
0AAC5DCA29539EB0003032C3 /* ggml.h */,
0AAC5DC729539EB0003032C3 /* whisper.cpp */,
0AAC5DC829539EB0003032C3 /* whisper.h */,
);
name = whisper.cpp;
path = ../..;
sourceTree = "<group>";
};
0AAC5DCF2953A36C003032C3 /* whisper.cpp.swift */ = {
isa = PBXGroup;
children = (
0AAC5DC629539EAF003032C3 /* WhisperCppDemo-Bridging-Header.h */,
0AAC5DD02953A394003032C3 /* LibWhisper.swift */,
);
path = whisper.cpp.swift;
sourceTree = "<group>";
};
0AAC5DD22953A9E3003032C3 /* UI */ = {
isa = PBXGroup;
children = (
0AAC5D9C29539CCF003032C3 /* ContentView.swift */,
);
path = UI;
sourceTree = "<group>";
};
0AAC5DD32953A9ED003032C3 /* Supporting files */ = {
isa = PBXGroup;
children = (
0AAC5D9E29539CD0003032C3 /* Assets.xcassets */,
0AAC5DA029539CD0003032C3 /* WhisperCppDemo.entitlements */,
0AAC5DA129539CD0003032C3 /* Preview Content */,
);
path = "Supporting files";
sourceTree = "<group>";
};
/* End PBXGroup section */
/* Begin PBXNativeTarget section */
0AAC5D9629539CCF003032C3 /* whisper.swiftui */ = {
isa = PBXNativeTarget;
buildConfigurationList = 0AAC5DBC29539CD0003032C3 /* Build configuration list for PBXNativeTarget "whisper.swiftui" */;
buildPhases = (
0AAC5D9329539CCF003032C3 /* Sources */,
0AAC5D9429539CCF003032C3 /* Frameworks */,
0AAC5D9529539CCF003032C3 /* Resources */,
);
buildRules = (
);
dependencies = (
);
name = whisper.swiftui;
productName = WhisperCppDemo;
productReference = 0AAC5D9729539CCF003032C3 /* whisper.swiftui.app */;
productType = "com.apple.product-type.application";
};
/* End PBXNativeTarget section */
/* Begin PBXProject section */
0AAC5D8F29539CCF003032C3 /* Project object */ = {
isa = PBXProject;
attributes = {
BuildIndependentTargetsInParallel = 1;
LastSwiftUpdateCheck = 1410;
LastUpgradeCheck = 1410;
TargetAttributes = {
0AAC5D9629539CCF003032C3 = {
CreatedOnToolsVersion = 14.1;
LastSwiftMigration = 1410;
};
};
};
buildConfigurationList = 0AAC5D9229539CCF003032C3 /* Build configuration list for PBXProject "whisper.swiftui" */;
compatibilityVersion = "Xcode 14.0";
developmentRegion = en;
hasScannedForEncodings = 0;
knownRegions = (
en,
Base,
);
mainGroup = 0AAC5D8E29539CCF003032C3;
productRefGroup = 0AAC5D9829539CCF003032C3 /* Products */;
projectDirPath = "";
projectRoot = "";
targets = (
0AAC5D9629539CCF003032C3 /* whisper.swiftui */,
);
};
/* End PBXProject section */
/* Begin PBXResourcesBuildPhase section */
0AAC5D9529539CCF003032C3 /* Resources */ = {
isa = PBXResourcesBuildPhase;
buildActionMask = 2147483647;
files = (
0AA751482953AC2E001EE061 /* samples in Resources */,
0AAC5DA329539CD0003032C3 /* Preview Assets.xcassets in Resources */,
0A8E49002954B3F100704C1B /* README.md in Resources */,
0AA751492953AC2E001EE061 /* models in Resources */,
0AAC5D9F29539CD0003032C3 /* Assets.xcassets in Resources */,
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXResourcesBuildPhase section */
/* Begin PBXSourcesBuildPhase section */
0AAC5D9329539CCF003032C3 /* Sources */ = {
isa = PBXSourcesBuildPhase;
buildActionMask = 2147483647;
files = (
0AAC5D9D29539CCF003032C3 /* ContentView.swift in Sources */,
0AAC5D9B29539CCF003032C3 /* WhisperCppDemoApp.swift in Sources */,
0AAC5DCC29539EB1003032C3 /* ggml.c in Sources */,
0AAC5DCE2953A05C003032C3 /* WhisperState.swift in Sources */,
0AAC5DD12953A394003032C3 /* LibWhisper.swift in Sources */,
0AA7514C2953B569001EE061 /* RiffWaveUtils.swift in Sources */,
0AAC5DCB29539EB1003032C3 /* whisper.cpp in Sources */,
0AA7514E2953D958001EE061 /* Recorder.swift in Sources */,
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXSourcesBuildPhase section */
/* Begin XCBuildConfiguration section */
0AAC5DBA29539CD0003032C3 /* Debug */ = {
isa = XCBuildConfiguration;
buildSettings = {
ALWAYS_SEARCH_USER_PATHS = NO;
CLANG_ANALYZER_NONNULL = YES;
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
CLANG_CXX_LANGUAGE_STANDARD = "gnu++20";
CLANG_ENABLE_MODULES = YES;
CLANG_ENABLE_OBJC_ARC = YES;
CLANG_ENABLE_OBJC_WEAK = YES;
CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
CLANG_WARN_BOOL_CONVERSION = YES;
CLANG_WARN_COMMA = YES;
CLANG_WARN_CONSTANT_CONVERSION = YES;
CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
CLANG_WARN_EMPTY_BODY = YES;
CLANG_WARN_ENUM_CONVERSION = YES;
CLANG_WARN_INFINITE_RECURSION = YES;
CLANG_WARN_INT_CONVERSION = YES;
CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
CLANG_WARN_STRICT_PROTOTYPES = YES;
CLANG_WARN_SUSPICIOUS_MOVE = YES;
CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
CLANG_WARN_UNREACHABLE_CODE = YES;
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
COPY_PHASE_STRIP = NO;
DEBUG_INFORMATION_FORMAT = dwarf;
ENABLE_STRICT_OBJC_MSGSEND = YES;
ENABLE_TESTABILITY = YES;
GCC_C_LANGUAGE_STANDARD = gnu11;
GCC_DYNAMIC_NO_PIC = NO;
GCC_NO_COMMON_BLOCKS = YES;
GCC_OPTIMIZATION_LEVEL = 0;
GCC_PREPROCESSOR_DEFINITIONS = (
"DEBUG=1",
"$(inherited)",
);
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
GCC_WARN_UNDECLARED_SELECTOR = YES;
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
GCC_WARN_UNUSED_FUNCTION = YES;
GCC_WARN_UNUSED_VARIABLE = YES;
INFOPLIST_KEY_NSMicrophoneUsageDescription = "Needed to transcribe audio";
MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE;
MTL_FAST_MATH = YES;
ONLY_ACTIVE_ARCH = YES;
SWIFT_ACTIVE_COMPILATION_CONDITIONS = DEBUG;
SWIFT_OPTIMIZATION_LEVEL = "-Onone";
};
name = Debug;
};
0AAC5DBB29539CD0003032C3 /* Release */ = {
isa = XCBuildConfiguration;
buildSettings = {
ALWAYS_SEARCH_USER_PATHS = NO;
CLANG_ANALYZER_NONNULL = YES;
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
CLANG_CXX_LANGUAGE_STANDARD = "gnu++20";
CLANG_ENABLE_MODULES = YES;
CLANG_ENABLE_OBJC_ARC = YES;
CLANG_ENABLE_OBJC_WEAK = YES;
CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
CLANG_WARN_BOOL_CONVERSION = YES;
CLANG_WARN_COMMA = YES;
CLANG_WARN_CONSTANT_CONVERSION = YES;
CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
CLANG_WARN_EMPTY_BODY = YES;
CLANG_WARN_ENUM_CONVERSION = YES;
CLANG_WARN_INFINITE_RECURSION = YES;
CLANG_WARN_INT_CONVERSION = YES;
CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
CLANG_WARN_STRICT_PROTOTYPES = YES;
CLANG_WARN_SUSPICIOUS_MOVE = YES;
CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
CLANG_WARN_UNREACHABLE_CODE = YES;
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
COPY_PHASE_STRIP = NO;
DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym";
ENABLE_NS_ASSERTIONS = NO;
ENABLE_STRICT_OBJC_MSGSEND = YES;
GCC_C_LANGUAGE_STANDARD = gnu11;
GCC_NO_COMMON_BLOCKS = YES;
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
GCC_WARN_UNDECLARED_SELECTOR = YES;
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
GCC_WARN_UNUSED_FUNCTION = YES;
GCC_WARN_UNUSED_VARIABLE = YES;
INFOPLIST_KEY_NSMicrophoneUsageDescription = "Needed to transcribe audio";
MTL_ENABLE_DEBUG_INFO = NO;
MTL_FAST_MATH = YES;
SWIFT_COMPILATION_MODE = wholemodule;
SWIFT_OPTIMIZATION_LEVEL = "-O";
};
name = Release;
};
0AAC5DBD29539CD0003032C3 /* Debug */ = {
isa = XCBuildConfiguration;
buildSettings = {
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor;
CLANG_ENABLE_MODULES = YES;
CODE_SIGN_ENTITLEMENTS = "whisper.swiftui.demo/Supporting files/WhisperCppDemo.entitlements";
CODE_SIGN_STYLE = Automatic;
CURRENT_PROJECT_VERSION = 1;
DEVELOPMENT_ASSET_PATHS = "\"whisper.swiftui.demo/Supporting files/Preview Content\"";
DEVELOPMENT_TEAM = 3TZ9BM962G;
ENABLE_HARDENED_RUNTIME = YES;
ENABLE_PREVIEWS = YES;
GENERATE_INFOPLIST_FILE = YES;
"INFOPLIST_KEY_UIApplicationSceneManifest_Generation[sdk=iphoneos*]" = YES;
"INFOPLIST_KEY_UIApplicationSceneManifest_Generation[sdk=iphonesimulator*]" = YES;
"INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents[sdk=iphoneos*]" = YES;
"INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents[sdk=iphonesimulator*]" = YES;
"INFOPLIST_KEY_UILaunchScreen_Generation[sdk=iphoneos*]" = YES;
"INFOPLIST_KEY_UILaunchScreen_Generation[sdk=iphonesimulator*]" = YES;
"INFOPLIST_KEY_UIStatusBarStyle[sdk=iphoneos*]" = UIStatusBarStyleDefault;
"INFOPLIST_KEY_UIStatusBarStyle[sdk=iphonesimulator*]" = UIStatusBarStyleDefault;
INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = "UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight";
INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone = "UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight";
IPHONEOS_DEPLOYMENT_TARGET = 16.1;
LD_RUNPATH_SEARCH_PATHS = "@executable_path/Frameworks";
"LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = "@executable_path/../Frameworks";
MACOSX_DEPLOYMENT_TARGET = 13.0;
MARKETING_VERSION = 1.0;
PRODUCT_BUNDLE_IDENTIFIER = com.whispercppdemo.WhisperCppDemo;
PRODUCT_NAME = "$(TARGET_NAME)";
SDKROOT = auto;
SUPPORTED_PLATFORMS = "iphoneos iphonesimulator macosx";
SWIFT_EMIT_LOC_STRINGS = YES;
SWIFT_OBJC_BRIDGING_HEADER = "whisper.cpp.swift/WhisperCppDemo-Bridging-Header.h";
SWIFT_OPTIMIZATION_LEVEL = "-Onone";
SWIFT_VERSION = 5.0;
TARGETED_DEVICE_FAMILY = "1,2";
};
name = Debug;
};
0AAC5DBE29539CD0003032C3 /* Release */ = {
isa = XCBuildConfiguration;
buildSettings = {
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor;
CLANG_ENABLE_MODULES = YES;
CODE_SIGN_ENTITLEMENTS = "whisper.swiftui.demo/Supporting files/WhisperCppDemo.entitlements";
CODE_SIGN_STYLE = Automatic;
CURRENT_PROJECT_VERSION = 1;
DEVELOPMENT_ASSET_PATHS = "\"whisper.swiftui.demo/Supporting files/Preview Content\"";
DEVELOPMENT_TEAM = 3TZ9BM962G;
ENABLE_HARDENED_RUNTIME = YES;
ENABLE_PREVIEWS = YES;
GENERATE_INFOPLIST_FILE = YES;
"INFOPLIST_KEY_UIApplicationSceneManifest_Generation[sdk=iphoneos*]" = YES;
"INFOPLIST_KEY_UIApplicationSceneManifest_Generation[sdk=iphonesimulator*]" = YES;
"INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents[sdk=iphoneos*]" = YES;
"INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents[sdk=iphonesimulator*]" = YES;
"INFOPLIST_KEY_UILaunchScreen_Generation[sdk=iphoneos*]" = YES;
"INFOPLIST_KEY_UILaunchScreen_Generation[sdk=iphonesimulator*]" = YES;
"INFOPLIST_KEY_UIStatusBarStyle[sdk=iphoneos*]" = UIStatusBarStyleDefault;
"INFOPLIST_KEY_UIStatusBarStyle[sdk=iphonesimulator*]" = UIStatusBarStyleDefault;
INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = "UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight";
INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone = "UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight";
IPHONEOS_DEPLOYMENT_TARGET = 16.1;
LD_RUNPATH_SEARCH_PATHS = "@executable_path/Frameworks";
"LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = "@executable_path/../Frameworks";
LLVM_LTO = YES;
MACOSX_DEPLOYMENT_TARGET = 13.0;
MARKETING_VERSION = 1.0;
PRODUCT_BUNDLE_IDENTIFIER = com.whispercppdemo.WhisperCppDemo;
PRODUCT_NAME = "$(TARGET_NAME)";
SDKROOT = auto;
SUPPORTED_PLATFORMS = "iphoneos iphonesimulator macosx";
SWIFT_EMIT_LOC_STRINGS = YES;
SWIFT_OBJC_BRIDGING_HEADER = "whisper.cpp.swift/WhisperCppDemo-Bridging-Header.h";
SWIFT_VERSION = 5.0;
TARGETED_DEVICE_FAMILY = "1,2";
};
name = Release;
};
/* End XCBuildConfiguration section */
/* Begin XCConfigurationList section */
0AAC5D9229539CCF003032C3 /* Build configuration list for PBXProject "whisper.swiftui" */ = {
isa = XCConfigurationList;
buildConfigurations = (
0AAC5DBA29539CD0003032C3 /* Debug */,
0AAC5DBB29539CD0003032C3 /* Release */,
);
defaultConfigurationIsVisible = 0;
defaultConfigurationName = Release;
};
0AAC5DBC29539CD0003032C3 /* Build configuration list for PBXNativeTarget "whisper.swiftui" */ = {
isa = XCConfigurationList;
buildConfigurations = (
0AAC5DBD29539CD0003032C3 /* Debug */,
0AAC5DBE29539CD0003032C3 /* Release */,
);
defaultConfigurationIsVisible = 0;
defaultConfigurationName = Release;
};
/* End XCConfigurationList section */
};
rootObject = 0AAC5D8F29539CCF003032C3 /* Project object */;
}

View File

@ -0,0 +1 @@
contents.xcworkspacedata

View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>IDEDidComputeMac32BitWarning</key>
<true/>
</dict>
</plist>

View File

@ -0,0 +1,100 @@
<?xml version="1.0" encoding="UTF-8"?>
<Scheme
LastUpgradeVersion = "1410"
version = "1.3">
<BuildAction
parallelizeBuildables = "YES"
buildImplicitDependencies = "YES">
<BuildActionEntries>
<BuildActionEntry
buildForTesting = "YES"
buildForRunning = "YES"
buildForProfiling = "YES"
buildForArchiving = "YES"
buildForAnalyzing = "YES">
<BuildableReference
BuildableIdentifier = "primary"
BlueprintIdentifier = "0AAC5D9629539CCF003032C3"
BuildableName = "whisper.swiftui.app"
BlueprintName = "whisper.swiftui"
ReferencedContainer = "container:whisper.swiftui.xcodeproj">
</BuildableReference>
</BuildActionEntry>
</BuildActionEntries>
</BuildAction>
<TestAction
buildConfiguration = "Debug"
selectedDebuggerIdentifier = "Xcode.DebuggerFoundation.Debugger.LLDB"
selectedLauncherIdentifier = "Xcode.DebuggerFoundation.Launcher.LLDB"
shouldUseLaunchSchemeArgsEnv = "YES">
<Testables>
<TestableReference
skipped = "NO"
parallelizable = "YES">
<BuildableReference
BuildableIdentifier = "primary"
BlueprintIdentifier = "0AAC5DA729539CD0003032C3"
BuildableName = "whisper.swiftuiTests.xctest"
BlueprintName = "whisper.swiftuiTests"
ReferencedContainer = "container:whisper.swiftui.xcodeproj">
</BuildableReference>
</TestableReference>
<TestableReference
skipped = "NO"
parallelizable = "YES">
<BuildableReference
BuildableIdentifier = "primary"
BlueprintIdentifier = "0AAC5DB129539CD0003032C3"
BuildableName = "whisper.swiftuiUITests.xctest"
BlueprintName = "whisper.swiftuiUITests"
ReferencedContainer = "container:whisper.swiftui.xcodeproj">
</BuildableReference>
</TestableReference>
</Testables>
</TestAction>
<LaunchAction
buildConfiguration = "Release"
selectedDebuggerIdentifier = "Xcode.DebuggerFoundation.Debugger.LLDB"
selectedLauncherIdentifier = "Xcode.DebuggerFoundation.Launcher.LLDB"
launchStyle = "0"
useCustomWorkingDirectory = "NO"
ignoresPersistentStateOnLaunch = "NO"
debugDocumentVersioning = "YES"
debugServiceExtension = "internal"
allowLocationSimulation = "YES">
<BuildableProductRunnable
runnableDebuggingMode = "0">
<BuildableReference
BuildableIdentifier = "primary"
BlueprintIdentifier = "0AAC5D9629539CCF003032C3"
BuildableName = "whisper.swiftui.app"
BlueprintName = "whisper.swiftui"
ReferencedContainer = "container:whisper.swiftui.xcodeproj">
</BuildableReference>
</BuildableProductRunnable>
</LaunchAction>
<ProfileAction
buildConfiguration = "Release"
shouldUseLaunchSchemeArgsEnv = "YES"
savedToolIdentifier = ""
useCustomWorkingDirectory = "NO"
debugDocumentVersioning = "YES">
<BuildableProductRunnable
runnableDebuggingMode = "0">
<BuildableReference
BuildableIdentifier = "primary"
BlueprintIdentifier = "0AAC5D9629539CCF003032C3"
BuildableName = "whisper.swiftui.app"
BlueprintName = "whisper.swiftui"
ReferencedContainer = "container:whisper.swiftui.xcodeproj">
</BuildableReference>
</BuildableProductRunnable>
</ProfileAction>
<AnalyzeAction
buildConfiguration = "Debug">
</AnalyzeAction>
<ArchiveAction
buildConfiguration = "Release"
revealArchiveInOrganizer = "YES">
</ArchiveAction>
</Scheme>

View File

@ -8,6 +8,8 @@ add_executable(${TARGET}
emscripten.cpp
)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE
whisper
)

View File

@ -18,7 +18,7 @@ EMSCRIPTEN_BINDINGS(whisper) {
for (size_t i = 0; i < g_contexts.size(); ++i) {
if (g_contexts[i] == nullptr) {
g_contexts[i] = whisper_init(path_model.c_str());
g_contexts[i] = whisper_init_from_file(path_model.c_str());
if (g_contexts[i] != nullptr) {
return i + 1;
} else {

2042
ggml.c

File diff suppressed because it is too large Load Diff

4
ggml.h
View File

@ -724,11 +724,15 @@ enum ggml_opt_result ggml_opt(
int ggml_cpu_has_avx(void);
int ggml_cpu_has_avx2(void);
int ggml_cpu_has_avx512(void);
int ggml_cpu_has_fma(void);
int ggml_cpu_has_neon(void);
int ggml_cpu_has_arm_fma(void);
int ggml_cpu_has_f16c(void);
int ggml_cpu_has_fp16_va(void);
int ggml_cpu_has_wasm_simd(void);
int ggml_cpu_has_blas(void);
int ggml_cpu_has_sse3(void);
int ggml_cpu_has_vsx(void);
#ifdef __cplusplus
}

View File

@ -56,7 +56,7 @@ def bytes_to_unicode():
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
This is a significant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""

View File

@ -40,7 +40,7 @@ if exist "ggml-%model%.bin" (
goto :eof
)
PowerShell -NoProfile -ExecutionPolicy Bypass -Command "Invoke-WebRequest -Uri https://huggingface.co/datasets/ggerganov/whisper.cpp/raw/main/ggml-%model%.bin -OutFile ggml-%model%.bin"
PowerShell -NoProfile -ExecutionPolicy Bypass -Command "Invoke-WebRequest -Uri https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main/ggml-%model%.bin -OutFile ggml-%model%.bin"
if %ERRORLEVEL% neq 0 (
echo Failed to download ggml model %model%

View File

@ -19,7 +19,7 @@ function get_script_path() {
fi
}
models_path=$(get_script_path)
models_path="$(get_script_path)"
# Whisper models
models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small" "medium.en" "medium" "large-v1" "large" )

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,7 @@
#ifndef WHISPER_H
#define WHISPER_H
#include <stddef.h>
#include <stdint.h>
#include <stdbool.h>
@ -40,7 +41,7 @@ extern "C" {
//
// ...
//
// struct whisper_context * ctx = whisper_init("/path/to/ggml-base.en.bin");
// struct whisper_context * ctx = whisper_init_from_file("/path/to/ggml-base.en.bin");
//
// if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
// fprintf(stderr, "failed to process audio\n");
@ -73,6 +74,7 @@ extern "C" {
whisper_token tid; // forced timestamp token id
float p; // probability of the token
float plog; // log probability of the token
float pt; // probability of the timestamp token
float ptsum; // sum of probabilities of all timestamp tokens
@ -84,9 +86,20 @@ extern "C" {
float vlen; // voice length of the token
} whisper_token_data;
// Allocates all memory needed for the model and loads the model from the given file.
// Returns NULL on failure.
WHISPER_API struct whisper_context * whisper_init(const char * path_model);
typedef struct whisper_model_loader {
void * context;
size_t (*read)(void * ctx, void * output, size_t read_size);
bool (*eof)(void * ctx);
void (*close)(void * ctx);
} whisper_model_loader;
// Various functions for loading a ggml whisper model.
// Allocate (almost) all memory needed for the model.
// Return NULL on failure
WHISPER_API struct whisper_context * whisper_init_from_file(const char * path_model);
WHISPER_API struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size);
WHISPER_API struct whisper_context * whisper_init(struct whisper_model_loader * loader);
// Frees all memory allocated by the model.
WHISPER_API void whisper_free(struct whisper_context * ctx);
@ -124,6 +137,7 @@ extern "C" {
// tokens + n_tokens is the provided context for the decoder.
// n_past is the number of tokens to use from previous decoder calls.
// Returns 0 on success
// TODO: add support for multiple decoders
WHISPER_API int whisper_decode(
struct whisper_context * ctx,
const whisper_token * tokens,
@ -131,14 +145,6 @@ extern "C" {
int n_past,
int n_threads);
// Token sampling methods.
// These are provided for convenience and can be used after each call to whisper_decode().
// You can also implement your own sampling method using the whisper_get_probs() function.
// whisper_sample_best() returns the token with the highest probability
// whisper_sample_timestamp() returns the most probable timestamp token
WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx);
WHISPER_API whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial);
// Convert the provided text into tokens.
// The tokens pointer must be large enough to hold the resulting tokens.
// Returns the number of tokens on success, no more than n_max_tokens
@ -177,10 +183,14 @@ extern "C" {
WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length
WHISPER_API int whisper_n_vocab (struct whisper_context * ctx);
WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx);
WHISPER_API int whisper_n_audio_ctx (struct whisper_context * ctx);
WHISPER_API int whisper_is_multilingual(struct whisper_context * ctx);
// The probabilities for the next token
WHISPER_API float * whisper_get_probs(struct whisper_context * ctx);
// Token logits obtained from the last call to whisper_decode()
// The logits for the last token are stored in the last row
// Rows: n_tokens
// Cols: n_vocab
WHISPER_API float * whisper_get_logits(struct whisper_context * ctx);
// Token Id -> String. Uses the vocabulary in the provided context
WHISPER_API const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token);
@ -209,8 +219,8 @@ extern "C" {
// Available sampling strategies
enum whisper_sampling_strategy {
WHISPER_SAMPLING_GREEDY, // Always select the most probable token
WHISPER_SAMPLING_BEAM_SEARCH, // TODO: not implemented yet!
WHISPER_SAMPLING_GREEDY, // similar to OpenAI's GreefyDecoder
WHISPER_SAMPLING_BEAM_SEARCH, // similar to OpenAI's BeamSearchDecoder
};
// Text segment callback
@ -230,17 +240,17 @@ extern "C" {
enum whisper_sampling_strategy strategy;
int n_threads;
int n_max_text_ctx;
int n_max_text_ctx; // max tokens to use from past text as prompt for the decoder
int offset_ms; // start offset in ms
int duration_ms; // audio duration to process in ms
bool translate;
bool no_context;
bool no_context; // do not use initial prompt for the decoder (if any)
bool single_segment; // force single segment output (useful for streaming)
bool print_special;
bool print_progress;
bool print_realtime;
bool print_timestamps;
bool print_special; // print special tokens (e.g. <SOT>, <EOT>, <BEG>, etc.)
bool print_progress; // print progress information
bool print_realtime; // print results from within whisper.cpp (avoid it, use callback instead)
bool print_timestamps; // print timestamps for each text segment when printing realtime
// [EXPERIMENTAL] token-level timestamps
bool token_timestamps; // enable token-level timestamps
@ -250,10 +260,11 @@ extern "C" {
int max_tokens; // max tokens per segment (0 = no limit)
// [EXPERIMENTAL] speed-up techniques
// note: these can significantly reduce the quality of the output
bool speed_up; // speed-up the audio by 2x using Phase Vocoder
int audio_ctx; // overwrite the audio context size (0 = use default)
// tokens to provide the whisper model as initial prompt
// tokens to provide to the whisper decoder as initial prompt
// these are prepended to any existing text context from a previous call
const whisper_token * prompt_tokens;
int prompt_n_tokens;
@ -261,19 +272,35 @@ extern "C" {
// for auto-detection, set to nullptr, "" or "auto"
const char * language;
// common decoding parameters:
bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89
float temperature; // initial decoding temperature, ref: https://ai.stackexchange.com/a/32478
float max_initial_ts; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97
float length_penalty; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L267
// fallback parameters
// ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L274-L278
float temperature_inc;
float entropy_thold; // similar to OpenAI's "compression_ratio_threshold"
float logprob_thold;
float no_speech_thold; // TODO: not implemented
struct {
int n_past;
int best_of; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L264
} greedy;
struct {
int n_past;
int beam_width;
int n_best;
int beam_size; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L265
float patience; // TODO: not implemented, ref: https://arxiv.org/pdf/2204.05424.pdf
} beam_search;
// called for every newly generated text segment
whisper_new_segment_callback new_segment_callback;
void * new_segment_callback_user_data;
// called each time before the encoder starts
whisper_encoder_begin_callback encoder_begin_callback;
void * encoder_begin_callback_user_data;
};