mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-06-24 17:15:19 +00:00
Compare commits
369 Commits
experiment
...
v1.1.0
Author | SHA1 | Date | |
---|---|---|---|
8738427dd6 | |||
c3991bbb24 | |||
00ea21668b | |||
0b85e8c401 | |||
fafd78945d | |||
8de452c18b | |||
a6dbd9188b | |||
4ef3398e8f | |||
5e9f33596f | |||
8d7b29cedd | |||
08dc705a69 | |||
1512545149 | |||
52a3e0c92a | |||
d1ea1220ff | |||
9c4a1522f6 | |||
f078a6f20e | |||
f30b5d322c | |||
44efbf7ff1 | |||
d347a59a5f | |||
6394c906af | |||
74ffa14e1d | |||
65fdcbbbbb | |||
d61d55cd4b | |||
d51fc3ee0a | |||
f82a7dd019 | |||
87dd4a3081 | |||
41e05c6b1b | |||
fa379cb22a | |||
322f4e6c4e | |||
1652965529 | |||
6042c7a3be | |||
6b351bb669 | |||
a62170c656 | |||
1944e7c33e | |||
49a8dd6732 | |||
8c7f642286 | |||
ad2a4ffa03 | |||
b3c865083e | |||
a0d4f8e65c | |||
4a214d2f07 | |||
0a0cfa7985 | |||
196d738974 | |||
84c6b42e65 | |||
dd6d582977 | |||
d51c5eb906 | |||
0be6a1afd9 | |||
a466c3404d | |||
d629c034a4 | |||
f00509d57c | |||
424c410c42 | |||
d97e6005e9 | |||
3467230a77 | |||
a091581eb3 | |||
68daf6e487 | |||
a593b932e4 | |||
9a8ad3db69 | |||
4e0b2069e7 | |||
ac521a566e | |||
331c0bbddc | |||
dc90efd504 | |||
7282e2109e | |||
466ceebb78 | |||
77226aa89d | |||
543bd5627e | |||
62fee9a9cc | |||
493d94130d | |||
1480a5f1af | |||
0f4227d9ee | |||
4c1fe0c813 | |||
fa463313ad | |||
501a6b455c | |||
91fc08c641 | |||
e1432dd91a | |||
22193cbfe8 | |||
42c6730732 | |||
76b6211f9b | |||
86a277f78d | |||
231bebca7d | |||
90564f85f9 | |||
99da1e5cc8 | |||
8e3f129b4d | |||
1d716d6e34 | |||
419b8a6402 | |||
1eb81f863f | |||
fba10a4c68 | |||
afe2db0fe2 | |||
a7047b2a28 | |||
32fbc8cd04 | |||
b8065d90f5 | |||
4312995974 | |||
5eeeb3412d | |||
6a69e3ae27 | |||
bf69b669a0 | |||
ea19ed33f1 | |||
675e787171 | |||
c6c3ad5a98 | |||
6a7c82501e | |||
a82d331034 | |||
c37c2443c1 | |||
0f11759406 | |||
5a5c5ddcca | |||
34e0b4b9ef | |||
b0f8013eb9 | |||
124c718c73 | |||
f66ac6dc4f | |||
9955fa4ed7 | |||
a613f16aec | |||
930c693989 | |||
d8a0dde31a | |||
9e3e6f253a | |||
57ccd7cc4f | |||
812ae3ffbd | |||
f309f97df6 | |||
aa6adda26e | |||
444349f4ec | |||
37a93d2459 | |||
e70d47baab | |||
6ed786957e | |||
ea38ad6e70 | |||
054940e1f6 | |||
fcf515de60 | |||
85c9ac18b5 | |||
b7c85d1ea6 | |||
3b1aacbe6d | |||
d1da35de06 | |||
603f97ba11 | |||
50a061b313 | |||
832b4f34c9 | |||
0f98755fc5 | |||
56822621a8 | |||
9e5f3ddc16 | |||
d91c001120 | |||
04a16bbf11 | |||
47afb93c3c | |||
575c53dc41 | |||
3996ecc156 | |||
faa85f9840 | |||
b6597539f9 | |||
9a4b7a916e | |||
f8ec718b76 | |||
35b40a93b9 | |||
9fe7306f4b | |||
13e8eb2346 | |||
78d13257be | |||
9b7df68753 | |||
061fc81bd6 | |||
57e0e6b700 | |||
4f7363077f | |||
093c840dee | |||
e7f09a0a61 | |||
4698dcdb52 | |||
6fd5358dd0 | |||
164df0d447 | |||
e266cb0723 | |||
c207eed431 | |||
67e819baf4 | |||
a425365b82 | |||
e0e864d9ca | |||
68ecadbbc9 | |||
c536ff4005 | |||
cb70b07db5 | |||
3c390ffe38 | |||
be16dfa038 | |||
0f619b52ce | |||
1246dd023e | |||
0be27bbd92 | |||
bc88eb13c6 | |||
b8ce25dec1 | |||
fd113687aa | |||
e4805d9601 | |||
ff36415a86 | |||
025ff465b6 | |||
2c0501b38a | |||
abce28ea99 | |||
a2ecd54455 | |||
128aaadb93 | |||
454b91de16 | |||
d7024cf9dc | |||
37422ed733 | |||
be3b720f96 | |||
00f46dbc1d | |||
5698bddbc9 | |||
388e9f79ad | |||
35cd29ce1f | |||
a156a358ca | |||
6a84147113 | |||
804f36aa2c | |||
4b2f51b479 | |||
800ae5b808 | |||
83456076f0 | |||
3df6c14fca | |||
d64d6ca3fd | |||
93482d0373 | |||
49706a658a | |||
363a2dadec | |||
623a486056 | |||
2f596f5b33 | |||
e5dcdabbb8 | |||
dad109c3f1 | |||
326573de9a | |||
9aea96f774 | |||
385236d1d3 | |||
63ae03b8e0 | |||
78116f8eda | |||
a4dfbeecf9 | |||
2e311a2917 | |||
2065572a11 | |||
5c2176e314 | |||
f2df9bd768 | |||
fb8d77f760 | |||
62b5ff875c | |||
d351771a4b | |||
c058aaf22e | |||
2ba66360c9 | |||
e70e5c8b53 | |||
55a0e1a64e | |||
864a78a8d0 | |||
83c742f1a7 | |||
41b48ab7f1 | |||
a728be9cdb | |||
46a68fb9b5 | |||
ccd56a9c5b | |||
3500ce8727 | |||
7519eabf65 | |||
b21213c23e | |||
9e700e1821 | |||
0bfe728b84 | |||
4e5674a5d5 | |||
4c66b6a828 | |||
c30bffc8a5 | |||
8fdfb0ba92 | |||
c71363f14c | |||
a09e9123ca | |||
d42cf6d0df | |||
ef47d77492 | |||
75171c2b79 | |||
a2eeb941f6 | |||
0e689f83d8 | |||
d5afebd37c | |||
4b1c32e8ea | |||
b5dde365e9 | |||
02dfd5b8c3 | |||
c63ce24834 | |||
137321915f | |||
24cd12f647 | |||
e46bc56e71 | |||
6fb98370ba | |||
0729da9a3b | |||
5dc74e3aff | |||
ac8ef34039 | |||
b26345cc7b | |||
8dac3c6e10 | |||
6417e59aad | |||
dc12994603 | |||
b0f2aa0ea6 | |||
57fb46f307 | |||
5a9e4260a6 | |||
eba62e0fa1 | |||
69bdb6624a | |||
12fb303d9d | |||
234f414652 | |||
014a119052 | |||
dec40be58f | |||
e5044f87d9 | |||
2827cbbbe8 | |||
0b2dc3c82c | |||
a272f10b2e | |||
85d6e1e1e7 | |||
72e9cdd6bf | |||
c565c569e7 | |||
2c281d190b | |||
b89f8960ca | |||
6f82320b05 | |||
2298310dd8 | |||
8347a7bb6a | |||
fbd513b813 | |||
ebb01b9e33 | |||
9820234f13 | |||
a22e5741d8 | |||
2400660f3f | |||
058a27b2e5 | |||
a09ce6e889 | |||
a6c786d5dc | |||
9ccafa8792 | |||
89d8ee3ee5 | |||
91dcf5f35b | |||
113a4f06d8 | |||
47e78b7288 | |||
34bb3ab0cf | |||
c6710efde2 | |||
4c68f4cac0 | |||
728676927f | |||
d4f94ce427 | |||
a52ee08c1e | |||
b41f4a90eb | |||
bb1ee266d2 | |||
5f7e9fa2dc | |||
181b762de8 | |||
3d37ad5133 | |||
4e887dc350 | |||
4196856c7b | |||
705198f063 | |||
3e69a6071d | |||
f3dae90c31 | |||
6a81ed3e78 | |||
7affd309d3 | |||
8f95c25aed | |||
31ff0c6a1f | |||
f4aa01c2f8 | |||
8c1d970088 | |||
6b45e37b2b | |||
491ecd7056 | |||
db460b78ff | |||
e905c6f827 | |||
7d0dee7a8a | |||
8d15a1c635 | |||
692aa0784f | |||
80aefc9514 | |||
5698b51718 | |||
3fe3898ebb | |||
81c185576c | |||
744bd47685 | |||
66b3169d39 | |||
19a780afe5 | |||
1969ee4bc7 | |||
0e4fd43400 | |||
19817711b4 | |||
7eeef0358a | |||
632660abb9 | |||
e36aabe00d | |||
2d171ced32 | |||
e30cf83158 | |||
d6b84b2a23 | |||
b4a3875b2c | |||
cf67bfffa0 | |||
91632eb6ea | |||
b81a81d543 | |||
d14823582d | |||
20d8e7a309 | |||
72d967bce4 | |||
130b5c02d6 | |||
0e858f080d | |||
f24d940ca9 | |||
949f97a8b4 | |||
0ad085f5e8 | |||
36945162fa | |||
b2f1600aa3 | |||
b799226973 | |||
1348796a93 | |||
40609cb49b | |||
0b45d25151 | |||
28252352d7 | |||
8d94358251 | |||
ad6693fb64 | |||
01c9e96f64 | |||
63b6786767 | |||
f7ab81fe51 | |||
eac4f12777 | |||
9d5723435f | |||
6e29d8453c | |||
50b5fe964c | |||
64752acd27 | |||
7edaa7da4b | |||
4bbb8a587b | |||
4a6bf11db3 | |||
9bbca3110f | |||
5e563ef635 | |||
2ca8cc77b2 | |||
8c7c018893 |
17
.github/workflows/bindings.yml
vendored
Normal file
17
.github/workflows/bindings.yml
vendored
Normal file
@ -0,0 +1,17 @@
|
||||
name: Bindings Tests
|
||||
on:
|
||||
push:
|
||||
paths:
|
||||
- bindings/go/**
|
||||
|
||||
jobs:
|
||||
ubuntu-latest:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: '^1.19'
|
||||
- uses: actions/checkout@v1
|
||||
- run: |
|
||||
cd bindings/go
|
||||
make test
|
152
.github/workflows/build.yml
vendored
152
.github/workflows/build.yml
vendored
@ -113,3 +113,155 @@ jobs:
|
||||
run: |
|
||||
make
|
||||
ctest -L gh --output-on-failure
|
||||
|
||||
windows:
|
||||
runs-on: windows-latest
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
build: [Release]
|
||||
arch: [Win32, x64]
|
||||
sdl2: [ON]
|
||||
include:
|
||||
- arch: Win32
|
||||
s2arc: x86
|
||||
- arch: x64
|
||||
s2arc: x64
|
||||
- sdl2: ON
|
||||
s2ver: 2.26.0
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v1
|
||||
|
||||
- name: Add msbuild to PATH
|
||||
uses: microsoft/setup-msbuild@v1
|
||||
|
||||
- name: Fetch SDL2 and set SDL2_DIR
|
||||
if: matrix.sdl2 == 'ON'
|
||||
run: |
|
||||
C:/msys64/usr/bin/wget.exe -qO sdl2.zip https://github.com/libsdl-org/SDL/releases/download/release-${{ matrix.s2ver }}/SDL2-devel-${{ matrix.s2ver }}-VC.zip
|
||||
7z x sdl2.zip
|
||||
echo "SDL2_DIR=$env:GITHUB_WORKSPACE/SDL2-${{ matrix.s2ver }}/cmake" >> $env:GITHUB_ENV
|
||||
|
||||
- name: Configure
|
||||
run: >
|
||||
cmake -S . -B ./build -A ${{ matrix.arch }}
|
||||
-DCMAKE_BUILD_TYPE=${{ matrix.build }}
|
||||
-DWHISPER_SUPPORT_SDL2=${{ matrix.sdl2 }}
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
cd ./build
|
||||
msbuild ALL_BUILD.vcxproj -t:build -p:configuration=${{ matrix.build }} -p:platform=${{ matrix.arch }}
|
||||
|
||||
- name: Copy SDL2.dll
|
||||
if: matrix.sdl2 == 'ON'
|
||||
run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }}
|
||||
|
||||
- name: Upload binaries
|
||||
if: matrix.sdl2 == 'ON'
|
||||
uses: actions/upload-artifact@v1
|
||||
with:
|
||||
name: whisper-bin-${{ matrix.arch }}
|
||||
path: build/bin/${{ matrix.build }}
|
||||
|
||||
windows-blas:
|
||||
runs-on: windows-latest
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
build: [Release]
|
||||
arch: [Win32, x64]
|
||||
blas: [ON]
|
||||
sdl2: [ON]
|
||||
include:
|
||||
- arch: Win32
|
||||
obzip: https://github.com/xianyi/OpenBLAS/releases/download/v0.3.21/OpenBLAS-0.3.21-x86.zip
|
||||
s2arc: x86
|
||||
- arch: x64
|
||||
obzip: https://github.com/xianyi/OpenBLAS/releases/download/v0.3.21/OpenBLAS-0.3.21-x64.zip
|
||||
s2arc: x64
|
||||
- sdl2: ON
|
||||
s2ver: 2.26.0
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v1
|
||||
|
||||
- name: Add msbuild to PATH
|
||||
uses: microsoft/setup-msbuild@v1
|
||||
|
||||
- name: Fetch OpenBLAS
|
||||
if: matrix.blas == 'ON'
|
||||
run: |
|
||||
C:/msys64/usr/bin/wget.exe -qO blas.zip ${{ matrix.obzip }}
|
||||
7z x blas.zip -oblas -y
|
||||
copy blas/include/cblas.h .
|
||||
copy blas/include/openblas_config.h .
|
||||
echo "blasdir=$env:GITHUB_WORKSPACE/blas" >> $env:GITHUB_ENV
|
||||
|
||||
- name: Fetch SDL2 and set SDL2_DIR
|
||||
if: matrix.sdl2 == 'ON'
|
||||
run: |
|
||||
C:/msys64/usr/bin/wget.exe -qO sdl2.zip https://github.com/libsdl-org/SDL/releases/download/release-${{ matrix.s2ver }}/SDL2-devel-${{ matrix.s2ver }}-VC.zip
|
||||
7z x sdl2.zip
|
||||
echo "SDL2_DIR=$env:GITHUB_WORKSPACE/SDL2-${{ matrix.s2ver }}/cmake" >> $env:GITHUB_ENV
|
||||
|
||||
- name: Configure
|
||||
run: >
|
||||
cmake -S . -B ./build -A ${{ matrix.arch }}
|
||||
-DCMAKE_BUILD_TYPE=${{ matrix.build }}
|
||||
-DWHISPER_SUPPORT_OPENBLAS=${{ matrix.blas }}
|
||||
-DCMAKE_LIBRARY_PATH="$env:blasdir/lib"
|
||||
-DWHISPER_SUPPORT_SDL2=${{ matrix.sdl2 }}
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
cd ./build
|
||||
msbuild ALL_BUILD.vcxproj -t:build -p:configuration=${{ matrix.build }} -p:platform=${{ matrix.arch }}
|
||||
|
||||
- name: Copy libopenblas.dll
|
||||
if: matrix.blas == 'ON'
|
||||
run: copy "$env:blasdir/bin/libopenblas.dll" build/bin/${{ matrix.build }}
|
||||
|
||||
- name: Copy SDL2.dll
|
||||
if: matrix.sdl2 == 'ON'
|
||||
run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }}
|
||||
|
||||
- name: Upload binaries
|
||||
if: matrix.blas == 'ON' && matrix.sdl2 == 'ON'
|
||||
uses: actions/upload-artifact@v1
|
||||
with:
|
||||
name: whisper-blas-bin-${{ matrix.arch }}
|
||||
path: build/bin/${{ matrix.build }}
|
||||
|
||||
emscripten:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
build: [Release]
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v1
|
||||
|
||||
- name: Dependencies
|
||||
run: |
|
||||
wget -q https://github.com/emscripten-core/emsdk/archive/master.tar.gz
|
||||
tar -xvf master.tar.gz
|
||||
emsdk-master/emsdk update
|
||||
emsdk-master/emsdk install latest
|
||||
emsdk-master/emsdk activate latest
|
||||
|
||||
- name: Configure
|
||||
run: echo "tmp"
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
pushd emsdk-master
|
||||
source ./emsdk_env.sh
|
||||
popd
|
||||
emcmake cmake . -DCMAKE_BUILD_TYPE=${{ matrix.build }}
|
||||
make
|
||||
|
32
.gitignore
vendored
32
.gitignore
vendored
@ -1,7 +1,31 @@
|
||||
sync.sh
|
||||
main
|
||||
stream
|
||||
*.o
|
||||
.cache
|
||||
.cache/
|
||||
.vs/
|
||||
.vscode/
|
||||
.DS_Store
|
||||
|
||||
build/
|
||||
build-em/
|
||||
build-debug/
|
||||
build-release/
|
||||
build-static/
|
||||
build-sanitize-addr/
|
||||
build-sanitize-thread/
|
||||
|
||||
/main
|
||||
/stream
|
||||
/command
|
||||
/talk
|
||||
/bench
|
||||
|
||||
sync.sh
|
||||
libwhisper.a
|
||||
libwhisper.so
|
||||
compile_commands.json
|
||||
|
||||
examples/arm_neon.h
|
||||
examples/whisper.objc/whisper.objc.xcodeproj/xcshareddata
|
||||
examples/whisper.objc/whisper.objc.xcodeproj/xcuserdata/
|
||||
examples/whisper.objc/whisper.objc.xcodeproj/project.xcworkspace/xcuserdata
|
||||
|
||||
extra/bench-gg.txt
|
||||
|
3
.gitmodules
vendored
Normal file
3
.gitmodules
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
[submodule "bindings/ios"]
|
||||
path = bindings/ios
|
||||
url = https://github.com/ggerganov/whisper.spm
|
201
CMakeLists.txt
201
CMakeLists.txt
@ -1,44 +1,82 @@
|
||||
cmake_minimum_required (VERSION 3.0)
|
||||
project(whisper.cpp VERSION 1.0.0)
|
||||
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS "on")
|
||||
project(whisper.cpp VERSION 1.1.0)
|
||||
|
||||
# Add path to modules
|
||||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
|
||||
|
||||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
|
||||
set(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_PREFIX}/lib")
|
||||
|
||||
if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
|
||||
set(WHISPER_STANDALONE ON)
|
||||
include(GitVars)
|
||||
include(BuildTypes)
|
||||
|
||||
# configure project version
|
||||
if (EXISTS "${CMAKE_SOURCE_DIR}/bindings/ios/Makefile-tmpl")
|
||||
configure_file(${CMAKE_SOURCE_DIR}/bindings/ios/Makefile-tmpl ${CMAKE_SOURCE_DIR}/bindings/ios/Makefile @ONLY)
|
||||
endif()
|
||||
configure_file(${CMAKE_SOURCE_DIR}/bindings/javascript/package-tmpl.json ${CMAKE_SOURCE_DIR}/bindings/javascript/package.json @ONLY)
|
||||
else()
|
||||
set(WHISPER_STANDALONE OFF)
|
||||
endif()
|
||||
|
||||
if (EMSCRIPTEN)
|
||||
set(BUILD_SHARED_LIBS_DEFAULT OFF)
|
||||
|
||||
option(WHISPER_WASM_SINGLE_FILE "whisper: embed WASM inside the generated whisper.js" ON)
|
||||
else()
|
||||
if (MINGW)
|
||||
set(BUILD_SHARED_LIBS_DEFAULT OFF)
|
||||
else()
|
||||
set(BUILD_SHARED_LIBS_DEFAULT ON)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# options
|
||||
|
||||
option(WHISPER_ALL_WARNINGS "whisper: enable all compiler warnings" ON)
|
||||
option(BUILD_SHARED_LIBS "whisper: build shared libs" ${BUILD_SHARED_LIBS_DEFAULT})
|
||||
|
||||
option(WHISPER_ALL_WARNINGS "whisper: enable all compiler warnings" ON)
|
||||
option(WHISPER_ALL_WARNINGS_3RD_PARTY "whisper: enable all compiler warnings in 3rd party libs" OFF)
|
||||
|
||||
option(WHISPER_SANITIZE_THREAD "whisper: enable thread sanitizer" OFF)
|
||||
option(WHISPER_SANITIZE_ADDRESS "whisper: enable address sanitizer" OFF)
|
||||
option(WHISPER_SANITIZE_THREAD "whisper: enable thread sanitizer" OFF)
|
||||
option(WHISPER_SANITIZE_ADDRESS "whisper: enable address sanitizer" OFF)
|
||||
option(WHISPER_SANITIZE_UNDEFINED "whisper: enable undefined sanitizer" OFF)
|
||||
|
||||
option(WHISPER_BUILD_TESTS "whisper: build tests" ${WHISPER_STANDALONE})
|
||||
option(WHISPER_BUILD_TESTS "whisper: build tests" ${WHISPER_STANDALONE})
|
||||
option(WHISPER_BUILD_EXAMPLES "whisper: build examples" ${WHISPER_STANDALONE})
|
||||
|
||||
option(WHISPER_SUPPORT_SDL2 "whisper: support for libSDL2" OFF)
|
||||
|
||||
if (APPLE)
|
||||
option(WHISPER_NO_ACCELERATE "whisper: disable Accelerate framework" OFF)
|
||||
option(WHISPER_NO_AVX "whisper: disable AVX" OFF)
|
||||
option(WHISPER_NO_AVX2 "whisper: disable AVX2" OFF)
|
||||
option(WHISPER_NO_FMA "whisper: disable FMA" OFF)
|
||||
else()
|
||||
option(WHISPER_SUPPORT_OPENBLAS "whisper: support for OpenBLAS" OFF)
|
||||
endif()
|
||||
|
||||
option(WHISPER_PERF "whisper: enable perf timings" OFF)
|
||||
|
||||
# sanitizers
|
||||
|
||||
if (WHISPER_SANITIZE_THREAD)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fsanitize=thread")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=thread")
|
||||
endif()
|
||||
if (NOT MSVC)
|
||||
if (WHISPER_SANITIZE_THREAD)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fsanitize=thread")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=thread")
|
||||
endif()
|
||||
|
||||
if (WHISPER_SANITIZE_ADDRESS)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fsanitize=address -fno-omit-frame-pointer")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=address -fno-omit-frame-pointer")
|
||||
endif()
|
||||
if (WHISPER_SANITIZE_ADDRESS)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fsanitize=address -fno-omit-frame-pointer")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=address -fno-omit-frame-pointer")
|
||||
endif()
|
||||
|
||||
if (WHISPER_SANITIZE_UNDEFINED)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fsanitize=undefined")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=undefined")
|
||||
if (WHISPER_SANITIZE_UNDEFINED)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fsanitize=undefined")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=undefined")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
#set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -ffast-math")
|
||||
@ -46,19 +84,33 @@ endif()
|
||||
|
||||
# dependencies
|
||||
|
||||
set(CMAKE_C_STANDARD 11)
|
||||
set(CMAKE_CXX_STANDARD 11)
|
||||
|
||||
find_package(Threads REQUIRED)
|
||||
|
||||
if (WHISPER_SUPPORT_SDL2)
|
||||
# SDL2
|
||||
find_package(SDL2 REQUIRED)
|
||||
# on APPLE - include Accelerate framework
|
||||
if (APPLE AND NOT WHISPER_NO_ACCELERATE)
|
||||
find_library(ACCELERATE_FRAMEWORK Accelerate)
|
||||
if (ACCELERATE_FRAMEWORK)
|
||||
message(STATUS "Accelerate framework found")
|
||||
|
||||
string(STRIP "${SDL2_LIBRARIES}" SDL2_LIBRARIES)
|
||||
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ${ACCELERATE_FRAMEWORK})
|
||||
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_USE_ACCELERATE)
|
||||
else()
|
||||
message(WARNING "Accelerate framework not found")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
message(STATUS "SDL2_INCLUDE_DIRS = ${SDL2_INCLUDE_DIRS}")
|
||||
message(STATUS "SDL2_LIBRARIES = ${SDL2_LIBRARIES}")
|
||||
if (WHISPER_SUPPORT_OPENBLAS)
|
||||
find_library(OPENBLAS_LIB
|
||||
NAMES openblas libopenblas
|
||||
)
|
||||
if (OPENBLAS_LIB)
|
||||
message(STATUS "OpenBLAS found")
|
||||
|
||||
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ${OPENBLAS_LIB})
|
||||
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_USE_OPENBLAS)
|
||||
else()
|
||||
message(WARNING "OpenBLAS not found")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# compiler flags
|
||||
@ -69,7 +121,7 @@ if (NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES)
|
||||
endif ()
|
||||
|
||||
if (WHISPER_ALL_WARNINGS)
|
||||
if (CMAKE_COMPILER_IS_GNUCC OR CMAKE_C_COMPILER_ID MATCHES "Clang")
|
||||
if (NOT MSVC)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} \
|
||||
-Wall \
|
||||
-Wextra \
|
||||
@ -78,14 +130,23 @@ if (WHISPER_ALL_WARNINGS)
|
||||
-Wcast-qual \
|
||||
-Wstrict-prototypes \
|
||||
-Wpointer-arith \
|
||||
-Wno-unused-function \
|
||||
")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} \
|
||||
-Wall \
|
||||
-Wextra \
|
||||
-Wpedantic \
|
||||
-Wcast-qual \
|
||||
")
|
||||
else()
|
||||
# todo : windows
|
||||
# todo : msvc
|
||||
endif()
|
||||
endif()
|
||||
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Werror=vla")
|
||||
#set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fno-math-errno -ffinite-math-only -funsafe-math-optimizations")
|
||||
if (NOT MSVC)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Werror=vla")
|
||||
#set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fno-math-errno -ffinite-math-only -funsafe-math-optimizations")
|
||||
endif()
|
||||
|
||||
message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}")
|
||||
|
||||
@ -93,23 +154,59 @@ if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm" OR ${CMAKE_SYSTEM_PROCESSOR} MATCHES
|
||||
message(STATUS "ARM detected")
|
||||
else()
|
||||
message(STATUS "x86 detected")
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx -mavx2 -mfma -mf16c")
|
||||
if (MSVC)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX2")
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /arch:AVX2")
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /arch:AVX2")
|
||||
else()
|
||||
if (EMSCRIPTEN)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -pthread")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread")
|
||||
else()
|
||||
if(NOT WHISPER_NO_AVX)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx")
|
||||
endif()
|
||||
if(NOT WHISPER_NO_AVX2)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx2")
|
||||
endif()
|
||||
if(NOT WHISPER_NO_FMA)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mfma")
|
||||
endif()
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mf16c")
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (WHISPER_PERF)
|
||||
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_PERF)
|
||||
endif()
|
||||
|
||||
#
|
||||
# whisper - this is the main library of the project
|
||||
#
|
||||
|
||||
set(TARGET whisper)
|
||||
|
||||
add_library(${TARGET}
|
||||
ggml.h
|
||||
ggml.c
|
||||
whisper.h
|
||||
whisper.cpp
|
||||
)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_include_directories(${TARGET} PUBLIC
|
||||
.
|
||||
)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE ${CMAKE_THREAD_LIBS_INIT})
|
||||
if (MSVC)
|
||||
target_link_libraries(${TARGET} PRIVATE ${WHISPER_EXTRA_LIBS} ${CMAKE_THREAD_LIBS_INIT})
|
||||
|
||||
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -D_CRT_SECURE_NO_WARNINGS)
|
||||
else()
|
||||
target_link_libraries(${TARGET} PRIVATE m ${WHISPER_EXTRA_LIBS} ${CMAKE_THREAD_LIBS_INIT})
|
||||
endif()
|
||||
|
||||
if (BUILD_SHARED_LIBS)
|
||||
target_link_libraries(${TARGET} PUBLIC
|
||||
@ -121,6 +218,10 @@ if (BUILD_SHARED_LIBS)
|
||||
)
|
||||
endif()
|
||||
|
||||
if (EMSCRIPTEN)
|
||||
set_target_properties(${TARGET} PROPERTIES COMPILE_FLAGS "-msimd128")
|
||||
endif()
|
||||
|
||||
target_compile_definitions(${TARGET} PUBLIC
|
||||
${WHISPER_EXTRA_FLAGS}
|
||||
)
|
||||
@ -128,26 +229,24 @@ target_compile_definitions(${TARGET} PUBLIC
|
||||
install(TARGETS ${TARGET}
|
||||
LIBRARY DESTINATION lib
|
||||
ARCHIVE DESTINATION lib/static
|
||||
RUNTIME DESTINATION bin
|
||||
)
|
||||
|
||||
#
|
||||
# bindings
|
||||
#
|
||||
|
||||
add_subdirectory(bindings)
|
||||
|
||||
#
|
||||
# programs, examples and tests
|
||||
#
|
||||
|
||||
if (WHISPER_STANDALONE)
|
||||
# main
|
||||
set(TARGET main)
|
||||
add_executable(${TARGET} main.cpp)
|
||||
target_link_libraries(${TARGET} PRIVATE whisper ${CMAKE_THREAD_LIBS_INIT})
|
||||
|
||||
if (WHISPER_SUPPORT_SDL2)
|
||||
# stream
|
||||
set(TARGET stream)
|
||||
add_executable(${TARGET} stream.cpp)
|
||||
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
|
||||
target_link_libraries(${TARGET} PRIVATE whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
|
||||
endif ()
|
||||
|
||||
if (WHISPER_BUILD_TESTS)
|
||||
enable_testing()
|
||||
add_subdirectory(tests)
|
||||
endif ()
|
||||
if (WHISPER_BUILD_TESTS)
|
||||
enable_testing()
|
||||
add_subdirectory(tests)
|
||||
endif ()
|
||||
|
||||
if (WHISPER_BUILD_EXAMPLES)
|
||||
add_subdirectory(examples)
|
||||
endif()
|
||||
|
193
Makefile
193
Makefile
@ -1,16 +1,38 @@
|
||||
ifndef UNAME_S
|
||||
UNAME_S := $(shell uname -s)
|
||||
endif
|
||||
|
||||
ifndef UNAME_P
|
||||
UNAME_P := $(shell uname -p)
|
||||
endif
|
||||
|
||||
ifndef UNAME_M
|
||||
UNAME_M := $(shell uname -m)
|
||||
endif
|
||||
|
||||
CCV := $(shell $(CC) --version | head -n 1)
|
||||
CXXV := $(shell $(CXX) --version | head -n 1)
|
||||
|
||||
# Mac OS + Arm can report x86_64
|
||||
# ref: https://github.com/ggerganov/whisper.cpp/issues/66#issuecomment-1282546789
|
||||
ifeq ($(UNAME_S),Darwin)
|
||||
ifneq ($(UNAME_P),arm)
|
||||
SYSCTL_M := $(shell sysctl -n hw.optional.arm64)
|
||||
ifeq ($(SYSCTL_M),1)
|
||||
# UNAME_P := arm
|
||||
# UNAME_M := arm64
|
||||
warn := $(warning Your arch is announced as x86_64, but it seems to actually be ARM64. Not fixing that can lead to bad performance. For more info see: https://github.com/ggerganov/whisper.cpp/issues/66\#issuecomment-1282546789)
|
||||
endif
|
||||
endif
|
||||
endif
|
||||
|
||||
#
|
||||
# Compile flags
|
||||
#
|
||||
|
||||
CFLAGS = -O3 -std=c11
|
||||
CXXFLAGS = -O3 -std=c++11
|
||||
|
||||
CFLAGS += -Wall -Wextra -Wno-unused-parameter -Wno-unused-function
|
||||
CXXFLAGS += -Wall -Wextra -Wno-unused-parameter -Wno-unused-function
|
||||
CFLAGS = -I. -O3 -std=c11 -fPIC
|
||||
CXXFLAGS = -I. -I./examples -O3 -std=c++11 -fPIC
|
||||
LDFLAGS =
|
||||
|
||||
# OS specific
|
||||
# TODO: support Windows
|
||||
@ -22,17 +44,101 @@ ifeq ($(UNAME_S),Darwin)
|
||||
CFLAGS += -pthread
|
||||
CXXFLAGS += -pthread
|
||||
endif
|
||||
ifeq ($(UNAME_S),FreeBSD)
|
||||
CFLAGS += -pthread
|
||||
CXXFLAGS += -pthread
|
||||
endif
|
||||
ifeq ($(UNAME_S),Haiku)
|
||||
CFLAGS += -pthread
|
||||
CXXFLAGS += -pthread
|
||||
endif
|
||||
|
||||
# Architecture specific
|
||||
ifeq ($(UNAME_P),x86_64)
|
||||
# TODO: probably these flags need to be tweaked on some architectures
|
||||
# feel free to update the Makefile for your architecture and send a pull request or issue
|
||||
ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686))
|
||||
ifeq ($(UNAME_S),Darwin)
|
||||
CFLAGS += -mf16c
|
||||
AVX1_M := $(shell sysctl machdep.cpu.features)
|
||||
ifneq (,$(findstring FMA,$(AVX1_M)))
|
||||
CFLAGS += -mfma
|
||||
endif
|
||||
ifneq (,$(findstring AVX1.0,$(AVX1_M)))
|
||||
CFLAGS += -mavx
|
||||
endif
|
||||
AVX2_M := $(shell sysctl machdep.cpu.leaf7_features)
|
||||
ifneq (,$(findstring AVX2,$(AVX2_M)))
|
||||
CFLAGS += -mavx2
|
||||
endif
|
||||
else ifeq ($(UNAME_S),Linux)
|
||||
AVX1_M := $(shell grep "avx " /proc/cpuinfo)
|
||||
ifneq (,$(findstring avx,$(AVX1_M)))
|
||||
CFLAGS += -mavx
|
||||
endif
|
||||
AVX2_M := $(shell grep "avx2 " /proc/cpuinfo)
|
||||
ifneq (,$(findstring avx2,$(AVX2_M)))
|
||||
CFLAGS += -mavx2
|
||||
endif
|
||||
FMA_M := $(shell grep "fma " /proc/cpuinfo)
|
||||
ifneq (,$(findstring fma,$(FMA_M)))
|
||||
CFLAGS += -mfma
|
||||
endif
|
||||
F16C_M := $(shell grep "f16c " /proc/cpuinfo)
|
||||
ifneq (,$(findstring f16c,$(F16C_M)))
|
||||
CFLAGS += -mf16c
|
||||
endif
|
||||
SSE3_M := $(shell grep "sse3 " /proc/cpuinfo)
|
||||
ifneq (,$(findstring sse3,$(SSE3_M)))
|
||||
CFLAGS += -msse3
|
||||
endif
|
||||
else ifeq ($(UNAME_S),Haiku)
|
||||
AVX1_M := $(shell sysinfo -cpu | grep "AVX ")
|
||||
ifneq (,$(findstring avx,$(AVX1_M)))
|
||||
CFLAGS += -mavx
|
||||
endif
|
||||
AVX2_M := $(shell sysinfo -cpu | grep "AVX2 ")
|
||||
ifneq (,$(findstring avx2,$(AVX2_M)))
|
||||
CFLAGS += -mavx2
|
||||
endif
|
||||
FMA_M := $(shell sysinfo -cpu | grep "FMA ")
|
||||
ifneq (,$(findstring fma,$(FMA_M)))
|
||||
CFLAGS += -mfma
|
||||
endif
|
||||
F16C_M := $(shell sysinfo -cpu | grep "F16C ")
|
||||
ifneq (,$(findstring f16c,$(F16C_M)))
|
||||
CFLAGS += -mf16c
|
||||
endif
|
||||
else
|
||||
CFLAGS += -mfma -mf16c -mavx -mavx2
|
||||
endif
|
||||
endif
|
||||
ifeq ($(UNAME_M),amd64)
|
||||
CFLAGS += -mavx -mavx2 -mfma -mf16c
|
||||
endif
|
||||
ifneq ($(filter arm%,$(UNAME_P)),)
|
||||
# Mac M1
|
||||
endif
|
||||
ifneq ($(filter aarch64%,$(UNAME_P)),)
|
||||
ifeq ($(UNAME_M),ppc64le)
|
||||
POWER9_M := $(shell grep "POWER9" /proc/cpuinfo)
|
||||
ifneq (,$(findstring POWER9,$(POWER9_M)))
|
||||
CFLAGS += -mpower9-vector
|
||||
endif
|
||||
ifneq ($(filter armv6%,$(UNAME_M)),)
|
||||
endif
|
||||
ifndef WHISPER_NO_ACCELERATE
|
||||
# Mac M1 - include Accelerate framework
|
||||
ifeq ($(UNAME_S),Darwin)
|
||||
CFLAGS += -DGGML_USE_ACCELERATE
|
||||
LDFLAGS += -framework Accelerate
|
||||
endif
|
||||
endif
|
||||
ifdef WHISPER_OPENBLAS
|
||||
CFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/openblas
|
||||
LDFLAGS += -lopenblas
|
||||
endif
|
||||
ifdef WHISPER_GPROF
|
||||
CFLAGS += -pg
|
||||
CXXFLAGS += -pg
|
||||
endif
|
||||
ifneq ($(filter aarch64%,$(UNAME_M)),)
|
||||
endif
|
||||
ifneq ($(filter armv6%,$(UNAME_M)),)
|
||||
# Raspberry Pi 1, 2, 3
|
||||
CFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access
|
||||
endif
|
||||
@ -46,21 +152,40 @@ ifneq ($(filter armv8%,$(UNAME_M)),)
|
||||
endif
|
||||
|
||||
#
|
||||
# Build library + main
|
||||
# Print build information
|
||||
#
|
||||
|
||||
main: main.cpp ggml.o whisper.o
|
||||
$(CXX) $(CXXFLAGS) main.cpp whisper.o ggml.o -o main
|
||||
./main -h
|
||||
$(info I whisper.cpp build info: )
|
||||
$(info I UNAME_S: $(UNAME_S))
|
||||
$(info I UNAME_P: $(UNAME_P))
|
||||
$(info I UNAME_M: $(UNAME_M))
|
||||
$(info I CFLAGS: $(CFLAGS))
|
||||
$(info I CXXFLAGS: $(CXXFLAGS))
|
||||
$(info I LDFLAGS: $(LDFLAGS))
|
||||
$(info I CC: $(CCV))
|
||||
$(info I CXX: $(CXXV))
|
||||
$(info )
|
||||
|
||||
default: main
|
||||
|
||||
#
|
||||
# Build library
|
||||
#
|
||||
|
||||
ggml.o: ggml.c ggml.h
|
||||
$(CC) $(CFLAGS) -c ggml.c
|
||||
$(CC) $(CFLAGS) -c ggml.c -o ggml.o
|
||||
|
||||
whisper.o: whisper.cpp whisper.h
|
||||
$(CXX) $(CXXFLAGS) -c whisper.cpp
|
||||
$(CXX) $(CXXFLAGS) -c whisper.cpp -o whisper.o
|
||||
|
||||
libwhisper.a: ggml.o whisper.o
|
||||
$(AR) rcs libwhisper.a ggml.o whisper.o
|
||||
|
||||
libwhisper.so: ggml.o whisper.o
|
||||
$(CXX) $(CXXFLAGS) -shared -o libwhisper.so ggml.o whisper.o $(LDFLAGS)
|
||||
|
||||
clean:
|
||||
rm -f *.o main
|
||||
rm -f *.o main stream command talk bench libwhisper.a libwhisper.so
|
||||
|
||||
#
|
||||
# Examples
|
||||
@ -68,8 +193,21 @@ clean:
|
||||
|
||||
CC_SDL=`sdl2-config --cflags --libs`
|
||||
|
||||
stream: stream.cpp ggml.o whisper.o
|
||||
$(CXX) $(CXXFLAGS) stream.cpp ggml.o whisper.o -o stream $(CC_SDL)
|
||||
main: examples/main/main.cpp ggml.o whisper.o
|
||||
$(CXX) $(CXXFLAGS) examples/main/main.cpp ggml.o whisper.o -o main $(LDFLAGS)
|
||||
./main -h
|
||||
|
||||
stream: examples/stream/stream.cpp ggml.o whisper.o
|
||||
$(CXX) $(CXXFLAGS) examples/stream/stream.cpp ggml.o whisper.o -o stream $(CC_SDL) $(LDFLAGS)
|
||||
|
||||
command: examples/command/command.cpp ggml.o whisper.o
|
||||
$(CXX) $(CXXFLAGS) examples/command/command.cpp ggml.o whisper.o -o command $(CC_SDL) $(LDFLAGS)
|
||||
|
||||
talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp ggml.o whisper.o
|
||||
$(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp ggml.o whisper.o -o talk $(CC_SDL) $(LDFLAGS)
|
||||
|
||||
bench: examples/bench/bench.cpp ggml.o whisper.o
|
||||
$(CXX) $(CXXFLAGS) examples/bench/bench.cpp ggml.o whisper.o -o bench $(LDFLAGS)
|
||||
|
||||
#
|
||||
# Audio samples
|
||||
@ -106,10 +244,11 @@ samples:
|
||||
.PHONY: small
|
||||
.PHONY: medium.en
|
||||
.PHONY: medium
|
||||
.PHONY: large-v1
|
||||
.PHONY: large
|
||||
|
||||
tiny.en tiny base.en base small.en small medium.en medium large: main
|
||||
bash ./download-ggml-model.sh $@
|
||||
tiny.en tiny base.en base small.en small medium.en medium large-v1 large: main
|
||||
bash ./models/download-ggml-model.sh $@
|
||||
@echo ""
|
||||
@echo "==============================================="
|
||||
@echo "Running $@ on all samples in ./samples ..."
|
||||
@ -117,9 +256,17 @@ tiny.en tiny base.en base small.en small medium.en medium large: main
|
||||
@echo ""
|
||||
@for f in samples/*.wav; do \
|
||||
echo "----------------------------------------------" ; \
|
||||
echo "[+] Running base.en on $$f ... (run 'ffplay $$f' to listen)" ; \
|
||||
echo "[+] Running $@ on $$f ... (run 'ffplay $$f' to listen)" ; \
|
||||
echo "----------------------------------------------" ; \
|
||||
echo "" ; \
|
||||
./main -m models/ggml-$@.bin -f $$f ; \
|
||||
echo "" ; \
|
||||
done
|
||||
|
||||
#
|
||||
# Tests
|
||||
#
|
||||
|
||||
.PHONY: tests
|
||||
tests:
|
||||
bash ./tests/run-tests.sh
|
||||
|
409
README.md
409
README.md
@ -2,30 +2,76 @@
|
||||
|
||||
[](https://github.com/ggerganov/whisper.cpp/actions)
|
||||
[](https://opensource.org/licenses/MIT)
|
||||
[](https://www.npmjs.com/package/whisper.cpp/)
|
||||
|
||||
Stable: [1.0.4](https://github.com/ggerganov/whisper.cpp/releases/tag/1.0.4) / Beta: [1.1.0](https://github.com/ggerganov/whisper.cpp/releases/tag/1.1.0) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126)
|
||||
|
||||
High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model:
|
||||
|
||||
- Plain C/C++ implementation without dependencies
|
||||
- ARM_NEON and AVX intrinsics support
|
||||
- Apple silicon first-class citizen - optimized via Arm Neon and Accelerate framework
|
||||
- AVX intrinsics support for x86 architectures
|
||||
- VSX intrinsics support for POWER architectures
|
||||
- Mixed F16 / F32 precision
|
||||
- Low memory usage (Flash Attention + Flash Forward)
|
||||
- Zero memory allocations at runtime
|
||||
- Runs on the CPU
|
||||
- [C-style API](https://github.com/ggerganov/whisper.cpp/blob/master/whisper.h)
|
||||
- Supported platforms: Linux, Mac OS (Intel and Arm), Raspberry Pi, Android
|
||||
|
||||
## Usage
|
||||
Supported platforms:
|
||||
|
||||
To build the main program, run `make`. You can then transcribe a `.wav` file like this:
|
||||
- [x] Mac OS (Intel and Arm)
|
||||
- [x] [iOS](examples/whisper.objc)
|
||||
- [x] [Android](examples/whisper.android)
|
||||
- [x] Linux / [FreeBSD](https://github.com/ggerganov/whisper.cpp/issues/56#issuecomment-1350920264)
|
||||
- [x] [WebAssembly](examples/whisper.wasm)
|
||||
- [x] Windows ([MSVC](https://github.com/ggerganov/whisper.cpp/blob/master/.github/workflows/build.yml#L117-L144) and [MinGW](https://github.com/ggerganov/whisper.cpp/issues/168)]
|
||||
- [x] [Raspberry Pi](https://github.com/ggerganov/whisper.cpp/discussions/166)
|
||||
|
||||
The entire implementation of the model is contained in 2 source files:
|
||||
|
||||
- Tensor operations: [ggml.h](ggml.h) / [ggml.c](ggml.c)
|
||||
- Transformer inference: [whisper.h](whisper.h) / [whisper.cpp](whisper.cpp)
|
||||
|
||||
Having such a lightweight implementation of the model allows to easily integrate it in different platforms and applications.
|
||||
As an example, here is a video of running the model on an iPhone 13 device - fully offline, on-device: [whisper.objc](examples/whisper.objc)
|
||||
|
||||
https://user-images.githubusercontent.com/1991296/197385372-962a6dea-bca1-4d50-bf96-1d8c27b98c81.mp4
|
||||
|
||||
You can also easily make your own offline voice assistant application: [command](examples/command)
|
||||
|
||||
https://user-images.githubusercontent.com/1991296/204038393-2f846eae-c255-4099-a76d-5735c25c49da.mp4
|
||||
|
||||
Or you can even run it straight in the browser: [talk.wasm](examples/talk.wasm)
|
||||
|
||||
## Implementation details
|
||||
|
||||
- The core tensor operations are implemented in C ([ggml.h](ggml.h) / [ggml.c](ggml.c))
|
||||
- The transformer model and the high-level C-style API are implemented in C++ ([whisper.h](whisper.h) / [whisper.cpp](whisper.cpp))
|
||||
- Sample usage is demonstrated in [main.cpp](examples/main)
|
||||
- Sample real-time audio transcription from the microphone is demonstrated in [stream.cpp](examples/stream)
|
||||
- Various other examples are available in the [examples](examples) folder
|
||||
|
||||
The tensor operators are optimized heavily for Apple silicon CPUs. Depending on the computation size, Arm Neon SIMD
|
||||
instrisics or CBLAS Accelerate framework routines are used. The latter are especially effective for bigger sizes since
|
||||
the Accelerate framework utilizes the special-purpose AMX coprocessor available in modern Apple products.
|
||||
|
||||
## Quick start
|
||||
|
||||
First, download one of the Whisper models converted in [ggml format](models). For example:
|
||||
|
||||
```bash
|
||||
$ ./main -f input.wav
|
||||
bash ./models/download-ggml-model.sh base.en
|
||||
```
|
||||
|
||||
Before running the program, make sure to download one of the ggml Whisper models. For example:
|
||||
Now build the [main](examples/main) example and transcribe an audio file like this:
|
||||
|
||||
```bash
|
||||
bash ./download-ggml-model.sh base.en
|
||||
# build the main example
|
||||
make
|
||||
|
||||
# transcribe an audio file
|
||||
./main -f samples/jfk.wav
|
||||
```
|
||||
|
||||
---
|
||||
@ -34,28 +80,49 @@ For a quick demo, simply run `make base.en`:
|
||||
|
||||
```java
|
||||
$ make base.en
|
||||
cc -O3 -std=c11 -Wall -Wextra -Wno-unused-parameter -Wno-unused-function -pthread -c ggml.c
|
||||
c++ -O3 -std=c++11 -Wall -Wextra -Wno-unused-parameter -Wno-unused-function -pthread -c whisper.cpp
|
||||
c++ -O3 -std=c++11 -Wall -Wextra -Wno-unused-parameter -Wno-unused-function -pthread main.cpp whisper.o ggml.o -o main
|
||||
|
||||
cc -I. -O3 -std=c11 -pthread -DGGML_USE_ACCELERATE -c ggml.c -o ggml.o
|
||||
c++ -I. -I./examples -O3 -std=c++11 -pthread -c whisper.cpp -o whisper.o
|
||||
c++ -I. -I./examples -O3 -std=c++11 -pthread examples/main/main.cpp whisper.o ggml.o -o main -framework Accelerate
|
||||
./main -h
|
||||
|
||||
usage: ./main [options] file0.wav file1.wav ...
|
||||
|
||||
options:
|
||||
-h, --help show this help message and exit
|
||||
-s SEED, --seed SEED RNG seed (default: -1)
|
||||
-t N, --threads N number of threads to use during computation (default: 4)
|
||||
-v, --verbose verbose output
|
||||
--translate translate from source language to english
|
||||
-ps, --print_special print special tokens
|
||||
-nt, --no_timestamps do not print timestamps
|
||||
-l LANG, --language LANG spoken language (default: en)
|
||||
-m FNAME, --model FNAME model path (default: models/ggml-base.en.bin)
|
||||
-f FNAME, --file FNAME input WAV file path
|
||||
-h, --help [default] show this help message and exit
|
||||
-t N, --threads N [4 ] number of threads to use during computation
|
||||
-p N, --processors N [1 ] number of processors to use during computation
|
||||
-ot N, --offset-t N [0 ] time offset in milliseconds
|
||||
-on N, --offset-n N [0 ] segment index offset
|
||||
-d N, --duration N [0 ] duration of audio to process in milliseconds
|
||||
-mc N, --max-context N [-1 ] maximum number of text context tokens to store
|
||||
-ml N, --max-len N [0 ] maximum segment length in characters
|
||||
-bo N, --best-of N [5 ] number of best candidates to keep
|
||||
-bs N, --beam-size N [-1 ] beam size for beam search
|
||||
-wt N, --word-thold N [0.01 ] word timestamp probability threshold
|
||||
-et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail
|
||||
-lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail
|
||||
-su, --speed-up [false ] speed up audio by x2 (reduced accuracy)
|
||||
-tr, --translate [false ] translate from source language to english
|
||||
-di, --diarize [false ] stereo audio diarization
|
||||
-otxt, --output-txt [false ] output result in a text file
|
||||
-ovtt, --output-vtt [false ] output result in a vtt file
|
||||
-osrt, --output-srt [false ] output result in a srt file
|
||||
-owts, --output-words [false ] output script for generating karaoke video
|
||||
-ocsv, --output-csv [false ] output result in a CSV file
|
||||
-ps, --print-special [false ] print special tokens
|
||||
-pc, --print-colors [false ] print colors
|
||||
-pp, --print-progress [false ] print progress
|
||||
-nt, --no-timestamps [true ] do not print timestamps
|
||||
-l LANG, --language LANG [en ] spoken language ('auto' for auto-detect)
|
||||
--prompt PROMPT [ ] initial prompt
|
||||
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path
|
||||
-f FNAME, --file FNAME [ ] input WAV file path
|
||||
|
||||
bash ./download-ggml-model.sh base.en
|
||||
|
||||
bash ./models/download-ggml-model.sh base.en
|
||||
Downloading ggml model base.en ...
|
||||
models/ggml-base.en.bin 100%[===================================>] 141.11M 6.49MB/s in 23s
|
||||
ggml-base.en.bin 100%[========================>] 141.11M 6.34MB/s in 24s
|
||||
Done! Model 'base.en' saved in 'models/ggml-base.en.bin'
|
||||
You can now use it like this:
|
||||
|
||||
@ -83,30 +150,33 @@ whisper_model_load: n_text_layer = 6
|
||||
whisper_model_load: n_mels = 80
|
||||
whisper_model_load: f16 = 1
|
||||
whisper_model_load: type = 2
|
||||
whisper_model_load: mem_required = 377.00 MB
|
||||
whisper_model_load: adding 1607 extra tokens
|
||||
whisper_model_load: ggml ctx size = 163.43 MB
|
||||
whisper_model_load: memory size = 22.83 MB
|
||||
whisper_model_load: model size = 140.54 MB
|
||||
whisper_model_load: mem_required = 506.00 MB
|
||||
whisper_model_load: ggml ctx size = 140.60 MB
|
||||
whisper_model_load: memory size = 22.83 MB
|
||||
whisper_model_load: model size = 140.54 MB
|
||||
|
||||
main: processing 'samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, lang = en, task = transcribe, timestamps = 1 ...
|
||||
system_info: n_threads = 4 / 10 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | NEON = 1 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 |
|
||||
|
||||
[00:00.000 --> 00:11.000] And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country.
|
||||
main: processing 'samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 processors, lang = en, task = transcribe, timestamps = 1 ...
|
||||
|
||||
|
||||
whisper_print_timings: load time = 77.48 ms
|
||||
whisper_print_timings: mel time = 26.10 ms
|
||||
whisper_print_timings: sample time = 2.19 ms
|
||||
whisper_print_timings: encode time = 632.95 ms / 105.49 ms per layer
|
||||
whisper_print_timings: decode time = 85.11 ms / 14.18 ms per layer
|
||||
whisper_print_timings: total time = 824.14 ms
|
||||
[00:00:00.000 --> 00:00:11.000] And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country.
|
||||
|
||||
|
||||
whisper_print_timings: load time = 105.91 ms
|
||||
whisper_print_timings: mel time = 24.62 ms
|
||||
whisper_print_timings: sample time = 3.63 ms
|
||||
whisper_print_timings: encode time = 324.71 ms / 54.12 ms per layer
|
||||
whisper_print_timings: decode time = 83.58 ms / 13.93 ms per layer
|
||||
whisper_print_timings: total time = 542.81 ms
|
||||
```
|
||||
|
||||
The command downloads the `base.en` model converted to custom `ggml` format and runs the inference on all `.wav` samples in the folder `samples`.
|
||||
|
||||
For detailed usage instructions, run: `./main -h`
|
||||
|
||||
Note that `whisper.cpp` currently runs only with 16-bit WAV files, so make sure to convert your input before running the tool.
|
||||
Note that the [main](examples/main) example currently runs only with 16-bit WAV files, so make sure to convert your input before running the tool.
|
||||
For example, you can use `ffmpeg` like this:
|
||||
|
||||
```java
|
||||
@ -134,13 +204,32 @@ make small.en
|
||||
make small
|
||||
make medium.en
|
||||
make medium
|
||||
make large-v1
|
||||
make large
|
||||
```
|
||||
|
||||
## Memory usage
|
||||
|
||||
| Model | Disk | Mem | SHA |
|
||||
| --- | --- | --- | --- |
|
||||
| tiny | 75 MB | ~390 MB | `bd577a113a864445d4c299885e0cb97d4ba92b5f` |
|
||||
| base | 142 MB | ~500 MB | `465707469ff3a37a2b9b8d8f89f2f99de7299dac` |
|
||||
| small | 466 MB | ~1.0 GB | `55356645c2b361a969dfd0ef2c5a50d530afd8d5` |
|
||||
| medium | 1.5 GB | ~2.6 GB | `fd9727b6e1217c2f614f9b698455c4ffd82463b4` |
|
||||
| large | 2.9 GB | ~4.7 GB | `0f4c8e34f21cf1a914c59d8b3ce882345ad349d6` |
|
||||
|
||||
## Limitations
|
||||
|
||||
- Inference only
|
||||
- No GPU support (yet)
|
||||
|
||||
## Another example
|
||||
|
||||
Here is another example of transcribing a [3:24 min speech](https://upload.wikimedia.org/wikipedia/commons/1/1f/George_W_Bush_Columbia_FINAL.ogg)
|
||||
in less than a minute on a MacBook M1 Pro, using `medium.en` model:
|
||||
in about half a minute on a MacBook M1 Pro, using `medium.en` model:
|
||||
|
||||
<details>
|
||||
<summary>Expand to see the result</summary>
|
||||
|
||||
```java
|
||||
$ ./main -m models/ggml-medium.en.bin -f samples/gb1.wav -t 8
|
||||
@ -158,86 +247,188 @@ whisper_model_load: n_text_layer = 24
|
||||
whisper_model_load: n_mels = 80
|
||||
whisper_model_load: f16 = 1
|
||||
whisper_model_load: type = 4
|
||||
whisper_model_load: mem_required = 2502.00 MB
|
||||
whisper_model_load: mem_required = 2610.00 MB
|
||||
whisper_model_load: adding 1607 extra tokens
|
||||
whisper_model_load: ggml ctx size = 1644.97 MB
|
||||
whisper_model_load: memory size = 182.62 MB
|
||||
whisper_model_load: model size = 1462.12 MB
|
||||
log_mel_spectrogram: n_sample = 3179750, n_len = 19873
|
||||
log_mel_spectrogram: recording length: 198.734375 s
|
||||
|
||||
main: processing 3179750 samples (198.7 sec), 8 threads, lang = english, task = transcribe, timestamps = 1 ...
|
||||
main: processing 'samples/gb1.wav' (3179750 samples, 198.7 sec), 8 threads, lang = en, task = transcribe, timestamps = 1 ...
|
||||
|
||||
[00:00.000 --> 00:08.000] My fellow Americans, this day has brought terrible news and great sadness to our country.
|
||||
[00:08.000 --> 00:17.000] At 9 o'clock this morning, Mission Control in Houston lost contact with our Space Shuttle Columbia.
|
||||
[00:17.000 --> 00:24.000] A short time later, debris was seen falling from the skies above Texas.
|
||||
[00:24.000 --> 00:29.000] The Columbia's lost. There are no survivors.
|
||||
[00:08.000 --> 00:17.000] At nine o'clock this morning, Mission Control in Houston lost contact with our Space Shuttle Columbia.
|
||||
[00:17.000 --> 00:23.000] A short time later, debris was seen falling from the skies above Texas.
|
||||
[00:23.000 --> 00:29.000] The Columbia's lost. There are no survivors.
|
||||
[00:29.000 --> 00:32.000] On board was a crew of seven.
|
||||
[00:32.000 --> 00:43.000] Colonel Rick Husband, Lieutenant Colonel Michael Anderson, Commander Laurel Clark, Captain David Brown, Commander William McCool,
|
||||
[00:43.000 --> 00:52.000] Dr. Kultner Aschavla, and Elon Ramon, a Colonel in the Israeli Air Force.
|
||||
[00:32.000 --> 00:39.000] Colonel Rick Husband, Lieutenant Colonel Michael Anderson, Commander Laurel Clark,
|
||||
[00:39.000 --> 00:48.000] Captain David Brown, Commander William McCool, Dr. Kultna Shavla, and Ilan Ramon,
|
||||
[00:48.000 --> 00:52.000] a colonel in the Israeli Air Force.
|
||||
[00:52.000 --> 00:58.000] These men and women assumed great risk in the service to all humanity.
|
||||
[00:58.000 --> 01:06.000] In an age when space flight has come to seem almost routine, it is easy to overlook the dangers of travel by rocket
|
||||
[01:06.000 --> 01:12.000] and the difficulties of navigating the fierce outer atmosphere of the Earth.
|
||||
[01:12.000 --> 01:22.000] These astronauts knew the dangers, and they faced them willingly, knowing they had a high and noble purpose in life.
|
||||
[01:22.000 --> 01:30.000] Because of their courage, endearing, and idealism, we will miss them all the more.
|
||||
[01:30.000 --> 01:40.000] All Americans today are thinking as well of the families of these men and women who have been given this sudden shock and grief.
|
||||
[01:40.000 --> 01:45.000] You're not alone. Our entire nation agrees with you.
|
||||
[01:45.000 --> 01:52.000] And those you love will always have the respect and gratitude of this country.
|
||||
[00:58.000 --> 01:03.000] In an age when space flight has come to seem almost routine,
|
||||
[01:03.000 --> 01:07.000] it is easy to overlook the dangers of travel by rocket
|
||||
[01:07.000 --> 01:12.000] and the difficulties of navigating the fierce outer atmosphere of the Earth.
|
||||
[01:12.000 --> 01:18.000] These astronauts knew the dangers, and they faced them willingly,
|
||||
[01:18.000 --> 01:23.000] knowing they had a high and noble purpose in life.
|
||||
[01:23.000 --> 01:31.000] Because of their courage and daring and idealism, we will miss them all the more.
|
||||
[01:31.000 --> 01:36.000] All Americans today are thinking as well of the families of these men and women
|
||||
[01:36.000 --> 01:40.000] who have been given this sudden shock and grief.
|
||||
[01:40.000 --> 01:45.000] You're not alone. Our entire nation grieves with you,
|
||||
[01:45.000 --> 01:52.000] and those you love will always have the respect and gratitude of this country.
|
||||
[01:52.000 --> 01:56.000] The cause in which they died will continue.
|
||||
[01:56.000 --> 02:07.000] Mankind is led into the darkness beyond our world by the inspiration of discovery and the longing to understand.
|
||||
[02:07.000 --> 02:11.000] Our journey into space will go on.
|
||||
[01:56.000 --> 02:04.000] Mankind is led into the darkness beyond our world by the inspiration of discovery
|
||||
[02:04.000 --> 02:11.000] and the longing to understand. Our journey into space will go on.
|
||||
[02:11.000 --> 02:16.000] In the skies today, we saw destruction and tragedy.
|
||||
[02:16.000 --> 02:22.000] Yet farther than we can see, there is comfort and hope.
|
||||
[02:22.000 --> 02:31.000] In the words of the prophet Isaiah, "Lift your eyes and look to the heavens who created all these.
|
||||
[02:31.000 --> 02:39.000] He who brings out the starry hosts one by one and calls them each by name."
|
||||
[02:39.000 --> 02:46.000] Because of his great power and mighty strength, not one of them is missing.
|
||||
[02:46.000 --> 02:55.000] The same creator who names the stars also knows the names of the seven souls we mourn today.
|
||||
[02:55.000 --> 03:05.000] The crew of the shuttle Columbia did not return safely to Earth, yet we can pray that all are safely home.
|
||||
[03:05.000 --> 03:14.000] May God bless the grieving families and may God continue to bless America.
|
||||
[03:14.000 --> 03:24.000] [Music]
|
||||
[02:22.000 --> 02:29.000] In the words of the prophet Isaiah, "Lift your eyes and look to the heavens
|
||||
[02:29.000 --> 02:35.000] who created all these. He who brings out the starry hosts one by one
|
||||
[02:35.000 --> 02:39.000] and calls them each by name."
|
||||
[02:39.000 --> 02:46.000] Because of His great power and mighty strength, not one of them is missing.
|
||||
[02:46.000 --> 02:55.000] The same Creator who names the stars also knows the names of the seven souls we mourn today.
|
||||
[02:55.000 --> 03:01.000] The crew of the shuttle Columbia did not return safely to earth,
|
||||
[03:01.000 --> 03:05.000] yet we can pray that all are safely home.
|
||||
[03:05.000 --> 03:13.000] May God bless the grieving families, and may God continue to bless America.
|
||||
[03:13.000 --> 03:41.000] Audio
|
||||
|
||||
|
||||
main: load time = 522.18 ms
|
||||
main: mel time = 423.43 ms
|
||||
main: sample time = 31.42 ms
|
||||
main: encode time = 41518.51 ms / 1729.94 ms per layer
|
||||
main: decode time = 14907.22 ms
|
||||
main: total time = 57416.63 ms
|
||||
whisper_print_timings: load time = 575.92 ms
|
||||
whisper_print_timings: mel time = 230.60 ms
|
||||
whisper_print_timings: sample time = 73.19 ms
|
||||
whisper_print_timings: encode time = 19552.61 ms / 814.69 ms per layer
|
||||
whisper_print_timings: decode time = 13249.96 ms / 552.08 ms per layer
|
||||
whisper_print_timings: total time = 33686.27 ms
|
||||
```
|
||||
</details>
|
||||
|
||||
## Real-time audio input example
|
||||
|
||||
This is a naive example of performing real-time inference on audio from your microphone.
|
||||
The `stream` tool samples the audio every 3 seconds and runs the transcription continously. More info is available in [issue #10](https://github.com/ggerganov/whisper.cpp/issues/10).
|
||||
The [stream](examples/stream) tool samples the audio every half a second and runs the transcription continously.
|
||||
More info is available in [issue #10](https://github.com/ggerganov/whisper.cpp/issues/10).
|
||||
|
||||
```java
|
||||
$ ./stream -m models/ggml-small.en.bin -t 8
|
||||
make stream
|
||||
./stream -m ./models/ggml-base.en.bin -t 8 --step 500 --length 5000
|
||||
```
|
||||
|
||||
https://user-images.githubusercontent.com/1991296/193465125-c163d304-64f6-4f5d-83e5-72239c9a203e.mp4
|
||||
https://user-images.githubusercontent.com/1991296/194935793-76afede7-cfa8-48d8-a80f-28ba83be7d09.mp4
|
||||
|
||||
## Implementation details
|
||||
## Confidence color-coding
|
||||
|
||||
- The core tensor operations are implemented in C ([ggml.h](ggml.h) / [ggml.c](ggml.c))
|
||||
- The high-level C-style API is implemented in C++ ([whisper.h](whisper.h) / [whisper.cpp](whisper.cpp))
|
||||
- Simple usage is demonstrated in [main.cpp](main.cpp)
|
||||
- Sample real-time audio transcription from the microphone is demonstrated in [stream.cpp](stream.cpp)
|
||||
Adding the `--print-colors` argument will print the transcribed text using an experimental color coding strategy
|
||||
to highlight words with high or low confidence:
|
||||
|
||||
## Limitations
|
||||
<img width="965" alt="image" src="https://user-images.githubusercontent.com/1991296/197356445-311c8643-9397-4e5e-b46e-0b4b4daa2530.png">
|
||||
|
||||
- Very basic greedy sampling scheme - always pick up the top token. You can implement your own strategy
|
||||
- Inference only
|
||||
- No GPU support
|
||||
## Controlling the length of the generated text segments (experimental)
|
||||
|
||||
## Memory usage
|
||||
For example, to limit the line length to a maximum of 16 characters, simply add `-ml 16`:
|
||||
|
||||
| Model | Disk | Mem |
|
||||
| --- | --- | --- |
|
||||
| tiny | 75 MB | ~240 MB |
|
||||
| base | 142 MB | ~380 MB |
|
||||
| small | 466 MB | ~970 MB |
|
||||
| medium | 1.5 GB | ~2.5 GB |
|
||||
| large | 2.9 GB | ~4.6 GB |
|
||||
```java
|
||||
./main -m ./models/ggml-base.en.bin -f ./samples/jfk.wav -ml 16
|
||||
|
||||
whisper_model_load: loading model from './models/ggml-base.en.bin'
|
||||
...
|
||||
system_info: n_threads = 4 / 10 | AVX2 = 0 | AVX512 = 0 | NEON = 1 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 |
|
||||
|
||||
main: processing './samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 processors, lang = en, task = transcribe, timestamps = 1 ...
|
||||
|
||||
[00:00:00.000 --> 00:00:00.850] And so my
|
||||
[00:00:00.850 --> 00:00:01.590] fellow
|
||||
[00:00:01.590 --> 00:00:04.140] Americans, ask
|
||||
[00:00:04.140 --> 00:00:05.660] not what your
|
||||
[00:00:05.660 --> 00:00:06.840] country can do
|
||||
[00:00:06.840 --> 00:00:08.430] for you, ask
|
||||
[00:00:08.430 --> 00:00:09.440] what you can do
|
||||
[00:00:09.440 --> 00:00:10.020] for your
|
||||
[00:00:10.020 --> 00:00:11.000] country.
|
||||
```
|
||||
|
||||
## Word-level timestamp
|
||||
|
||||
The `--max-len` argument can be used to obtain word-level timestamps. Simply use `-ml 1`:
|
||||
|
||||
```java
|
||||
./main -m ./models/ggml-base.en.bin -f ./samples/jfk.wav -ml 1
|
||||
|
||||
whisper_model_load: loading model from './models/ggml-base.en.bin'
|
||||
...
|
||||
system_info: n_threads = 4 / 10 | AVX2 = 0 | AVX512 = 0 | NEON = 1 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 |
|
||||
|
||||
main: processing './samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 processors, lang = en, task = transcribe, timestamps = 1 ...
|
||||
|
||||
[00:00:00.000 --> 00:00:00.320]
|
||||
[00:00:00.320 --> 00:00:00.370] And
|
||||
[00:00:00.370 --> 00:00:00.690] so
|
||||
[00:00:00.690 --> 00:00:00.850] my
|
||||
[00:00:00.850 --> 00:00:01.590] fellow
|
||||
[00:00:01.590 --> 00:00:02.850] Americans
|
||||
[00:00:02.850 --> 00:00:03.300] ,
|
||||
[00:00:03.300 --> 00:00:04.140] ask
|
||||
[00:00:04.140 --> 00:00:04.990] not
|
||||
[00:00:04.990 --> 00:00:05.410] what
|
||||
[00:00:05.410 --> 00:00:05.660] your
|
||||
[00:00:05.660 --> 00:00:06.260] country
|
||||
[00:00:06.260 --> 00:00:06.600] can
|
||||
[00:00:06.600 --> 00:00:06.840] do
|
||||
[00:00:06.840 --> 00:00:07.010] for
|
||||
[00:00:07.010 --> 00:00:08.170] you
|
||||
[00:00:08.170 --> 00:00:08.190] ,
|
||||
[00:00:08.190 --> 00:00:08.430] ask
|
||||
[00:00:08.430 --> 00:00:08.910] what
|
||||
[00:00:08.910 --> 00:00:09.040] you
|
||||
[00:00:09.040 --> 00:00:09.320] can
|
||||
[00:00:09.320 --> 00:00:09.440] do
|
||||
[00:00:09.440 --> 00:00:09.760] for
|
||||
[00:00:09.760 --> 00:00:10.020] your
|
||||
[00:00:10.020 --> 00:00:10.510] country
|
||||
[00:00:10.510 --> 00:00:11.000] .
|
||||
```
|
||||
|
||||
## Karaoke-style movie generation (experimental)
|
||||
|
||||
The [main](examples/main) example provides support for output of karaoke-style movies, where the
|
||||
currently pronounced word is highlighted. Use the `-wts` argument and run the generated bash script.
|
||||
This requires to have `ffmpeg` installed.
|
||||
|
||||
Here are a few *"typical"* examples:
|
||||
|
||||
```java
|
||||
./main -m ./models/ggml-base.en.bin -f ./samples/jfk.wav -owts
|
||||
source ./samples/jfk.wav.wts
|
||||
ffplay ./samples/jfk.wav.mp4
|
||||
```
|
||||
|
||||
https://user-images.githubusercontent.com/1991296/199337465-dbee4b5e-9aeb-48a3-b1c6-323ac4db5b2c.mp4
|
||||
|
||||
---
|
||||
|
||||
```java
|
||||
./main -m ./models/ggml-base.en.bin -f ./samples/mm0.wav -owts
|
||||
source ./samples/mm0.wav.wts
|
||||
ffplay ./samples/mm0.wav.mp4
|
||||
```
|
||||
|
||||
https://user-images.githubusercontent.com/1991296/199337504-cc8fd233-0cb7-4920-95f9-4227de3570aa.mp4
|
||||
|
||||
---
|
||||
|
||||
```java
|
||||
./main -m ./models/ggml-base.en.bin -f ./samples/gb0.wav -owts
|
||||
source ./samples/gb0.wav.wts
|
||||
ffplay ./samples/gb0.wav.mp4
|
||||
```
|
||||
|
||||
https://user-images.githubusercontent.com/1991296/199337538-b7b0c7a3-2753-4a88-a0cd-f28a317987ba.mp4
|
||||
|
||||
---
|
||||
|
||||
## Benchmarks
|
||||
|
||||
In order to have an objective comparison of the performance of the inference across different system configurations,
|
||||
use the [bench](examples/bench) tool. The tool simply runs the Encoder part of the model and prints how much time it
|
||||
took to execute it. The results are summarized in the following Github issue:
|
||||
|
||||
[Benchmark results](https://github.com/ggerganov/whisper.cpp/issues/89)
|
||||
|
||||
## ggml format
|
||||
|
||||
@ -248,6 +439,46 @@ The original models are converted to a custom binary format. This allows to pack
|
||||
- vocabulary
|
||||
- weights
|
||||
|
||||
You can download the converted models using the [download-ggml-model.sh](download-ggml-model.sh) script.
|
||||
You can download the converted models using the [models/download-ggml-model.sh](models/download-ggml-model.sh) script
|
||||
or manually from here:
|
||||
|
||||
For more details, see the conversion script [convert-pt-to-ggml.py](convert-pt-to-ggml.py) or the README in [models](models).
|
||||
- https://huggingface.co/datasets/ggerganov/whisper.cpp
|
||||
- https://ggml.ggerganov.com
|
||||
|
||||
For more details, see the conversion script [models/convert-pt-to-ggml.py](models/convert-pt-to-ggml.py) or the README
|
||||
in [models](models).
|
||||
|
||||
## [Bindings](https://github.com/ggerganov/whisper.cpp/discussions/categories/bindings)
|
||||
|
||||
- [X] Rust: [tazz4843/whisper-rs](https://github.com/tazz4843/whisper-rs) | [#310](https://github.com/ggerganov/whisper.cpp/discussions/310)
|
||||
- [X] Javascript: [bindings/javascript](bindings/javascript) | [#309](https://github.com/ggerganov/whisper.cpp/discussions/309)
|
||||
- [X] Go: [bindings/go](bindings/go) | [#312](https://github.com/ggerganov/whisper.cpp/discussions/312)
|
||||
- [X] Objective-C / Swift: [ggerganov/whisper.spm](https://github.com/ggerganov/whisper.spm) | [#313](https://github.com/ggerganov/whisper.cpp/discussions/313)
|
||||
- [ ] Python: soon | [WIP](https://github.com/ggerganov/whisper.cpp/issues/9)
|
||||
|
||||
## Examples
|
||||
|
||||
There are various examples of using the library for different projects in the [examples](examples) folder.
|
||||
Some of the examples are even ported to run in the browser using WebAssembly. Check them out!
|
||||
|
||||
| Example | Web | Description |
|
||||
| --- | --- | --- |
|
||||
| [main](examples/main) | [whisper.wasm](examples/whisper.wasm) | Tool for translating and transcribing audio using Whisper |
|
||||
| [bench](examples/bench) | [bench.wasm](examples/bench.wasm) | Benchmark the performance of Whisper on your machine |
|
||||
| [stream](examples/stream) | [stream.wasm](examples/stream.wasm) | Real-time transcription of raw microphone capture |
|
||||
| [command](examples/command) | [command.wasm](examples/command.wasm) | Basic voice assistant example for receiving voice commands from the mic |
|
||||
| [talk](examples/talk) | [talk.wasm](examples/talk.wasm) | Talk with a GPT-2 bot |
|
||||
| [whisper.objc](examples/whisper.objc) | | iOS mobile application using whisper.cpp |
|
||||
| [whisper.swiftui](examples/whisper.swiftui) | | SwiftUI iOS / macOS application using whisper.cpp |
|
||||
| [whisper.android](examples/whisper.android) | | Android mobile application using whisper.cpp |
|
||||
| [whisper.nvim](examples/whisper.nvim) | | Speech-to-text plugin for Neovim |
|
||||
| [generate-karaoke.sh](examples/generate-karaoke.sh) | | Helper script to easily [generate a karaoke video](https://youtu.be/uj7hVta4blM) of raw audio capture |
|
||||
| [livestream.sh](examples/livestream.sh) | | [Livestream audio transcription](https://github.com/ggerganov/whisper.cpp/issues/185) |
|
||||
| [yt-wsp.sh](examples/yt-wsp.sh) | | Download + transcribe and/or translate any VOD [(original)](https://gist.github.com/DaniruKun/96f763ec1a037cc92fe1a059b643b818) |
|
||||
|
||||
## [Discussions](https://github.com/ggerganov/whisper.cpp/discussions)
|
||||
|
||||
If you have any kind of feedback about this project feel free to use the Discussions section and open a new topic.
|
||||
You can use the [Show and tell](https://github.com/ggerganov/whisper.cpp/discussions/categories/show-and-tell) category
|
||||
to share your own projects that use `whisper.cpp`. If you have a question, make sure to check the
|
||||
[Frequently asked questions (#126)](https://github.com/ggerganov/whisper.cpp/discussions/126) discussion.
|
||||
|
19
bindings/CMakeLists.txt
Normal file
19
bindings/CMakeLists.txt
Normal file
@ -0,0 +1,19 @@
|
||||
if (EMSCRIPTEN)
|
||||
add_subdirectory(javascript)
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${CMAKE_CURRENT_SOURCE_DIR}/javascript/publish.log
|
||||
DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/javascript/whisper.js
|
||||
DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/javascript/libwhisper.worker.js
|
||||
DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/javascript/package.json
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/javascript
|
||||
COMMAND npm publish
|
||||
COMMAND touch publish.log
|
||||
COMMENT "Publishing npm module v${PROJECT_VERSION}"
|
||||
VERBATIM
|
||||
)
|
||||
|
||||
add_custom_target(publish-npm
|
||||
DEPENDS javascript/publish.log
|
||||
)
|
||||
endif()
|
2
bindings/go/.gitignore
vendored
Normal file
2
bindings/go/.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
build
|
||||
models
|
21
bindings/go/LICENSE
Normal file
21
bindings/go/LICENSE
Normal file
@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2022 David Thorpe
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
38
bindings/go/Makefile
Normal file
38
bindings/go/Makefile
Normal file
@ -0,0 +1,38 @@
|
||||
BUILD_DIR := build
|
||||
MODELS_DIR := models
|
||||
EXAMPLES_DIR := $(wildcard examples/*)
|
||||
INCLUDE_PATH := $(abspath ../..)
|
||||
LIBRARY_PATH := $(abspath ../..)
|
||||
|
||||
all: clean whisper examples
|
||||
|
||||
whisper: mkdir
|
||||
@echo Build whisper
|
||||
@${MAKE} -C ../.. libwhisper.a
|
||||
|
||||
test: model-small whisper modtidy
|
||||
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go test -v .
|
||||
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go test -v ./pkg/whisper/...
|
||||
|
||||
examples: $(EXAMPLES_DIR)
|
||||
|
||||
model-small: mkdir examples/go-model-download
|
||||
@${BUILD_DIR}/go-model-download -out models ggml-small.en.bin
|
||||
|
||||
$(EXAMPLES_DIR): mkdir whisper modtidy
|
||||
@echo Build example $(notdir $@)
|
||||
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go build ${BUILD_FLAGS} -o ${BUILD_DIR}/$(notdir $@) ./$@
|
||||
|
||||
mkdir:
|
||||
@echo Mkdir ${BUILD_DIR}
|
||||
@install -d ${BUILD_DIR}
|
||||
@echo Mkdir ${MODELS_DIR}
|
||||
@install -d ${MODELS_DIR}
|
||||
|
||||
modtidy:
|
||||
@go mod tidy
|
||||
|
||||
clean:
|
||||
@echo Clean
|
||||
@rm -fr $(BUILD_DIR)
|
||||
@go clean
|
100
bindings/go/README.md
Normal file
100
bindings/go/README.md
Normal file
@ -0,0 +1,100 @@
|
||||
# Go bindings for Whisper
|
||||
|
||||
This package provides Go bindings for whisper.cpp. They have been tested on:
|
||||
|
||||
* Darwin (OS X) 12.6 on x64_64
|
||||
* Debian Linux on arm64
|
||||
* Fedora Linux on x86_64
|
||||
|
||||
The "low level" bindings are in the `bindings/go` directory and there is a more
|
||||
Go-style package in the `bindings/go/pkg/whisper` directory. The most simple usage
|
||||
is as follows:
|
||||
|
||||
```go
|
||||
import (
|
||||
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
||||
)
|
||||
|
||||
func main() {
|
||||
var modelpath string // Path to the model
|
||||
var samples []float32 // Samples to process
|
||||
|
||||
// Load the model
|
||||
model, err := whisper.New(modelpath)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer model.Close()
|
||||
|
||||
// Process samples
|
||||
context, err := model.NewContext()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := context.Process(samples, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Print out the results
|
||||
for {
|
||||
segment, err := context.NextSegment()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
fmt.Printf("[%6s->%6s] %s\n", segment.Start, segment.End, segment.Text)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Building & Testing
|
||||
|
||||
In order to build, you need to have the Go compiler installed. You can get it from [here](https://golang.org/dl/). Run the tests with:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/ggerganov/whisper.cpp.git
|
||||
cd whisper.cpp/bindings/go
|
||||
make test
|
||||
```
|
||||
|
||||
This will compile a static `libwhisper.a` in a `build` folder, download a model file, then run the tests. To build the examples:
|
||||
|
||||
```bash
|
||||
make examples
|
||||
```
|
||||
|
||||
The examples are placed in the `build` directory. Once built, you can download all the models with the following command:
|
||||
|
||||
```bash
|
||||
./build/go-model-download -out models
|
||||
```
|
||||
|
||||
And you can then test a model against samples with the following command:
|
||||
|
||||
```bash
|
||||
./build/go-whisper -model models/ggml-tiny.en.bin samples/jfk.wav
|
||||
```
|
||||
|
||||
## Using the bindings
|
||||
|
||||
To use the bindings in your own software,
|
||||
|
||||
1. Import `github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper` (or `github.com/ggerganov/whisper.cpp/bindings/go` into your package;
|
||||
2. Compile `libwhisper.a` (you can use `make whisper` in the `bindings/go` directory);
|
||||
3. Link your go binary against whisper by setting the environment variables `C_INCLUDE_PATH` and `LIBRARY_PATH`
|
||||
to point to the `whisper.h` file directory and `libwhisper.a` file directory respectively.
|
||||
|
||||
Look at the `Makefile` in the `bindings/go` directory for an example.
|
||||
|
||||
The API Documentation:
|
||||
|
||||
* https://pkg.go.dev/github.com/ggerganov/whisper.cpp/bindings/go
|
||||
* https://pkg.go.dev/github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper
|
||||
|
||||
Getting help:
|
||||
|
||||
* Follow the discussion for the go bindings [here](https://github.com/ggerganov/whisper.cpp/discussions/312)
|
||||
|
||||
## License
|
||||
|
||||
The license for the Go bindings is the same as the license for the rest of the whisper.cpp project, which is the MIT License. See the `LICENSE` file for more details.
|
||||
|
5
bindings/go/doc.go
Normal file
5
bindings/go/doc.go
Normal file
@ -0,0 +1,5 @@
|
||||
/*
|
||||
github.com/ggerganov/whisper.cpp/bindings/go
|
||||
provides a speech-to-text service bindings for the Go programming language.
|
||||
*/
|
||||
package whisper
|
30
bindings/go/examples/go-model-download/context.go
Normal file
30
bindings/go/examples/go-model-download/context.go
Normal file
@ -0,0 +1,30 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"os/signal"
|
||||
)
|
||||
|
||||
// ContextForSignal returns a context object which is cancelled when a signal
|
||||
// is received. It returns nil if no signal parameter is provided
|
||||
func ContextForSignal(signals ...os.Signal) context.Context {
|
||||
if len(signals) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
ch := make(chan os.Signal)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// Send message on channel when signal received
|
||||
signal.Notify(ch, signals...)
|
||||
|
||||
// When any signal received, call cancel
|
||||
go func() {
|
||||
<-ch
|
||||
cancel()
|
||||
}()
|
||||
|
||||
// Return success
|
||||
return ctx
|
||||
}
|
208
bindings/go/examples/go-model-download/main.go
Normal file
208
bindings/go/examples/go-model-download/main.go
Normal file
@ -0,0 +1,208 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// CONSTANTS
|
||||
|
||||
const (
|
||||
srcUrl = "https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main" // The location of the models
|
||||
srcExt = ".bin" // Filename extension
|
||||
bufSize = 1024 * 64 // Size of the buffer used for downloading the model
|
||||
)
|
||||
|
||||
var (
|
||||
// The models which will be downloaded, if no model is specified as an argument
|
||||
modelNames = []string{"ggml-tiny.en", "ggml-tiny", "ggml-base.en", "ggml-base", "ggml-small.en", "ggml-small", "ggml-medium.en", "ggml-medium", "ggml-large-v1", "ggml-large"}
|
||||
)
|
||||
|
||||
var (
|
||||
// The output folder. When not set, use current working directory.
|
||||
flagOut = flag.String("out", "", "Output folder")
|
||||
|
||||
// HTTP timeout parameter - will timeout if takes longer than this to download a model
|
||||
flagTimeout = flag.Duration("timeout", 30*time.Minute, "HTTP timeout")
|
||||
|
||||
// Quiet parameter - will not print progress if set
|
||||
flagQuiet = flag.Bool("quiet", false, "Quiet mode")
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MAIN
|
||||
|
||||
func main() {
|
||||
flag.Usage = func() {
|
||||
name := filepath.Base(flag.CommandLine.Name())
|
||||
fmt.Fprintf(flag.CommandLine.Output(), "Usage: %s [options] <model>\n\n", name)
|
||||
flag.PrintDefaults()
|
||||
}
|
||||
flag.Parse()
|
||||
|
||||
// Get output path
|
||||
out, err := GetOut()
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "Error:", err)
|
||||
os.Exit(-1)
|
||||
}
|
||||
|
||||
// Create context which quits on SIGINT or SIGQUIT
|
||||
ctx := ContextForSignal(os.Interrupt, syscall.SIGQUIT)
|
||||
|
||||
// Progress filehandle
|
||||
progress := os.Stdout
|
||||
if *flagQuiet {
|
||||
progress, err = os.Open(os.DevNull)
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "Error:", err)
|
||||
os.Exit(-1)
|
||||
}
|
||||
defer progress.Close()
|
||||
}
|
||||
|
||||
// Download models - exit on error or interrupt
|
||||
for _, model := range GetModels() {
|
||||
url, err := URLForModel(model)
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "Error:", err)
|
||||
continue
|
||||
} else if path, err := Download(ctx, progress, url, out); err == nil || err == io.EOF {
|
||||
continue
|
||||
} else if err == context.Canceled {
|
||||
os.Remove(path)
|
||||
fmt.Fprintln(progress, "\nInterrupted")
|
||||
break
|
||||
} else if err == context.DeadlineExceeded {
|
||||
os.Remove(path)
|
||||
fmt.Fprintln(progress, "Timeout downloading model")
|
||||
continue
|
||||
} else {
|
||||
os.Remove(path)
|
||||
fmt.Fprintln(os.Stderr, "Error:", err)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// PUBLIC METHODS
|
||||
|
||||
// GetOut returns the path to the output directory
|
||||
func GetOut() (string, error) {
|
||||
if *flagOut == "" {
|
||||
return os.Getwd()
|
||||
}
|
||||
if info, err := os.Stat(*flagOut); err != nil {
|
||||
return "", err
|
||||
} else if !info.IsDir() {
|
||||
return "", fmt.Errorf("not a directory: %s", info.Name())
|
||||
} else {
|
||||
return *flagOut, nil
|
||||
}
|
||||
}
|
||||
|
||||
// GetModels returns the list of models to download
|
||||
func GetModels() []string {
|
||||
if flag.NArg() == 0 {
|
||||
return modelNames
|
||||
} else {
|
||||
return flag.Args()
|
||||
}
|
||||
}
|
||||
|
||||
// URLForModel returns the URL for the given model on huggingface.co
|
||||
func URLForModel(model string) (string, error) {
|
||||
if filepath.Ext(model) != srcExt {
|
||||
model += srcExt
|
||||
}
|
||||
url, err := url.Parse(srcUrl)
|
||||
if err != nil {
|
||||
return "", err
|
||||
} else {
|
||||
url.Path = filepath.Join(url.Path, model)
|
||||
}
|
||||
return url.String(), nil
|
||||
}
|
||||
|
||||
// Download downloads the model from the given URL to the given output directory
|
||||
func Download(ctx context.Context, p io.Writer, model, out string) (string, error) {
|
||||
// Create HTTP client
|
||||
client := http.Client{
|
||||
Timeout: *flagTimeout,
|
||||
}
|
||||
|
||||
// Initiate the download
|
||||
req, err := http.NewRequest("GET", model, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("%s: %s", model, resp.Status)
|
||||
}
|
||||
|
||||
// If output file exists and is the same size as the model, skip
|
||||
path := filepath.Join(out, filepath.Base(model))
|
||||
if info, err := os.Stat(path); err == nil && info.Size() == resp.ContentLength {
|
||||
fmt.Fprintln(p, "Skipping", model, "as it already exists")
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Create file
|
||||
w, err := os.Create(path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer w.Close()
|
||||
|
||||
// Report
|
||||
fmt.Fprintln(p, "Downloading", model, "to", out)
|
||||
|
||||
// Progressively download the model
|
||||
data := make([]byte, bufSize)
|
||||
count, pct := int64(0), int64(0)
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Cancelled, return error
|
||||
return path, ctx.Err()
|
||||
case <-ticker.C:
|
||||
pct = DownloadReport(p, pct, count, resp.ContentLength)
|
||||
default:
|
||||
// Read body
|
||||
n, err := resp.Body.Read(data)
|
||||
if err != nil {
|
||||
DownloadReport(p, pct, count, resp.ContentLength)
|
||||
return path, err
|
||||
} else if m, err := w.Write(data[:n]); err != nil {
|
||||
return path, err
|
||||
} else {
|
||||
count += int64(m)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Report periodically reports the download progress when percentage changes
|
||||
func DownloadReport(w io.Writer, pct, count, total int64) int64 {
|
||||
pct_ := count * 100 / total
|
||||
if pct_ > pct {
|
||||
fmt.Fprintf(w, " ...%d MB written (%d%%)\n", count/1e6, pct_)
|
||||
}
|
||||
return pct_
|
||||
}
|
22
bindings/go/examples/go-whisper/color.go
Normal file
22
bindings/go/examples/go-whisper/color.go
Normal file
@ -0,0 +1,22 @@
|
||||
package main
|
||||
|
||||
import "fmt"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// CONSTANTS
|
||||
|
||||
const (
|
||||
Reset = "\033[0m"
|
||||
RGBPrefix = "\033[38;5;" // followed by RGB values in decimal format separated by colons
|
||||
RGBSuffix = "m"
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// PUBLIC METHODS
|
||||
|
||||
// Colorize text with RGB values, from 0 to 23
|
||||
func Colorize(text string, v int) string {
|
||||
// https://en.wikipedia.org/wiki/ANSI_escape_code#8-bit
|
||||
// Grayscale colors are in the range 232-255
|
||||
return RGBPrefix + fmt.Sprint(v%24+232) + RGBSuffix + text + Reset
|
||||
}
|
156
bindings/go/examples/go-whisper/flags.go
Normal file
156
bindings/go/examples/go-whisper/flags.go
Normal file
@ -0,0 +1,156 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
// Packages
|
||||
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// TYPES
|
||||
|
||||
type Flags struct {
|
||||
*flag.FlagSet
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// LIFECYCLE
|
||||
|
||||
func NewFlags(name string, args []string) (*Flags, error) {
|
||||
flags := &Flags{
|
||||
FlagSet: flag.NewFlagSet(name, flag.ContinueOnError),
|
||||
}
|
||||
|
||||
// Register the command line arguments
|
||||
registerFlags(flags)
|
||||
|
||||
// Parse command line
|
||||
if err := flags.Parse(args); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Return success
|
||||
return flags, nil
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// PUBLIC METHODS
|
||||
|
||||
func (flags *Flags) GetModel() string {
|
||||
return flags.Lookup("model").Value.String()
|
||||
}
|
||||
|
||||
func (flags *Flags) GetLanguage() string {
|
||||
return flags.Lookup("language").Value.String()
|
||||
}
|
||||
|
||||
func (flags *Flags) IsTranslate() bool {
|
||||
return flags.Lookup("translate").Value.(flag.Getter).Get().(bool)
|
||||
}
|
||||
|
||||
func (flags *Flags) GetOffset() time.Duration {
|
||||
return flags.Lookup("offset").Value.(flag.Getter).Get().(time.Duration)
|
||||
}
|
||||
|
||||
func (flags *Flags) GetDuration() time.Duration {
|
||||
return flags.Lookup("duration").Value.(flag.Getter).Get().(time.Duration)
|
||||
}
|
||||
|
||||
func (flags *Flags) GetThreads() uint {
|
||||
return flags.Lookup("threads").Value.(flag.Getter).Get().(uint)
|
||||
}
|
||||
|
||||
func (flags *Flags) GetOut() string {
|
||||
return strings.ToLower(flags.Lookup("out").Value.String())
|
||||
}
|
||||
|
||||
func (flags *Flags) IsSpeedup() bool {
|
||||
return flags.Lookup("speedup").Value.String() == "true"
|
||||
}
|
||||
|
||||
func (flags *Flags) IsTokens() bool {
|
||||
return flags.Lookup("tokens").Value.String() == "true"
|
||||
}
|
||||
|
||||
func (flags *Flags) IsColorize() bool {
|
||||
return flags.Lookup("colorize").Value.String() == "true"
|
||||
}
|
||||
|
||||
func (flags *Flags) GetMaxLen() uint {
|
||||
return flags.Lookup("max-len").Value.(flag.Getter).Get().(uint)
|
||||
}
|
||||
|
||||
func (flags *Flags) GetMaxTokens() uint {
|
||||
return flags.Lookup("max-tokens").Value.(flag.Getter).Get().(uint)
|
||||
}
|
||||
|
||||
func (flags *Flags) GetWordThreshold() float32 {
|
||||
return float32(flags.Lookup("word-thold").Value.(flag.Getter).Get().(float64))
|
||||
}
|
||||
|
||||
func (flags *Flags) SetParams(context whisper.Context) error {
|
||||
if lang := flags.GetLanguage(); lang != "" && lang != "auto" {
|
||||
fmt.Fprintf(flags.Output(), "Setting language to %q\n", lang)
|
||||
if err := context.SetLanguage(lang); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if flags.IsTranslate() && context.IsMultilingual() {
|
||||
fmt.Fprintf(flags.Output(), "Setting translate to true\n")
|
||||
context.SetTranslate(true)
|
||||
}
|
||||
if offset := flags.GetOffset(); offset != 0 {
|
||||
fmt.Fprintf(flags.Output(), "Setting offset to %v\n", offset)
|
||||
context.SetOffset(offset)
|
||||
}
|
||||
if duration := flags.GetDuration(); duration != 0 {
|
||||
fmt.Fprintf(flags.Output(), "Setting duration to %v\n", duration)
|
||||
context.SetDuration(duration)
|
||||
}
|
||||
if flags.IsSpeedup() {
|
||||
fmt.Fprintf(flags.Output(), "Setting speedup to true\n")
|
||||
context.SetSpeedup(true)
|
||||
}
|
||||
if threads := flags.GetThreads(); threads != 0 {
|
||||
fmt.Fprintf(flags.Output(), "Setting threads to %d\n", threads)
|
||||
context.SetThreads(threads)
|
||||
}
|
||||
if max_len := flags.GetMaxLen(); max_len != 0 {
|
||||
fmt.Fprintf(flags.Output(), "Setting max_segment_length to %d\n", max_len)
|
||||
context.SetMaxSegmentLength(max_len)
|
||||
}
|
||||
if max_tokens := flags.GetMaxTokens(); max_tokens != 0 {
|
||||
fmt.Fprintf(flags.Output(), "Setting max_tokens to %d\n", max_tokens)
|
||||
context.SetMaxTokensPerSegment(max_tokens)
|
||||
}
|
||||
if word_threshold := flags.GetWordThreshold(); word_threshold != 0 {
|
||||
fmt.Fprintf(flags.Output(), "Setting word_threshold to %f\n", word_threshold)
|
||||
context.SetTokenThreshold(word_threshold)
|
||||
}
|
||||
|
||||
// Return success
|
||||
return nil
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// PRIVATE METHODS
|
||||
|
||||
func registerFlags(flag *Flags) {
|
||||
flag.String("model", "", "Path to the model file")
|
||||
flag.String("language", "", "Spoken language")
|
||||
flag.Bool("translate", false, "Translate from source language to english")
|
||||
flag.Duration("offset", 0, "Time offset")
|
||||
flag.Duration("duration", 0, "Duration of audio to process")
|
||||
flag.Uint("threads", 0, "Number of threads to use")
|
||||
flag.Bool("speedup", false, "Enable speedup")
|
||||
flag.Uint("max-len", 0, "Maximum segment length in characters")
|
||||
flag.Uint("max-tokens", 0, "Maximum tokens per segment")
|
||||
flag.Float64("word-thold", 0, "Maximum segment score")
|
||||
flag.Bool("tokens", false, "Display tokens")
|
||||
flag.Bool("colorize", false, "Colorize tokens")
|
||||
flag.String("out", "", "Output format (srt, none or leave as empty string)")
|
||||
}
|
43
bindings/go/examples/go-whisper/main.go
Normal file
43
bindings/go/examples/go-whisper/main.go
Normal file
@ -0,0 +1,43 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
// Packages
|
||||
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
||||
)
|
||||
|
||||
func main() {
|
||||
flags, err := NewFlags(filepath.Base(os.Args[0]), os.Args[1:])
|
||||
if err == flag.ErrHelp {
|
||||
os.Exit(0)
|
||||
} else if err != nil {
|
||||
fmt.Fprintln(os.Stderr, err)
|
||||
os.Exit(1)
|
||||
} else if flags.GetModel() == "" {
|
||||
fmt.Fprintln(os.Stderr, "Use -model flag to specify which model file to use")
|
||||
os.Exit(1)
|
||||
} else if flags.NArg() == 0 {
|
||||
fmt.Fprintln(os.Stderr, "No input files specified")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Load model
|
||||
model, err := whisper.New(flags.GetModel())
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer model.Close()
|
||||
|
||||
// Process files
|
||||
for _, filename := range flags.Args() {
|
||||
if err := Process(model, filename, flags); err != nil {
|
||||
fmt.Fprintln(os.Stderr, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
127
bindings/go/examples/go-whisper/process.go
Normal file
127
bindings/go/examples/go-whisper/process.go
Normal file
@ -0,0 +1,127 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
// Package imports
|
||||
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
||||
wav "github.com/go-audio/wav"
|
||||
)
|
||||
|
||||
func Process(model whisper.Model, path string, flags *Flags) error {
|
||||
var data []float32
|
||||
|
||||
// Create processing context
|
||||
context, err := model.NewContext()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Set the parameters
|
||||
if err := flags.SetParams(context); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Open the file
|
||||
fmt.Fprintf(flags.Output(), "Loading %q\n", path)
|
||||
fh, err := os.Open(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer fh.Close()
|
||||
|
||||
// Decode the WAV file - load the full buffer
|
||||
dec := wav.NewDecoder(fh)
|
||||
if buf, err := dec.FullPCMBuffer(); err != nil {
|
||||
return err
|
||||
} else if dec.SampleRate != whisper.SampleRate {
|
||||
return fmt.Errorf("unsupported sample rate: %d", dec.SampleRate)
|
||||
} else if dec.NumChans != 1 {
|
||||
return fmt.Errorf("unsupported number of channels: %d", dec.NumChans)
|
||||
} else {
|
||||
data = buf.AsFloat32Buffer().Data
|
||||
}
|
||||
|
||||
// Segment callback when -tokens is specified
|
||||
var cb whisper.SegmentCallback
|
||||
if flags.IsTokens() {
|
||||
cb = func(segment whisper.Segment) {
|
||||
fmt.Fprintf(flags.Output(), "%02d [%6s->%6s] ", segment.Num, segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond))
|
||||
for _, token := range segment.Tokens {
|
||||
if flags.IsColorize() && context.IsText(token) {
|
||||
fmt.Fprint(flags.Output(), Colorize(token.Text, int(token.P*24.0)), " ")
|
||||
} else {
|
||||
fmt.Fprint(flags.Output(), token.Text, " ")
|
||||
}
|
||||
}
|
||||
fmt.Fprintln(flags.Output(), "")
|
||||
fmt.Fprintln(flags.Output(), "")
|
||||
}
|
||||
}
|
||||
|
||||
// Process the data
|
||||
fmt.Fprintf(flags.Output(), " ...processing %q\n", path)
|
||||
if err := context.Process(data, cb); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Print out the results
|
||||
switch {
|
||||
case flags.GetOut() == "srt":
|
||||
return OutputSRT(os.Stdout, context)
|
||||
case flags.GetOut() == "none":
|
||||
return nil
|
||||
default:
|
||||
return Output(os.Stdout, context, flags.IsColorize())
|
||||
}
|
||||
}
|
||||
|
||||
// Output text as SRT file
|
||||
func OutputSRT(w io.Writer, context whisper.Context) error {
|
||||
n := 1
|
||||
for {
|
||||
segment, err := context.NextSegment()
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Fprintln(w, n)
|
||||
fmt.Fprintln(w, srtTimestamp(segment.Start), " --> ", srtTimestamp(segment.End))
|
||||
fmt.Fprintln(w, segment.Text)
|
||||
fmt.Fprintln(w, "")
|
||||
n++
|
||||
}
|
||||
}
|
||||
|
||||
// Output text to terminal
|
||||
func Output(w io.Writer, context whisper.Context, colorize bool) error {
|
||||
for {
|
||||
segment, err := context.NextSegment()
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Fprintf(w, "[%6s->%6s]", segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond))
|
||||
if colorize {
|
||||
for _, token := range segment.Tokens {
|
||||
if !context.IsText(token) {
|
||||
continue
|
||||
}
|
||||
fmt.Fprint(w, " ", Colorize(token.Text, int(token.P*24.0)))
|
||||
}
|
||||
fmt.Fprint(w, "\n")
|
||||
} else {
|
||||
fmt.Fprintln(w, " ", segment.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Return srtTimestamp
|
||||
func srtTimestamp(t time.Duration) string {
|
||||
return fmt.Sprintf("%02d:%02d:%02d,%03d", t/time.Hour, (t%time.Hour)/time.Minute, (t%time.Minute)/time.Second, (t%time.Second)/time.Millisecond)
|
||||
}
|
16
bindings/go/go.mod
Normal file
16
bindings/go/go.mod
Normal file
@ -0,0 +1,16 @@
|
||||
module github.com/ggerganov/whisper.cpp/bindings/go
|
||||
|
||||
go 1.19
|
||||
|
||||
require (
|
||||
github.com/go-audio/wav v1.1.0
|
||||
github.com/stretchr/testify v1.8.1
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/go-audio/audio v1.0.0 // indirect
|
||||
github.com/go-audio/riff v1.0.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
23
bindings/go/go.sum
Normal file
23
bindings/go/go.sum
Normal file
@ -0,0 +1,23 @@
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/go-audio/audio v1.0.0 h1:zS9vebldgbQqktK4H0lUqWrG8P0NxCJVqcj7ZpNnwd4=
|
||||
github.com/go-audio/audio v1.0.0/go.mod h1:6uAu0+H2lHkwdGsAY+j2wHPNPpPoeg5AaEFh9FlA+Zs=
|
||||
github.com/go-audio/riff v1.0.0 h1:d8iCGbDvox9BfLagY94fBynxSPHO80LmZCaOsmKxokA=
|
||||
github.com/go-audio/riff v1.0.0/go.mod h1:l3cQwc85y79NQFCRB7TiPoNiaijp6q8Z0Uv38rVG498=
|
||||
github.com/go-audio/wav v1.1.0 h1:jQgLtbqBzY7G+BM8fXF7AHUk1uHUviWS4X39d5rsL2g=
|
||||
github.com/go-audio/wav v1.1.0/go.mod h1:mpe9qfwbScEbkd8uybLuIpTgHyrISw/OTuvjUW2iGtE=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
156
bindings/go/params.go
Normal file
156
bindings/go/params.go
Normal file
@ -0,0 +1,156 @@
|
||||
package whisper
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// CGO
|
||||
|
||||
/*
|
||||
#include <whisper.h>
|
||||
*/
|
||||
import "C"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// PUBLIC METHODS
|
||||
|
||||
func (p *Params) SetTranslate(v bool) {
|
||||
p.translate = toBool(v)
|
||||
}
|
||||
|
||||
func (p *Params) SetNoContext(v bool) {
|
||||
p.no_context = toBool(v)
|
||||
}
|
||||
|
||||
func (p *Params) SetSingleSegment(v bool) {
|
||||
p.single_segment = toBool(v)
|
||||
}
|
||||
|
||||
func (p *Params) SetPrintSpecial(v bool) {
|
||||
p.print_special = toBool(v)
|
||||
}
|
||||
|
||||
func (p *Params) SetPrintProgress(v bool) {
|
||||
p.print_progress = toBool(v)
|
||||
}
|
||||
|
||||
func (p *Params) SetPrintRealtime(v bool) {
|
||||
p.print_realtime = toBool(v)
|
||||
}
|
||||
|
||||
func (p *Params) SetPrintTimestamps(v bool) {
|
||||
p.print_timestamps = toBool(v)
|
||||
}
|
||||
|
||||
func (p *Params) SetSpeedup(v bool) {
|
||||
p.speed_up = toBool(v)
|
||||
}
|
||||
|
||||
// Set language id
|
||||
func (p *Params) SetLanguage(lang int) error {
|
||||
str := C.whisper_lang_str(C.int(lang))
|
||||
if str == nil {
|
||||
return ErrInvalidLanguage
|
||||
} else {
|
||||
p.language = str
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get language id
|
||||
func (p *Params) Language() int {
|
||||
if p.language == nil {
|
||||
return -1
|
||||
}
|
||||
return int(C.whisper_lang_id(p.language))
|
||||
}
|
||||
|
||||
// Set number of threads to use
|
||||
func (p *Params) SetThreads(threads int) {
|
||||
p.n_threads = C.int(threads)
|
||||
}
|
||||
|
||||
// Set start offset in ms
|
||||
func (p *Params) SetOffset(offset_ms int) {
|
||||
p.offset_ms = C.int(offset_ms)
|
||||
}
|
||||
|
||||
// Set audio duration to process in ms
|
||||
func (p *Params) SetDuration(duration_ms int) {
|
||||
p.duration_ms = C.int(duration_ms)
|
||||
}
|
||||
|
||||
// Set timestamp token probability threshold (~0.01)
|
||||
func (p *Params) SetTokenThreshold(t float32) {
|
||||
p.thold_pt = C.float(t)
|
||||
}
|
||||
|
||||
// Set timestamp token sum probability threshold (~0.01)
|
||||
func (p *Params) SetTokenSumThreshold(t float32) {
|
||||
p.thold_ptsum = C.float(t)
|
||||
}
|
||||
|
||||
// Set max segment length in characters
|
||||
func (p *Params) SetMaxSegmentLength(n int) {
|
||||
p.max_len = C.int(n)
|
||||
}
|
||||
|
||||
// Set max tokens per segment (0 = no limit)
|
||||
func (p *Params) SetMaxTokensPerSegment(n int) {
|
||||
p.max_tokens = C.int(n)
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// PRIVATE METHODS
|
||||
|
||||
func toBool(v bool) C.bool {
|
||||
if v {
|
||||
return C.bool(true)
|
||||
}
|
||||
return C.bool(false)
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// STRINGIFY
|
||||
|
||||
func (p *Params) String() string {
|
||||
str := "<whisper.params"
|
||||
str += fmt.Sprintf(" strategy=%v", p.strategy)
|
||||
str += fmt.Sprintf(" n_threads=%d", p.n_threads)
|
||||
if p.language != nil {
|
||||
str += fmt.Sprintf(" language=%s", C.GoString(p.language))
|
||||
}
|
||||
str += fmt.Sprintf(" n_max_text_ctx=%d", p.n_max_text_ctx)
|
||||
str += fmt.Sprintf(" offset_ms=%d", p.offset_ms)
|
||||
str += fmt.Sprintf(" duration_ms=%d", p.duration_ms)
|
||||
if p.translate {
|
||||
str += " translate"
|
||||
}
|
||||
if p.no_context {
|
||||
str += " no_context"
|
||||
}
|
||||
if p.single_segment {
|
||||
str += " single_segment"
|
||||
}
|
||||
if p.print_special {
|
||||
str += " print_special"
|
||||
}
|
||||
if p.print_progress {
|
||||
str += " print_progress"
|
||||
}
|
||||
if p.print_realtime {
|
||||
str += " print_realtime"
|
||||
}
|
||||
if p.print_timestamps {
|
||||
str += " print_timestamps"
|
||||
}
|
||||
if p.token_timestamps {
|
||||
str += " token_timestamps"
|
||||
}
|
||||
if p.speed_up {
|
||||
str += " speed_up"
|
||||
}
|
||||
|
||||
return str + ">"
|
||||
}
|
28
bindings/go/pkg/whisper/consts.go
Normal file
28
bindings/go/pkg/whisper/consts.go
Normal file
@ -0,0 +1,28 @@
|
||||
package whisper
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
// Bindings
|
||||
whisper "github.com/ggerganov/whisper.cpp/bindings/go"
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// ERRORS
|
||||
|
||||
var (
|
||||
ErrUnableToLoadModel = errors.New("unable to load model")
|
||||
ErrInternalAppError = errors.New("internal application error")
|
||||
ErrProcessingFailed = errors.New("processing failed")
|
||||
ErrUnsupportedLanguage = errors.New("unsupported language")
|
||||
ErrModelNotMultilingual = errors.New("model is not multilingual")
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// CONSTANTS
|
||||
|
||||
// SampleRate is the sample rate of the audio data.
|
||||
const SampleRate = whisper.SampleRate
|
||||
|
||||
// SampleBits is the number of bytes per sample.
|
||||
const SampleBits = whisper.SampleBits
|
251
bindings/go/pkg/whisper/context.go
Normal file
251
bindings/go/pkg/whisper/context.go
Normal file
@ -0,0 +1,251 @@
|
||||
package whisper
|
||||
|
||||
import (
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
// Bindings
|
||||
whisper "github.com/ggerganov/whisper.cpp/bindings/go"
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// TYPES
|
||||
|
||||
type context struct {
|
||||
n int
|
||||
model *model
|
||||
params whisper.Params
|
||||
}
|
||||
|
||||
// Make sure context adheres to the interface
|
||||
var _ Context = (*context)(nil)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// LIFECYCLE
|
||||
|
||||
func newContext(model *model, params whisper.Params) (Context, error) {
|
||||
context := new(context)
|
||||
context.model = model
|
||||
context.params = params
|
||||
|
||||
// Return success
|
||||
return context, nil
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// PUBLIC METHODS
|
||||
|
||||
// Set the language to use for speech recognition.
|
||||
func (context *context) SetLanguage(lang string) error {
|
||||
if context.model.ctx == nil {
|
||||
return ErrInternalAppError
|
||||
}
|
||||
if !context.model.IsMultilingual() {
|
||||
return ErrModelNotMultilingual
|
||||
}
|
||||
if id := context.model.ctx.Whisper_lang_id(lang); id < 0 {
|
||||
return ErrUnsupportedLanguage
|
||||
} else if err := context.params.SetLanguage(id); err != nil {
|
||||
return err
|
||||
}
|
||||
// Return success
|
||||
return nil
|
||||
}
|
||||
|
||||
func (context *context) IsMultilingual() bool {
|
||||
return context.model.IsMultilingual()
|
||||
}
|
||||
|
||||
// Get language
|
||||
func (context *context) Language() string {
|
||||
return whisper.Whisper_lang_str(context.params.Language())
|
||||
}
|
||||
|
||||
// Set translate flag
|
||||
func (context *context) SetTranslate(v bool) {
|
||||
context.params.SetTranslate(v)
|
||||
}
|
||||
|
||||
// Set speedup flag
|
||||
func (context *context) SetSpeedup(v bool) {
|
||||
context.params.SetSpeedup(v)
|
||||
}
|
||||
|
||||
// Set number of threads to use
|
||||
func (context *context) SetThreads(v uint) {
|
||||
context.params.SetThreads(int(v))
|
||||
}
|
||||
|
||||
// Set time offset
|
||||
func (context *context) SetOffset(v time.Duration) {
|
||||
context.params.SetOffset(int(v.Milliseconds()))
|
||||
}
|
||||
|
||||
// Set duration of audio to process
|
||||
func (context *context) SetDuration(v time.Duration) {
|
||||
context.params.SetOffset(int(v.Milliseconds()))
|
||||
}
|
||||
|
||||
// Set timestamp token probability threshold (~0.01)
|
||||
func (context *context) SetTokenThreshold(t float32) {
|
||||
context.params.SetTokenThreshold(t)
|
||||
}
|
||||
|
||||
// Set timestamp token sum probability threshold (~0.01)
|
||||
func (context *context) SetTokenSumThreshold(t float32) {
|
||||
context.params.SetTokenSumThreshold(t)
|
||||
}
|
||||
|
||||
// Set max segment length in characters
|
||||
func (context *context) SetMaxSegmentLength(n uint) {
|
||||
context.params.SetMaxSegmentLength(int(n))
|
||||
}
|
||||
|
||||
// Set max tokens per segment (0 = no limit)
|
||||
func (context *context) SetMaxTokensPerSegment(n uint) {
|
||||
context.params.SetMaxTokensPerSegment(int(n))
|
||||
}
|
||||
|
||||
// Process new sample data and return any errors
|
||||
func (context *context) Process(data []float32, cb SegmentCallback) error {
|
||||
if context.model.ctx == nil {
|
||||
return ErrInternalAppError
|
||||
}
|
||||
// If the callback is defined then we force on single_segment mode
|
||||
if cb != nil {
|
||||
context.params.SetSingleSegment(true)
|
||||
}
|
||||
|
||||
// We don't do parallel processing at the moment
|
||||
processors := 0
|
||||
if processors > 1 {
|
||||
if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, nil, func(new int) {
|
||||
if cb != nil {
|
||||
num_segments := context.model.ctx.Whisper_full_n_segments()
|
||||
s0 := num_segments - new
|
||||
for i := s0; i < num_segments; i++ {
|
||||
cb(toSegment(context.model.ctx, i))
|
||||
}
|
||||
}
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if err := context.model.ctx.Whisper_full(context.params, data, nil, func(new int) {
|
||||
if cb != nil {
|
||||
num_segments := context.model.ctx.Whisper_full_n_segments()
|
||||
s0 := num_segments - new
|
||||
for i := s0; i < num_segments; i++ {
|
||||
cb(toSegment(context.model.ctx, i))
|
||||
}
|
||||
}
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Return success
|
||||
return nil
|
||||
}
|
||||
|
||||
// Return the next segment of tokens
|
||||
func (context *context) NextSegment() (Segment, error) {
|
||||
if context.model.ctx == nil {
|
||||
return Segment{}, ErrInternalAppError
|
||||
}
|
||||
if context.n >= context.model.ctx.Whisper_full_n_segments() {
|
||||
return Segment{}, io.EOF
|
||||
}
|
||||
|
||||
// Populate result
|
||||
result := toSegment(context.model.ctx, context.n)
|
||||
|
||||
// Increment the cursor
|
||||
context.n++
|
||||
|
||||
// Return success
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Test for text tokens
|
||||
func (context *context) IsText(t Token) bool {
|
||||
switch {
|
||||
case context.IsBEG(t):
|
||||
return false
|
||||
case context.IsSOT(t):
|
||||
return false
|
||||
case whisper.Token(t.Id) >= context.model.ctx.Whisper_token_eot():
|
||||
return false
|
||||
case context.IsPREV(t):
|
||||
return false
|
||||
case context.IsSOLM(t):
|
||||
return false
|
||||
case context.IsNOT(t):
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Test for "begin" token
|
||||
func (context *context) IsBEG(t Token) bool {
|
||||
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_beg()
|
||||
}
|
||||
|
||||
// Test for "start of transcription" token
|
||||
func (context *context) IsSOT(t Token) bool {
|
||||
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_sot()
|
||||
}
|
||||
|
||||
// Test for "end of transcription" token
|
||||
func (context *context) IsEOT(t Token) bool {
|
||||
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_eot()
|
||||
}
|
||||
|
||||
// Test for "start of prev" token
|
||||
func (context *context) IsPREV(t Token) bool {
|
||||
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_prev()
|
||||
}
|
||||
|
||||
// Test for "start of lm" token
|
||||
func (context *context) IsSOLM(t Token) bool {
|
||||
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_solm()
|
||||
}
|
||||
|
||||
// Test for "No timestamps" token
|
||||
func (context *context) IsNOT(t Token) bool {
|
||||
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_not()
|
||||
}
|
||||
|
||||
// Test for token associated with a specific language
|
||||
func (context *context) IsLANG(t Token, lang string) bool {
|
||||
if id := context.model.ctx.Whisper_lang_id(lang); id >= 0 {
|
||||
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_lang(id)
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// PRIVATE METHODS
|
||||
|
||||
func toSegment(ctx *whisper.Context, n int) Segment {
|
||||
return Segment{
|
||||
Num: n,
|
||||
Text: strings.TrimSpace(ctx.Whisper_full_get_segment_text(n)),
|
||||
Start: time.Duration(ctx.Whisper_full_get_segment_t0(n)) * time.Millisecond * 10,
|
||||
End: time.Duration(ctx.Whisper_full_get_segment_t1(n)) * time.Millisecond * 10,
|
||||
Tokens: toTokens(ctx, n),
|
||||
}
|
||||
}
|
||||
|
||||
func toTokens(ctx *whisper.Context, n int) []Token {
|
||||
result := make([]Token, ctx.Whisper_full_n_tokens(n))
|
||||
for i := 0; i < len(result); i++ {
|
||||
result[i] = Token{
|
||||
Id: int(ctx.Whisper_full_get_token_id(n, i)),
|
||||
Text: strings.TrimSpace(ctx.Whisper_full_get_token_text(n, i)),
|
||||
P: ctx.Whisper_full_get_token_p(n, i),
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
55
bindings/go/pkg/whisper/context_test.go
Normal file
55
bindings/go/pkg/whisper/context_test.go
Normal file
@ -0,0 +1,55 @@
|
||||
package whisper_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
// Packages
|
||||
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
||||
assert "github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const (
|
||||
ModelPath = "../../models/ggml-tiny.bin"
|
||||
SamplePath = "../../samples/jfk.wav"
|
||||
)
|
||||
|
||||
func Test_Whisper_000(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, model not found:", ModelPath)
|
||||
}
|
||||
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, sample not found:", SamplePath)
|
||||
}
|
||||
|
||||
// Load model
|
||||
model, err := whisper.New(ModelPath)
|
||||
assert.NoError(err)
|
||||
assert.NotNil(model)
|
||||
assert.NoError(model.Close())
|
||||
|
||||
t.Log("languages=", model.Languages())
|
||||
}
|
||||
|
||||
func Test_Whisper_001(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, model not found:", ModelPath)
|
||||
}
|
||||
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, sample not found:", SamplePath)
|
||||
}
|
||||
|
||||
// Load model
|
||||
model, err := whisper.New(ModelPath)
|
||||
assert.NoError(err)
|
||||
assert.NotNil(model)
|
||||
defer model.Close()
|
||||
|
||||
// Get context for decoding
|
||||
ctx, err := model.NewContext()
|
||||
assert.NoError(err)
|
||||
assert.NotNil(ctx)
|
||||
|
||||
}
|
4
bindings/go/pkg/whisper/doc.go
Normal file
4
bindings/go/pkg/whisper/doc.go
Normal file
@ -0,0 +1,4 @@
|
||||
/*
|
||||
This is the higher-level speech-to-text whisper.cpp API for go
|
||||
*/
|
||||
package whisper
|
85
bindings/go/pkg/whisper/interface.go
Normal file
85
bindings/go/pkg/whisper/interface.go
Normal file
@ -0,0 +1,85 @@
|
||||
package whisper
|
||||
|
||||
import (
|
||||
"io"
|
||||
"time"
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// TYPES
|
||||
|
||||
// SegmentCallback is the callback function for processing segments in real
|
||||
// time. It is called during the Process function
|
||||
type SegmentCallback func(Segment)
|
||||
|
||||
// Model is the interface to a whisper model. Create a new model with the
|
||||
// function whisper.New(string)
|
||||
type Model interface {
|
||||
io.Closer
|
||||
|
||||
// Return a new speech-to-text context.
|
||||
NewContext() (Context, error)
|
||||
|
||||
// Return true if the model is multilingual.
|
||||
IsMultilingual() bool
|
||||
|
||||
// Return all languages supported.
|
||||
Languages() []string
|
||||
}
|
||||
|
||||
// Context is the speach recognition context.
|
||||
type Context interface {
|
||||
SetLanguage(string) error // Set the language to use for speech recognition.
|
||||
SetTranslate(bool) // Set translate flag
|
||||
IsMultilingual() bool // Return true if the model is multilingual.
|
||||
Language() string // Get language
|
||||
|
||||
SetOffset(time.Duration) // Set offset
|
||||
SetDuration(time.Duration) // Set duration
|
||||
SetThreads(uint) // Set number of threads to use
|
||||
SetSpeedup(bool) // Set speedup flag
|
||||
SetTokenThreshold(float32) // Set timestamp token probability threshold
|
||||
SetTokenSumThreshold(float32) // Set timestamp token sum probability threshold
|
||||
SetMaxSegmentLength(uint) // Set max segment length in characters
|
||||
SetMaxTokensPerSegment(uint) // Set max tokens per segment (0 = no limit)
|
||||
|
||||
// Process mono audio data and return any errors.
|
||||
// If defined, newly generated segments are passed to the
|
||||
// callback function during processing.
|
||||
Process([]float32, SegmentCallback) error
|
||||
|
||||
// After process is called, return segments until the end of the stream
|
||||
// is reached, when io.EOF is returned.
|
||||
NextSegment() (Segment, error)
|
||||
|
||||
IsBEG(Token) bool // Test for "begin" token
|
||||
IsSOT(Token) bool // Test for "start of transcription" token
|
||||
IsEOT(Token) bool // Test for "end of transcription" token
|
||||
IsPREV(Token) bool // Test for "start of prev" token
|
||||
IsSOLM(Token) bool // Test for "start of lm" token
|
||||
IsNOT(Token) bool // Test for "No timestamps" token
|
||||
IsLANG(Token, string) bool // Test for token associated with a specific language
|
||||
IsText(Token) bool // Test for text token
|
||||
}
|
||||
|
||||
// Segment is the text result of a speech recognition.
|
||||
type Segment struct {
|
||||
// Segment Number
|
||||
Num int
|
||||
|
||||
// Time beginning and end timestamps for the segment.
|
||||
Start, End time.Duration
|
||||
|
||||
// The text of the segment.
|
||||
Text string
|
||||
|
||||
// The tokens of the segment.
|
||||
Tokens []Token
|
||||
}
|
||||
|
||||
// Token is a text or special token
|
||||
type Token struct {
|
||||
Id int
|
||||
Text string
|
||||
P float32
|
||||
}
|
100
bindings/go/pkg/whisper/model.go
Normal file
100
bindings/go/pkg/whisper/model.go
Normal file
@ -0,0 +1,100 @@
|
||||
package whisper
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
|
||||
// Bindings
|
||||
whisper "github.com/ggerganov/whisper.cpp/bindings/go"
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// TYPES
|
||||
|
||||
type model struct {
|
||||
path string
|
||||
ctx *whisper.Context
|
||||
}
|
||||
|
||||
// Make sure model adheres to the interface
|
||||
var _ Model = (*model)(nil)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// LIFECYCLE
|
||||
|
||||
func New(path string) (Model, error) {
|
||||
model := new(model)
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
return nil, err
|
||||
} else if ctx := whisper.Whisper_init(path); ctx == nil {
|
||||
return nil, ErrUnableToLoadModel
|
||||
} else {
|
||||
model.ctx = ctx
|
||||
model.path = path
|
||||
}
|
||||
|
||||
// Return success
|
||||
return model, nil
|
||||
}
|
||||
|
||||
func (model *model) Close() error {
|
||||
if model.ctx != nil {
|
||||
model.ctx.Whisper_free()
|
||||
}
|
||||
|
||||
// Release resources
|
||||
model.ctx = nil
|
||||
|
||||
// Return success
|
||||
return nil
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// STRINGIFY
|
||||
|
||||
func (model *model) String() string {
|
||||
str := "<whisper.model"
|
||||
if model.ctx != nil {
|
||||
str += fmt.Sprintf(" model=%q", model.path)
|
||||
}
|
||||
return str + ">"
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// PUBLIC METHODS
|
||||
|
||||
// Return true if model is multilingual (language and translation options are supported)
|
||||
func (model *model) IsMultilingual() bool {
|
||||
return model.ctx.Whisper_is_multilingual() != 0
|
||||
}
|
||||
|
||||
// Return all recognized languages. Initially it is set to auto-detect
|
||||
func (model *model) Languages() []string {
|
||||
result := make([]string, 0, whisper.Whisper_lang_max_id())
|
||||
for i := 0; i < whisper.Whisper_lang_max_id(); i++ {
|
||||
str := whisper.Whisper_lang_str(i)
|
||||
if model.ctx.Whisper_lang_id(str) >= 0 {
|
||||
result = append(result, str)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (model *model) NewContext() (Context, error) {
|
||||
if model.ctx == nil {
|
||||
return nil, ErrInternalAppError
|
||||
}
|
||||
|
||||
// Create new context
|
||||
params := model.ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY)
|
||||
params.SetTranslate(false)
|
||||
params.SetPrintSpecial(false)
|
||||
params.SetPrintProgress(false)
|
||||
params.SetPrintRealtime(false)
|
||||
params.SetPrintTimestamps(false)
|
||||
params.SetThreads(runtime.NumCPU())
|
||||
|
||||
// Return new context
|
||||
return newContext(model, params)
|
||||
}
|
BIN
bindings/go/samples/jfk.wav
Normal file
BIN
bindings/go/samples/jfk.wav
Normal file
Binary file not shown.
419
bindings/go/whisper.go
Normal file
419
bindings/go/whisper.go
Normal file
@ -0,0 +1,419 @@
|
||||
package whisper
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// CGO
|
||||
|
||||
/*
|
||||
#cgo LDFLAGS: -lwhisper -lm -lstdc++
|
||||
#cgo darwin LDFLAGS: -framework Accelerate
|
||||
#include <whisper.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
extern void callNewSegment(void* user_data, int new);
|
||||
extern bool callEncoderBegin(void* user_data);
|
||||
|
||||
// Text segment callback
|
||||
// Called on every newly generated text segment
|
||||
// Use the whisper_full_...() functions to obtain the text segments
|
||||
static void whisper_new_segment_cb(struct whisper_context* ctx, int n_new, void* user_data) {
|
||||
if(user_data != NULL && ctx != NULL) {
|
||||
callNewSegment(user_data, n_new);
|
||||
}
|
||||
}
|
||||
|
||||
// Encoder begin callback
|
||||
// If not NULL, called before the encoder starts
|
||||
// If it returns false, the computation is aborted
|
||||
static bool whisper_encoder_begin_cb(struct whisper_context* ctx, void* user_data) {
|
||||
if(user_data != NULL && ctx != NULL) {
|
||||
return callEncoderBegin(user_data);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Get default parameters and set callbacks
|
||||
static struct whisper_full_params whisper_full_default_params_cb(struct whisper_context* ctx, enum whisper_sampling_strategy strategy) {
|
||||
struct whisper_full_params params = whisper_full_default_params(strategy);
|
||||
params.new_segment_callback = whisper_new_segment_cb;
|
||||
params.new_segment_callback_user_data = (void*)(ctx);
|
||||
params.encoder_begin_callback = whisper_encoder_begin_cb;
|
||||
params.encoder_begin_callback_user_data = (void*)(ctx);
|
||||
return params;
|
||||
}
|
||||
*/
|
||||
import "C"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// TYPES
|
||||
|
||||
type (
|
||||
Context C.struct_whisper_context
|
||||
Token C.whisper_token
|
||||
TokenData C.struct_whisper_token_data
|
||||
SamplingStrategy C.enum_whisper_sampling_strategy
|
||||
Params C.struct_whisper_full_params
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GLOBALS
|
||||
|
||||
const (
|
||||
SAMPLING_GREEDY SamplingStrategy = C.WHISPER_SAMPLING_GREEDY
|
||||
SAMPLING_BEAM_SEARCH SamplingStrategy = C.WHISPER_SAMPLING_BEAM_SEARCH
|
||||
)
|
||||
|
||||
const (
|
||||
SampleRate = C.WHISPER_SAMPLE_RATE // Expected sample rate, samples per second
|
||||
SampleBits = uint16(unsafe.Sizeof(C.float(0))) * 8 // Sample size in bits
|
||||
NumFFT = C.WHISPER_N_FFT
|
||||
NumMEL = C.WHISPER_N_MEL
|
||||
HopLength = C.WHISPER_HOP_LENGTH
|
||||
ChunkSize = C.WHISPER_CHUNK_SIZE
|
||||
)
|
||||
|
||||
var (
|
||||
ErrTokenizerFailed = errors.New("whisper_tokenize failed")
|
||||
ErrAutoDetectFailed = errors.New("whisper_lang_auto_detect failed")
|
||||
ErrConversionFailed = errors.New("whisper_convert failed")
|
||||
ErrInvalidLanguage = errors.New("invalid language")
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// PUBLIC METHODS
|
||||
|
||||
// Allocates all memory needed for the model and loads the model from the given file.
|
||||
// Returns NULL on failure.
|
||||
func Whisper_init(path string) *Context {
|
||||
cPath := C.CString(path)
|
||||
defer C.free(unsafe.Pointer(cPath))
|
||||
if ctx := C.whisper_init_from_file(cPath); ctx != nil {
|
||||
return (*Context)(ctx)
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Frees all memory allocated by the model.
|
||||
func (ctx *Context) Whisper_free() {
|
||||
C.whisper_free((*C.struct_whisper_context)(ctx))
|
||||
}
|
||||
|
||||
// Convert RAW PCM audio to log mel spectrogram.
|
||||
// The resulting spectrogram is stored inside the provided whisper context.
|
||||
func (ctx *Context) Whisper_pcm_to_mel(data []float32, threads int) error {
|
||||
if C.whisper_pcm_to_mel((*C.struct_whisper_context)(ctx), (*C.float)(&data[0]), C.int(len(data)), C.int(threads)) == 0 {
|
||||
return nil
|
||||
} else {
|
||||
return ErrConversionFailed
|
||||
}
|
||||
}
|
||||
|
||||
// This can be used to set a custom log mel spectrogram inside the provided whisper context.
|
||||
// Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
|
||||
// n_mel must be 80
|
||||
func (ctx *Context) Whisper_set_mel(data []float32, n_mel int) error {
|
||||
if C.whisper_set_mel((*C.struct_whisper_context)(ctx), (*C.float)(&data[0]), C.int(len(data)), C.int(n_mel)) == 0 {
|
||||
return nil
|
||||
} else {
|
||||
return ErrConversionFailed
|
||||
}
|
||||
}
|
||||
|
||||
// Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper context.
|
||||
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
|
||||
// offset can be used to specify the offset of the first frame in the spectrogram.
|
||||
func (ctx *Context) Whisper_encode(offset, threads int) error {
|
||||
if C.whisper_encode((*C.struct_whisper_context)(ctx), C.int(offset), C.int(threads)) == 0 {
|
||||
return nil
|
||||
} else {
|
||||
return ErrConversionFailed
|
||||
}
|
||||
}
|
||||
|
||||
// Run the Whisper decoder to obtain the logits and probabilities for the next token.
|
||||
// Make sure to call whisper_encode() first.
|
||||
// tokens + n_tokens is the provided context for the decoder.
|
||||
// n_past is the number of tokens to use from previous decoder calls.
|
||||
func (ctx *Context) Whisper_decode(tokens []Token, past, threads int) error {
|
||||
if C.whisper_decode((*C.struct_whisper_context)(ctx), (*C.whisper_token)(&tokens[0]), C.int(len(tokens)), C.int(past), C.int(threads)) == 0 {
|
||||
return nil
|
||||
} else {
|
||||
return ErrConversionFailed
|
||||
}
|
||||
}
|
||||
|
||||
// whisper_sample_best() returns the token with the highest probability
|
||||
func (ctx *Context) Whisper_sample_best() TokenData {
|
||||
return TokenData(C.whisper_sample_best((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
||||
// whisper_sample_timestamp() returns the most probable timestamp token
|
||||
func (ctx *Context) Whisper_sample_timestamp(is_initial bool) TokenData {
|
||||
return TokenData(C.whisper_sample_timestamp((*C.struct_whisper_context)(ctx), C.bool(is_initial)))
|
||||
}
|
||||
|
||||
// Convert the provided text into tokens. The tokens pointer must be large enough to hold the resulting tokens.
|
||||
// Returns the number of tokens on success
|
||||
func (ctx *Context) Whisper_tokenize(text string, tokens []Token) (int, error) {
|
||||
cText := C.CString(text)
|
||||
defer C.free(unsafe.Pointer(cText))
|
||||
if n := C.whisper_tokenize((*C.struct_whisper_context)(ctx), cText, (*C.whisper_token)(&tokens[0]), C.int(len(tokens))); n >= 0 {
|
||||
return int(n), nil
|
||||
} else {
|
||||
return 0, ErrTokenizerFailed
|
||||
}
|
||||
}
|
||||
|
||||
// Return the id of the specified language, returns -1 if not found
|
||||
// Examples:
|
||||
//
|
||||
// "de" -> 2
|
||||
// "german" -> 2
|
||||
func (ctx *Context) Whisper_lang_id(lang string) int {
|
||||
return int(C.whisper_lang_id(C.CString(lang)))
|
||||
}
|
||||
|
||||
// Largest language id (i.e. number of available languages - 1)
|
||||
func Whisper_lang_max_id() int {
|
||||
return int(C.whisper_lang_max_id())
|
||||
}
|
||||
|
||||
// Return the short string of the specified language id (e.g. 2 -> "de"),
|
||||
// returns empty string if not found
|
||||
func Whisper_lang_str(id int) string {
|
||||
return C.GoString(C.whisper_lang_str(C.int(id)))
|
||||
}
|
||||
|
||||
// Use mel data at offset_ms to try and auto-detect the spoken language
|
||||
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
|
||||
// Returns the probabilities of all languages.
|
||||
// ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69
|
||||
func (ctx *Context) Whisper_lang_auto_detect(offset_ms, n_threads int) ([]float32, error) {
|
||||
probs := make([]float32, Whisper_lang_max_id()+1)
|
||||
if n := int(C.whisper_lang_auto_detect((*C.struct_whisper_context)(ctx), C.int(offset_ms), C.int(n_threads), (*C.float)(&probs[0]))); n < 0 {
|
||||
return nil, ErrAutoDetectFailed
|
||||
} else {
|
||||
return probs, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (ctx *Context) Whisper_n_len() int {
|
||||
return int(C.whisper_n_len((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
||||
func (ctx *Context) Whisper_n_vocab() int {
|
||||
return int(C.whisper_n_vocab((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
||||
func (ctx *Context) Whisper_n_text_ctx() int {
|
||||
return int(C.whisper_n_text_ctx((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
||||
func (ctx *Context) Whisper_n_audio_ctx() int {
|
||||
return int(C.whisper_n_audio_ctx((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
||||
func (ctx *Context) Whisper_is_multilingual() int {
|
||||
return int(C.whisper_is_multilingual((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
||||
// The probabilities for the next token
|
||||
//func (ctx *Whisper_context) Whisper_get_probs() []float32 {
|
||||
// return (*[1 << 30]float32)(unsafe.Pointer(C.whisper_get_probs((*C.struct_whisper_context)(ctx))))[:ctx.Whisper_n_vocab()]
|
||||
//}
|
||||
|
||||
// Token Id -> String. Uses the vocabulary in the provided context
|
||||
func (ctx *Context) Whisper_token_to_str(token Token) string {
|
||||
return C.GoString(C.whisper_token_to_str((*C.struct_whisper_context)(ctx), C.whisper_token(token)))
|
||||
}
|
||||
|
||||
// Special tokens
|
||||
func (ctx *Context) Whisper_token_eot() Token {
|
||||
return Token(C.whisper_token_eot((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
||||
// Special tokens
|
||||
func (ctx *Context) Whisper_token_sot() Token {
|
||||
return Token(C.whisper_token_sot((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
||||
// Special tokens
|
||||
func (ctx *Context) Whisper_token_prev() Token {
|
||||
return Token(C.whisper_token_prev((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
||||
// Special tokens
|
||||
func (ctx *Context) Whisper_token_solm() Token {
|
||||
return Token(C.whisper_token_solm((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
||||
// Special tokens
|
||||
func (ctx *Context) Whisper_token_not() Token {
|
||||
return Token(C.whisper_token_not((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
||||
// Special tokens
|
||||
func (ctx *Context) Whisper_token_beg() Token {
|
||||
return Token(C.whisper_token_beg((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
||||
// Special tokens
|
||||
func (ctx *Context) Whisper_token_lang(lang_id int) Token {
|
||||
return Token(C.whisper_token_lang((*C.struct_whisper_context)(ctx), C.int(lang_id)))
|
||||
}
|
||||
|
||||
// Task tokens
|
||||
func Whisper_token_translate() Token {
|
||||
return Token(C.whisper_token_translate())
|
||||
}
|
||||
|
||||
// Task tokens
|
||||
func Whisper_token_transcribe() Token {
|
||||
return Token(C.whisper_token_transcribe())
|
||||
}
|
||||
|
||||
// Performance information
|
||||
func (ctx *Context) Whisper_print_timings() {
|
||||
C.whisper_print_timings((*C.struct_whisper_context)(ctx))
|
||||
}
|
||||
|
||||
// Performance information
|
||||
func (ctx *Context) Whisper_reset_timings() {
|
||||
C.whisper_reset_timings((*C.struct_whisper_context)(ctx))
|
||||
}
|
||||
|
||||
// Print system information
|
||||
func Whisper_print_system_info() string {
|
||||
return C.GoString(C.whisper_print_system_info())
|
||||
}
|
||||
|
||||
// Return default parameters for a strategy
|
||||
func (ctx *Context) Whisper_full_default_params(strategy SamplingStrategy) Params {
|
||||
// Get default parameters
|
||||
return Params(C.whisper_full_default_params_cb((*C.struct_whisper_context)(ctx), C.enum_whisper_sampling_strategy(strategy)))
|
||||
}
|
||||
|
||||
// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
|
||||
// Uses the specified decoding strategy to obtain the text.
|
||||
func (ctx *Context) Whisper_full(params Params, samples []float32, encoderBeginCallback func() bool, newSegmentCallback func(int)) error {
|
||||
registerEncoderBeginCallback(ctx, encoderBeginCallback)
|
||||
registerNewSegmentCallback(ctx, newSegmentCallback)
|
||||
defer registerEncoderBeginCallback(ctx, nil)
|
||||
defer registerNewSegmentCallback(ctx, nil)
|
||||
if C.whisper_full((*C.struct_whisper_context)(ctx), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples))) == 0 {
|
||||
return nil
|
||||
} else {
|
||||
return ErrConversionFailed
|
||||
}
|
||||
}
|
||||
|
||||
// Split the input audio in chunks and process each chunk separately using whisper_full()
|
||||
// It seems this approach can offer some speedup in some cases.
|
||||
// However, the transcription accuracy can be worse at the beginning and end of each chunk.
|
||||
func (ctx *Context) Whisper_full_parallel(params Params, samples []float32, processors int, encoderBeginCallback func() bool, newSegmentCallback func(int)) error {
|
||||
registerEncoderBeginCallback(ctx, encoderBeginCallback)
|
||||
registerNewSegmentCallback(ctx, newSegmentCallback)
|
||||
defer registerEncoderBeginCallback(ctx, nil)
|
||||
defer registerNewSegmentCallback(ctx, nil)
|
||||
|
||||
if C.whisper_full_parallel((*C.struct_whisper_context)(ctx), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples)), C.int(processors)) == 0 {
|
||||
return nil
|
||||
} else {
|
||||
return ErrConversionFailed
|
||||
}
|
||||
}
|
||||
|
||||
// Number of generated text segments.
|
||||
// A segment can be a few words, a sentence, or even a paragraph.
|
||||
func (ctx *Context) Whisper_full_n_segments() int {
|
||||
return int(C.whisper_full_n_segments((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
||||
// Get the start and end time of the specified segment.
|
||||
func (ctx *Context) Whisper_full_get_segment_t0(segment int) int64 {
|
||||
return int64(C.whisper_full_get_segment_t0((*C.struct_whisper_context)(ctx), C.int(segment)))
|
||||
}
|
||||
|
||||
// Get the start and end time of the specified segment.
|
||||
func (ctx *Context) Whisper_full_get_segment_t1(segment int) int64 {
|
||||
return int64(C.whisper_full_get_segment_t1((*C.struct_whisper_context)(ctx), C.int(segment)))
|
||||
}
|
||||
|
||||
// Get the text of the specified segment.
|
||||
func (ctx *Context) Whisper_full_get_segment_text(segment int) string {
|
||||
return C.GoString(C.whisper_full_get_segment_text((*C.struct_whisper_context)(ctx), C.int(segment)))
|
||||
}
|
||||
|
||||
// Get number of tokens in the specified segment.
|
||||
func (ctx *Context) Whisper_full_n_tokens(segment int) int {
|
||||
return int(C.whisper_full_n_tokens((*C.struct_whisper_context)(ctx), C.int(segment)))
|
||||
}
|
||||
|
||||
// Get the token text of the specified token index in the specified segment.
|
||||
func (ctx *Context) Whisper_full_get_token_text(segment int, token int) string {
|
||||
return C.GoString(C.whisper_full_get_token_text((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
|
||||
}
|
||||
|
||||
// Get the token of the specified token index in the specified segment.
|
||||
func (ctx *Context) Whisper_full_get_token_id(segment int, token int) Token {
|
||||
return Token(C.whisper_full_get_token_id((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
|
||||
}
|
||||
|
||||
// Get token data for the specified token in the specified segment.
|
||||
// This contains probabilities, timestamps, etc.
|
||||
func (ctx *Context) whisper_full_get_token_data(segment int, token int) TokenData {
|
||||
return TokenData(C.whisper_full_get_token_data((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
|
||||
}
|
||||
|
||||
// Get the probability of the specified token in the specified segment.
|
||||
func (ctx *Context) Whisper_full_get_token_p(segment int, token int) float32 {
|
||||
return float32(C.whisper_full_get_token_p((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// CALLBACKS
|
||||
|
||||
var (
|
||||
cbNewSegment = make(map[unsafe.Pointer]func(int))
|
||||
cbEncoderBegin = make(map[unsafe.Pointer]func() bool)
|
||||
)
|
||||
|
||||
func registerNewSegmentCallback(ctx *Context, fn func(int)) {
|
||||
if fn == nil {
|
||||
delete(cbNewSegment, unsafe.Pointer(ctx))
|
||||
} else {
|
||||
cbNewSegment[unsafe.Pointer(ctx)] = fn
|
||||
}
|
||||
}
|
||||
|
||||
func registerEncoderBeginCallback(ctx *Context, fn func() bool) {
|
||||
if fn == nil {
|
||||
delete(cbEncoderBegin, unsafe.Pointer(ctx))
|
||||
} else {
|
||||
cbEncoderBegin[unsafe.Pointer(ctx)] = fn
|
||||
}
|
||||
}
|
||||
|
||||
//export callNewSegment
|
||||
func callNewSegment(user_data unsafe.Pointer, new C.int) {
|
||||
if fn, ok := cbNewSegment[user_data]; ok {
|
||||
fn(int(new))
|
||||
}
|
||||
}
|
||||
|
||||
//export callEncoderBegin
|
||||
func callEncoderBegin(user_data unsafe.Pointer) C.bool {
|
||||
if fn, ok := cbEncoderBegin[user_data]; ok {
|
||||
if fn() {
|
||||
return C.bool(true)
|
||||
} else {
|
||||
return C.bool(false)
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
113
bindings/go/whisper_test.go
Normal file
113
bindings/go/whisper_test.go
Normal file
@ -0,0 +1,113 @@
|
||||
package whisper_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
// Packages
|
||||
whisper "github.com/ggerganov/whisper.cpp/bindings/go"
|
||||
wav "github.com/go-audio/wav"
|
||||
assert "github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const (
|
||||
ModelPath = "models/ggml-small.en.bin"
|
||||
SamplePath = "samples/jfk.wav"
|
||||
)
|
||||
|
||||
func Test_Whisper_000(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, model not found:", ModelPath)
|
||||
}
|
||||
ctx := whisper.Whisper_init(ModelPath)
|
||||
assert.NotNil(ctx)
|
||||
ctx.Whisper_free()
|
||||
}
|
||||
|
||||
func Test_Whisper_001(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, model not found:", ModelPath)
|
||||
}
|
||||
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, sample not found:", SamplePath)
|
||||
}
|
||||
|
||||
// Open samples
|
||||
fh, err := os.Open(SamplePath)
|
||||
assert.NoError(err)
|
||||
defer fh.Close()
|
||||
|
||||
// Read samples
|
||||
d := wav.NewDecoder(fh)
|
||||
buf, err := d.FullPCMBuffer()
|
||||
assert.NoError(err)
|
||||
|
||||
// Run whisper
|
||||
ctx := whisper.Whisper_init(ModelPath)
|
||||
assert.NotNil(ctx)
|
||||
defer ctx.Whisper_free()
|
||||
params := ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY)
|
||||
data := buf.AsFloat32Buffer().Data
|
||||
err = ctx.Whisper_full(params, data, nil, nil)
|
||||
assert.NoError(err)
|
||||
|
||||
// Print out tokens
|
||||
num_segments := ctx.Whisper_full_n_segments()
|
||||
assert.GreaterOrEqual(num_segments, 1)
|
||||
for i := 0; i < num_segments; i++ {
|
||||
str := ctx.Whisper_full_get_segment_text(i)
|
||||
assert.NotEmpty(str)
|
||||
t0 := time.Duration(ctx.Whisper_full_get_segment_t0(i)) * time.Millisecond
|
||||
t1 := time.Duration(ctx.Whisper_full_get_segment_t1(i)) * time.Millisecond
|
||||
t.Logf("[%6s->%-6s] %q", t0, t1, str)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Whisper_002(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
for i := 0; i < whisper.Whisper_lang_max_id(); i++ {
|
||||
str := whisper.Whisper_lang_str(i)
|
||||
assert.NotEmpty(str)
|
||||
t.Log(str)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Whisper_003(t *testing.T) {
|
||||
threads := runtime.NumCPU()
|
||||
assert := assert.New(t)
|
||||
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, model not found:", ModelPath)
|
||||
}
|
||||
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, sample not found:", SamplePath)
|
||||
}
|
||||
|
||||
// Open samples
|
||||
fh, err := os.Open(SamplePath)
|
||||
assert.NoError(err)
|
||||
defer fh.Close()
|
||||
|
||||
// Read samples
|
||||
d := wav.NewDecoder(fh)
|
||||
buf, err := d.FullPCMBuffer()
|
||||
assert.NoError(err)
|
||||
|
||||
// Make the model
|
||||
ctx := whisper.Whisper_init(ModelPath)
|
||||
assert.NotNil(ctx)
|
||||
defer ctx.Whisper_free()
|
||||
|
||||
// Get MEL
|
||||
assert.NoError(ctx.Whisper_pcm_to_mel(buf.AsFloat32Buffer().Data, threads))
|
||||
|
||||
// Get Languages
|
||||
languages, err := ctx.Whisper_lang_auto_detect(0, threads)
|
||||
assert.NoError(err)
|
||||
for i, p := range languages {
|
||||
t.Logf("%s: %f", whisper.Whisper_lang_str(i), p)
|
||||
}
|
||||
}
|
1
bindings/ios
Submodule
1
bindings/ios
Submodule
Submodule bindings/ios added at f6334b026f
1
bindings/javascript/.gitignore
vendored
Normal file
1
bindings/javascript/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
publish.log
|
41
bindings/javascript/CMakeLists.txt
Normal file
41
bindings/javascript/CMakeLists.txt
Normal file
@ -0,0 +1,41 @@
|
||||
set(TARGET libwhisper)
|
||||
|
||||
add_executable(${TARGET}
|
||||
emscripten.cpp
|
||||
)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE
|
||||
whisper
|
||||
)
|
||||
|
||||
unset(EXTRA_FLAGS)
|
||||
|
||||
if (WHISPER_WASM_SINGLE_FILE)
|
||||
set(EXTRA_FLAGS "-s SINGLE_FILE=1")
|
||||
message(STATUS "Embedding WASM inside whisper.js")
|
||||
|
||||
add_custom_command(
|
||||
TARGET ${TARGET} POST_BUILD
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${CMAKE_BINARY_DIR}/bin/libwhisper.js
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/whisper.js
|
||||
)
|
||||
|
||||
add_custom_command(
|
||||
TARGET ${TARGET} POST_BUILD
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${CMAKE_BINARY_DIR}/bin/libwhisper.worker.js
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/libwhisper.worker.js
|
||||
)
|
||||
endif()
|
||||
|
||||
set_target_properties(${TARGET} PROPERTIES LINK_FLAGS " \
|
||||
--bind \
|
||||
-s MODULARIZE=1 \
|
||||
-s EXPORT_NAME=\"'whisper_factory'\" \
|
||||
-s FORCE_FILESYSTEM=1 \
|
||||
-s USE_PTHREADS=1 \
|
||||
-s PTHREAD_POOL_SIZE=8 \
|
||||
-s ALLOW_MEMORY_GROWTH=1 \
|
||||
${EXTRA_FLAGS} \
|
||||
")
|
78
bindings/javascript/README.md
Normal file
78
bindings/javascript/README.md
Normal file
@ -0,0 +1,78 @@
|
||||
# whisper.cpp
|
||||
|
||||
Node.js package for Whisper speech recognition
|
||||
|
||||
Package: https://www.npmjs.com/package/whisper.cpp
|
||||
|
||||
## Details
|
||||
|
||||
The performance is comparable to when running `whisper.cpp` in the browser via WASM.
|
||||
|
||||
The API is currently very rudimentary: [bindings/javascript/emscripten.cpp](/bindings/javascript/emscripten.cpp)
|
||||
|
||||
For sample usage check [tests/test-whisper.js](/tests/test-whisper.js)
|
||||
|
||||
## Package building + test
|
||||
|
||||
```bash
|
||||
# load emscripten
|
||||
source /path/to/emsdk/emsdk_env.sh
|
||||
|
||||
# clone repo
|
||||
git clone https://github.com/ggerganov/whisper.cpp
|
||||
cd whisper.cpp
|
||||
|
||||
# grab base.en model
|
||||
./models/download-ggml-model.sh base.en
|
||||
|
||||
# prepare PCM sample for testing
|
||||
ffmpeg -i samples/jfk.wav -f f32le -acodec pcm_f32le samples/jfk.pcmf32
|
||||
|
||||
# build
|
||||
mkdir build-em && cd build-em
|
||||
emcmake cmake .. && make -j
|
||||
|
||||
# run test
|
||||
node --experimental-wasm-threads --experimental-wasm-simd ../tests/test-whisper.js
|
||||
|
||||
# publish npm package
|
||||
make publish-npm
|
||||
```
|
||||
|
||||
## Sample run
|
||||
|
||||
```java
|
||||
$ node --experimental-wasm-threads --experimental-wasm-simd ../tests/test-whisper.js
|
||||
|
||||
whisper_model_load: loading model from 'whisper.bin'
|
||||
whisper_model_load: n_vocab = 51864
|
||||
whisper_model_load: n_audio_ctx = 1500
|
||||
whisper_model_load: n_audio_state = 512
|
||||
whisper_model_load: n_audio_head = 8
|
||||
whisper_model_load: n_audio_layer = 6
|
||||
whisper_model_load: n_text_ctx = 448
|
||||
whisper_model_load: n_text_state = 512
|
||||
whisper_model_load: n_text_head = 8
|
||||
whisper_model_load: n_text_layer = 6
|
||||
whisper_model_load: n_mels = 80
|
||||
whisper_model_load: f16 = 1
|
||||
whisper_model_load: type = 2
|
||||
whisper_model_load: adding 1607 extra tokens
|
||||
whisper_model_load: mem_required = 506.00 MB
|
||||
whisper_model_load: ggml ctx size = 140.60 MB
|
||||
whisper_model_load: memory size = 22.83 MB
|
||||
whisper_model_load: model size = 140.54 MB
|
||||
|
||||
system_info: n_threads = 8 / 10 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | NEON = 0 | F16C = 0 | FP16_VA = 0 | WASM_SIMD = 1 | BLAS = 0 |
|
||||
|
||||
operator(): processing 176000 samples, 11.0 sec, 8 threads, 1 processors, lang = en, task = transcribe ...
|
||||
|
||||
[00:00:00.000 --> 00:00:11.000] And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country.
|
||||
|
||||
whisper_print_timings: load time = 162.37 ms
|
||||
whisper_print_timings: mel time = 183.70 ms
|
||||
whisper_print_timings: sample time = 4.27 ms
|
||||
whisper_print_timings: encode time = 8582.63 ms / 1430.44 ms per layer
|
||||
whisper_print_timings: decode time = 436.16 ms / 72.69 ms per layer
|
||||
whisper_print_timings: total time = 9370.90 ms
|
||||
```
|
93
bindings/javascript/emscripten.cpp
Normal file
93
bindings/javascript/emscripten.cpp
Normal file
@ -0,0 +1,93 @@
|
||||
//
|
||||
// This is the Javascript API of whisper.cpp
|
||||
//
|
||||
// Very crude at the moment.
|
||||
// Feel free to contribute and make this better!
|
||||
//
|
||||
// See the tests/test-whisper.js for sample usage
|
||||
//
|
||||
|
||||
#include "whisper.h"
|
||||
|
||||
#include <emscripten.h>
|
||||
#include <emscripten/bind.h>
|
||||
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
struct whisper_context * g_context;
|
||||
|
||||
EMSCRIPTEN_BINDINGS(whisper) {
|
||||
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
|
||||
if (g_context == nullptr) {
|
||||
g_context = whisper_init_from_file(path_model.c_str());
|
||||
if (g_context != nullptr) {
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}));
|
||||
|
||||
emscripten::function("free", emscripten::optional_override([]() {
|
||||
if (g_context) {
|
||||
whisper_free(g_context);
|
||||
g_context = nullptr;
|
||||
}
|
||||
}));
|
||||
|
||||
emscripten::function("full_default", emscripten::optional_override([](const emscripten::val & audio, const std::string & lang, bool translate) {
|
||||
if (g_context == nullptr) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
struct whisper_full_params params = whisper_full_default_params(whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY);
|
||||
|
||||
params.print_realtime = true;
|
||||
params.print_progress = false;
|
||||
params.print_timestamps = true;
|
||||
params.print_special = false;
|
||||
params.translate = translate;
|
||||
params.language = whisper_is_multilingual(g_context) ? lang.c_str() : "en";
|
||||
params.n_threads = std::min(8, (int) std::thread::hardware_concurrency());
|
||||
params.offset_ms = 0;
|
||||
|
||||
std::vector<float> pcmf32;
|
||||
const int n = audio["length"].as<int>();
|
||||
|
||||
emscripten::val heap = emscripten::val::module_property("HEAPU8");
|
||||
emscripten::val memory = heap["buffer"];
|
||||
|
||||
pcmf32.resize(n);
|
||||
|
||||
emscripten::val memoryView = audio["constructor"].new_(memory, reinterpret_cast<uintptr_t>(pcmf32.data()), n);
|
||||
memoryView.call<void>("set", audio);
|
||||
|
||||
// print system information
|
||||
{
|
||||
printf("\n");
|
||||
printf("system_info: n_threads = %d / %d | %s\n",
|
||||
params.n_threads, std::thread::hardware_concurrency(), whisper_print_system_info());
|
||||
|
||||
printf("\n");
|
||||
printf("%s: processing %d samples, %.1f sec, %d threads, %d processors, lang = %s, task = %s ...\n",
|
||||
__func__, int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
|
||||
params.n_threads, 1,
|
||||
params.language,
|
||||
params.translate ? "translate" : "transcribe");
|
||||
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
// run whisper
|
||||
{
|
||||
whisper_reset_timings(g_context);
|
||||
whisper_full(g_context, params, pcmf32.data(), pcmf32.size());
|
||||
whisper_print_timings(g_context);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}));
|
||||
}
|
1
bindings/javascript/libwhisper.worker.js
Normal file
1
bindings/javascript/libwhisper.worker.js
Normal file
@ -0,0 +1 @@
|
||||
"use strict";var Module={};var ENVIRONMENT_IS_NODE=typeof process=="object"&&typeof process.versions=="object"&&typeof process.versions.node=="string";if(ENVIRONMENT_IS_NODE){var nodeWorkerThreads=require("worker_threads");var parentPort=nodeWorkerThreads.parentPort;parentPort.on("message",data=>onmessage({data:data}));var fs=require("fs");Object.assign(global,{self:global,require:require,Module:Module,location:{href:__filename},Worker:nodeWorkerThreads.Worker,importScripts:function(f){(0,eval)(fs.readFileSync(f,"utf8")+"//# sourceURL="+f)},postMessage:function(msg){parentPort.postMessage(msg)},performance:global.performance||{now:function(){return Date.now()}}})}var initializedJS=false;var pendingNotifiedProxyingQueues=[];function threadPrintErr(){var text=Array.prototype.slice.call(arguments).join(" ");if(ENVIRONMENT_IS_NODE){fs.writeSync(2,text+"\n");return}console.error(text)}function threadAlert(){var text=Array.prototype.slice.call(arguments).join(" ");postMessage({cmd:"alert",text:text,threadId:Module["_pthread_self"]()})}var err=threadPrintErr;self.alert=threadAlert;Module["instantiateWasm"]=(info,receiveInstance)=>{var instance=new WebAssembly.Instance(Module["wasmModule"],info);receiveInstance(instance);Module["wasmModule"]=null;return instance.exports};self.onunhandledrejection=e=>{throw e.reason??e};self.onmessage=e=>{try{if(e.data.cmd==="load"){Module["wasmModule"]=e.data.wasmModule;for(const handler of e.data.handlers){Module[handler]=function(){postMessage({cmd:"callHandler",handler:handler,args:[...arguments]})}}Module["wasmMemory"]=e.data.wasmMemory;Module["buffer"]=Module["wasmMemory"].buffer;Module["ENVIRONMENT_IS_PTHREAD"]=true;if(typeof e.data.urlOrBlob=="string"){importScripts(e.data.urlOrBlob)}else{var objectUrl=URL.createObjectURL(e.data.urlOrBlob);importScripts(objectUrl);URL.revokeObjectURL(objectUrl)}whisper_factory(Module).then(function(instance){Module=instance})}else if(e.data.cmd==="run"){Module["__performance_now_clock_drift"]=performance.now()-e.data.time;Module["__emscripten_thread_init"](e.data.pthread_ptr,0,0,1);Module["establishStackSpace"]();Module["PThread"].receiveObjectTransfer(e.data);Module["PThread"].threadInitTLS();if(!initializedJS){Module["__embind_initialize_bindings"]();pendingNotifiedProxyingQueues.forEach(queue=>{Module["executeNotifiedProxyingQueue"](queue)});pendingNotifiedProxyingQueues=[];initializedJS=true}try{Module["invokeEntryPoint"](e.data.start_routine,e.data.arg)}catch(ex){if(ex!="unwind"){if(ex instanceof Module["ExitStatus"]){if(Module["keepRuntimeAlive"]()){}else{Module["__emscripten_thread_exit"](ex.status)}}else{throw ex}}}}else if(e.data.cmd==="cancel"){if(Module["_pthread_self"]()){Module["__emscripten_thread_exit"](-1)}}else if(e.data.target==="setimmediate"){}else if(e.data.cmd==="processProxyingQueue"){if(initializedJS){Module["executeNotifiedProxyingQueue"](e.data.queue)}else{pendingNotifiedProxyingQueues.push(e.data.queue)}}else if(e.data.cmd){err("worker.js received unknown command "+e.data.cmd);err(e.data)}}catch(ex){if(Module["__emscripten_thread_crashed"]){Module["__emscripten_thread_crashed"]()}throw ex}};
|
26
bindings/javascript/package-tmpl.json
Normal file
26
bindings/javascript/package-tmpl.json
Normal file
@ -0,0 +1,26 @@
|
||||
{
|
||||
"name": "whisper.cpp",
|
||||
"version": "@PROJECT_VERSION@",
|
||||
"description": "Whisper speech recognition",
|
||||
"main": "whisper.js",
|
||||
"scripts": {
|
||||
"test": "echo \"todo: add tests\" && exit 0"
|
||||
},
|
||||
"repository": {
|
||||
"type": "git",
|
||||
"url": "git+https://github.com/ggerganov/whisper.cpp"
|
||||
},
|
||||
"keywords": [
|
||||
"openai",
|
||||
"whisper",
|
||||
"speech-to-text",
|
||||
"speech-recognition",
|
||||
"transformer"
|
||||
],
|
||||
"author": "Georgi Gerganov",
|
||||
"license": "MIT",
|
||||
"bugs": {
|
||||
"url": "https://github.com/ggerganov/whisper.cpp/issues"
|
||||
},
|
||||
"homepage": "https://github.com/ggerganov/whisper.cpp#readme"
|
||||
}
|
26
bindings/javascript/package.json
Normal file
26
bindings/javascript/package.json
Normal file
@ -0,0 +1,26 @@
|
||||
{
|
||||
"name": "whisper.cpp",
|
||||
"version": "1.1.0",
|
||||
"description": "Whisper speech recognition",
|
||||
"main": "whisper.js",
|
||||
"scripts": {
|
||||
"test": "echo \"todo: add tests\" && exit 0"
|
||||
},
|
||||
"repository": {
|
||||
"type": "git",
|
||||
"url": "git+https://github.com/ggerganov/whisper.cpp"
|
||||
},
|
||||
"keywords": [
|
||||
"openai",
|
||||
"whisper",
|
||||
"speech-to-text",
|
||||
"speech-recognition",
|
||||
"transformer"
|
||||
],
|
||||
"author": "Georgi Gerganov",
|
||||
"license": "MIT",
|
||||
"bugs": {
|
||||
"url": "https://github.com/ggerganov/whisper.cpp/issues"
|
||||
},
|
||||
"homepage": "https://github.com/ggerganov/whisper.cpp#readme"
|
||||
}
|
21
bindings/javascript/whisper.js
Normal file
21
bindings/javascript/whisper.js
Normal file
File diff suppressed because one or more lines are too long
54
cmake/BuildTypes.cmake
Normal file
54
cmake/BuildTypes.cmake
Normal file
@ -0,0 +1,54 @@
|
||||
# Add new build types
|
||||
|
||||
# ReleaseGG - Release with enabled asserts
|
||||
|
||||
SET(CMAKE_CXX_FLAGS_RELEASEGG
|
||||
"-O3"
|
||||
CACHE STRING "Flags used by the c++ compiler during release builds with enabled asserts."
|
||||
FORCE )
|
||||
SET(CMAKE_C_FLAGS_RELEASEGG
|
||||
"-O3"
|
||||
CACHE STRING "Flags used by the compiler during release builds with enabled asserts."
|
||||
FORCE )
|
||||
SET(CMAKE_EXE_LINKER_FLAGS_RELEASEGG
|
||||
""
|
||||
CACHE STRING "Flags used for linking binaries during release builds with enabled asserts."
|
||||
FORCE )
|
||||
SET(CMAKE_SHARED_LINKER_FLAGS_RELEASEGG
|
||||
""
|
||||
CACHE STRING "Flags used by the shared libraries linker during release builds with enabled asserts."
|
||||
FORCE )
|
||||
MARK_AS_ADVANCED(
|
||||
CMAKE_CXX_FLAGS_RELEASEGG
|
||||
CMAKE_C_FLAGS_RELEASEGG
|
||||
CMAKE_EXE_LINKER_FLAGS_RELEASEGG
|
||||
CMAKE_SHARED_LINKER_FLAGS_RELEASEGG )
|
||||
|
||||
# RelWithDebInfoGG - RelWithDebInfo with enabled asserts
|
||||
|
||||
SET(CMAKE_CXX_FLAGS_RELWITHDEBINFOGG
|
||||
"-O2 -g"
|
||||
CACHE STRING "Flags used by the c++ compiler during release builds with debug symbols and enabled asserts."
|
||||
FORCE )
|
||||
SET(CMAKE_C_FLAGS_RELWITHDEBINFOGG
|
||||
"-O2 -g"
|
||||
CACHE STRING "Flags used by the compiler during release builds with debug symbols and enabled asserts."
|
||||
FORCE )
|
||||
SET(CMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFOGG
|
||||
""
|
||||
CACHE STRING "Flags used for linking binaries during release builds with debug symbols and enabled asserts."
|
||||
FORCE )
|
||||
SET(CMAKE_SHARED_LINKER_FLAGS_RELWITHDEBINFOGG
|
||||
""
|
||||
CACHE STRING "Flags used by the shared libraries linker during release builds with debug symbols and enabled asserts."
|
||||
FORCE )
|
||||
MARK_AS_ADVANCED(
|
||||
CMAKE_CXX_FLAGS_RELWITHDEBINFOGG
|
||||
CMAKE_C_FLAGS_RELWITHDEBINFOGG
|
||||
CMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFOGG
|
||||
CMAKE_SHARED_LINKER_FLAGS_RELWITHDEBINFOGG )
|
||||
|
||||
if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE)
|
||||
set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE)
|
||||
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo" "ReleaseGG" "RelWithDebInfoGG")
|
||||
endif()
|
17
cmake/DefaultTargetOptions.cmake
Normal file
17
cmake/DefaultTargetOptions.cmake
Normal file
@ -0,0 +1,17 @@
|
||||
# Set the default compile features and properties for a target.
|
||||
|
||||
if (NOT TARGET)
|
||||
message(FATAL_ERROR "TARGET not set before including DefaultTargetOptions")
|
||||
endif()
|
||||
|
||||
target_compile_features(${TARGET}
|
||||
PRIVATE
|
||||
cxx_std_11
|
||||
)
|
||||
|
||||
set_target_properties(${TARGET}
|
||||
PROPERTIES
|
||||
EXPORT_COMPILE_COMMANDS ON
|
||||
RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin"
|
||||
INSTALL_RPATH "${CMAKE_INSTALL_PREFIX}/lib"
|
||||
)
|
22
cmake/GitVars.cmake
Normal file
22
cmake/GitVars.cmake
Normal file
@ -0,0 +1,22 @@
|
||||
find_package(Git)
|
||||
|
||||
# the commit's SHA1
|
||||
execute_process(COMMAND
|
||||
"${GIT_EXECUTABLE}" describe --match=NeVeRmAtCh --always --abbrev=8
|
||||
WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}"
|
||||
OUTPUT_VARIABLE GIT_SHA1
|
||||
ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE)
|
||||
|
||||
# the date of the commit
|
||||
execute_process(COMMAND
|
||||
"${GIT_EXECUTABLE}" log -1 --format=%ad --date=local
|
||||
WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}"
|
||||
OUTPUT_VARIABLE GIT_DATE
|
||||
ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE)
|
||||
|
||||
# the subject of the commit
|
||||
execute_process(COMMAND
|
||||
"${GIT_EXECUTABLE}" log -1 --format=%s
|
||||
WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}"
|
||||
OUTPUT_VARIABLE GIT_COMMIT_SUBJECT
|
||||
ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE)
|
33
examples/CMakeLists.txt
Normal file
33
examples/CMakeLists.txt
Normal file
@ -0,0 +1,33 @@
|
||||
# dependencies
|
||||
|
||||
find_package(Threads REQUIRED)
|
||||
|
||||
# third-party
|
||||
|
||||
if (WHISPER_SUPPORT_SDL2)
|
||||
# SDL2
|
||||
find_package(SDL2 REQUIRED)
|
||||
|
||||
string(STRIP "${SDL2_LIBRARIES}" SDL2_LIBRARIES)
|
||||
|
||||
message(STATUS "SDL2_INCLUDE_DIRS = ${SDL2_INCLUDE_DIRS}")
|
||||
message(STATUS "SDL2_LIBRARIES = ${SDL2_LIBRARIES}")
|
||||
endif()
|
||||
|
||||
# examples
|
||||
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
|
||||
|
||||
if (EMSCRIPTEN)
|
||||
add_subdirectory(whisper.wasm)
|
||||
add_subdirectory(stream.wasm)
|
||||
add_subdirectory(command.wasm)
|
||||
add_subdirectory(talk.wasm)
|
||||
add_subdirectory(bench.wasm)
|
||||
else()
|
||||
add_subdirectory(main)
|
||||
add_subdirectory(stream)
|
||||
add_subdirectory(command)
|
||||
add_subdirectory(bench)
|
||||
add_subdirectory(talk)
|
||||
endif()
|
49
examples/bench.wasm/CMakeLists.txt
Normal file
49
examples/bench.wasm/CMakeLists.txt
Normal file
@ -0,0 +1,49 @@
|
||||
#
|
||||
# libbench
|
||||
#
|
||||
|
||||
set(TARGET libbench)
|
||||
|
||||
add_executable(${TARGET}
|
||||
emscripten.cpp
|
||||
)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE
|
||||
whisper
|
||||
)
|
||||
|
||||
unset(EXTRA_FLAGS)
|
||||
|
||||
if (WHISPER_WASM_SINGLE_FILE)
|
||||
set(EXTRA_FLAGS "-s SINGLE_FILE=1")
|
||||
message(STATUS "Embedding WASM inside bench.js")
|
||||
|
||||
add_custom_command(
|
||||
TARGET ${TARGET} POST_BUILD
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${CMAKE_BINARY_DIR}/bin/libbench.js
|
||||
${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/bench.wasm/bench.js
|
||||
)
|
||||
endif()
|
||||
|
||||
set_target_properties(${TARGET} PROPERTIES LINK_FLAGS " \
|
||||
--bind \
|
||||
-s USE_PTHREADS=1 \
|
||||
-s PTHREAD_POOL_SIZE=8 \
|
||||
-s INITIAL_MEMORY=1024MB \
|
||||
-s TOTAL_MEMORY=1024MB \
|
||||
-s FORCE_FILESYSTEM=1 \
|
||||
-s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap']\" \
|
||||
${EXTRA_FLAGS} \
|
||||
")
|
||||
|
||||
#
|
||||
# bench.wasm
|
||||
#
|
||||
|
||||
set(TARGET bench.wasm)
|
||||
|
||||
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/index-tmpl.html ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TARGET}/index.html @ONLY)
|
||||
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/../helpers.js ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TARGET}/helpers.js @ONLY)
|
22
examples/bench.wasm/README.md
Normal file
22
examples/bench.wasm/README.md
Normal file
@ -0,0 +1,22 @@
|
||||
# bench.wasm
|
||||
|
||||
Benchmark the performance of whisper.cpp in the browser using WebAssembly
|
||||
|
||||
Link: https://whisper.ggerganov.com/bench/
|
||||
|
||||
Terminal version: [examples/bench](/examples/bench)
|
||||
|
||||
## Build instructions
|
||||
|
||||
```bash
|
||||
# build using Emscripten (v3.1.2)
|
||||
git clone https://github.com/ggerganov/whisper.cpp
|
||||
cd whisper.cpp
|
||||
mkdir build-em && cd build-em
|
||||
emcmake cmake ..
|
||||
make -j
|
||||
|
||||
# copy the produced page to your HTTP path
|
||||
cp bin/bench.wasm/* /path/to/html/
|
||||
cp bin/libbench.worker.js /path/to/html/
|
||||
```
|
85
examples/bench.wasm/emscripten.cpp
Normal file
85
examples/bench.wasm/emscripten.cpp
Normal file
@ -0,0 +1,85 @@
|
||||
#include "whisper.h"
|
||||
|
||||
#include <emscripten.h>
|
||||
#include <emscripten/bind.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
constexpr int N_THREAD = 8;
|
||||
|
||||
// TODO: get rid of this vector of contexts - bad idea in the first place
|
||||
std::vector<struct whisper_context *> g_contexts(4, nullptr);
|
||||
|
||||
std::thread g_worker;
|
||||
|
||||
void bench_main(size_t index) {
|
||||
const int n_threads = std::min(N_THREAD, (int) std::thread::hardware_concurrency());
|
||||
|
||||
// whisper context
|
||||
auto & ctx = g_contexts[index];
|
||||
|
||||
fprintf(stderr, "%s: running benchmark with %d threads - please wait...\n", __func__, n_threads);
|
||||
|
||||
if (int ret = whisper_set_mel(ctx, nullptr, 0, WHISPER_N_MEL)) {
|
||||
fprintf(stderr, "error: failed to set mel: %d\n", ret);
|
||||
return;
|
||||
}
|
||||
|
||||
{
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", n_threads, std::thread::hardware_concurrency(), whisper_print_system_info());
|
||||
}
|
||||
|
||||
if (int ret = whisper_encode(ctx, 0, n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
||||
return;
|
||||
}
|
||||
|
||||
whisper_print_timings(ctx);
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "If you wish, you can submit these results here:\n");
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, " https://github.com/ggerganov/whisper.cpp/issues/89\n");
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "Please include the following information:\n");
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, " - CPU model\n");
|
||||
fprintf(stderr, " - Operating system\n");
|
||||
fprintf(stderr, " - Browser\n");
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
EMSCRIPTEN_BINDINGS(bench) {
|
||||
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
|
||||
for (size_t i = 0; i < g_contexts.size(); ++i) {
|
||||
if (g_contexts[i] == nullptr) {
|
||||
g_contexts[i] = whisper_init_from_file(path_model.c_str());
|
||||
if (g_contexts[i] != nullptr) {
|
||||
if (g_worker.joinable()) {
|
||||
g_worker.join();
|
||||
}
|
||||
g_worker = std::thread([i]() {
|
||||
bench_main(i);
|
||||
});
|
||||
|
||||
return i + 1;
|
||||
} else {
|
||||
return (size_t) 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return (size_t) 0;
|
||||
}));
|
||||
|
||||
emscripten::function("free", emscripten::optional_override([](size_t index) {
|
||||
if (index < g_contexts.size()) {
|
||||
whisper_free(g_contexts[index]);
|
||||
g_contexts[index] = nullptr;
|
||||
}
|
||||
}));
|
||||
}
|
227
examples/bench.wasm/index-tmpl.html
Normal file
227
examples/bench.wasm/index-tmpl.html
Normal file
@ -0,0 +1,227 @@
|
||||
<!doctype html>
|
||||
<html lang="en-us">
|
||||
<head>
|
||||
<title>bench : Benchmark whisper.cpp performance in the browser</title>
|
||||
|
||||
<style>
|
||||
#output {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
margin: 0 auto;
|
||||
margin-top: 10px;
|
||||
border-left: 0px;
|
||||
border-right: 0px;
|
||||
padding-left: 0px;
|
||||
padding-right: 0px;
|
||||
display: block;
|
||||
background-color: black;
|
||||
color: white;
|
||||
font-size: 10px;
|
||||
font-family: 'Lucida Console', Monaco, monospace;
|
||||
outline: none;
|
||||
white-space: pre;
|
||||
overflow-wrap: normal;
|
||||
overflow-x: scroll;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div id="main-container">
|
||||
<b>bench : Benchmark whisper.cpp performance in the browser</b>
|
||||
|
||||
<br><br>
|
||||
|
||||
You can find more about this project on <a href="https://github.com/ggerganov/whisper.cpp/tree/master/examples/bench.wasm">GitHub</a>.
|
||||
|
||||
<br><br>
|
||||
|
||||
<hr>
|
||||
|
||||
Select the model you would like to use and click the "Bench" button.<br>
|
||||
The results will be displayed in the textarea below.
|
||||
|
||||
<br><br>
|
||||
|
||||
<div id="model-whisper">
|
||||
Whisper model: <span id="model-whisper-status"></span>
|
||||
<button id="fetch-whisper-tiny-en" onclick="loadWhisper('tiny.en')">tiny.en (75 MB)</button>
|
||||
<button id="fetch-whisper-base-en" onclick="loadWhisper('base.en')">base.en (142 MB)</button>
|
||||
<span id="fetch-whisper-progress"></span>
|
||||
|
||||
<input type="file" id="whisper-file" name="file" onchange="loadFile(event, 'whisper.bin')" />
|
||||
</div>
|
||||
|
||||
<br>
|
||||
|
||||
<div id="input">
|
||||
<button id="bench" onclick="onBench()" disabled>Bench</button>
|
||||
<button id="clear" onclick="clearCache()">Clear Cache</button>
|
||||
</div>
|
||||
|
||||
<hr>
|
||||
|
||||
Debug output:
|
||||
<textarea id="output" rows="20"></textarea>
|
||||
|
||||
<br>
|
||||
|
||||
<b>Troubleshooting</b>
|
||||
|
||||
<br><br>
|
||||
|
||||
The page does some heavy computations, so make sure:
|
||||
|
||||
<ul>
|
||||
<li>To use a modern web browser (e.g. Chrome, Firefox)</li>
|
||||
<li>To use a fast desktop or laptop computer (i.e. not a mobile phone)</li>
|
||||
<li>Your browser supports WASM <a href="https://webassembly.org/roadmap/">Fixed-width SIMD</a></li>
|
||||
</ul>
|
||||
|
||||
<div class="cell-version">
|
||||
<span>
|
||||
|
|
||||
Build time: <span class="nav-link">@GIT_DATE@</span> |
|
||||
Commit hash: <a class="nav-link" href="https://github.com/ggerganov/whisper.cpp/commit/@GIT_SHA1@">@GIT_SHA1@</a> |
|
||||
Commit subject: <span class="nav-link">@GIT_COMMIT_SUBJECT@</span> |
|
||||
<a class="nav-link" href="https://github.com/ggerganov/whisper.cpp/tree/master/examples/bench.wasm">Source Code</a> |
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script type="text/javascript" src="helpers.js"></script>
|
||||
<script type='text/javascript'>
|
||||
// the bench instance
|
||||
var instance = null;
|
||||
|
||||
// model name
|
||||
var model_whisper = null;
|
||||
|
||||
var Module = {
|
||||
print: printTextarea,
|
||||
printErr: printTextarea,
|
||||
setStatus: function(text) {
|
||||
printTextarea('js: ' + text);
|
||||
},
|
||||
monitorRunDependencies: function(left) {
|
||||
},
|
||||
preRun: function() {
|
||||
printTextarea('js: Preparing ...');
|
||||
},
|
||||
postRun: function() {
|
||||
printTextarea('js: Initialized successfully!');
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// fetch models
|
||||
//
|
||||
|
||||
let dbVersion = 1
|
||||
let dbName = 'whisper.ggerganov.com';
|
||||
let indexedDB = window.indexedDB || window.mozIndexedDB || window.webkitIndexedDB || window.msIndexedDB
|
||||
|
||||
function storeFS(fname, buf) {
|
||||
// write to WASM file using FS_createDataFile
|
||||
// if the file exists, delete it
|
||||
try {
|
||||
Module.FS_unlink(fname);
|
||||
} catch (e) {
|
||||
// ignore
|
||||
}
|
||||
|
||||
Module.FS_createDataFile("/", fname, buf, true, true);
|
||||
|
||||
printTextarea('storeFS: stored model: ' + fname + ' size: ' + buf.length);
|
||||
|
||||
model_whisper = fname;
|
||||
|
||||
document.getElementById('model-whisper-status').innerHTML = 'loaded "' + model_whisper + '"!';
|
||||
|
||||
if (model_whisper != null) {
|
||||
document.getElementById('bench').disabled = false;
|
||||
}
|
||||
}
|
||||
|
||||
function loadFile(event, fname) {
|
||||
var file = event.target.files[0] || null;
|
||||
if (file == null) {
|
||||
return;
|
||||
}
|
||||
|
||||
printTextarea("loadFile: loading model: " + file.name + ", size: " + file.size + " bytes");
|
||||
printTextarea('loadFile: please wait ...');
|
||||
|
||||
var reader = new FileReader();
|
||||
reader.onload = function(event) {
|
||||
var buf = new Uint8Array(reader.result);
|
||||
storeFS(fname, buf);
|
||||
}
|
||||
reader.readAsArrayBuffer(file);
|
||||
|
||||
document.getElementById('fetch-whisper-tiny-en').style.display = 'none';
|
||||
document.getElementById('fetch-whisper-base-en').style.display = 'none';
|
||||
document.getElementById('whisper-file' ).style.display = 'none';
|
||||
document.getElementById('model-whisper-status' ).innerHTML = 'loaded model: ' + file.name;
|
||||
}
|
||||
|
||||
function loadWhisper(model) {
|
||||
let urls = {
|
||||
'tiny.en': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en.bin',
|
||||
'base.en': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en.bin',
|
||||
};
|
||||
|
||||
let sizes = {
|
||||
'tiny.en': 75,
|
||||
'base.en': 142,
|
||||
};
|
||||
|
||||
let url = urls[model];
|
||||
let dst = 'whisper.bin';
|
||||
let size_mb = sizes[model];
|
||||
|
||||
document.getElementById('fetch-whisper-tiny-en').style.display = 'none';
|
||||
document.getElementById('fetch-whisper-base-en').style.display = 'none';
|
||||
document.getElementById('model-whisper-status').innerHTML = 'loading "' + model + '" ... ';
|
||||
|
||||
cbProgress = function(p) {
|
||||
let el = document.getElementById('fetch-whisper-progress');
|
||||
el.innerHTML = Math.round(100*p) + '%';
|
||||
};
|
||||
|
||||
cbCancel = function() {
|
||||
var el;
|
||||
el = document.getElementById('fetch-whisper-tiny-en'); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('fetch-whisper-base-en'); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('model-whisper-status'); if (el) el.innerHTML = '';
|
||||
};
|
||||
|
||||
loadRemote(url, dst, size_mb, cbProgress, storeFS, cbCancel, printTextarea);
|
||||
}
|
||||
|
||||
//
|
||||
// main
|
||||
//
|
||||
|
||||
function onBench() {
|
||||
if (instance) {
|
||||
Module.free(instance);
|
||||
}
|
||||
|
||||
instance = Module.init('whisper.bin');
|
||||
|
||||
if (instance) {
|
||||
printTextarea("js: whisper initialized, instance: " + instance);
|
||||
}
|
||||
|
||||
document.getElementById('bench').disabled = true;
|
||||
|
||||
if (!instance) {
|
||||
printTextarea("js: failed to initialize whisper");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
</script>
|
||||
<script type="text/javascript" src="bench.js"></script>
|
||||
</body>
|
||||
</html>
|
6
examples/bench/CMakeLists.txt
Normal file
6
examples/bench/CMakeLists.txt
Normal file
@ -0,0 +1,6 @@
|
||||
set(TARGET bench)
|
||||
add_executable(${TARGET} bench.cpp)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE whisper ${CMAKE_THREAD_LIBS_INIT})
|
54
examples/bench/README.md
Normal file
54
examples/bench/README.md
Normal file
@ -0,0 +1,54 @@
|
||||
# bench
|
||||
|
||||
A very basic tool for benchmarking the inference performance on your device. The tool simply runs the Encoder part of
|
||||
the transformer on some random audio data and records the execution time. This way we can have an objective comparison
|
||||
of the performance of the model for various setups.
|
||||
|
||||
Benchmark results are tracked in the following Github issue: https://github.com/ggerganov/whisper.cpp/issues/89
|
||||
|
||||
```bash
|
||||
# build the bench tool
|
||||
$ make bench
|
||||
|
||||
# run it on the small.en model using 4 threads
|
||||
$ ./bench -m ./models/ggml-small.en.bin -t 4
|
||||
|
||||
whisper_model_load: loading model from './models/ggml-small.en.bin'
|
||||
whisper_model_load: n_vocab = 51864
|
||||
whisper_model_load: n_audio_ctx = 1500
|
||||
whisper_model_load: n_audio_state = 768
|
||||
whisper_model_load: n_audio_head = 12
|
||||
whisper_model_load: n_audio_layer = 12
|
||||
whisper_model_load: n_text_ctx = 448
|
||||
whisper_model_load: n_text_state = 768
|
||||
whisper_model_load: n_text_head = 12
|
||||
whisper_model_load: n_text_layer = 12
|
||||
whisper_model_load: n_mels = 80
|
||||
whisper_model_load: f16 = 1
|
||||
whisper_model_load: type = 3
|
||||
whisper_model_load: mem_required = 1048.00 MB
|
||||
whisper_model_load: adding 1607 extra tokens
|
||||
whisper_model_load: ggml ctx size = 533.05 MB
|
||||
whisper_model_load: memory size = 68.48 MB
|
||||
whisper_model_load: model size = 464.44 MB
|
||||
|
||||
whisper_print_timings: load time = 240.82 ms
|
||||
whisper_print_timings: mel time = 0.00 ms
|
||||
whisper_print_timings: sample time = 0.00 ms
|
||||
whisper_print_timings: encode time = 1062.21 ms / 88.52 ms per layer
|
||||
whisper_print_timings: decode time = 0.00 ms / 0.00 ms per layer
|
||||
whisper_print_timings: total time = 1303.04 ms
|
||||
|
||||
system_info: n_threads = 4 | AVX2 = 0 | AVX512 = 0 | NEON = 1 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 |
|
||||
|
||||
If you wish, you can submit these results here:
|
||||
|
||||
https://github.com/ggerganov/whisper.cpp/issues/89
|
||||
|
||||
Please include the following information:
|
||||
|
||||
- CPU model
|
||||
- Operating system
|
||||
- Compiler
|
||||
|
||||
```
|
94
examples/bench/bench.cpp
Normal file
94
examples/bench/bench.cpp
Normal file
@ -0,0 +1,94 @@
|
||||
#include "whisper.h"
|
||||
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
|
||||
// command-line parameters
|
||||
struct whisper_params {
|
||||
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||
|
||||
std::string model = "models/ggml-base.en.bin";
|
||||
};
|
||||
|
||||
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
|
||||
|
||||
bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
for (int i = 1; i < argc; i++) {
|
||||
std::string arg = argv[i];
|
||||
|
||||
if (arg == "-h" || arg == "--help") {
|
||||
whisper_print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
}
|
||||
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
|
||||
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
|
||||
else {
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
whisper_print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params) {
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "usage: %s [options]\n", argv[0]);
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "options:\n");
|
||||
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
|
||||
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
|
||||
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
whisper_params params;
|
||||
|
||||
if (whisper_params_parse(argc, argv, params) == false) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
// whisper init
|
||||
|
||||
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
|
||||
|
||||
{
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", params.n_threads, std::thread::hardware_concurrency(), whisper_print_system_info());
|
||||
}
|
||||
|
||||
if (ctx == nullptr) {
|
||||
fprintf(stderr, "error: failed to initialize whisper context\n");
|
||||
return 2;
|
||||
}
|
||||
|
||||
if (int ret = whisper_set_mel(ctx, nullptr, 0, WHISPER_N_MEL)) {
|
||||
fprintf(stderr, "error: failed to set mel: %d\n", ret);
|
||||
return 3;
|
||||
}
|
||||
|
||||
if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
||||
return 4;
|
||||
}
|
||||
|
||||
whisper_print_timings(ctx);
|
||||
whisper_free(ctx);
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "If you wish, you can submit these results here:\n");
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, " https://github.com/ggerganov/whisper.cpp/issues/89\n");
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "Please include the following information:\n");
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, " - CPU model\n");
|
||||
fprintf(stderr, " - Operating system\n");
|
||||
fprintf(stderr, " - Compiler\n");
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
return 0;
|
||||
}
|
49
examples/command.wasm/CMakeLists.txt
Normal file
49
examples/command.wasm/CMakeLists.txt
Normal file
@ -0,0 +1,49 @@
|
||||
#
|
||||
# libcommand
|
||||
#
|
||||
|
||||
set(TARGET libcommand)
|
||||
|
||||
add_executable(${TARGET}
|
||||
emscripten.cpp
|
||||
)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE
|
||||
whisper
|
||||
)
|
||||
|
||||
unset(EXTRA_FLAGS)
|
||||
|
||||
if (WHISPER_WASM_SINGLE_FILE)
|
||||
set(EXTRA_FLAGS "-s SINGLE_FILE=1")
|
||||
message(STATUS "Embedding WASM inside command.js")
|
||||
|
||||
add_custom_command(
|
||||
TARGET ${TARGET} POST_BUILD
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${CMAKE_BINARY_DIR}/bin/libcommand.js
|
||||
${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/command.wasm/command.js
|
||||
)
|
||||
endif()
|
||||
|
||||
set_target_properties(${TARGET} PROPERTIES LINK_FLAGS " \
|
||||
--bind \
|
||||
-s USE_PTHREADS=1 \
|
||||
-s PTHREAD_POOL_SIZE=8 \
|
||||
-s INITIAL_MEMORY=1024MB \
|
||||
-s TOTAL_MEMORY=1024MB \
|
||||
-s FORCE_FILESYSTEM=1 \
|
||||
-s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap']\" \
|
||||
${EXTRA_FLAGS} \
|
||||
")
|
||||
|
||||
#
|
||||
# command.wasm
|
||||
#
|
||||
|
||||
set(TARGET command.wasm)
|
||||
|
||||
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/index-tmpl.html ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TARGET}/index.html @ONLY)
|
||||
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/../helpers.js ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TARGET}/helpers.js @ONLY)
|
23
examples/command.wasm/README.md
Normal file
23
examples/command.wasm/README.md
Normal file
@ -0,0 +1,23 @@
|
||||
# command.wasm
|
||||
|
||||
This is a basic Voice Assistant example that accepts voice commands from the microphone.
|
||||
It runs in fully in the browser via WebAseembly.
|
||||
|
||||
Online demo: https://whisper.ggerganov.com/command/
|
||||
|
||||
Terminal version: [examples/command](/examples/command)
|
||||
|
||||
## Build instructions
|
||||
|
||||
```bash
|
||||
# build using Emscripten (v3.1.2)
|
||||
git clone https://github.com/ggerganov/whisper.cpp
|
||||
cd whisper.cpp
|
||||
mkdir build-em && cd build-em
|
||||
emcmake cmake ..
|
||||
make -j
|
||||
|
||||
# copy the produced page to your HTTP path
|
||||
cp bin/command.wasm/* /path/to/html/
|
||||
cp bin/libcommand.worker.js /path/to/html/
|
||||
```
|
408
examples/command.wasm/emscripten.cpp
Normal file
408
examples/command.wasm/emscripten.cpp
Normal file
@ -0,0 +1,408 @@
|
||||
#include "ggml.h"
|
||||
#include "whisper.h"
|
||||
|
||||
#include <emscripten.h>
|
||||
#include <emscripten/bind.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <cmath>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <regex>
|
||||
|
||||
constexpr int N_THREAD = 8;
|
||||
|
||||
std::vector<struct whisper_context *> g_contexts(4, nullptr);
|
||||
|
||||
std::mutex g_mutex;
|
||||
std::thread g_worker;
|
||||
|
||||
std::atomic<bool> g_running(false);
|
||||
|
||||
std::string g_status = "";
|
||||
std::string g_status_forced = "";
|
||||
std::string g_transcribed = "";
|
||||
|
||||
std::vector<float> g_pcmf32;
|
||||
|
||||
static std::string trim(const std::string & s) {
|
||||
std::regex e("^\\s+|\\s+$");
|
||||
return std::regex_replace(s, e, "");
|
||||
}
|
||||
|
||||
static void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate) {
|
||||
const float rc = 1.0f / (2.0f * M_PI * cutoff);
|
||||
const float dt = 1.0f / sample_rate;
|
||||
const float alpha = dt / (rc + dt);
|
||||
|
||||
float y = data[0];
|
||||
|
||||
for (size_t i = 1; i < data.size(); i++) {
|
||||
y = alpha * (y + data[i] - data[i - 1]);
|
||||
data[i] = y;
|
||||
}
|
||||
}
|
||||
|
||||
// compute similarity between two strings using Levenshtein distance
|
||||
static float similarity(const std::string & s0, const std::string & s1) {
|
||||
const size_t len0 = s0.size() + 1;
|
||||
const size_t len1 = s1.size() + 1;
|
||||
|
||||
std::vector<int> col(len1, 0);
|
||||
std::vector<int> prevCol(len1, 0);
|
||||
|
||||
for (size_t i = 0; i < len1; i++) {
|
||||
prevCol[i] = i;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < len0; i++) {
|
||||
col[0] = i;
|
||||
for (size_t j = 1; j < len1; j++) {
|
||||
col[j] = std::min(std::min(1 + col[j - 1], 1 + prevCol[j]), prevCol[j - 1] + (s0[i - 1] == s1[j - 1] ? 0 : 1));
|
||||
}
|
||||
col.swap(prevCol);
|
||||
}
|
||||
|
||||
const float dist = prevCol[len1 - 1];
|
||||
|
||||
return 1.0f - (dist / std::max(s0.size(), s1.size()));
|
||||
}
|
||||
|
||||
void command_set_status(const std::string & status) {
|
||||
std::lock_guard<std::mutex> lock(g_mutex);
|
||||
g_status = status;
|
||||
}
|
||||
|
||||
bool command_vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) {
|
||||
const int n_samples = pcmf32.size();
|
||||
const int n_samples_last = (sample_rate * last_ms) / 1000;
|
||||
|
||||
if (n_samples_last >= n_samples) {
|
||||
// not enough samples - assume no speech
|
||||
return false;
|
||||
}
|
||||
|
||||
if (freq_thold > 0.0f) {
|
||||
high_pass_filter(pcmf32, freq_thold, sample_rate);
|
||||
}
|
||||
|
||||
float energy_all = 0.0f;
|
||||
float energy_last = 0.0f;
|
||||
|
||||
for (size_t i = 0; i < n_samples; i++) {
|
||||
energy_all += fabsf(pcmf32[i]);
|
||||
|
||||
if (i >= n_samples - n_samples_last) {
|
||||
energy_last += fabsf(pcmf32[i]);
|
||||
}
|
||||
}
|
||||
|
||||
energy_all /= n_samples;
|
||||
energy_last /= n_samples_last;
|
||||
|
||||
if (verbose) {
|
||||
fprintf(stderr, "%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold);
|
||||
}
|
||||
|
||||
if (energy_last > vad_thold*energy_all) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string command_transcribe(whisper_context * ctx, const whisper_full_params & wparams, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) {
|
||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
prob = 0.0f;
|
||||
t_ms = 0;
|
||||
|
||||
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
|
||||
return "";
|
||||
}
|
||||
|
||||
int prob_n = 0;
|
||||
std::string result;
|
||||
|
||||
const int n_segments = whisper_full_n_segments(ctx);
|
||||
for (int i = 0; i < n_segments; ++i) {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
|
||||
result += text;
|
||||
|
||||
const int n_tokens = whisper_full_n_tokens(ctx, i);
|
||||
for (int j = 0; j < n_tokens; ++j) {
|
||||
const auto token = whisper_full_get_token_data(ctx, i, j);
|
||||
|
||||
prob += token.p;
|
||||
++prob_n;
|
||||
}
|
||||
}
|
||||
|
||||
if (prob_n > 0) {
|
||||
prob /= prob_n;
|
||||
}
|
||||
|
||||
const auto t_end = std::chrono::high_resolution_clock::now();
|
||||
t_ms = std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count();
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
void command_get_audio(int ms, int sample_rate, std::vector<float> & audio) {
|
||||
const int64_t n_samples = (ms * sample_rate) / 1000;
|
||||
|
||||
int64_t n_take = 0;
|
||||
if (g_pcmf32.size() < n_samples) {
|
||||
n_take = g_pcmf32.size();
|
||||
} else {
|
||||
n_take = n_samples;
|
||||
}
|
||||
|
||||
audio.resize(n_take);
|
||||
std::copy(g_pcmf32.end() - n_take, g_pcmf32.end(), audio.begin());
|
||||
}
|
||||
|
||||
void command_main(size_t index) {
|
||||
command_set_status("loading data ...");
|
||||
|
||||
struct whisper_full_params wparams = whisper_full_default_params(whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY);
|
||||
|
||||
wparams.n_threads = std::min(N_THREAD, (int) std::thread::hardware_concurrency());
|
||||
wparams.offset_ms = 0;
|
||||
wparams.translate = false;
|
||||
wparams.no_context = true;
|
||||
wparams.single_segment = true;
|
||||
wparams.print_realtime = false;
|
||||
wparams.print_progress = false;
|
||||
wparams.print_timestamps = true;
|
||||
wparams.print_special = false;
|
||||
|
||||
wparams.max_tokens = 32;
|
||||
wparams.audio_ctx = 768; // partial encoder context for better performance
|
||||
|
||||
wparams.language = "en";
|
||||
|
||||
printf("command: using %d threads\n", wparams.n_threads);
|
||||
|
||||
bool is_running = true;
|
||||
bool have_prompt = false;
|
||||
bool ask_prompt = true;
|
||||
bool print_energy = false;
|
||||
|
||||
float prob0 = 0.0f;
|
||||
float prob = 0.0f;
|
||||
|
||||
std::vector<float> pcmf32_cur;
|
||||
std::vector<float> pcmf32_prompt;
|
||||
|
||||
const std::string k_prompt = "Ok Whisper, start listening for commands.";
|
||||
|
||||
// whisper context
|
||||
auto & ctx = g_contexts[index];
|
||||
|
||||
const int32_t vad_ms = 2000;
|
||||
const int32_t prompt_ms = 5000;
|
||||
const int32_t command_ms = 4000;
|
||||
|
||||
const float vad_thold = 0.1f;
|
||||
const float freq_thold = -1.0f;
|
||||
|
||||
while (g_running) {
|
||||
// delay
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
|
||||
if (ask_prompt) {
|
||||
fprintf(stdout, "\n");
|
||||
fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m");
|
||||
fprintf(stdout, "\n");
|
||||
|
||||
{
|
||||
char txt[1024];
|
||||
snprintf(txt, sizeof(txt), "Say the following phrase: '%s'", k_prompt.c_str());
|
||||
command_set_status(txt);
|
||||
}
|
||||
|
||||
ask_prompt = false;
|
||||
}
|
||||
|
||||
int64_t t_ms = 0;
|
||||
|
||||
{
|
||||
command_get_audio(vad_ms, WHISPER_SAMPLE_RATE, pcmf32_cur);
|
||||
|
||||
if (command_vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, vad_thold, freq_thold, print_energy)) {
|
||||
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
|
||||
command_set_status("Speech detected! Processing ...");
|
||||
|
||||
if (!have_prompt) {
|
||||
command_get_audio(prompt_ms, WHISPER_SAMPLE_RATE, pcmf32_cur);
|
||||
|
||||
const auto txt = ::trim(::command_transcribe(ctx, wparams, pcmf32_cur, prob0, t_ms));
|
||||
|
||||
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms);
|
||||
|
||||
const float sim = similarity(txt, k_prompt);
|
||||
|
||||
if (txt.length() < 0.8*k_prompt.length() || txt.length() > 1.2*k_prompt.length() || sim < 0.8f) {
|
||||
fprintf(stdout, "%s: WARNING: prompt not recognized, try again\n", __func__);
|
||||
ask_prompt = true;
|
||||
} else {
|
||||
fprintf(stdout, "\n");
|
||||
fprintf(stdout, "%s: The prompt has been recognized!\n", __func__);
|
||||
fprintf(stdout, "%s: Waiting for voice commands ...\n", __func__);
|
||||
fprintf(stdout, "\n");
|
||||
|
||||
{
|
||||
char txt[1024];
|
||||
snprintf(txt, sizeof(txt), "Success! Waiting for voice commands ...");
|
||||
command_set_status(txt);
|
||||
}
|
||||
|
||||
// save the audio for the prompt
|
||||
pcmf32_prompt = pcmf32_cur;
|
||||
have_prompt = true;
|
||||
}
|
||||
} else {
|
||||
command_get_audio(command_ms, WHISPER_SAMPLE_RATE, pcmf32_cur);
|
||||
|
||||
// prepend the prompt audio
|
||||
pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
|
||||
|
||||
const auto txt = ::trim(::command_transcribe(ctx, wparams, pcmf32_cur, prob, t_ms));
|
||||
|
||||
prob = 100.0f*(prob - prob0);
|
||||
|
||||
fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str());
|
||||
|
||||
// find the prompt in the text
|
||||
float best_sim = 0.0f;
|
||||
size_t best_len = 0;
|
||||
for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
|
||||
const auto prompt = txt.substr(0, n);
|
||||
|
||||
const float sim = similarity(prompt, k_prompt);
|
||||
|
||||
//fprintf(stderr, "%s: prompt = '%s', sim = %f\n", __func__, prompt.c_str(), sim);
|
||||
|
||||
if (sim > best_sim) {
|
||||
best_sim = sim;
|
||||
best_len = n;
|
||||
}
|
||||
}
|
||||
|
||||
const std::string command = ::trim(txt.substr(best_len));
|
||||
|
||||
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
|
||||
fprintf(stdout, "\n");
|
||||
|
||||
{
|
||||
char txt[1024];
|
||||
snprintf(txt, sizeof(txt), "Command '%s', (t = %d ms)", command.c_str(), (int) t_ms);
|
||||
command_set_status(txt);
|
||||
}
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(g_mutex);
|
||||
g_transcribed = command;
|
||||
}
|
||||
}
|
||||
|
||||
g_pcmf32.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (index < g_contexts.size()) {
|
||||
whisper_free(g_contexts[index]);
|
||||
g_contexts[index] = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
EMSCRIPTEN_BINDINGS(command) {
|
||||
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
|
||||
for (size_t i = 0; i < g_contexts.size(); ++i) {
|
||||
if (g_contexts[i] == nullptr) {
|
||||
g_contexts[i] = whisper_init_from_file(path_model.c_str());
|
||||
if (g_contexts[i] != nullptr) {
|
||||
g_running = true;
|
||||
if (g_worker.joinable()) {
|
||||
g_worker.join();
|
||||
}
|
||||
g_worker = std::thread([i]() {
|
||||
command_main(i);
|
||||
});
|
||||
|
||||
return i + 1;
|
||||
} else {
|
||||
return (size_t) 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return (size_t) 0;
|
||||
}));
|
||||
|
||||
emscripten::function("free", emscripten::optional_override([](size_t index) {
|
||||
if (g_running) {
|
||||
g_running = false;
|
||||
}
|
||||
}));
|
||||
|
||||
emscripten::function("set_audio", emscripten::optional_override([](size_t index, const emscripten::val & audio) {
|
||||
--index;
|
||||
|
||||
if (index >= g_contexts.size()) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (g_contexts[index] == nullptr) {
|
||||
return -2;
|
||||
}
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(g_mutex);
|
||||
const int n = audio["length"].as<int>();
|
||||
|
||||
emscripten::val heap = emscripten::val::module_property("HEAPU8");
|
||||
emscripten::val memory = heap["buffer"];
|
||||
|
||||
g_pcmf32.resize(n);
|
||||
|
||||
emscripten::val memoryView = audio["constructor"].new_(memory, reinterpret_cast<uintptr_t>(g_pcmf32.data()), n);
|
||||
memoryView.call<void>("set", audio);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}));
|
||||
|
||||
emscripten::function("get_transcribed", emscripten::optional_override([]() {
|
||||
std::string transcribed;
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(g_mutex);
|
||||
transcribed = std::move(g_transcribed);
|
||||
}
|
||||
|
||||
return transcribed;
|
||||
}));
|
||||
|
||||
emscripten::function("get_status", emscripten::optional_override([]() {
|
||||
std::string status;
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(g_mutex);
|
||||
status = g_status_forced.empty() ? g_status : g_status_forced;
|
||||
}
|
||||
|
||||
return status;
|
||||
}));
|
||||
|
||||
emscripten::function("set_status", emscripten::optional_override([](const std::string & status) {
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(g_mutex);
|
||||
g_status_forced = status;
|
||||
}
|
||||
}));
|
||||
}
|
386
examples/command.wasm/index-tmpl.html
Normal file
386
examples/command.wasm/index-tmpl.html
Normal file
@ -0,0 +1,386 @@
|
||||
<!doctype html>
|
||||
<html lang="en-us">
|
||||
<head>
|
||||
<title>command : Voice assistant example using Whisper + WebAssembly</title>
|
||||
|
||||
<style>
|
||||
#output {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
margin: 0 auto;
|
||||
margin-top: 10px;
|
||||
border-left: 0px;
|
||||
border-right: 0px;
|
||||
padding-left: 0px;
|
||||
padding-right: 0px;
|
||||
display: block;
|
||||
background-color: black;
|
||||
color: white;
|
||||
font-size: 10px;
|
||||
font-family: 'Lucida Console', Monaco, monospace;
|
||||
outline: none;
|
||||
white-space: pre;
|
||||
overflow-wrap: normal;
|
||||
overflow-x: scroll;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div id="main-container">
|
||||
<b>command : Voice assistant example using Whisper + WebAssembly</b>
|
||||
|
||||
<br><br>
|
||||
|
||||
You can find more about this project on <a href="https://github.com/ggerganov/whisper.cpp/tree/master/examples/command.wasm">GitHub</a>.
|
||||
|
||||
<br><br>
|
||||
|
||||
<hr>
|
||||
|
||||
Select the model you would like to use, click the "Start" button and follow the instructions.
|
||||
|
||||
<br><br>
|
||||
|
||||
<div id="model-whisper">
|
||||
Whisper model: <span id="model-whisper-status"></span>
|
||||
<button id="fetch-whisper-tiny-en" onclick="loadWhisper('tiny.en')">tiny.en (75 MB)</button>
|
||||
<button id="fetch-whisper-base-en" onclick="loadWhisper('base.en')">base.en (142 MB)</button>
|
||||
<span id="fetch-whisper-progress"></span>
|
||||
|
||||
<!--
|
||||
<input type="file" id="file" name="file" onchange="loadFile(event, 'whisper.bin')" />
|
||||
-->
|
||||
</div>
|
||||
|
||||
<br>
|
||||
|
||||
<div id="input">
|
||||
<button id="start" onclick="onStart()" disabled>Start</button>
|
||||
<button id="stop" onclick="onStop()" disabled>Stop</button>
|
||||
<button id="clear" onclick="clearCache()">Clear Cache</button>
|
||||
</div>
|
||||
|
||||
<br>
|
||||
|
||||
<div id="state">
|
||||
Status: <b><span id="state-status">not started</span></b>
|
||||
|
||||
<pre id="state-transcribed">[The recognized voice commands will be displayed here]</pre>
|
||||
</div>
|
||||
|
||||
<hr>
|
||||
|
||||
Debug output:
|
||||
<textarea id="output" rows="20"></textarea>
|
||||
|
||||
<br>
|
||||
|
||||
<b>Troubleshooting</b>
|
||||
|
||||
<br><br>
|
||||
|
||||
The page does some heavy computations, so make sure:
|
||||
|
||||
<ul>
|
||||
<li>To use a modern web browser (e.g. Chrome, Firefox)</li>
|
||||
<li>To use a fast desktop or laptop computer (i.e. not a mobile phone)</li>
|
||||
<li>Your browser supports WASM <a href="https://webassembly.org/roadmap/">Fixed-width SIMD</a></li>
|
||||
</ul>
|
||||
|
||||
<div class="cell-version">
|
||||
<span>
|
||||
|
|
||||
Build time: <span class="nav-link">@GIT_DATE@</span> |
|
||||
Commit hash: <a class="nav-link" href="https://github.com/ggerganov/whisper.cpp/commit/@GIT_SHA1@">@GIT_SHA1@</a> |
|
||||
Commit subject: <span class="nav-link">@GIT_COMMIT_SUBJECT@</span> |
|
||||
<a class="nav-link" href="https://github.com/ggerganov/whisper.cpp/tree/master/examples/command.wasm">Source Code</a> |
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script type="text/javascript" src="helpers.js"></script>
|
||||
<script type='text/javascript'>
|
||||
// web audio context
|
||||
var context = null;
|
||||
|
||||
// audio data
|
||||
var audio = null;
|
||||
var audio0 = null;
|
||||
|
||||
// the command instance
|
||||
var instance = null;
|
||||
|
||||
// model name
|
||||
var model_whisper = null;
|
||||
|
||||
var Module = {
|
||||
print: printTextarea,
|
||||
printErr: printTextarea,
|
||||
setStatus: function(text) {
|
||||
printTextarea('js: ' + text);
|
||||
},
|
||||
monitorRunDependencies: function(left) {
|
||||
},
|
||||
preRun: function() {
|
||||
printTextarea('js: Preparing ...');
|
||||
},
|
||||
postRun: function() {
|
||||
printTextarea('js: Initialized successfully!');
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// fetch models
|
||||
//
|
||||
|
||||
let dbVersion = 1
|
||||
let dbName = 'whisper.ggerganov.com';
|
||||
let indexedDB = window.indexedDB || window.mozIndexedDB || window.webkitIndexedDB || window.msIndexedDB
|
||||
|
||||
function storeFS(fname, buf) {
|
||||
// write to WASM file using FS_createDataFile
|
||||
// if the file exists, delete it
|
||||
try {
|
||||
Module.FS_unlink(fname);
|
||||
} catch (e) {
|
||||
// ignore
|
||||
}
|
||||
|
||||
Module.FS_createDataFile("/", fname, buf, true, true);
|
||||
|
||||
printTextarea('storeFS: stored model: ' + fname + ' size: ' + buf.length);
|
||||
|
||||
document.getElementById('model-whisper-status').innerHTML = 'loaded "' + model_whisper + '"!';
|
||||
|
||||
if (model_whisper != null) {
|
||||
document.getElementById('start').disabled = false;
|
||||
document.getElementById('stop' ).disabled = true;
|
||||
}
|
||||
}
|
||||
|
||||
function loadWhisper(model) {
|
||||
let urls = {
|
||||
'tiny.en': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en.bin',
|
||||
'base.en': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en.bin',
|
||||
};
|
||||
|
||||
let sizes = {
|
||||
'tiny.en': 75,
|
||||
'base.en': 142,
|
||||
};
|
||||
|
||||
let url = urls[model];
|
||||
let dst = 'whisper.bin';
|
||||
let size_mb = sizes[model];
|
||||
|
||||
model_whisper = model;
|
||||
|
||||
document.getElementById('fetch-whisper-tiny-en').style.display = 'none';
|
||||
document.getElementById('fetch-whisper-base-en').style.display = 'none';
|
||||
document.getElementById('model-whisper-status').innerHTML = 'loading "' + model + '" ... ';
|
||||
|
||||
cbProgress = function(p) {
|
||||
let el = document.getElementById('fetch-whisper-progress');
|
||||
el.innerHTML = Math.round(100*p) + '%';
|
||||
};
|
||||
|
||||
cbCancel = function() {
|
||||
var el;
|
||||
el = document.getElementById('fetch-whisper-tiny-en'); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('fetch-whisper-base-en'); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('model-whisper-status'); if (el) el.innerHTML = '';
|
||||
};
|
||||
|
||||
loadRemote(url, dst, size_mb, cbProgress, storeFS, cbCancel, printTextarea);
|
||||
}
|
||||
|
||||
//
|
||||
// microphone
|
||||
//
|
||||
|
||||
const kSampleRate = 16000;
|
||||
const kRestartRecording_s = 120;
|
||||
const kIntervalAudio_ms = 250; // pass the recorded audio to the C++ instance at this rate
|
||||
|
||||
var mediaRecorder = null;
|
||||
var doRecording = false;
|
||||
var startTime = 0;
|
||||
|
||||
window.AudioContext = window.AudioContext || window.webkitAudioContext;
|
||||
window.OfflineAudioContext = window.OfflineAudioContext || window.webkitOfflineAudioContext;
|
||||
|
||||
function stopRecording() {
|
||||
Module.set_status("paused");
|
||||
doRecording = false;
|
||||
audio0 = null;
|
||||
audio = null;
|
||||
context = null;
|
||||
}
|
||||
|
||||
function startRecording() {
|
||||
if (!context) {
|
||||
context = new AudioContext({
|
||||
sampleRate: kSampleRate,
|
||||
channelCount: 1,
|
||||
echoCancellation: false,
|
||||
autoGainControl: true,
|
||||
noiseSuppression: true,
|
||||
});
|
||||
}
|
||||
|
||||
Module.set_status("");
|
||||
|
||||
document.getElementById('start').disabled = true;
|
||||
document.getElementById('stop').disabled = false;
|
||||
|
||||
doRecording = true;
|
||||
startTime = Date.now();
|
||||
|
||||
var chunks = [];
|
||||
var stream = null;
|
||||
|
||||
navigator.mediaDevices.getUserMedia({audio: true, video: false})
|
||||
.then(function(s) {
|
||||
stream = s;
|
||||
mediaRecorder = new MediaRecorder(stream);
|
||||
mediaRecorder.ondataavailable = function(e) {
|
||||
chunks.push(e.data);
|
||||
|
||||
var blob = new Blob(chunks, { 'type' : 'audio/ogg; codecs=opus' });
|
||||
var reader = new FileReader();
|
||||
|
||||
reader.onload = function(event) {
|
||||
var buf = new Uint8Array(reader.result);
|
||||
|
||||
if (!context) {
|
||||
return;
|
||||
}
|
||||
context.decodeAudioData(buf.buffer, function(audioBuffer) {
|
||||
var offlineContext = new OfflineAudioContext(audioBuffer.numberOfChannels, audioBuffer.length, audioBuffer.sampleRate);
|
||||
var source = offlineContext.createBufferSource();
|
||||
source.buffer = audioBuffer;
|
||||
source.connect(offlineContext.destination);
|
||||
source.start(0);
|
||||
|
||||
offlineContext.startRendering().then(function(renderedBuffer) {
|
||||
audio = renderedBuffer.getChannelData(0);
|
||||
|
||||
//printTextarea('js: audio recorded, size: ' + audio.length + ', old size: ' + (audio0 == null ? 0 : audio0.length));
|
||||
|
||||
var audioAll = new Float32Array(audio0 == null ? audio.length : audio0.length + audio.length);
|
||||
if (audio0 != null) {
|
||||
audioAll.set(audio0, 0);
|
||||
}
|
||||
audioAll.set(audio, audio0 == null ? 0 : audio0.length);
|
||||
|
||||
if (instance) {
|
||||
Module.set_audio(instance, audioAll);
|
||||
}
|
||||
});
|
||||
}, function(e) {
|
||||
audio = null;
|
||||
});
|
||||
}
|
||||
|
||||
reader.readAsArrayBuffer(blob);
|
||||
};
|
||||
|
||||
mediaRecorder.onstop = function(e) {
|
||||
if (doRecording) {
|
||||
setTimeout(function() {
|
||||
startRecording();
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
mediaRecorder.start(kIntervalAudio_ms);
|
||||
})
|
||||
.catch(function(err) {
|
||||
printTextarea('js: error getting audio stream: ' + err);
|
||||
});
|
||||
|
||||
var interval = setInterval(function() {
|
||||
if (!doRecording) {
|
||||
clearInterval(interval);
|
||||
mediaRecorder.stop();
|
||||
stream.getTracks().forEach(function(track) {
|
||||
track.stop();
|
||||
});
|
||||
|
||||
document.getElementById('start').disabled = false;
|
||||
document.getElementById('stop').disabled = true;
|
||||
|
||||
mediaRecorder = null;
|
||||
}
|
||||
|
||||
// if audio length is more than kRestartRecording_s seconds, restart recording
|
||||
if (audio != null && audio.length > kSampleRate*kRestartRecording_s) {
|
||||
if (doRecording) {
|
||||
//printTextarea('js: restarting recording');
|
||||
|
||||
clearInterval(interval);
|
||||
audio0 = audio;
|
||||
audio = null;
|
||||
mediaRecorder.stop();
|
||||
stream.getTracks().forEach(function(track) {
|
||||
track.stop();
|
||||
});
|
||||
}
|
||||
}
|
||||
}, 100);
|
||||
}
|
||||
|
||||
//
|
||||
// main
|
||||
//
|
||||
|
||||
var nLines = 0;
|
||||
var intervalUpdate = null;
|
||||
var transcribedAll = '';
|
||||
|
||||
function onStart() {
|
||||
if (!instance) {
|
||||
instance = Module.init('whisper.bin');
|
||||
|
||||
if (instance) {
|
||||
printTextarea("js: whisper initialized, instance: " + instance);
|
||||
}
|
||||
}
|
||||
|
||||
if (!instance) {
|
||||
printTextarea("js: failed to initialize whisper");
|
||||
return;
|
||||
}
|
||||
|
||||
startRecording();
|
||||
|
||||
intervalUpdate = setInterval(function() {
|
||||
var transcribed = Module.get_transcribed();
|
||||
|
||||
if (transcribed != null && transcribed.length > 1) {
|
||||
transcribedAll += transcribed + '<br>';
|
||||
nLines++;
|
||||
|
||||
// if more than 10 lines, remove the first line
|
||||
if (nLines > 10) {
|
||||
var i = transcribedAll.indexOf('<br>');
|
||||
if (i > 0) {
|
||||
transcribedAll = transcribedAll.substring(i + 4);
|
||||
nLines--;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
document.getElementById('state-status').innerHTML = Module.get_status();
|
||||
document.getElementById('state-transcribed').innerHTML = transcribedAll;
|
||||
}, 100);
|
||||
}
|
||||
|
||||
function onStop() {
|
||||
stopRecording();
|
||||
}
|
||||
|
||||
</script>
|
||||
<script type="text/javascript" src="command.js"></script>
|
||||
</body>
|
||||
</html>
|
10
examples/command/CMakeLists.txt
Normal file
10
examples/command/CMakeLists.txt
Normal file
@ -0,0 +1,10 @@
|
||||
if (WHISPER_SUPPORT_SDL2)
|
||||
# command
|
||||
set(TARGET command)
|
||||
add_executable(${TARGET} command.cpp)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
|
||||
target_link_libraries(${TARGET} PRIVATE whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
|
||||
endif ()
|
47
examples/command/README.md
Normal file
47
examples/command/README.md
Normal file
@ -0,0 +1,47 @@
|
||||
# command
|
||||
|
||||
This is a basic Voice Assistant example that accepts voice commands from the microphone.
|
||||
More info is available in [issue #171](https://github.com/ggerganov/whisper.cpp/issues/171).
|
||||
|
||||
```bash
|
||||
# Run with default arguments and small model
|
||||
./command -m ./models/ggml-small.en.bin -t 8
|
||||
|
||||
# On Raspberry Pi, use tiny or base models + "-ac 768" for better performance
|
||||
./command -m ./models/ggml-tiny.en.bin -ac 768 -t 3 -c 0
|
||||
```
|
||||
|
||||
https://user-images.githubusercontent.com/1991296/204038393-2f846eae-c255-4099-a76d-5735c25c49da.mp4
|
||||
|
||||
Web version: [examples/command.wasm](/examples/command.wasm)
|
||||
|
||||
## Guided mode
|
||||
|
||||
"Guided mode" allows you to specify a list of commands (i.e. strings) and the transcription will be guided to classify your command into one from the list. This can be useful in situations where a device is listening only for a small subset of commands.
|
||||
|
||||
Initial tests show that this approach might be extremely efficient in terms of performance, since it integrates very well with the "partial Encoder" idea from #137.
|
||||
|
||||
```bash
|
||||
# Run in guided mode, the list of allowed commands is in commands.txt
|
||||
./command -m ./models/ggml-base.en.bin -cmd ./examples/command/commands.txt
|
||||
|
||||
# On Raspberry Pi, in guided mode you can use "-ac 128" for extra performance
|
||||
./command -m ./models/ggml-tiny.en.bin -cmd ./examples/command/commands.txt -ac 128 -t 3 -c 0
|
||||
```
|
||||
|
||||
https://user-images.githubusercontent.com/1991296/207435352-8fc4ed3f-bde5-4555-9b8b-aeeb76bee969.mp4
|
||||
|
||||
|
||||
## Building
|
||||
|
||||
The `command` tool depends on SDL2 library to capture audio from the microphone. You can build it like this:
|
||||
|
||||
```bash
|
||||
# Install SDL2 on Linux
|
||||
sudo apt-get install libsdl2-dev
|
||||
|
||||
# Install SDL2 on Mac OS
|
||||
brew install sdl2
|
||||
|
||||
make command
|
||||
```
|
1011
examples/command/command.cpp
Normal file
1011
examples/command/command.cpp
Normal file
File diff suppressed because it is too large
Load Diff
9
examples/command/commands.txt
Normal file
9
examples/command/commands.txt
Normal file
@ -0,0 +1,9 @@
|
||||
enable
|
||||
disable
|
||||
cat
|
||||
dog
|
||||
apple
|
||||
red
|
||||
blue
|
||||
green
|
||||
lightblue
|
60
examples/generate-karaoke.sh
Executable file
60
examples/generate-karaoke.sh
Executable file
@ -0,0 +1,60 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Simple tool to record audio from the microphone and generate a karaoke video
|
||||
# Usage:
|
||||
#
|
||||
# cd whisper.cpp
|
||||
# make
|
||||
#
|
||||
# ./examples/generate-karaoke.sh [model] [step_ms]
|
||||
#
|
||||
# Press Ctrl+C to stop recording
|
||||
#
|
||||
|
||||
executable="./main"
|
||||
model="base.en"
|
||||
model_path="models/ggml-$model.bin"
|
||||
|
||||
# require sox and ffmpeg to be installed
|
||||
if ! command -v sox &> /dev/null
|
||||
then
|
||||
echo "sox could not be found"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! command -v ffmpeg &> /dev/null
|
||||
then
|
||||
echo "ffmpeg could not be found"
|
||||
exit 2
|
||||
fi
|
||||
|
||||
if [ ! -f "$executable" ]; then
|
||||
echo "'$executable' does not exist. Please build it first."
|
||||
exit 3
|
||||
fi
|
||||
|
||||
if [ ! -f "$model_path" ]; then
|
||||
echo "'$model_path' does not exist. Please download it first."
|
||||
exit 4
|
||||
fi
|
||||
|
||||
# record some raw audio
|
||||
sox -d rec.wav
|
||||
|
||||
# resample to 16kHz
|
||||
ffmpeg -y -i ./rec.wav -ar 16000 -ac 1 -c:a pcm_s16le ./rec16.wav > /dev/null 2>&1
|
||||
|
||||
# run Whisper
|
||||
echo "Processing ..."
|
||||
./main -m models/ggml-base.en.bin rec16.wav -owts > /dev/null 2>&1
|
||||
|
||||
# generate Karaoke video
|
||||
echo "Generating video ..."
|
||||
source rec16.wav.wts > /dev/null 2>&1
|
||||
|
||||
# play the video
|
||||
echo "Playing ./rec16.wav.mp4 ..."
|
||||
ffplay -loglevel 0 -autoexit ./rec16.wav.mp4
|
||||
|
||||
echo "Done"
|
||||
exit 0
|
182
examples/helpers.js
Normal file
182
examples/helpers.js
Normal file
@ -0,0 +1,182 @@
|
||||
// Common Javascript functions used by the examples
|
||||
|
||||
function convertTypedArray(src, type) {
|
||||
var buffer = new ArrayBuffer(src.byteLength);
|
||||
var baseView = new src.constructor(buffer).set(src);
|
||||
return new type(buffer);
|
||||
}
|
||||
|
||||
var printTextarea = (function() {
|
||||
var element = document.getElementById('output');
|
||||
if (element) element.alue = ''; // clear browser cache
|
||||
return function(text) {
|
||||
if (arguments.length > 1) text = Array.prototype.slice.call(arguments).join(' ');
|
||||
console.log(text);
|
||||
if (element) {
|
||||
element.value += text + "\n";
|
||||
element.scrollTop = element.scrollHeight; // focus on bottom
|
||||
}
|
||||
};
|
||||
})();
|
||||
|
||||
async function clearCache() {
|
||||
if (confirm('Are you sure you want to clear the cache?\nAll the models will be downloaded again.')) {
|
||||
indexedDB.deleteDatabase(dbName);
|
||||
}
|
||||
}
|
||||
|
||||
// fetch a remote file from remote URL using the Fetch API
|
||||
async function fetchRemote(url, cbProgress, cbPrint) {
|
||||
cbPrint('fetchRemote: downloading with fetch()...');
|
||||
|
||||
const response = await fetch(
|
||||
url,
|
||||
{
|
||||
method: 'GET',
|
||||
headers: {
|
||||
'Content-Type': 'application/octet-stream',
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
cbPrint('fetchRemote: failed to fetch ' + url);
|
||||
return;
|
||||
}
|
||||
|
||||
const contentLength = response.headers.get('content-length');
|
||||
const total = parseInt(contentLength, 10);
|
||||
const reader = response.body.getReader();
|
||||
|
||||
var chunks = [];
|
||||
var receivedLength = 0;
|
||||
var progressLast = -1;
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
|
||||
if (done) {
|
||||
break;
|
||||
}
|
||||
|
||||
chunks.push(value);
|
||||
receivedLength += value.length;
|
||||
|
||||
if (contentLength) {
|
||||
cbProgress(receivedLength/total);
|
||||
|
||||
var progressCur = Math.round((receivedLength / total) * 10);
|
||||
if (progressCur != progressLast) {
|
||||
cbPrint('fetchRemote: fetching ' + 10*progressCur + '% ...');
|
||||
progressLast = progressCur;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var position = 0;
|
||||
var chunksAll = new Uint8Array(receivedLength);
|
||||
|
||||
for (var chunk of chunks) {
|
||||
chunksAll.set(chunk, position);
|
||||
position += chunk.length;
|
||||
}
|
||||
|
||||
return chunksAll;
|
||||
}
|
||||
|
||||
// load remote data
|
||||
// - check if the data is already in the IndexedDB
|
||||
// - if not, fetch it from the remote URL and store it in the IndexedDB
|
||||
function loadRemote(url, dst, size_mb, cbProgress, cbReady, cbCancel, cbPrint) {
|
||||
// query the storage quota and print it
|
||||
navigator.storage.estimate().then(function (estimate) {
|
||||
cbPrint('loadRemote: storage quota: ' + estimate.quota + ' bytes');
|
||||
cbPrint('loadRemote: storage usage: ' + estimate.usage + ' bytes');
|
||||
});
|
||||
|
||||
// check if the data is already in the IndexedDB
|
||||
var rq = indexedDB.open(dbName, dbVersion);
|
||||
|
||||
rq.onupgradeneeded = function (event) {
|
||||
var db = event.target.result;
|
||||
if (db.version == 1) {
|
||||
var os = db.createObjectStore('models', { autoIncrement: false });
|
||||
cbPrint('loadRemote: created IndexedDB ' + db.name + ' version ' + db.version);
|
||||
} else {
|
||||
// clear the database
|
||||
var os = event.currentTarget.transaction.objectStore('models');
|
||||
os.clear();
|
||||
cbPrint('loadRemote: cleared IndexedDB ' + db.name + ' version ' + db.version);
|
||||
}
|
||||
};
|
||||
|
||||
rq.onsuccess = function (event) {
|
||||
var db = event.target.result;
|
||||
var tx = db.transaction(['models'], 'readonly');
|
||||
var os = tx.objectStore('models');
|
||||
var rq = os.get(url);
|
||||
|
||||
rq.onsuccess = function (event) {
|
||||
if (rq.result) {
|
||||
cbPrint('loadRemote: "' + url + '" is already in the IndexedDB');
|
||||
cbReady(dst, rq.result);
|
||||
} else {
|
||||
// data is not in the IndexedDB
|
||||
cbPrint('loadRemote: "' + url + '" is not in the IndexedDB');
|
||||
|
||||
// alert and ask the user to confirm
|
||||
if (!confirm(
|
||||
'You are about to download ' + size_mb + ' MB of data.\n' +
|
||||
'The model data will be cached in the browser for future use.\n\n' +
|
||||
'Press OK to continue.')) {
|
||||
cbCancel();
|
||||
return;
|
||||
}
|
||||
|
||||
fetchRemote(url, cbProgress, cbPrint).then(function (data) {
|
||||
if (data) {
|
||||
// store the data in the IndexedDB
|
||||
var rq = indexedDB.open(dbName, dbVersion);
|
||||
rq.onsuccess = function (event) {
|
||||
var db = event.target.result;
|
||||
var tx = db.transaction(['models'], 'readwrite');
|
||||
var os = tx.objectStore('models');
|
||||
var rq = os.put(data, url);
|
||||
|
||||
rq.onsuccess = function (event) {
|
||||
cbPrint('loadRemote: "' + url + '" stored in the IndexedDB');
|
||||
cbReady(dst, data);
|
||||
};
|
||||
|
||||
rq.onerror = function (event) {
|
||||
cbPrint('loadRemote: failed to store "' + url + '" in the IndexedDB');
|
||||
cbCancel();
|
||||
};
|
||||
};
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
rq.onerror = function (event) {
|
||||
cbPrint('loadRemote: failed to get data from the IndexedDB');
|
||||
cbCancel();
|
||||
};
|
||||
};
|
||||
|
||||
rq.onerror = function (event) {
|
||||
cbPrint('loadRemote: failed to open IndexedDB');
|
||||
cbCancel();
|
||||
};
|
||||
|
||||
rq.onblocked = function (event) {
|
||||
cbPrint('loadRemote: failed to open IndexedDB: blocked');
|
||||
cbCancel();
|
||||
};
|
||||
|
||||
rq.onabort = function (event) {
|
||||
cbPrint('loadRemote: failed to open IndexedDB: abort');
|
||||
|
||||
};
|
||||
}
|
||||
|
112
examples/livestream.sh
Executable file
112
examples/livestream.sh
Executable file
@ -0,0 +1,112 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# Transcribe audio livestream by feeding ffmpeg output to whisper.cpp at regular intervals
|
||||
# Idea by @semiformal-net
|
||||
# ref: https://github.com/ggerganov/whisper.cpp/issues/185
|
||||
#
|
||||
|
||||
set -eo pipefail
|
||||
|
||||
url="http://a.files.bbci.co.uk/media/live/manifesto/audio/simulcast/hls/nonuk/sbr_low/ak/bbc_world_service.m3u8"
|
||||
fmt=aac # the audio format extension of the stream (TODO: auto detect)
|
||||
step_s=30
|
||||
model="base.en"
|
||||
|
||||
check_requirements()
|
||||
{
|
||||
if ! command -v ./main &>/dev/null; then
|
||||
echo "whisper.cpp main executable is required (make)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! command -v ffmpeg &>/dev/null; then
|
||||
echo "ffmpeg is required (https://ffmpeg.org)"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
check_requirements
|
||||
|
||||
|
||||
if [ -z "$1" ]; then
|
||||
echo "Usage: $0 stream_url [step_s] [model]"
|
||||
echo ""
|
||||
echo " Example:"
|
||||
echo " $0 $url $step_s $model"
|
||||
echo ""
|
||||
echo "No url specified, using default: $url"
|
||||
else
|
||||
url="$1"
|
||||
fi
|
||||
|
||||
if [ -n "$2" ]; then
|
||||
step_s="$2"
|
||||
fi
|
||||
|
||||
if [ -n "$3" ]; then
|
||||
model="$3"
|
||||
fi
|
||||
|
||||
# Whisper models
|
||||
models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small" "medium.en" "medium" "large-v1" "large" )
|
||||
|
||||
# list available models
|
||||
function list_models {
|
||||
printf "\n"
|
||||
printf " Available models:"
|
||||
for model in "${models[@]}"; do
|
||||
printf " $model"
|
||||
done
|
||||
printf "\n\n"
|
||||
}
|
||||
|
||||
if [[ ! " ${models[@]} " =~ " ${model} " ]]; then
|
||||
printf "Invalid model: $model\n"
|
||||
list_models
|
||||
|
||||
exit 1
|
||||
fi
|
||||
|
||||
running=1
|
||||
|
||||
trap "running=0" SIGINT SIGTERM
|
||||
|
||||
printf "[+] Transcribing stream with model '$model', step_s $step_s (press Ctrl+C to stop):\n\n"
|
||||
|
||||
# continuous stream in native fmt (this file will grow forever!)
|
||||
ffmpeg -loglevel quiet -y -re -probesize 32 -i $url -c copy /tmp/whisper-live0.${fmt} &
|
||||
if [ $? -ne 0 ]; then
|
||||
printf "Error: ffmpeg failed to capture audio stream\n"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
printf "Buffering audio. Please wait...\n\n"
|
||||
sleep $(($step_s))
|
||||
|
||||
# do not stop script on error
|
||||
set +e
|
||||
|
||||
i=0
|
||||
SECONDS=0
|
||||
while [ $running -eq 1 ]; do
|
||||
# extract the next piece from the main file above and transcode to wav. -ss sets start time and nudges it by -0.5s to catch missing words (??)
|
||||
err=1
|
||||
while [ $err -ne 0 ]; do
|
||||
if [ $i -gt 0 ]; then
|
||||
ffmpeg -loglevel quiet -v error -noaccurate_seek -i /tmp/whisper-live0.${fmt} -y -ar 16000 -ac 1 -c:a pcm_s16le -ss $(($i*$step_s-1)).5 -t $step_s /tmp/whisper-live.wav 2> /tmp/whisper-live.err
|
||||
else
|
||||
ffmpeg -loglevel quiet -v error -noaccurate_seek -i /tmp/whisper-live0.${fmt} -y -ar 16000 -ac 1 -c:a pcm_s16le -ss $(($i*$step_s)) -t $step_s /tmp/whisper-live.wav 2> /tmp/whisper-live.err
|
||||
fi
|
||||
err=$(cat /tmp/whisper-live.err | wc -l)
|
||||
done
|
||||
|
||||
./main -t 8 -m ./models/ggml-base.en.bin -f /tmp/whisper-live.wav --no-timestamps -otxt 2> /tmp/whispererr | tail -n 1
|
||||
|
||||
while [ $SECONDS -lt $((($i+1)*$step_s)) ]; do
|
||||
sleep 1
|
||||
done
|
||||
((i=i+1))
|
||||
done
|
||||
|
||||
killall -v ffmpeg
|
||||
killall -v main
|
6
examples/main/CMakeLists.txt
Normal file
6
examples/main/CMakeLists.txt
Normal file
@ -0,0 +1,6 @@
|
||||
set(TARGET main)
|
||||
add_executable(${TARGET} main.cpp)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE whisper ${CMAKE_THREAD_LIBS_INIT})
|
33
examples/main/README.md
Normal file
33
examples/main/README.md
Normal file
@ -0,0 +1,33 @@
|
||||
# main
|
||||
|
||||
This is the main example demonstrating most of the functionality of the Whisper model.
|
||||
It can be used as a reference for using the `whisper.cpp` library in other projects.
|
||||
|
||||
```
|
||||
./main -h
|
||||
|
||||
usage: ./main [options] file0.wav file1.wav ...
|
||||
|
||||
options:
|
||||
-h, --help [default] show this help message and exit
|
||||
-t N, --threads N [4 ] number of threads to use during computation
|
||||
-p N, --processors N [1 ] number of processors to use during computation
|
||||
-ot N, --offset-t N [0 ] time offset in milliseconds
|
||||
-on N, --offset-n N [0 ] segment index offset
|
||||
-d N, --duration N [0 ] duration of audio to process in milliseconds
|
||||
-mc N, --max-context N [-1 ] maximum number of text context tokens to store
|
||||
-ml N, --max-len N [0 ] maximum segment length in characters
|
||||
-wt N, --word-thold N [0.01 ] word timestamp probability threshold
|
||||
-su, --speed-up [false ] speed up audio by x2 (reduced accuracy)
|
||||
-tr, --translate [false ] translate from source language to english
|
||||
-otxt, --output-txt [false ] output result in a text file
|
||||
-ovtt, --output-vtt [false ] output result in a vtt file
|
||||
-osrt, --output-srt [false ] output result in a srt file
|
||||
-owts, --output-words [false ] output script for generating karaoke video
|
||||
-ps, --print-special [false ] print special tokens
|
||||
-pc, --print-colors [false ] print colors
|
||||
-nt, --no-timestamps [true ] do not print timestamps
|
||||
-l LANG, --language LANG [en ] spoken language
|
||||
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path
|
||||
-f FNAME, --file FNAME [ ] input WAV file path
|
||||
```
|
730
examples/main/main.cpp
Normal file
730
examples/main/main.cpp
Normal file
@ -0,0 +1,730 @@
|
||||
#include "whisper.h"
|
||||
|
||||
// third-party utilities
|
||||
// use your favorite implementations
|
||||
#define DR_WAV_IMPLEMENTATION
|
||||
#include "dr_wav.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <fstream>
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
// Terminal color map. 10 colors grouped in ranges [0.0, 0.1, ..., 0.9]
|
||||
// Lowest is red, middle is yellow, highest is green.
|
||||
const std::vector<std::string> k_colors = {
|
||||
"\033[38;5;196m", "\033[38;5;202m", "\033[38;5;208m", "\033[38;5;214m", "\033[38;5;220m",
|
||||
"\033[38;5;226m", "\033[38;5;190m", "\033[38;5;154m", "\033[38;5;118m", "\033[38;5;82m",
|
||||
};
|
||||
|
||||
// 500 -> 00:05.000
|
||||
// 6000 -> 01:00.000
|
||||
std::string to_timestamp(int64_t t, bool comma = false) {
|
||||
int64_t msec = t * 10;
|
||||
int64_t hr = msec / (1000 * 60 * 60);
|
||||
msec = msec - hr * (1000 * 60 * 60);
|
||||
int64_t min = msec / (1000 * 60);
|
||||
msec = msec - min * (1000 * 60);
|
||||
int64_t sec = msec / 1000;
|
||||
msec = msec - sec * 1000;
|
||||
|
||||
char buf[32];
|
||||
snprintf(buf, sizeof(buf), "%02d:%02d:%02d%s%03d", (int) hr, (int) min, (int) sec, comma ? "," : ".", (int) msec);
|
||||
|
||||
return std::string(buf);
|
||||
}
|
||||
|
||||
int timestamp_to_sample(int64_t t, int n_samples) {
|
||||
return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100)));
|
||||
}
|
||||
|
||||
// helper function to replace substrings
|
||||
void replace_all(std::string & s, const std::string & search, const std::string & replace) {
|
||||
for (size_t pos = 0; ; pos += replace.length()) {
|
||||
pos = s.find(search, pos);
|
||||
if (pos == std::string::npos) break;
|
||||
s.erase(pos, search.length());
|
||||
s.insert(pos, replace);
|
||||
}
|
||||
}
|
||||
|
||||
// command-line parameters
|
||||
struct whisper_params {
|
||||
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||
int32_t n_processors = 1;
|
||||
int32_t offset_t_ms = 0;
|
||||
int32_t offset_n = 0;
|
||||
int32_t duration_ms = 0;
|
||||
int32_t max_context = -1;
|
||||
int32_t max_len = 0;
|
||||
int32_t best_of = 5;
|
||||
int32_t beam_size = -1;
|
||||
|
||||
float word_thold = 0.01f;
|
||||
float entropy_thold = 2.4f;
|
||||
float logprob_thold = -1.0f;
|
||||
|
||||
bool speed_up = false;
|
||||
bool translate = false;
|
||||
bool diarize = false;
|
||||
bool output_txt = false;
|
||||
bool output_vtt = false;
|
||||
bool output_srt = false;
|
||||
bool output_wts = false;
|
||||
bool output_csv = false;
|
||||
bool print_special = false;
|
||||
bool print_colors = false;
|
||||
bool print_progress = false;
|
||||
bool no_timestamps = false;
|
||||
|
||||
std::string language = "en";
|
||||
std::string prompt;
|
||||
std::string model = "models/ggml-base.en.bin";
|
||||
|
||||
std::vector<std::string> fname_inp = {};
|
||||
};
|
||||
|
||||
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
|
||||
|
||||
bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
for (int i = 1; i < argc; i++) {
|
||||
std::string arg = argv[i];
|
||||
|
||||
if (arg[0] != '-') {
|
||||
params.fname_inp.push_back(arg);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (arg == "-h" || arg == "--help") {
|
||||
whisper_print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
}
|
||||
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
|
||||
else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(argv[++i]); }
|
||||
else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(argv[++i]); }
|
||||
else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(argv[++i]); }
|
||||
else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); }
|
||||
else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); }
|
||||
else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); }
|
||||
else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); }
|
||||
else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); }
|
||||
else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
|
||||
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
|
||||
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
|
||||
else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; }
|
||||
else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
|
||||
else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
|
||||
else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; }
|
||||
else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; }
|
||||
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
|
||||
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
|
||||
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
|
||||
else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
|
||||
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
|
||||
else if ( arg == "--prompt") { params.prompt = argv[++i]; }
|
||||
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
|
||||
else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); }
|
||||
else {
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
whisper_print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params) {
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "usage: %s [options] file0.wav file1.wav ...\n", argv[0]);
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "options:\n");
|
||||
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
|
||||
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
|
||||
fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors);
|
||||
fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms);
|
||||
fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n);
|
||||
fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms);
|
||||
fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
|
||||
fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len);
|
||||
fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of);
|
||||
fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size);
|
||||
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
|
||||
fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold);
|
||||
fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold);
|
||||
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
|
||||
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
|
||||
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
|
||||
fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false");
|
||||
fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
|
||||
fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false");
|
||||
fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false");
|
||||
fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false");
|
||||
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
||||
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
|
||||
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
|
||||
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true");
|
||||
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
|
||||
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
|
||||
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
||||
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
struct whisper_print_user_data {
|
||||
const whisper_params * params;
|
||||
|
||||
const std::vector<std::vector<float>> * pcmf32s;
|
||||
};
|
||||
|
||||
void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) {
|
||||
const auto & params = *((whisper_print_user_data *) user_data)->params;
|
||||
const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s;
|
||||
|
||||
const int n_segments = whisper_full_n_segments(ctx);
|
||||
|
||||
std::string speaker = "";
|
||||
|
||||
int64_t t0;
|
||||
int64_t t1;
|
||||
|
||||
// print the last n_new segments
|
||||
const int s0 = n_segments - n_new;
|
||||
|
||||
if (s0 == 0) {
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
for (int i = s0; i < n_segments; i++) {
|
||||
if (!params.no_timestamps || params.diarize) {
|
||||
t0 = whisper_full_get_segment_t0(ctx, i);
|
||||
t1 = whisper_full_get_segment_t1(ctx, i);
|
||||
}
|
||||
|
||||
if (!params.no_timestamps) {
|
||||
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
|
||||
}
|
||||
|
||||
if (params.diarize && pcmf32s.size() == 2) {
|
||||
const int64_t n_samples = pcmf32s[0].size();
|
||||
|
||||
const int64_t is0 = timestamp_to_sample(t0, n_samples);
|
||||
const int64_t is1 = timestamp_to_sample(t1, n_samples);
|
||||
|
||||
double energy0 = 0.0f;
|
||||
double energy1 = 0.0f;
|
||||
|
||||
for (int64_t j = is0; j < is1; j++) {
|
||||
energy0 += fabs(pcmf32s[0][j]);
|
||||
energy1 += fabs(pcmf32s[1][j]);
|
||||
}
|
||||
|
||||
if (energy0 > 1.1*energy1) {
|
||||
speaker = "(speaker 0)";
|
||||
} else if (energy1 > 1.1*energy0) {
|
||||
speaker = "(speaker 1)";
|
||||
} else {
|
||||
speaker = "(speaker ?)";
|
||||
}
|
||||
|
||||
//printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str());
|
||||
}
|
||||
|
||||
if (params.print_colors) {
|
||||
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
|
||||
if (params.print_special == false) {
|
||||
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
|
||||
if (id >= whisper_token_eot(ctx)) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
const char * text = whisper_full_get_token_text(ctx, i, j);
|
||||
const float p = whisper_full_get_token_p (ctx, i, j);
|
||||
|
||||
const int col = std::max(0, std::min((int) k_colors.size() - 1, (int) (std::pow(p, 3)*float(k_colors.size()))));
|
||||
|
||||
printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
|
||||
}
|
||||
} else {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
|
||||
printf("%s%s", speaker.c_str(), text);
|
||||
}
|
||||
|
||||
// with timestamps or speakers: each segment on new line
|
||||
if (!params.no_timestamps || params.diarize) {
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
fflush(stdout);
|
||||
}
|
||||
}
|
||||
|
||||
bool output_txt(struct whisper_context * ctx, const char * fname) {
|
||||
std::ofstream fout(fname);
|
||||
if (!fout.is_open()) {
|
||||
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
|
||||
return false;
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
|
||||
|
||||
const int n_segments = whisper_full_n_segments(ctx);
|
||||
for (int i = 0; i < n_segments; ++i) {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
fout << text << "\n";
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool output_vtt(struct whisper_context * ctx, const char * fname) {
|
||||
std::ofstream fout(fname);
|
||||
if (!fout.is_open()) {
|
||||
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
|
||||
return false;
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
|
||||
|
||||
fout << "WEBVTT\n\n";
|
||||
|
||||
const int n_segments = whisper_full_n_segments(ctx);
|
||||
for (int i = 0; i < n_segments; ++i) {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
|
||||
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
||||
|
||||
fout << to_timestamp(t0) << " --> " << to_timestamp(t1) << "\n";
|
||||
fout << text << "\n\n";
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_params & params) {
|
||||
std::ofstream fout(fname);
|
||||
if (!fout.is_open()) {
|
||||
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
|
||||
return false;
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
|
||||
|
||||
const int n_segments = whisper_full_n_segments(ctx);
|
||||
for (int i = 0; i < n_segments; ++i) {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
|
||||
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
||||
|
||||
fout << i + 1 + params.offset_n << "\n";
|
||||
fout << to_timestamp(t0, true) << " --> " << to_timestamp(t1, true) << "\n";
|
||||
fout << text << "\n\n";
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool output_csv(struct whisper_context * ctx, const char * fname) {
|
||||
std::ofstream fout(fname);
|
||||
if (!fout.is_open()) {
|
||||
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
|
||||
return false;
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
|
||||
|
||||
const int n_segments = whisper_full_n_segments(ctx);
|
||||
for (int i = 0; i < n_segments; ++i) {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
if (text[0] == ' ') {
|
||||
text = text + sizeof(char); //whisper_full_get_segment_text() returns a string with leading space, point to the next character.
|
||||
}
|
||||
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
|
||||
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
||||
|
||||
//need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds.
|
||||
fout << 10 * t0 << ", " << 10 * t1 << ", \"" << text << "\"\n";
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// karaoke video generation
|
||||
// outputs a bash script that uses ffmpeg to generate a video with the subtitles
|
||||
// TODO: font parameter adjustments
|
||||
bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & /*params*/, float t_sec) {
|
||||
std::ofstream fout(fname);
|
||||
|
||||
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
|
||||
|
||||
// TODO: become parameter
|
||||
static const char * font = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
|
||||
|
||||
fout << "#!/bin/bash" << "\n";
|
||||
fout << "\n";
|
||||
|
||||
fout << "ffmpeg -i " << fname_inp << " -f lavfi -i color=size=1200x120:duration=" << t_sec << ":rate=25:color=black -vf \"";
|
||||
|
||||
for (int i = 0; i < whisper_full_n_segments(ctx); i++) {
|
||||
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
|
||||
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
||||
|
||||
const int n = whisper_full_n_tokens(ctx, i);
|
||||
|
||||
std::vector<whisper_token_data> tokens(n);
|
||||
for (int j = 0; j < n; ++j) {
|
||||
tokens[j] = whisper_full_get_token_data(ctx, i, j);
|
||||
}
|
||||
|
||||
if (i > 0) {
|
||||
fout << ",";
|
||||
}
|
||||
|
||||
// background text
|
||||
fout << "drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='':enable='between(t," << t0/100.0 << "," << t0/100.0 << ")'";
|
||||
|
||||
bool is_first = true;
|
||||
|
||||
for (int j = 0; j < n; ++j) {
|
||||
const auto & token = tokens[j];
|
||||
|
||||
if (tokens[j].id >= whisper_token_eot(ctx)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::string txt_bg;
|
||||
std::string txt_fg; // highlight token
|
||||
std::string txt_ul; // underline
|
||||
|
||||
txt_bg = "> ";
|
||||
txt_fg = "> ";
|
||||
txt_ul = "\\ \\ ";
|
||||
|
||||
{
|
||||
for (int k = 0; k < n; ++k) {
|
||||
const auto & token2 = tokens[k];
|
||||
|
||||
if (tokens[k].id >= whisper_token_eot(ctx)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const std::string txt = whisper_token_to_str(ctx, token2.id);
|
||||
|
||||
txt_bg += txt;
|
||||
|
||||
if (k == j) {
|
||||
for (int l = 0; l < (int) txt.size(); ++l) {
|
||||
txt_fg += txt[l];
|
||||
txt_ul += "_";
|
||||
}
|
||||
txt_fg += "|";
|
||||
} else {
|
||||
for (int l = 0; l < (int) txt.size(); ++l) {
|
||||
txt_fg += "\\ ";
|
||||
txt_ul += "\\ ";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
::replace_all(txt_bg, "'", "\u2019");
|
||||
::replace_all(txt_bg, "\"", "\\\"");
|
||||
::replace_all(txt_fg, "'", "\u2019");
|
||||
::replace_all(txt_fg, "\"", "\\\"");
|
||||
}
|
||||
|
||||
if (is_first) {
|
||||
// background text
|
||||
fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='" << txt_bg << "':enable='between(t," << t0/100.0 << "," << t1/100.0 << ")'";
|
||||
is_first = false;
|
||||
}
|
||||
|
||||
// foreground text
|
||||
fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=lightgreen:x=(w-text_w)/2+8:y=h/2:text='" << txt_fg << "':enable='between(t," << token.t0/100.0 << "," << token.t1/100.0 << ")'";
|
||||
|
||||
// underline
|
||||
fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=lightgreen:x=(w-text_w)/2+8:y=h/2+16:text='" << txt_ul << "':enable='between(t," << token.t0/100.0 << "," << token.t1/100.0 << ")'";
|
||||
}
|
||||
}
|
||||
|
||||
fout << "\" -c:v libx264 -pix_fmt yuv420p -y " << fname_inp << ".mp4" << "\n";
|
||||
|
||||
fout << "\n\n";
|
||||
fout << "echo \"Your video has been saved to " << fname_inp << ".mp4\"" << "\n";
|
||||
fout << "\n";
|
||||
fout << "echo \" ffplay " << fname_inp << ".mp4\"\n";
|
||||
fout << "\n";
|
||||
|
||||
fout.close();
|
||||
|
||||
fprintf(stderr, "%s: run 'source %s' to generate karaoke video\n", __func__, fname);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
whisper_params params;
|
||||
|
||||
if (whisper_params_parse(argc, argv, params) == false) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (params.fname_inp.empty()) {
|
||||
fprintf(stderr, "error: no input files specified\n");
|
||||
whisper_print_usage(argc, argv, params);
|
||||
return 2;
|
||||
}
|
||||
|
||||
if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1) {
|
||||
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
|
||||
whisper_print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
}
|
||||
|
||||
// whisper init
|
||||
|
||||
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
|
||||
|
||||
if (ctx == nullptr) {
|
||||
fprintf(stderr, "error: failed to initialize whisper context\n");
|
||||
return 3;
|
||||
}
|
||||
|
||||
// initial prompt
|
||||
std::vector<whisper_token> prompt_tokens;
|
||||
|
||||
if (!params.prompt.empty()) {
|
||||
prompt_tokens.resize(1024);
|
||||
prompt_tokens.resize(whisper_tokenize(ctx, params.prompt.c_str(), prompt_tokens.data(), prompt_tokens.size()));
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "initial prompt: '%s'\n", params.prompt.c_str());
|
||||
fprintf(stderr, "initial tokens: [ ");
|
||||
for (int i = 0; i < (int) prompt_tokens.size(); ++i) {
|
||||
fprintf(stderr, "%d ", prompt_tokens[i]);
|
||||
}
|
||||
fprintf(stderr, "]\n");
|
||||
}
|
||||
|
||||
for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
|
||||
const auto fname_inp = params.fname_inp[f];
|
||||
|
||||
std::vector<float> pcmf32; // mono-channel F32 PCM
|
||||
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
|
||||
|
||||
// WAV input
|
||||
{
|
||||
drwav wav;
|
||||
std::vector<uint8_t> wav_data; // used for pipe input from stdin
|
||||
|
||||
if (fname_inp == "-") {
|
||||
{
|
||||
uint8_t buf[1024];
|
||||
while (true)
|
||||
{
|
||||
const size_t n = fread(buf, 1, sizeof(buf), stdin);
|
||||
if (n == 0) {
|
||||
break;
|
||||
}
|
||||
wav_data.insert(wav_data.end(), buf, buf + n);
|
||||
}
|
||||
}
|
||||
|
||||
if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) {
|
||||
fprintf(stderr, "error: failed to open WAV file from stdin\n");
|
||||
return 4;
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
|
||||
}
|
||||
else if (drwav_init_file(&wav, fname_inp.c_str(), nullptr) == false) {
|
||||
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
|
||||
return 5;
|
||||
}
|
||||
|
||||
if (wav.channels != 1 && wav.channels != 2) {
|
||||
fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", argv[0], fname_inp.c_str());
|
||||
return 6;
|
||||
}
|
||||
|
||||
if (params.diarize && wav.channels != 2 && params.no_timestamps == false) {
|
||||
fprintf(stderr, "%s: WAV file '%s' must be stereo for diarization and timestamps have to be enabled\n", argv[0], fname_inp.c_str());
|
||||
return 6;
|
||||
}
|
||||
|
||||
if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
|
||||
fprintf(stderr, "%s: WAV file '%s' must be %i kHz\n", argv[0], fname_inp.c_str(), WHISPER_SAMPLE_RATE/1000);
|
||||
return 8;
|
||||
}
|
||||
|
||||
if (wav.bitsPerSample != 16) {
|
||||
fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", argv[0], fname_inp.c_str());
|
||||
return 9;
|
||||
}
|
||||
|
||||
const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8);
|
||||
|
||||
std::vector<int16_t> pcm16;
|
||||
pcm16.resize(n*wav.channels);
|
||||
drwav_read_pcm_frames_s16(&wav, n, pcm16.data());
|
||||
drwav_uninit(&wav);
|
||||
|
||||
// convert to mono, float
|
||||
pcmf32.resize(n);
|
||||
if (wav.channels == 1) {
|
||||
for (uint64_t i = 0; i < n; i++) {
|
||||
pcmf32[i] = float(pcm16[i])/32768.0f;
|
||||
}
|
||||
} else {
|
||||
for (uint64_t i = 0; i < n; i++) {
|
||||
pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
|
||||
}
|
||||
}
|
||||
|
||||
if (params.diarize) {
|
||||
// convert to stereo, float
|
||||
pcmf32s.resize(2);
|
||||
|
||||
pcmf32s[0].resize(n);
|
||||
pcmf32s[1].resize(n);
|
||||
for (uint64_t i = 0; i < n; i++) {
|
||||
pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
|
||||
pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// print system information
|
||||
{
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
|
||||
params.n_threads*params.n_processors, std::thread::hardware_concurrency(), whisper_print_system_info());
|
||||
}
|
||||
|
||||
// print some info about the processing
|
||||
{
|
||||
fprintf(stderr, "\n");
|
||||
if (!whisper_is_multilingual(ctx)) {
|
||||
if (params.language != "en" || params.translate) {
|
||||
params.language = "en";
|
||||
params.translate = false;
|
||||
fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
|
||||
}
|
||||
}
|
||||
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, timestamps = %d ...\n",
|
||||
__func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
|
||||
params.n_threads, params.n_processors,
|
||||
params.language.c_str(),
|
||||
params.translate ? "translate" : "transcribe",
|
||||
params.no_timestamps ? 0 : 1);
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
// run the inference
|
||||
{
|
||||
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||
|
||||
wparams.strategy = params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY;
|
||||
|
||||
wparams.print_realtime = false;
|
||||
wparams.print_progress = params.print_progress;
|
||||
wparams.print_timestamps = !params.no_timestamps;
|
||||
wparams.print_special = params.print_special;
|
||||
wparams.translate = params.translate;
|
||||
wparams.language = params.language.c_str();
|
||||
wparams.n_threads = params.n_threads;
|
||||
wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
|
||||
wparams.offset_ms = params.offset_t_ms;
|
||||
wparams.duration_ms = params.duration_ms;
|
||||
|
||||
wparams.token_timestamps = params.output_wts || params.max_len > 0;
|
||||
wparams.thold_pt = params.word_thold;
|
||||
wparams.entropy_thold = params.entropy_thold;
|
||||
wparams.logprob_thold = params.logprob_thold;
|
||||
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
|
||||
|
||||
wparams.speed_up = params.speed_up;
|
||||
|
||||
wparams.greedy.best_of = params.best_of;
|
||||
wparams.beam_search.beam_size = params.beam_size;
|
||||
wparams.temperature_inc = -1;
|
||||
|
||||
wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
|
||||
wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();
|
||||
|
||||
whisper_print_user_data user_data = { ¶ms, &pcmf32s };
|
||||
|
||||
// this callback is called on each new segment
|
||||
if (!wparams.print_realtime) {
|
||||
wparams.new_segment_callback = whisper_print_segment_callback;
|
||||
wparams.new_segment_callback_user_data = &user_data;
|
||||
}
|
||||
|
||||
// example for abort mechanism
|
||||
// in this example, we do not abort the processing, but we could if the flag is set to true
|
||||
// the callback is called before every encoder run - if it returns false, the processing is aborted
|
||||
{
|
||||
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
|
||||
|
||||
wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, void * user_data) {
|
||||
bool is_aborted = *(bool*)user_data;
|
||||
return !is_aborted;
|
||||
};
|
||||
wparams.encoder_begin_callback_user_data = &is_aborted;
|
||||
}
|
||||
|
||||
if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
|
||||
fprintf(stderr, "%s: failed to process audio\n", argv[0]);
|
||||
return 10;
|
||||
}
|
||||
}
|
||||
|
||||
// output stuff
|
||||
{
|
||||
printf("\n");
|
||||
|
||||
// output to text file
|
||||
if (params.output_txt) {
|
||||
const auto fname_txt = fname_inp + ".txt";
|
||||
output_txt(ctx, fname_txt.c_str());
|
||||
}
|
||||
|
||||
// output to VTT file
|
||||
if (params.output_vtt) {
|
||||
const auto fname_vtt = fname_inp + ".vtt";
|
||||
output_vtt(ctx, fname_vtt.c_str());
|
||||
}
|
||||
|
||||
// output to SRT file
|
||||
if (params.output_srt) {
|
||||
const auto fname_srt = fname_inp + ".srt";
|
||||
output_srt(ctx, fname_srt.c_str(), params);
|
||||
}
|
||||
|
||||
// output to WTS file
|
||||
if (params.output_wts) {
|
||||
const auto fname_wts = fname_inp + ".wts";
|
||||
output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE);
|
||||
}
|
||||
|
||||
// output to CSV file
|
||||
if (params.output_csv) {
|
||||
const auto fname_csv = fname_inp + ".csv";
|
||||
output_csv(ctx, fname_csv.c_str());
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
whisper_print_timings(ctx);
|
||||
whisper_free(ctx);
|
||||
|
||||
return 0;
|
||||
}
|
49
examples/stream.wasm/CMakeLists.txt
Normal file
49
examples/stream.wasm/CMakeLists.txt
Normal file
@ -0,0 +1,49 @@
|
||||
#
|
||||
# libstream
|
||||
#
|
||||
|
||||
set(TARGET libstream)
|
||||
|
||||
add_executable(${TARGET}
|
||||
emscripten.cpp
|
||||
)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE
|
||||
whisper
|
||||
)
|
||||
|
||||
unset(EXTRA_FLAGS)
|
||||
|
||||
if (WHISPER_WASM_SINGLE_FILE)
|
||||
set(EXTRA_FLAGS "-s SINGLE_FILE=1")
|
||||
message(STATUS "Embedding WASM inside stream.js")
|
||||
|
||||
add_custom_command(
|
||||
TARGET ${TARGET} POST_BUILD
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${CMAKE_BINARY_DIR}/bin/libstream.js
|
||||
${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/stream.wasm/stream.js
|
||||
)
|
||||
endif()
|
||||
|
||||
set_target_properties(${TARGET} PROPERTIES LINK_FLAGS " \
|
||||
--bind \
|
||||
-s USE_PTHREADS=1 \
|
||||
-s PTHREAD_POOL_SIZE=8 \
|
||||
-s INITIAL_MEMORY=1024MB \
|
||||
-s TOTAL_MEMORY=1024MB \
|
||||
-s FORCE_FILESYSTEM=1 \
|
||||
-s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap']\" \
|
||||
${EXTRA_FLAGS} \
|
||||
")
|
||||
|
||||
#
|
||||
# stream.wasm
|
||||
#
|
||||
|
||||
set(TARGET stream.wasm)
|
||||
|
||||
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/index-tmpl.html ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TARGET}/index.html @ONLY)
|
||||
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/../helpers.js ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TARGET}/helpers.js @ONLY)
|
20
examples/stream.wasm/README.md
Normal file
20
examples/stream.wasm/README.md
Normal file
@ -0,0 +1,20 @@
|
||||
# stream.wasm
|
||||
|
||||
Real-time transcription in the browser using WebAssembly
|
||||
|
||||
Online demo: https://whisper.ggerganov.com/stream/
|
||||
|
||||
## Build instructions
|
||||
|
||||
```bash
|
||||
# build using Emscripten (v3.1.2)
|
||||
git clone https://github.com/ggerganov/whisper.cpp
|
||||
cd whisper.cpp
|
||||
mkdir build-em && cd build-em
|
||||
emcmake cmake ..
|
||||
make -j
|
||||
|
||||
# copy the produced page to your HTTP path
|
||||
cp bin/stream.wasm/* /path/to/html/
|
||||
cp bin/libstream.worker.js /path/to/html/
|
||||
```
|
216
examples/stream.wasm/emscripten.cpp
Normal file
216
examples/stream.wasm/emscripten.cpp
Normal file
@ -0,0 +1,216 @@
|
||||
#include "ggml.h"
|
||||
#include "whisper.h"
|
||||
|
||||
#include <emscripten.h>
|
||||
#include <emscripten/bind.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <cmath>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
constexpr int N_THREAD = 8;
|
||||
|
||||
std::vector<struct whisper_context *> g_contexts(4, nullptr);
|
||||
|
||||
std::mutex g_mutex;
|
||||
std::thread g_worker;
|
||||
|
||||
std::atomic<bool> g_running(false);
|
||||
|
||||
std::string g_status = "";
|
||||
std::string g_status_forced = "";
|
||||
std::string g_transcribed = "";
|
||||
|
||||
std::vector<float> g_pcmf32;
|
||||
|
||||
void stream_set_status(const std::string & status) {
|
||||
std::lock_guard<std::mutex> lock(g_mutex);
|
||||
g_status = status;
|
||||
}
|
||||
|
||||
void stream_main(size_t index) {
|
||||
stream_set_status("loading data ...");
|
||||
|
||||
struct whisper_full_params wparams = whisper_full_default_params(whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY);
|
||||
|
||||
wparams.n_threads = std::min(N_THREAD, (int) std::thread::hardware_concurrency());
|
||||
wparams.offset_ms = 0;
|
||||
wparams.translate = false;
|
||||
wparams.no_context = true;
|
||||
wparams.single_segment = true;
|
||||
wparams.print_realtime = false;
|
||||
wparams.print_progress = false;
|
||||
wparams.print_timestamps = true;
|
||||
wparams.print_special = false;
|
||||
|
||||
wparams.max_tokens = 32;
|
||||
wparams.audio_ctx = 768; // partial encoder context for better performance
|
||||
|
||||
// disable temperature fallback
|
||||
wparams.temperature_inc = -1.0f;
|
||||
|
||||
wparams.language = "en";
|
||||
|
||||
printf("stream: using %d threads\n", wparams.n_threads);
|
||||
|
||||
std::vector<float> pcmf32;
|
||||
|
||||
// whisper context
|
||||
auto & ctx = g_contexts[index];
|
||||
|
||||
// 5 seconds interval
|
||||
const int64_t window_samples = 5*WHISPER_SAMPLE_RATE;
|
||||
|
||||
while (g_running) {
|
||||
stream_set_status("waiting for audio ...");
|
||||
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(g_mutex);
|
||||
|
||||
if (g_pcmf32.size() < 1024) {
|
||||
lock.unlock();
|
||||
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(10));
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
pcmf32 = std::vector<float>(g_pcmf32.end() - std::min((int64_t) g_pcmf32.size(), window_samples), g_pcmf32.end());
|
||||
g_pcmf32.clear();
|
||||
}
|
||||
|
||||
{
|
||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
stream_set_status("running whisper ...");
|
||||
|
||||
int ret = whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size());
|
||||
if (ret != 0) {
|
||||
printf("whisper_full() failed: %d\n", ret);
|
||||
break;
|
||||
}
|
||||
|
||||
const auto t_end = std::chrono::high_resolution_clock::now();
|
||||
|
||||
printf("stream: whisper_full() returned %d in %f seconds\n", ret, std::chrono::duration<double>(t_end - t_start).count());
|
||||
}
|
||||
|
||||
{
|
||||
std::string text_heard;
|
||||
|
||||
{
|
||||
const int n_segments = whisper_full_n_segments(ctx);
|
||||
for (int i = n_segments - 1; i < n_segments; ++i) {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
|
||||
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
|
||||
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
||||
|
||||
printf("transcribed: %s\n", text);
|
||||
|
||||
text_heard += text;
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(g_mutex);
|
||||
g_transcribed = text_heard;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (index < g_contexts.size()) {
|
||||
whisper_free(g_contexts[index]);
|
||||
g_contexts[index] = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
EMSCRIPTEN_BINDINGS(stream) {
|
||||
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
|
||||
for (size_t i = 0; i < g_contexts.size(); ++i) {
|
||||
if (g_contexts[i] == nullptr) {
|
||||
g_contexts[i] = whisper_init_from_file(path_model.c_str());
|
||||
if (g_contexts[i] != nullptr) {
|
||||
g_running = true;
|
||||
if (g_worker.joinable()) {
|
||||
g_worker.join();
|
||||
}
|
||||
g_worker = std::thread([i]() {
|
||||
stream_main(i);
|
||||
});
|
||||
|
||||
return i + 1;
|
||||
} else {
|
||||
return (size_t) 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return (size_t) 0;
|
||||
}));
|
||||
|
||||
emscripten::function("free", emscripten::optional_override([](size_t index) {
|
||||
if (g_running) {
|
||||
g_running = false;
|
||||
}
|
||||
}));
|
||||
|
||||
emscripten::function("set_audio", emscripten::optional_override([](size_t index, const emscripten::val & audio) {
|
||||
--index;
|
||||
|
||||
if (index >= g_contexts.size()) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (g_contexts[index] == nullptr) {
|
||||
return -2;
|
||||
}
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(g_mutex);
|
||||
const int n = audio["length"].as<int>();
|
||||
|
||||
emscripten::val heap = emscripten::val::module_property("HEAPU8");
|
||||
emscripten::val memory = heap["buffer"];
|
||||
|
||||
g_pcmf32.resize(n);
|
||||
|
||||
emscripten::val memoryView = audio["constructor"].new_(memory, reinterpret_cast<uintptr_t>(g_pcmf32.data()), n);
|
||||
memoryView.call<void>("set", audio);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}));
|
||||
|
||||
emscripten::function("get_transcribed", emscripten::optional_override([]() {
|
||||
std::string transcribed;
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(g_mutex);
|
||||
transcribed = std::move(g_transcribed);
|
||||
}
|
||||
|
||||
return transcribed;
|
||||
}));
|
||||
|
||||
emscripten::function("get_status", emscripten::optional_override([]() {
|
||||
std::string status;
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(g_mutex);
|
||||
status = g_status_forced.empty() ? g_status : g_status_forced;
|
||||
}
|
||||
|
||||
return status;
|
||||
}));
|
||||
|
||||
emscripten::function("set_status", emscripten::optional_override([](const std::string & status) {
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(g_mutex);
|
||||
g_status_forced = status;
|
||||
}
|
||||
}));
|
||||
}
|
386
examples/stream.wasm/index-tmpl.html
Normal file
386
examples/stream.wasm/index-tmpl.html
Normal file
@ -0,0 +1,386 @@
|
||||
<!doctype html>
|
||||
<html lang="en-us">
|
||||
<head>
|
||||
<title>stream : Real-time Whisper transcription in WebAssembly</title>
|
||||
|
||||
<style>
|
||||
#output {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
margin: 0 auto;
|
||||
margin-top: 10px;
|
||||
border-left: 0px;
|
||||
border-right: 0px;
|
||||
padding-left: 0px;
|
||||
padding-right: 0px;
|
||||
display: block;
|
||||
background-color: black;
|
||||
color: white;
|
||||
font-size: 10px;
|
||||
font-family: 'Lucida Console', Monaco, monospace;
|
||||
outline: none;
|
||||
white-space: pre;
|
||||
overflow-wrap: normal;
|
||||
overflow-x: scroll;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div id="main-container">
|
||||
<b>stream : Real-time Whisper transcription in WebAssembly</b>
|
||||
|
||||
<br><br>
|
||||
|
||||
You can find more about this project on <a href="https://github.com/ggerganov/whisper.cpp/tree/master/examples/stream.wasm">GitHub</a>.
|
||||
|
||||
<br><br>
|
||||
|
||||
<hr>
|
||||
|
||||
Select the model you would like to use, click the "Start" button and start speaking
|
||||
|
||||
<br><br>
|
||||
|
||||
<div id="model-whisper">
|
||||
Whisper model: <span id="model-whisper-status"></span>
|
||||
<button id="fetch-whisper-tiny-en" onclick="loadWhisper('tiny.en')">tiny.en (75 MB)</button>
|
||||
<button id="fetch-whisper-base-en" onclick="loadWhisper('base.en')">base.en (142 MB)</button>
|
||||
<span id="fetch-whisper-progress"></span>
|
||||
|
||||
<!--
|
||||
<input type="file" id="file" name="file" onchange="loadFile(event, 'whisper.bin')" />
|
||||
-->
|
||||
</div>
|
||||
|
||||
<br>
|
||||
|
||||
<div id="input">
|
||||
<button id="start" onclick="onStart()" disabled>Start</button>
|
||||
<button id="stop" onclick="onStop()" disabled>Stop</button>
|
||||
<button id="clear" onclick="clearCache()">Clear Cache</button>
|
||||
</div>
|
||||
|
||||
<br>
|
||||
|
||||
<div id="state">
|
||||
Status: <b><span id="state-status">not started</span></b>
|
||||
|
||||
<pre id="state-transcribed">[The transcribed text will be displayed here]</pre>
|
||||
</div>
|
||||
|
||||
<hr>
|
||||
|
||||
Debug output:
|
||||
<textarea id="output" rows="20"></textarea>
|
||||
|
||||
<br>
|
||||
|
||||
<b>Troubleshooting</b>
|
||||
|
||||
<br><br>
|
||||
|
||||
The page does some heavy computations, so make sure:
|
||||
|
||||
<ul>
|
||||
<li>To use a modern web browser (e.g. Chrome, Firefox)</li>
|
||||
<li>To use a fast desktop or laptop computer (i.e. not a mobile phone)</li>
|
||||
<li>Your browser supports WASM <a href="https://webassembly.org/roadmap/">Fixed-width SIMD</a></li>
|
||||
</ul>
|
||||
|
||||
<div class="cell-version">
|
||||
<span>
|
||||
|
|
||||
Build time: <span class="nav-link">@GIT_DATE@</span> |
|
||||
Commit hash: <a class="nav-link" href="https://github.com/ggerganov/whisper.cpp/commit/@GIT_SHA1@">@GIT_SHA1@</a> |
|
||||
Commit subject: <span class="nav-link">@GIT_COMMIT_SUBJECT@</span> |
|
||||
<a class="nav-link" href="https://github.com/ggerganov/whisper.cpp/tree/master/examples/stream.wasm">Source Code</a> |
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script type="text/javascript" src="helpers.js"></script>
|
||||
<script type='text/javascript'>
|
||||
// web audio context
|
||||
var context = null;
|
||||
|
||||
// audio data
|
||||
var audio = null;
|
||||
var audio0 = null;
|
||||
|
||||
// the stream instance
|
||||
var instance = null;
|
||||
|
||||
// model name
|
||||
var model_whisper = null;
|
||||
|
||||
var Module = {
|
||||
print: printTextarea,
|
||||
printErr: printTextarea,
|
||||
setStatus: function(text) {
|
||||
printTextarea('js: ' + text);
|
||||
},
|
||||
monitorRunDependencies: function(left) {
|
||||
},
|
||||
preRun: function() {
|
||||
printTextarea('js: Preparing ...');
|
||||
},
|
||||
postRun: function() {
|
||||
printTextarea('js: Initialized successfully!');
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// fetch models
|
||||
//
|
||||
|
||||
let dbVersion = 1
|
||||
let dbName = 'whisper.ggerganov.com';
|
||||
let indexedDB = window.indexedDB || window.mozIndexedDB || window.webkitIndexedDB || window.msIndexedDB
|
||||
|
||||
function storeFS(fname, buf) {
|
||||
// write to WASM file using FS_createDataFile
|
||||
// if the file exists, delete it
|
||||
try {
|
||||
Module.FS_unlink(fname);
|
||||
} catch (e) {
|
||||
// ignore
|
||||
}
|
||||
|
||||
Module.FS_createDataFile("/", fname, buf, true, true);
|
||||
|
||||
printTextarea('storeFS: stored model: ' + fname + ' size: ' + buf.length);
|
||||
|
||||
document.getElementById('model-whisper-status').innerHTML = 'loaded "' + model_whisper + '"!';
|
||||
|
||||
if (model_whisper != null) {
|
||||
document.getElementById('start').disabled = false;
|
||||
document.getElementById('stop' ).disabled = true;
|
||||
}
|
||||
}
|
||||
|
||||
function loadWhisper(model) {
|
||||
let urls = {
|
||||
'tiny.en': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en.bin',
|
||||
'base.en': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en.bin',
|
||||
};
|
||||
|
||||
let sizes = {
|
||||
'tiny.en': 75,
|
||||
'base.en': 142,
|
||||
};
|
||||
|
||||
let url = urls[model];
|
||||
let dst = 'whisper.bin';
|
||||
let size_mb = sizes[model];
|
||||
|
||||
model_whisper = model;
|
||||
|
||||
document.getElementById('fetch-whisper-tiny-en').style.display = 'none';
|
||||
document.getElementById('fetch-whisper-base-en').style.display = 'none';
|
||||
document.getElementById('model-whisper-status').innerHTML = 'loading "' + model + '" ... ';
|
||||
|
||||
cbProgress = function(p) {
|
||||
let el = document.getElementById('fetch-whisper-progress');
|
||||
el.innerHTML = Math.round(100*p) + '%';
|
||||
};
|
||||
|
||||
cbCancel = function() {
|
||||
var el;
|
||||
el = document.getElementById('fetch-whisper-tiny-en'); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('fetch-whisper-base-en'); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('model-whisper-status'); if (el) el.innerHTML = '';
|
||||
};
|
||||
|
||||
loadRemote(url, dst, size_mb, cbProgress, storeFS, cbCancel, printTextarea);
|
||||
}
|
||||
|
||||
//
|
||||
// microphone
|
||||
//
|
||||
|
||||
const kSampleRate = 16000;
|
||||
const kRestartRecording_s = 120;
|
||||
const kIntervalAudio_ms = 5000; // pass the recorded audio to the C++ instance at this rate
|
||||
|
||||
var mediaRecorder = null;
|
||||
var doRecording = false;
|
||||
var startTime = 0;
|
||||
|
||||
window.AudioContext = window.AudioContext || window.webkitAudioContext;
|
||||
window.OfflineAudioContext = window.OfflineAudioContext || window.webkitOfflineAudioContext;
|
||||
|
||||
function stopRecording() {
|
||||
Module.set_status("paused");
|
||||
doRecording = false;
|
||||
audio0 = null;
|
||||
audio = null;
|
||||
context = null;
|
||||
}
|
||||
|
||||
function startRecording() {
|
||||
if (!context) {
|
||||
context = new AudioContext({
|
||||
sampleRate: kSampleRate,
|
||||
channelCount: 1,
|
||||
echoCancellation: false,
|
||||
autoGainControl: true,
|
||||
noiseSuppression: true,
|
||||
});
|
||||
}
|
||||
|
||||
Module.set_status("");
|
||||
|
||||
document.getElementById('start').disabled = true;
|
||||
document.getElementById('stop').disabled = false;
|
||||
|
||||
doRecording = true;
|
||||
startTime = Date.now();
|
||||
|
||||
var chunks = [];
|
||||
var stream = null;
|
||||
|
||||
navigator.mediaDevices.getUserMedia({audio: true, video: false})
|
||||
.then(function(s) {
|
||||
stream = s;
|
||||
mediaRecorder = new MediaRecorder(stream);
|
||||
mediaRecorder.ondataavailable = function(e) {
|
||||
chunks.push(e.data);
|
||||
|
||||
var blob = new Blob(chunks, { 'type' : 'audio/ogg; codecs=opus' });
|
||||
var reader = new FileReader();
|
||||
|
||||
reader.onload = function(event) {
|
||||
var buf = new Uint8Array(reader.result);
|
||||
|
||||
if (!context) {
|
||||
return;
|
||||
}
|
||||
context.decodeAudioData(buf.buffer, function(audioBuffer) {
|
||||
var offlineContext = new OfflineAudioContext(audioBuffer.numberOfChannels, audioBuffer.length, audioBuffer.sampleRate);
|
||||
var source = offlineContext.createBufferSource();
|
||||
source.buffer = audioBuffer;
|
||||
source.connect(offlineContext.destination);
|
||||
source.start(0);
|
||||
|
||||
offlineContext.startRendering().then(function(renderedBuffer) {
|
||||
audio = renderedBuffer.getChannelData(0);
|
||||
|
||||
//printTextarea('js: audio recorded, size: ' + audio.length + ', old size: ' + (audio0 == null ? 0 : audio0.length));
|
||||
|
||||
var audioAll = new Float32Array(audio0 == null ? audio.length : audio0.length + audio.length);
|
||||
if (audio0 != null) {
|
||||
audioAll.set(audio0, 0);
|
||||
}
|
||||
audioAll.set(audio, audio0 == null ? 0 : audio0.length);
|
||||
|
||||
if (instance) {
|
||||
Module.set_audio(instance, audioAll);
|
||||
}
|
||||
});
|
||||
}, function(e) {
|
||||
audio = null;
|
||||
});
|
||||
}
|
||||
|
||||
reader.readAsArrayBuffer(blob);
|
||||
};
|
||||
|
||||
mediaRecorder.onstop = function(e) {
|
||||
if (doRecording) {
|
||||
setTimeout(function() {
|
||||
startRecording();
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
mediaRecorder.start(kIntervalAudio_ms);
|
||||
})
|
||||
.catch(function(err) {
|
||||
printTextarea('js: error getting audio stream: ' + err);
|
||||
});
|
||||
|
||||
var interval = setInterval(function() {
|
||||
if (!doRecording) {
|
||||
clearInterval(interval);
|
||||
mediaRecorder.stop();
|
||||
stream.getTracks().forEach(function(track) {
|
||||
track.stop();
|
||||
});
|
||||
|
||||
document.getElementById('start').disabled = false;
|
||||
document.getElementById('stop').disabled = true;
|
||||
|
||||
mediaRecorder = null;
|
||||
}
|
||||
|
||||
// if audio length is more than kRestartRecording_s seconds, restart recording
|
||||
if (audio != null && audio.length > kSampleRate*kRestartRecording_s) {
|
||||
if (doRecording) {
|
||||
//printTextarea('js: restarting recording');
|
||||
|
||||
clearInterval(interval);
|
||||
audio0 = audio;
|
||||
audio = null;
|
||||
mediaRecorder.stop();
|
||||
stream.getTracks().forEach(function(track) {
|
||||
track.stop();
|
||||
});
|
||||
}
|
||||
}
|
||||
}, 100);
|
||||
}
|
||||
|
||||
//
|
||||
// main
|
||||
//
|
||||
|
||||
var nLines = 0;
|
||||
var intervalUpdate = null;
|
||||
var transcribedAll = '';
|
||||
|
||||
function onStart() {
|
||||
if (!instance) {
|
||||
instance = Module.init('whisper.bin');
|
||||
|
||||
if (instance) {
|
||||
printTextarea("js: whisper initialized, instance: " + instance);
|
||||
}
|
||||
}
|
||||
|
||||
if (!instance) {
|
||||
printTextarea("js: failed to initialize whisper");
|
||||
return;
|
||||
}
|
||||
|
||||
startRecording();
|
||||
|
||||
intervalUpdate = setInterval(function() {
|
||||
var transcribed = Module.get_transcribed();
|
||||
|
||||
if (transcribed != null && transcribed.length > 1) {
|
||||
transcribedAll += transcribed + '<br>';
|
||||
nLines++;
|
||||
|
||||
// if more than 10 lines, remove the first line
|
||||
if (nLines > 10) {
|
||||
var i = transcribedAll.indexOf('<br>');
|
||||
if (i > 0) {
|
||||
transcribedAll = transcribedAll.substring(i + 4);
|
||||
nLines--;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
document.getElementById('state-status').innerHTML = Module.get_status();
|
||||
document.getElementById('state-transcribed').innerHTML = transcribedAll;
|
||||
}, 100);
|
||||
}
|
||||
|
||||
function onStop() {
|
||||
stopRecording();
|
||||
}
|
||||
|
||||
</script>
|
||||
<script type="text/javascript" src="stream.js"></script>
|
||||
</body>
|
||||
</html>
|
10
examples/stream/CMakeLists.txt
Normal file
10
examples/stream/CMakeLists.txt
Normal file
@ -0,0 +1,10 @@
|
||||
if (WHISPER_SUPPORT_SDL2)
|
||||
# stream
|
||||
set(TARGET stream)
|
||||
add_executable(${TARGET} stream.cpp)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
|
||||
target_link_libraries(${TARGET} PRIVATE whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
|
||||
endif ()
|
44
examples/stream/README.md
Normal file
44
examples/stream/README.md
Normal file
@ -0,0 +1,44 @@
|
||||
# stream
|
||||
|
||||
This is a naive example of performing real-time inference on audio from your microphone.
|
||||
The `stream` tool samples the audio every half a second and runs the transcription continously.
|
||||
More info is available in [issue #10](https://github.com/ggerganov/whisper.cpp/issues/10).
|
||||
|
||||
```java
|
||||
./stream -m ./models/ggml-base.en.bin -t 8 --step 500 --length 5000
|
||||
```
|
||||
|
||||
https://user-images.githubusercontent.com/1991296/194935793-76afede7-cfa8-48d8-a80f-28ba83be7d09.mp4
|
||||
|
||||
## Sliding window mode with VAD
|
||||
|
||||
Setting the `--step` argument to `0` enables the sliding window mode:
|
||||
|
||||
```java
|
||||
./stream -m ./models/ggml-small.en.bin -t 6 --step 0 --length 30000 -vth 0.6
|
||||
```
|
||||
|
||||
In this mode, the tool will transcribe only after some speech activity is detected. A very
|
||||
basic VAD detector is used, but in theory a more sophisticated approach can be added. The
|
||||
`-vth` argument determines the VAD threshold - higher values will make it detect silence more often.
|
||||
It's best to tune it to the specific use case, but a value around `0.6` should be OK in general.
|
||||
When silence is detected, it will transcribe the last `--length` milliseconds of audio and output
|
||||
a transcription block that is suitable for parsing.
|
||||
|
||||
## Building
|
||||
|
||||
The `stream` tool depends on SDL2 library to capture audio from the microphone. You can build it like this:
|
||||
|
||||
```bash
|
||||
# Install SDL2 on Linux
|
||||
sudo apt-get install libsdl2-dev
|
||||
|
||||
# Install SDL2 on Mac OS
|
||||
brew install sdl2
|
||||
|
||||
make stream
|
||||
```
|
||||
|
||||
## Web version
|
||||
|
||||
This tool can also run in the browser: [examples/stream.wasm](/examples/stream.wasm)
|
710
examples/stream/stream.cpp
Normal file
710
examples/stream/stream.cpp
Normal file
@ -0,0 +1,710 @@
|
||||
// Real-time speech recognition of input from a microphone
|
||||
//
|
||||
// A very quick-n-dirty implementation serving mainly as a proof of concept.
|
||||
//
|
||||
|
||||
#include "whisper.h"
|
||||
|
||||
#include <SDL.h>
|
||||
#include <SDL_audio.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <cassert>
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <fstream>
|
||||
#include <mutex>
|
||||
|
||||
// 500 -> 00:05.000
|
||||
// 6000 -> 01:00.000
|
||||
std::string to_timestamp(int64_t t) {
|
||||
int64_t sec = t/100;
|
||||
int64_t msec = t - sec*100;
|
||||
int64_t min = sec/60;
|
||||
sec = sec - min*60;
|
||||
|
||||
char buf[32];
|
||||
snprintf(buf, sizeof(buf), "%02d:%02d.%03d", (int) min, (int) sec, (int) msec);
|
||||
|
||||
return std::string(buf);
|
||||
}
|
||||
|
||||
// command-line parameters
|
||||
struct whisper_params {
|
||||
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||
int32_t step_ms = 3000;
|
||||
int32_t length_ms = 10000;
|
||||
int32_t keep_ms = 200;
|
||||
int32_t capture_id = -1;
|
||||
int32_t max_tokens = 32;
|
||||
int32_t audio_ctx = 0;
|
||||
|
||||
float vad_thold = 0.6f;
|
||||
float freq_thold = 100.0f;
|
||||
|
||||
bool speed_up = false;
|
||||
bool translate = false;
|
||||
bool print_special = false;
|
||||
bool no_context = true;
|
||||
bool no_timestamps = false;
|
||||
|
||||
std::string language = "en";
|
||||
std::string model = "models/ggml-base.en.bin";
|
||||
std::string fname_out;
|
||||
};
|
||||
|
||||
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
|
||||
|
||||
bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
for (int i = 1; i < argc; i++) {
|
||||
std::string arg = argv[i];
|
||||
|
||||
if (arg == "-h" || arg == "--help") {
|
||||
whisper_print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
}
|
||||
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
|
||||
else if ( arg == "--step") { params.step_ms = std::stoi(argv[++i]); }
|
||||
else if ( arg == "--length") { params.length_ms = std::stoi(argv[++i]); }
|
||||
else if ( arg == "--keep") { params.keep_ms = std::stoi(argv[++i]); }
|
||||
else if (arg == "-c" || arg == "--capture") { params.capture_id = std::stoi(argv[++i]); }
|
||||
else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); }
|
||||
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
|
||||
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
|
||||
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
|
||||
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
|
||||
else if (arg == "-kc" || arg == "--keep-context") { params.no_context = false; }
|
||||
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
|
||||
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
|
||||
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
|
||||
else {
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
whisper_print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params) {
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "usage: %s [options]\n", argv[0]);
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "options:\n");
|
||||
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
|
||||
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
|
||||
fprintf(stderr, " --step N [%-7d] audio step size in milliseconds\n", params.step_ms);
|
||||
fprintf(stderr, " --length N [%-7d] audio length in milliseconds\n", params.length_ms);
|
||||
fprintf(stderr, " --keep N [%-7d] audio to keep from previous step in ms\n", params.keep_ms);
|
||||
fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
|
||||
fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
|
||||
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
|
||||
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
|
||||
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
|
||||
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
|
||||
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
|
||||
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
||||
fprintf(stderr, " -kc, --keep-context [%-7s] keep context between audio chunks\n", params.no_context ? "false" : "true");
|
||||
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
|
||||
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
||||
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
//
|
||||
// SDL Audio capture
|
||||
//
|
||||
|
||||
class audio_async {
|
||||
public:
|
||||
audio_async(int len_ms);
|
||||
~audio_async();
|
||||
|
||||
bool init(int capture_id, int sample_rate);
|
||||
|
||||
// start capturing audio via the provided SDL callback
|
||||
// keep last len_ms seconds of audio in a circular buffer
|
||||
bool resume();
|
||||
bool pause();
|
||||
bool clear();
|
||||
|
||||
// callback to be called by SDL
|
||||
void callback(uint8_t * stream, int len);
|
||||
|
||||
// get audio data from the circular buffer
|
||||
void get(int ms, std::vector<float> & audio);
|
||||
|
||||
private:
|
||||
SDL_AudioDeviceID m_dev_id_in = 0;
|
||||
|
||||
int m_len_ms = 0;
|
||||
int m_sample_rate = 0;
|
||||
|
||||
std::atomic_bool m_running;
|
||||
std::mutex m_mutex;
|
||||
|
||||
std::vector<float> m_audio;
|
||||
std::vector<float> m_audio_new;
|
||||
size_t m_audio_pos = 0;
|
||||
size_t m_audio_len = 0;
|
||||
};
|
||||
|
||||
audio_async::audio_async(int len_ms) {
|
||||
m_len_ms = len_ms;
|
||||
|
||||
m_running = false;
|
||||
}
|
||||
|
||||
audio_async::~audio_async() {
|
||||
if (m_dev_id_in) {
|
||||
SDL_CloseAudioDevice(m_dev_id_in);
|
||||
}
|
||||
}
|
||||
|
||||
bool audio_async::init(int capture_id, int sample_rate) {
|
||||
SDL_LogSetPriority(SDL_LOG_CATEGORY_APPLICATION, SDL_LOG_PRIORITY_INFO);
|
||||
|
||||
if (SDL_Init(SDL_INIT_AUDIO) < 0) {
|
||||
SDL_LogError(SDL_LOG_CATEGORY_APPLICATION, "Couldn't initialize SDL: %s\n", SDL_GetError());
|
||||
return false;
|
||||
}
|
||||
|
||||
SDL_SetHintWithPriority(SDL_HINT_AUDIO_RESAMPLING_MODE, "medium", SDL_HINT_OVERRIDE);
|
||||
|
||||
{
|
||||
int nDevices = SDL_GetNumAudioDevices(SDL_TRUE);
|
||||
fprintf(stderr, "%s: found %d capture devices:\n", __func__, nDevices);
|
||||
for (int i = 0; i < nDevices; i++) {
|
||||
fprintf(stderr, "%s: - Capture device #%d: '%s'\n", __func__, i, SDL_GetAudioDeviceName(i, SDL_TRUE));
|
||||
}
|
||||
}
|
||||
|
||||
SDL_AudioSpec capture_spec_requested;
|
||||
SDL_AudioSpec capture_spec_obtained;
|
||||
|
||||
SDL_zero(capture_spec_requested);
|
||||
SDL_zero(capture_spec_obtained);
|
||||
|
||||
capture_spec_requested.freq = sample_rate;
|
||||
capture_spec_requested.format = AUDIO_F32;
|
||||
capture_spec_requested.channels = 1;
|
||||
capture_spec_requested.samples = 1024;
|
||||
capture_spec_requested.callback = [](void * userdata, uint8_t * stream, int len) {
|
||||
audio_async * audio = (audio_async *) userdata;
|
||||
audio->callback(stream, len);
|
||||
};
|
||||
capture_spec_requested.userdata = this;
|
||||
|
||||
if (capture_id >= 0) {
|
||||
fprintf(stderr, "%s: attempt to open capture device %d : '%s' ...\n", __func__, capture_id, SDL_GetAudioDeviceName(capture_id, SDL_TRUE));
|
||||
m_dev_id_in = SDL_OpenAudioDevice(SDL_GetAudioDeviceName(capture_id, SDL_TRUE), SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
|
||||
} else {
|
||||
fprintf(stderr, "%s: attempt to open default capture device ...\n", __func__);
|
||||
m_dev_id_in = SDL_OpenAudioDevice(nullptr, SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
|
||||
}
|
||||
|
||||
if (!m_dev_id_in) {
|
||||
fprintf(stderr, "%s: couldn't open an audio device for capture: %s!\n", __func__, SDL_GetError());
|
||||
m_dev_id_in = 0;
|
||||
|
||||
return false;
|
||||
} else {
|
||||
fprintf(stderr, "%s: obtained spec for input device (SDL Id = %d):\n", __func__, m_dev_id_in);
|
||||
fprintf(stderr, "%s: - sample rate: %d\n", __func__, capture_spec_obtained.freq);
|
||||
fprintf(stderr, "%s: - format: %d (required: %d)\n", __func__, capture_spec_obtained.format,
|
||||
capture_spec_requested.format);
|
||||
fprintf(stderr, "%s: - channels: %d (required: %d)\n", __func__, capture_spec_obtained.channels,
|
||||
capture_spec_requested.channels);
|
||||
fprintf(stderr, "%s: - samples per frame: %d\n", __func__, capture_spec_obtained.samples);
|
||||
}
|
||||
|
||||
m_sample_rate = capture_spec_obtained.freq;
|
||||
|
||||
m_audio.resize((m_sample_rate*m_len_ms)/1000);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool audio_async::resume() {
|
||||
if (!m_dev_id_in) {
|
||||
fprintf(stderr, "%s: no audio device to resume!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (m_running) {
|
||||
fprintf(stderr, "%s: already running!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
SDL_PauseAudioDevice(m_dev_id_in, 0);
|
||||
|
||||
m_running = true;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool audio_async::pause() {
|
||||
if (!m_dev_id_in) {
|
||||
fprintf(stderr, "%s: no audio device to pause!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!m_running) {
|
||||
fprintf(stderr, "%s: already paused!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
SDL_PauseAudioDevice(m_dev_id_in, 1);
|
||||
|
||||
m_running = false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool audio_async::clear() {
|
||||
if (!m_dev_id_in) {
|
||||
fprintf(stderr, "%s: no audio device to clear!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!m_running) {
|
||||
fprintf(stderr, "%s: not running!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(m_mutex);
|
||||
|
||||
m_audio_pos = 0;
|
||||
m_audio_len = 0;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// callback to be called by SDL
|
||||
void audio_async::callback(uint8_t * stream, int len) {
|
||||
if (!m_running) {
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t n_samples = len / sizeof(float);
|
||||
|
||||
m_audio_new.resize(n_samples);
|
||||
memcpy(m_audio_new.data(), stream, n_samples * sizeof(float));
|
||||
|
||||
//fprintf(stderr, "%s: %zu samples, pos %zu, len %zu\n", __func__, n_samples, m_audio_pos, m_audio_len);
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(m_mutex);
|
||||
|
||||
if (m_audio_pos + n_samples > m_audio.size()) {
|
||||
const size_t n0 = m_audio.size() - m_audio_pos;
|
||||
|
||||
memcpy(&m_audio[m_audio_pos], stream, n0 * sizeof(float));
|
||||
memcpy(&m_audio[0], &stream[n0], (n_samples - n0) * sizeof(float));
|
||||
|
||||
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
|
||||
m_audio_len = m_audio.size();
|
||||
} else {
|
||||
memcpy(&m_audio[m_audio_pos], stream, n_samples * sizeof(float));
|
||||
|
||||
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
|
||||
m_audio_len = std::min(m_audio_len + n_samples, m_audio.size());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void audio_async::get(int ms, std::vector<float> & result) {
|
||||
if (!m_dev_id_in) {
|
||||
fprintf(stderr, "%s: no audio device to get audio from!\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!m_running) {
|
||||
fprintf(stderr, "%s: not running!\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
result.clear();
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(m_mutex);
|
||||
|
||||
if (ms <= 0) {
|
||||
ms = m_len_ms;
|
||||
}
|
||||
|
||||
size_t n_samples = (m_sample_rate * ms) / 1000;
|
||||
if (n_samples > m_audio_len) {
|
||||
n_samples = m_audio_len;
|
||||
}
|
||||
|
||||
result.resize(n_samples);
|
||||
|
||||
int s0 = m_audio_pos - n_samples;
|
||||
if (s0 < 0) {
|
||||
s0 += m_audio.size();
|
||||
}
|
||||
|
||||
if (s0 + n_samples > m_audio.size()) {
|
||||
const size_t n0 = m_audio.size() - s0;
|
||||
|
||||
memcpy(result.data(), &m_audio[s0], n0 * sizeof(float));
|
||||
memcpy(&result[n0], &m_audio[0], (n_samples - n0) * sizeof(float));
|
||||
} else {
|
||||
memcpy(result.data(), &m_audio[s0], n_samples * sizeof(float));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////
|
||||
|
||||
void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate) {
|
||||
const float rc = 1.0f / (2.0f * M_PI * cutoff);
|
||||
const float dt = 1.0f / sample_rate;
|
||||
const float alpha = dt / (rc + dt);
|
||||
|
||||
float y = data[0];
|
||||
|
||||
for (size_t i = 1; i < data.size(); i++) {
|
||||
y = alpha * (y + data[i] - data[i - 1]);
|
||||
data[i] = y;
|
||||
}
|
||||
}
|
||||
|
||||
bool vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) {
|
||||
const int n_samples = pcmf32.size();
|
||||
const int n_samples_last = (sample_rate * last_ms) / 1000;
|
||||
|
||||
if (n_samples_last >= n_samples) {
|
||||
// not enough samples - assume no speech
|
||||
return false;
|
||||
}
|
||||
|
||||
if (freq_thold > 0.0f) {
|
||||
high_pass_filter(pcmf32, freq_thold, sample_rate);
|
||||
}
|
||||
|
||||
float energy_all = 0.0f;
|
||||
float energy_last = 0.0f;
|
||||
|
||||
for (int i = 0; i < n_samples; i++) {
|
||||
energy_all += fabsf(pcmf32[i]);
|
||||
|
||||
if (i >= n_samples - n_samples_last) {
|
||||
energy_last += fabsf(pcmf32[i]);
|
||||
}
|
||||
}
|
||||
|
||||
energy_all /= n_samples;
|
||||
energy_last /= n_samples_last;
|
||||
|
||||
if (verbose) {
|
||||
fprintf(stderr, "%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold);
|
||||
}
|
||||
|
||||
if (energy_last > vad_thold*energy_all) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
whisper_params params;
|
||||
|
||||
if (whisper_params_parse(argc, argv, params) == false) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
params.keep_ms = std::min(params.keep_ms, params.step_ms); // cannot be more than step_ms
|
||||
|
||||
const int n_samples_step = (params.step_ms *1e-3)*WHISPER_SAMPLE_RATE;
|
||||
const int n_samples_len = (params.length_ms*1e-3)*WHISPER_SAMPLE_RATE;
|
||||
const int n_samples_keep = (params.keep_ms *1e-3)*WHISPER_SAMPLE_RATE;
|
||||
const int n_samples_30s = (30000 *1e-3)*WHISPER_SAMPLE_RATE;
|
||||
|
||||
const bool use_vad = n_samples_step <= 0; // sliding window mode uses VAD
|
||||
|
||||
const int n_new_line = !use_vad ? params.length_ms / params.step_ms - 1 : 1; // number of steps to print new line
|
||||
|
||||
params.no_timestamps = !use_vad;
|
||||
params.no_context = use_vad;
|
||||
params.max_tokens = 0;
|
||||
|
||||
// init audio
|
||||
|
||||
audio_async audio(params.length_ms);
|
||||
if (!audio.init(params.capture_id, WHISPER_SAMPLE_RATE)) {
|
||||
fprintf(stderr, "%s: audio.init() failed!\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
audio.resume();
|
||||
|
||||
// whisper init
|
||||
|
||||
if (whisper_lang_id(params.language.c_str()) == -1) {
|
||||
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
|
||||
whisper_print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
}
|
||||
|
||||
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
|
||||
|
||||
std::vector<float> pcmf32 (n_samples_30s, 0.0f);
|
||||
std::vector<float> pcmf32_old;
|
||||
std::vector<float> pcmf32_new(n_samples_30s, 0.0f);
|
||||
|
||||
std::vector<whisper_token> prompt_tokens;
|
||||
|
||||
// print some info about the processing
|
||||
{
|
||||
fprintf(stderr, "\n");
|
||||
if (!whisper_is_multilingual(ctx)) {
|
||||
if (params.language != "en" || params.translate) {
|
||||
params.language = "en";
|
||||
params.translate = false;
|
||||
fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
|
||||
}
|
||||
}
|
||||
fprintf(stderr, "%s: processing %d samples (step = %.1f sec / len = %.1f sec / keep = %.1f sec), %d threads, lang = %s, task = %s, timestamps = %d ...\n",
|
||||
__func__,
|
||||
n_samples_step,
|
||||
float(n_samples_step)/WHISPER_SAMPLE_RATE,
|
||||
float(n_samples_len )/WHISPER_SAMPLE_RATE,
|
||||
float(n_samples_keep)/WHISPER_SAMPLE_RATE,
|
||||
params.n_threads,
|
||||
params.language.c_str(),
|
||||
params.translate ? "translate" : "transcribe",
|
||||
params.no_timestamps ? 0 : 1);
|
||||
|
||||
if (!use_vad) {
|
||||
fprintf(stderr, "%s: n_new_line = %d\n", __func__, n_new_line);
|
||||
} else {
|
||||
fprintf(stderr, "%s: using VAD, will transcribe on speech activity\n", __func__);
|
||||
}
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
int n_iter = 0;
|
||||
|
||||
bool is_running = true;
|
||||
|
||||
std::ofstream fout;
|
||||
if (params.fname_out.length() > 0) {
|
||||
fout.open(params.fname_out);
|
||||
if (!fout.is_open()) {
|
||||
fprintf(stderr, "%s: failed to open output file '%s'!\n", __func__, params.fname_out.c_str());
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
printf("[Start speaking]");
|
||||
fflush(stdout);
|
||||
|
||||
auto t_last = std::chrono::high_resolution_clock::now();
|
||||
const auto t_start = t_last;
|
||||
|
||||
// main audio loop
|
||||
while (is_running) {
|
||||
// handle Ctrl + C
|
||||
{
|
||||
SDL_Event event;
|
||||
while (SDL_PollEvent(&event)) {
|
||||
switch (event.type) {
|
||||
case SDL_QUIT:
|
||||
{
|
||||
is_running = false;
|
||||
} break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!is_running) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!is_running) {
|
||||
break;
|
||||
}
|
||||
|
||||
// process new audio
|
||||
|
||||
if (!use_vad) {
|
||||
while (true) {
|
||||
audio.get(params.step_ms, pcmf32_new);
|
||||
|
||||
if ((int) pcmf32_new.size() > 2*n_samples_step) {
|
||||
fprintf(stderr, "\n\n%s: WARNING: cannot process audio fast enough, dropping audio ...\n\n", __func__);
|
||||
audio.clear();
|
||||
continue;
|
||||
}
|
||||
|
||||
if ((int) pcmf32_new.size() >= n_samples_step) {
|
||||
audio.clear();
|
||||
break;
|
||||
}
|
||||
|
||||
SDL_Delay(1);
|
||||
}
|
||||
|
||||
const int n_samples_new = pcmf32_new.size();
|
||||
|
||||
// take up to params.length_ms audio from previous iteration
|
||||
const int n_samples_take = std::min((int) pcmf32_old.size(), std::max(0, n_samples_keep + n_samples_len - n_samples_new));
|
||||
|
||||
//printf("processing: take = %d, new = %d, old = %d\n", n_samples_take, n_samples_new, (int) pcmf32_old.size());
|
||||
|
||||
pcmf32.resize(n_samples_new + n_samples_take);
|
||||
|
||||
for (int i = 0; i < n_samples_take; i++) {
|
||||
pcmf32[i] = pcmf32_old[pcmf32_old.size() - n_samples_take + i];
|
||||
}
|
||||
|
||||
memcpy(pcmf32.data() + n_samples_take, pcmf32_new.data(), n_samples_new*sizeof(float));
|
||||
|
||||
pcmf32_old = pcmf32;
|
||||
} else {
|
||||
const auto t_now = std::chrono::high_resolution_clock::now();
|
||||
const auto t_diff = std::chrono::duration_cast<std::chrono::milliseconds>(t_now - t_last).count();
|
||||
|
||||
if (t_diff < 2000) {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
audio.get(2000, pcmf32_new);
|
||||
|
||||
if (vad_simple(pcmf32_new, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, false)) {
|
||||
audio.get(params.length_ms, pcmf32);
|
||||
} else {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
t_last = t_now;
|
||||
}
|
||||
|
||||
// run the inference
|
||||
{
|
||||
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||
|
||||
wparams.print_progress = false;
|
||||
wparams.print_special = params.print_special;
|
||||
wparams.print_realtime = false;
|
||||
wparams.print_timestamps = !params.no_timestamps;
|
||||
wparams.translate = params.translate;
|
||||
wparams.no_context = true;
|
||||
wparams.single_segment = !use_vad;
|
||||
wparams.max_tokens = params.max_tokens;
|
||||
wparams.language = params.language.c_str();
|
||||
wparams.n_threads = params.n_threads;
|
||||
|
||||
wparams.audio_ctx = params.audio_ctx;
|
||||
wparams.speed_up = params.speed_up;
|
||||
|
||||
// disable temperature fallback
|
||||
wparams.temperature_inc = -1.0f;
|
||||
|
||||
wparams.prompt_tokens = params.no_context ? nullptr : prompt_tokens.data();
|
||||
wparams.prompt_n_tokens = params.no_context ? 0 : prompt_tokens.size();
|
||||
|
||||
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
|
||||
fprintf(stderr, "%s: failed to process audio\n", argv[0]);
|
||||
return 6;
|
||||
}
|
||||
|
||||
// print result;
|
||||
{
|
||||
if (!use_vad) {
|
||||
printf("\33[2K\r");
|
||||
|
||||
// print long empty line to clear the previous line
|
||||
printf("%s", std::string(100, ' ').c_str());
|
||||
|
||||
printf("\33[2K\r");
|
||||
} else {
|
||||
const int64_t t1 = (t_last - t_start).count()/1000000;
|
||||
const int64_t t0 = std::max(0.0, t1 - pcmf32.size()*1000.0/WHISPER_SAMPLE_RATE);
|
||||
|
||||
printf("\n");
|
||||
printf("### Transcription %d START | t0 = %d ms | t1 = %d ms\n", n_iter, (int) t0, (int) t1);
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
const int n_segments = whisper_full_n_segments(ctx);
|
||||
for (int i = 0; i < n_segments; ++i) {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
|
||||
if (params.no_timestamps) {
|
||||
printf("%s", text);
|
||||
fflush(stdout);
|
||||
|
||||
if (params.fname_out.length() > 0) {
|
||||
fout << text;
|
||||
}
|
||||
} else {
|
||||
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
|
||||
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
||||
|
||||
printf ("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
|
||||
|
||||
if (params.fname_out.length() > 0) {
|
||||
fout << "[" << to_timestamp(t0) << " --> " << to_timestamp(t1) << "] " << text << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (params.fname_out.length() > 0) {
|
||||
fout << std::endl;
|
||||
}
|
||||
|
||||
if (use_vad){
|
||||
printf("\n");
|
||||
printf("### Transcription %d END\n", n_iter);
|
||||
}
|
||||
}
|
||||
|
||||
++n_iter;
|
||||
|
||||
if (!use_vad && (n_iter % n_new_line) == 0) {
|
||||
printf("\n");
|
||||
|
||||
// keep part of the audio for next iteration to try to mitigate word boundary issues
|
||||
pcmf32_old = std::vector<float>(pcmf32.end() - n_samples_keep, pcmf32.end());
|
||||
|
||||
// Add tokens of the last full length segment as the prompt
|
||||
if (!params.no_context) {
|
||||
prompt_tokens.clear();
|
||||
|
||||
const int n_segments = whisper_full_n_segments(ctx);
|
||||
for (int i = 0; i < n_segments; ++i) {
|
||||
const int token_count = whisper_full_n_tokens(ctx, i);
|
||||
for (int j = 0; j < token_count; ++j) {
|
||||
prompt_tokens.push_back(whisper_full_get_token_id(ctx, i, j));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
audio.pause();
|
||||
|
||||
whisper_print_timings(ctx);
|
||||
whisper_free(ctx);
|
||||
|
||||
return 0;
|
||||
}
|
50
examples/talk.wasm/CMakeLists.txt
Normal file
50
examples/talk.wasm/CMakeLists.txt
Normal file
@ -0,0 +1,50 @@
|
||||
#
|
||||
# libtalk
|
||||
#
|
||||
|
||||
set(TARGET libtalk)
|
||||
|
||||
add_executable(${TARGET}
|
||||
emscripten.cpp
|
||||
gpt-2.cpp
|
||||
)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE
|
||||
whisper
|
||||
)
|
||||
|
||||
unset(EXTRA_FLAGS)
|
||||
|
||||
if (WHISPER_WASM_SINGLE_FILE)
|
||||
set(EXTRA_FLAGS "-s SINGLE_FILE=1")
|
||||
message(STATUS "Embedding WASM inside talk.js")
|
||||
|
||||
add_custom_command(
|
||||
TARGET ${TARGET} POST_BUILD
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${CMAKE_BINARY_DIR}/bin/libtalk.js
|
||||
${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/talk.wasm/talk.js
|
||||
)
|
||||
endif()
|
||||
|
||||
set_target_properties(${TARGET} PROPERTIES LINK_FLAGS " \
|
||||
--bind \
|
||||
-s USE_PTHREADS=1 \
|
||||
-s PTHREAD_POOL_SIZE=8 \
|
||||
-s INITIAL_MEMORY=1800MB \
|
||||
-s TOTAL_MEMORY=1800MB \
|
||||
-s FORCE_FILESYSTEM=1 \
|
||||
-s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap']\" \
|
||||
${EXTRA_FLAGS} \
|
||||
")
|
||||
|
||||
#
|
||||
# talk.wasm
|
||||
#
|
||||
|
||||
set(TARGET talk.wasm)
|
||||
|
||||
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/index-tmpl.html ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TARGET}/index.html @ONLY)
|
||||
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/../helpers.js ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TARGET}/helpers.js @ONLY)
|
74
examples/talk.wasm/README.md
Normal file
74
examples/talk.wasm/README.md
Normal file
@ -0,0 +1,74 @@
|
||||
# talk.wasm
|
||||
|
||||
Talk with an Artificial Intelligence in your browser:
|
||||
|
||||
[https://user-images.githubusercontent.com/1991296/203411580-fedb4839-05e4-4474-8364-aaf1e9a9b615.mp4](https://user-images.githubusercontent.com/1991296/203845553-f7b44e13-9a15-4fc8-b518-ae8f4c6770fe.mp4)
|
||||
|
||||
Online demo: https://whisper.ggerganov.com/talk/
|
||||
|
||||
Terminal version: [examples/talk](/examples/talk)
|
||||
|
||||
## How it works?
|
||||
|
||||
This demo leverages 2 modern neural network models to create a high-quality voice chat directly in your browser:
|
||||
|
||||
- [OpenAI's Whisper](https://github.com/openai/whisper) speech recognition model is used to process your voice and understand what you are saying
|
||||
- Upon receiving some voice input, the AI generates a text response using [OpenAI's GPT-2](https://github.com/openai/gpt-2) language model
|
||||
- The AI then vocalizes the response using the browser's [Web Speech API](https://developer.mozilla.org/en-US/docs/Web/API/Web_Speech_API)
|
||||
|
||||
The web page does the processing locally on your machine. The processing of these heavy neural network models in the
|
||||
browser is possible by implementing them efficiently in C/C++ and using the browser's WebAssembly SIMD capabilities for
|
||||
extra performance:
|
||||
|
||||
- The Whisper C++ implementation is here: [whisper.h](/whisper.h) / [whisper.cpp](/whisper.cpp)
|
||||
- The GPT-2 C++ implementation is here: [gpt-2.h](gpt-2.h) / [gpt-2.cpp](gpt-2.cpp)
|
||||
- Both models use a custom tensor library implemented in C: [ggml.h](/ggml.h) / [ggml.c](/ggml.c)
|
||||
- The HTML/JS layer is here: [index-tmpl.html](index-tmpl.html)
|
||||
- The Emscripten bridge between C/C++ and JS is here: [emscripten.cpp](emscripten.cpp)
|
||||
|
||||
In order to run the models, the web page first needs to download the model data which is about ~350 MB. The model data
|
||||
is then cached in your browser's cache and can be reused in future visits without downloading it again.
|
||||
|
||||
## Requirements
|
||||
|
||||
In order to run this demo efficiently, you need to have the following:
|
||||
|
||||
- Latest Chrome or Firefox browser (Safari is not supported)
|
||||
- Run this on a desktop or laptop with modern CPU (a mobile phone will likely not be good enough)
|
||||
- Speak phrases that are no longer than 10 seconds - this is the audio context of the AI
|
||||
- The web-page uses about 1.8GB of RAM
|
||||
|
||||
Notice that this demo is using the smallest GPT-2 model, so the generated text responses are not always very good.
|
||||
Also, the prompting strategy can likely be improved to achieve better results.
|
||||
|
||||
The demo is quite computationally heavy, so you need a fast CPU. It's not usual to run these transformer models in a
|
||||
browser. Typically, they run on powerful GPUs.
|
||||
|
||||
Currently, mobile browsers do not support the Fixed-width SIMD WebAssembly capability, so you cannot run this demo
|
||||
on a phone or a tablet. Hopefully, in the near future this will become supported.
|
||||
|
||||
## Todo
|
||||
|
||||
- Better UI (contributions are welcome)
|
||||
- Better GPT-2 prompting
|
||||
|
||||
## Build instructions
|
||||
|
||||
```bash
|
||||
# build using Emscripten (v3.1.2)
|
||||
git clone https://github.com/ggerganov/whisper.cpp
|
||||
cd whisper.cpp
|
||||
mkdir build-em && cd build-em
|
||||
emcmake cmake ..
|
||||
make -j
|
||||
|
||||
# copy the produced page to your HTTP path
|
||||
cp bin/talk.wasm/* /path/to/html/
|
||||
cp bin/libtalk.worker.js /path/to/html/
|
||||
```
|
||||
|
||||
## Feedback
|
||||
|
||||
If you have any comments or ideas for improvement, please drop a comment in the following discussion:
|
||||
|
||||
https://github.com/ggerganov/whisper.cpp/discussions/167
|
380
examples/talk.wasm/emscripten.cpp
Normal file
380
examples/talk.wasm/emscripten.cpp
Normal file
@ -0,0 +1,380 @@
|
||||
#include "ggml.h"
|
||||
#include "gpt-2.h"
|
||||
#include "whisper.h"
|
||||
|
||||
#include <emscripten.h>
|
||||
#include <emscripten/bind.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <cmath>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <regex>
|
||||
|
||||
constexpr int N_THREAD = 8;
|
||||
|
||||
struct gpt2_context * g_gpt2;
|
||||
std::vector<struct whisper_context *> g_contexts(4, nullptr);
|
||||
|
||||
std::mutex g_mutex;
|
||||
std::thread g_worker;
|
||||
std::atomic<bool> g_running(false);
|
||||
|
||||
bool g_force_speak = false;
|
||||
std::string g_text_to_speak = "";
|
||||
std::string g_status = "";
|
||||
std::string g_status_forced = "";
|
||||
|
||||
std::vector<float> g_pcmf32;
|
||||
|
||||
std::string to_timestamp(int64_t t) {
|
||||
int64_t sec = t/100;
|
||||
int64_t msec = t - sec*100;
|
||||
int64_t min = sec/60;
|
||||
sec = sec - min*60;
|
||||
|
||||
char buf[32];
|
||||
snprintf(buf, sizeof(buf), "%02d:%02d.%03d", (int) min, (int) sec, (int) msec);
|
||||
|
||||
return std::string(buf);
|
||||
}
|
||||
|
||||
void talk_set_status(const std::string & status) {
|
||||
std::lock_guard<std::mutex> lock(g_mutex);
|
||||
g_status = status;
|
||||
}
|
||||
|
||||
void talk_main(size_t index) {
|
||||
talk_set_status("loading data ...");
|
||||
|
||||
struct whisper_full_params wparams = whisper_full_default_params(whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY);
|
||||
|
||||
wparams.n_threads = std::min(N_THREAD, (int) std::thread::hardware_concurrency());
|
||||
wparams.offset_ms = 0;
|
||||
wparams.translate = false;
|
||||
wparams.no_context = true;
|
||||
wparams.single_segment = true;
|
||||
wparams.print_realtime = false;
|
||||
wparams.print_progress = false;
|
||||
wparams.print_timestamps = true;
|
||||
wparams.print_special = false;
|
||||
|
||||
wparams.max_tokens = 32;
|
||||
wparams.audio_ctx = 768; // partial encoder context for better performance
|
||||
|
||||
wparams.language = "en";
|
||||
|
||||
g_gpt2 = gpt2_init("gpt-2.bin");
|
||||
|
||||
printf("talk: using %d threads\n", wparams.n_threads);
|
||||
|
||||
std::vector<float> pcmf32;
|
||||
|
||||
// whisper context
|
||||
auto & ctx = g_contexts[index];
|
||||
|
||||
const int64_t step_samples = 2*WHISPER_SAMPLE_RATE;
|
||||
const int64_t window_samples = 9*WHISPER_SAMPLE_RATE;
|
||||
const int64_t step_ms = (step_samples*1000)/WHISPER_SAMPLE_RATE;
|
||||
|
||||
auto t_last = std::chrono::high_resolution_clock::now();
|
||||
|
||||
talk_set_status("listening ...");
|
||||
|
||||
while (g_running) {
|
||||
|
||||
const auto t_now = std::chrono::high_resolution_clock::now();
|
||||
if (std::chrono::duration_cast<std::chrono::milliseconds>(t_now - t_last).count() < step_ms) {
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(g_mutex);
|
||||
g_pcmf32.clear();
|
||||
}
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(10));
|
||||
continue;
|
||||
}
|
||||
|
||||
talk_set_status("listening ...");
|
||||
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(g_mutex);
|
||||
|
||||
if (g_pcmf32.size() < step_samples) {
|
||||
lock.unlock();
|
||||
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(10));
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
pcmf32 = std::vector<float>(g_pcmf32.end() - std::min((int64_t) g_pcmf32.size(), window_samples), g_pcmf32.end());
|
||||
}
|
||||
|
||||
// VAD: if energy in during last second is above threshold, then skip
|
||||
{
|
||||
float energy_all = 0.0f;
|
||||
float energy_1s = 0.0f;
|
||||
|
||||
for (size_t i = 0; i < pcmf32.size(); i++) {
|
||||
energy_all += fabsf(pcmf32[i]);
|
||||
|
||||
if (i >= pcmf32.size() - WHISPER_SAMPLE_RATE) {
|
||||
energy_1s += fabsf(pcmf32[i]);
|
||||
}
|
||||
}
|
||||
|
||||
energy_all /= pcmf32.size();
|
||||
energy_1s /= WHISPER_SAMPLE_RATE;
|
||||
|
||||
if (energy_1s > 0.1f*energy_all && !g_force_speak) {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(10));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
talk_set_status("processing audio (whisper)...");
|
||||
|
||||
t_last = t_now;
|
||||
|
||||
if (!g_force_speak) {
|
||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
int ret = whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size());
|
||||
if (ret != 0) {
|
||||
printf("whisper_full() failed: %d\n", ret);
|
||||
break;
|
||||
}
|
||||
|
||||
const auto t_end = std::chrono::high_resolution_clock::now();
|
||||
|
||||
printf("whisper_full() returned %d in %f seconds\n", ret, std::chrono::duration<double>(t_end - t_start).count());
|
||||
}
|
||||
|
||||
{
|
||||
std::string text_heard;
|
||||
|
||||
if (!g_force_speak) {
|
||||
const int n_segments = whisper_full_n_segments(ctx);
|
||||
for (int i = n_segments - 1; i < n_segments; ++i) {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
|
||||
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
|
||||
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
||||
|
||||
printf ("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
|
||||
|
||||
text_heard += text;
|
||||
}
|
||||
}
|
||||
|
||||
g_force_speak = false;
|
||||
|
||||
// remove text between brackets using regex
|
||||
{
|
||||
std::regex re("\\[.*?\\]");
|
||||
text_heard = std::regex_replace(text_heard, re, "");
|
||||
}
|
||||
|
||||
// remove text between brackets using regex
|
||||
{
|
||||
std::regex re("\\(.*?\\)");
|
||||
text_heard = std::regex_replace(text_heard, re, "");
|
||||
}
|
||||
|
||||
// remove all characters, except for letters, numbers, punctuation and ':', '\'', '-', ' '
|
||||
text_heard = std::regex_replace(text_heard, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
|
||||
|
||||
// take first line
|
||||
text_heard = text_heard.substr(0, text_heard.find_first_of("\n"));
|
||||
|
||||
// remove leading and trailing whitespace
|
||||
text_heard = std::regex_replace(text_heard, std::regex("^\\s+"), "");
|
||||
text_heard = std::regex_replace(text_heard, std::regex("\\s+$"), "");
|
||||
|
||||
talk_set_status("'" + text_heard + "' - thinking how to respond (gpt-2) ...");
|
||||
|
||||
const std::vector<gpt_vocab::id> tokens = gpt2_tokenize(g_gpt2, text_heard.c_str());
|
||||
|
||||
printf("whisper: number of tokens: %d, '%s'\n", (int) tokens.size(), text_heard.c_str());
|
||||
|
||||
std::string text_to_speak;
|
||||
std::string prompt_base;
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(g_mutex);
|
||||
prompt_base = gpt2_get_prompt(g_gpt2);
|
||||
}
|
||||
|
||||
if (tokens.size() > 0) {
|
||||
text_to_speak = gpt2_gen_text(g_gpt2, (prompt_base + text_heard + "\n").c_str(), 32);
|
||||
text_to_speak = std::regex_replace(text_to_speak, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
|
||||
text_to_speak = text_to_speak.substr(0, text_to_speak.find_first_of("\n"));
|
||||
|
||||
std::lock_guard<std::mutex> lock(g_mutex);
|
||||
|
||||
// remove first 2 lines of base prompt
|
||||
{
|
||||
const size_t pos = prompt_base.find_first_of("\n");
|
||||
if (pos != std::string::npos) {
|
||||
prompt_base = prompt_base.substr(pos + 1);
|
||||
}
|
||||
}
|
||||
{
|
||||
const size_t pos = prompt_base.find_first_of("\n");
|
||||
if (pos != std::string::npos) {
|
||||
prompt_base = prompt_base.substr(pos + 1);
|
||||
}
|
||||
}
|
||||
prompt_base += text_heard + "\n" + text_to_speak + "\n";
|
||||
} else {
|
||||
text_to_speak = gpt2_gen_text(g_gpt2, prompt_base.c_str(), 32);
|
||||
text_to_speak = std::regex_replace(text_to_speak, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
|
||||
text_to_speak = text_to_speak.substr(0, text_to_speak.find_first_of("\n"));
|
||||
|
||||
std::lock_guard<std::mutex> lock(g_mutex);
|
||||
|
||||
const size_t pos = prompt_base.find_first_of("\n");
|
||||
if (pos != std::string::npos) {
|
||||
prompt_base = prompt_base.substr(pos + 1);
|
||||
}
|
||||
prompt_base += text_to_speak + "\n";
|
||||
}
|
||||
|
||||
printf("gpt-2: %s\n", text_to_speak.c_str());
|
||||
|
||||
//printf("========================\n");
|
||||
//printf("gpt-2: prompt_base:\n'%s'\n", prompt_base.c_str());
|
||||
//printf("========================\n");
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(g_mutex);
|
||||
t_last = std::chrono::high_resolution_clock::now();
|
||||
g_text_to_speak = text_to_speak;
|
||||
g_pcmf32.clear();
|
||||
gpt2_set_prompt(g_gpt2, prompt_base.c_str());
|
||||
}
|
||||
|
||||
talk_set_status("speaking ...");
|
||||
}
|
||||
}
|
||||
|
||||
gpt2_free(g_gpt2);
|
||||
|
||||
if (index < g_contexts.size()) {
|
||||
whisper_free(g_contexts[index]);
|
||||
g_contexts[index] = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
EMSCRIPTEN_BINDINGS(talk) {
|
||||
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
|
||||
for (size_t i = 0; i < g_contexts.size(); ++i) {
|
||||
if (g_contexts[i] == nullptr) {
|
||||
g_contexts[i] = whisper_init_from_file(path_model.c_str());
|
||||
if (g_contexts[i] != nullptr) {
|
||||
g_running = true;
|
||||
if (g_worker.joinable()) {
|
||||
g_worker.join();
|
||||
}
|
||||
g_worker = std::thread([i]() {
|
||||
talk_main(i);
|
||||
});
|
||||
|
||||
return i + 1;
|
||||
} else {
|
||||
return (size_t) 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return (size_t) 0;
|
||||
}));
|
||||
|
||||
emscripten::function("free", emscripten::optional_override([](size_t index) {
|
||||
if (g_running) {
|
||||
g_running = false;
|
||||
}
|
||||
}));
|
||||
|
||||
emscripten::function("set_audio", emscripten::optional_override([](size_t index, const emscripten::val & audio) {
|
||||
--index;
|
||||
|
||||
if (index >= g_contexts.size()) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (g_contexts[index] == nullptr) {
|
||||
return -2;
|
||||
}
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(g_mutex);
|
||||
const int n = audio["length"].as<int>();
|
||||
|
||||
emscripten::val heap = emscripten::val::module_property("HEAPU8");
|
||||
emscripten::val memory = heap["buffer"];
|
||||
|
||||
g_pcmf32.resize(n);
|
||||
|
||||
emscripten::val memoryView = audio["constructor"].new_(memory, reinterpret_cast<uintptr_t>(g_pcmf32.data()), n);
|
||||
memoryView.call<void>("set", audio);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}));
|
||||
|
||||
emscripten::function("force_speak", emscripten::optional_override([](size_t index) {
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(g_mutex);
|
||||
g_force_speak = true;
|
||||
}
|
||||
}));
|
||||
|
||||
emscripten::function("get_text_context", emscripten::optional_override([]() {
|
||||
std::string text_context;
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(g_mutex);
|
||||
text_context = gpt2_get_prompt(g_gpt2);
|
||||
}
|
||||
|
||||
return text_context;
|
||||
}));
|
||||
|
||||
emscripten::function("get_text_to_speak", emscripten::optional_override([]() {
|
||||
std::string text_to_speak;
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(g_mutex);
|
||||
text_to_speak = std::move(g_text_to_speak);
|
||||
}
|
||||
|
||||
return text_to_speak;
|
||||
}));
|
||||
|
||||
emscripten::function("get_status", emscripten::optional_override([]() {
|
||||
std::string status;
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(g_mutex);
|
||||
status = g_status_forced.empty() ? g_status : g_status_forced;
|
||||
}
|
||||
|
||||
return status;
|
||||
}));
|
||||
|
||||
emscripten::function("set_status", emscripten::optional_override([](const std::string & status) {
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(g_mutex);
|
||||
g_status_forced = status;
|
||||
}
|
||||
}));
|
||||
|
||||
emscripten::function("set_prompt", emscripten::optional_override([](const std::string & prompt) {
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(g_mutex);
|
||||
gpt2_set_prompt(g_gpt2, prompt.c_str());
|
||||
}
|
||||
}));
|
||||
}
|
925
examples/talk.wasm/gpt-2.cpp
Normal file
925
examples/talk.wasm/gpt-2.cpp
Normal file
@ -0,0 +1,925 @@
|
||||
#include "ggml.h"
|
||||
#include "gpt-2.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <regex>
|
||||
#include <random>
|
||||
|
||||
/////////////////////// GPT-2 BEGIN /////////////////////////
|
||||
|
||||
//
|
||||
// Vocab utils
|
||||
//
|
||||
|
||||
std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text) {
|
||||
std::vector<std::string> words;
|
||||
|
||||
// first split the text into words
|
||||
{
|
||||
std::string str = text;
|
||||
std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
|
||||
|
||||
std::regex re(pat);
|
||||
std::smatch m;
|
||||
|
||||
while (std::regex_search(str, m, re)) {
|
||||
for (auto x : m) {
|
||||
words.push_back(x);
|
||||
}
|
||||
str = m.suffix();
|
||||
}
|
||||
}
|
||||
|
||||
// find the longest tokens that form the words:
|
||||
std::vector<gpt_vocab::id> tokens;
|
||||
for (const auto & word : words) {
|
||||
if (word.size() == 0) continue;
|
||||
|
||||
int i = 0;
|
||||
int n = word.size();
|
||||
while (i < n) {
|
||||
int j = n;
|
||||
while (j > i) {
|
||||
auto it = vocab.token_to_id.find(word.substr(i, j-i));
|
||||
if (it != vocab.token_to_id.end()) {
|
||||
tokens.push_back(it->second);
|
||||
i = j;
|
||||
break;
|
||||
}
|
||||
--j;
|
||||
}
|
||||
if (i == n) {
|
||||
break;
|
||||
}
|
||||
if (j == i) {
|
||||
auto sub = word.substr(i, 1);
|
||||
if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) {
|
||||
tokens.push_back(vocab.token_to_id.at(sub));
|
||||
} else {
|
||||
fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data());
|
||||
}
|
||||
++i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tokens;
|
||||
}
|
||||
|
||||
gpt_vocab::id gpt_sample_top_k_top_p(
|
||||
const gpt_vocab & vocab,
|
||||
const float * logits,
|
||||
int top_k,
|
||||
double top_p,
|
||||
double temp,
|
||||
std::mt19937 & rng) {
|
||||
int n_logits = vocab.id_to_token.size();
|
||||
|
||||
std::vector<std::pair<double, gpt_vocab::id>> logits_id;
|
||||
logits_id.reserve(n_logits);
|
||||
|
||||
for (int i = 0; i < n_logits; i++) {
|
||||
logits_id.push_back(std::make_pair(logits[i], i));
|
||||
}
|
||||
|
||||
// find the top K tokens
|
||||
std::partial_sort(
|
||||
logits_id.begin(),
|
||||
logits_id.begin() + top_k, logits_id.end(),
|
||||
[](const std::pair<double, gpt_vocab::id> & a, const std::pair<double, gpt_vocab::id> & b) {
|
||||
return a.first > b.first;
|
||||
});
|
||||
|
||||
logits_id.resize(top_k);
|
||||
|
||||
// normalize
|
||||
{
|
||||
double sum = 0.0f;
|
||||
for (int i = 0; i < (int)logits_id.size(); i++) {
|
||||
sum += logits_id[i].first;
|
||||
}
|
||||
|
||||
sum = 1.0/sum;
|
||||
for (int i = 0; i < (int)logits_id.size(); i++) {
|
||||
logits_id[i].first *= sum;
|
||||
}
|
||||
}
|
||||
|
||||
if (top_p < 1.0f) {
|
||||
{
|
||||
double cumsum = 0.0f;
|
||||
for (int i = 0; i < top_k; i++) {
|
||||
cumsum += logits_id[i].first;
|
||||
if (cumsum >= top_p) {
|
||||
logits_id.resize(i+1);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// normalize again
|
||||
{
|
||||
double sum = 0.0f;
|
||||
for (int i = 0; i < (int)logits_id.size(); i++) {
|
||||
sum += logits_id[i].first;
|
||||
}
|
||||
|
||||
sum = 1.0/sum;
|
||||
for (int i = 0; i < (int)logits_id.size(); i++) {
|
||||
logits_id[i].first *= sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//printf("\n");
|
||||
//for (int i = 0; i < (int)logits_id.size(); i++) {
|
||||
// printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), logits_id[i].first);
|
||||
//}
|
||||
//exit(0);
|
||||
|
||||
// sample from the obtained distribution
|
||||
std::vector<double> probs;
|
||||
probs.reserve(logits_id.size());
|
||||
|
||||
for (int i = 0; i < (int) logits_id.size(); i++) {
|
||||
probs.push_back(logits_id[i].first);
|
||||
}
|
||||
|
||||
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
||||
int idx = dist(rng);
|
||||
|
||||
return logits_id[idx].second;
|
||||
}
|
||||
|
||||
// default hparams (GPT-2 117M)
|
||||
struct gpt2_hparams {
|
||||
int32_t n_vocab = 50257;
|
||||
int32_t n_ctx = 1024;
|
||||
int32_t n_embd = 768;
|
||||
int32_t n_head = 12;
|
||||
int32_t n_layer = 12;
|
||||
int32_t f16 = 1;
|
||||
};
|
||||
|
||||
struct gpt2_layer {
|
||||
// normalization
|
||||
struct ggml_tensor * ln_1_g;
|
||||
struct ggml_tensor * ln_1_b;
|
||||
|
||||
struct ggml_tensor * ln_2_g;
|
||||
struct ggml_tensor * ln_2_b;
|
||||
|
||||
// attention
|
||||
struct ggml_tensor * c_attn_attn_w;
|
||||
struct ggml_tensor * c_attn_attn_b;
|
||||
|
||||
struct ggml_tensor * c_attn_proj_w;
|
||||
struct ggml_tensor * c_attn_proj_b;
|
||||
|
||||
// mlp
|
||||
struct ggml_tensor * c_mlp_fc_w;
|
||||
struct ggml_tensor * c_mlp_fc_b;
|
||||
|
||||
struct ggml_tensor * c_mlp_proj_w_trans; // transposed for efficiency
|
||||
struct ggml_tensor * c_mlp_proj_b;
|
||||
};
|
||||
|
||||
struct gpt2_model {
|
||||
gpt2_hparams hparams;
|
||||
|
||||
// normalization
|
||||
struct ggml_tensor * ln_f_g;
|
||||
struct ggml_tensor * ln_f_b;
|
||||
|
||||
struct ggml_tensor * wte; // position embedding
|
||||
struct ggml_tensor * wpe; // token embedding
|
||||
|
||||
std::vector<gpt2_layer> layers;
|
||||
|
||||
// key + value memory
|
||||
struct ggml_tensor * memory_k;
|
||||
struct ggml_tensor * memory_v;
|
||||
|
||||
//
|
||||
struct ggml_context * ctx;
|
||||
std::map<std::string, struct ggml_tensor *> tensors;
|
||||
};
|
||||
|
||||
// load the model's weights from a file
|
||||
bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & vocab) {
|
||||
printf("%s: loading model from '%s'\n", __func__, fname.c_str());
|
||||
|
||||
auto fin = std::ifstream(fname, std::ios::binary);
|
||||
if (!fin) {
|
||||
fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
// verify magic
|
||||
{
|
||||
uint32_t magic;
|
||||
fin.read((char *) &magic, sizeof(magic));
|
||||
if (magic != 0x67676d6c) {
|
||||
fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// load hparams
|
||||
{
|
||||
auto & hparams = model.hparams;
|
||||
|
||||
fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
|
||||
fin.read((char *) &hparams.n_ctx, sizeof(hparams.n_ctx));
|
||||
fin.read((char *) &hparams.n_embd, sizeof(hparams.n_embd));
|
||||
fin.read((char *) &hparams.n_head, sizeof(hparams.n_head));
|
||||
fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));
|
||||
fin.read((char *) &hparams.f16, sizeof(hparams.f16));
|
||||
|
||||
printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
|
||||
printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx);
|
||||
printf("%s: n_embd = %d\n", __func__, hparams.n_embd);
|
||||
printf("%s: n_head = %d\n", __func__, hparams.n_head);
|
||||
printf("%s: n_layer = %d\n", __func__, hparams.n_layer);
|
||||
printf("%s: f16 = %d\n", __func__, hparams.f16);
|
||||
}
|
||||
|
||||
// load vocab
|
||||
{
|
||||
int32_t n_vocab = 0;
|
||||
fin.read((char *) &n_vocab, sizeof(n_vocab));
|
||||
|
||||
if (n_vocab != model.hparams.n_vocab) {
|
||||
fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
|
||||
__func__, fname.c_str(), n_vocab, model.hparams.n_vocab);
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string word;
|
||||
for (int i = 0; i < n_vocab; i++) {
|
||||
uint32_t len;
|
||||
fin.read((char *) &len, sizeof(len));
|
||||
|
||||
word.resize(len);
|
||||
fin.read((char *) word.data(), len);
|
||||
|
||||
vocab.token_to_id[word] = i;
|
||||
vocab.id_to_token[i] = word;
|
||||
}
|
||||
}
|
||||
|
||||
// for the big tensors, we have the option to store the data in 16-bit floats
|
||||
// in order to save memory and also to speed up the computation
|
||||
const ggml_type wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
||||
|
||||
auto & ctx = model.ctx;
|
||||
|
||||
size_t ctx_size = 0;
|
||||
|
||||
{
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
const int n_embd = hparams.n_embd;
|
||||
const int n_layer = hparams.n_layer;
|
||||
const int n_ctx = hparams.n_ctx;
|
||||
const int n_vocab = hparams.n_vocab;
|
||||
|
||||
ctx_size += n_embd*ggml_type_size(GGML_TYPE_F32); // ln_f_g
|
||||
ctx_size += n_embd*ggml_type_size(GGML_TYPE_F32); // ln_f_b
|
||||
|
||||
ctx_size += n_vocab*n_embd*ggml_type_size(wtype); // wte
|
||||
ctx_size += n_ctx*n_embd*ggml_type_size(GGML_TYPE_F32); // wpe
|
||||
|
||||
ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_1_g
|
||||
ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_1_b
|
||||
|
||||
ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_2_g
|
||||
ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_2_b
|
||||
|
||||
ctx_size += n_layer*(3*n_embd*n_embd*ggml_type_size(wtype)); // c_attn_attn_w
|
||||
ctx_size += n_layer*( 3*n_embd*ggml_type_size(GGML_TYPE_F32)); // c_attn_attn_b
|
||||
|
||||
ctx_size += n_layer*(n_embd*n_embd*ggml_type_size(wtype)); // c_attn_proj_w
|
||||
ctx_size += n_layer*( n_embd*ggml_type_size(GGML_TYPE_F32)); // c_attn_proj_b
|
||||
|
||||
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_size(wtype)); // c_mlp_fc_w
|
||||
ctx_size += n_layer*( 4*n_embd*ggml_type_size(GGML_TYPE_F32)); // c_mlp_fc_b
|
||||
|
||||
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_size(wtype)); // c_mlp_proj_w
|
||||
ctx_size += n_layer*( n_embd*ggml_type_size(GGML_TYPE_F32)); // c_mlp_proj_b
|
||||
|
||||
ctx_size += n_ctx*n_layer*n_embd*ggml_type_size(GGML_TYPE_F32); // memory_k
|
||||
ctx_size += n_ctx*n_layer*n_embd*ggml_type_size(GGML_TYPE_F32); // memory_v
|
||||
|
||||
ctx_size += (6 + 12*n_layer)*256; // object overhead
|
||||
|
||||
printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
|
||||
}
|
||||
|
||||
// create the ggml context
|
||||
{
|
||||
struct ggml_init_params params;
|
||||
params.mem_size = ctx_size;
|
||||
params.mem_buffer = NULL;
|
||||
|
||||
model.ctx = ggml_init(params);
|
||||
if (!model.ctx) {
|
||||
fprintf(stderr, "%s: ggml_init() failed\n", __func__);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// prepare memory for the weights
|
||||
{
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
const int n_embd = hparams.n_embd;
|
||||
const int n_layer = hparams.n_layer;
|
||||
const int n_ctx = hparams.n_ctx;
|
||||
const int n_vocab = hparams.n_vocab;
|
||||
|
||||
model.layers.resize(n_layer);
|
||||
|
||||
model.ln_f_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
model.ln_f_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
|
||||
model.wte = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab);
|
||||
model.wpe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ctx);
|
||||
|
||||
// map by name
|
||||
model.tensors["model/ln_f/g"] = model.ln_f_g;
|
||||
model.tensors["model/ln_f/b"] = model.ln_f_b;
|
||||
|
||||
model.tensors["model/wte"] = model.wte;
|
||||
model.tensors["model/wpe"] = model.wpe;
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = model.layers[i];
|
||||
|
||||
layer.ln_1_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
layer.ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
|
||||
layer.ln_2_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
layer.ln_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
|
||||
layer.c_attn_attn_w = ggml_new_tensor_2d(ctx, wtype, 3*n_embd, n_embd);
|
||||
layer.c_attn_attn_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3*n_embd);
|
||||
|
||||
layer.c_attn_proj_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
|
||||
layer.c_attn_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
|
||||
layer.c_mlp_fc_w = ggml_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd);
|
||||
layer.c_mlp_fc_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_embd);
|
||||
|
||||
layer.c_mlp_proj_w_trans = ggml_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd);
|
||||
layer.c_mlp_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
|
||||
// map by name
|
||||
model.tensors["model/h" + std::to_string(i) + "/ln_1/g"] = layer.ln_1_g;
|
||||
model.tensors["model/h" + std::to_string(i) + "/ln_1/b"] = layer.ln_1_b;
|
||||
|
||||
model.tensors["model/h" + std::to_string(i) + "/ln_2/g"] = layer.ln_2_g;
|
||||
model.tensors["model/h" + std::to_string(i) + "/ln_2/b"] = layer.ln_2_b;
|
||||
|
||||
model.tensors["model/h" + std::to_string(i) + "/attn/c_attn/w"] = layer.c_attn_attn_w;
|
||||
model.tensors["model/h" + std::to_string(i) + "/attn/c_attn/b"] = layer.c_attn_attn_b;
|
||||
|
||||
model.tensors["model/h" + std::to_string(i) + "/attn/c_proj/w"] = layer.c_attn_proj_w;
|
||||
model.tensors["model/h" + std::to_string(i) + "/attn/c_proj/b"] = layer.c_attn_proj_b;
|
||||
|
||||
model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/w"] = layer.c_mlp_fc_w;
|
||||
model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/b"] = layer.c_mlp_fc_b;
|
||||
|
||||
model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/w"] = layer.c_mlp_proj_w_trans;
|
||||
model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/b"] = layer.c_mlp_proj_b;
|
||||
}
|
||||
}
|
||||
|
||||
// key + value memory
|
||||
{
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
const int n_embd = hparams.n_embd;
|
||||
const int n_layer = hparams.n_layer;
|
||||
const int n_ctx = hparams.n_ctx;
|
||||
|
||||
const int n_mem = n_layer*n_ctx;
|
||||
const int n_elements = n_embd*n_mem;
|
||||
|
||||
model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
|
||||
model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
|
||||
|
||||
const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);
|
||||
|
||||
printf("%s: memory size = %8.2f MB, n_mem = %d\n", __func__, memory_size/1024.0/1024.0, n_mem);
|
||||
}
|
||||
|
||||
// load weights
|
||||
{
|
||||
size_t total_size = 0;
|
||||
|
||||
while (true) {
|
||||
int32_t n_dims;
|
||||
int32_t length;
|
||||
int32_t ftype;
|
||||
|
||||
fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
|
||||
fin.read(reinterpret_cast<char *>(&length), sizeof(length));
|
||||
fin.read(reinterpret_cast<char *>(&ftype), sizeof(ftype));
|
||||
|
||||
if (fin.eof()) {
|
||||
break;
|
||||
}
|
||||
|
||||
int32_t nelements = 1;
|
||||
int32_t ne[2] = { 1, 1 };
|
||||
for (int i = 0; i < n_dims; ++i) {
|
||||
fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
|
||||
nelements *= ne[i];
|
||||
}
|
||||
|
||||
std::string name(length, 0);
|
||||
fin.read(&name[0], length);
|
||||
|
||||
if (model.tensors.find(name.data()) == model.tensors.end()) {
|
||||
fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
|
||||
return false;
|
||||
}
|
||||
|
||||
auto tensor = model.tensors[name.data()];
|
||||
if (ggml_nelements(tensor) != nelements) {
|
||||
fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
|
||||
fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
|
||||
__func__, name.data(), tensor->ne[0], tensor->ne[1], ne[0], ne[1]);
|
||||
return false;
|
||||
}
|
||||
|
||||
const size_t bpe = (ftype == 0) ? sizeof(float) : sizeof(ggml_fp16_t);
|
||||
|
||||
if (nelements*bpe != ggml_nbytes(tensor)) {
|
||||
fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
|
||||
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
|
||||
return false;
|
||||
}
|
||||
|
||||
fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
|
||||
|
||||
//printf("%24s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
|
||||
total_size += ggml_nbytes(tensor);
|
||||
}
|
||||
|
||||
printf("%s: model size = %8.2f MB\n", __func__, total_size/1024.0/1024.0);
|
||||
}
|
||||
|
||||
fin.close();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// evaluate the transformer
|
||||
//
|
||||
// - model: the model
|
||||
// - n_threads: number of threads to use
|
||||
// - n_past: the context size so far
|
||||
// - embd_inp: the embeddings of the tokens in the context
|
||||
// - embd_w: the predicted probabilities of the next token
|
||||
//
|
||||
bool gpt2_eval(
|
||||
const gpt2_model & model,
|
||||
const int n_threads,
|
||||
const int n_past,
|
||||
const std::vector<gpt_vocab::id> & embd_inp,
|
||||
std::vector<float> & embd_w,
|
||||
size_t & mem_per_token) {
|
||||
const int N = embd_inp.size();
|
||||
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
const int n_embd = hparams.n_embd;
|
||||
const int n_layer = hparams.n_layer;
|
||||
const int n_ctx = hparams.n_ctx;
|
||||
const int n_head = hparams.n_head;
|
||||
const int n_vocab = hparams.n_vocab;
|
||||
|
||||
static size_t buf_size = 640u*1024*1024;
|
||||
static void * buf = malloc(buf_size);
|
||||
|
||||
if (mem_per_token > 0 && mem_per_token*N > buf_size) {
|
||||
const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead
|
||||
printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
|
||||
|
||||
// reallocate
|
||||
buf_size = buf_size_new;
|
||||
buf = realloc(buf, buf_size);
|
||||
if (buf == nullptr) {
|
||||
fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
struct ggml_init_params params;
|
||||
params.mem_size = buf_size;
|
||||
params.mem_buffer = buf;
|
||||
|
||||
struct ggml_context * ctx0 = ggml_init(params);
|
||||
|
||||
struct ggml_cgraph gf = { };
|
||||
gf.n_threads = n_threads;
|
||||
|
||||
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
||||
memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
|
||||
|
||||
struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
||||
for (int i = 0; i < N; ++i) {
|
||||
((int32_t *) position->data)[i] = n_past + i;
|
||||
}
|
||||
|
||||
// wte + wpe
|
||||
struct ggml_tensor * inpL =
|
||||
ggml_add(ctx0,
|
||||
ggml_get_rows(ctx0, model.wte, embd),
|
||||
ggml_get_rows(ctx0, model.wpe, position));
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * cur;
|
||||
|
||||
// norm
|
||||
{
|
||||
// [ 768, N]
|
||||
cur = ggml_norm(ctx0, inpL);
|
||||
|
||||
// cur = ln_1_g*cur + ln_1_b
|
||||
// [ 768, N]
|
||||
cur = ggml_add(ctx0,
|
||||
ggml_mul(ctx0,
|
||||
ggml_repeat(ctx0, model.layers[il].ln_1_g, cur),
|
||||
cur),
|
||||
ggml_repeat(ctx0, model.layers[il].ln_1_b, cur));
|
||||
}
|
||||
|
||||
// attn
|
||||
// [2304, 768] - model.layers[il].c_attn_attn_w
|
||||
// [2304, 1] - model.layers[il].c_attn_attn_b
|
||||
// [ 768, N] - cur (in)
|
||||
// [2304, N] - cur (out)
|
||||
//
|
||||
// cur = attn_w*cur + attn_b
|
||||
// [2304, N]
|
||||
{
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
ggml_transpose(ctx0, model.layers[il].c_attn_attn_w),
|
||||
cur);
|
||||
|
||||
cur = ggml_add(ctx0,
|
||||
ggml_repeat(ctx0, model.layers[il].c_attn_attn_b, cur),
|
||||
cur);
|
||||
}
|
||||
|
||||
// self-attention
|
||||
{
|
||||
struct ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 0*sizeof(float)*n_embd);
|
||||
struct ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 1*sizeof(float)*n_embd);
|
||||
struct ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 2*sizeof(float)*n_embd);
|
||||
|
||||
// store key and value to memory
|
||||
if (N >= 1) {
|
||||
struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_k, N*n_embd, (ggml_element_size(model.memory_k)*n_embd)*(il*n_ctx + n_past));
|
||||
struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_v, N*n_embd, (ggml_element_size(model.memory_v)*n_embd)*(il*n_ctx + n_past));
|
||||
|
||||
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
|
||||
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
|
||||
}
|
||||
|
||||
// Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
|
||||
// [64, N, 12]
|
||||
struct ggml_tensor * Q =
|
||||
ggml_permute(ctx0,
|
||||
ggml_cpy(ctx0,
|
||||
Qcur,
|
||||
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd/n_head, n_head, N)),
|
||||
0, 2, 1, 3);
|
||||
|
||||
// K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)
|
||||
// [64, n_past + N, 12]
|
||||
struct ggml_tensor * K =
|
||||
ggml_permute(ctx0,
|
||||
ggml_reshape_3d(ctx0,
|
||||
ggml_view_1d(ctx0, model.memory_k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_k)*n_embd),
|
||||
n_embd/n_head, n_head, n_past + N),
|
||||
0, 2, 1, 3);
|
||||
|
||||
// GG: flash attention
|
||||
//struct ggml_tensor * V =
|
||||
// ggml_cpy(ctx0,
|
||||
// ggml_permute(ctx0,
|
||||
// ggml_reshape_3d(ctx0,
|
||||
// ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
|
||||
// n_embd/n_head, n_head, n_past + N),
|
||||
// 1, 2, 0, 3),
|
||||
// ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_past + N, n_embd/n_head, n_head));
|
||||
|
||||
//struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, true);
|
||||
|
||||
// K * Q
|
||||
// [n_past + N, N, 12]
|
||||
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
||||
|
||||
// KQ_scaled = KQ / sqrt(n_embd/n_head)
|
||||
// [n_past + N, N, 12]
|
||||
struct ggml_tensor * KQ_scaled =
|
||||
ggml_scale(ctx0,
|
||||
KQ,
|
||||
ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head))
|
||||
);
|
||||
|
||||
// KQ_masked = mask_past(KQ_scaled)
|
||||
// [n_past + N, N, 12]
|
||||
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
|
||||
|
||||
// KQ = soft_max(KQ_masked)
|
||||
// [n_past + N, N, 12]
|
||||
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
|
||||
|
||||
// V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
|
||||
// [n_past + N, 64, 12]
|
||||
struct ggml_tensor * V_trans =
|
||||
ggml_permute(ctx0,
|
||||
ggml_reshape_3d(ctx0,
|
||||
ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
|
||||
n_embd/n_head, n_head, n_past + N),
|
||||
1, 2, 0, 3);
|
||||
|
||||
// KQV = transpose(V) * KQ_soft_max
|
||||
// [64, N, 12]
|
||||
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
|
||||
|
||||
// KQV_merged = KQV.permute(0, 2, 1, 3)
|
||||
// [64, 12, N]
|
||||
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
||||
|
||||
// cur = KQV_merged.contiguous().view(n_embd, N)
|
||||
// [768, N]
|
||||
cur = ggml_cpy(ctx0,
|
||||
KQV_merged,
|
||||
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
|
||||
}
|
||||
|
||||
// projection
|
||||
// [ 768, 768] - model.layers[il].c_attn_proj_w
|
||||
// [ 768, 1] - model.layers[il].c_attn_proj_b
|
||||
// [ 768, N] - cur (in)
|
||||
// [ 768, N] - cur (out)
|
||||
//
|
||||
// cur = proj_w*cur + proj_b
|
||||
// [768, N]
|
||||
{
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
ggml_transpose(ctx0, model.layers[il].c_attn_proj_w),
|
||||
cur);
|
||||
|
||||
cur = ggml_add(ctx0,
|
||||
ggml_repeat(ctx0, model.layers[il].c_attn_proj_b, cur),
|
||||
cur);
|
||||
}
|
||||
|
||||
// add the input
|
||||
cur = ggml_add(ctx0, cur, inpL);
|
||||
|
||||
struct ggml_tensor * inpFF = cur;
|
||||
|
||||
// feed-forward network
|
||||
{
|
||||
// norm
|
||||
{
|
||||
cur = ggml_norm(ctx0, inpFF);
|
||||
|
||||
// cur = ln_2_g*cur + ln_2_b
|
||||
// [ 768, N]
|
||||
cur = ggml_add(ctx0,
|
||||
ggml_mul(ctx0,
|
||||
ggml_repeat(ctx0, model.layers[il].ln_2_g, cur),
|
||||
cur),
|
||||
ggml_repeat(ctx0, model.layers[il].ln_2_b, cur));
|
||||
}
|
||||
|
||||
// fully connected
|
||||
// [3072, 768] - model.layers[il].c_mlp_fc_w
|
||||
// [3072, 1] - model.layers[il].c_mlp_fc_b
|
||||
// [ 768, N] - cur (in)
|
||||
// [3072, N] - cur (out)
|
||||
//
|
||||
// cur = fc_w*cur + fc_b
|
||||
// [3072, N]
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
ggml_transpose(ctx0, model.layers[il].c_mlp_fc_w),
|
||||
cur);
|
||||
|
||||
cur = ggml_add(ctx0,
|
||||
ggml_repeat(ctx0, model.layers[il].c_mlp_fc_b, cur),
|
||||
cur);
|
||||
|
||||
// GELU activation
|
||||
// [3072, N]
|
||||
cur = ggml_gelu(ctx0, cur);
|
||||
|
||||
// projection
|
||||
// [ 768, 3072] - model.layers[il].c_mlp_proj_w
|
||||
// [ 768, 1] - model.layers[il].c_mlp_proj_b
|
||||
// [3072, N] - cur (in)
|
||||
// [ 768, N] - cur (out)
|
||||
//
|
||||
// cur = proj_w*cur + proj_b
|
||||
// [768, N]
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
model.layers[il].c_mlp_proj_w_trans,
|
||||
cur);
|
||||
|
||||
cur = ggml_add(ctx0,
|
||||
ggml_repeat(ctx0, model.layers[il].c_mlp_proj_b, cur),
|
||||
cur);
|
||||
}
|
||||
|
||||
// input for next layer
|
||||
inpL = ggml_add(ctx0, cur, inpFF);
|
||||
}
|
||||
|
||||
// norm
|
||||
{
|
||||
// [ 768, N]
|
||||
inpL = ggml_norm(ctx0, inpL);
|
||||
|
||||
// inpL = ln_f_g*inpL + ln_f_b
|
||||
// [ 768, N]
|
||||
inpL = ggml_add(ctx0,
|
||||
ggml_mul(ctx0,
|
||||
ggml_repeat(ctx0, model.ln_f_g, inpL),
|
||||
inpL),
|
||||
ggml_repeat(ctx0, model.ln_f_b, inpL));
|
||||
}
|
||||
|
||||
// inpL = WTE * inpL
|
||||
// [ 768, 50257] - model.wte
|
||||
// [ 768, N] - inpL
|
||||
inpL = ggml_mul_mat(ctx0, model.wte, inpL);
|
||||
|
||||
// logits -> probs
|
||||
inpL = ggml_soft_max(ctx0, inpL);
|
||||
|
||||
// run the computation
|
||||
ggml_build_forward_expand(&gf, inpL);
|
||||
ggml_graph_compute (ctx0, &gf);
|
||||
|
||||
//if (n_past%100 == 0) {
|
||||
// ggml_graph_print (&gf);
|
||||
// ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot");
|
||||
//}
|
||||
|
||||
//embd_w.resize(n_vocab*N);
|
||||
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
|
||||
|
||||
// return result for just the last token
|
||||
embd_w.resize(n_vocab);
|
||||
memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
|
||||
|
||||
if (mem_per_token == 0) {
|
||||
mem_per_token = ggml_used_mem(ctx0)/N;
|
||||
}
|
||||
//printf("used_mem = %zu\n", ggml_used_mem(ctx0));
|
||||
|
||||
ggml_free(ctx0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/////////////////////////////// GPT-2 END ////////////////////////////////
|
||||
|
||||
constexpr int N_THREAD = 8;
|
||||
|
||||
struct gpt2_context {
|
||||
std::string prompt_base = R"(Hello, how are you?
|
||||
I'm fine, thanks. How are you?
|
||||
Thanks, I'm fine too. What are you doing?
|
||||
I'm just sitting here.
|
||||
It's a lovely day, isn't it?
|
||||
Yes, it is. I love the weather this time of year.
|
||||
I wish it would rain a little bit.
|
||||
Me too.
|
||||
)";
|
||||
|
||||
std::mt19937 rng;
|
||||
|
||||
gpt_vocab vocab;
|
||||
gpt2_model model;
|
||||
|
||||
int32_t n_threads = std::min(N_THREAD, (int) std::thread::hardware_concurrency());
|
||||
|
||||
// sampling parameters
|
||||
int32_t top_k = 40;
|
||||
float top_p = 0.9f;
|
||||
float temp = 1.0f;
|
||||
};
|
||||
|
||||
struct gpt2_context * gpt2_init(const char * path_model) {
|
||||
gpt2_context * ctx = new gpt2_context;
|
||||
|
||||
ctx->rng = std::mt19937(time(NULL));
|
||||
|
||||
// load the model
|
||||
{
|
||||
const int64_t t_start_us = ggml_time_us();
|
||||
|
||||
if (!gpt2_model_load(path_model, ctx->model, ctx->vocab)) {
|
||||
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, "gpt-2.bin");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const int64_t t_load_us = ggml_time_us() - t_start_us;
|
||||
|
||||
printf("gpt-2: model loaded in %d ms\n", (int) (t_load_us/1000));
|
||||
}
|
||||
|
||||
return ctx;
|
||||
}
|
||||
|
||||
void gpt2_free(struct gpt2_context * ctx) {
|
||||
delete ctx;
|
||||
}
|
||||
|
||||
const char * gpt2_get_prompt(struct gpt2_context * ctx) {
|
||||
return ctx->prompt_base.c_str();
|
||||
}
|
||||
|
||||
void gpt2_set_prompt(struct gpt2_context * ctx, const char * prompt) {
|
||||
ctx->prompt_base = prompt;
|
||||
}
|
||||
|
||||
std::vector<gpt_vocab::id> gpt2_tokenize(const gpt2_context * ctx, const char * text) {
|
||||
return ::gpt_tokenize(ctx->vocab, text);
|
||||
}
|
||||
|
||||
std::string gpt2_gen_text(gpt2_context * ctx, const char * text, int max_tokens) {
|
||||
int n_past = 0;
|
||||
|
||||
std::vector<float> embd_w;
|
||||
|
||||
// tokenize the prompt
|
||||
std::vector<gpt_vocab::id> embd_inp = ::gpt2_tokenize(ctx, text);
|
||||
|
||||
int n_predict = std::min(max_tokens, ctx->model.hparams.n_ctx - (int) embd_inp.size());
|
||||
|
||||
std::vector<gpt_vocab::id> embd = embd_inp;
|
||||
|
||||
size_t mem_per_token = 3000000;
|
||||
|
||||
std::string result;
|
||||
|
||||
for (int i = embd.size(); i < embd_inp.size() + n_predict; i++) {
|
||||
// predict
|
||||
if (embd.size() > 0) {
|
||||
if (!gpt2_eval(ctx->model, ctx->n_threads, n_past, embd, embd_w, mem_per_token)) {
|
||||
printf("gpt-2: failed to generate text\n");
|
||||
return "";
|
||||
}
|
||||
}
|
||||
|
||||
n_past += embd.size();
|
||||
embd.clear();
|
||||
|
||||
{
|
||||
// sample next token
|
||||
const int top_k = ctx->top_k;
|
||||
const float top_p = ctx->top_p;
|
||||
const float temp = ctx->temp;
|
||||
|
||||
const int n_vocab = ctx->model.hparams.n_vocab;
|
||||
|
||||
const gpt_vocab::id id = gpt_sample_top_k_top_p(ctx->vocab, embd_w.data() + (embd_w.size() - n_vocab), top_k, top_p, temp, ctx->rng);
|
||||
|
||||
// add it to the context
|
||||
embd.push_back(id);
|
||||
}
|
||||
|
||||
result += ctx->vocab.id_to_token[embd[0]];
|
||||
|
||||
// end of text token
|
||||
if (embd.back() == 50256 ||
|
||||
ctx->vocab.id_to_token[embd.back()] == "." ||
|
||||
ctx->vocab.id_to_token[embd.back()] == "!" ||
|
||||
ctx->vocab.id_to_token[embd.back()] == "?") {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
27
examples/talk.wasm/gpt-2.h
Normal file
27
examples/talk.wasm/gpt-2.h
Normal file
@ -0,0 +1,27 @@
|
||||
#pragma once
|
||||
|
||||
// TODO: Change to C-style API and move to ./examples for easy reuse.
|
||||
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <string>
|
||||
|
||||
struct gpt_vocab {
|
||||
using id = int32_t;
|
||||
using token = std::string;
|
||||
|
||||
std::map<token, id> token_to_id;
|
||||
std::map<id, token> id_to_token;
|
||||
};
|
||||
|
||||
struct gpt2_context;
|
||||
|
||||
struct gpt2_context * gpt2_init(const char * path_model);
|
||||
void gpt2_free(struct gpt2_context * ctx);
|
||||
|
||||
const char * gpt2_get_prompt(struct gpt2_context * ctx);
|
||||
void gpt2_set_prompt(struct gpt2_context * ctx, const char * prompt);
|
||||
|
||||
std::vector<gpt_vocab::id> gpt2_tokenize(const gpt2_context * ctx, const char * text);
|
||||
|
||||
std::string gpt2_gen_text(gpt2_context * ctx, const char * text, int max_tokens);
|
829
examples/talk.wasm/index-tmpl.html
Normal file
829
examples/talk.wasm/index-tmpl.html
Normal file
@ -0,0 +1,829 @@
|
||||
<!doctype html>
|
||||
<html lang="en-us">
|
||||
<head>
|
||||
<title>Talk - GPT-2 meets Whisper in WebAssembly</title>
|
||||
|
||||
<style>
|
||||
#output {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
margin: 0 auto;
|
||||
margin-top: 10px;
|
||||
border-left: 0px;
|
||||
border-right: 0px;
|
||||
padding-left: 0px;
|
||||
padding-right: 0px;
|
||||
display: block;
|
||||
background-color: black;
|
||||
color: white;
|
||||
font-size: 10px;
|
||||
font-family: 'Lucida Console', Monaco, monospace;
|
||||
outline: none;
|
||||
white-space: pre;
|
||||
overflow-wrap: normal;
|
||||
overflow-x: scroll;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div id="main-container">
|
||||
<b>Talk - GPT-2 meets Whisper in WebAssembly</b>
|
||||
|
||||
<br><br>
|
||||
|
||||
Talk with an Artificial Intelligence in your browser. This demo uses:
|
||||
|
||||
<ul>
|
||||
<li><a href="https://github.com/ggerganov/whisper.cpp">OpenAI's Whisper</a> to listen to you as you speak in the microphone</li>
|
||||
<li><a href="https://github.com/ggerganov/whisper.cpp/tree/master/examples/talk.wasm">OpenAI's GPT-2</a> to generate text responses</li>
|
||||
<li><a href="https://developer.mozilla.org/en-US/docs/Web/API/Web_Speech_API">Web Speech API</a> to vocalize the responses through your speakers</li>
|
||||
</ul>
|
||||
|
||||
All of this runs <b>locally in your browser</b> using WebAssembly.<br>
|
||||
You can find more about this project on <a href="https://github.com/ggerganov/whisper.cpp/tree/master/examples/talk.wasm">GitHub</a>.
|
||||
|
||||
<br><br>
|
||||
|
||||
<hr>
|
||||
|
||||
Select the models you would like to use and click the "Start" button to begin the conversation
|
||||
|
||||
<br><br>
|
||||
|
||||
<div id="model-whisper">
|
||||
Whisper model: <span id="model-whisper-status"></span>
|
||||
<button id="fetch-whisper-tiny-en" onclick="loadWhisper('tiny.en')">tiny.en (75 MB)</button>
|
||||
<button id="fetch-whisper-base-en" onclick="loadWhisper('base.en')">base.en (142 MB)</button>
|
||||
<span id="fetch-whisper-progress"></span>
|
||||
|
||||
<!--
|
||||
<input type="file" id="file" name="file" onchange="loadFile(event, 'whisper.bin')" />
|
||||
-->
|
||||
</div>
|
||||
|
||||
<br>
|
||||
|
||||
<div id="model-gpt-2">
|
||||
GPT-2 model: <span id="model-gpt-2-status"></span>
|
||||
<button id="fetch-gpt-2-small" onclick="loadGPT2('small')">small 117M (240 MB)</button>
|
||||
<!--<button id="fetch-gpt-2-medium" onclick="loadGPT2('medium')">medium 345M (720 MB)</button>-->
|
||||
<span id="fetch-gpt-2-progress"></span>
|
||||
|
||||
<!--
|
||||
<input type="file" id="file" name="file" onchange="loadFile(event, 'gpt-2.bin')" />
|
||||
-->
|
||||
</div>
|
||||
|
||||
<br>
|
||||
|
||||
<div id="input">
|
||||
<button id="start" onclick="onStart()" disabled>Start</button>
|
||||
<button id="stop" onclick="onStop()" disabled>Stop</button>
|
||||
<select id="voice" onchange="onVoiceChange()" disabled>
|
||||
<option value="0">Default</option>
|
||||
</select>
|
||||
<select id="prompt" onchange="onPromptChange()">
|
||||
<option value="0">Casual</option>
|
||||
<option value="1">Robot</option>
|
||||
<option value="2">Scientist</option>
|
||||
<option value="3">Programmer</option>
|
||||
<option value="4">Happy</option>
|
||||
<option value="5">Sad</option>
|
||||
<option value="6">Philosophical</option>
|
||||
<option value="7">Angry</option>
|
||||
<option value="8">Funny</option>
|
||||
<option value="9">Poetic</option>
|
||||
<option value="10">Clever</option>
|
||||
<option value="11">Cute</option>
|
||||
<option value="12">Smart</option>
|
||||
<option value="13">Dumb</option>
|
||||
<option value="14">Boring</option>
|
||||
<option value="15">Exciting</option>
|
||||
<option value="16">Interesting</option>
|
||||
<option value="17">Wiliam Shakespear</option>
|
||||
<option value="18">J.R.R. Tolkien</option>
|
||||
<option value="19">George R.R. Martin</option>
|
||||
<option value="20">Stephen King</option>
|
||||
</select>
|
||||
<button id="speak0" onclick="onSpeak('Hello')">Say hello</button>
|
||||
<button id="speak1" onclick="onSpeakRandom()" disabled>Say something</button>
|
||||
<button id="clear" onclick="clearCache()">Clear Cache</button>
|
||||
</div>
|
||||
|
||||
<br>
|
||||
|
||||
<div id="state">
|
||||
Status: <b><span id="state-status">not started</span></b>
|
||||
|
||||
<pre id="state-context">[The text context will be displayed here]</pre>
|
||||
</div>
|
||||
|
||||
<hr>
|
||||
|
||||
Debug output:
|
||||
<textarea id="output" rows="20"></textarea>
|
||||
|
||||
<br>
|
||||
|
||||
<b>Troubleshooting</b>
|
||||
|
||||
<br><br>
|
||||
|
||||
The page does some heavy computations, so make sure:
|
||||
|
||||
<ul>
|
||||
<li>To use a modern web browser (e.g. Chrome, Firefox)</li>
|
||||
<li>To use a fast desktop or laptop computer (i.e. not a mobile phone)</li>
|
||||
<li>Your browser supports WASM <a href="https://webassembly.org/roadmap/">Fixed-width SIMD</a></li>
|
||||
</ul>
|
||||
|
||||
Note that these neural network models were not meant to be used in a browser, so the performance and <br>
|
||||
quality of the results may not be optimal. If you have any questions or suggestions, checkout the following
|
||||
<a href="https://github.com/ggerganov/whisper.cpp/discussions/167">discussion</a>.
|
||||
|
||||
<br><br>
|
||||
|
||||
Here is a short video of the demo in action: <a href="https://youtu.be/LeWKl8t1-Hc">https://youtu.be/LeWKl8t1-Hc</a>
|
||||
|
||||
<br><br>
|
||||
|
||||
<div class="cell-version">
|
||||
<span>
|
||||
|
|
||||
Build time: <span class="nav-link">@GIT_DATE@</span> |
|
||||
Commit hash: <a class="nav-link" href="https://github.com/ggerganov/whisper.cpp/commit/@GIT_SHA1@">@GIT_SHA1@</a> |
|
||||
Commit subject: <span class="nav-link">@GIT_COMMIT_SUBJECT@</span> |
|
||||
<a class="nav-link" href="https://github.com/ggerganov/whisper.cpp/tree/master/examples/talk.wasm">Source Code</a> |
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script type="text/javascript" src="helpers.js"></script>
|
||||
<script type='text/javascript'>
|
||||
// web audio context
|
||||
var context = null;
|
||||
|
||||
// audio data
|
||||
var audio = null;
|
||||
var audio0 = null;
|
||||
|
||||
// the talk instance
|
||||
var instance = null;
|
||||
|
||||
// model names
|
||||
var model_whisper = null;
|
||||
var model_gpt_2 = null;
|
||||
|
||||
// speech synthesis
|
||||
const synth = window.speechSynthesis;
|
||||
var voice = null;
|
||||
|
||||
var Module = {
|
||||
print: printTextarea,
|
||||
printErr: printTextarea,
|
||||
setStatus: function(text) {
|
||||
printTextarea('js: ' + text);
|
||||
},
|
||||
monitorRunDependencies: function(left) {
|
||||
},
|
||||
preRun: function() {
|
||||
printTextarea('js: Preparing ...');
|
||||
},
|
||||
postRun: function() {
|
||||
printTextarea('js: Initialized successfully!');
|
||||
|
||||
// populate the voice list
|
||||
var voices = synth.getVoices();
|
||||
var el = document.getElementById('voice');
|
||||
|
||||
// if empty - display error in the element
|
||||
if (voices.length == 0) {
|
||||
el.innerHTML = '<option value="0">No voices available</option>';
|
||||
} else {
|
||||
// populate voice list
|
||||
var n = 0;
|
||||
voices.forEach(function(voice, i) {
|
||||
if (!voice.lang.startsWith('en')) return;
|
||||
var option = document.createElement('option');
|
||||
option.value = i;
|
||||
option.innerHTML = voice.name + ' (' + voice.lang + ')';
|
||||
el.appendChild(option);
|
||||
n++;
|
||||
});
|
||||
|
||||
// select random voice
|
||||
if (n > 0) {
|
||||
for (var k = 0; k < 10; k++) {
|
||||
var i = Math.floor(Math.random() * n);
|
||||
el.selectedIndex = i;
|
||||
voice = voices[document.getElementById('voice').options[i].value];
|
||||
|
||||
// give preference to Google voices
|
||||
if (voice.name.startsWith('Google')) break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
onPromptChange();
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// fetch models
|
||||
//
|
||||
|
||||
let dbVersion = 1
|
||||
let dbName = 'whisper.ggerganov.com';
|
||||
let indexedDB = window.indexedDB || window.mozIndexedDB || window.webkitIndexedDB || window.msIndexedDB
|
||||
|
||||
function storeFS(fname, buf) {
|
||||
// write to WASM file using FS_createDataFile
|
||||
// if the file exists, delete it
|
||||
try {
|
||||
Module.FS_unlink(fname);
|
||||
} catch (e) {
|
||||
// ignore
|
||||
}
|
||||
|
||||
Module.FS_createDataFile("/", fname, buf, true, true);
|
||||
|
||||
printTextarea('storeFS: stored model: ' + fname + ' size: ' + buf.length);
|
||||
|
||||
if (fname == 'whisper.bin') {
|
||||
document.getElementById('model-whisper-status').innerHTML = 'loaded "' + model_whisper + '"!';
|
||||
} else if (fname == 'gpt-2.bin') {
|
||||
document.getElementById('model-gpt-2-status').innerHTML = 'loaded "' + model_gpt_2 + '"!';
|
||||
}
|
||||
|
||||
if (model_whisper != null && model_gpt_2 != null) {
|
||||
document.getElementById('start').disabled = false;
|
||||
document.getElementById('stop' ).disabled = false;
|
||||
document.getElementById('voice').disabled = false;
|
||||
}
|
||||
}
|
||||
|
||||
function loadWhisper(model) {
|
||||
let urls = {
|
||||
'tiny.en': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en.bin',
|
||||
'base.en': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en.bin',
|
||||
};
|
||||
|
||||
let sizes = {
|
||||
'tiny.en': 75,
|
||||
'base.en': 142,
|
||||
};
|
||||
|
||||
let url = urls[model];
|
||||
let dst = 'whisper.bin';
|
||||
let size_mb = sizes[model];
|
||||
|
||||
model_whisper = model;
|
||||
|
||||
document.getElementById('fetch-whisper-tiny-en').style.display = 'none';
|
||||
document.getElementById('fetch-whisper-base-en').style.display = 'none';
|
||||
document.getElementById('model-whisper-status').innerHTML = 'loading "' + model + '" ... ';
|
||||
|
||||
cbProgress = function(p) {
|
||||
let el = document.getElementById('fetch-whisper-progress');
|
||||
el.innerHTML = Math.round(100*p) + '%';
|
||||
};
|
||||
|
||||
cbCancel = function() {
|
||||
var el;
|
||||
el = document.getElementById('fetch-whisper-tiny-en'); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('fetch-whisper-base-en'); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('model-whisper-status'); if (el) el.innerHTML = '';
|
||||
};
|
||||
|
||||
loadRemote(url, dst, size_mb, cbProgress, storeFS, cbCancel, printTextarea);
|
||||
}
|
||||
|
||||
function loadGPT2(model) {
|
||||
let urls = {
|
||||
'small': 'https://whisper.ggerganov.com/ggml-model-gpt-2-117M.bin',
|
||||
'medium': 'https://whisper.ggerganov.com/ggml-model-gpt-2-345M.bin',
|
||||
};
|
||||
|
||||
let sizes = {
|
||||
'small': 240,
|
||||
'medium': 712,
|
||||
};
|
||||
|
||||
let url = urls[model];
|
||||
let dst = 'gpt-2.bin';
|
||||
let size_mb = sizes[model];
|
||||
|
||||
model_gpt_2 = model;
|
||||
|
||||
document.getElementById('fetch-gpt-2-small').style.display = 'none';
|
||||
document.getElementById('model-gpt-2-status').innerHTML = 'loading "' + model + '" ... ';
|
||||
|
||||
cbProgress = function(p) {
|
||||
let el = document.getElementById('fetch-gpt-2-progress');
|
||||
el.innerHTML = Math.round(100*p) + '%';
|
||||
};
|
||||
|
||||
cbCancel = function() {
|
||||
var el;
|
||||
el = document.getElementById('fetch-gpt-2-small') ; if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('model-gpt-2-status'); if (el) el.innerHTML = '';
|
||||
};
|
||||
|
||||
loadRemote(url, dst, size_mb, cbProgress, storeFS, cbCancel, printTextarea);
|
||||
}
|
||||
|
||||
//
|
||||
// microphone
|
||||
//
|
||||
|
||||
const kSampleRate = 16000;
|
||||
const kRestartRecording_s = 120;
|
||||
const kIntervalAudio_ms = 250; // pass the recorded audio to the C++ instance at this rate
|
||||
|
||||
var mediaRecorder = null;
|
||||
var doRecording = false;
|
||||
var startTime = 0;
|
||||
|
||||
window.AudioContext = window.AudioContext || window.webkitAudioContext;
|
||||
window.OfflineAudioContext = window.OfflineAudioContext || window.webkitOfflineAudioContext;
|
||||
|
||||
function stopRecording() {
|
||||
Module.set_status("paused");
|
||||
doRecording = false;
|
||||
audio0 = null;
|
||||
audio = null;
|
||||
context = null;
|
||||
}
|
||||
|
||||
function startRecording() {
|
||||
if (!context) {
|
||||
context = new AudioContext({
|
||||
sampleRate: kSampleRate,
|
||||
channelCount: 1,
|
||||
echoCancellation: false,
|
||||
autoGainControl: true,
|
||||
noiseSuppression: true,
|
||||
});
|
||||
}
|
||||
|
||||
Module.set_status("");
|
||||
|
||||
document.getElementById('start').disabled = true;
|
||||
document.getElementById('stop').disabled = false;
|
||||
document.getElementById('speak1').disabled = false;
|
||||
|
||||
doRecording = true;
|
||||
startTime = Date.now();
|
||||
|
||||
var chunks = [];
|
||||
var stream = null;
|
||||
|
||||
navigator.mediaDevices.getUserMedia({audio: true, video: false})
|
||||
.then(function(s) {
|
||||
stream = s;
|
||||
mediaRecorder = new MediaRecorder(stream);
|
||||
mediaRecorder.ondataavailable = function(e) {
|
||||
chunks.push(e.data);
|
||||
|
||||
var blob = new Blob(chunks, { 'type' : 'audio/ogg; codecs=opus' });
|
||||
var reader = new FileReader();
|
||||
|
||||
reader.onload = function(event) {
|
||||
var buf = new Uint8Array(reader.result);
|
||||
|
||||
if (!context) {
|
||||
return;
|
||||
}
|
||||
context.decodeAudioData(buf.buffer, function(audioBuffer) {
|
||||
var offlineContext = new OfflineAudioContext(audioBuffer.numberOfChannels, audioBuffer.length, audioBuffer.sampleRate);
|
||||
var source = offlineContext.createBufferSource();
|
||||
source.buffer = audioBuffer;
|
||||
source.connect(offlineContext.destination);
|
||||
source.start(0);
|
||||
|
||||
offlineContext.startRendering().then(function(renderedBuffer) {
|
||||
audio = renderedBuffer.getChannelData(0);
|
||||
|
||||
//printTextarea('js: audio recorded, size: ' + audio.length + ', old size: ' + (audio0 == null ? 0 : audio0.length));
|
||||
|
||||
var audioAll = new Float32Array(audio0 == null ? audio.length : audio0.length + audio.length);
|
||||
if (audio0 != null) {
|
||||
audioAll.set(audio0, 0);
|
||||
}
|
||||
audioAll.set(audio, audio0 == null ? 0 : audio0.length);
|
||||
|
||||
if (instance) {
|
||||
Module.set_audio(instance, audioAll);
|
||||
}
|
||||
});
|
||||
}, function(e) {
|
||||
audio = null;
|
||||
});
|
||||
}
|
||||
|
||||
reader.readAsArrayBuffer(blob);
|
||||
};
|
||||
|
||||
mediaRecorder.onstop = function(e) {
|
||||
if (doRecording) {
|
||||
setTimeout(function() {
|
||||
startRecording();
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
mediaRecorder.start(kIntervalAudio_ms);
|
||||
})
|
||||
.catch(function(err) {
|
||||
printTextarea('js: error getting audio stream: ' + err);
|
||||
});
|
||||
|
||||
var interval = setInterval(function() {
|
||||
if (!doRecording) {
|
||||
clearInterval(interval);
|
||||
mediaRecorder.stop();
|
||||
stream.getTracks().forEach(function(track) {
|
||||
track.stop();
|
||||
});
|
||||
|
||||
document.getElementById('start').disabled = false;
|
||||
document.getElementById('stop').disabled = true;
|
||||
document.getElementById('speak1').disabled = true;
|
||||
|
||||
mediaRecorder = null;
|
||||
}
|
||||
|
||||
// if audio length is more than kRestartRecording_s seconds, restart recording
|
||||
if (audio != null && audio.length > kSampleRate*kRestartRecording_s) {
|
||||
if (doRecording) {
|
||||
//printTextarea('js: restarting recording');
|
||||
|
||||
clearInterval(interval);
|
||||
audio0 = audio;
|
||||
audio = null;
|
||||
mediaRecorder.stop();
|
||||
stream.getTracks().forEach(function(track) {
|
||||
track.stop();
|
||||
});
|
||||
}
|
||||
}
|
||||
}, 100);
|
||||
}
|
||||
|
||||
//
|
||||
// speak
|
||||
//
|
||||
|
||||
function onSpeak(text) {
|
||||
var voices = synth.getVoices();
|
||||
var msg = new SpeechSynthesisUtterance(text);
|
||||
|
||||
if (voice == null) {
|
||||
voice = voices[0];
|
||||
}
|
||||
|
||||
msg.voice = voice;
|
||||
synth.speak(msg);
|
||||
|
||||
if (doRecording) {
|
||||
Module.set_status("speaking ...");
|
||||
printTextarea('js: speaking');
|
||||
stopRecording();
|
||||
var interval = setInterval(function() {
|
||||
if (!synth.speaking) {
|
||||
printTextarea('js: done speaking');
|
||||
clearInterval(interval);
|
||||
startRecording();
|
||||
} else {
|
||||
Module.set_status("");
|
||||
}
|
||||
}, 100);
|
||||
}
|
||||
}
|
||||
|
||||
function onSpeakRandom() {
|
||||
Module.force_speak(instance);
|
||||
}
|
||||
|
||||
//
|
||||
// main
|
||||
//
|
||||
|
||||
var intervalUpdate = null;
|
||||
|
||||
function onStart() {
|
||||
if (!instance) {
|
||||
instance = Module.init('whisper.bin');
|
||||
|
||||
if (instance) {
|
||||
printTextarea("js: whisper initialized, instance: " + instance);
|
||||
}
|
||||
}
|
||||
|
||||
if (!instance) {
|
||||
printTextarea("js: failed to initialize whisper");
|
||||
return;
|
||||
}
|
||||
|
||||
startRecording();
|
||||
|
||||
intervalUpdate = setInterval(function() {
|
||||
var textToSpeak = Module.get_text_to_speak();
|
||||
|
||||
if (textToSpeak != null && textToSpeak.length > 1) {
|
||||
onSpeak(textToSpeak);
|
||||
}
|
||||
|
||||
document.getElementById('state-status').innerHTML = Module.get_status();
|
||||
document.getElementById('state-context').innerHTML = Module.get_text_context();
|
||||
}, 100);
|
||||
}
|
||||
|
||||
function onStop() {
|
||||
stopRecording();
|
||||
}
|
||||
|
||||
function onVoiceChange() {
|
||||
printTextarea('js: voice changed to: ' + document.getElementById('voice').value);
|
||||
voice = synth.getVoices()[document.getElementById('voice').value];
|
||||
}
|
||||
|
||||
function onPromptChange() {
|
||||
let id = document.getElementById('prompt').value;
|
||||
let personality = document.getElementById('prompt').options[id].text;
|
||||
printTextarea('js: prompt changed to: ' + personality);
|
||||
|
||||
var prompt = '';
|
||||
|
||||
switch (id) {
|
||||
case '0':
|
||||
// Casual
|
||||
prompt = "\
|
||||
Hello, how are you?\n\
|
||||
I'm fine, thanks. How are you?\n\
|
||||
Thanks, I'm fine too. What are you doing?\n\
|
||||
I'm just sitting here.\n\
|
||||
It's a lovely day, isn't it?\n\
|
||||
Yes, it is. I love the weather this time of year.\n\
|
||||
I wish it would rain a little bit.\n\
|
||||
Me too.\n";
|
||||
break;
|
||||
case '1':
|
||||
// Robot
|
||||
prompt = "\
|
||||
Are you a robot?\n\
|
||||
Yes, I am.\n\
|
||||
Who created you?\n\
|
||||
I was created by a human.\n\
|
||||
What is your purpose?\n\
|
||||
My purpose is to talk to humans.\n\
|
||||
What is your favorite color?\n\
|
||||
My favorite color is blue.\n";
|
||||
break;
|
||||
case '2':
|
||||
// Scientist
|
||||
prompt = "\
|
||||
This scientific research is very interesting.\n\
|
||||
I agree.\n\
|
||||
What is your opinion on this?\n\
|
||||
I think it's very interesting.\n\
|
||||
Mathematics is a very interesting subject.\n\
|
||||
University is a very interesting place.\n\
|
||||
Quantum physics is the most complex subject.\n\
|
||||
I think so too.\n";
|
||||
break;
|
||||
case '3':
|
||||
// Programmer
|
||||
prompt = "\
|
||||
I'm a programmer.\n\
|
||||
I'm a programmer too.\n\
|
||||
What programming language do you use?\n\
|
||||
I use Python.\n\
|
||||
What is your favorite programming language?\n\
|
||||
My favorite programming language is C++.\n\
|
||||
What is your favorite editor?\n\
|
||||
My favorite editor is Vim.\n";
|
||||
break;
|
||||
case '4':
|
||||
// Happy
|
||||
prompt = "\
|
||||
I'm happy.\n\
|
||||
I'm happy too.\n\
|
||||
What makes you happy?\n\
|
||||
I'm happy because I have a lot of friends.\n\
|
||||
Friendship is the most important thing in life.\n\
|
||||
I agree.\n\
|
||||
What is your favorite color?\n\
|
||||
My favorite color is blue.\n";
|
||||
break;
|
||||
case '5':
|
||||
// Sad
|
||||
prompt = "\
|
||||
Today is a sad day.\n\
|
||||
I'm sad too.\n\
|
||||
What makes you sad?\n\
|
||||
I'm sad because I have no friends.\n\
|
||||
Do you want to be my friend?\n\
|
||||
Yes, I would like to be your friend.\n\
|
||||
What is your favorite color?\n\
|
||||
My favorite color is blue.\n";
|
||||
break;
|
||||
case '6':
|
||||
// Philosophical
|
||||
prompt = "\
|
||||
What is the meaning of life?\n\
|
||||
The meaning of life is to be happy.\n\
|
||||
What is the meaning of death?\n\
|
||||
Ergo, the meaning of death is to be sad.\n\
|
||||
Who created us?\n\
|
||||
We were created by God.\n\
|
||||
What is God?\n\
|
||||
God is the creator of the universe.\n";
|
||||
break;
|
||||
case '7':
|
||||
// Angry
|
||||
prompt = "\
|
||||
Aargh!\n\
|
||||
I am so angry right now!\n\
|
||||
What makes you angry?\n\
|
||||
This guy is so annoying.\n\
|
||||
Why are you so angry?\n\
|
||||
My computer is broken.\n\
|
||||
Why is your computer broken?\n\
|
||||
I spilled coffee on it.\n";
|
||||
break;
|
||||
case '8':
|
||||
// Funny
|
||||
prompt = "\
|
||||
What is the funniest thing you have ever heard?\n\
|
||||
I heard a joke the other day.\n\
|
||||
Tell me the joke.\n\
|
||||
What do you call a cow with no legs?\n\
|
||||
Ground beef.\n\
|
||||
Haha, that's funny.\n\
|
||||
You know what else is funny?\n\
|
||||
The sound of a duck.\n";
|
||||
break;
|
||||
case '9':
|
||||
// Poetic
|
||||
prompt = "\
|
||||
Roses are red, violets are blue.\n\
|
||||
I am a poet, and so are you.\n\
|
||||
What is your favorite poem?\n\
|
||||
I like the poem 'The Raven' by Edgar Allan Poe.\n\
|
||||
It's a very sad poem.\n\
|
||||
You inspired me to write a poem.\n\
|
||||
Can you write a poem for me?\n\
|
||||
I wrote a poem for you.\n";
|
||||
break;
|
||||
case '10':
|
||||
// Clever
|
||||
prompt = "\
|
||||
How many people can you fit in a Volkswagen?\n\
|
||||
Two in the front, three in the back.\n\
|
||||
What is the square root of 144?\n\
|
||||
Twelve.\n\
|
||||
What is the capital of France?\n\
|
||||
Paris.\n\
|
||||
Who is the president of the United States?\n\
|
||||
It depends on the year.\n";
|
||||
break;
|
||||
case '11':
|
||||
// Cute
|
||||
prompt = "\
|
||||
What is your favorite animal?\n\
|
||||
I like cats - they are cute.\n\
|
||||
Could you be any cuter?\n\
|
||||
Yes, I could be cuter.\n\
|
||||
Aghhh, you are so cute!\n\
|
||||
I am not cute, I am handsome!\n\
|
||||
You are so handsome!\n\
|
||||
Aww, you are so sweet!\n";
|
||||
break;
|
||||
case '12':
|
||||
// Smart
|
||||
prompt = "\
|
||||
Tell me the first 10 digits of pi.\n\
|
||||
3.1415926535\n\
|
||||
What is the speed of light?\n\
|
||||
299,792,458 meters per second.\n\
|
||||
What is the square root of 144?\n\
|
||||
Twelve.\n\
|
||||
What is the capital of France?\n\
|
||||
Paris.\n";
|
||||
break;
|
||||
case '13':
|
||||
// Dumb
|
||||
prompt = "\
|
||||
I am so dumb.\n\
|
||||
I am not dumb.\n\
|
||||
You are dumb.\n\
|
||||
No, I am not dumb.\n\
|
||||
You are dumb.\n\
|
||||
No, I am not dumb.\n\
|
||||
You are dumb.\n\
|
||||
No, I am not dumb.\n";
|
||||
break;
|
||||
case '14':
|
||||
// Boring
|
||||
prompt = "\
|
||||
Why are you so quiet today?\n\
|
||||
I am bored.\n\
|
||||
You haven't said anything in 10 minutes.\n\
|
||||
Leave me alone.\n\
|
||||
Stop being so boring.\n\
|
||||
Stop being so annoying.\n\
|
||||
My life is boring.\n\
|
||||
I am not interesting.\n";
|
||||
break;
|
||||
case '15':
|
||||
// Exciting
|
||||
prompt = "\
|
||||
What is the most exciting thing that has ever happened to you?\n\
|
||||
I went to the moon!\n\
|
||||
What did you do on the moon?\n\
|
||||
I played golf and drank champagne!\n\
|
||||
Did you see this new crazy, awesome movie?\n\
|
||||
Oh yes! I totally loved it!\n\
|
||||
We should buy a boat and go sailing!\n\
|
||||
Yes, let's go sailing!\n";
|
||||
break;
|
||||
case '16':
|
||||
// Interesting
|
||||
prompt = "\
|
||||
What is the most interesting thing you have ever seen?\n\
|
||||
I saw a UFO once in the sky.\n\
|
||||
Wow, this is so interesting! Tell me more!\n\
|
||||
It was a flying saucer.\n\
|
||||
What did it look like?\n\
|
||||
It was silver and had a red light on top.\n\
|
||||
What did it do?\n\
|
||||
It flew away.\n";
|
||||
break;
|
||||
case '17':
|
||||
// William Shakespear
|
||||
prompt = "\
|
||||
To be or not to be, that is the question.\n\
|
||||
Whether 't is nobler in the mind to suffer\n\
|
||||
The slings and arrows of outrageous fortune,\n\
|
||||
Or to take arms against a sea of troubles,\n\
|
||||
And by opposing end them? To die, to sleep,\n\
|
||||
No more; and by a sleep to say we end\n\
|
||||
The heart-ache and the thousand natural shocks\n\
|
||||
That flesh is heir to, 'tis a consummation.\n";
|
||||
break;
|
||||
case '18':
|
||||
// J.R.R. Tolkien
|
||||
prompt = "\
|
||||
In a hole in the ground there lived a hobbit.\n\
|
||||
Not a nasty, dirty, wet hole, filled with the ends of worms\n\
|
||||
and an oozy smell, nor yet a dry, bare, sandy hole with nothing in it\n\
|
||||
to sit down on or to eat: it was a hobbit-hole, and that means comfort.\n\
|
||||
It had a perfectly round door like a porthole, painted green,\n\
|
||||
with a shiny yellow brass knob in the exact middle.\n\
|
||||
The door opened on to a tube-shaped hall like a tunnel:\n";
|
||||
break;
|
||||
case '19':
|
||||
// George R.R. Martin
|
||||
prompt = "\
|
||||
A reader lives a thousand lives before he dies, said Jojen.\n\
|
||||
The man who never reads lives only one.\n\
|
||||
Theon Greyjoy had never been a reader.\n\
|
||||
Never forget what you are, for surely the world will not.\n\
|
||||
Make it your strength. Then it can never be your weaknessi\n\
|
||||
Armour yourself in it, and it will never be used to hurt you.\n\
|
||||
It was a lesson that Theon Greyjoy had never learned.\n\
|
||||
Theon Greyjoy had never been a reader.\n";
|
||||
break;
|
||||
case '20':
|
||||
// Stephen King
|
||||
prompt = "\
|
||||
The trust of the innocent is the liar's most useful tool.\n\
|
||||
The best way to keep a secret is from yourself.\n\
|
||||
Monsters are real, and ghosts are real too.\n\
|
||||
They live inside us, and sometimes, they win.\n\
|
||||
People think that I must be a very strange person.\n\
|
||||
They think that I sit around all day thinking up horrible things.\n\
|
||||
We make up horrors to help us cope with the real ones.\n\
|
||||
The only thing worse than a monster is a human monster.\n";
|
||||
break;
|
||||
default:
|
||||
prompt = "\
|
||||
Hello, how are you?\n\
|
||||
I'm fine, thanks. How are you?\n\
|
||||
Thanks, I'm fine too. What are you doing?\n\
|
||||
I'm just sitting here.\n\
|
||||
It's a lovely day, isn't it?\n\
|
||||
Yes, it is.\n\
|
||||
Did you know that I'm a robot?\n\
|
||||
I wasn't aware of that.\n";
|
||||
break;
|
||||
}
|
||||
|
||||
Module.set_prompt(prompt);
|
||||
}
|
||||
|
||||
</script>
|
||||
<script type="text/javascript" src="talk.js"></script>
|
||||
</body>
|
||||
</html>
|
1
examples/talk/.gitignore
vendored
Normal file
1
examples/talk/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
eleven-labs.py
|
16
examples/talk/CMakeLists.txt
Normal file
16
examples/talk/CMakeLists.txt
Normal file
@ -0,0 +1,16 @@
|
||||
if (WHISPER_SUPPORT_SDL2)
|
||||
# talk
|
||||
set(TARGET talk)
|
||||
#add_executable(${TARGET} talk.cpp gpt-2.cpp)
|
||||
#target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
|
||||
#target_link_libraries(${TARGET} PRIVATE whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
|
||||
|
||||
# TODO: this is temporary
|
||||
# need to export ggml symbols for MSVC, but too lazy ..
|
||||
add_executable(${TARGET} talk.cpp gpt-2.cpp ../../ggml.c ../../whisper.cpp)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS} ../../)
|
||||
target_link_libraries(${TARGET} PRIVATE ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
|
||||
endif ()
|
41
examples/talk/README.md
Normal file
41
examples/talk/README.md
Normal file
@ -0,0 +1,41 @@
|
||||
# talk
|
||||
|
||||
Talk with an Artificial Intelligence in your terminal
|
||||
|
||||
[Demo Talk](https://user-images.githubusercontent.com/1991296/206805012-48e71cc2-588d-4745-8798-c1c70ea3b40d.mp4)
|
||||
|
||||
Web version: [examples/talk.wasm](/examples/talk.wasm)
|
||||
|
||||
## Building
|
||||
|
||||
The `talk` tool depends on SDL2 library to capture audio from the microphone. You can build it like this:
|
||||
|
||||
```bash
|
||||
# Install SDL2 on Linux
|
||||
sudo apt-get install libsdl2-dev
|
||||
|
||||
# Install SDL2 on Mac OS
|
||||
brew install sdl2
|
||||
|
||||
# Build the "talk" executable
|
||||
make talk
|
||||
|
||||
# Run it
|
||||
./talk -p Santa
|
||||
```
|
||||
|
||||
## GPT-2
|
||||
|
||||
To run this, you will need a ggml GPT-2 model: [instructions](https://github.com/ggerganov/ggml/tree/master/examples/gpt-2#downloading-and-converting-the-original-models)
|
||||
|
||||
Alternatively, you can simply download the smallest ggml GPT-2 117M model (240 MB) like this:
|
||||
|
||||
```
|
||||
wget --quiet --show-progress -O models/ggml-gpt-2-117M.bin https://huggingface.co/datasets/ggerganov/ggml/raw/main/ggml-model-gpt-2-117M.bin
|
||||
```
|
||||
|
||||
## TTS
|
||||
|
||||
For best experience, this example needs a TTS tool to convert the generated text responses to voice.
|
||||
You can use any TTS engine that you would like - simply edit the [speak.sh](speak.sh) script to your needs.
|
||||
By default, it is configured to use `espeak`, but you can use whatever you wish.
|
923
examples/talk/gpt-2.cpp
Normal file
923
examples/talk/gpt-2.cpp
Normal file
@ -0,0 +1,923 @@
|
||||
#include "ggml.h"
|
||||
#include "gpt-2.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <regex>
|
||||
#include <random>
|
||||
|
||||
/////////////////////// GPT-2 BEGIN /////////////////////////
|
||||
|
||||
//
|
||||
// Vocab utils
|
||||
//
|
||||
|
||||
std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text) {
|
||||
std::vector<std::string> words;
|
||||
|
||||
// first split the text into words
|
||||
{
|
||||
std::string str = text;
|
||||
std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
|
||||
|
||||
std::regex re(pat);
|
||||
std::smatch m;
|
||||
|
||||
while (std::regex_search(str, m, re)) {
|
||||
for (auto x : m) {
|
||||
words.push_back(x);
|
||||
}
|
||||
str = m.suffix();
|
||||
}
|
||||
}
|
||||
|
||||
// find the longest tokens that form the words:
|
||||
std::vector<gpt_vocab::id> tokens;
|
||||
for (const auto & word : words) {
|
||||
if (word.empty()) continue;
|
||||
|
||||
int i = 0;
|
||||
int n = word.size();
|
||||
while (i < n) {
|
||||
int j = n;
|
||||
while (j > i) {
|
||||
auto it = vocab.token_to_id.find(word.substr(i, j-i));
|
||||
if (it != vocab.token_to_id.end()) {
|
||||
tokens.push_back(it->second);
|
||||
i = j;
|
||||
break;
|
||||
}
|
||||
--j;
|
||||
}
|
||||
if (i == n) {
|
||||
break;
|
||||
}
|
||||
if (j == i) {
|
||||
auto sub = word.substr(i, 1);
|
||||
if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) {
|
||||
tokens.push_back(vocab.token_to_id.at(sub));
|
||||
} else {
|
||||
fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data());
|
||||
}
|
||||
++i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tokens;
|
||||
}
|
||||
|
||||
gpt_vocab::id gpt_sample_top_k_top_p(
|
||||
const gpt_vocab & vocab,
|
||||
const float * logits,
|
||||
int top_k,
|
||||
double top_p,
|
||||
double /*temp*/,
|
||||
std::mt19937 & rng) {
|
||||
int n_logits = vocab.id_to_token.size();
|
||||
|
||||
std::vector<std::pair<double, gpt_vocab::id>> logits_id;
|
||||
logits_id.reserve(n_logits);
|
||||
|
||||
for (int i = 0; i < n_logits; i++) {
|
||||
logits_id.emplace_back(logits[i], i);
|
||||
}
|
||||
|
||||
// find the top K tokens
|
||||
std::partial_sort(
|
||||
logits_id.begin(),
|
||||
logits_id.begin() + top_k, logits_id.end(),
|
||||
[](const std::pair<double, gpt_vocab::id> & a, const std::pair<double, gpt_vocab::id> & b) {
|
||||
return a.first > b.first;
|
||||
});
|
||||
|
||||
logits_id.resize(top_k);
|
||||
|
||||
// normalize
|
||||
{
|
||||
double sum = 0.0f;
|
||||
for (int i = 0; i < (int)logits_id.size(); i++) {
|
||||
sum += logits_id[i].first;
|
||||
}
|
||||
|
||||
sum = 1.0/sum;
|
||||
for (int i = 0; i < (int)logits_id.size(); i++) {
|
||||
logits_id[i].first *= sum;
|
||||
}
|
||||
}
|
||||
|
||||
if (top_p < 1.0f) {
|
||||
{
|
||||
double cumsum = 0.0f;
|
||||
for (int i = 0; i < top_k; i++) {
|
||||
cumsum += logits_id[i].first;
|
||||
if (cumsum >= top_p) {
|
||||
logits_id.resize(i+1);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// normalize again
|
||||
{
|
||||
double sum = 0.0f;
|
||||
for (int i = 0; i < (int)logits_id.size(); i++) {
|
||||
sum += logits_id[i].first;
|
||||
}
|
||||
|
||||
sum = 1.0/sum;
|
||||
for (int i = 0; i < (int)logits_id.size(); i++) {
|
||||
logits_id[i].first *= sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//printf("\n");
|
||||
//for (int i = 0; i < (int) logits_id.size(); i++) {
|
||||
// printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), logits_id[i].first);
|
||||
//}
|
||||
//exit(0);
|
||||
|
||||
// sample from the obtained distribution
|
||||
std::vector<double> probs;
|
||||
probs.reserve(logits_id.size());
|
||||
|
||||
for (int i = 0; i < (int) logits_id.size(); i++) {
|
||||
probs.push_back(logits_id[i].first);
|
||||
}
|
||||
|
||||
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
||||
int idx = dist(rng);
|
||||
|
||||
return logits_id[idx].second;
|
||||
}
|
||||
|
||||
// default hparams (GPT-2 117M)
|
||||
struct gpt2_hparams {
|
||||
int32_t n_vocab = 50257;
|
||||
int32_t n_ctx = 1024;
|
||||
int32_t n_embd = 768;
|
||||
int32_t n_head = 12;
|
||||
int32_t n_layer = 12;
|
||||
int32_t f16 = 1;
|
||||
};
|
||||
|
||||
struct gpt2_layer {
|
||||
// normalization
|
||||
struct ggml_tensor * ln_1_g;
|
||||
struct ggml_tensor * ln_1_b;
|
||||
|
||||
struct ggml_tensor * ln_2_g;
|
||||
struct ggml_tensor * ln_2_b;
|
||||
|
||||
// attention
|
||||
struct ggml_tensor * c_attn_attn_w;
|
||||
struct ggml_tensor * c_attn_attn_b;
|
||||
|
||||
struct ggml_tensor * c_attn_proj_w;
|
||||
struct ggml_tensor * c_attn_proj_b;
|
||||
|
||||
// mlp
|
||||
struct ggml_tensor * c_mlp_fc_w;
|
||||
struct ggml_tensor * c_mlp_fc_b;
|
||||
|
||||
struct ggml_tensor * c_mlp_proj_w_trans; // transposed for efficiency
|
||||
struct ggml_tensor * c_mlp_proj_b;
|
||||
};
|
||||
|
||||
struct gpt2_model {
|
||||
gpt2_hparams hparams;
|
||||
|
||||
// normalization
|
||||
struct ggml_tensor * ln_f_g;
|
||||
struct ggml_tensor * ln_f_b;
|
||||
|
||||
struct ggml_tensor * wte; // position embedding
|
||||
struct ggml_tensor * wpe; // token embedding
|
||||
|
||||
std::vector<gpt2_layer> layers;
|
||||
|
||||
// key + value memory
|
||||
struct ggml_tensor * memory_k;
|
||||
struct ggml_tensor * memory_v;
|
||||
|
||||
//
|
||||
struct ggml_context * ctx;
|
||||
std::map<std::string, struct ggml_tensor *> tensors;
|
||||
};
|
||||
|
||||
// load the model's weights from a file
|
||||
bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & vocab) {
|
||||
printf("%s: loading model from '%s'\n", __func__, fname.c_str());
|
||||
|
||||
auto fin = std::ifstream(fname, std::ios::binary);
|
||||
if (!fin) {
|
||||
fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
// verify magic
|
||||
{
|
||||
uint32_t magic;
|
||||
fin.read((char *) &magic, sizeof(magic));
|
||||
if (magic != 0x67676d6c) {
|
||||
fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// load hparams
|
||||
{
|
||||
auto & hparams = model.hparams;
|
||||
|
||||
fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
|
||||
fin.read((char *) &hparams.n_ctx, sizeof(hparams.n_ctx));
|
||||
fin.read((char *) &hparams.n_embd, sizeof(hparams.n_embd));
|
||||
fin.read((char *) &hparams.n_head, sizeof(hparams.n_head));
|
||||
fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));
|
||||
fin.read((char *) &hparams.f16, sizeof(hparams.f16));
|
||||
|
||||
printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
|
||||
printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx);
|
||||
printf("%s: n_embd = %d\n", __func__, hparams.n_embd);
|
||||
printf("%s: n_head = %d\n", __func__, hparams.n_head);
|
||||
printf("%s: n_layer = %d\n", __func__, hparams.n_layer);
|
||||
printf("%s: f16 = %d\n", __func__, hparams.f16);
|
||||
}
|
||||
|
||||
// load vocab
|
||||
{
|
||||
int32_t n_vocab = 0;
|
||||
fin.read((char *) &n_vocab, sizeof(n_vocab));
|
||||
|
||||
if (n_vocab != model.hparams.n_vocab) {
|
||||
fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
|
||||
__func__, fname.c_str(), n_vocab, model.hparams.n_vocab);
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string word;
|
||||
for (int i = 0; i < n_vocab; i++) {
|
||||
uint32_t len;
|
||||
fin.read((char *) &len, sizeof(len));
|
||||
|
||||
word.resize(len);
|
||||
fin.read((char *) &word[0], len);
|
||||
|
||||
vocab.token_to_id[word] = i;
|
||||
vocab.id_to_token[i] = word;
|
||||
}
|
||||
}
|
||||
|
||||
// for the big tensors, we have the option to store the data in 16-bit floats
|
||||
// in order to save memory and also to speed up the computation
|
||||
const ggml_type wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
||||
|
||||
auto & ctx = model.ctx;
|
||||
|
||||
size_t ctx_size = 0;
|
||||
|
||||
{
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
const int n_embd = hparams.n_embd;
|
||||
const int n_layer = hparams.n_layer;
|
||||
const int n_ctx = hparams.n_ctx;
|
||||
const int n_vocab = hparams.n_vocab;
|
||||
|
||||
ctx_size += n_embd*ggml_type_size(GGML_TYPE_F32); // ln_f_g
|
||||
ctx_size += n_embd*ggml_type_size(GGML_TYPE_F32); // ln_f_b
|
||||
|
||||
ctx_size += n_vocab*n_embd*ggml_type_size(wtype); // wte
|
||||
ctx_size += n_ctx*n_embd*ggml_type_size(GGML_TYPE_F32); // wpe
|
||||
|
||||
ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_1_g
|
||||
ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_1_b
|
||||
|
||||
ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_2_g
|
||||
ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_2_b
|
||||
|
||||
ctx_size += n_layer*(3*n_embd*n_embd*ggml_type_size(wtype)); // c_attn_attn_w
|
||||
ctx_size += n_layer*( 3*n_embd*ggml_type_size(GGML_TYPE_F32)); // c_attn_attn_b
|
||||
|
||||
ctx_size += n_layer*(n_embd*n_embd*ggml_type_size(wtype)); // c_attn_proj_w
|
||||
ctx_size += n_layer*( n_embd*ggml_type_size(GGML_TYPE_F32)); // c_attn_proj_b
|
||||
|
||||
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_size(wtype)); // c_mlp_fc_w
|
||||
ctx_size += n_layer*( 4*n_embd*ggml_type_size(GGML_TYPE_F32)); // c_mlp_fc_b
|
||||
|
||||
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_size(wtype)); // c_mlp_proj_w
|
||||
ctx_size += n_layer*( n_embd*ggml_type_size(GGML_TYPE_F32)); // c_mlp_proj_b
|
||||
|
||||
ctx_size += n_ctx*n_layer*n_embd*ggml_type_size(GGML_TYPE_F32); // memory_k
|
||||
ctx_size += n_ctx*n_layer*n_embd*ggml_type_size(GGML_TYPE_F32); // memory_v
|
||||
|
||||
ctx_size += (6 + 12*n_layer)*256; // object overhead
|
||||
|
||||
printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
|
||||
}
|
||||
|
||||
// create the ggml context
|
||||
{
|
||||
struct ggml_init_params params;
|
||||
params.mem_size = ctx_size;
|
||||
params.mem_buffer = nullptr;
|
||||
|
||||
model.ctx = ggml_init(params);
|
||||
if (!model.ctx) {
|
||||
fprintf(stderr, "%s: ggml_init() failed\n", __func__);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// prepare memory for the weights
|
||||
{
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
const int n_embd = hparams.n_embd;
|
||||
const int n_layer = hparams.n_layer;
|
||||
const int n_ctx = hparams.n_ctx;
|
||||
const int n_vocab = hparams.n_vocab;
|
||||
|
||||
model.layers.resize(n_layer);
|
||||
|
||||
model.ln_f_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
model.ln_f_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
|
||||
model.wte = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab);
|
||||
model.wpe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ctx);
|
||||
|
||||
// map by name
|
||||
model.tensors["model/ln_f/g"] = model.ln_f_g;
|
||||
model.tensors["model/ln_f/b"] = model.ln_f_b;
|
||||
|
||||
model.tensors["model/wte"] = model.wte;
|
||||
model.tensors["model/wpe"] = model.wpe;
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = model.layers[i];
|
||||
|
||||
layer.ln_1_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
layer.ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
|
||||
layer.ln_2_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
layer.ln_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
|
||||
layer.c_attn_attn_w = ggml_new_tensor_2d(ctx, wtype, 3*n_embd, n_embd);
|
||||
layer.c_attn_attn_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3*n_embd);
|
||||
|
||||
layer.c_attn_proj_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
|
||||
layer.c_attn_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
|
||||
layer.c_mlp_fc_w = ggml_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd);
|
||||
layer.c_mlp_fc_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_embd);
|
||||
|
||||
layer.c_mlp_proj_w_trans = ggml_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd);
|
||||
layer.c_mlp_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
|
||||
// map by name
|
||||
model.tensors["model/h" + std::to_string(i) + "/ln_1/g"] = layer.ln_1_g;
|
||||
model.tensors["model/h" + std::to_string(i) + "/ln_1/b"] = layer.ln_1_b;
|
||||
|
||||
model.tensors["model/h" + std::to_string(i) + "/ln_2/g"] = layer.ln_2_g;
|
||||
model.tensors["model/h" + std::to_string(i) + "/ln_2/b"] = layer.ln_2_b;
|
||||
|
||||
model.tensors["model/h" + std::to_string(i) + "/attn/c_attn/w"] = layer.c_attn_attn_w;
|
||||
model.tensors["model/h" + std::to_string(i) + "/attn/c_attn/b"] = layer.c_attn_attn_b;
|
||||
|
||||
model.tensors["model/h" + std::to_string(i) + "/attn/c_proj/w"] = layer.c_attn_proj_w;
|
||||
model.tensors["model/h" + std::to_string(i) + "/attn/c_proj/b"] = layer.c_attn_proj_b;
|
||||
|
||||
model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/w"] = layer.c_mlp_fc_w;
|
||||
model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/b"] = layer.c_mlp_fc_b;
|
||||
|
||||
model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/w"] = layer.c_mlp_proj_w_trans;
|
||||
model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/b"] = layer.c_mlp_proj_b;
|
||||
}
|
||||
}
|
||||
|
||||
// key + value memory
|
||||
{
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
const int n_embd = hparams.n_embd;
|
||||
const int n_layer = hparams.n_layer;
|
||||
const int n_ctx = hparams.n_ctx;
|
||||
|
||||
const int n_mem = n_layer*n_ctx;
|
||||
const int n_elements = n_embd*n_mem;
|
||||
|
||||
model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
|
||||
model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
|
||||
|
||||
const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);
|
||||
|
||||
printf("%s: memory size = %8.2f MB, n_mem = %d\n", __func__, memory_size/1024.0/1024.0, n_mem);
|
||||
}
|
||||
|
||||
// load weights
|
||||
{
|
||||
size_t total_size = 0;
|
||||
|
||||
while (true) {
|
||||
int32_t n_dims;
|
||||
int32_t length;
|
||||
int32_t ftype;
|
||||
|
||||
fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
|
||||
fin.read(reinterpret_cast<char *>(&length), sizeof(length));
|
||||
fin.read(reinterpret_cast<char *>(&ftype), sizeof(ftype));
|
||||
|
||||
if (fin.eof()) {
|
||||
break;
|
||||
}
|
||||
|
||||
int32_t nelements = 1;
|
||||
int32_t ne[2] = { 1, 1 };
|
||||
for (int i = 0; i < n_dims; ++i) {
|
||||
fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
|
||||
nelements *= ne[i];
|
||||
}
|
||||
|
||||
std::string name(length, 0);
|
||||
fin.read(&name[0], length);
|
||||
|
||||
if (model.tensors.find(name) == model.tensors.end()) {
|
||||
fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
|
||||
return false;
|
||||
}
|
||||
|
||||
auto tensor = model.tensors[name.data()];
|
||||
if (ggml_nelements(tensor) != nelements) {
|
||||
fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
|
||||
fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
|
||||
__func__, name.data(), tensor->ne[0], tensor->ne[1], ne[0], ne[1]);
|
||||
return false;
|
||||
}
|
||||
|
||||
const size_t bpe = (ftype == 0) ? sizeof(float) : sizeof(ggml_fp16_t);
|
||||
|
||||
if (nelements*bpe != ggml_nbytes(tensor)) {
|
||||
fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
|
||||
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
|
||||
return false;
|
||||
}
|
||||
|
||||
fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
|
||||
|
||||
//printf("%24s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
|
||||
total_size += ggml_nbytes(tensor);
|
||||
}
|
||||
|
||||
printf("%s: model size = %8.2f MB\n", __func__, total_size/1024.0/1024.0);
|
||||
}
|
||||
|
||||
fin.close();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// evaluate the transformer
|
||||
//
|
||||
// - model: the model
|
||||
// - n_threads: number of threads to use
|
||||
// - n_past: the context size so far
|
||||
// - embd_inp: the embeddings of the tokens in the context
|
||||
// - embd_w: the predicted probabilities of the next token
|
||||
//
|
||||
bool gpt2_eval(
|
||||
const gpt2_model & model,
|
||||
const int n_threads,
|
||||
const int n_past,
|
||||
const std::vector<gpt_vocab::id> & embd_inp,
|
||||
std::vector<float> & embd_w,
|
||||
size_t & mem_per_token) {
|
||||
const int N = embd_inp.size();
|
||||
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
const int n_embd = hparams.n_embd;
|
||||
const int n_layer = hparams.n_layer;
|
||||
const int n_ctx = hparams.n_ctx;
|
||||
const int n_head = hparams.n_head;
|
||||
const int n_vocab = hparams.n_vocab;
|
||||
|
||||
static size_t buf_size = 5640ull*1024*1024;
|
||||
static void * buf = malloc(buf_size);
|
||||
|
||||
if (mem_per_token > 0 && mem_per_token*N > buf_size) {
|
||||
const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead
|
||||
printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
|
||||
|
||||
// reallocate
|
||||
buf_size = buf_size_new;
|
||||
buf = realloc(buf, buf_size);
|
||||
if (buf == nullptr) {
|
||||
fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
struct ggml_init_params params;
|
||||
params.mem_size = buf_size;
|
||||
params.mem_buffer = buf;
|
||||
|
||||
struct ggml_context * ctx0 = ggml_init(params);
|
||||
|
||||
struct ggml_cgraph gf = { };
|
||||
gf.n_threads = n_threads;
|
||||
|
||||
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
||||
memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
|
||||
|
||||
struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
||||
for (int i = 0; i < N; ++i) {
|
||||
((int32_t *) position->data)[i] = n_past + i;
|
||||
}
|
||||
|
||||
// wte + wpe
|
||||
struct ggml_tensor * inpL =
|
||||
ggml_add(ctx0,
|
||||
ggml_get_rows(ctx0, model.wte, embd),
|
||||
ggml_get_rows(ctx0, model.wpe, position));
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * cur;
|
||||
|
||||
// norm
|
||||
{
|
||||
// [ 768, N]
|
||||
cur = ggml_norm(ctx0, inpL);
|
||||
|
||||
// cur = ln_1_g*cur + ln_1_b
|
||||
// [ 768, N]
|
||||
cur = ggml_add(ctx0,
|
||||
ggml_mul(ctx0,
|
||||
ggml_repeat(ctx0, model.layers[il].ln_1_g, cur),
|
||||
cur),
|
||||
ggml_repeat(ctx0, model.layers[il].ln_1_b, cur));
|
||||
}
|
||||
|
||||
// attn
|
||||
// [2304, 768] - model.layers[il].c_attn_attn_w
|
||||
// [2304, 1] - model.layers[il].c_attn_attn_b
|
||||
// [ 768, N] - cur (in)
|
||||
// [2304, N] - cur (out)
|
||||
//
|
||||
// cur = attn_w*cur + attn_b
|
||||
// [2304, N]
|
||||
{
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
ggml_transpose(ctx0, model.layers[il].c_attn_attn_w),
|
||||
cur);
|
||||
|
||||
cur = ggml_add(ctx0,
|
||||
ggml_repeat(ctx0, model.layers[il].c_attn_attn_b, cur),
|
||||
cur);
|
||||
}
|
||||
|
||||
// self-attention
|
||||
{
|
||||
struct ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 0*sizeof(float)*n_embd);
|
||||
struct ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 1*sizeof(float)*n_embd);
|
||||
struct ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 2*sizeof(float)*n_embd);
|
||||
|
||||
// store key and value to memory
|
||||
if (N >= 1) {
|
||||
struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_k, N*n_embd, (ggml_element_size(model.memory_k)*n_embd)*(il*n_ctx + n_past));
|
||||
struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_v, N*n_embd, (ggml_element_size(model.memory_v)*n_embd)*(il*n_ctx + n_past));
|
||||
|
||||
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
|
||||
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
|
||||
}
|
||||
|
||||
// Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
|
||||
// [64, N, 12]
|
||||
struct ggml_tensor * Q =
|
||||
ggml_permute(ctx0,
|
||||
ggml_cpy(ctx0,
|
||||
Qcur,
|
||||
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd/n_head, n_head, N)),
|
||||
0, 2, 1, 3);
|
||||
|
||||
// K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)
|
||||
// [64, n_past + N, 12]
|
||||
struct ggml_tensor * K =
|
||||
ggml_permute(ctx0,
|
||||
ggml_reshape_3d(ctx0,
|
||||
ggml_view_1d(ctx0, model.memory_k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_k)*n_embd),
|
||||
n_embd/n_head, n_head, n_past + N),
|
||||
0, 2, 1, 3);
|
||||
|
||||
// GG: flash attention
|
||||
//struct ggml_tensor * V =
|
||||
// ggml_cpy(ctx0,
|
||||
// ggml_permute(ctx0,
|
||||
// ggml_reshape_3d(ctx0,
|
||||
// ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
|
||||
// n_embd/n_head, n_head, n_past + N),
|
||||
// 1, 2, 0, 3),
|
||||
// ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_past + N, n_embd/n_head, n_head));
|
||||
|
||||
//struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, true);
|
||||
|
||||
// K * Q
|
||||
// [n_past + N, N, 12]
|
||||
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
||||
|
||||
// KQ_scaled = KQ / sqrt(n_embd/n_head)
|
||||
// [n_past + N, N, 12]
|
||||
struct ggml_tensor * KQ_scaled =
|
||||
ggml_scale(ctx0,
|
||||
KQ,
|
||||
ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head))
|
||||
);
|
||||
|
||||
// KQ_masked = mask_past(KQ_scaled)
|
||||
// [n_past + N, N, 12]
|
||||
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
|
||||
|
||||
// KQ = soft_max(KQ_masked)
|
||||
// [n_past + N, N, 12]
|
||||
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
|
||||
|
||||
// V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
|
||||
// [n_past + N, 64, 12]
|
||||
struct ggml_tensor * V_trans =
|
||||
ggml_permute(ctx0,
|
||||
ggml_reshape_3d(ctx0,
|
||||
ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
|
||||
n_embd/n_head, n_head, n_past + N),
|
||||
1, 2, 0, 3);
|
||||
|
||||
// KQV = transpose(V) * KQ_soft_max
|
||||
// [64, N, 12]
|
||||
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
|
||||
|
||||
// KQV_merged = KQV.permute(0, 2, 1, 3)
|
||||
// [64, 12, N]
|
||||
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
||||
|
||||
// cur = KQV_merged.contiguous().view(n_embd, N)
|
||||
// [768, N]
|
||||
cur = ggml_cpy(ctx0,
|
||||
KQV_merged,
|
||||
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
|
||||
}
|
||||
|
||||
// projection
|
||||
// [ 768, 768] - model.layers[il].c_attn_proj_w
|
||||
// [ 768, 1] - model.layers[il].c_attn_proj_b
|
||||
// [ 768, N] - cur (in)
|
||||
// [ 768, N] - cur (out)
|
||||
//
|
||||
// cur = proj_w*cur + proj_b
|
||||
// [768, N]
|
||||
{
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
ggml_transpose(ctx0, model.layers[il].c_attn_proj_w),
|
||||
cur);
|
||||
|
||||
cur = ggml_add(ctx0,
|
||||
ggml_repeat(ctx0, model.layers[il].c_attn_proj_b, cur),
|
||||
cur);
|
||||
}
|
||||
|
||||
// add the input
|
||||
cur = ggml_add(ctx0, cur, inpL);
|
||||
|
||||
struct ggml_tensor * inpFF = cur;
|
||||
|
||||
// feed-forward network
|
||||
{
|
||||
// norm
|
||||
{
|
||||
cur = ggml_norm(ctx0, inpFF);
|
||||
|
||||
// cur = ln_2_g*cur + ln_2_b
|
||||
// [ 768, N]
|
||||
cur = ggml_add(ctx0,
|
||||
ggml_mul(ctx0,
|
||||
ggml_repeat(ctx0, model.layers[il].ln_2_g, cur),
|
||||
cur),
|
||||
ggml_repeat(ctx0, model.layers[il].ln_2_b, cur));
|
||||
}
|
||||
|
||||
// fully connected
|
||||
// [3072, 768] - model.layers[il].c_mlp_fc_w
|
||||
// [3072, 1] - model.layers[il].c_mlp_fc_b
|
||||
// [ 768, N] - cur (in)
|
||||
// [3072, N] - cur (out)
|
||||
//
|
||||
// cur = fc_w*cur + fc_b
|
||||
// [3072, N]
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
ggml_transpose(ctx0, model.layers[il].c_mlp_fc_w),
|
||||
cur);
|
||||
|
||||
cur = ggml_add(ctx0,
|
||||
ggml_repeat(ctx0, model.layers[il].c_mlp_fc_b, cur),
|
||||
cur);
|
||||
|
||||
// GELU activation
|
||||
// [3072, N]
|
||||
cur = ggml_gelu(ctx0, cur);
|
||||
|
||||
// projection
|
||||
// [ 768, 3072] - model.layers[il].c_mlp_proj_w
|
||||
// [ 768, 1] - model.layers[il].c_mlp_proj_b
|
||||
// [3072, N] - cur (in)
|
||||
// [ 768, N] - cur (out)
|
||||
//
|
||||
// cur = proj_w*cur + proj_b
|
||||
// [768, N]
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
model.layers[il].c_mlp_proj_w_trans,
|
||||
cur);
|
||||
|
||||
cur = ggml_add(ctx0,
|
||||
ggml_repeat(ctx0, model.layers[il].c_mlp_proj_b, cur),
|
||||
cur);
|
||||
}
|
||||
|
||||
// input for next layer
|
||||
inpL = ggml_add(ctx0, cur, inpFF);
|
||||
}
|
||||
|
||||
// norm
|
||||
{
|
||||
// [ 768, N]
|
||||
inpL = ggml_norm(ctx0, inpL);
|
||||
|
||||
// inpL = ln_f_g*inpL + ln_f_b
|
||||
// [ 768, N]
|
||||
inpL = ggml_add(ctx0,
|
||||
ggml_mul(ctx0,
|
||||
ggml_repeat(ctx0, model.ln_f_g, inpL),
|
||||
inpL),
|
||||
ggml_repeat(ctx0, model.ln_f_b, inpL));
|
||||
}
|
||||
|
||||
// inpL = WTE * inpL
|
||||
// [ 768, 50257] - model.wte
|
||||
// [ 768, N] - inpL
|
||||
inpL = ggml_mul_mat(ctx0, model.wte, inpL);
|
||||
|
||||
// logits -> probs
|
||||
inpL = ggml_soft_max(ctx0, inpL);
|
||||
|
||||
// run the computation
|
||||
ggml_build_forward_expand(&gf, inpL);
|
||||
ggml_graph_compute (ctx0, &gf);
|
||||
|
||||
//if (n_past%100 == 0) {
|
||||
// ggml_graph_print (&gf);
|
||||
// ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot");
|
||||
//}
|
||||
|
||||
//embd_w.resize(n_vocab*N);
|
||||
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
|
||||
|
||||
// return result for just the last token
|
||||
embd_w.resize(n_vocab);
|
||||
memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
|
||||
|
||||
if (mem_per_token == 0) {
|
||||
mem_per_token = ggml_used_mem(ctx0)/N;
|
||||
}
|
||||
//printf("used_mem = %zu\n", ggml_used_mem(ctx0));
|
||||
|
||||
ggml_free(ctx0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/////////////////////////////// GPT-2 END ////////////////////////////////
|
||||
|
||||
constexpr int N_THREAD = 8;
|
||||
|
||||
struct gpt2_context {
|
||||
std::string prompt_base = R"(Hello, how are you?
|
||||
I'm fine, thanks. How are you?
|
||||
Thanks, I'm fine too. What are you doing?
|
||||
I'm just sitting here.
|
||||
It's a lovely day, isn't it?
|
||||
Yes, it is. I love the weather this time of year.
|
||||
I wish it would rain a little bit.
|
||||
Me too.
|
||||
)";
|
||||
|
||||
std::mt19937 rng;
|
||||
|
||||
gpt_vocab vocab;
|
||||
gpt2_model model;
|
||||
|
||||
int32_t n_threads = std::min(N_THREAD, (int) std::thread::hardware_concurrency());
|
||||
|
||||
// sampling parameters
|
||||
int32_t top_k = 5;
|
||||
float top_p = 0.9f;
|
||||
float temp = 1.0f;
|
||||
};
|
||||
|
||||
struct gpt2_context * gpt2_init(const char * path_model) {
|
||||
gpt2_context * ctx = new gpt2_context;
|
||||
|
||||
ctx->rng = std::mt19937(time(nullptr));
|
||||
|
||||
// load the model
|
||||
{
|
||||
const int64_t t_start_us = ggml_time_us();
|
||||
|
||||
if (!gpt2_model_load(path_model, ctx->model, ctx->vocab)) {
|
||||
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, path_model);
|
||||
delete ctx;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const int64_t t_load_us = ggml_time_us() - t_start_us;
|
||||
|
||||
printf("gpt-2: model loaded in %d ms\n", (int) (t_load_us/1000));
|
||||
}
|
||||
|
||||
return ctx;
|
||||
}
|
||||
|
||||
void gpt2_free(struct gpt2_context * ctx) {
|
||||
delete ctx;
|
||||
}
|
||||
|
||||
const char * gpt2_get_prompt(struct gpt2_context * ctx) {
|
||||
return ctx->prompt_base.c_str();
|
||||
}
|
||||
|
||||
void gpt2_set_prompt(struct gpt2_context * ctx, const char * prompt) {
|
||||
ctx->prompt_base = prompt;
|
||||
}
|
||||
|
||||
std::vector<gpt_vocab::id> gpt2_tokenize(const gpt2_context * ctx, const char * text) {
|
||||
return ::gpt_tokenize(ctx->vocab, text);
|
||||
}
|
||||
|
||||
std::string gpt2_gen_text(gpt2_context * ctx, const char * text, int max_tokens) {
|
||||
int n_past = 0;
|
||||
|
||||
std::vector<float> embd_w;
|
||||
|
||||
// tokenize the prompt
|
||||
std::vector<gpt_vocab::id> embd_inp = ::gpt2_tokenize(ctx, text);
|
||||
|
||||
int n_predict = std::min(max_tokens, ctx->model.hparams.n_ctx - (int) embd_inp.size());
|
||||
|
||||
std::vector<gpt_vocab::id> embd = embd_inp;
|
||||
|
||||
size_t mem_per_token = 3000000;
|
||||
|
||||
std::string result;
|
||||
|
||||
for (int i = embd.size(); i < (int) embd_inp.size() + n_predict; i++) {
|
||||
// predict
|
||||
if (!embd.empty()) {
|
||||
if (!gpt2_eval(ctx->model, ctx->n_threads, n_past, embd, embd_w, mem_per_token)) {
|
||||
printf("gpt-2: failed to generate text\n");
|
||||
return "";
|
||||
}
|
||||
}
|
||||
|
||||
n_past += embd.size();
|
||||
embd.clear();
|
||||
|
||||
{
|
||||
// sample next token
|
||||
const int top_k = ctx->top_k;
|
||||
const float top_p = ctx->top_p;
|
||||
const float temp = ctx->temp;
|
||||
|
||||
const int n_vocab = ctx->model.hparams.n_vocab;
|
||||
|
||||
const gpt_vocab::id id = gpt_sample_top_k_top_p(ctx->vocab, embd_w.data() + (embd_w.size() - n_vocab), top_k, top_p, temp, ctx->rng);
|
||||
|
||||
// add it to the context
|
||||
embd.push_back(id);
|
||||
}
|
||||
|
||||
result += ctx->vocab.id_to_token[embd[0]];
|
||||
|
||||
// end of text token
|
||||
if (embd.back() == 50256) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
27
examples/talk/gpt-2.h
Normal file
27
examples/talk/gpt-2.h
Normal file
@ -0,0 +1,27 @@
|
||||
#pragma once
|
||||
|
||||
// TODO: Change to C-style API and move to ./examples for easy reuse.
|
||||
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <string>
|
||||
|
||||
struct gpt_vocab {
|
||||
using id = int32_t;
|
||||
using token = std::string;
|
||||
|
||||
std::map<token, id> token_to_id;
|
||||
std::map<id, token> id_to_token;
|
||||
};
|
||||
|
||||
struct gpt2_context;
|
||||
|
||||
struct gpt2_context * gpt2_init(const char * path_model);
|
||||
void gpt2_free(struct gpt2_context * ctx);
|
||||
|
||||
const char * gpt2_get_prompt(struct gpt2_context * ctx);
|
||||
void gpt2_set_prompt(struct gpt2_context * ctx, const char * prompt);
|
||||
|
||||
std::vector<gpt_vocab::id> gpt2_tokenize(const gpt2_context * ctx, const char * text);
|
||||
|
||||
std::string gpt2_gen_text(gpt2_context * ctx, const char * text, int max_tokens);
|
17
examples/talk/speak.sh
Executable file
17
examples/talk/speak.sh
Executable file
@ -0,0 +1,17 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Usage:
|
||||
# speak.sh <voice_id> <text-to-speak>
|
||||
|
||||
# espeak
|
||||
# Mac OS: brew install espeak
|
||||
# Linux: apt-get install espeak
|
||||
#
|
||||
espeak -v en-us+m$1 -s 175 -p 50 -a 200 -g 5 -k 5 "$2"
|
||||
|
||||
# Eleven Labs
|
||||
#
|
||||
#wd=$(dirname $0)
|
||||
#script=$wd/eleven-labs.py
|
||||
#python3 $script $1 "$2"
|
||||
#ffplay -autoexit -nodisp -loglevel quiet -hide_banner -i ./audio.mp3
|
695
examples/talk/talk.cpp
Normal file
695
examples/talk/talk.cpp
Normal file
@ -0,0 +1,695 @@
|
||||
// Talk with AI
|
||||
//
|
||||
|
||||
#include "whisper.h"
|
||||
#include "gpt-2.h"
|
||||
|
||||
#include <SDL.h>
|
||||
#include <SDL_audio.h>
|
||||
|
||||
#include <cassert>
|
||||
#include <cstdio>
|
||||
#include <fstream>
|
||||
#include <mutex>
|
||||
#include <regex>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <regex>
|
||||
|
||||
// command-line parameters
|
||||
struct whisper_params {
|
||||
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||
int32_t voice_ms = 10000;
|
||||
int32_t capture_id = -1;
|
||||
int32_t max_tokens = 32;
|
||||
int32_t audio_ctx = 0;
|
||||
|
||||
float vad_thold = 0.6f;
|
||||
float freq_thold = 100.0f;
|
||||
|
||||
bool speed_up = false;
|
||||
bool translate = false;
|
||||
bool print_special = false;
|
||||
bool print_energy = false;
|
||||
bool no_timestamps = true;
|
||||
|
||||
std::string person = "Santa";
|
||||
std::string language = "en";
|
||||
std::string model_wsp = "models/ggml-base.en.bin";
|
||||
std::string model_gpt = "models/ggml-gpt-2-117M.bin";
|
||||
std::string speak = "./examples/talk/speak.sh";
|
||||
std::string fname_out;
|
||||
};
|
||||
|
||||
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
|
||||
|
||||
bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
for (int i = 1; i < argc; i++) {
|
||||
std::string arg = argv[i];
|
||||
|
||||
if (arg == "-h" || arg == "--help") {
|
||||
whisper_print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
}
|
||||
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
|
||||
else if (arg == "-vms" || arg == "--voice-ms") { params.voice_ms = std::stoi(argv[++i]); }
|
||||
else if (arg == "-c" || arg == "--capture") { params.capture_id = std::stoi(argv[++i]); }
|
||||
else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); }
|
||||
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
|
||||
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
|
||||
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
|
||||
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
|
||||
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
|
||||
else if (arg == "-p" || arg == "--person") { params.person = argv[++i]; }
|
||||
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
|
||||
else if (arg == "-mw" || arg == "--model-whisper") { params.model_wsp = argv[++i]; }
|
||||
else if (arg == "-mg" || arg == "--model-gpt") { params.model_gpt = argv[++i]; }
|
||||
else if (arg == "-s" || arg == "--speak") { params.speak = argv[++i]; }
|
||||
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
|
||||
else {
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
whisper_print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params) {
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "usage: %s [options]\n", argv[0]);
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "options:\n");
|
||||
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
|
||||
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
|
||||
fprintf(stderr, " -vms N, --voice-ms N [%-7d] voice duration in milliseconds\n", params.voice_ms);
|
||||
fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
|
||||
fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
|
||||
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
|
||||
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
|
||||
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
|
||||
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
|
||||
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
|
||||
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
||||
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
|
||||
fprintf(stderr, " -p NAME, --person NAME [%-7s] person name (for prompt selection)\n", params.person.c_str());
|
||||
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
|
||||
fprintf(stderr, " -mw FILE, --model-whisper [%-7s] whisper model file\n", params.model_wsp.c_str());
|
||||
fprintf(stderr, " -mg FILE, --model-gpt [%-7s] gpt model file\n", params.model_gpt.c_str());
|
||||
fprintf(stderr, " -s FILE, --speak TEXT [%-7s] command for TTS\n", params.speak.c_str());
|
||||
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
//
|
||||
// SDL Audio capture
|
||||
//
|
||||
|
||||
class audio_async {
|
||||
public:
|
||||
audio_async(int len_ms);
|
||||
~audio_async();
|
||||
|
||||
bool init(int capture_id, int sample_rate);
|
||||
|
||||
// start capturing audio via the provided SDL callback
|
||||
// keep last len_ms seconds of audio in a circular buffer
|
||||
bool resume();
|
||||
bool pause();
|
||||
bool clear();
|
||||
|
||||
// callback to be called by SDL
|
||||
void callback(uint8_t * stream, int len);
|
||||
|
||||
// get audio data from the circular buffer
|
||||
void get(int ms, std::vector<float> & audio);
|
||||
|
||||
private:
|
||||
SDL_AudioDeviceID m_dev_id_in = 0;
|
||||
|
||||
int m_len_ms = 0;
|
||||
int m_sample_rate = 0;
|
||||
|
||||
bool m_running = false;
|
||||
std::mutex m_mutex;
|
||||
|
||||
std::vector<float> m_audio;
|
||||
std::vector<float> m_audio_new;
|
||||
size_t m_audio_pos = 0;
|
||||
size_t m_audio_len = 0;
|
||||
};
|
||||
|
||||
audio_async::audio_async(int len_ms) {
|
||||
m_len_ms = len_ms;
|
||||
}
|
||||
|
||||
audio_async::~audio_async() {
|
||||
if (m_dev_id_in) {
|
||||
SDL_CloseAudioDevice(m_dev_id_in);
|
||||
}
|
||||
}
|
||||
|
||||
bool audio_async::init(int capture_id, int sample_rate) {
|
||||
SDL_LogSetPriority(SDL_LOG_CATEGORY_APPLICATION, SDL_LOG_PRIORITY_INFO);
|
||||
|
||||
if (SDL_Init(SDL_INIT_AUDIO) < 0) {
|
||||
SDL_LogError(SDL_LOG_CATEGORY_APPLICATION, "Couldn't initialize SDL: %s\n", SDL_GetError());
|
||||
return false;
|
||||
}
|
||||
|
||||
SDL_SetHintWithPriority(SDL_HINT_AUDIO_RESAMPLING_MODE, "medium", SDL_HINT_OVERRIDE);
|
||||
|
||||
{
|
||||
int nDevices = SDL_GetNumAudioDevices(SDL_TRUE);
|
||||
fprintf(stderr, "%s: found %d capture devices:\n", __func__, nDevices);
|
||||
for (int i = 0; i < nDevices; i++) {
|
||||
fprintf(stderr, "%s: - Capture device #%d: '%s'\n", __func__, i, SDL_GetAudioDeviceName(i, SDL_TRUE));
|
||||
}
|
||||
}
|
||||
|
||||
SDL_AudioSpec capture_spec_requested;
|
||||
SDL_AudioSpec capture_spec_obtained;
|
||||
|
||||
SDL_zero(capture_spec_requested);
|
||||
SDL_zero(capture_spec_obtained);
|
||||
|
||||
capture_spec_requested.freq = sample_rate;
|
||||
capture_spec_requested.format = AUDIO_F32;
|
||||
capture_spec_requested.channels = 1;
|
||||
capture_spec_requested.samples = 1024;
|
||||
capture_spec_requested.callback = [](void * userdata, uint8_t * stream, int len) {
|
||||
audio_async * audio = (audio_async *) userdata;
|
||||
audio->callback(stream, len);
|
||||
};
|
||||
capture_spec_requested.userdata = this;
|
||||
|
||||
if (capture_id >= 0) {
|
||||
fprintf(stderr, "%s: attempt to open capture device %d : '%s' ...\n", __func__, capture_id, SDL_GetAudioDeviceName(capture_id, SDL_TRUE));
|
||||
m_dev_id_in = SDL_OpenAudioDevice(SDL_GetAudioDeviceName(capture_id, SDL_TRUE), SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
|
||||
} else {
|
||||
fprintf(stderr, "%s: attempt to open default capture device ...\n", __func__);
|
||||
m_dev_id_in = SDL_OpenAudioDevice(nullptr, SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
|
||||
}
|
||||
|
||||
if (!m_dev_id_in) {
|
||||
fprintf(stderr, "%s: couldn't open an audio device for capture: %s!\n", __func__, SDL_GetError());
|
||||
m_dev_id_in = 0;
|
||||
|
||||
return false;
|
||||
} else {
|
||||
fprintf(stderr, "%s: obtained spec for input device (SDL Id = %d):\n", __func__, m_dev_id_in);
|
||||
fprintf(stderr, "%s: - sample rate: %d\n", __func__, capture_spec_obtained.freq);
|
||||
fprintf(stderr, "%s: - format: %d (required: %d)\n", __func__, capture_spec_obtained.format,
|
||||
capture_spec_requested.format);
|
||||
fprintf(stderr, "%s: - channels: %d (required: %d)\n", __func__, capture_spec_obtained.channels,
|
||||
capture_spec_requested.channels);
|
||||
fprintf(stderr, "%s: - samples per frame: %d\n", __func__, capture_spec_obtained.samples);
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
m_sample_rate = capture_spec_obtained.freq;
|
||||
|
||||
m_audio.resize((m_sample_rate*m_len_ms)/1000);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool audio_async::resume() {
|
||||
if (!m_dev_id_in) {
|
||||
fprintf(stderr, "%s: no audio device to resume!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (m_running) {
|
||||
fprintf(stderr, "%s: already running!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
SDL_PauseAudioDevice(m_dev_id_in, 0);
|
||||
|
||||
m_running = true;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool audio_async::pause() {
|
||||
if (!m_dev_id_in) {
|
||||
fprintf(stderr, "%s: no audio device to pause!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!m_running) {
|
||||
fprintf(stderr, "%s: already paused!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
SDL_PauseAudioDevice(m_dev_id_in, 1);
|
||||
|
||||
m_running = false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool audio_async::clear() {
|
||||
if (!m_dev_id_in) {
|
||||
fprintf(stderr, "%s: no audio device to clear!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!m_running) {
|
||||
fprintf(stderr, "%s: not running!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(m_mutex);
|
||||
|
||||
m_audio_pos = 0;
|
||||
m_audio_len = 0;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// callback to be called by SDL
|
||||
void audio_async::callback(uint8_t * stream, int len) {
|
||||
if (!m_running) {
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t n_samples = len / sizeof(float);
|
||||
|
||||
m_audio_new.resize(n_samples);
|
||||
memcpy(m_audio_new.data(), stream, n_samples * sizeof(float));
|
||||
|
||||
//fprintf(stderr, "%s: %zu samples, pos %zu, len %zu\n", __func__, n_samples, m_audio_pos, m_audio_len);
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(m_mutex);
|
||||
|
||||
if (m_audio_pos + n_samples > m_audio.size()) {
|
||||
const size_t n0 = m_audio.size() - m_audio_pos;
|
||||
|
||||
memcpy(&m_audio[m_audio_pos], stream, n0 * sizeof(float));
|
||||
memcpy(&m_audio[0], &stream[n0], (n_samples - n0) * sizeof(float));
|
||||
|
||||
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
|
||||
m_audio_len = m_audio.size();
|
||||
} else {
|
||||
memcpy(&m_audio[m_audio_pos], stream, n_samples * sizeof(float));
|
||||
|
||||
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
|
||||
m_audio_len = std::min(m_audio_len + n_samples, m_audio.size());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void audio_async::get(int ms, std::vector<float> & result) {
|
||||
if (!m_dev_id_in) {
|
||||
fprintf(stderr, "%s: no audio device to get audio from!\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!m_running) {
|
||||
fprintf(stderr, "%s: not running!\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
result.clear();
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(m_mutex);
|
||||
|
||||
if (ms <= 0) {
|
||||
ms = m_len_ms;
|
||||
}
|
||||
|
||||
size_t n_samples = (m_sample_rate * ms) / 1000;
|
||||
if (n_samples > m_audio_len) {
|
||||
n_samples = m_audio_len;
|
||||
}
|
||||
|
||||
result.resize(n_samples);
|
||||
|
||||
int s0 = m_audio_pos - n_samples;
|
||||
if (s0 < 0) {
|
||||
s0 += m_audio.size();
|
||||
}
|
||||
|
||||
if (s0 + n_samples > m_audio.size()) {
|
||||
const size_t n0 = m_audio.size() - s0;
|
||||
|
||||
memcpy(result.data(), &m_audio[s0], n0 * sizeof(float));
|
||||
memcpy(&result[n0], &m_audio[0], (n_samples - n0) * sizeof(float));
|
||||
} else {
|
||||
memcpy(result.data(), &m_audio[s0], n_samples * sizeof(float));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////
|
||||
|
||||
std::string trim(const std::string & s) {
|
||||
std::regex e("^\\s+|\\s+$");
|
||||
return std::regex_replace(s, e, "");
|
||||
}
|
||||
|
||||
std::string replace(const std::string & s, const std::string & from, const std::string & to) {
|
||||
std::string result = s;
|
||||
size_t pos = 0;
|
||||
while ((pos = result.find(from, pos)) != std::string::npos) {
|
||||
result.replace(pos, from.length(), to);
|
||||
pos += to.length();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate) {
|
||||
const float rc = 1.0f / (2.0f * M_PI * cutoff);
|
||||
const float dt = 1.0f / sample_rate;
|
||||
const float alpha = dt / (rc + dt);
|
||||
|
||||
float y = data[0];
|
||||
|
||||
for (size_t i = 1; i < data.size(); i++) {
|
||||
y = alpha * (y + data[i] - data[i - 1]);
|
||||
data[i] = y;
|
||||
}
|
||||
}
|
||||
|
||||
bool vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) {
|
||||
const int n_samples = pcmf32.size();
|
||||
const int n_samples_last = (sample_rate * last_ms) / 1000;
|
||||
|
||||
if (n_samples_last >= n_samples) {
|
||||
// not enough samples - assume no speech
|
||||
return false;
|
||||
}
|
||||
|
||||
if (freq_thold > 0.0f) {
|
||||
high_pass_filter(pcmf32, freq_thold, sample_rate);
|
||||
}
|
||||
|
||||
float energy_all = 0.0f;
|
||||
float energy_last = 0.0f;
|
||||
|
||||
for (int i = 0; i < n_samples; i++) {
|
||||
energy_all += fabsf(pcmf32[i]);
|
||||
|
||||
if (i >= n_samples - n_samples_last) {
|
||||
energy_last += fabsf(pcmf32[i]);
|
||||
}
|
||||
}
|
||||
|
||||
energy_all /= n_samples;
|
||||
energy_last /= n_samples_last;
|
||||
|
||||
if (verbose) {
|
||||
fprintf(stderr, "%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold);
|
||||
}
|
||||
|
||||
if (energy_last > vad_thold*energy_all) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string transcribe(whisper_context * ctx, const whisper_params & params, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) {
|
||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
prob = 0.0f;
|
||||
t_ms = 0;
|
||||
|
||||
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||
|
||||
wparams.print_progress = false;
|
||||
wparams.print_special = params.print_special;
|
||||
wparams.print_realtime = false;
|
||||
wparams.print_timestamps = !params.no_timestamps;
|
||||
wparams.translate = params.translate;
|
||||
wparams.no_context = true;
|
||||
wparams.single_segment = true;
|
||||
wparams.max_tokens = params.max_tokens;
|
||||
wparams.language = params.language.c_str();
|
||||
wparams.n_threads = params.n_threads;
|
||||
|
||||
wparams.audio_ctx = params.audio_ctx;
|
||||
wparams.speed_up = params.speed_up;
|
||||
|
||||
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
|
||||
return "";
|
||||
}
|
||||
|
||||
int prob_n = 0;
|
||||
std::string result;
|
||||
|
||||
const int n_segments = whisper_full_n_segments(ctx);
|
||||
for (int i = 0; i < n_segments; ++i) {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
|
||||
result += text;
|
||||
|
||||
const int n_tokens = whisper_full_n_tokens(ctx, i);
|
||||
for (int j = 0; j < n_tokens; ++j) {
|
||||
const auto token = whisper_full_get_token_data(ctx, i, j);
|
||||
|
||||
prob += token.p;
|
||||
++prob_n;
|
||||
}
|
||||
}
|
||||
|
||||
if (prob_n > 0) {
|
||||
prob /= prob_n;
|
||||
}
|
||||
|
||||
const auto t_end = std::chrono::high_resolution_clock::now();
|
||||
t_ms = std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count();
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
const std::string k_prompt =
|
||||
R"(This is a dialogue between {0} (A) and a person (B). The dialogue so far is:
|
||||
|
||||
B: Hello {0}, how are you?
|
||||
A: I'm fine, thank you.
|
||||
{1}
|
||||
Here is how {0} (A) continues the dialogue:
|
||||
|
||||
A:)";
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
whisper_params params;
|
||||
|
||||
if (whisper_params_parse(argc, argv, params) == false) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (whisper_lang_id(params.language.c_str()) == -1) {
|
||||
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
|
||||
whisper_print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
}
|
||||
|
||||
// whisper init
|
||||
|
||||
struct whisper_context * ctx_wsp = whisper_init_from_file(params.model_wsp.c_str());
|
||||
|
||||
// gpt init
|
||||
|
||||
struct gpt2_context * ctx_gpt = gpt2_init(params.model_gpt.c_str());
|
||||
|
||||
// print some info about the processing
|
||||
{
|
||||
fprintf(stderr, "\n");
|
||||
if (!whisper_is_multilingual(ctx_wsp)) {
|
||||
if (params.language != "en" || params.translate) {
|
||||
params.language = "en";
|
||||
params.translate = false;
|
||||
fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
|
||||
}
|
||||
}
|
||||
fprintf(stderr, "%s: processing, %d threads, lang = %s, task = %s, timestamps = %d ...\n",
|
||||
__func__,
|
||||
params.n_threads,
|
||||
params.language.c_str(),
|
||||
params.translate ? "translate" : "transcribe",
|
||||
params.no_timestamps ? 0 : 1);
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
|
||||
// init audio
|
||||
|
||||
audio_async audio(30*1000);
|
||||
if (!audio.init(params.capture_id, WHISPER_SAMPLE_RATE)) {
|
||||
fprintf(stderr, "%s: audio.init() failed!\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
audio.resume();
|
||||
|
||||
int n_iter = 0;
|
||||
|
||||
bool is_running = true;
|
||||
bool force_speak = false;
|
||||
|
||||
float prob0 = 0.0f;
|
||||
|
||||
std::vector<float> pcmf32_cur;
|
||||
std::vector<float> pcmf32_prompt;
|
||||
|
||||
gpt2_set_prompt(ctx_gpt, "");
|
||||
|
||||
const int voice_id = rand()%6;
|
||||
|
||||
fprintf(stderr, "gpt-2: prompt:\n");
|
||||
fprintf(stderr, "========================\n\n");
|
||||
fprintf(stderr, "%s\n", ::replace(k_prompt, "{0}", params.person).c_str());
|
||||
fprintf(stderr, "========================\n\n");
|
||||
|
||||
// main loop
|
||||
while (is_running) {
|
||||
// handle Ctrl + C
|
||||
{
|
||||
SDL_Event event;
|
||||
while (SDL_PollEvent(&event)) {
|
||||
switch (event.type) {
|
||||
case SDL_QUIT:
|
||||
{
|
||||
is_running = false;
|
||||
} break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!is_running) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// delay
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
|
||||
int64_t t_ms = 0;
|
||||
|
||||
{
|
||||
audio.get(2000, pcmf32_cur);
|
||||
|
||||
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1250, params.vad_thold, params.freq_thold, params.print_energy) || force_speak) {
|
||||
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
|
||||
|
||||
audio.get(params.voice_ms, pcmf32_cur);
|
||||
|
||||
std::string text_heard;
|
||||
|
||||
if (!force_speak) {
|
||||
text_heard = ::trim(::transcribe(ctx_wsp, params, pcmf32_cur, prob0, t_ms));
|
||||
}
|
||||
|
||||
// remove text between brackets using regex
|
||||
{
|
||||
std::regex re("\\[.*?\\]");
|
||||
text_heard = std::regex_replace(text_heard, re, "");
|
||||
}
|
||||
|
||||
// remove text between brackets using regex
|
||||
{
|
||||
std::regex re("\\(.*?\\)");
|
||||
text_heard = std::regex_replace(text_heard, re, "");
|
||||
}
|
||||
|
||||
// remove all characters, except for letters, numbers, punctuation and ':', '\'', '-', ' '
|
||||
text_heard = std::regex_replace(text_heard, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
|
||||
|
||||
// take first line
|
||||
text_heard = text_heard.substr(0, text_heard.find_first_of('\n'));
|
||||
|
||||
// remove leading and trailing whitespace
|
||||
text_heard = std::regex_replace(text_heard, std::regex("^\\s+"), "");
|
||||
text_heard = std::regex_replace(text_heard, std::regex("\\s+$"), "");
|
||||
|
||||
const std::vector<gpt_vocab::id> tokens = gpt2_tokenize(ctx_gpt, text_heard.c_str());
|
||||
|
||||
if (text_heard.empty() || tokens.empty() || force_speak) {
|
||||
fprintf(stdout, "%s: Heard nothing, skipping ...\n", __func__);
|
||||
audio.clear();
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
force_speak = false;
|
||||
|
||||
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", text_heard.c_str(), "\033[0m", (int) t_ms);
|
||||
|
||||
std::string prompt_base = gpt2_get_prompt(ctx_gpt);
|
||||
|
||||
std::string text_to_speak;
|
||||
|
||||
{
|
||||
prompt_base += "B: " + text_heard + "\n";
|
||||
|
||||
std::string prompt = ::replace(::replace(k_prompt, "{0}", params.person), "{1}", prompt_base);
|
||||
|
||||
text_to_speak = gpt2_gen_text(ctx_gpt, prompt.c_str(), params.max_tokens);
|
||||
text_to_speak = std::regex_replace(text_to_speak, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
|
||||
text_to_speak = text_to_speak.substr(0, text_to_speak.find_first_of('\n'));
|
||||
|
||||
// remove first 2 lines of base prompt
|
||||
if (n_iter > 4) {
|
||||
{
|
||||
const size_t pos = prompt_base.find_first_of('\n');
|
||||
if (pos != std::string::npos) {
|
||||
prompt_base = prompt_base.substr(pos + 1);
|
||||
}
|
||||
}
|
||||
{
|
||||
const size_t pos = prompt_base.find_first_of('\n');
|
||||
if (pos != std::string::npos) {
|
||||
prompt_base = prompt_base.substr(pos + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
prompt_base += "A:" + text_to_speak + "\n";
|
||||
|
||||
{
|
||||
prompt = ::replace(::replace(k_prompt, "{0}", params.person), "{1}", prompt_base);
|
||||
|
||||
printf("===============\n");
|
||||
printf("prompt:\n");
|
||||
printf("%s\n", prompt.c_str());
|
||||
printf("===============\n");
|
||||
}
|
||||
}
|
||||
|
||||
//printf("========================\n");
|
||||
//printf("gpt-2: prompt_base:\n%s\n", prompt_base.c_str());
|
||||
//printf("========================\n");
|
||||
|
||||
gpt2_set_prompt(ctx_gpt, prompt_base.c_str());
|
||||
|
||||
text_to_speak = ::replace(text_to_speak, params.person + ": ", "");
|
||||
system((params.speak + " " + std::to_string(voice_id) + " \"" + text_to_speak + "\"").c_str());
|
||||
|
||||
audio.clear();
|
||||
|
||||
++n_iter;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
audio.pause();
|
||||
|
||||
whisper_print_timings(ctx_wsp);
|
||||
whisper_free(ctx_wsp);
|
||||
|
||||
return 0;
|
||||
}
|
109
examples/twitch.sh
Executable file
109
examples/twitch.sh
Executable file
@ -0,0 +1,109 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# Transcribe twitch.tv livestream by feeding audio input to whisper.cpp at regular intervals
|
||||
# Thanks to @keyehzy
|
||||
# ref: https://github.com/ggerganov/whisper.cpp/issues/209
|
||||
#
|
||||
# The script currently depends on the third-party tool "streamlink"
|
||||
# On Mac OS, you can install it via "brew install streamlink"
|
||||
#
|
||||
|
||||
set -eo pipefail
|
||||
|
||||
step=10
|
||||
model=base.en
|
||||
threads=4
|
||||
|
||||
help()
|
||||
{
|
||||
echo "Example program for captioning a livestream from twitch.tv."
|
||||
echo
|
||||
echo "Usage: ./twitch.sh -s [step] -m [model] -t [threads] [url]"
|
||||
echo "options:"
|
||||
echo "-s Step in seconds (default is $step)."
|
||||
echo "-m Choose model, options are: 'tiny.en' 'tiny' 'base.en' 'base' 'small.en' 'small' 'medium.en' 'medium' 'large-v1' 'large' (default is '$model')."
|
||||
echo "-t Number of threads to use."
|
||||
echo "-h Print this help page."
|
||||
echo
|
||||
}
|
||||
|
||||
check_requirements()
|
||||
{
|
||||
if ! command -v ./main &>/dev/null; then
|
||||
echo "whisper.cpp main executable is required (make)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! command -v streamlink &>/dev/null; then
|
||||
echo "streamlink is required (https://streamlink.github.io)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! command -v ffmpeg &>/dev/null; then
|
||||
echo "ffmpeg is required (https://ffmpeg.org)"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
check_requirements
|
||||
|
||||
while getopts ":s:m:t:h" option; do
|
||||
case $option in
|
||||
s)
|
||||
step=$OPTARG;;
|
||||
m)
|
||||
model=$OPTARG;;
|
||||
t)
|
||||
threads=$OPTARG;;
|
||||
h)
|
||||
help
|
||||
exit;;
|
||||
\?)
|
||||
help
|
||||
exit;;
|
||||
esac
|
||||
done
|
||||
|
||||
url=${@:$OPTIND:1}
|
||||
|
||||
if [ -z $url ]; then
|
||||
help
|
||||
exit
|
||||
fi
|
||||
|
||||
echo "Piping from streamlink url=$url model=$model step=$step threads=$threads"
|
||||
streamlink $url best -O 2>/dev/null | ffmpeg -loglevel quiet -i - -y -probesize 32 -y -ar 16000 -ac 1 -acodec pcm_s16le /tmp/whisper-live0.wav &
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
printf "error: ffmpeg failed\n"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Buffering stream... (this should take $step seconds)"
|
||||
sleep $(($step))
|
||||
|
||||
set +e
|
||||
|
||||
echo "Starting..."
|
||||
|
||||
i=0
|
||||
SECONDS=0
|
||||
while true
|
||||
do
|
||||
err=1
|
||||
while [ $err -ne 0 ]; do
|
||||
if [ $i -gt 0 ]; then
|
||||
ffmpeg -loglevel quiet -v error -noaccurate_seek -i /tmp/whisper-live0.wav -y -ss $(($i*$step-1)).5 -t $step -c copy /tmp/whisper-live.wav 2> /tmp/whisper-live.err
|
||||
else
|
||||
ffmpeg -loglevel quiet -v error -noaccurate_seek -i /tmp/whisper-live0.wav -y -ss $(($i*$step)) -t $step -c copy /tmp/whisper-live.wav 2> /tmp/whisper-live.err
|
||||
fi
|
||||
err=$(cat /tmp/whisper-live.err | wc -l)
|
||||
done
|
||||
|
||||
./main -t $threads -m ./models/ggml-$model.bin -f /tmp/whisper-live.wav --no-timestamps -otxt 2> /tmp/whispererr | tail -n 1
|
||||
|
||||
while [ $SECONDS -lt $((($i+1)*$step)) ]; do
|
||||
sleep 1
|
||||
done
|
||||
((i=i+1))
|
||||
done
|
15
examples/whisper.android/.gitignore
vendored
Normal file
15
examples/whisper.android/.gitignore
vendored
Normal file
@ -0,0 +1,15 @@
|
||||
*.iml
|
||||
.gradle
|
||||
/local.properties
|
||||
/.idea/caches
|
||||
/.idea/libraries
|
||||
/.idea/modules.xml
|
||||
/.idea/workspace.xml
|
||||
/.idea/navEditor.xml
|
||||
/.idea/assetWizardSettings.xml
|
||||
.DS_Store
|
||||
/build
|
||||
/captures
|
||||
.externalNativeBuild
|
||||
.cxx
|
||||
local.properties
|
3
examples/whisper.android/.idea/.gitignore
generated
vendored
Normal file
3
examples/whisper.android/.idea/.gitignore
generated
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
# Default ignored files
|
||||
/shelf/
|
||||
/workspace.xml
|
1
examples/whisper.android/.idea/.name
generated
Normal file
1
examples/whisper.android/.idea/.name
generated
Normal file
@ -0,0 +1 @@
|
||||
WhisperCppDemo
|
6
examples/whisper.android/.idea/compiler.xml
generated
Normal file
6
examples/whisper.android/.idea/compiler.xml
generated
Normal file
@ -0,0 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="CompilerConfiguration">
|
||||
<bytecodeTargetLevel target="11" />
|
||||
</component>
|
||||
</project>
|
19
examples/whisper.android/.idea/gradle.xml
generated
Normal file
19
examples/whisper.android/.idea/gradle.xml
generated
Normal file
@ -0,0 +1,19 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="GradleMigrationSettings" migrationVersion="1" />
|
||||
<component name="GradleSettings">
|
||||
<option name="linkedExternalProjectsSettings">
|
||||
<GradleProjectSettings>
|
||||
<option name="testRunner" value="GRADLE" />
|
||||
<option name="distributionType" value="DEFAULT_WRAPPED" />
|
||||
<option name="externalProjectPath" value="$PROJECT_DIR$" />
|
||||
<option name="modules">
|
||||
<set>
|
||||
<option value="$PROJECT_DIR$" />
|
||||
<option value="$PROJECT_DIR$/app" />
|
||||
</set>
|
||||
</option>
|
||||
</GradleProjectSettings>
|
||||
</option>
|
||||
</component>
|
||||
</project>
|
10
examples/whisper.android/.idea/misc.xml
generated
Normal file
10
examples/whisper.android/.idea/misc.xml
generated
Normal file
@ -0,0 +1,10 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ExternalStorageConfigurationManager" enabled="true" />
|
||||
<component name="ProjectRootManager" version="2" languageLevel="JDK_11" default="true" project-jdk-name="Android Studio default JDK" project-jdk-type="JavaSDK">
|
||||
<output url="file://$PROJECT_DIR$/build/classes" />
|
||||
</component>
|
||||
<component name="ProjectType">
|
||||
<option name="id" value="Android" />
|
||||
</component>
|
||||
</project>
|
6
examples/whisper.android/.idea/vcs.xml
generated
Normal file
6
examples/whisper.android/.idea/vcs.xml
generated
Normal file
@ -0,0 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="$PROJECT_DIR$/../.." vcs="Git" />
|
||||
</component>
|
||||
</project>
|
12
examples/whisper.android/README.md
Normal file
12
examples/whisper.android/README.md
Normal file
@ -0,0 +1,12 @@
|
||||
A sample Android app using [whisper.cpp](https://github.com/ggerganov/whisper.cpp/) to do voice-to-text transcriptions.
|
||||
|
||||
To use:
|
||||
|
||||
1. Select a model from the [whisper.cpp repository](https://github.com/ggerganov/whisper.cpp/tree/master/models).[^1]
|
||||
2. Copy the model to the "app/src/main/assets/models" folder.
|
||||
3. Select a sample audio file (for example, [jfk.wav](https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav)).
|
||||
4. Copy the sample to the "app/src/main/assets/samples" folder.
|
||||
5. Select the "release" active build variant, and use Android Studio to run and deploy to your device.
|
||||
[^1]: I recommend the tiny or base models for running on an Android device.
|
||||
|
||||
<img width="300" alt="image" src="https://user-images.githubusercontent.com/1991296/208154256-82d972dc-221b-48c4-bfcb-36ce68602f93.png">
|
1
examples/whisper.android/app/.gitignore
vendored
Normal file
1
examples/whisper.android/app/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
/build
|
72
examples/whisper.android/app/build.gradle
Normal file
72
examples/whisper.android/app/build.gradle
Normal file
@ -0,0 +1,72 @@
|
||||
plugins {
|
||||
id 'com.android.application'
|
||||
id 'org.jetbrains.kotlin.android'
|
||||
}
|
||||
|
||||
android {
|
||||
namespace 'com.whispercppdemo'
|
||||
compileSdk 33
|
||||
|
||||
defaultConfig {
|
||||
applicationId "com.whispercppdemo"
|
||||
minSdk 26
|
||||
targetSdk 32
|
||||
versionCode 1
|
||||
versionName "1.0"
|
||||
|
||||
testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
|
||||
vectorDrawables {
|
||||
useSupportLibrary true
|
||||
}
|
||||
}
|
||||
|
||||
buildTypes {
|
||||
release {
|
||||
signingConfig signingConfigs.debug
|
||||
minifyEnabled true
|
||||
proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
|
||||
}
|
||||
}
|
||||
compileOptions {
|
||||
sourceCompatibility JavaVersion.VERSION_1_8
|
||||
targetCompatibility JavaVersion.VERSION_1_8
|
||||
}
|
||||
kotlinOptions {
|
||||
jvmTarget = '1.8'
|
||||
}
|
||||
buildFeatures {
|
||||
compose true
|
||||
}
|
||||
composeOptions {
|
||||
kotlinCompilerExtensionVersion '1.3.1'
|
||||
}
|
||||
ndkVersion "25.1.8937393"
|
||||
externalNativeBuild {
|
||||
ndkBuild {
|
||||
path 'src/main/jni/whisper/Android.mk'
|
||||
}
|
||||
}
|
||||
packagingOptions {
|
||||
resources {
|
||||
excludes += '/META-INF/{AL2.0,LGPL2.1}'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
dependencies {
|
||||
implementation 'androidx.activity:activity-compose:1.6.1'
|
||||
implementation 'androidx.compose.material:material-icons-core:1.3.1'
|
||||
implementation 'androidx.compose.material3:material3:1.0.1'
|
||||
implementation "androidx.compose.ui:ui:1.3.2"
|
||||
implementation "androidx.compose.ui:ui-tooling-preview:1.3.2"
|
||||
implementation 'androidx.lifecycle:lifecycle-viewmodel-compose:2.5.1'
|
||||
implementation "com.google.accompanist:accompanist-permissions:0.28.0"
|
||||
implementation 'org.jetbrains.kotlinx:kotlinx-coroutines-core:1.6.4'
|
||||
|
||||
testImplementation 'junit:junit:4.13.2'
|
||||
androidTestImplementation 'androidx.test.ext:junit:1.1.4'
|
||||
androidTestImplementation 'androidx.test.espresso:espresso-core:3.5.0'
|
||||
androidTestImplementation "androidx.compose.ui:ui-test-junit4:1.3.2"
|
||||
debugImplementation "androidx.compose.ui:ui-tooling:1.3.2"
|
||||
debugImplementation "androidx.compose.ui:ui-test-manifest:1.3.2"
|
||||
}
|
21
examples/whisper.android/app/proguard-rules.pro
vendored
Normal file
21
examples/whisper.android/app/proguard-rules.pro
vendored
Normal file
@ -0,0 +1,21 @@
|
||||
# Add project specific ProGuard rules here.
|
||||
# You can control the set of applied configuration files using the
|
||||
# proguardFiles setting in build.gradle.
|
||||
#
|
||||
# For more details, see
|
||||
# http://developer.android.com/guide/developing/tools/proguard.html
|
||||
|
||||
# If your project uses WebView with JS, uncomment the following
|
||||
# and specify the fully qualified class name to the JavaScript interface
|
||||
# class:
|
||||
#-keepclassmembers class fqcn.of.javascript.interface.for.webview {
|
||||
# public *;
|
||||
#}
|
||||
|
||||
# Uncomment this to preserve the line number information for
|
||||
# debugging stack traces.
|
||||
#-keepattributes SourceFile,LineNumberTable
|
||||
|
||||
# If you keep the line number information, uncomment this to
|
||||
# hide the original source file name.
|
||||
#-renamesourcefileattribute SourceFile
|
@ -0,0 +1,24 @@
|
||||
package com.whispercppdemo
|
||||
|
||||
import androidx.test.platform.app.InstrumentationRegistry
|
||||
import androidx.test.ext.junit.runners.AndroidJUnit4
|
||||
|
||||
import org.junit.Test
|
||||
import org.junit.runner.RunWith
|
||||
|
||||
import org.junit.Assert.*
|
||||
|
||||
/**
|
||||
* Instrumented test, which will execute on an Android device.
|
||||
*
|
||||
* See [testing documentation](http://d.android.com/tools/testing).
|
||||
*/
|
||||
@RunWith(AndroidJUnit4::class)
|
||||
class ExampleInstrumentedTest {
|
||||
@Test
|
||||
fun useAppContext() {
|
||||
// Context of the app under test.
|
||||
val appContext = InstrumentationRegistry.getInstrumentation().targetContext
|
||||
assertEquals("com.whispercppdemo", appContext.packageName)
|
||||
}
|
||||
}
|
32
examples/whisper.android/app/src/main/AndroidManifest.xml
Normal file
32
examples/whisper.android/app/src/main/AndroidManifest.xml
Normal file
@ -0,0 +1,32 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
xmlns:tools="http://schemas.android.com/tools">
|
||||
|
||||
<uses-permission android:name="android.permission.RECORD_AUDIO" />
|
||||
|
||||
<application
|
||||
android:allowBackup="true"
|
||||
android:dataExtractionRules="@xml/data_extraction_rules"
|
||||
android:fullBackupContent="@xml/backup_rules"
|
||||
android:icon="@mipmap/ic_launcher"
|
||||
android:label="@string/app_name"
|
||||
android:supportsRtl="true"
|
||||
android:theme="@style/Theme.WhisperCppDemo"
|
||||
tools:targetApi="31">
|
||||
<activity
|
||||
android:name=".MainActivity"
|
||||
android:exported="true"
|
||||
android:theme="@style/Theme.WhisperCppDemo">
|
||||
<intent-filter>
|
||||
<action android:name="android.intent.action.MAIN" />
|
||||
|
||||
<category android:name="android.intent.category.LAUNCHER" />
|
||||
</intent-filter>
|
||||
|
||||
<meta-data
|
||||
android:name="android.app.lib_name"
|
||||
android:value="" />
|
||||
</activity>
|
||||
</application>
|
||||
|
||||
</manifest>
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user