mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-06-24 17:15:19 +00:00
Compare commits
235 Commits
Author | SHA1 | Date | |
---|---|---|---|
fa8dbdc888 | |||
4a7d49af95 | |||
794b162a46 | |||
5fd1bdd7fc | |||
0ccd6746c9 | |||
d9b550c0a1 | |||
e9b091c92a | |||
1f30b99208 | |||
05c3ea3bc8 | |||
6108d3cc58 | |||
bab97c83d0 | |||
3eaeb030ff | |||
acec73ab6e | |||
5cc17418c7 | |||
3efb81dec6 | |||
94a7cd2a07 | |||
3e82ff4747 | |||
b5bd2f43c5 | |||
94aa56f19e | |||
4d89ee2e59 | |||
70567eff23 | |||
02ec83c5d5 | |||
2bd4b8d577 | |||
eecf2c3d41 | |||
c23588cc4b | |||
5108b30e6d | |||
f19e23fbd1 | |||
ea1f8a50d4 | |||
3dead611bb | |||
355da83690 | |||
3e5c49e59a | |||
5e47e223bd | |||
794ff3074a | |||
7e2afa4384 | |||
1c5edc3cb3 | |||
34b772727d | |||
2c856fb9e5 | |||
7727a40dc9 | |||
b5639ed313 | |||
2c4ac2627d | |||
674a8e579b | |||
001083a769 | |||
62b51c3070 | |||
61128870b8 | |||
78548dc03f | |||
66110dafcc | |||
b73a4638ac | |||
5f16420333 | |||
ccb47e7e10 | |||
677ad754a0 | |||
514cd04452 | |||
6704a81255 | |||
463e46338c | |||
2f889132c6 | |||
ebef1e8620 | |||
114df388fe | |||
ea36831459 | |||
69b8503935 | |||
0a2d1210bc | |||
859ffc994e | |||
5e6e2187a3 | |||
a7f1f33715 | |||
86ecfc6333 | |||
18e6fb0287 | |||
0f759f125d | |||
eefed45e37 | |||
aac1710afb | |||
21c1e6afc5 | |||
a47e812a54 | |||
42c6855103 | |||
0be9cd3497 | |||
e5c197d8aa | |||
7cd1d3bc34 | |||
82637b8e9f | |||
4a0deb8b1e | |||
8e361d90d7 | |||
fc49c44426 | |||
aec01bb337 | |||
21165580a1 | |||
1d749919e3 | |||
d4fa0d92ad | |||
a5e60c019d | |||
8fcd1a3b32 | |||
992aa2cd1b | |||
4aa3bcf8a4 | |||
1beff6f66d | |||
09e9068007 | |||
fa9d43181f | |||
bb6b54a03d | |||
b597c5a779 | |||
a3fb6c507f | |||
59fdcd19c8 | |||
478289a4b3 | |||
5e94129cb2 | |||
72af0f5697 | |||
af005d573f | |||
ad1389003d | |||
f420de1322 | |||
d176160f6f | |||
ca21f7ab16 | |||
373043cabe | |||
fb4d0d470f | |||
0d229163bb | |||
f254e78737 | |||
a94897bcde | |||
2407ae8ef0 | |||
b623ca43b1 | |||
69e6e4644a | |||
09d7d2b68e | |||
0336161b7d | |||
459753342d | |||
9764782bd9 | |||
3b010f9bed | |||
113fcec513 | |||
cfc06bf8df | |||
2bfe0ebc0f | |||
4dd7119deb | |||
ab1916fc59 | |||
a1c1583cc7 | |||
d012b5c7e4 | |||
b2083c5d02 | |||
f3ee4a9673 | |||
c306a7fd89 | |||
b2fc4c7010 | |||
291980369c | |||
86ef64a855 | |||
3b1960520a | |||
2bee2650c6 | |||
beb9512be3 | |||
47737b2e82 | |||
b992f3709e | |||
60337f5306 | |||
02c7516c57 | |||
411ea9b833 | |||
11f61cecd6 | |||
b5ddb16ec7 | |||
ae16c21e9c | |||
2c3f50a021 | |||
9a65269a20 | |||
78f166174f | |||
21c569ba4a | |||
1a91c19af9 | |||
f583e2d2f5 | |||
206fc93396 | |||
a6cf6f4c4a | |||
472a473fd1 | |||
9ba66c2fad | |||
1ccb8a46a5 | |||
1290fc6457 | |||
49b529ba74 | |||
8088a977af | |||
c9aeb33676 | |||
4a3f0d3fe9 | |||
874bde887e | |||
8738427dd6 | |||
c3991bbb24 | |||
00ea21668b | |||
0b85e8c401 | |||
fafd78945d | |||
8de452c18b | |||
a6dbd9188b | |||
4ef3398e8f | |||
5e9f33596f | |||
8d7b29cedd | |||
08dc705a69 | |||
1512545149 | |||
52a3e0c92a | |||
d1ea1220ff | |||
9c4a1522f6 | |||
f078a6f20e | |||
f30b5d322c | |||
44efbf7ff1 | |||
d347a59a5f | |||
6394c906af | |||
74ffa14e1d | |||
65fdcbbbbb | |||
d61d55cd4b | |||
d51fc3ee0a | |||
f82a7dd019 | |||
87dd4a3081 | |||
41e05c6b1b | |||
fa379cb22a | |||
322f4e6c4e | |||
1652965529 | |||
6042c7a3be | |||
6b351bb669 | |||
a62170c656 | |||
1944e7c33e | |||
49a8dd6732 | |||
8c7f642286 | |||
ad2a4ffa03 | |||
b3c865083e | |||
a0d4f8e65c | |||
4a214d2f07 | |||
0a0cfa7985 | |||
196d738974 | |||
84c6b42e65 | |||
dd6d582977 | |||
d51c5eb906 | |||
0be6a1afd9 | |||
a466c3404d | |||
d629c034a4 | |||
f00509d57c | |||
424c410c42 | |||
d97e6005e9 | |||
3467230a77 | |||
a091581eb3 | |||
68daf6e487 | |||
a593b932e4 | |||
9a8ad3db69 | |||
4e0b2069e7 | |||
ac521a566e | |||
331c0bbddc | |||
dc90efd504 | |||
7282e2109e | |||
466ceebb78 | |||
77226aa89d | |||
543bd5627e | |||
62fee9a9cc | |||
493d94130d | |||
1480a5f1af | |||
0f4227d9ee | |||
4c1fe0c813 | |||
fa463313ad | |||
501a6b455c | |||
91fc08c641 | |||
e1432dd91a | |||
22193cbfe8 | |||
42c6730732 | |||
76b6211f9b | |||
86a277f78d | |||
231bebca7d | |||
90564f85f9 | |||
99da1e5cc8 | |||
8e3f129b4d |
22
.github/workflows/bindings-go.yml
vendored
Normal file
22
.github/workflows/bindings-go.yml
vendored
Normal file
@ -0,0 +1,22 @@
|
||||
name: Bindings Tests (Go)
|
||||
on:
|
||||
push:
|
||||
paths:
|
||||
- bindings/go/**
|
||||
- whisper.h
|
||||
pull_request:
|
||||
paths:
|
||||
- bindings/go/**
|
||||
- whisper.h
|
||||
|
||||
jobs:
|
||||
ubuntu-latest:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: '^1.19'
|
||||
- uses: actions/checkout@v1
|
||||
- run: |
|
||||
cd bindings/go
|
||||
make test
|
22
.github/workflows/bindings-ruby.yml
vendored
Normal file
22
.github/workflows/bindings-ruby.yml
vendored
Normal file
@ -0,0 +1,22 @@
|
||||
name: Bindings Tests (Ruby)
|
||||
on:
|
||||
push:
|
||||
paths:
|
||||
- bindings/ruby/**
|
||||
- whisper.h
|
||||
pull_request:
|
||||
paths:
|
||||
- bindings/ruby/**
|
||||
- whisper.h
|
||||
|
||||
jobs:
|
||||
ubuntu-latest:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: ruby/setup-ruby@v1
|
||||
with:
|
||||
ruby-version: '3.0'
|
||||
- uses: actions/checkout@v1
|
||||
- run: |
|
||||
cd bindings/ruby/ext
|
||||
ruby extconf.rb && make
|
449
.github/workflows/build.yml
vendored
449
.github/workflows/build.yml
vendored
@ -1,237 +1,308 @@
|
||||
name: CI
|
||||
on: [push]
|
||||
on: [push, pull_request]
|
||||
|
||||
jobs:
|
||||
ubuntu-latest:
|
||||
runs-on: ubuntu-latest
|
||||
ubuntu-latest:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v1
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v1
|
||||
|
||||
- name: Dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install build-essential
|
||||
sudo apt-get install libsdl2-dev
|
||||
- name: Dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install build-essential
|
||||
sudo apt-get install libsdl2-dev
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
make
|
||||
make stream
|
||||
- name: Build
|
||||
run: |
|
||||
make
|
||||
make stream
|
||||
|
||||
macOS-latest:
|
||||
runs-on: macOS-latest
|
||||
macOS-latest:
|
||||
runs-on: macOS-latest
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v1
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v1
|
||||
|
||||
- name: Dependencies
|
||||
run: |
|
||||
brew update
|
||||
brew install sdl2
|
||||
- name: Dependencies
|
||||
run: |
|
||||
brew update
|
||||
brew install sdl2
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
make
|
||||
make stream
|
||||
- name: Build
|
||||
run: |
|
||||
make
|
||||
make stream
|
||||
|
||||
ubuntu-latest-gcc:
|
||||
runs-on: ubuntu-latest
|
||||
ubuntu-latest-gcc:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
build: [Debug, Release]
|
||||
strategy:
|
||||
matrix:
|
||||
build: [Debug, Release]
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v1
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v1
|
||||
|
||||
- name: Dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install build-essential
|
||||
sudo apt-get install cmake
|
||||
sudo apt-get install libsdl2-dev
|
||||
- name: Dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install build-essential
|
||||
sudo apt-get install cmake
|
||||
sudo apt-get install libsdl2-dev
|
||||
|
||||
- name: Configure
|
||||
run: cmake . -DWHISPER_SUPPORT_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }}
|
||||
- name: Configure
|
||||
run: cmake . -DWHISPER_SUPPORT_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }}
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
make
|
||||
ctest -L gh --output-on-failure
|
||||
- name: Build
|
||||
run: |
|
||||
make
|
||||
ctest -L gh --output-on-failure
|
||||
|
||||
ubuntu-latest-clang:
|
||||
runs-on: ubuntu-latest
|
||||
ubuntu-latest-clang:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
build: [Debug, Release]
|
||||
strategy:
|
||||
matrix:
|
||||
build: [Debug, Release]
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v1
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v1
|
||||
|
||||
- name: Dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install build-essential
|
||||
sudo apt-get install cmake
|
||||
sudo apt-get install libsdl2-dev
|
||||
- name: Dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install build-essential
|
||||
sudo apt-get install cmake
|
||||
sudo apt-get install libsdl2-dev
|
||||
|
||||
- name: Configure
|
||||
run: cmake . -DWHISPER_SUPPORT_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_COMPILER=clang
|
||||
- name: Configure
|
||||
run: cmake . -DWHISPER_SUPPORT_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_COMPILER=clang
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
make
|
||||
ctest -L gh --output-on-failure
|
||||
- name: Build
|
||||
run: |
|
||||
make
|
||||
ctest -L gh --output-on-failure
|
||||
|
||||
ubuntu-latest-gcc-sanitized:
|
||||
runs-on: ubuntu-latest
|
||||
ubuntu-latest-gcc-sanitized:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
sanitizer: [ADDRESS, THREAD, UNDEFINED]
|
||||
strategy:
|
||||
matrix:
|
||||
sanitizer: [ADDRESS, THREAD, UNDEFINED]
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v1
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v1
|
||||
|
||||
- name: Dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install build-essential
|
||||
sudo apt-get install cmake
|
||||
- name: Dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install build-essential
|
||||
sudo apt-get install cmake
|
||||
|
||||
- name: Configure
|
||||
run: cmake . -DCMAKE_BUILD_TYPE=Debug -DWHISPER_SANITIZE_${{ matrix.sanitizer }}=ON
|
||||
- name: Configure
|
||||
run: cmake . -DCMAKE_BUILD_TYPE=Debug -DWHISPER_SANITIZE_${{ matrix.sanitizer }}=ON
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
make
|
||||
ctest -L gh --output-on-failure
|
||||
- name: Build
|
||||
run: |
|
||||
make
|
||||
ctest -L gh --output-on-failure
|
||||
|
||||
windows:
|
||||
runs-on: windows-latest
|
||||
windows:
|
||||
runs-on: windows-latest
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
build: [Release]
|
||||
arch: [Win32, x64]
|
||||
sdl2: [ON]
|
||||
include:
|
||||
- arch: Win32
|
||||
s2arc: x86
|
||||
- arch: x64
|
||||
s2arc: x64
|
||||
- sdl2: ON
|
||||
s2ver: 2.26.0
|
||||
strategy:
|
||||
matrix:
|
||||
build: [Release]
|
||||
arch: [Win32, x64]
|
||||
sdl2: [ON]
|
||||
include:
|
||||
- arch: Win32
|
||||
s2arc: x86
|
||||
- arch: x64
|
||||
s2arc: x64
|
||||
- sdl2: ON
|
||||
s2ver: 2.26.0
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v1
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v1
|
||||
|
||||
- name: Add msbuild to PATH
|
||||
uses: microsoft/setup-msbuild@v1
|
||||
- name: Add msbuild to PATH
|
||||
uses: microsoft/setup-msbuild@v1
|
||||
|
||||
- name: Fetch SDL2 and set SDL2_DIR
|
||||
if: matrix.sdl2 == 'ON'
|
||||
run: |
|
||||
C:/msys64/usr/bin/wget.exe -qO sdl2.zip https://github.com/libsdl-org/SDL/releases/download/release-${{ matrix.s2ver }}/SDL2-devel-${{ matrix.s2ver }}-VC.zip
|
||||
7z x sdl2.zip
|
||||
echo "SDL2_DIR=$env:GITHUB_WORKSPACE/SDL2-${{ matrix.s2ver }}/cmake" >> $env:GITHUB_ENV
|
||||
- name: Fetch SDL2 and set SDL2_DIR
|
||||
if: matrix.sdl2 == 'ON'
|
||||
run: |
|
||||
C:/msys64/usr/bin/wget.exe -qO sdl2.zip https://github.com/libsdl-org/SDL/releases/download/release-${{ matrix.s2ver }}/SDL2-devel-${{ matrix.s2ver }}-VC.zip
|
||||
7z x sdl2.zip
|
||||
echo "SDL2_DIR=$env:GITHUB_WORKSPACE/SDL2-${{ matrix.s2ver }}/cmake" >> $env:GITHUB_ENV
|
||||
|
||||
- name: Configure
|
||||
run: >
|
||||
cmake -S . -B ./build -A ${{ matrix.arch }}
|
||||
-DCMAKE_BUILD_TYPE=${{ matrix.build }}
|
||||
-DWHISPER_SUPPORT_SDL2=${{ matrix.sdl2 }}
|
||||
- name: Configure
|
||||
run: >
|
||||
cmake -S . -B ./build -A ${{ matrix.arch }}
|
||||
-DCMAKE_BUILD_TYPE=${{ matrix.build }}
|
||||
-DWHISPER_SUPPORT_SDL2=${{ matrix.sdl2 }}
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
cd ./build
|
||||
msbuild ALL_BUILD.vcxproj -t:build -p:configuration=${{ matrix.build }} -p:platform=${{ matrix.arch }}
|
||||
- name: Build
|
||||
run: |
|
||||
cd ./build
|
||||
msbuild ALL_BUILD.vcxproj -t:build -p:configuration=${{ matrix.build }} -p:platform=${{ matrix.arch }}
|
||||
|
||||
- name: Copy SDL2.dll
|
||||
if: matrix.sdl2 == 'ON'
|
||||
run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }}
|
||||
- name: Copy SDL2.dll
|
||||
if: matrix.sdl2 == 'ON'
|
||||
run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }}
|
||||
|
||||
- name: Upload binaries
|
||||
if: matrix.sdl2 == 'ON'
|
||||
uses: actions/upload-artifact@v1
|
||||
with:
|
||||
name: whisper-bin-${{ matrix.arch }}
|
||||
path: build/bin/${{ matrix.build }}
|
||||
- name: Upload binaries
|
||||
if: matrix.sdl2 == 'ON'
|
||||
uses: actions/upload-artifact@v1
|
||||
with:
|
||||
name: whisper-bin-${{ matrix.arch }}
|
||||
path: build/bin/${{ matrix.build }}
|
||||
|
||||
windows-blas:
|
||||
runs-on: windows-latest
|
||||
windows-blas:
|
||||
runs-on: windows-latest
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
build: [Release]
|
||||
arch: [Win32, x64]
|
||||
blas: [ON]
|
||||
sdl2: [ON]
|
||||
include:
|
||||
- arch: Win32
|
||||
obzip: https://github.com/xianyi/OpenBLAS/releases/download/v0.3.21/OpenBLAS-0.3.21-x86.zip
|
||||
s2arc: x86
|
||||
- arch: x64
|
||||
obzip: https://github.com/xianyi/OpenBLAS/releases/download/v0.3.21/OpenBLAS-0.3.21-x64.zip
|
||||
s2arc: x64
|
||||
- sdl2: ON
|
||||
s2ver: 2.26.0
|
||||
strategy:
|
||||
matrix:
|
||||
build: [Release]
|
||||
arch: [Win32, x64]
|
||||
blas: [ON]
|
||||
sdl2: [ON]
|
||||
include:
|
||||
- arch: Win32
|
||||
obzip: https://github.com/xianyi/OpenBLAS/releases/download/v0.3.21/OpenBLAS-0.3.21-x86.zip
|
||||
s2arc: x86
|
||||
- arch: x64
|
||||
obzip: https://github.com/xianyi/OpenBLAS/releases/download/v0.3.21/OpenBLAS-0.3.21-x64.zip
|
||||
s2arc: x64
|
||||
- sdl2: ON
|
||||
s2ver: 2.26.0
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v1
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v1
|
||||
|
||||
- name: Add msbuild to PATH
|
||||
uses: microsoft/setup-msbuild@v1
|
||||
- name: Add msbuild to PATH
|
||||
uses: microsoft/setup-msbuild@v1
|
||||
|
||||
- name: Fetch OpenBLAS
|
||||
if: matrix.blas == 'ON'
|
||||
run: |
|
||||
C:/msys64/usr/bin/wget.exe -qO blas.zip ${{ matrix.obzip }}
|
||||
7z x blas.zip -oblas -y
|
||||
copy blas/include/cblas.h .
|
||||
copy blas/include/openblas_config.h .
|
||||
echo "blasdir=$env:GITHUB_WORKSPACE/blas" >> $env:GITHUB_ENV
|
||||
- name: Fetch OpenBLAS
|
||||
if: matrix.blas == 'ON'
|
||||
run: |
|
||||
C:/msys64/usr/bin/wget.exe -qO blas.zip ${{ matrix.obzip }}
|
||||
7z x blas.zip -oblas -y
|
||||
copy blas/include/cblas.h .
|
||||
copy blas/include/openblas_config.h .
|
||||
echo "blasdir=$env:GITHUB_WORKSPACE/blas" >> $env:GITHUB_ENV
|
||||
|
||||
- name: Fetch SDL2 and set SDL2_DIR
|
||||
if: matrix.sdl2 == 'ON'
|
||||
run: |
|
||||
C:/msys64/usr/bin/wget.exe -qO sdl2.zip https://github.com/libsdl-org/SDL/releases/download/release-${{ matrix.s2ver }}/SDL2-devel-${{ matrix.s2ver }}-VC.zip
|
||||
7z x sdl2.zip
|
||||
echo "SDL2_DIR=$env:GITHUB_WORKSPACE/SDL2-${{ matrix.s2ver }}/cmake" >> $env:GITHUB_ENV
|
||||
- name: Fetch SDL2 and set SDL2_DIR
|
||||
if: matrix.sdl2 == 'ON'
|
||||
run: |
|
||||
C:/msys64/usr/bin/wget.exe -qO sdl2.zip https://github.com/libsdl-org/SDL/releases/download/release-${{ matrix.s2ver }}/SDL2-devel-${{ matrix.s2ver }}-VC.zip
|
||||
7z x sdl2.zip
|
||||
echo "SDL2_DIR=$env:GITHUB_WORKSPACE/SDL2-${{ matrix.s2ver }}/cmake" >> $env:GITHUB_ENV
|
||||
|
||||
- name: Configure
|
||||
run: >
|
||||
cmake -S . -B ./build -A ${{ matrix.arch }}
|
||||
-DCMAKE_BUILD_TYPE=${{ matrix.build }}
|
||||
-DWHISPER_SUPPORT_OPENBLAS=${{ matrix.blas }}
|
||||
-DCMAKE_LIBRARY_PATH="$env:blasdir/lib"
|
||||
-DWHISPER_SUPPORT_SDL2=${{ matrix.sdl2 }}
|
||||
- name: Configure
|
||||
run: >
|
||||
cmake -S . -B ./build -A ${{ matrix.arch }}
|
||||
-DCMAKE_BUILD_TYPE=${{ matrix.build }}
|
||||
-DWHISPER_SUPPORT_OPENBLAS=${{ matrix.blas }}
|
||||
-DCMAKE_LIBRARY_PATH="$env:blasdir/lib"
|
||||
-DWHISPER_SUPPORT_SDL2=${{ matrix.sdl2 }}
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
cd ./build
|
||||
msbuild ALL_BUILD.vcxproj -t:build -p:configuration=${{ matrix.build }} -p:platform=${{ matrix.arch }}
|
||||
- name: Build
|
||||
run: |
|
||||
cd ./build
|
||||
msbuild ALL_BUILD.vcxproj -t:build -p:configuration=${{ matrix.build }} -p:platform=${{ matrix.arch }}
|
||||
|
||||
- name: Copy libopenblas.dll
|
||||
if: matrix.blas == 'ON'
|
||||
run: copy "$env:blasdir/bin/libopenblas.dll" build/bin/${{ matrix.build }}
|
||||
- name: Copy libopenblas.dll
|
||||
if: matrix.blas == 'ON'
|
||||
run: copy "$env:blasdir/bin/libopenblas.dll" build/bin/${{ matrix.build }}
|
||||
|
||||
- name: Copy SDL2.dll
|
||||
if: matrix.sdl2 == 'ON'
|
||||
run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }}
|
||||
- name: Copy SDL2.dll
|
||||
if: matrix.sdl2 == 'ON'
|
||||
run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }}
|
||||
|
||||
- name: Upload binaries
|
||||
if: matrix.blas == 'ON' && matrix.sdl2 == 'ON'
|
||||
uses: actions/upload-artifact@v1
|
||||
with:
|
||||
name: whisper-blas-bin-${{ matrix.arch }}
|
||||
path: build/bin/${{ matrix.build }}
|
||||
- name: Upload binaries
|
||||
if: matrix.blas == 'ON' && matrix.sdl2 == 'ON'
|
||||
uses: actions/upload-artifact@v1
|
||||
with:
|
||||
name: whisper-blas-bin-${{ matrix.arch }}
|
||||
path: build/bin/${{ matrix.build }}
|
||||
|
||||
emscripten:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
build: [Release]
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v1
|
||||
|
||||
- name: Dependencies
|
||||
run: |
|
||||
wget -q https://github.com/emscripten-core/emsdk/archive/master.tar.gz
|
||||
tar -xvf master.tar.gz
|
||||
emsdk-master/emsdk update
|
||||
emsdk-master/emsdk install latest
|
||||
emsdk-master/emsdk activate latest
|
||||
|
||||
- name: Configure
|
||||
run: echo "tmp"
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
pushd emsdk-master
|
||||
source ./emsdk_env.sh
|
||||
popd
|
||||
emcmake cmake . -DCMAKE_BUILD_TYPE=${{ matrix.build }}
|
||||
make
|
||||
|
||||
ios:
|
||||
runs-on: macos-latest
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
build: [Release]
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v1
|
||||
|
||||
- name: Configure
|
||||
run: cp models/for-tests-ggml-base.en.bin models/ggml-base.en.bin
|
||||
|
||||
- name: Build objc example
|
||||
run: xcodebuild -project examples/whisper.objc/whisper.objc.xcodeproj -scheme whisper.objc -configuration ${{ matrix.build }} -sdk iphonesimulator build
|
||||
|
||||
- name: Build swiftui example
|
||||
run: xcodebuild -project examples/whisper.swiftui/whisper.swiftui.xcodeproj -scheme WhisperCppDemo -configuration ${{ matrix.build }} -sdk iphonesimulator build
|
||||
|
||||
android:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v1
|
||||
|
||||
- name: Install Java
|
||||
uses: actions/setup-java@v3
|
||||
with:
|
||||
distribution: zulu
|
||||
java-version: 17
|
||||
|
||||
- name: Setup Android SDK
|
||||
uses: android-actions/setup-android@v2
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
cd examples/whisper.android
|
||||
./gradlew assembleRelease --no-daemon
|
48
.github/workflows/examples.yml
vendored
Normal file
48
.github/workflows/examples.yml
vendored
Normal file
@ -0,0 +1,48 @@
|
||||
name: Examples Tests
|
||||
on:
|
||||
push:
|
||||
paths:
|
||||
- examples/addon.node/**
|
||||
- whisper.h
|
||||
pull_request:
|
||||
paths:
|
||||
- examples/addon.node/**
|
||||
- whisper.h
|
||||
|
||||
jobs:
|
||||
addon_node-ubuntu-latest:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
node-version: [ 16.x, 18.x ]
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v1
|
||||
|
||||
- name: Dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install build-essential
|
||||
sudo apt-get install cmake
|
||||
sudo apt-get install libsdl2-dev
|
||||
|
||||
- name: Use Node.js ${{ matrix.node-version }}
|
||||
uses: actions/setup-node@v1
|
||||
with:
|
||||
node-version: ${{ matrix.node-version }}
|
||||
cache: 'npm'
|
||||
|
||||
- name: Install package.json dependencies
|
||||
working-directory: ./examples/addon.node
|
||||
run: npm install
|
||||
|
||||
- name: Compile addon.node
|
||||
run: npx cmake-js compile -T whisper-addon -B Release
|
||||
|
||||
- name: Download test model
|
||||
run: |
|
||||
bash ./models/download-ggml-model.sh base.en
|
||||
- name: Test
|
||||
run: |
|
||||
cd examples/addon.node
|
||||
npm run test
|
14
.gitignore
vendored
14
.gitignore
vendored
@ -1,5 +1,8 @@
|
||||
*.o
|
||||
*.a
|
||||
.cache/
|
||||
.coreml/
|
||||
.test/
|
||||
.vs/
|
||||
.vscode/
|
||||
.DS_Store
|
||||
@ -8,6 +11,9 @@ build/
|
||||
build-em/
|
||||
build-debug/
|
||||
build-release/
|
||||
build-static/
|
||||
build-cublas/
|
||||
build-no-accel/
|
||||
build-sanitize-addr/
|
||||
build-sanitize-thread/
|
||||
|
||||
@ -15,9 +21,13 @@ build-sanitize-thread/
|
||||
/stream
|
||||
/command
|
||||
/talk
|
||||
/talk-llama
|
||||
/bench
|
||||
/quantize
|
||||
|
||||
arm_neon.h
|
||||
sync.sh
|
||||
libwhisper.a
|
||||
libwhisper.so
|
||||
compile_commands.json
|
||||
|
||||
@ -27,3 +37,7 @@ examples/whisper.objc/whisper.objc.xcodeproj/xcuserdata/
|
||||
examples/whisper.objc/whisper.objc.xcodeproj/project.xcworkspace/xcuserdata
|
||||
|
||||
extra/bench-gg.txt
|
||||
|
||||
models/*.mlmodel
|
||||
models/*.mlmodelc
|
||||
models/*.mlpackage
|
||||
|
186
CMakeLists.txt
186
CMakeLists.txt
@ -1,15 +1,20 @@
|
||||
cmake_minimum_required (VERSION 3.0)
|
||||
|
||||
project(whisper.cpp VERSION 1.0.4)
|
||||
project(whisper.cpp VERSION 1.4.0)
|
||||
|
||||
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC")
|
||||
add_compile_options(/utf-8)
|
||||
endif ()
|
||||
|
||||
# Add path to modules
|
||||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
|
||||
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS "on")
|
||||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
|
||||
set(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_PREFIX}/lib")
|
||||
|
||||
if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
|
||||
set(WHISPER_STANDALONE ON)
|
||||
include(cmake/GitVars.cmake)
|
||||
include(cmake/BuildTypes.cmake)
|
||||
include(GitVars)
|
||||
include(BuildTypes)
|
||||
|
||||
# configure project version
|
||||
if (EXISTS "${CMAKE_SOURCE_DIR}/bindings/ios/Makefile-tmpl")
|
||||
@ -34,29 +39,34 @@ endif()
|
||||
|
||||
# options
|
||||
|
||||
option(BUILD_SHARED_LIBS "whisper: build shared libs" ${BUILD_SHARED_LIBS_DEFAULT})
|
||||
option(BUILD_SHARED_LIBS "whisper: build shared libs" ${BUILD_SHARED_LIBS_DEFAULT})
|
||||
|
||||
option(WHISPER_ALL_WARNINGS "whisper: enable all compiler warnings" ON)
|
||||
option(WHISPER_ALL_WARNINGS_3RD_PARTY "whisper: enable all compiler warnings in 3rd party libs" OFF)
|
||||
option(WHISPER_ALL_WARNINGS "whisper: enable all compiler warnings" ON)
|
||||
option(WHISPER_ALL_WARNINGS_3RD_PARTY "whisper: enable all compiler warnings in 3rd party libs" OFF)
|
||||
|
||||
option(WHISPER_SANITIZE_THREAD "whisper: enable thread sanitizer" OFF)
|
||||
option(WHISPER_SANITIZE_ADDRESS "whisper: enable address sanitizer" OFF)
|
||||
option(WHISPER_SANITIZE_UNDEFINED "whisper: enable undefined sanitizer" OFF)
|
||||
option(WHISPER_SANITIZE_THREAD "whisper: enable thread sanitizer" OFF)
|
||||
option(WHISPER_SANITIZE_ADDRESS "whisper: enable address sanitizer" OFF)
|
||||
option(WHISPER_SANITIZE_UNDEFINED "whisper: enable undefined sanitizer" OFF)
|
||||
|
||||
option(WHISPER_BUILD_TESTS "whisper: build tests" ${WHISPER_STANDALONE})
|
||||
option(WHISPER_BUILD_EXAMPLES "whisper: build examples" ${WHISPER_STANDALONE})
|
||||
option(WHISPER_BUILD_TESTS "whisper: build tests" ${WHISPER_STANDALONE})
|
||||
option(WHISPER_BUILD_EXAMPLES "whisper: build examples" ${WHISPER_STANDALONE})
|
||||
|
||||
option(WHISPER_SUPPORT_SDL2 "whisper: support for libSDL2" OFF)
|
||||
option(WHISPER_SDL2 "whisper: support for libSDL2" OFF)
|
||||
|
||||
if (APPLE)
|
||||
option(WHISPER_NO_ACCELERATE "whisper: disable Accelerate framework" OFF)
|
||||
option(WHISPER_NO_AVX "whisper: disable AVX" OFF)
|
||||
option(WHISPER_NO_AVX2 "whisper: disable AVX2" OFF)
|
||||
option(WHISPER_NO_ACCELERATE "whisper: disable Accelerate framework" OFF)
|
||||
option(WHISPER_NO_AVX "whisper: disable AVX" OFF)
|
||||
option(WHISPER_NO_AVX2 "whisper: disable AVX2" OFF)
|
||||
option(WHISPER_NO_FMA "whisper: disable FMA" OFF)
|
||||
|
||||
option(WHISPER_COREML "whisper: enable Core ML framework" OFF)
|
||||
option(WHISPER_COREML_ALLOW_FALLBACK "whisper: allow non-CoreML fallback" OFF)
|
||||
else()
|
||||
option(WHISPER_SUPPORT_OPENBLAS "whisper: support for OpenBLAS" OFF)
|
||||
option(WHISPER_OPENBLAS "whisper: support for OpenBLAS" OFF)
|
||||
option(WHISPER_CUBLAS "whisper: support for cuBLAS" OFF)
|
||||
endif()
|
||||
|
||||
option(WHISPER_PERF "whisper: enable perf timings" OFF)
|
||||
option(WHISPER_PERF "whisper: enable perf timings" OFF)
|
||||
|
||||
# sanitizers
|
||||
|
||||
@ -82,25 +92,43 @@ endif()
|
||||
|
||||
# dependencies
|
||||
|
||||
set(CMAKE_C_STANDARD 11)
|
||||
set(CMAKE_CXX_STANDARD 11)
|
||||
|
||||
find_package(Threads REQUIRED)
|
||||
|
||||
# on APPLE - include Accelerate framework
|
||||
if (APPLE AND NOT WHISPER_NO_ACCELERATE)
|
||||
find_library(ACCELERATE_FRAMEWORK Accelerate)
|
||||
if (ACCELERATE_FRAMEWORK)
|
||||
message(STATUS "Accelerate framework found")
|
||||
# on APPLE
|
||||
if (APPLE)
|
||||
# include Accelerate framework
|
||||
if (NOT WHISPER_NO_ACCELERATE)
|
||||
find_library(ACCELERATE_FRAMEWORK Accelerate)
|
||||
|
||||
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ${ACCELERATE_FRAMEWORK})
|
||||
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_USE_ACCELERATE)
|
||||
else()
|
||||
message(WARNING "Accelerate framework not found")
|
||||
if (ACCELERATE_FRAMEWORK)
|
||||
message(STATUS "Accelerate framework found")
|
||||
|
||||
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ${ACCELERATE_FRAMEWORK})
|
||||
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_USE_ACCELERATE)
|
||||
else()
|
||||
message(WARNING "Accelerate framework not found")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (WHISPER_COREML)
|
||||
find_library(FOUNDATION_FRAMEWORK Foundation)
|
||||
find_library(COREML_FRAMEWORK CoreML)
|
||||
|
||||
if (COREML_FRAMEWORK)
|
||||
message(STATUS "CoreML framework found")
|
||||
|
||||
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DWHISPER_USE_COREML)
|
||||
else()
|
||||
message(WARNING "CoreML framework not found")
|
||||
endif()
|
||||
|
||||
if (WHISPER_COREML_ALLOW_FALLBACK)
|
||||
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DWHISPER_USE_COREML_ALLOW_FALLBACK)
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (WHISPER_SUPPORT_OPENBLAS)
|
||||
if (WHISPER_OPENBLAS)
|
||||
find_library(OPENBLAS_LIB
|
||||
NAMES openblas libopenblas
|
||||
)
|
||||
@ -114,6 +142,31 @@ if (WHISPER_SUPPORT_OPENBLAS)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (WHISPER_CUBLAS)
|
||||
cmake_minimum_required(VERSION 3.17)
|
||||
|
||||
find_package(CUDAToolkit)
|
||||
|
||||
if (CUDAToolkit_FOUND)
|
||||
message(STATUS "cuBLAS found")
|
||||
|
||||
enable_language(CUDA)
|
||||
|
||||
set(GGML_CUDA_SOURCES ggml-cuda.cu ggml-cuda.h)
|
||||
|
||||
add_compile_definitions(GGML_USE_CUBLAS)
|
||||
|
||||
if (WHISPER_STATIC)
|
||||
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
|
||||
else()
|
||||
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
|
||||
endif()
|
||||
|
||||
else()
|
||||
message(WARNING "cuBLAS not found")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# compiler flags
|
||||
|
||||
if (NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES)
|
||||
@ -131,6 +184,13 @@ if (WHISPER_ALL_WARNINGS)
|
||||
-Wcast-qual \
|
||||
-Wstrict-prototypes \
|
||||
-Wpointer-arith \
|
||||
-Wno-unused-function \
|
||||
")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} \
|
||||
-Wall \
|
||||
-Wextra \
|
||||
-Wpedantic \
|
||||
-Wcast-qual \
|
||||
")
|
||||
else()
|
||||
# todo : msvc
|
||||
@ -151,6 +211,7 @@ else()
|
||||
if (MSVC)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX2")
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /arch:AVX2")
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /arch:AVX2")
|
||||
else()
|
||||
if (EMSCRIPTEN)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -pthread")
|
||||
@ -162,7 +223,12 @@ else()
|
||||
if(NOT WHISPER_NO_AVX2)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx2")
|
||||
endif()
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mfma -mf16c")
|
||||
if(NOT WHISPER_NO_FMA)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mfma")
|
||||
endif()
|
||||
if(NOT WHISPER_NO_F16C)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mf16c")
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
@ -171,6 +237,33 @@ if (WHISPER_PERF)
|
||||
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_PERF)
|
||||
endif()
|
||||
|
||||
#
|
||||
# whisper.coreml - Core ML support
|
||||
#
|
||||
|
||||
if (WHISPER_COREML)
|
||||
set(TARGET whisper.coreml)
|
||||
|
||||
add_library(${TARGET}
|
||||
coreml/whisper-encoder.h
|
||||
coreml/whisper-encoder.mm
|
||||
coreml/whisper-encoder-impl.h
|
||||
coreml/whisper-encoder-impl.m
|
||||
)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_include_directories(${TARGET} PUBLIC
|
||||
.
|
||||
)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE ${FOUNDATION_FRAMEWORK} ${COREML_FRAMEWORK})
|
||||
|
||||
set_target_properties(${TARGET} PROPERTIES
|
||||
COMPILE_FLAGS "-fobjc-arc"
|
||||
)
|
||||
endif()
|
||||
|
||||
#
|
||||
# whisper - this is the main library of the project
|
||||
#
|
||||
@ -178,14 +271,23 @@ endif()
|
||||
set(TARGET whisper)
|
||||
|
||||
add_library(${TARGET}
|
||||
ggml.h
|
||||
ggml.c
|
||||
${GGML_CUDA_SOURCES}
|
||||
whisper.h
|
||||
whisper.cpp
|
||||
)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_include_directories(${TARGET} PUBLIC
|
||||
.
|
||||
)
|
||||
|
||||
if (WHISPER_COREML)
|
||||
target_link_libraries(${TARGET} PRIVATE whisper.coreml)
|
||||
endif()
|
||||
|
||||
if (MSVC)
|
||||
target_link_libraries(${TARGET} PRIVATE ${WHISPER_EXTRA_LIBS} ${CMAKE_THREAD_LIBS_INIT})
|
||||
|
||||
@ -201,7 +303,19 @@ if (BUILD_SHARED_LIBS)
|
||||
|
||||
target_compile_definitions(${TARGET} PUBLIC
|
||||
WHISPER_SHARED
|
||||
GGML_SHARED
|
||||
)
|
||||
|
||||
target_compile_definitions(${TARGET} PRIVATE
|
||||
WHISPER_BUILD
|
||||
GGML_BUILD
|
||||
)
|
||||
endif()
|
||||
|
||||
if (GGML_CUDA_SOURCES)
|
||||
message(STATUS "GGML CUDA sources found, configuring CUDA architecture")
|
||||
set_property(TARGET whisper PROPERTY CUDA_ARCHITECTURES OFF)
|
||||
set_property(TARGET whisper PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto")
|
||||
endif()
|
||||
|
||||
if (EMSCRIPTEN)
|
||||
@ -212,9 +326,13 @@ target_compile_definitions(${TARGET} PUBLIC
|
||||
${WHISPER_EXTRA_FLAGS}
|
||||
)
|
||||
|
||||
set_target_properties(${TARGET} PROPERTIES PUBLIC_HEADER "whisper.h")
|
||||
|
||||
install(TARGETS ${TARGET}
|
||||
LIBRARY DESTINATION lib
|
||||
ARCHIVE DESTINATION lib/static
|
||||
RUNTIME DESTINATION bin
|
||||
PUBLIC_HEADER DESTINATION include
|
||||
)
|
||||
|
||||
#
|
||||
@ -227,7 +345,7 @@ add_subdirectory(bindings)
|
||||
# programs, examples and tests
|
||||
#
|
||||
|
||||
if (WHISPER_BUILD_TESTS)
|
||||
if (WHISPER_BUILD_TESTS AND NOT CMAKE_JS_VERSION)
|
||||
enable_testing()
|
||||
add_subdirectory(tests)
|
||||
endif ()
|
||||
|
2
LICENSE
2
LICENSE
@ -1,6 +1,6 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2022 Georgi Gerganov
|
||||
Copyright (c) 2023 Georgi Gerganov
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
|
174
Makefile
174
Makefile
@ -1,3 +1,5 @@
|
||||
default: main bench quantize
|
||||
|
||||
ifndef UNAME_S
|
||||
UNAME_S := $(shell uname -s)
|
||||
endif
|
||||
@ -10,6 +12,9 @@ ifndef UNAME_M
|
||||
UNAME_M := $(shell uname -m)
|
||||
endif
|
||||
|
||||
CCV := $(shell $(CC) --version | head -n 1)
|
||||
CXXV := $(shell $(CXX) --version | head -n 1)
|
||||
|
||||
# Mac OS + Arm can report x86_64
|
||||
# ref: https://github.com/ggerganov/whisper.cpp/issues/66#issuecomment-1282546789
|
||||
ifeq ($(UNAME_S),Darwin)
|
||||
@ -27,10 +32,16 @@ endif
|
||||
# Compile flags
|
||||
#
|
||||
|
||||
CFLAGS = -I. -O3 -std=c11 -fPIC
|
||||
CXXFLAGS = -I. -I./examples -O3 -std=c++11 -fPIC
|
||||
CFLAGS = -I. -O3 -DNDEBUG -std=c11 -fPIC
|
||||
CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC
|
||||
LDFLAGS =
|
||||
|
||||
# ref: https://github.com/ggerganov/whisper.cpp/issues/37
|
||||
ifneq ($(wildcard /usr/include/musl/*),)
|
||||
CFLAGS += -D_POSIX_SOURCE -D_GNU_SOURCE
|
||||
CXXFLAGS += -D_POSIX_SOURCE -D_GNU_SOURCE
|
||||
endif
|
||||
|
||||
# OS specific
|
||||
# TODO: support Windows
|
||||
ifeq ($(UNAME_S),Linux)
|
||||
@ -53,10 +64,13 @@ endif
|
||||
# Architecture specific
|
||||
# TODO: probably these flags need to be tweaked on some architectures
|
||||
# feel free to update the Makefile for your architecture and send a pull request or issue
|
||||
ifeq ($(UNAME_M),x86_64)
|
||||
ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686))
|
||||
ifeq ($(UNAME_S),Darwin)
|
||||
CFLAGS += -mfma -mf16c
|
||||
CFLAGS += -mf16c
|
||||
AVX1_M := $(shell sysctl machdep.cpu.features)
|
||||
ifneq (,$(findstring FMA,$(AVX1_M)))
|
||||
CFLAGS += -mfma
|
||||
endif
|
||||
ifneq (,$(findstring AVX1.0,$(AVX1_M)))
|
||||
CFLAGS += -mavx
|
||||
endif
|
||||
@ -65,10 +79,6 @@ ifeq ($(UNAME_M),x86_64)
|
||||
CFLAGS += -mavx2
|
||||
endif
|
||||
else ifeq ($(UNAME_S),Linux)
|
||||
AVX1_M := $(shell grep "avx " /proc/cpuinfo)
|
||||
ifneq (,$(findstring avx,$(AVX1_M)))
|
||||
CFLAGS += -mavx
|
||||
endif
|
||||
AVX2_M := $(shell grep "avx2 " /proc/cpuinfo)
|
||||
ifneq (,$(findstring avx2,$(AVX2_M)))
|
||||
CFLAGS += -mavx2
|
||||
@ -80,12 +90,17 @@ ifeq ($(UNAME_M),x86_64)
|
||||
F16C_M := $(shell grep "f16c " /proc/cpuinfo)
|
||||
ifneq (,$(findstring f16c,$(F16C_M)))
|
||||
CFLAGS += -mf16c
|
||||
|
||||
AVX1_M := $(shell grep "avx " /proc/cpuinfo)
|
||||
ifneq (,$(findstring avx,$(AVX1_M)))
|
||||
CFLAGS += -mavx
|
||||
endif
|
||||
endif
|
||||
SSE3_M := $(shell grep "sse3 " /proc/cpuinfo)
|
||||
ifneq (,$(findstring sse3,$(SSE3_M)))
|
||||
CFLAGS += -msse3
|
||||
endif
|
||||
else ifeq ($(UNAME_S),Haiku)
|
||||
AVX1_M := $(shell sysinfo -cpu | grep "AVX ")
|
||||
ifneq (,$(findstring avx,$(AVX1_M)))
|
||||
CFLAGS += -mavx
|
||||
endif
|
||||
AVX2_M := $(shell sysinfo -cpu | grep "AVX2 ")
|
||||
ifneq (,$(findstring avx2,$(AVX2_M)))
|
||||
CFLAGS += -mavx2
|
||||
@ -97,6 +112,11 @@ ifeq ($(UNAME_M),x86_64)
|
||||
F16C_M := $(shell sysinfo -cpu | grep "F16C ")
|
||||
ifneq (,$(findstring f16c,$(F16C_M)))
|
||||
CFLAGS += -mf16c
|
||||
|
||||
AVX1_M := $(shell sysinfo -cpu | grep "AVX ")
|
||||
ifneq (,$(findstring avx,$(AVX1_M)))
|
||||
CFLAGS += -mavx
|
||||
endif
|
||||
endif
|
||||
else
|
||||
CFLAGS += -mfma -mf16c -mavx -mavx2
|
||||
@ -105,6 +125,18 @@ endif
|
||||
ifeq ($(UNAME_M),amd64)
|
||||
CFLAGS += -mavx -mavx2 -mfma -mf16c
|
||||
endif
|
||||
|
||||
ifneq ($(filter ppc64%,$(UNAME_M)),)
|
||||
POWER9_M := $(shell grep "POWER9" /proc/cpuinfo)
|
||||
ifneq (,$(findstring POWER9,$(POWER9_M)))
|
||||
CFLAGS += -mpower9-vector
|
||||
endif
|
||||
# Require c++23's std::byteswap for big-endian support.
|
||||
ifeq ($(UNAME_M),ppc64)
|
||||
CXXFLAGS += -std=c++23 -DGGML_BIG_ENDIAN
|
||||
endif
|
||||
endif
|
||||
|
||||
ifndef WHISPER_NO_ACCELERATE
|
||||
# Mac M1 - include Accelerate framework
|
||||
ifeq ($(UNAME_S),Darwin)
|
||||
@ -112,49 +144,106 @@ ifndef WHISPER_NO_ACCELERATE
|
||||
LDFLAGS += -framework Accelerate
|
||||
endif
|
||||
endif
|
||||
|
||||
ifdef WHISPER_COREML
|
||||
CXXFLAGS += -DWHISPER_USE_COREML
|
||||
LDFLAGS += -framework Foundation -framework CoreML
|
||||
|
||||
ifdef WHISPER_COREML_ALLOW_FALLBACK
|
||||
CXXFLAGS += -DWHISPER_COREML_ALLOW_FALLBACK
|
||||
endif
|
||||
endif
|
||||
|
||||
ifdef WHISPER_OPENBLAS
|
||||
CFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/openblas
|
||||
LDFLAGS += -lopenblas
|
||||
endif
|
||||
|
||||
ifdef WHISPER_CUBLAS
|
||||
CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
|
||||
CXXFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
|
||||
LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib
|
||||
WHISPER_OBJ += ggml-cuda.o
|
||||
NVCC = nvcc
|
||||
NVCCFLAGS = --forward-unknown-to-host-compiler -arch=native
|
||||
|
||||
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
|
||||
$(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@
|
||||
endif
|
||||
|
||||
ifdef WHISPER_GPROF
|
||||
CFLAGS += -pg
|
||||
CXXFLAGS += -pg
|
||||
CFLAGS += -pg
|
||||
CXXFLAGS += -pg
|
||||
endif
|
||||
|
||||
ifneq ($(filter aarch64%,$(UNAME_M)),)
|
||||
CFLAGS += -mcpu=native
|
||||
CXXFLAGS += -mcpu=native
|
||||
endif
|
||||
|
||||
ifneq ($(filter armv6%,$(UNAME_M)),)
|
||||
# Raspberry Pi 1, 2, 3
|
||||
CFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access
|
||||
# 32-bit Raspberry Pi 1, 2, 3
|
||||
CFLAGS += -mfpu=neon -mfp16-format=ieee -mno-unaligned-access
|
||||
endif
|
||||
|
||||
ifneq ($(filter armv7%,$(UNAME_M)),)
|
||||
# Raspberry Pi 4
|
||||
CFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations
|
||||
# 32-bit ARM, for example on Armbian or possibly raspbian
|
||||
CFLAGS += -mfpu=neon -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations
|
||||
|
||||
# 64-bit ARM, use these (TODO: auto-detect 64-bit)
|
||||
# CFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations
|
||||
endif
|
||||
|
||||
ifneq ($(filter armv8%,$(UNAME_M)),)
|
||||
# Raspberry Pi 4
|
||||
CFLAGS += -mfp16-format=ieee -mno-unaligned-access
|
||||
endif
|
||||
|
||||
default: main
|
||||
#
|
||||
# Print build information
|
||||
#
|
||||
|
||||
$(info I whisper.cpp build info: )
|
||||
$(info I UNAME_S: $(UNAME_S))
|
||||
$(info I UNAME_P: $(UNAME_P))
|
||||
$(info I UNAME_M: $(UNAME_M))
|
||||
$(info I CFLAGS: $(CFLAGS))
|
||||
$(info I CXXFLAGS: $(CXXFLAGS))
|
||||
$(info I LDFLAGS: $(LDFLAGS))
|
||||
$(info I CC: $(CCV))
|
||||
$(info I CXX: $(CXXV))
|
||||
$(info )
|
||||
|
||||
#
|
||||
# Build library
|
||||
#
|
||||
|
||||
ggml.o: ggml.c ggml.h
|
||||
$(CC) $(CFLAGS) -c ggml.c -o ggml.o
|
||||
ggml.o: ggml.c ggml.h ggml-cuda.h
|
||||
$(CC) $(CFLAGS) -c $< -o $@
|
||||
|
||||
whisper.o: whisper.cpp whisper.h
|
||||
$(CXX) $(CXXFLAGS) -c whisper.cpp -o whisper.o
|
||||
whisper.o: whisper.cpp whisper.h ggml.h ggml-cuda.h
|
||||
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||
|
||||
libwhisper.a: ggml.o whisper.o
|
||||
$(AR) rcs libwhisper.a ggml.o whisper.o
|
||||
ifndef WHISPER_COREML
|
||||
WHISPER_OBJ += whisper.o
|
||||
else
|
||||
whisper-encoder.o: coreml/whisper-encoder.mm coreml/whisper-encoder.h
|
||||
$(CXX) -O3 -I . -c coreml/whisper-encoder.mm -o whisper-encoder.o
|
||||
|
||||
libwhisper.so: ggml.o whisper.o
|
||||
$(CXX) $(CXXFLAGS) -shared -o libwhisper.so ggml.o whisper.o $(LDFLAGS)
|
||||
whisper-encoder-impl.o: coreml/whisper-encoder-impl.m coreml/whisper-encoder-impl.h
|
||||
$(CXX) -O3 -I . -fobjc-arc -c coreml/whisper-encoder-impl.m -o whisper-encoder-impl.o
|
||||
|
||||
WHISPER_OBJ += whisper.o whisper-encoder.o whisper-encoder-impl.o
|
||||
endif
|
||||
|
||||
libwhisper.a: ggml.o $(WHISPER_OBJ)
|
||||
$(AR) rcs libwhisper.a ggml.o $(WHISPER_OBJ)
|
||||
|
||||
libwhisper.so: ggml.o $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) -shared -o libwhisper.so ggml.o $(WHISPER_OBJ) $(LDFLAGS)
|
||||
|
||||
clean:
|
||||
rm -f *.o main stream command talk bench libwhisper.a libwhisper.so
|
||||
rm -f *.o main stream command talk talk-llama bench quantize libwhisper.a libwhisper.so
|
||||
|
||||
#
|
||||
# Examples
|
||||
@ -162,21 +251,30 @@ clean:
|
||||
|
||||
CC_SDL=`sdl2-config --cflags --libs`
|
||||
|
||||
main: examples/main/main.cpp ggml.o whisper.o
|
||||
$(CXX) $(CXXFLAGS) examples/main/main.cpp ggml.o whisper.o -o main $(LDFLAGS)
|
||||
SRC_COMMON = examples/common.cpp examples/common-ggml.cpp
|
||||
SRC_COMMON_SDL = examples/common-sdl.cpp
|
||||
|
||||
main: examples/main/main.cpp $(SRC_COMMON) ggml.o $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) examples/main/main.cpp $(SRC_COMMON) ggml.o $(WHISPER_OBJ) -o main $(LDFLAGS)
|
||||
./main -h
|
||||
|
||||
stream: examples/stream/stream.cpp ggml.o whisper.o
|
||||
$(CXX) $(CXXFLAGS) examples/stream/stream.cpp ggml.o whisper.o -o stream $(CC_SDL) $(LDFLAGS)
|
||||
bench: examples/bench/bench.cpp ggml.o $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) examples/bench/bench.cpp ggml.o $(WHISPER_OBJ) -o bench $(LDFLAGS)
|
||||
|
||||
command: examples/command/command.cpp ggml.o whisper.o
|
||||
$(CXX) $(CXXFLAGS) examples/command/command.cpp ggml.o whisper.o -o command $(CC_SDL) $(LDFLAGS)
|
||||
quantize: examples/quantize/quantize.cpp ggml.o $(WHISPER_OBJ) $(SRC_COMMON)
|
||||
$(CXX) $(CXXFLAGS) examples/quantize/quantize.cpp $(SRC_COMMON) ggml.o $(WHISPER_OBJ) -o quantize $(LDFLAGS)
|
||||
|
||||
talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp ggml.o whisper.o
|
||||
$(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp ggml.o whisper.o -o talk $(CC_SDL) $(LDFLAGS)
|
||||
stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o stream $(CC_SDL) $(LDFLAGS)
|
||||
|
||||
bench: examples/bench/bench.cpp ggml.o whisper.o
|
||||
$(CXX) $(CXXFLAGS) examples/bench/bench.cpp ggml.o whisper.o -o bench $(LDFLAGS)
|
||||
command: examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o command $(CC_SDL) $(LDFLAGS)
|
||||
|
||||
talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o talk $(CC_SDL) $(LDFLAGS)
|
||||
|
||||
talk-llama: examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o talk-llama $(CC_SDL) $(LDFLAGS)
|
||||
|
||||
#
|
||||
# Audio samples
|
||||
|
349
README.md
349
README.md
@ -1,20 +1,25 @@
|
||||
# whisper.cpp
|
||||
|
||||

|
||||
|
||||
[](https://github.com/ggerganov/whisper.cpp/actions)
|
||||
[](https://opensource.org/licenses/MIT)
|
||||
[](https://www.npmjs.com/package/whisper.cpp/)
|
||||
|
||||
[Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126)
|
||||
Beta: [v1.4.0](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.4.0) / Stable: [v1.2.1](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.2.1) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126)
|
||||
|
||||
High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model:
|
||||
|
||||
- Plain C/C++ implementation without dependencies
|
||||
- Apple silicon first-class citizen - optimized via Arm Neon and Accelerate framework
|
||||
- Apple silicon first-class citizen - optimized via ARM NEON, Accelerate framework and [Core ML](https://github.com/ggerganov/whisper.cpp#core-ml-support)
|
||||
- AVX intrinsics support for x86 architectures
|
||||
- VSX intrinsics support for POWER architectures
|
||||
- Mixed F16 / F32 precision
|
||||
- Low memory usage (Flash Attention + Flash Forward)
|
||||
- [4-bit and 5-bit integer quantization support](https://github.com/ggerganov/whisper.cpp#quantization)
|
||||
- Low memory usage (Flash Attention)
|
||||
- Zero memory allocations at runtime
|
||||
- Runs on the CPU
|
||||
- [Partial GPU support for NVIDIA via cuBLAS](https://github.com/ggerganov/whisper.cpp#nvidia-gpu-support-via-cublas)
|
||||
- [C-style API](https://github.com/ggerganov/whisper.cpp/blob/master/whisper.h)
|
||||
|
||||
Supported platforms:
|
||||
@ -57,7 +62,9 @@ the Accelerate framework utilizes the special-purpose AMX coprocessor available
|
||||
|
||||
## Quick start
|
||||
|
||||
First, download one of the Whisper models converted in [ggml format](models). For example:
|
||||
First clone the repository.
|
||||
|
||||
Then, download one of the Whisper models converted in [ggml format](models). For example:
|
||||
|
||||
```bash
|
||||
bash ./models/download-ggml-model.sh base.en
|
||||
@ -70,7 +77,7 @@ Now build the [main](examples/main) example and transcribe an audio file like th
|
||||
make
|
||||
|
||||
# transcribe an audio file
|
||||
./main -f input.wav
|
||||
./main -f samples/jfk.wav
|
||||
```
|
||||
|
||||
---
|
||||
@ -88,27 +95,38 @@ c++ -I. -I./examples -O3 -std=c++11 -pthread examples/main/main.cpp whisper.o gg
|
||||
usage: ./main [options] file0.wav file1.wav ...
|
||||
|
||||
options:
|
||||
-h, --help [default] show this help message and exit
|
||||
-t N, --threads N [4 ] number of threads to use during computation
|
||||
-p N, --processors N [1 ] number of processors to use during computation
|
||||
-ot N, --offset-t N [0 ] time offset in milliseconds
|
||||
-on N, --offset-n N [0 ] segment index offset
|
||||
-d N, --duration N [0 ] duration of audio to process in milliseconds
|
||||
-mc N, --max-context N [-1 ] maximum number of text context tokens to store
|
||||
-ml N, --max-len N [0 ] maximum segment length in characters
|
||||
-wt N, --word-thold N [0.01 ] word timestamp probability threshold
|
||||
-su, --speed-up [false ] speed up audio by x2 (reduced accuracy)
|
||||
-tr, --translate [false ] translate from source language to english
|
||||
-otxt, --output-txt [false ] output result in a text file
|
||||
-ovtt, --output-vtt [false ] output result in a vtt file
|
||||
-osrt, --output-srt [false ] output result in a srt file
|
||||
-owts, --output-words [false ] output script for generating karaoke video
|
||||
-ps, --print-special [false ] print special tokens
|
||||
-pc, --print-colors [false ] print colors
|
||||
-nt, --no-timestamps [true ] do not print timestamps
|
||||
-l LANG, --language LANG [en ] spoken language
|
||||
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path
|
||||
-f FNAME, --file FNAME [ ] input WAV file path
|
||||
-h, --help [default] show this help message and exit
|
||||
-t N, --threads N [4 ] number of threads to use during computation
|
||||
-p N, --processors N [1 ] number of processors to use during computation
|
||||
-ot N, --offset-t N [0 ] time offset in milliseconds
|
||||
-on N, --offset-n N [0 ] segment index offset
|
||||
-d N, --duration N [0 ] duration of audio to process in milliseconds
|
||||
-mc N, --max-context N [-1 ] maximum number of text context tokens to store
|
||||
-ml N, --max-len N [0 ] maximum segment length in characters
|
||||
-bo N, --best-of N [5 ] number of best candidates to keep
|
||||
-bs N, --beam-size N [-1 ] beam size for beam search
|
||||
-wt N, --word-thold N [0.01 ] word timestamp probability threshold
|
||||
-et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail
|
||||
-lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail
|
||||
-su, --speed-up [false ] speed up audio by x2 (reduced accuracy)
|
||||
-tr, --translate [false ] translate from source language to english
|
||||
-di, --diarize [false ] stereo audio diarization
|
||||
-nf, --no-fallback [false ] do not use temperature fallback while decoding
|
||||
-otxt, --output-txt [false ] output result in a text file
|
||||
-ovtt, --output-vtt [false ] output result in a vtt file
|
||||
-osrt, --output-srt [false ] output result in a srt file
|
||||
-owts, --output-words [false ] output script for generating karaoke video
|
||||
-ocsv, --output-csv [false ] output result in a CSV file
|
||||
-of FNAME, --output-file FNAME [ ] output file path (without file extension)
|
||||
-ps, --print-special [false ] print special tokens
|
||||
-pc, --print-colors [false ] print colors
|
||||
-pp, --print-progress [false ] print progress
|
||||
-nt, --no-timestamps [true ] do not print timestamps
|
||||
-l LANG, --language LANG [en ] spoken language ('auto' for auto-detect)
|
||||
--prompt PROMPT [ ] initial prompt
|
||||
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path
|
||||
-f FNAME, --file FNAME [ ] input WAV file path
|
||||
|
||||
|
||||
bash ./models/download-ggml-model.sh base.en
|
||||
Downloading ggml model base.en ...
|
||||
@ -127,7 +145,8 @@ Running base.en on all samples in ./samples ...
|
||||
[+] Running base.en on samples/jfk.wav ... (run 'ffplay samples/jfk.wav' to listen)
|
||||
----------------------------------------------
|
||||
|
||||
whisper_model_load: loading model from 'models/ggml-base.en.bin'
|
||||
whisper_init_from_file: loading model from 'models/ggml-base.en.bin'
|
||||
whisper_model_load: loading model
|
||||
whisper_model_load: n_vocab = 51864
|
||||
whisper_model_load: n_audio_ctx = 1500
|
||||
whisper_model_load: n_audio_state = 512
|
||||
@ -140,13 +159,14 @@ whisper_model_load: n_text_layer = 6
|
||||
whisper_model_load: n_mels = 80
|
||||
whisper_model_load: f16 = 1
|
||||
whisper_model_load: type = 2
|
||||
whisper_model_load: mem required = 215.00 MB (+ 6.00 MB per decoder)
|
||||
whisper_model_load: kv self size = 5.25 MB
|
||||
whisper_model_load: kv cross size = 17.58 MB
|
||||
whisper_model_load: adding 1607 extra tokens
|
||||
whisper_model_load: mem_required = 506.00 MB
|
||||
whisper_model_load: ggml ctx size = 140.60 MB
|
||||
whisper_model_load: memory size = 22.83 MB
|
||||
whisper_model_load: model ctx = 140.60 MB
|
||||
whisper_model_load: model size = 140.54 MB
|
||||
|
||||
system_info: n_threads = 4 / 10 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | NEON = 1 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 |
|
||||
system_info: n_threads = 4 / 10 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 |
|
||||
|
||||
main: processing 'samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 processors, lang = en, task = transcribe, timestamps = 1 ...
|
||||
|
||||
@ -154,12 +174,13 @@ main: processing 'samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 proc
|
||||
[00:00:00.000 --> 00:00:11.000] And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country.
|
||||
|
||||
|
||||
whisper_print_timings: load time = 105.91 ms
|
||||
whisper_print_timings: mel time = 24.62 ms
|
||||
whisper_print_timings: sample time = 3.63 ms
|
||||
whisper_print_timings: encode time = 324.71 ms / 54.12 ms per layer
|
||||
whisper_print_timings: decode time = 83.58 ms / 13.93 ms per layer
|
||||
whisper_print_timings: total time = 542.81 ms
|
||||
whisper_print_timings: fallbacks = 0 p / 0 h
|
||||
whisper_print_timings: load time = 113.81 ms
|
||||
whisper_print_timings: mel time = 15.40 ms
|
||||
whisper_print_timings: sample time = 11.58 ms / 27 runs ( 0.43 ms per run)
|
||||
whisper_print_timings: encode time = 266.60 ms / 1 runs ( 266.60 ms per run)
|
||||
whisper_print_timings: decode time = 66.11 ms / 27 runs ( 2.45 ms per run)
|
||||
whisper_print_timings: total time = 476.31 ms
|
||||
```
|
||||
|
||||
The command downloads the `base.en` model converted to custom `ggml` format and runs the inference on all `.wav` samples in the folder `samples`.
|
||||
@ -202,26 +223,99 @@ make large
|
||||
|
||||
| Model | Disk | Mem | SHA |
|
||||
| --- | --- | --- | --- |
|
||||
| tiny | 75 MB | ~390 MB | `bd577a113a864445d4c299885e0cb97d4ba92b5f` |
|
||||
| base | 142 MB | ~500 MB | `465707469ff3a37a2b9b8d8f89f2f99de7299dac` |
|
||||
| small | 466 MB | ~1.0 GB | `55356645c2b361a969dfd0ef2c5a50d530afd8d5` |
|
||||
| medium | 1.5 GB | ~2.6 GB | `fd9727b6e1217c2f614f9b698455c4ffd82463b4` |
|
||||
| large | 2.9 GB | ~4.7 GB | `0f4c8e34f21cf1a914c59d8b3ce882345ad349d6` |
|
||||
| tiny | 75 MB | ~125 MB | `bd577a113a864445d4c299885e0cb97d4ba92b5f` |
|
||||
| base | 142 MB | ~210 MB | `465707469ff3a37a2b9b8d8f89f2f99de7299dac` |
|
||||
| small | 466 MB | ~600 MB | `55356645c2b361a969dfd0ef2c5a50d530afd8d5` |
|
||||
| medium | 1.5 GB | ~1.7 GB | `fd9727b6e1217c2f614f9b698455c4ffd82463b4` |
|
||||
| large | 2.9 GB | ~3.3 GB | `0f4c8e34f21cf1a914c59d8b3ce882345ad349d6` |
|
||||
|
||||
## Quantization
|
||||
|
||||
`whisper.cpp` supports integer quantization of the Whisper `ggml` models.
|
||||
Quantized models require less memory and disk space and depending on the hardware can be processed more efficiently.
|
||||
|
||||
Here are the steps for creating and using a quantized model:
|
||||
|
||||
```bash
|
||||
# quantize a model with Q5_0 method
|
||||
make quantize
|
||||
./quantize models/ggml-base.en.bin models/ggml-base.en-q5_0.bin q5_0
|
||||
|
||||
# run the examples as usual, specifying the quantized model file
|
||||
./main -m models/ggml-base.en-q5_0.bin ./samples/gb0.wav
|
||||
```
|
||||
|
||||
## Core ML support
|
||||
|
||||
On Apple Silicon devices, the Encoder inference can be executed on the Apple Neural Engine (ANE) via Core ML. This can result in significant
|
||||
speed-up - more than x3 faster compared with CPU-only execution. Here are the instructions for generating a Core ML model and using it with `whisper.cpp`:
|
||||
|
||||
- Install Python dependencies needed for the creation of the Core ML model:
|
||||
|
||||
```bash
|
||||
pip install ane_transformers
|
||||
pip install openai-whisper
|
||||
pip install coremltools
|
||||
```
|
||||
|
||||
- Generate a Core ML model. For example, to generate a `base.en` model, use:
|
||||
|
||||
```bash
|
||||
./models/generate-coreml-model.sh base.en
|
||||
```
|
||||
|
||||
This will generate the folder `models/ggml-base.en-encoder.mlmodelc`
|
||||
|
||||
- Build `whisper.cpp` with Core ML support:
|
||||
|
||||
```bash
|
||||
# using Makefile
|
||||
make clean
|
||||
WHISPER_COREML=1 make -j
|
||||
|
||||
# using CMake
|
||||
cd build
|
||||
cmake -DWHISPER_COREML=1 ..
|
||||
```
|
||||
|
||||
- Run the examples as usual. For example:
|
||||
|
||||
```bash
|
||||
./main -m models/ggml-base.en.bin -f samples/jfk.wav
|
||||
|
||||
...
|
||||
|
||||
whisper_init_state: loading Core ML model from 'models/ggml-base.en-encoder.mlmodelc'
|
||||
whisper_init_state: first run on a device may take a while ...
|
||||
whisper_init_state: Core ML model loaded
|
||||
|
||||
system_info: n_threads = 4 / 10 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 | COREML = 1 |
|
||||
|
||||
...
|
||||
```
|
||||
|
||||
The first run on a device is slow, since the ANE service compiles the Core ML model to some device-specific format.
|
||||
Next runs are faster.
|
||||
|
||||
For more information about the Core ML implementation please refer to PR [#566](https://github.com/ggerganov/whisper.cpp/pull/566).
|
||||
|
||||
## NVIDIA GPU support via cuBLAS
|
||||
|
||||
With NVIDIA cards, the Encoder processing can be offloaded to the GPU to a large extend through cuBLAS.
|
||||
First, make sure you have installed `cuda`: https://developer.nvidia.com/cuda-downloads
|
||||
|
||||
Now build `whisper.cpp` with cuBLAS support:
|
||||
|
||||
```
|
||||
make clean
|
||||
WHISPER_CUBLAS=1 make -j
|
||||
```
|
||||
|
||||
Run all the examples as usual.
|
||||
|
||||
## Limitations
|
||||
|
||||
- Inference only
|
||||
- No GPU support
|
||||
- Very basic greedy sampling scheme - always pick up the token with highest probability.
|
||||
This should be similar to the [GreedyDecoder](https://github.com/openai/whisper/blob/main/whisper/decoding.py#L249-L274)
|
||||
from the original python implementation, so in order to make a fair comparison between the 2 implementations, make sure
|
||||
to run the python code with the following parameters:
|
||||
|
||||
```
|
||||
whisper --best_of None --beam_size None ...
|
||||
```
|
||||
|
||||
In the future, `whisper.cpp` will support more sampling strategies.
|
||||
|
||||
## Another example
|
||||
|
||||
@ -234,7 +328,8 @@ in about half a minute on a MacBook M1 Pro, using `medium.en` model:
|
||||
```java
|
||||
$ ./main -m models/ggml-medium.en.bin -f samples/gb1.wav -t 8
|
||||
|
||||
whisper_model_load: loading model from 'models/ggml-medium.en.bin'
|
||||
whisper_init_from_file: loading model from 'models/ggml-medium.en.bin'
|
||||
whisper_model_load: loading model
|
||||
whisper_model_load: n_vocab = 51864
|
||||
whisper_model_load: n_audio_ctx = 1500
|
||||
whisper_model_load: n_audio_state = 1024
|
||||
@ -247,65 +342,71 @@ whisper_model_load: n_text_layer = 24
|
||||
whisper_model_load: n_mels = 80
|
||||
whisper_model_load: f16 = 1
|
||||
whisper_model_load: type = 4
|
||||
whisper_model_load: mem_required = 2610.00 MB
|
||||
whisper_model_load: mem required = 1720.00 MB (+ 43.00 MB per decoder)
|
||||
whisper_model_load: kv self size = 42.00 MB
|
||||
whisper_model_load: kv cross size = 140.62 MB
|
||||
whisper_model_load: adding 1607 extra tokens
|
||||
whisper_model_load: ggml ctx size = 1644.97 MB
|
||||
whisper_model_load: memory size = 182.62 MB
|
||||
whisper_model_load: model size = 1462.12 MB
|
||||
whisper_model_load: model ctx = 1462.35 MB
|
||||
whisper_model_load: model size = 1462.12 MB
|
||||
|
||||
main: processing 'samples/gb1.wav' (3179750 samples, 198.7 sec), 8 threads, lang = en, task = transcribe, timestamps = 1 ...
|
||||
system_info: n_threads = 8 / 10 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 |
|
||||
|
||||
[00:00.000 --> 00:08.000] My fellow Americans, this day has brought terrible news and great sadness to our country.
|
||||
[00:08.000 --> 00:17.000] At nine o'clock this morning, Mission Control in Houston lost contact with our Space Shuttle Columbia.
|
||||
[00:17.000 --> 00:23.000] A short time later, debris was seen falling from the skies above Texas.
|
||||
[00:23.000 --> 00:29.000] The Columbia's lost. There are no survivors.
|
||||
[00:29.000 --> 00:32.000] On board was a crew of seven.
|
||||
[00:32.000 --> 00:39.000] Colonel Rick Husband, Lieutenant Colonel Michael Anderson, Commander Laurel Clark,
|
||||
[00:39.000 --> 00:48.000] Captain David Brown, Commander William McCool, Dr. Kultna Shavla, and Ilan Ramon,
|
||||
[00:48.000 --> 00:52.000] a colonel in the Israeli Air Force.
|
||||
[00:52.000 --> 00:58.000] These men and women assumed great risk in the service to all humanity.
|
||||
[00:58.000 --> 01:03.000] In an age when space flight has come to seem almost routine,
|
||||
[01:03.000 --> 01:07.000] it is easy to overlook the dangers of travel by rocket
|
||||
[01:07.000 --> 01:12.000] and the difficulties of navigating the fierce outer atmosphere of the Earth.
|
||||
[01:12.000 --> 01:18.000] These astronauts knew the dangers, and they faced them willingly,
|
||||
[01:18.000 --> 01:23.000] knowing they had a high and noble purpose in life.
|
||||
[01:23.000 --> 01:31.000] Because of their courage and daring and idealism, we will miss them all the more.
|
||||
[01:31.000 --> 01:36.000] All Americans today are thinking as well of the families of these men and women
|
||||
[01:36.000 --> 01:40.000] who have been given this sudden shock and grief.
|
||||
[01:40.000 --> 01:45.000] You're not alone. Our entire nation grieves with you,
|
||||
[01:45.000 --> 01:52.000] and those you love will always have the respect and gratitude of this country.
|
||||
[01:52.000 --> 01:56.000] The cause in which they died will continue.
|
||||
[01:56.000 --> 02:04.000] Mankind is led into the darkness beyond our world by the inspiration of discovery
|
||||
[02:04.000 --> 02:11.000] and the longing to understand. Our journey into space will go on.
|
||||
[02:11.000 --> 02:16.000] In the skies today, we saw destruction and tragedy.
|
||||
[02:16.000 --> 02:22.000] Yet farther than we can see, there is comfort and hope.
|
||||
[02:22.000 --> 02:29.000] In the words of the prophet Isaiah, "Lift your eyes and look to the heavens
|
||||
[02:29.000 --> 02:35.000] who created all these. He who brings out the starry hosts one by one
|
||||
[02:35.000 --> 02:39.000] and calls them each by name."
|
||||
[02:39.000 --> 02:46.000] Because of His great power and mighty strength, not one of them is missing.
|
||||
[02:46.000 --> 02:55.000] The same Creator who names the stars also knows the names of the seven souls we mourn today.
|
||||
[02:55.000 --> 03:01.000] The crew of the shuttle Columbia did not return safely to earth,
|
||||
[03:01.000 --> 03:05.000] yet we can pray that all are safely home.
|
||||
[03:05.000 --> 03:13.000] May God bless the grieving families, and may God continue to bless America.
|
||||
[03:13.000 --> 03:41.000] Audio
|
||||
main: processing 'samples/gb1.wav' (3179750 samples, 198.7 sec), 8 threads, 1 processors, lang = en, task = transcribe, timestamps = 1 ...
|
||||
|
||||
|
||||
whisper_print_timings: load time = 575.92 ms
|
||||
whisper_print_timings: mel time = 230.60 ms
|
||||
whisper_print_timings: sample time = 73.19 ms
|
||||
whisper_print_timings: encode time = 19552.61 ms / 814.69 ms per layer
|
||||
whisper_print_timings: decode time = 13249.96 ms / 552.08 ms per layer
|
||||
whisper_print_timings: total time = 33686.27 ms
|
||||
[00:00:00.000 --> 00:00:08.000] My fellow Americans, this day has brought terrible news and great sadness to our country.
|
||||
[00:00:08.000 --> 00:00:17.000] At nine o'clock this morning, Mission Control in Houston lost contact with our Space Shuttle Columbia.
|
||||
[00:00:17.000 --> 00:00:23.000] A short time later, debris was seen falling from the skies above Texas.
|
||||
[00:00:23.000 --> 00:00:29.000] The Columbia's lost. There are no survivors.
|
||||
[00:00:29.000 --> 00:00:32.000] On board was a crew of seven.
|
||||
[00:00:32.000 --> 00:00:39.000] Colonel Rick Husband, Lieutenant Colonel Michael Anderson, Commander Laurel Clark,
|
||||
[00:00:39.000 --> 00:00:48.000] Captain David Brown, Commander William McCool, Dr. Kultna Shavla, and Ilan Ramon,
|
||||
[00:00:48.000 --> 00:00:52.000] a colonel in the Israeli Air Force.
|
||||
[00:00:52.000 --> 00:00:58.000] These men and women assumed great risk in the service to all humanity.
|
||||
[00:00:58.000 --> 00:01:03.000] In an age when space flight has come to seem almost routine,
|
||||
[00:01:03.000 --> 00:01:07.000] it is easy to overlook the dangers of travel by rocket
|
||||
[00:01:07.000 --> 00:01:12.000] and the difficulties of navigating the fierce outer atmosphere of the Earth.
|
||||
[00:01:12.000 --> 00:01:18.000] These astronauts knew the dangers, and they faced them willingly,
|
||||
[00:01:18.000 --> 00:01:23.000] knowing they had a high and noble purpose in life.
|
||||
[00:01:23.000 --> 00:01:31.000] Because of their courage and daring and idealism, we will miss them all the more.
|
||||
[00:01:31.000 --> 00:01:36.000] All Americans today are thinking as well of the families of these men and women
|
||||
[00:01:36.000 --> 00:01:40.000] who have been given this sudden shock and grief.
|
||||
[00:01:40.000 --> 00:01:45.000] You're not alone. Our entire nation grieves with you,
|
||||
[00:01:45.000 --> 00:01:52.000] and those you love will always have the respect and gratitude of this country.
|
||||
[00:01:52.000 --> 00:01:56.000] The cause in which they died will continue.
|
||||
[00:01:56.000 --> 00:02:04.000] Mankind is led into the darkness beyond our world by the inspiration of discovery
|
||||
[00:02:04.000 --> 00:02:11.000] and the longing to understand. Our journey into space will go on.
|
||||
[00:02:11.000 --> 00:02:16.000] In the skies today, we saw destruction and tragedy.
|
||||
[00:02:16.000 --> 00:02:22.000] Yet farther than we can see, there is comfort and hope.
|
||||
[00:02:22.000 --> 00:02:29.000] In the words of the prophet Isaiah, "Lift your eyes and look to the heavens
|
||||
[00:02:29.000 --> 00:02:35.000] who created all these. He who brings out the starry hosts one by one
|
||||
[00:02:35.000 --> 00:02:39.000] and calls them each by name."
|
||||
[00:02:39.000 --> 00:02:46.000] Because of His great power and mighty strength, not one of them is missing.
|
||||
[00:02:46.000 --> 00:02:55.000] The same Creator who names the stars also knows the names of the seven souls we mourn today.
|
||||
[00:02:55.000 --> 00:03:01.000] The crew of the shuttle Columbia did not return safely to earth,
|
||||
[00:03:01.000 --> 00:03:05.000] yet we can pray that all are safely home.
|
||||
[00:03:05.000 --> 00:03:13.000] May God bless the grieving families, and may God continue to bless America.
|
||||
[00:03:13.000 --> 00:03:19.000] [Silence]
|
||||
|
||||
|
||||
whisper_print_timings: fallbacks = 1 p / 0 h
|
||||
whisper_print_timings: load time = 569.03 ms
|
||||
whisper_print_timings: mel time = 146.85 ms
|
||||
whisper_print_timings: sample time = 238.66 ms / 553 runs ( 0.43 ms per run)
|
||||
whisper_print_timings: encode time = 18665.10 ms / 9 runs ( 2073.90 ms per run)
|
||||
whisper_print_timings: decode time = 13090.93 ms / 549 runs ( 23.85 ms per run)
|
||||
whisper_print_timings: total time = 32733.52 ms
|
||||
```
|
||||
</details>
|
||||
|
||||
## Real-time audio input example
|
||||
|
||||
This is a naive example of performing real-time inference on audio from your microphone.
|
||||
The [stream](examples/stream) tool samples the audio every half a second and runs the transcription continously.
|
||||
The [stream](examples/stream) tool samples the audio every half a second and runs the transcription continuously.
|
||||
More info is available in [issue #10](https://github.com/ggerganov/whisper.cpp/issues/10).
|
||||
|
||||
```java
|
||||
make stream
|
||||
./stream -m ./models/ggml-base.en.bin -t 8 --step 500 --length 5000
|
||||
```
|
||||
|
||||
@ -316,18 +417,22 @@ https://user-images.githubusercontent.com/1991296/194935793-76afede7-cfa8-48d8-a
|
||||
Adding the `--print-colors` argument will print the transcribed text using an experimental color coding strategy
|
||||
to highlight words with high or low confidence:
|
||||
|
||||
```java
|
||||
./main -m models/ggml-base.en.bin -f samples/gb0.wav --print-colors
|
||||
```
|
||||
|
||||
<img width="965" alt="image" src="https://user-images.githubusercontent.com/1991296/197356445-311c8643-9397-4e5e-b46e-0b4b4daa2530.png">
|
||||
|
||||
## Controlling the length of the generated text segments (experimental)
|
||||
|
||||
For example, to limit the line length to a maximum of 16 characters, simply add `-ml 16`:
|
||||
For example, to limit the line length to a maximum of 16 characters, simply add `-ml 16`:
|
||||
|
||||
```java
|
||||
./main -m ./models/ggml-base.en.bin -f ./samples/jfk.wav -ml 16
|
||||
|
||||
whisper_model_load: loading model from './models/ggml-base.en.bin'
|
||||
...
|
||||
system_info: n_threads = 4 / 10 | AVX2 = 0 | AVX512 = 0 | NEON = 1 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 |
|
||||
system_info: n_threads = 4 / 10 | AVX2 = 0 | AVX512 = 0 | NEON = 1 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 |
|
||||
|
||||
main: processing './samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 processors, lang = en, task = transcribe, timestamps = 1 ...
|
||||
|
||||
@ -351,11 +456,11 @@ The `--max-len` argument can be used to obtain word-level timestamps. Simply use
|
||||
|
||||
whisper_model_load: loading model from './models/ggml-base.en.bin'
|
||||
...
|
||||
system_info: n_threads = 4 / 10 | AVX2 = 0 | AVX512 = 0 | NEON = 1 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 |
|
||||
system_info: n_threads = 4 / 10 | AVX2 = 0 | AVX512 = 0 | NEON = 1 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 |
|
||||
|
||||
main: processing './samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 processors, lang = en, task = transcribe, timestamps = 1 ...
|
||||
|
||||
[00:00:00.000 --> 00:00:00.320]
|
||||
[00:00:00.000 --> 00:00:00.320]
|
||||
[00:00:00.320 --> 00:00:00.370] And
|
||||
[00:00:00.370 --> 00:00:00.690] so
|
||||
[00:00:00.690 --> 00:00:00.850] my
|
||||
@ -421,6 +526,19 @@ https://user-images.githubusercontent.com/1991296/199337538-b7b0c7a3-2753-4a88-a
|
||||
|
||||
---
|
||||
|
||||
## Video comparison of different models
|
||||
|
||||
Use the [extra/bench-wts.sh](https://github.com/ggerganov/whisper.cpp/blob/master/extra/bench-wts.sh) script to generate a video in the following format:
|
||||
|
||||
```java
|
||||
./extra/bench-wts.sh samples/jfk.wav
|
||||
ffplay ./samples/jfk.wav.all.mp4
|
||||
```
|
||||
|
||||
https://user-images.githubusercontent.com/1991296/223206245-2d36d903-cf8e-4f09-8c3b-eb9f9c39d6fc.mp4
|
||||
|
||||
---
|
||||
|
||||
## Benchmarks
|
||||
|
||||
In order to have an objective comparison of the performance of the inference across different system configurations,
|
||||
@ -441,18 +559,29 @@ The original models are converted to a custom binary format. This allows to pack
|
||||
You can download the converted models using the [models/download-ggml-model.sh](models/download-ggml-model.sh) script
|
||||
or manually from here:
|
||||
|
||||
- https://huggingface.co/datasets/ggerganov/whisper.cpp
|
||||
- https://huggingface.co/ggerganov/whisper.cpp
|
||||
- https://ggml.ggerganov.com
|
||||
|
||||
For more details, see the conversion script [models/convert-pt-to-ggml.py](models/convert-pt-to-ggml.py) or the README
|
||||
in [models](models).
|
||||
|
||||
## Bindings
|
||||
## [Bindings](https://github.com/ggerganov/whisper.cpp/discussions/categories/bindings)
|
||||
|
||||
- [X] Rust: [tazz4843/whisper-rs](https://github.com/tazz4843/whisper-rs)
|
||||
- [X] Objective-C / Swift: [ggerganov/whisper.spm](https://github.com/ggerganov/whisper.spm)
|
||||
- [X] Javascript: [bindings/javascript](bindings/javascript)
|
||||
- [ ] Python: soon
|
||||
- [X] Rust: [tazz4843/whisper-rs](https://github.com/tazz4843/whisper-rs) | [#310](https://github.com/ggerganov/whisper.cpp/discussions/310)
|
||||
- [X] Javascript: [bindings/javascript](bindings/javascript) | [#309](https://github.com/ggerganov/whisper.cpp/discussions/309)
|
||||
- React Native (iOS / Android): [whisper.rn](https://github.com/mybigday/whisper.rn)
|
||||
- [X] Go: [bindings/go](bindings/go) | [#312](https://github.com/ggerganov/whisper.cpp/discussions/312)
|
||||
- [X] Ruby: [bindings/ruby](bindings/ruby) | [#507](https://github.com/ggerganov/whisper.cpp/discussions/507)
|
||||
- [X] Objective-C / Swift: [ggerganov/whisper.spm](https://github.com/ggerganov/whisper.spm) | [#313](https://github.com/ggerganov/whisper.cpp/discussions/313)
|
||||
- [exPHAT/SwiftWhisper](https://github.com/exPHAT/SwiftWhisper)
|
||||
- [X] .NET: | [#422](https://github.com/ggerganov/whisper.cpp/discussions/422)
|
||||
- [sandrohanea/whisper.net](https://github.com/sandrohanea/whisper.net)
|
||||
- [NickDarvey/whisper](https://github.com/NickDarvey/whisper)
|
||||
- [X] Python: | [#9](https://github.com/ggerganov/whisper.cpp/issues/9)
|
||||
- [stlukey/whispercpp.py](https://github.com/stlukey/whispercpp.py) (Cython)
|
||||
- [aarnphm/whispercpp](https://github.com/aarnphm/whispercpp) (Pybind11)
|
||||
- [X] R: [bnosac/audio.whisper](https://github.com/bnosac/audio.whisper)
|
||||
- [X] Unity: [macoron/whisper.unity](https://github.com/Macoron/whisper.unity)
|
||||
|
||||
## Examples
|
||||
|
||||
@ -466,7 +595,9 @@ Some of the examples are even ported to run in the browser using WebAssembly. Ch
|
||||
| [stream](examples/stream) | [stream.wasm](examples/stream.wasm) | Real-time transcription of raw microphone capture |
|
||||
| [command](examples/command) | [command.wasm](examples/command.wasm) | Basic voice assistant example for receiving voice commands from the mic |
|
||||
| [talk](examples/talk) | [talk.wasm](examples/talk.wasm) | Talk with a GPT-2 bot |
|
||||
| [talk-llama](examples/talk-llama) | | Talk with a LLaMA bot |
|
||||
| [whisper.objc](examples/whisper.objc) | | iOS mobile application using whisper.cpp |
|
||||
| [whisper.swiftui](examples/whisper.swiftui) | | SwiftUI iOS / macOS application using whisper.cpp |
|
||||
| [whisper.android](examples/whisper.android) | | Android mobile application using whisper.cpp |
|
||||
| [whisper.nvim](examples/whisper.nvim) | | Speech-to-text plugin for Neovim |
|
||||
| [generate-karaoke.sh](examples/generate-karaoke.sh) | | Helper script to easily [generate a karaoke video](https://youtu.be/uj7hVta4blM) of raw audio capture |
|
||||
|
2
bindings/go/.gitignore
vendored
Normal file
2
bindings/go/.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
build
|
||||
models
|
21
bindings/go/LICENSE
Normal file
21
bindings/go/LICENSE
Normal file
@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2022 David Thorpe
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
38
bindings/go/Makefile
Normal file
38
bindings/go/Makefile
Normal file
@ -0,0 +1,38 @@
|
||||
BUILD_DIR := build
|
||||
MODELS_DIR := models
|
||||
EXAMPLES_DIR := $(wildcard examples/*)
|
||||
INCLUDE_PATH := $(abspath ../..)
|
||||
LIBRARY_PATH := $(abspath ../..)
|
||||
|
||||
all: clean whisper examples
|
||||
|
||||
whisper: mkdir
|
||||
@echo Build whisper
|
||||
@${MAKE} -C ../.. libwhisper.a
|
||||
|
||||
test: model-small whisper modtidy
|
||||
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go test -v .
|
||||
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go test -v ./pkg/whisper/...
|
||||
|
||||
examples: $(EXAMPLES_DIR)
|
||||
|
||||
model-small: mkdir examples/go-model-download
|
||||
@${BUILD_DIR}/go-model-download -out models ggml-small.en.bin
|
||||
|
||||
$(EXAMPLES_DIR): mkdir whisper modtidy
|
||||
@echo Build example $(notdir $@)
|
||||
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go build ${BUILD_FLAGS} -o ${BUILD_DIR}/$(notdir $@) ./$@
|
||||
|
||||
mkdir:
|
||||
@echo Mkdir ${BUILD_DIR}
|
||||
@install -d ${BUILD_DIR}
|
||||
@echo Mkdir ${MODELS_DIR}
|
||||
@install -d ${MODELS_DIR}
|
||||
|
||||
modtidy:
|
||||
@go mod tidy
|
||||
|
||||
clean:
|
||||
@echo Clean
|
||||
@rm -fr $(BUILD_DIR)
|
||||
@go clean
|
100
bindings/go/README.md
Normal file
100
bindings/go/README.md
Normal file
@ -0,0 +1,100 @@
|
||||
# Go bindings for Whisper
|
||||
|
||||
This package provides Go bindings for whisper.cpp. They have been tested on:
|
||||
|
||||
* Darwin (OS X) 12.6 on x64_64
|
||||
* Debian Linux on arm64
|
||||
* Fedora Linux on x86_64
|
||||
|
||||
The "low level" bindings are in the `bindings/go` directory and there is a more
|
||||
Go-style package in the `bindings/go/pkg/whisper` directory. The most simple usage
|
||||
is as follows:
|
||||
|
||||
```go
|
||||
import (
|
||||
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
||||
)
|
||||
|
||||
func main() {
|
||||
var modelpath string // Path to the model
|
||||
var samples []float32 // Samples to process
|
||||
|
||||
// Load the model
|
||||
model, err := whisper.New(modelpath)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer model.Close()
|
||||
|
||||
// Process samples
|
||||
context, err := model.NewContext()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := context.Process(samples, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Print out the results
|
||||
for {
|
||||
segment, err := context.NextSegment()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
fmt.Printf("[%6s->%6s] %s\n", segment.Start, segment.End, segment.Text)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Building & Testing
|
||||
|
||||
In order to build, you need to have the Go compiler installed. You can get it from [here](https://golang.org/dl/). Run the tests with:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/ggerganov/whisper.cpp.git
|
||||
cd whisper.cpp/bindings/go
|
||||
make test
|
||||
```
|
||||
|
||||
This will compile a static `libwhisper.a` in a `build` folder, download a model file, then run the tests. To build the examples:
|
||||
|
||||
```bash
|
||||
make examples
|
||||
```
|
||||
|
||||
The examples are placed in the `build` directory. Once built, you can download all the models with the following command:
|
||||
|
||||
```bash
|
||||
./build/go-model-download -out models
|
||||
```
|
||||
|
||||
And you can then test a model against samples with the following command:
|
||||
|
||||
```bash
|
||||
./build/go-whisper -model models/ggml-tiny.en.bin samples/jfk.wav
|
||||
```
|
||||
|
||||
## Using the bindings
|
||||
|
||||
To use the bindings in your own software,
|
||||
|
||||
1. Import `github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper` (or `github.com/ggerganov/whisper.cpp/bindings/go` into your package;
|
||||
2. Compile `libwhisper.a` (you can use `make whisper` in the `bindings/go` directory);
|
||||
3. Link your go binary against whisper by setting the environment variables `C_INCLUDE_PATH` and `LIBRARY_PATH`
|
||||
to point to the `whisper.h` file directory and `libwhisper.a` file directory respectively.
|
||||
|
||||
Look at the `Makefile` in the `bindings/go` directory for an example.
|
||||
|
||||
The API Documentation:
|
||||
|
||||
* https://pkg.go.dev/github.com/ggerganov/whisper.cpp/bindings/go
|
||||
* https://pkg.go.dev/github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper
|
||||
|
||||
Getting help:
|
||||
|
||||
* Follow the discussion for the go bindings [here](https://github.com/ggerganov/whisper.cpp/discussions/312)
|
||||
|
||||
## License
|
||||
|
||||
The license for the Go bindings is the same as the license for the rest of the whisper.cpp project, which is the MIT License. See the `LICENSE` file for more details.
|
||||
|
5
bindings/go/doc.go
Normal file
5
bindings/go/doc.go
Normal file
@ -0,0 +1,5 @@
|
||||
/*
|
||||
github.com/ggerganov/whisper.cpp/bindings/go
|
||||
provides a speech-to-text service bindings for the Go programming language.
|
||||
*/
|
||||
package whisper
|
30
bindings/go/examples/go-model-download/context.go
Normal file
30
bindings/go/examples/go-model-download/context.go
Normal file
@ -0,0 +1,30 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"os/signal"
|
||||
)
|
||||
|
||||
// ContextForSignal returns a context object which is cancelled when a signal
|
||||
// is received. It returns nil if no signal parameter is provided
|
||||
func ContextForSignal(signals ...os.Signal) context.Context {
|
||||
if len(signals) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
ch := make(chan os.Signal)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// Send message on channel when signal received
|
||||
signal.Notify(ch, signals...)
|
||||
|
||||
// When any signal received, call cancel
|
||||
go func() {
|
||||
<-ch
|
||||
cancel()
|
||||
}()
|
||||
|
||||
// Return success
|
||||
return ctx
|
||||
}
|
208
bindings/go/examples/go-model-download/main.go
Normal file
208
bindings/go/examples/go-model-download/main.go
Normal file
@ -0,0 +1,208 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// CONSTANTS
|
||||
|
||||
const (
|
||||
srcUrl = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main" // The location of the models
|
||||
srcExt = ".bin" // Filename extension
|
||||
bufSize = 1024 * 64 // Size of the buffer used for downloading the model
|
||||
)
|
||||
|
||||
var (
|
||||
// The models which will be downloaded, if no model is specified as an argument
|
||||
modelNames = []string{"ggml-tiny.en", "ggml-tiny", "ggml-base.en", "ggml-base", "ggml-small.en", "ggml-small", "ggml-medium.en", "ggml-medium", "ggml-large-v1", "ggml-large"}
|
||||
)
|
||||
|
||||
var (
|
||||
// The output folder. When not set, use current working directory.
|
||||
flagOut = flag.String("out", "", "Output folder")
|
||||
|
||||
// HTTP timeout parameter - will timeout if takes longer than this to download a model
|
||||
flagTimeout = flag.Duration("timeout", 30*time.Minute, "HTTP timeout")
|
||||
|
||||
// Quiet parameter - will not print progress if set
|
||||
flagQuiet = flag.Bool("quiet", false, "Quiet mode")
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MAIN
|
||||
|
||||
func main() {
|
||||
flag.Usage = func() {
|
||||
name := filepath.Base(flag.CommandLine.Name())
|
||||
fmt.Fprintf(flag.CommandLine.Output(), "Usage: %s [options] <model>\n\n", name)
|
||||
flag.PrintDefaults()
|
||||
}
|
||||
flag.Parse()
|
||||
|
||||
// Get output path
|
||||
out, err := GetOut()
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "Error:", err)
|
||||
os.Exit(-1)
|
||||
}
|
||||
|
||||
// Create context which quits on SIGINT or SIGQUIT
|
||||
ctx := ContextForSignal(os.Interrupt, syscall.SIGQUIT)
|
||||
|
||||
// Progress filehandle
|
||||
progress := os.Stdout
|
||||
if *flagQuiet {
|
||||
progress, err = os.Open(os.DevNull)
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "Error:", err)
|
||||
os.Exit(-1)
|
||||
}
|
||||
defer progress.Close()
|
||||
}
|
||||
|
||||
// Download models - exit on error or interrupt
|
||||
for _, model := range GetModels() {
|
||||
url, err := URLForModel(model)
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "Error:", err)
|
||||
continue
|
||||
} else if path, err := Download(ctx, progress, url, out); err == nil || err == io.EOF {
|
||||
continue
|
||||
} else if err == context.Canceled {
|
||||
os.Remove(path)
|
||||
fmt.Fprintln(progress, "\nInterrupted")
|
||||
break
|
||||
} else if err == context.DeadlineExceeded {
|
||||
os.Remove(path)
|
||||
fmt.Fprintln(progress, "Timeout downloading model")
|
||||
continue
|
||||
} else {
|
||||
os.Remove(path)
|
||||
fmt.Fprintln(os.Stderr, "Error:", err)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// PUBLIC METHODS
|
||||
|
||||
// GetOut returns the path to the output directory
|
||||
func GetOut() (string, error) {
|
||||
if *flagOut == "" {
|
||||
return os.Getwd()
|
||||
}
|
||||
if info, err := os.Stat(*flagOut); err != nil {
|
||||
return "", err
|
||||
} else if !info.IsDir() {
|
||||
return "", fmt.Errorf("not a directory: %s", info.Name())
|
||||
} else {
|
||||
return *flagOut, nil
|
||||
}
|
||||
}
|
||||
|
||||
// GetModels returns the list of models to download
|
||||
func GetModels() []string {
|
||||
if flag.NArg() == 0 {
|
||||
return modelNames
|
||||
} else {
|
||||
return flag.Args()
|
||||
}
|
||||
}
|
||||
|
||||
// URLForModel returns the URL for the given model on huggingface.co
|
||||
func URLForModel(model string) (string, error) {
|
||||
if filepath.Ext(model) != srcExt {
|
||||
model += srcExt
|
||||
}
|
||||
url, err := url.Parse(srcUrl)
|
||||
if err != nil {
|
||||
return "", err
|
||||
} else {
|
||||
url.Path = filepath.Join(url.Path, model)
|
||||
}
|
||||
return url.String(), nil
|
||||
}
|
||||
|
||||
// Download downloads the model from the given URL to the given output directory
|
||||
func Download(ctx context.Context, p io.Writer, model, out string) (string, error) {
|
||||
// Create HTTP client
|
||||
client := http.Client{
|
||||
Timeout: *flagTimeout,
|
||||
}
|
||||
|
||||
// Initiate the download
|
||||
req, err := http.NewRequest("GET", model, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("%s: %s", model, resp.Status)
|
||||
}
|
||||
|
||||
// If output file exists and is the same size as the model, skip
|
||||
path := filepath.Join(out, filepath.Base(model))
|
||||
if info, err := os.Stat(path); err == nil && info.Size() == resp.ContentLength {
|
||||
fmt.Fprintln(p, "Skipping", model, "as it already exists")
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Create file
|
||||
w, err := os.Create(path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer w.Close()
|
||||
|
||||
// Report
|
||||
fmt.Fprintln(p, "Downloading", model, "to", out)
|
||||
|
||||
// Progressively download the model
|
||||
data := make([]byte, bufSize)
|
||||
count, pct := int64(0), int64(0)
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Cancelled, return error
|
||||
return path, ctx.Err()
|
||||
case <-ticker.C:
|
||||
pct = DownloadReport(p, pct, count, resp.ContentLength)
|
||||
default:
|
||||
// Read body
|
||||
n, err := resp.Body.Read(data)
|
||||
if err != nil {
|
||||
DownloadReport(p, pct, count, resp.ContentLength)
|
||||
return path, err
|
||||
} else if m, err := w.Write(data[:n]); err != nil {
|
||||
return path, err
|
||||
} else {
|
||||
count += int64(m)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Report periodically reports the download progress when percentage changes
|
||||
func DownloadReport(w io.Writer, pct, count, total int64) int64 {
|
||||
pct_ := count * 100 / total
|
||||
if pct_ > pct {
|
||||
fmt.Fprintf(w, " ...%d MB written (%d%%)\n", count/1e6, pct_)
|
||||
}
|
||||
return pct_
|
||||
}
|
22
bindings/go/examples/go-whisper/color.go
Normal file
22
bindings/go/examples/go-whisper/color.go
Normal file
@ -0,0 +1,22 @@
|
||||
package main
|
||||
|
||||
import "fmt"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// CONSTANTS
|
||||
|
||||
const (
|
||||
Reset = "\033[0m"
|
||||
RGBPrefix = "\033[38;5;" // followed by RGB values in decimal format separated by colons
|
||||
RGBSuffix = "m"
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// PUBLIC METHODS
|
||||
|
||||
// Colorize text with RGB values, from 0 to 23
|
||||
func Colorize(text string, v int) string {
|
||||
// https://en.wikipedia.org/wiki/ANSI_escape_code#8-bit
|
||||
// Grayscale colors are in the range 232-255
|
||||
return RGBPrefix + fmt.Sprint(v%24+232) + RGBSuffix + text + Reset
|
||||
}
|
156
bindings/go/examples/go-whisper/flags.go
Normal file
156
bindings/go/examples/go-whisper/flags.go
Normal file
@ -0,0 +1,156 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
// Packages
|
||||
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// TYPES
|
||||
|
||||
type Flags struct {
|
||||
*flag.FlagSet
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// LIFECYCLE
|
||||
|
||||
func NewFlags(name string, args []string) (*Flags, error) {
|
||||
flags := &Flags{
|
||||
FlagSet: flag.NewFlagSet(name, flag.ContinueOnError),
|
||||
}
|
||||
|
||||
// Register the command line arguments
|
||||
registerFlags(flags)
|
||||
|
||||
// Parse command line
|
||||
if err := flags.Parse(args); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Return success
|
||||
return flags, nil
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// PUBLIC METHODS
|
||||
|
||||
func (flags *Flags) GetModel() string {
|
||||
return flags.Lookup("model").Value.String()
|
||||
}
|
||||
|
||||
func (flags *Flags) GetLanguage() string {
|
||||
return flags.Lookup("language").Value.String()
|
||||
}
|
||||
|
||||
func (flags *Flags) IsTranslate() bool {
|
||||
return flags.Lookup("translate").Value.(flag.Getter).Get().(bool)
|
||||
}
|
||||
|
||||
func (flags *Flags) GetOffset() time.Duration {
|
||||
return flags.Lookup("offset").Value.(flag.Getter).Get().(time.Duration)
|
||||
}
|
||||
|
||||
func (flags *Flags) GetDuration() time.Duration {
|
||||
return flags.Lookup("duration").Value.(flag.Getter).Get().(time.Duration)
|
||||
}
|
||||
|
||||
func (flags *Flags) GetThreads() uint {
|
||||
return flags.Lookup("threads").Value.(flag.Getter).Get().(uint)
|
||||
}
|
||||
|
||||
func (flags *Flags) GetOut() string {
|
||||
return strings.ToLower(flags.Lookup("out").Value.String())
|
||||
}
|
||||
|
||||
func (flags *Flags) IsSpeedup() bool {
|
||||
return flags.Lookup("speedup").Value.String() == "true"
|
||||
}
|
||||
|
||||
func (flags *Flags) IsTokens() bool {
|
||||
return flags.Lookup("tokens").Value.String() == "true"
|
||||
}
|
||||
|
||||
func (flags *Flags) IsColorize() bool {
|
||||
return flags.Lookup("colorize").Value.String() == "true"
|
||||
}
|
||||
|
||||
func (flags *Flags) GetMaxLen() uint {
|
||||
return flags.Lookup("max-len").Value.(flag.Getter).Get().(uint)
|
||||
}
|
||||
|
||||
func (flags *Flags) GetMaxTokens() uint {
|
||||
return flags.Lookup("max-tokens").Value.(flag.Getter).Get().(uint)
|
||||
}
|
||||
|
||||
func (flags *Flags) GetWordThreshold() float32 {
|
||||
return float32(flags.Lookup("word-thold").Value.(flag.Getter).Get().(float64))
|
||||
}
|
||||
|
||||
func (flags *Flags) SetParams(context whisper.Context) error {
|
||||
if lang := flags.GetLanguage(); lang != "" && lang != "auto" {
|
||||
fmt.Fprintf(flags.Output(), "Setting language to %q\n", lang)
|
||||
if err := context.SetLanguage(lang); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if flags.IsTranslate() && context.IsMultilingual() {
|
||||
fmt.Fprintf(flags.Output(), "Setting translate to true\n")
|
||||
context.SetTranslate(true)
|
||||
}
|
||||
if offset := flags.GetOffset(); offset != 0 {
|
||||
fmt.Fprintf(flags.Output(), "Setting offset to %v\n", offset)
|
||||
context.SetOffset(offset)
|
||||
}
|
||||
if duration := flags.GetDuration(); duration != 0 {
|
||||
fmt.Fprintf(flags.Output(), "Setting duration to %v\n", duration)
|
||||
context.SetDuration(duration)
|
||||
}
|
||||
if flags.IsSpeedup() {
|
||||
fmt.Fprintf(flags.Output(), "Setting speedup to true\n")
|
||||
context.SetSpeedup(true)
|
||||
}
|
||||
if threads := flags.GetThreads(); threads != 0 {
|
||||
fmt.Fprintf(flags.Output(), "Setting threads to %d\n", threads)
|
||||
context.SetThreads(threads)
|
||||
}
|
||||
if max_len := flags.GetMaxLen(); max_len != 0 {
|
||||
fmt.Fprintf(flags.Output(), "Setting max_segment_length to %d\n", max_len)
|
||||
context.SetMaxSegmentLength(max_len)
|
||||
}
|
||||
if max_tokens := flags.GetMaxTokens(); max_tokens != 0 {
|
||||
fmt.Fprintf(flags.Output(), "Setting max_tokens to %d\n", max_tokens)
|
||||
context.SetMaxTokensPerSegment(max_tokens)
|
||||
}
|
||||
if word_threshold := flags.GetWordThreshold(); word_threshold != 0 {
|
||||
fmt.Fprintf(flags.Output(), "Setting word_threshold to %f\n", word_threshold)
|
||||
context.SetTokenThreshold(word_threshold)
|
||||
}
|
||||
|
||||
// Return success
|
||||
return nil
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// PRIVATE METHODS
|
||||
|
||||
func registerFlags(flag *Flags) {
|
||||
flag.String("model", "", "Path to the model file")
|
||||
flag.String("language", "", "Spoken language")
|
||||
flag.Bool("translate", false, "Translate from source language to english")
|
||||
flag.Duration("offset", 0, "Time offset")
|
||||
flag.Duration("duration", 0, "Duration of audio to process")
|
||||
flag.Uint("threads", 0, "Number of threads to use")
|
||||
flag.Bool("speedup", false, "Enable speedup")
|
||||
flag.Uint("max-len", 0, "Maximum segment length in characters")
|
||||
flag.Uint("max-tokens", 0, "Maximum tokens per segment")
|
||||
flag.Float64("word-thold", 0, "Maximum segment score")
|
||||
flag.Bool("tokens", false, "Display tokens")
|
||||
flag.Bool("colorize", false, "Colorize tokens")
|
||||
flag.String("out", "", "Output format (srt, none or leave as empty string)")
|
||||
}
|
43
bindings/go/examples/go-whisper/main.go
Normal file
43
bindings/go/examples/go-whisper/main.go
Normal file
@ -0,0 +1,43 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
// Packages
|
||||
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
||||
)
|
||||
|
||||
func main() {
|
||||
flags, err := NewFlags(filepath.Base(os.Args[0]), os.Args[1:])
|
||||
if err == flag.ErrHelp {
|
||||
os.Exit(0)
|
||||
} else if err != nil {
|
||||
fmt.Fprintln(os.Stderr, err)
|
||||
os.Exit(1)
|
||||
} else if flags.GetModel() == "" {
|
||||
fmt.Fprintln(os.Stderr, "Use -model flag to specify which model file to use")
|
||||
os.Exit(1)
|
||||
} else if flags.NArg() == 0 {
|
||||
fmt.Fprintln(os.Stderr, "No input files specified")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Load model
|
||||
model, err := whisper.New(flags.GetModel())
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer model.Close()
|
||||
|
||||
// Process files
|
||||
for _, filename := range flags.Args() {
|
||||
if err := Process(model, filename, flags); err != nil {
|
||||
fmt.Fprintln(os.Stderr, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
132
bindings/go/examples/go-whisper/process.go
Normal file
132
bindings/go/examples/go-whisper/process.go
Normal file
@ -0,0 +1,132 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
// Package imports
|
||||
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
||||
wav "github.com/go-audio/wav"
|
||||
)
|
||||
|
||||
func Process(model whisper.Model, path string, flags *Flags) error {
|
||||
var data []float32
|
||||
|
||||
// Create processing context
|
||||
context, err := model.NewContext()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Set the parameters
|
||||
if err := flags.SetParams(context); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Printf("\n%s\n", context.SystemInfo())
|
||||
|
||||
// Open the file
|
||||
fmt.Fprintf(flags.Output(), "Loading %q\n", path)
|
||||
fh, err := os.Open(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer fh.Close()
|
||||
|
||||
// Decode the WAV file - load the full buffer
|
||||
dec := wav.NewDecoder(fh)
|
||||
if buf, err := dec.FullPCMBuffer(); err != nil {
|
||||
return err
|
||||
} else if dec.SampleRate != whisper.SampleRate {
|
||||
return fmt.Errorf("unsupported sample rate: %d", dec.SampleRate)
|
||||
} else if dec.NumChans != 1 {
|
||||
return fmt.Errorf("unsupported number of channels: %d", dec.NumChans)
|
||||
} else {
|
||||
data = buf.AsFloat32Buffer().Data
|
||||
}
|
||||
|
||||
// Segment callback when -tokens is specified
|
||||
var cb whisper.SegmentCallback
|
||||
if flags.IsTokens() {
|
||||
cb = func(segment whisper.Segment) {
|
||||
fmt.Fprintf(flags.Output(), "%02d [%6s->%6s] ", segment.Num, segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond))
|
||||
for _, token := range segment.Tokens {
|
||||
if flags.IsColorize() && context.IsText(token) {
|
||||
fmt.Fprint(flags.Output(), Colorize(token.Text, int(token.P*24.0)), " ")
|
||||
} else {
|
||||
fmt.Fprint(flags.Output(), token.Text, " ")
|
||||
}
|
||||
}
|
||||
fmt.Fprintln(flags.Output(), "")
|
||||
fmt.Fprintln(flags.Output(), "")
|
||||
}
|
||||
}
|
||||
|
||||
// Process the data
|
||||
fmt.Fprintf(flags.Output(), " ...processing %q\n", path)
|
||||
context.ResetTimings()
|
||||
if err := context.Process(data, cb); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
context.PrintTimings()
|
||||
|
||||
// Print out the results
|
||||
switch {
|
||||
case flags.GetOut() == "srt":
|
||||
return OutputSRT(os.Stdout, context)
|
||||
case flags.GetOut() == "none":
|
||||
return nil
|
||||
default:
|
||||
return Output(os.Stdout, context, flags.IsColorize())
|
||||
}
|
||||
}
|
||||
|
||||
// Output text as SRT file
|
||||
func OutputSRT(w io.Writer, context whisper.Context) error {
|
||||
n := 1
|
||||
for {
|
||||
segment, err := context.NextSegment()
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Fprintln(w, n)
|
||||
fmt.Fprintln(w, srtTimestamp(segment.Start), " --> ", srtTimestamp(segment.End))
|
||||
fmt.Fprintln(w, segment.Text)
|
||||
fmt.Fprintln(w, "")
|
||||
n++
|
||||
}
|
||||
}
|
||||
|
||||
// Output text to terminal
|
||||
func Output(w io.Writer, context whisper.Context, colorize bool) error {
|
||||
for {
|
||||
segment, err := context.NextSegment()
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Fprintf(w, "[%6s->%6s]", segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond))
|
||||
if colorize {
|
||||
for _, token := range segment.Tokens {
|
||||
if !context.IsText(token) {
|
||||
continue
|
||||
}
|
||||
fmt.Fprint(w, " ", Colorize(token.Text, int(token.P*24.0)))
|
||||
}
|
||||
fmt.Fprint(w, "\n")
|
||||
} else {
|
||||
fmt.Fprintln(w, " ", segment.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Return srtTimestamp
|
||||
func srtTimestamp(t time.Duration) string {
|
||||
return fmt.Sprintf("%02d:%02d:%02d,%03d", t/time.Hour, (t%time.Hour)/time.Minute, (t%time.Minute)/time.Second, (t%time.Second)/time.Millisecond)
|
||||
}
|
16
bindings/go/go.mod
Normal file
16
bindings/go/go.mod
Normal file
@ -0,0 +1,16 @@
|
||||
module github.com/ggerganov/whisper.cpp/bindings/go
|
||||
|
||||
go 1.19
|
||||
|
||||
require (
|
||||
github.com/go-audio/wav v1.1.0
|
||||
github.com/stretchr/testify v1.8.1
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/go-audio/audio v1.0.0 // indirect
|
||||
github.com/go-audio/riff v1.0.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
23
bindings/go/go.sum
Normal file
23
bindings/go/go.sum
Normal file
@ -0,0 +1,23 @@
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/go-audio/audio v1.0.0 h1:zS9vebldgbQqktK4H0lUqWrG8P0NxCJVqcj7ZpNnwd4=
|
||||
github.com/go-audio/audio v1.0.0/go.mod h1:6uAu0+H2lHkwdGsAY+j2wHPNPpPoeg5AaEFh9FlA+Zs=
|
||||
github.com/go-audio/riff v1.0.0 h1:d8iCGbDvox9BfLagY94fBynxSPHO80LmZCaOsmKxokA=
|
||||
github.com/go-audio/riff v1.0.0/go.mod h1:l3cQwc85y79NQFCRB7TiPoNiaijp6q8Z0Uv38rVG498=
|
||||
github.com/go-audio/wav v1.1.0 h1:jQgLtbqBzY7G+BM8fXF7AHUk1uHUviWS4X39d5rsL2g=
|
||||
github.com/go-audio/wav v1.1.0/go.mod h1:mpe9qfwbScEbkd8uybLuIpTgHyrISw/OTuvjUW2iGtE=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
169
bindings/go/params.go
Normal file
169
bindings/go/params.go
Normal file
@ -0,0 +1,169 @@
|
||||
package whisper
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// CGO
|
||||
|
||||
/*
|
||||
#include <whisper.h>
|
||||
*/
|
||||
import "C"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// PUBLIC METHODS
|
||||
|
||||
func (p *Params) SetTranslate(v bool) {
|
||||
p.translate = toBool(v)
|
||||
}
|
||||
|
||||
func (p *Params) SetNoContext(v bool) {
|
||||
p.no_context = toBool(v)
|
||||
}
|
||||
|
||||
func (p *Params) SetSingleSegment(v bool) {
|
||||
p.single_segment = toBool(v)
|
||||
}
|
||||
|
||||
func (p *Params) SetPrintSpecial(v bool) {
|
||||
p.print_special = toBool(v)
|
||||
}
|
||||
|
||||
func (p *Params) SetPrintProgress(v bool) {
|
||||
p.print_progress = toBool(v)
|
||||
}
|
||||
|
||||
func (p *Params) SetPrintRealtime(v bool) {
|
||||
p.print_realtime = toBool(v)
|
||||
}
|
||||
|
||||
func (p *Params) SetPrintTimestamps(v bool) {
|
||||
p.print_timestamps = toBool(v)
|
||||
}
|
||||
|
||||
func (p *Params) SetSpeedup(v bool) {
|
||||
p.speed_up = toBool(v)
|
||||
}
|
||||
|
||||
// Set language id
|
||||
func (p *Params) SetLanguage(lang int) error {
|
||||
if lang == -1 {
|
||||
p.language = nil
|
||||
return nil
|
||||
}
|
||||
str := C.whisper_lang_str(C.int(lang))
|
||||
if str == nil {
|
||||
return ErrInvalidLanguage
|
||||
} else {
|
||||
p.language = str
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get language id
|
||||
func (p *Params) Language() int {
|
||||
if p.language == nil {
|
||||
return -1
|
||||
}
|
||||
return int(C.whisper_lang_id(p.language))
|
||||
}
|
||||
|
||||
// Threads available
|
||||
func (p *Params) Threads() int {
|
||||
return int(p.n_threads)
|
||||
}
|
||||
|
||||
// Set number of threads to use
|
||||
func (p *Params) SetThreads(threads int) {
|
||||
p.n_threads = C.int(threads)
|
||||
}
|
||||
|
||||
// Set start offset in ms
|
||||
func (p *Params) SetOffset(offset_ms int) {
|
||||
p.offset_ms = C.int(offset_ms)
|
||||
}
|
||||
|
||||
// Set audio duration to process in ms
|
||||
func (p *Params) SetDuration(duration_ms int) {
|
||||
p.duration_ms = C.int(duration_ms)
|
||||
}
|
||||
|
||||
// Set timestamp token probability threshold (~0.01)
|
||||
func (p *Params) SetTokenThreshold(t float32) {
|
||||
p.thold_pt = C.float(t)
|
||||
}
|
||||
|
||||
// Set timestamp token sum probability threshold (~0.01)
|
||||
func (p *Params) SetTokenSumThreshold(t float32) {
|
||||
p.thold_ptsum = C.float(t)
|
||||
}
|
||||
|
||||
// Set max segment length in characters
|
||||
func (p *Params) SetMaxSegmentLength(n int) {
|
||||
p.max_len = C.int(n)
|
||||
}
|
||||
|
||||
func (p *Params) SetTokenTimestamps(b bool) {
|
||||
p.token_timestamps = toBool(b)
|
||||
}
|
||||
|
||||
// Set max tokens per segment (0 = no limit)
|
||||
func (p *Params) SetMaxTokensPerSegment(n int) {
|
||||
p.max_tokens = C.int(n)
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// PRIVATE METHODS
|
||||
|
||||
func toBool(v bool) C.bool {
|
||||
if v {
|
||||
return C.bool(true)
|
||||
}
|
||||
return C.bool(false)
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// STRINGIFY
|
||||
|
||||
func (p *Params) String() string {
|
||||
str := "<whisper.params"
|
||||
str += fmt.Sprintf(" strategy=%v", p.strategy)
|
||||
str += fmt.Sprintf(" n_threads=%d", p.n_threads)
|
||||
if p.language != nil {
|
||||
str += fmt.Sprintf(" language=%s", C.GoString(p.language))
|
||||
}
|
||||
str += fmt.Sprintf(" n_max_text_ctx=%d", p.n_max_text_ctx)
|
||||
str += fmt.Sprintf(" offset_ms=%d", p.offset_ms)
|
||||
str += fmt.Sprintf(" duration_ms=%d", p.duration_ms)
|
||||
if p.translate {
|
||||
str += " translate"
|
||||
}
|
||||
if p.no_context {
|
||||
str += " no_context"
|
||||
}
|
||||
if p.single_segment {
|
||||
str += " single_segment"
|
||||
}
|
||||
if p.print_special {
|
||||
str += " print_special"
|
||||
}
|
||||
if p.print_progress {
|
||||
str += " print_progress"
|
||||
}
|
||||
if p.print_realtime {
|
||||
str += " print_realtime"
|
||||
}
|
||||
if p.print_timestamps {
|
||||
str += " print_timestamps"
|
||||
}
|
||||
if p.token_timestamps {
|
||||
str += " token_timestamps"
|
||||
}
|
||||
if p.speed_up {
|
||||
str += " speed_up"
|
||||
}
|
||||
|
||||
return str + ">"
|
||||
}
|
28
bindings/go/pkg/whisper/consts.go
Normal file
28
bindings/go/pkg/whisper/consts.go
Normal file
@ -0,0 +1,28 @@
|
||||
package whisper
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
// Bindings
|
||||
whisper "github.com/ggerganov/whisper.cpp/bindings/go"
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// ERRORS
|
||||
|
||||
var (
|
||||
ErrUnableToLoadModel = errors.New("unable to load model")
|
||||
ErrInternalAppError = errors.New("internal application error")
|
||||
ErrProcessingFailed = errors.New("processing failed")
|
||||
ErrUnsupportedLanguage = errors.New("unsupported language")
|
||||
ErrModelNotMultilingual = errors.New("model is not multilingual")
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// CONSTANTS
|
||||
|
||||
// SampleRate is the sample rate of the audio data.
|
||||
const SampleRate = whisper.SampleRate
|
||||
|
||||
// SampleBits is the number of bytes per sample.
|
||||
const SampleBits = whisper.SampleBits
|
299
bindings/go/pkg/whisper/context.go
Normal file
299
bindings/go/pkg/whisper/context.go
Normal file
@ -0,0 +1,299 @@
|
||||
package whisper
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
// Bindings
|
||||
whisper "github.com/ggerganov/whisper.cpp/bindings/go"
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// TYPES
|
||||
|
||||
type context struct {
|
||||
n int
|
||||
model *model
|
||||
params whisper.Params
|
||||
}
|
||||
|
||||
// Make sure context adheres to the interface
|
||||
var _ Context = (*context)(nil)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// LIFECYCLE
|
||||
|
||||
func newContext(model *model, params whisper.Params) (Context, error) {
|
||||
context := new(context)
|
||||
context.model = model
|
||||
context.params = params
|
||||
|
||||
// Return success
|
||||
return context, nil
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// PUBLIC METHODS
|
||||
|
||||
// Set the language to use for speech recognition.
|
||||
func (context *context) SetLanguage(lang string) error {
|
||||
if context.model.ctx == nil {
|
||||
return ErrInternalAppError
|
||||
}
|
||||
if !context.model.IsMultilingual() {
|
||||
return ErrModelNotMultilingual
|
||||
}
|
||||
|
||||
if lang == "auto" {
|
||||
context.params.SetLanguage(-1)
|
||||
} else if id := context.model.ctx.Whisper_lang_id(lang); id < 0 {
|
||||
return ErrUnsupportedLanguage
|
||||
} else if err := context.params.SetLanguage(id); err != nil {
|
||||
return err
|
||||
}
|
||||
// Return success
|
||||
return nil
|
||||
}
|
||||
|
||||
func (context *context) IsMultilingual() bool {
|
||||
return context.model.IsMultilingual()
|
||||
}
|
||||
|
||||
// Get language
|
||||
func (context *context) Language() string {
|
||||
id := context.params.Language()
|
||||
if id == -1 {
|
||||
return "auto"
|
||||
}
|
||||
return whisper.Whisper_lang_str(context.params.Language())
|
||||
}
|
||||
|
||||
// Set translate flag
|
||||
func (context *context) SetTranslate(v bool) {
|
||||
context.params.SetTranslate(v)
|
||||
}
|
||||
|
||||
// Set speedup flag
|
||||
func (context *context) SetSpeedup(v bool) {
|
||||
context.params.SetSpeedup(v)
|
||||
}
|
||||
|
||||
// Set number of threads to use
|
||||
func (context *context) SetThreads(v uint) {
|
||||
context.params.SetThreads(int(v))
|
||||
}
|
||||
|
||||
// Set time offset
|
||||
func (context *context) SetOffset(v time.Duration) {
|
||||
context.params.SetOffset(int(v.Milliseconds()))
|
||||
}
|
||||
|
||||
// Set duration of audio to process
|
||||
func (context *context) SetDuration(v time.Duration) {
|
||||
context.params.SetOffset(int(v.Milliseconds()))
|
||||
}
|
||||
|
||||
// Set timestamp token probability threshold (~0.01)
|
||||
func (context *context) SetTokenThreshold(t float32) {
|
||||
context.params.SetTokenThreshold(t)
|
||||
}
|
||||
|
||||
// Set timestamp token sum probability threshold (~0.01)
|
||||
func (context *context) SetTokenSumThreshold(t float32) {
|
||||
context.params.SetTokenSumThreshold(t)
|
||||
}
|
||||
|
||||
// Set max segment length in characters
|
||||
func (context *context) SetMaxSegmentLength(n uint) {
|
||||
context.params.SetMaxSegmentLength(int(n))
|
||||
}
|
||||
|
||||
// Set token timestamps flag
|
||||
func (context *context) SetTokenTimestamps(b bool) {
|
||||
context.params.SetTokenTimestamps(b)
|
||||
}
|
||||
|
||||
// Set max tokens per segment (0 = no limit)
|
||||
func (context *context) SetMaxTokensPerSegment(n uint) {
|
||||
context.params.SetMaxTokensPerSegment(int(n))
|
||||
}
|
||||
|
||||
// ResetTimings resets the mode timings. Should be called before processing
|
||||
func (context *context) ResetTimings() {
|
||||
context.model.ctx.Whisper_reset_timings()
|
||||
}
|
||||
|
||||
// PrintTimings prints the model timings to stdout.
|
||||
func (context *context) PrintTimings() {
|
||||
context.model.ctx.Whisper_print_timings()
|
||||
}
|
||||
|
||||
// SystemInfo returns the system information
|
||||
func (context *context) SystemInfo() string {
|
||||
return fmt.Sprintf("system_info: n_threads = %d / %d | %s\n",
|
||||
context.params.Threads(),
|
||||
runtime.NumCPU(),
|
||||
whisper.Whisper_print_system_info(),
|
||||
)
|
||||
}
|
||||
|
||||
// Use mel data at offset_ms to try and auto-detect the spoken language
|
||||
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
|
||||
// Returns the probabilities of all languages.
|
||||
func (context *context) WhisperLangAutoDetect(offset_ms int, n_threads int) ([]float32, error) {
|
||||
langProbs, err := context.model.ctx.Whisper_lang_auto_detect(offset_ms, n_threads)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return langProbs, nil
|
||||
}
|
||||
|
||||
// Process new sample data and return any errors
|
||||
func (context *context) Process(data []float32, cb SegmentCallback) error {
|
||||
if context.model.ctx == nil {
|
||||
return ErrInternalAppError
|
||||
}
|
||||
// If the callback is defined then we force on single_segment mode
|
||||
if cb != nil {
|
||||
context.params.SetSingleSegment(true)
|
||||
}
|
||||
|
||||
// We don't do parallel processing at the moment
|
||||
processors := 0
|
||||
if processors > 1 {
|
||||
if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, nil, func(new int) {
|
||||
if cb != nil {
|
||||
num_segments := context.model.ctx.Whisper_full_n_segments()
|
||||
s0 := num_segments - new
|
||||
for i := s0; i < num_segments; i++ {
|
||||
cb(toSegment(context.model.ctx, i))
|
||||
}
|
||||
}
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if err := context.model.ctx.Whisper_full(context.params, data, nil, func(new int) {
|
||||
if cb != nil {
|
||||
num_segments := context.model.ctx.Whisper_full_n_segments()
|
||||
s0 := num_segments - new
|
||||
for i := s0; i < num_segments; i++ {
|
||||
cb(toSegment(context.model.ctx, i))
|
||||
}
|
||||
}
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Return success
|
||||
return nil
|
||||
}
|
||||
|
||||
// Return the next segment of tokens
|
||||
func (context *context) NextSegment() (Segment, error) {
|
||||
if context.model.ctx == nil {
|
||||
return Segment{}, ErrInternalAppError
|
||||
}
|
||||
if context.n >= context.model.ctx.Whisper_full_n_segments() {
|
||||
return Segment{}, io.EOF
|
||||
}
|
||||
|
||||
// Populate result
|
||||
result := toSegment(context.model.ctx, context.n)
|
||||
|
||||
// Increment the cursor
|
||||
context.n++
|
||||
|
||||
// Return success
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Test for text tokens
|
||||
func (context *context) IsText(t Token) bool {
|
||||
switch {
|
||||
case context.IsBEG(t):
|
||||
return false
|
||||
case context.IsSOT(t):
|
||||
return false
|
||||
case whisper.Token(t.Id) >= context.model.ctx.Whisper_token_eot():
|
||||
return false
|
||||
case context.IsPREV(t):
|
||||
return false
|
||||
case context.IsSOLM(t):
|
||||
return false
|
||||
case context.IsNOT(t):
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Test for "begin" token
|
||||
func (context *context) IsBEG(t Token) bool {
|
||||
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_beg()
|
||||
}
|
||||
|
||||
// Test for "start of transcription" token
|
||||
func (context *context) IsSOT(t Token) bool {
|
||||
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_sot()
|
||||
}
|
||||
|
||||
// Test for "end of transcription" token
|
||||
func (context *context) IsEOT(t Token) bool {
|
||||
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_eot()
|
||||
}
|
||||
|
||||
// Test for "start of prev" token
|
||||
func (context *context) IsPREV(t Token) bool {
|
||||
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_prev()
|
||||
}
|
||||
|
||||
// Test for "start of lm" token
|
||||
func (context *context) IsSOLM(t Token) bool {
|
||||
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_solm()
|
||||
}
|
||||
|
||||
// Test for "No timestamps" token
|
||||
func (context *context) IsNOT(t Token) bool {
|
||||
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_not()
|
||||
}
|
||||
|
||||
// Test for token associated with a specific language
|
||||
func (context *context) IsLANG(t Token, lang string) bool {
|
||||
if id := context.model.ctx.Whisper_lang_id(lang); id >= 0 {
|
||||
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_lang(id)
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// PRIVATE METHODS
|
||||
|
||||
func toSegment(ctx *whisper.Context, n int) Segment {
|
||||
return Segment{
|
||||
Num: n,
|
||||
Text: strings.TrimSpace(ctx.Whisper_full_get_segment_text(n)),
|
||||
Start: time.Duration(ctx.Whisper_full_get_segment_t0(n)) * time.Millisecond * 10,
|
||||
End: time.Duration(ctx.Whisper_full_get_segment_t1(n)) * time.Millisecond * 10,
|
||||
Tokens: toTokens(ctx, n),
|
||||
}
|
||||
}
|
||||
|
||||
func toTokens(ctx *whisper.Context, n int) []Token {
|
||||
result := make([]Token, ctx.Whisper_full_n_tokens(n))
|
||||
for i := 0; i < len(result); i++ {
|
||||
data := ctx.Whisper_full_get_token_data(n, i)
|
||||
|
||||
result[i] = Token{
|
||||
Id: int(ctx.Whisper_full_get_token_id(n, i)),
|
||||
Text: ctx.Whisper_full_get_token_text(n, i),
|
||||
P: ctx.Whisper_full_get_token_p(n, i),
|
||||
Start: time.Duration(data.T0()) * time.Millisecond * 10,
|
||||
End: time.Duration(data.T1()) * time.Millisecond * 10,
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
55
bindings/go/pkg/whisper/context_test.go
Normal file
55
bindings/go/pkg/whisper/context_test.go
Normal file
@ -0,0 +1,55 @@
|
||||
package whisper_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
// Packages
|
||||
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
||||
assert "github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const (
|
||||
ModelPath = "../../models/ggml-tiny.bin"
|
||||
SamplePath = "../../samples/jfk.wav"
|
||||
)
|
||||
|
||||
func Test_Whisper_000(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, model not found:", ModelPath)
|
||||
}
|
||||
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, sample not found:", SamplePath)
|
||||
}
|
||||
|
||||
// Load model
|
||||
model, err := whisper.New(ModelPath)
|
||||
assert.NoError(err)
|
||||
assert.NotNil(model)
|
||||
assert.NoError(model.Close())
|
||||
|
||||
t.Log("languages=", model.Languages())
|
||||
}
|
||||
|
||||
func Test_Whisper_001(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, model not found:", ModelPath)
|
||||
}
|
||||
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, sample not found:", SamplePath)
|
||||
}
|
||||
|
||||
// Load model
|
||||
model, err := whisper.New(ModelPath)
|
||||
assert.NoError(err)
|
||||
assert.NotNil(model)
|
||||
defer model.Close()
|
||||
|
||||
// Get context for decoding
|
||||
ctx, err := model.NewContext()
|
||||
assert.NoError(err)
|
||||
assert.NotNil(ctx)
|
||||
|
||||
}
|
4
bindings/go/pkg/whisper/doc.go
Normal file
4
bindings/go/pkg/whisper/doc.go
Normal file
@ -0,0 +1,4 @@
|
||||
/*
|
||||
This is the higher-level speech-to-text whisper.cpp API for go
|
||||
*/
|
||||
package whisper
|
93
bindings/go/pkg/whisper/interface.go
Normal file
93
bindings/go/pkg/whisper/interface.go
Normal file
@ -0,0 +1,93 @@
|
||||
package whisper
|
||||
|
||||
import (
|
||||
"io"
|
||||
"time"
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// TYPES
|
||||
|
||||
// SegmentCallback is the callback function for processing segments in real
|
||||
// time. It is called during the Process function
|
||||
type SegmentCallback func(Segment)
|
||||
|
||||
// Model is the interface to a whisper model. Create a new model with the
|
||||
// function whisper.New(string)
|
||||
type Model interface {
|
||||
io.Closer
|
||||
|
||||
// Return a new speech-to-text context.
|
||||
NewContext() (Context, error)
|
||||
|
||||
// Return true if the model is multilingual.
|
||||
IsMultilingual() bool
|
||||
|
||||
// Return all languages supported.
|
||||
Languages() []string
|
||||
}
|
||||
|
||||
// Context is the speach recognition context.
|
||||
type Context interface {
|
||||
SetLanguage(string) error // Set the language to use for speech recognition, use "auto" for auto detect language.
|
||||
SetTranslate(bool) // Set translate flag
|
||||
IsMultilingual() bool // Return true if the model is multilingual.
|
||||
Language() string // Get language
|
||||
|
||||
SetOffset(time.Duration) // Set offset
|
||||
SetDuration(time.Duration) // Set duration
|
||||
SetThreads(uint) // Set number of threads to use
|
||||
SetSpeedup(bool) // Set speedup flag
|
||||
SetTokenThreshold(float32) // Set timestamp token probability threshold
|
||||
SetTokenSumThreshold(float32) // Set timestamp token sum probability threshold
|
||||
SetMaxSegmentLength(uint) // Set max segment length in characters
|
||||
SetTokenTimestamps(bool) // Set token timestamps flag
|
||||
SetMaxTokensPerSegment(uint) // Set max tokens per segment (0 = no limit)
|
||||
|
||||
// Process mono audio data and return any errors.
|
||||
// If defined, newly generated segments are passed to the
|
||||
// callback function during processing.
|
||||
Process([]float32, SegmentCallback) error
|
||||
|
||||
// After process is called, return segments until the end of the stream
|
||||
// is reached, when io.EOF is returned.
|
||||
NextSegment() (Segment, error)
|
||||
|
||||
IsBEG(Token) bool // Test for "begin" token
|
||||
IsSOT(Token) bool // Test for "start of transcription" token
|
||||
IsEOT(Token) bool // Test for "end of transcription" token
|
||||
IsPREV(Token) bool // Test for "start of prev" token
|
||||
IsSOLM(Token) bool // Test for "start of lm" token
|
||||
IsNOT(Token) bool // Test for "No timestamps" token
|
||||
IsLANG(Token, string) bool // Test for token associated with a specific language
|
||||
IsText(Token) bool // Test for text token
|
||||
|
||||
// Timings
|
||||
PrintTimings()
|
||||
ResetTimings()
|
||||
|
||||
SystemInfo() string
|
||||
}
|
||||
|
||||
// Segment is the text result of a speech recognition.
|
||||
type Segment struct {
|
||||
// Segment Number
|
||||
Num int
|
||||
|
||||
// Time beginning and end timestamps for the segment.
|
||||
Start, End time.Duration
|
||||
|
||||
// The text of the segment.
|
||||
Text string
|
||||
|
||||
// The tokens of the segment.
|
||||
Tokens []Token
|
||||
}
|
||||
|
||||
// Token is a text or special token
|
||||
type Token struct {
|
||||
Id int
|
||||
Text string
|
||||
P float32
|
||||
Start, End time.Duration
|
||||
}
|
101
bindings/go/pkg/whisper/model.go
Normal file
101
bindings/go/pkg/whisper/model.go
Normal file
@ -0,0 +1,101 @@
|
||||
package whisper
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
|
||||
// Bindings
|
||||
whisper "github.com/ggerganov/whisper.cpp/bindings/go"
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// TYPES
|
||||
|
||||
type model struct {
|
||||
path string
|
||||
ctx *whisper.Context
|
||||
}
|
||||
|
||||
// Make sure model adheres to the interface
|
||||
var _ Model = (*model)(nil)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// LIFECYCLE
|
||||
|
||||
func New(path string) (Model, error) {
|
||||
model := new(model)
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
return nil, err
|
||||
} else if ctx := whisper.Whisper_init(path); ctx == nil {
|
||||
return nil, ErrUnableToLoadModel
|
||||
} else {
|
||||
model.ctx = ctx
|
||||
model.path = path
|
||||
}
|
||||
|
||||
// Return success
|
||||
return model, nil
|
||||
}
|
||||
|
||||
func (model *model) Close() error {
|
||||
if model.ctx != nil {
|
||||
model.ctx.Whisper_free()
|
||||
}
|
||||
|
||||
// Release resources
|
||||
model.ctx = nil
|
||||
|
||||
// Return success
|
||||
return nil
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// STRINGIFY
|
||||
|
||||
func (model *model) String() string {
|
||||
str := "<whisper.model"
|
||||
if model.ctx != nil {
|
||||
str += fmt.Sprintf(" model=%q", model.path)
|
||||
}
|
||||
return str + ">"
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// PUBLIC METHODS
|
||||
|
||||
// Return true if model is multilingual (language and translation options are supported)
|
||||
func (model *model) IsMultilingual() bool {
|
||||
return model.ctx.Whisper_is_multilingual() != 0
|
||||
}
|
||||
|
||||
// Return all recognized languages. Initially it is set to auto-detect
|
||||
func (model *model) Languages() []string {
|
||||
result := make([]string, 0, whisper.Whisper_lang_max_id())
|
||||
for i := 0; i < whisper.Whisper_lang_max_id(); i++ {
|
||||
str := whisper.Whisper_lang_str(i)
|
||||
if model.ctx.Whisper_lang_id(str) >= 0 {
|
||||
result = append(result, str)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (model *model) NewContext() (Context, error) {
|
||||
if model.ctx == nil {
|
||||
return nil, ErrInternalAppError
|
||||
}
|
||||
|
||||
// Create new context
|
||||
params := model.ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY)
|
||||
params.SetTranslate(false)
|
||||
params.SetPrintSpecial(false)
|
||||
params.SetPrintProgress(false)
|
||||
params.SetPrintRealtime(false)
|
||||
params.SetPrintTimestamps(false)
|
||||
params.SetThreads(runtime.NumCPU())
|
||||
params.SetNoContext(true)
|
||||
|
||||
// Return new context
|
||||
return newContext(model, params)
|
||||
}
|
BIN
bindings/go/samples/jfk.wav
Normal file
BIN
bindings/go/samples/jfk.wav
Normal file
Binary file not shown.
417
bindings/go/whisper.go
Normal file
417
bindings/go/whisper.go
Normal file
@ -0,0 +1,417 @@
|
||||
package whisper
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// CGO
|
||||
|
||||
/*
|
||||
#cgo LDFLAGS: -lwhisper -lm -lstdc++
|
||||
#cgo darwin LDFLAGS: -framework Accelerate
|
||||
#include <whisper.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
extern void callNewSegment(void* user_data, int new);
|
||||
extern bool callEncoderBegin(void* user_data);
|
||||
|
||||
// Text segment callback
|
||||
// Called on every newly generated text segment
|
||||
// Use the whisper_full_...() functions to obtain the text segments
|
||||
static void whisper_new_segment_cb(struct whisper_context* ctx, struct whisper_state* state, int n_new, void* user_data) {
|
||||
if(user_data != NULL && ctx != NULL) {
|
||||
callNewSegment(user_data, n_new);
|
||||
}
|
||||
}
|
||||
|
||||
// Encoder begin callback
|
||||
// If not NULL, called before the encoder starts
|
||||
// If it returns false, the computation is aborted
|
||||
static bool whisper_encoder_begin_cb(struct whisper_context* ctx, struct whisper_state* state, void* user_data) {
|
||||
if(user_data != NULL && ctx != NULL) {
|
||||
return callEncoderBegin(user_data);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Get default parameters and set callbacks
|
||||
static struct whisper_full_params whisper_full_default_params_cb(struct whisper_context* ctx, enum whisper_sampling_strategy strategy) {
|
||||
struct whisper_full_params params = whisper_full_default_params(strategy);
|
||||
params.new_segment_callback = whisper_new_segment_cb;
|
||||
params.new_segment_callback_user_data = (void*)(ctx);
|
||||
params.encoder_begin_callback = whisper_encoder_begin_cb;
|
||||
params.encoder_begin_callback_user_data = (void*)(ctx);
|
||||
return params;
|
||||
}
|
||||
*/
|
||||
import "C"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// TYPES
|
||||
|
||||
type (
|
||||
Context C.struct_whisper_context
|
||||
Token C.whisper_token
|
||||
TokenData C.struct_whisper_token_data
|
||||
SamplingStrategy C.enum_whisper_sampling_strategy
|
||||
Params C.struct_whisper_full_params
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GLOBALS
|
||||
|
||||
const (
|
||||
SAMPLING_GREEDY SamplingStrategy = C.WHISPER_SAMPLING_GREEDY
|
||||
SAMPLING_BEAM_SEARCH SamplingStrategy = C.WHISPER_SAMPLING_BEAM_SEARCH
|
||||
)
|
||||
|
||||
const (
|
||||
SampleRate = C.WHISPER_SAMPLE_RATE // Expected sample rate, samples per second
|
||||
SampleBits = uint16(unsafe.Sizeof(C.float(0))) * 8 // Sample size in bits
|
||||
NumFFT = C.WHISPER_N_FFT
|
||||
NumMEL = C.WHISPER_N_MEL
|
||||
HopLength = C.WHISPER_HOP_LENGTH
|
||||
ChunkSize = C.WHISPER_CHUNK_SIZE
|
||||
)
|
||||
|
||||
var (
|
||||
ErrTokenizerFailed = errors.New("whisper_tokenize failed")
|
||||
ErrAutoDetectFailed = errors.New("whisper_lang_auto_detect failed")
|
||||
ErrConversionFailed = errors.New("whisper_convert failed")
|
||||
ErrInvalidLanguage = errors.New("invalid language")
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// PUBLIC METHODS
|
||||
|
||||
// Allocates all memory needed for the model and loads the model from the given file.
|
||||
// Returns NULL on failure.
|
||||
func Whisper_init(path string) *Context {
|
||||
cPath := C.CString(path)
|
||||
defer C.free(unsafe.Pointer(cPath))
|
||||
if ctx := C.whisper_init_from_file(cPath); ctx != nil {
|
||||
return (*Context)(ctx)
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Frees all memory allocated by the model.
|
||||
func (ctx *Context) Whisper_free() {
|
||||
C.whisper_free((*C.struct_whisper_context)(ctx))
|
||||
}
|
||||
|
||||
// Convert RAW PCM audio to log mel spectrogram.
|
||||
// The resulting spectrogram is stored inside the provided whisper context.
|
||||
func (ctx *Context) Whisper_pcm_to_mel(data []float32, threads int) error {
|
||||
if C.whisper_pcm_to_mel((*C.struct_whisper_context)(ctx), (*C.float)(&data[0]), C.int(len(data)), C.int(threads)) == 0 {
|
||||
return nil
|
||||
} else {
|
||||
return ErrConversionFailed
|
||||
}
|
||||
}
|
||||
|
||||
// This can be used to set a custom log mel spectrogram inside the provided whisper context.
|
||||
// Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
|
||||
// n_mel must be 80
|
||||
func (ctx *Context) Whisper_set_mel(data []float32, n_mel int) error {
|
||||
if C.whisper_set_mel((*C.struct_whisper_context)(ctx), (*C.float)(&data[0]), C.int(len(data)), C.int(n_mel)) == 0 {
|
||||
return nil
|
||||
} else {
|
||||
return ErrConversionFailed
|
||||
}
|
||||
}
|
||||
|
||||
// Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper context.
|
||||
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
|
||||
// offset can be used to specify the offset of the first frame in the spectrogram.
|
||||
func (ctx *Context) Whisper_encode(offset, threads int) error {
|
||||
if C.whisper_encode((*C.struct_whisper_context)(ctx), C.int(offset), C.int(threads)) == 0 {
|
||||
return nil
|
||||
} else {
|
||||
return ErrConversionFailed
|
||||
}
|
||||
}
|
||||
|
||||
// Run the Whisper decoder to obtain the logits and probabilities for the next token.
|
||||
// Make sure to call whisper_encode() first.
|
||||
// tokens + n_tokens is the provided context for the decoder.
|
||||
// n_past is the number of tokens to use from previous decoder calls.
|
||||
func (ctx *Context) Whisper_decode(tokens []Token, past, threads int) error {
|
||||
if C.whisper_decode((*C.struct_whisper_context)(ctx), (*C.whisper_token)(&tokens[0]), C.int(len(tokens)), C.int(past), C.int(threads)) == 0 {
|
||||
return nil
|
||||
} else {
|
||||
return ErrConversionFailed
|
||||
}
|
||||
}
|
||||
|
||||
// Convert the provided text into tokens. The tokens pointer must be large enough to hold the resulting tokens.
|
||||
// Returns the number of tokens on success
|
||||
func (ctx *Context) Whisper_tokenize(text string, tokens []Token) (int, error) {
|
||||
cText := C.CString(text)
|
||||
defer C.free(unsafe.Pointer(cText))
|
||||
if n := C.whisper_tokenize((*C.struct_whisper_context)(ctx), cText, (*C.whisper_token)(&tokens[0]), C.int(len(tokens))); n >= 0 {
|
||||
return int(n), nil
|
||||
} else {
|
||||
return 0, ErrTokenizerFailed
|
||||
}
|
||||
}
|
||||
|
||||
// Return the id of the specified language, returns -1 if not found
|
||||
// Examples:
|
||||
//
|
||||
// "de" -> 2
|
||||
// "german" -> 2
|
||||
func (ctx *Context) Whisper_lang_id(lang string) int {
|
||||
return int(C.whisper_lang_id(C.CString(lang)))
|
||||
}
|
||||
|
||||
// Largest language id (i.e. number of available languages - 1)
|
||||
func Whisper_lang_max_id() int {
|
||||
return int(C.whisper_lang_max_id())
|
||||
}
|
||||
|
||||
// Return the short string of the specified language id (e.g. 2 -> "de"),
|
||||
// returns empty string if not found
|
||||
func Whisper_lang_str(id int) string {
|
||||
return C.GoString(C.whisper_lang_str(C.int(id)))
|
||||
}
|
||||
|
||||
// Use mel data at offset_ms to try and auto-detect the spoken language
|
||||
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
|
||||
// Returns the probabilities of all languages.
|
||||
// ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69
|
||||
func (ctx *Context) Whisper_lang_auto_detect(offset_ms, n_threads int) ([]float32, error) {
|
||||
probs := make([]float32, Whisper_lang_max_id()+1)
|
||||
if n := int(C.whisper_lang_auto_detect((*C.struct_whisper_context)(ctx), C.int(offset_ms), C.int(n_threads), (*C.float)(&probs[0]))); n < 0 {
|
||||
return nil, ErrAutoDetectFailed
|
||||
} else {
|
||||
return probs, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (ctx *Context) Whisper_n_len() int {
|
||||
return int(C.whisper_n_len((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
||||
func (ctx *Context) Whisper_n_vocab() int {
|
||||
return int(C.whisper_n_vocab((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
||||
func (ctx *Context) Whisper_n_text_ctx() int {
|
||||
return int(C.whisper_n_text_ctx((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
||||
func (ctx *Context) Whisper_n_audio_ctx() int {
|
||||
return int(C.whisper_n_audio_ctx((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
||||
func (ctx *Context) Whisper_is_multilingual() int {
|
||||
return int(C.whisper_is_multilingual((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
||||
// The probabilities for the next token
|
||||
//func (ctx *Whisper_context) Whisper_get_probs() []float32 {
|
||||
// return (*[1 << 30]float32)(unsafe.Pointer(C.whisper_get_probs((*C.struct_whisper_context)(ctx))))[:ctx.Whisper_n_vocab()]
|
||||
//}
|
||||
|
||||
// Token Id -> String. Uses the vocabulary in the provided context
|
||||
func (ctx *Context) Whisper_token_to_str(token Token) string {
|
||||
return C.GoString(C.whisper_token_to_str((*C.struct_whisper_context)(ctx), C.whisper_token(token)))
|
||||
}
|
||||
|
||||
// Special tokens
|
||||
func (ctx *Context) Whisper_token_eot() Token {
|
||||
return Token(C.whisper_token_eot((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
||||
// Special tokens
|
||||
func (ctx *Context) Whisper_token_sot() Token {
|
||||
return Token(C.whisper_token_sot((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
||||
// Special tokens
|
||||
func (ctx *Context) Whisper_token_prev() Token {
|
||||
return Token(C.whisper_token_prev((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
||||
// Special tokens
|
||||
func (ctx *Context) Whisper_token_solm() Token {
|
||||
return Token(C.whisper_token_solm((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
||||
// Special tokens
|
||||
func (ctx *Context) Whisper_token_not() Token {
|
||||
return Token(C.whisper_token_not((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
||||
// Special tokens
|
||||
func (ctx *Context) Whisper_token_beg() Token {
|
||||
return Token(C.whisper_token_beg((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
||||
// Special tokens
|
||||
func (ctx *Context) Whisper_token_lang(lang_id int) Token {
|
||||
return Token(C.whisper_token_lang((*C.struct_whisper_context)(ctx), C.int(lang_id)))
|
||||
}
|
||||
|
||||
// Task tokens
|
||||
func Whisper_token_translate() Token {
|
||||
return Token(C.whisper_token_translate())
|
||||
}
|
||||
|
||||
// Task tokens
|
||||
func Whisper_token_transcribe() Token {
|
||||
return Token(C.whisper_token_transcribe())
|
||||
}
|
||||
|
||||
// Performance information
|
||||
func (ctx *Context) Whisper_print_timings() {
|
||||
C.whisper_print_timings((*C.struct_whisper_context)(ctx))
|
||||
}
|
||||
|
||||
// Performance information
|
||||
func (ctx *Context) Whisper_reset_timings() {
|
||||
C.whisper_reset_timings((*C.struct_whisper_context)(ctx))
|
||||
}
|
||||
|
||||
// Print system information
|
||||
func Whisper_print_system_info() string {
|
||||
return C.GoString(C.whisper_print_system_info())
|
||||
}
|
||||
|
||||
// Return default parameters for a strategy
|
||||
func (ctx *Context) Whisper_full_default_params(strategy SamplingStrategy) Params {
|
||||
// Get default parameters
|
||||
return Params(C.whisper_full_default_params_cb((*C.struct_whisper_context)(ctx), C.enum_whisper_sampling_strategy(strategy)))
|
||||
}
|
||||
|
||||
// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
|
||||
// Uses the specified decoding strategy to obtain the text.
|
||||
func (ctx *Context) Whisper_full(params Params, samples []float32, encoderBeginCallback func() bool, newSegmentCallback func(int)) error {
|
||||
registerEncoderBeginCallback(ctx, encoderBeginCallback)
|
||||
registerNewSegmentCallback(ctx, newSegmentCallback)
|
||||
defer registerEncoderBeginCallback(ctx, nil)
|
||||
defer registerNewSegmentCallback(ctx, nil)
|
||||
if C.whisper_full((*C.struct_whisper_context)(ctx), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples))) == 0 {
|
||||
return nil
|
||||
} else {
|
||||
return ErrConversionFailed
|
||||
}
|
||||
}
|
||||
|
||||
// Split the input audio in chunks and process each chunk separately using whisper_full()
|
||||
// It seems this approach can offer some speedup in some cases.
|
||||
// However, the transcription accuracy can be worse at the beginning and end of each chunk.
|
||||
func (ctx *Context) Whisper_full_parallel(params Params, samples []float32, processors int, encoderBeginCallback func() bool, newSegmentCallback func(int)) error {
|
||||
registerEncoderBeginCallback(ctx, encoderBeginCallback)
|
||||
registerNewSegmentCallback(ctx, newSegmentCallback)
|
||||
defer registerEncoderBeginCallback(ctx, nil)
|
||||
defer registerNewSegmentCallback(ctx, nil)
|
||||
|
||||
if C.whisper_full_parallel((*C.struct_whisper_context)(ctx), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples)), C.int(processors)) == 0 {
|
||||
return nil
|
||||
} else {
|
||||
return ErrConversionFailed
|
||||
}
|
||||
}
|
||||
|
||||
// Number of generated text segments.
|
||||
// A segment can be a few words, a sentence, or even a paragraph.
|
||||
func (ctx *Context) Whisper_full_n_segments() int {
|
||||
return int(C.whisper_full_n_segments((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
||||
// Get the start and end time of the specified segment.
|
||||
func (ctx *Context) Whisper_full_get_segment_t0(segment int) int64 {
|
||||
return int64(C.whisper_full_get_segment_t0((*C.struct_whisper_context)(ctx), C.int(segment)))
|
||||
}
|
||||
|
||||
// Get the start and end time of the specified segment.
|
||||
func (ctx *Context) Whisper_full_get_segment_t1(segment int) int64 {
|
||||
return int64(C.whisper_full_get_segment_t1((*C.struct_whisper_context)(ctx), C.int(segment)))
|
||||
}
|
||||
|
||||
// Get the text of the specified segment.
|
||||
func (ctx *Context) Whisper_full_get_segment_text(segment int) string {
|
||||
return C.GoString(C.whisper_full_get_segment_text((*C.struct_whisper_context)(ctx), C.int(segment)))
|
||||
}
|
||||
|
||||
// Get number of tokens in the specified segment.
|
||||
func (ctx *Context) Whisper_full_n_tokens(segment int) int {
|
||||
return int(C.whisper_full_n_tokens((*C.struct_whisper_context)(ctx), C.int(segment)))
|
||||
}
|
||||
|
||||
// Get the token text of the specified token index in the specified segment.
|
||||
func (ctx *Context) Whisper_full_get_token_text(segment int, token int) string {
|
||||
return C.GoString(C.whisper_full_get_token_text((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
|
||||
}
|
||||
|
||||
// Get the token of the specified token index in the specified segment.
|
||||
func (ctx *Context) Whisper_full_get_token_id(segment int, token int) Token {
|
||||
return Token(C.whisper_full_get_token_id((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
|
||||
}
|
||||
|
||||
// Get token data for the specified token in the specified segment.
|
||||
// This contains probabilities, timestamps, etc.
|
||||
func (ctx *Context) Whisper_full_get_token_data(segment int, token int) TokenData {
|
||||
return TokenData(C.whisper_full_get_token_data((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
|
||||
}
|
||||
|
||||
// Get the probability of the specified token in the specified segment.
|
||||
func (ctx *Context) Whisper_full_get_token_p(segment int, token int) float32 {
|
||||
return float32(C.whisper_full_get_token_p((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// CALLBACKS
|
||||
|
||||
var (
|
||||
cbNewSegment = make(map[unsafe.Pointer]func(int))
|
||||
cbEncoderBegin = make(map[unsafe.Pointer]func() bool)
|
||||
)
|
||||
|
||||
func registerNewSegmentCallback(ctx *Context, fn func(int)) {
|
||||
if fn == nil {
|
||||
delete(cbNewSegment, unsafe.Pointer(ctx))
|
||||
} else {
|
||||
cbNewSegment[unsafe.Pointer(ctx)] = fn
|
||||
}
|
||||
}
|
||||
|
||||
func registerEncoderBeginCallback(ctx *Context, fn func() bool) {
|
||||
if fn == nil {
|
||||
delete(cbEncoderBegin, unsafe.Pointer(ctx))
|
||||
} else {
|
||||
cbEncoderBegin[unsafe.Pointer(ctx)] = fn
|
||||
}
|
||||
}
|
||||
|
||||
//export callNewSegment
|
||||
func callNewSegment(user_data unsafe.Pointer, new C.int) {
|
||||
if fn, ok := cbNewSegment[user_data]; ok {
|
||||
fn(int(new))
|
||||
}
|
||||
}
|
||||
|
||||
//export callEncoderBegin
|
||||
func callEncoderBegin(user_data unsafe.Pointer) C.bool {
|
||||
if fn, ok := cbEncoderBegin[user_data]; ok {
|
||||
if fn() {
|
||||
return C.bool(true)
|
||||
} else {
|
||||
return C.bool(false)
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (t TokenData) T0() int64 {
|
||||
return int64(t.t0)
|
||||
}
|
||||
|
||||
func (t TokenData) T1() int64 {
|
||||
return int64(t.t1)
|
||||
}
|
113
bindings/go/whisper_test.go
Normal file
113
bindings/go/whisper_test.go
Normal file
@ -0,0 +1,113 @@
|
||||
package whisper_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
// Packages
|
||||
whisper "github.com/ggerganov/whisper.cpp/bindings/go"
|
||||
wav "github.com/go-audio/wav"
|
||||
assert "github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const (
|
||||
ModelPath = "models/ggml-small.en.bin"
|
||||
SamplePath = "samples/jfk.wav"
|
||||
)
|
||||
|
||||
func Test_Whisper_000(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, model not found:", ModelPath)
|
||||
}
|
||||
ctx := whisper.Whisper_init(ModelPath)
|
||||
assert.NotNil(ctx)
|
||||
ctx.Whisper_free()
|
||||
}
|
||||
|
||||
func Test_Whisper_001(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, model not found:", ModelPath)
|
||||
}
|
||||
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, sample not found:", SamplePath)
|
||||
}
|
||||
|
||||
// Open samples
|
||||
fh, err := os.Open(SamplePath)
|
||||
assert.NoError(err)
|
||||
defer fh.Close()
|
||||
|
||||
// Read samples
|
||||
d := wav.NewDecoder(fh)
|
||||
buf, err := d.FullPCMBuffer()
|
||||
assert.NoError(err)
|
||||
|
||||
// Run whisper
|
||||
ctx := whisper.Whisper_init(ModelPath)
|
||||
assert.NotNil(ctx)
|
||||
defer ctx.Whisper_free()
|
||||
params := ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY)
|
||||
data := buf.AsFloat32Buffer().Data
|
||||
err = ctx.Whisper_full(params, data, nil, nil)
|
||||
assert.NoError(err)
|
||||
|
||||
// Print out tokens
|
||||
num_segments := ctx.Whisper_full_n_segments()
|
||||
assert.GreaterOrEqual(num_segments, 1)
|
||||
for i := 0; i < num_segments; i++ {
|
||||
str := ctx.Whisper_full_get_segment_text(i)
|
||||
assert.NotEmpty(str)
|
||||
t0 := time.Duration(ctx.Whisper_full_get_segment_t0(i)) * time.Millisecond
|
||||
t1 := time.Duration(ctx.Whisper_full_get_segment_t1(i)) * time.Millisecond
|
||||
t.Logf("[%6s->%-6s] %q", t0, t1, str)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Whisper_002(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
for i := 0; i < whisper.Whisper_lang_max_id(); i++ {
|
||||
str := whisper.Whisper_lang_str(i)
|
||||
assert.NotEmpty(str)
|
||||
t.Log(str)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Whisper_003(t *testing.T) {
|
||||
threads := runtime.NumCPU()
|
||||
assert := assert.New(t)
|
||||
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, model not found:", ModelPath)
|
||||
}
|
||||
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, sample not found:", SamplePath)
|
||||
}
|
||||
|
||||
// Open samples
|
||||
fh, err := os.Open(SamplePath)
|
||||
assert.NoError(err)
|
||||
defer fh.Close()
|
||||
|
||||
// Read samples
|
||||
d := wav.NewDecoder(fh)
|
||||
buf, err := d.FullPCMBuffer()
|
||||
assert.NoError(err)
|
||||
|
||||
// Make the model
|
||||
ctx := whisper.Whisper_init(ModelPath)
|
||||
assert.NotNil(ctx)
|
||||
defer ctx.Whisper_free()
|
||||
|
||||
// Get MEL
|
||||
assert.NoError(ctx.Whisper_pcm_to_mel(buf.AsFloat32Buffer().Data, threads))
|
||||
|
||||
// Get Languages
|
||||
languages, err := ctx.Whisper_lang_auto_detect(0, threads)
|
||||
assert.NoError(err)
|
||||
for i, p := range languages {
|
||||
t.Logf("%s: %f", whisper.Whisper_lang_str(i), p)
|
||||
}
|
||||
}
|
Submodule bindings/ios updated: 1502317fe0...30edc4c500
@ -20,7 +20,7 @@ struct whisper_context * g_context;
|
||||
EMSCRIPTEN_BINDINGS(whisper) {
|
||||
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
|
||||
if (g_context == nullptr) {
|
||||
g_context = whisper_init(path_model.c_str());
|
||||
g_context = whisper_init_from_file(path_model.c_str());
|
||||
if (g_context != nullptr) {
|
||||
return true;
|
||||
} else {
|
||||
|
@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "whisper.cpp",
|
||||
"version": "1.0.4",
|
||||
"version": "1.4.0",
|
||||
"description": "Whisper speech recognition",
|
||||
"main": "whisper.js",
|
||||
"scripts": {
|
||||
|
File diff suppressed because one or more lines are too long
7
bindings/ruby/ext/.gitignore
vendored
Normal file
7
bindings/ruby/ext/.gitignore
vendored
Normal file
@ -0,0 +1,7 @@
|
||||
Makefile
|
||||
ggml.c
|
||||
ggml.h
|
||||
whisper.bundle
|
||||
whisper.cpp
|
||||
whisper.h
|
||||
dr_wav.h
|
21
bindings/ruby/ext/extconf.rb
Normal file
21
bindings/ruby/ext/extconf.rb
Normal file
@ -0,0 +1,21 @@
|
||||
require 'mkmf'
|
||||
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.cpp')} .")
|
||||
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.h')} .")
|
||||
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.h')} .")
|
||||
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.c')} .")
|
||||
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','examples','dr_wav.h')} .")
|
||||
|
||||
|
||||
# need to use c++ compiler flags
|
||||
$CXXFLAGS << ' -std=c++11'
|
||||
# Set to true when building binary gems
|
||||
if enable_config('static-stdlib', false)
|
||||
$LDFLAGS << ' -static-libgcc -static-libstdc++'
|
||||
end
|
||||
|
||||
if enable_config('march-tune-native', false)
|
||||
$CFLAGS << ' -march=native -mtune=native'
|
||||
$CXXFLAGS << ' -march=native -mtune=native'
|
||||
end
|
||||
|
||||
create_makefile('whisper')
|
426
bindings/ruby/ext/ruby_whisper.cpp
Normal file
426
bindings/ruby/ext/ruby_whisper.cpp
Normal file
@ -0,0 +1,426 @@
|
||||
#include <ruby.h>
|
||||
#include "ruby_whisper.h"
|
||||
#define DR_WAV_IMPLEMENTATION
|
||||
#include "dr_wav.h"
|
||||
#include <cmath>
|
||||
#include <fstream>
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#define BOOL_PARAMS_SETTER(self, prop, value) \
|
||||
ruby_whisper_params *rwp; \
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp); \
|
||||
if (value == Qfalse || value == Qnil) { \
|
||||
rwp->params.prop = false; \
|
||||
} else { \
|
||||
rwp->params.prop = true; \
|
||||
} \
|
||||
return value; \
|
||||
|
||||
#define BOOL_PARAMS_GETTER(self, prop) \
|
||||
ruby_whisper_params *rwp; \
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp); \
|
||||
if (rwp->params.prop) { \
|
||||
return Qtrue; \
|
||||
} else { \
|
||||
return Qfalse; \
|
||||
}
|
||||
|
||||
VALUE mWhisper;
|
||||
VALUE cContext;
|
||||
VALUE cParams;
|
||||
|
||||
static void ruby_whisper_free(ruby_whisper *rw) {
|
||||
if (rw->context) {
|
||||
whisper_free(rw->context);
|
||||
rw->context = NULL;
|
||||
}
|
||||
}
|
||||
static void ruby_whisper_params_free(ruby_whisper_params *rwp) {
|
||||
}
|
||||
|
||||
void rb_whisper_mark(ruby_whisper *rw) {
|
||||
// call rb_gc_mark on any ruby references in rw
|
||||
}
|
||||
|
||||
void rb_whisper_free(ruby_whisper *rw) {
|
||||
ruby_whisper_free(rw);
|
||||
free(rw);
|
||||
}
|
||||
|
||||
void rb_whisper_params_mark(ruby_whisper_params *rwp) {
|
||||
}
|
||||
|
||||
void rb_whisper_params_free(ruby_whisper_params *rwp) {
|
||||
ruby_whisper_params_free(rwp);
|
||||
free(rwp);
|
||||
}
|
||||
|
||||
static VALUE ruby_whisper_allocate(VALUE klass) {
|
||||
ruby_whisper *rw;
|
||||
rw = ALLOC(ruby_whisper);
|
||||
rw->context = NULL;
|
||||
return Data_Wrap_Struct(klass, rb_whisper_mark, rb_whisper_free, rw);
|
||||
}
|
||||
|
||||
static VALUE ruby_whisper_params_allocate(VALUE klass) {
|
||||
ruby_whisper_params *rwp;
|
||||
rwp = ALLOC(ruby_whisper_params);
|
||||
rwp->params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||
return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp);
|
||||
}
|
||||
|
||||
static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) {
|
||||
ruby_whisper *rw;
|
||||
VALUE whisper_model_file_path;
|
||||
|
||||
// TODO: we can support init from buffer here too maybe another ruby object to expose
|
||||
rb_scan_args(argc, argv, "01", &whisper_model_file_path);
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
|
||||
if (!rb_respond_to(whisper_model_file_path, rb_intern("to_s"))) {
|
||||
rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context");
|
||||
}
|
||||
rw->context = whisper_init_from_file(StringValueCStr(whisper_model_file_path));
|
||||
if (rw->context == nullptr) {
|
||||
rb_raise(rb_eRuntimeError, "error: failed to initialize whisper context");
|
||||
}
|
||||
return self;
|
||||
}
|
||||
|
||||
/*
|
||||
* transcribe a single file
|
||||
* can emit to a block results
|
||||
*
|
||||
**/
|
||||
static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
|
||||
ruby_whisper *rw;
|
||||
ruby_whisper_params *rwp;
|
||||
VALUE wave_file_path, blk, params;
|
||||
|
||||
rb_scan_args(argc, argv, "02&", &wave_file_path, ¶ms, &blk);
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
Data_Get_Struct(params, ruby_whisper_params, rwp);
|
||||
|
||||
if (!rb_respond_to(wave_file_path, rb_intern("to_s"))) {
|
||||
rb_raise(rb_eRuntimeError, "Expected file path to wave file");
|
||||
}
|
||||
|
||||
std::string fname_inp = StringValueCStr(wave_file_path);
|
||||
|
||||
std::vector<float> pcmf32; // mono-channel F32 PCM
|
||||
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
|
||||
|
||||
// WAV input - this is directly from main.cpp example
|
||||
{
|
||||
drwav wav;
|
||||
std::vector<uint8_t> wav_data; // used for pipe input from stdin
|
||||
|
||||
if (fname_inp == "-") {
|
||||
{
|
||||
uint8_t buf[1024];
|
||||
while (true) {
|
||||
const size_t n = fread(buf, 1, sizeof(buf), stdin);
|
||||
if (n == 0) {
|
||||
break;
|
||||
}
|
||||
wav_data.insert(wav_data.end(), buf, buf + n);
|
||||
}
|
||||
}
|
||||
|
||||
if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) {
|
||||
fprintf(stderr, "error: failed to open WAV file from stdin\n");
|
||||
return self;
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
|
||||
} else if (drwav_init_file(&wav, fname_inp.c_str(), nullptr) == false) {
|
||||
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
|
||||
return self;
|
||||
}
|
||||
|
||||
if (wav.channels != 1 && wav.channels != 2) {
|
||||
fprintf(stderr, "WAV file '%s' must be mono or stereo\n", fname_inp.c_str());
|
||||
return self;
|
||||
}
|
||||
|
||||
if (rwp->diarize && wav.channels != 2 && rwp->params.print_timestamps == false) {
|
||||
fprintf(stderr, "WAV file '%s' must be stereo for diarization and timestamps have to be enabled\n", fname_inp.c_str());
|
||||
return self;
|
||||
}
|
||||
|
||||
if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
|
||||
fprintf(stderr, "WAV file '%s' must be %i kHz\n", fname_inp.c_str(), WHISPER_SAMPLE_RATE/1000);
|
||||
return self;
|
||||
}
|
||||
|
||||
if (wav.bitsPerSample != 16) {
|
||||
fprintf(stderr, "WAV file '%s' must be 16-bit\n", fname_inp.c_str());
|
||||
return self;
|
||||
}
|
||||
|
||||
const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8);
|
||||
|
||||
std::vector<int16_t> pcm16;
|
||||
pcm16.resize(n*wav.channels);
|
||||
drwav_read_pcm_frames_s16(&wav, n, pcm16.data());
|
||||
drwav_uninit(&wav);
|
||||
|
||||
// convert to mono, float
|
||||
pcmf32.resize(n);
|
||||
if (wav.channels == 1) {
|
||||
for (uint64_t i = 0; i < n; i++) {
|
||||
pcmf32[i] = float(pcm16[i])/32768.0f;
|
||||
}
|
||||
} else {
|
||||
for (uint64_t i = 0; i < n; i++) {
|
||||
pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
|
||||
}
|
||||
}
|
||||
|
||||
if (rwp->diarize) {
|
||||
// convert to stereo, float
|
||||
pcmf32s.resize(2);
|
||||
|
||||
pcmf32s[0].resize(n);
|
||||
pcmf32s[1].resize(n);
|
||||
for (uint64_t i = 0; i < n; i++) {
|
||||
pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
|
||||
pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
{
|
||||
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
|
||||
|
||||
rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
|
||||
bool is_aborted = *(bool*)user_data;
|
||||
return !is_aborted;
|
||||
};
|
||||
rwp->params.encoder_begin_callback_user_data = &is_aborted;
|
||||
}
|
||||
|
||||
if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) {
|
||||
fprintf(stderr, "failed to process audio\n");
|
||||
return self;
|
||||
}
|
||||
const int n_segments = whisper_full_n_segments(rw->context);
|
||||
VALUE output = rb_str_new2("");
|
||||
for (int i = 0; i < n_segments; ++i) {
|
||||
const char * text = whisper_full_get_segment_text(rw->context, i);
|
||||
output = rb_str_concat(output, rb_str_new2(text));
|
||||
}
|
||||
VALUE idCall = rb_intern("call");
|
||||
if (blk != Qnil) {
|
||||
rb_funcall(blk, idCall, 1, output);
|
||||
}
|
||||
return self;
|
||||
}
|
||||
|
||||
/*
|
||||
* params.language = "auto" | "en", etc...
|
||||
*/
|
||||
static VALUE ruby_whisper_params_set_language(VALUE self, VALUE value) {
|
||||
ruby_whisper_params *rwp;
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
if (value == Qfalse || value == Qnil) {
|
||||
rwp->params.language = "auto";
|
||||
} else {
|
||||
rwp->params.language = StringValueCStr(value);
|
||||
}
|
||||
return value;
|
||||
}
|
||||
static VALUE ruby_whisper_params_get_language(VALUE self) {
|
||||
ruby_whisper_params *rwp;
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
if (rwp->params.language) {
|
||||
return rb_str_new2(rwp->params.language);
|
||||
} else {
|
||||
return rb_str_new2("auto");
|
||||
}
|
||||
}
|
||||
static VALUE ruby_whisper_params_set_translate(VALUE self, VALUE value) {
|
||||
BOOL_PARAMS_SETTER(self, translate, value)
|
||||
}
|
||||
static VALUE ruby_whisper_params_get_translate(VALUE self) {
|
||||
BOOL_PARAMS_GETTER(self, translate)
|
||||
}
|
||||
static VALUE ruby_whisper_params_set_no_context(VALUE self, VALUE value) {
|
||||
BOOL_PARAMS_SETTER(self, no_context, value)
|
||||
}
|
||||
static VALUE ruby_whisper_params_get_no_context(VALUE self) {
|
||||
BOOL_PARAMS_GETTER(self, no_context)
|
||||
}
|
||||
static VALUE ruby_whisper_params_set_single_segment(VALUE self, VALUE value) {
|
||||
BOOL_PARAMS_SETTER(self, single_segment, value)
|
||||
}
|
||||
static VALUE ruby_whisper_params_get_single_segment(VALUE self) {
|
||||
BOOL_PARAMS_GETTER(self, single_segment)
|
||||
}
|
||||
static VALUE ruby_whisper_params_set_print_special(VALUE self, VALUE value) {
|
||||
BOOL_PARAMS_SETTER(self, print_special, value)
|
||||
}
|
||||
static VALUE ruby_whisper_params_get_print_special(VALUE self) {
|
||||
BOOL_PARAMS_GETTER(self, print_special)
|
||||
}
|
||||
static VALUE ruby_whisper_params_set_print_progress(VALUE self, VALUE value) {
|
||||
BOOL_PARAMS_SETTER(self, print_progress, value)
|
||||
}
|
||||
static VALUE ruby_whisper_params_get_print_progress(VALUE self) {
|
||||
BOOL_PARAMS_GETTER(self, print_progress)
|
||||
}
|
||||
static VALUE ruby_whisper_params_set_print_realtime(VALUE self, VALUE value) {
|
||||
BOOL_PARAMS_SETTER(self, print_realtime, value)
|
||||
}
|
||||
static VALUE ruby_whisper_params_get_print_realtime(VALUE self) {
|
||||
BOOL_PARAMS_GETTER(self, print_realtime)
|
||||
}
|
||||
static VALUE ruby_whisper_params_set_print_timestamps(VALUE self, VALUE value) {
|
||||
BOOL_PARAMS_SETTER(self, print_timestamps, value)
|
||||
}
|
||||
static VALUE ruby_whisper_params_get_print_timestamps(VALUE self) {
|
||||
BOOL_PARAMS_GETTER(self, print_timestamps)
|
||||
}
|
||||
static VALUE ruby_whisper_params_set_suppress_blank(VALUE self, VALUE value) {
|
||||
BOOL_PARAMS_SETTER(self, suppress_blank, value)
|
||||
}
|
||||
static VALUE ruby_whisper_params_get_suppress_blank(VALUE self) {
|
||||
BOOL_PARAMS_GETTER(self, suppress_blank)
|
||||
}
|
||||
static VALUE ruby_whisper_params_set_suppress_non_speech_tokens(VALUE self, VALUE value) {
|
||||
BOOL_PARAMS_SETTER(self, suppress_non_speech_tokens, value)
|
||||
}
|
||||
static VALUE ruby_whisper_params_get_suppress_non_speech_tokens(VALUE self) {
|
||||
BOOL_PARAMS_GETTER(self, suppress_non_speech_tokens)
|
||||
}
|
||||
static VALUE ruby_whisper_params_get_token_timestamps(VALUE self) {
|
||||
BOOL_PARAMS_GETTER(self, token_timestamps)
|
||||
}
|
||||
static VALUE ruby_whisper_params_set_token_timestamps(VALUE self, VALUE value) {
|
||||
BOOL_PARAMS_SETTER(self, token_timestamps, value)
|
||||
}
|
||||
static VALUE ruby_whisper_params_get_split_on_word(VALUE self) {
|
||||
BOOL_PARAMS_GETTER(self, split_on_word)
|
||||
}
|
||||
static VALUE ruby_whisper_params_set_split_on_word(VALUE self, VALUE value) {
|
||||
BOOL_PARAMS_SETTER(self, split_on_word, value)
|
||||
}
|
||||
static VALUE ruby_whisper_params_get_speed_up(VALUE self) {
|
||||
BOOL_PARAMS_GETTER(self, speed_up)
|
||||
}
|
||||
static VALUE ruby_whisper_params_set_speed_up(VALUE self, VALUE value) {
|
||||
BOOL_PARAMS_SETTER(self, speed_up, value)
|
||||
}
|
||||
static VALUE ruby_whisper_params_get_diarize(VALUE self) {
|
||||
ruby_whisper_params *rwp;
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
if (rwp->diarize) {
|
||||
return Qtrue;
|
||||
} else {
|
||||
return Qfalse;
|
||||
}
|
||||
}
|
||||
static VALUE ruby_whisper_params_set_diarize(VALUE self, VALUE value) {
|
||||
ruby_whisper_params *rwp;
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
if (value == Qfalse || value == Qnil) {
|
||||
rwp->diarize = false;
|
||||
} else {
|
||||
rwp->diarize = true;
|
||||
} \
|
||||
return value;
|
||||
}
|
||||
|
||||
static VALUE ruby_whisper_params_get_offset(VALUE self) {
|
||||
ruby_whisper_params *rwp;
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
return INT2NUM(rwp->params.offset_ms);
|
||||
}
|
||||
static VALUE ruby_whisper_params_set_offset(VALUE self, VALUE value) {
|
||||
ruby_whisper_params *rwp;
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
rwp->params.offset_ms = NUM2INT(value);
|
||||
return value;
|
||||
}
|
||||
static VALUE ruby_whisper_params_get_duration(VALUE self) {
|
||||
ruby_whisper_params *rwp;
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
return INT2NUM(rwp->params.duration_ms);
|
||||
}
|
||||
static VALUE ruby_whisper_params_set_duration(VALUE self, VALUE value) {
|
||||
ruby_whisper_params *rwp;
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
rwp->params.duration_ms = NUM2INT(value);
|
||||
return value;
|
||||
}
|
||||
|
||||
static VALUE ruby_whisper_params_get_max_text_tokens(VALUE self) {
|
||||
ruby_whisper_params *rwp;
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
return INT2NUM(rwp->params.n_max_text_ctx);
|
||||
}
|
||||
static VALUE ruby_whisper_params_set_max_text_tokens(VALUE self, VALUE value) {
|
||||
ruby_whisper_params *rwp;
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
rwp->params.n_max_text_ctx = NUM2INT(value);
|
||||
return value;
|
||||
}
|
||||
|
||||
void Init_whisper() {
|
||||
mWhisper = rb_define_module("Whisper");
|
||||
cContext = rb_define_class_under(mWhisper, "Context", rb_cObject);
|
||||
cParams = rb_define_class_under(mWhisper, "Params", rb_cObject);
|
||||
|
||||
rb_define_alloc_func(cContext, ruby_whisper_allocate);
|
||||
rb_define_method(cContext, "initialize", ruby_whisper_initialize, -1);
|
||||
|
||||
rb_define_method(cContext, "transcribe", ruby_whisper_transcribe, -1);
|
||||
|
||||
rb_define_alloc_func(cParams, ruby_whisper_params_allocate);
|
||||
|
||||
rb_define_method(cParams, "language=", ruby_whisper_params_set_language, 1);
|
||||
rb_define_method(cParams, "language", ruby_whisper_params_get_language, 0);
|
||||
rb_define_method(cParams, "translate=", ruby_whisper_params_set_translate, 1);
|
||||
rb_define_method(cParams, "translate", ruby_whisper_params_get_translate, 0);
|
||||
rb_define_method(cParams, "no_context=", ruby_whisper_params_set_no_context, 1);
|
||||
rb_define_method(cParams, "no_context", ruby_whisper_params_get_no_context, 0);
|
||||
rb_define_method(cParams, "single_segment=", ruby_whisper_params_set_single_segment, 1);
|
||||
rb_define_method(cParams, "single_segment", ruby_whisper_params_get_single_segment, 0);
|
||||
rb_define_method(cParams, "print_special", ruby_whisper_params_get_print_special, 0);
|
||||
rb_define_method(cParams, "print_special=", ruby_whisper_params_set_print_special, 1);
|
||||
rb_define_method(cParams, "print_progress", ruby_whisper_params_get_print_progress, 0);
|
||||
rb_define_method(cParams, "print_progress=", ruby_whisper_params_set_print_progress, 1);
|
||||
rb_define_method(cParams, "print_realtime", ruby_whisper_params_get_print_realtime, 0);
|
||||
rb_define_method(cParams, "print_realtime=", ruby_whisper_params_set_print_realtime, 1);
|
||||
rb_define_method(cParams, "print_timestamps", ruby_whisper_params_get_print_timestamps, 0);
|
||||
rb_define_method(cParams, "print_timestamps=", ruby_whisper_params_set_print_timestamps, 1);
|
||||
rb_define_method(cParams, "suppress_blank", ruby_whisper_params_get_suppress_blank, 0);
|
||||
rb_define_method(cParams, "suppress_blank=", ruby_whisper_params_set_suppress_blank, 1);
|
||||
rb_define_method(cParams, "suppress_non_speech_tokens", ruby_whisper_params_get_suppress_non_speech_tokens, 0);
|
||||
rb_define_method(cParams, "suppress_non_speech_tokens=", ruby_whisper_params_set_suppress_non_speech_tokens, 1);
|
||||
rb_define_method(cParams, "token_timestamps", ruby_whisper_params_get_token_timestamps, 0);
|
||||
rb_define_method(cParams, "token_timestamps=", ruby_whisper_params_set_token_timestamps, 1);
|
||||
rb_define_method(cParams, "split_on_word", ruby_whisper_params_get_split_on_word, 0);
|
||||
rb_define_method(cParams, "split_on_word=", ruby_whisper_params_set_split_on_word, 1);
|
||||
rb_define_method(cParams, "speed_up", ruby_whisper_params_get_speed_up, 0);
|
||||
rb_define_method(cParams, "speed_up=", ruby_whisper_params_set_speed_up, 1);
|
||||
rb_define_method(cParams, "diarize", ruby_whisper_params_get_diarize, 0);
|
||||
rb_define_method(cParams, "diarize=", ruby_whisper_params_set_diarize, 1);
|
||||
|
||||
rb_define_method(cParams, "offset", ruby_whisper_params_get_offset, 0);
|
||||
rb_define_method(cParams, "offset=", ruby_whisper_params_set_offset, 1);
|
||||
rb_define_method(cParams, "duration", ruby_whisper_params_get_duration, 0);
|
||||
rb_define_method(cParams, "duration=", ruby_whisper_params_set_duration, 1);
|
||||
|
||||
rb_define_method(cParams, "max_text_tokens", ruby_whisper_params_get_max_text_tokens, 0);
|
||||
rb_define_method(cParams, "max_text_tokens=", ruby_whisper_params_set_max_text_tokens, 1);
|
||||
}
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
15
bindings/ruby/ext/ruby_whisper.h
Normal file
15
bindings/ruby/ext/ruby_whisper.h
Normal file
@ -0,0 +1,15 @@
|
||||
#ifndef __RUBY_WHISPER_H
|
||||
#define __RUBY_WHISPER_H
|
||||
|
||||
#include "whisper.h"
|
||||
|
||||
typedef struct {
|
||||
struct whisper_context *context;
|
||||
} ruby_whisper;
|
||||
|
||||
typedef struct {
|
||||
struct whisper_full_params params;
|
||||
bool diarize;
|
||||
} ruby_whisper_params;
|
||||
|
||||
#endif
|
138
bindings/ruby/tests/test_whisper.rb
Normal file
138
bindings/ruby/tests/test_whisper.rb
Normal file
@ -0,0 +1,138 @@
|
||||
TOPDIR = File.expand_path(File.join(File.dirname(__FILE__), '..'))
|
||||
EXTDIR = File.join(TOPDIR, 'ext')
|
||||
#$LIBDIR = File.join(TOPDIR, 'lib')
|
||||
#$:.unshift(LIBDIR)
|
||||
$:.unshift(EXTDIR)
|
||||
|
||||
require 'whisper'
|
||||
require 'test/unit'
|
||||
|
||||
class TestWhisper < Test::Unit::TestCase
|
||||
def setup
|
||||
@params = Whisper::Params.new
|
||||
end
|
||||
|
||||
def test_language
|
||||
@params.language = "en"
|
||||
assert_equal @params.language, "en"
|
||||
@params.language = "auto"
|
||||
assert_equal @params.language, "auto"
|
||||
end
|
||||
|
||||
def test_offset
|
||||
@params.offset = 10_000
|
||||
assert_equal @params.offset, 10_000
|
||||
@params.offset = 0
|
||||
assert_equal @params.offset, 0
|
||||
end
|
||||
|
||||
def test_duration
|
||||
@params.duration = 60_000
|
||||
assert_equal @params.duration, 60_000
|
||||
@params.duration = 0
|
||||
assert_equal @params.duration, 0
|
||||
end
|
||||
|
||||
def test_max_text_tokens
|
||||
@params.max_text_tokens = 300
|
||||
assert_equal @params.max_text_tokens, 300
|
||||
@params.max_text_tokens = 0
|
||||
assert_equal @params.max_text_tokens, 0
|
||||
end
|
||||
|
||||
def test_translate
|
||||
@params.translate = true
|
||||
assert @params.translate
|
||||
@params.translate = false
|
||||
assert !@params.translate
|
||||
end
|
||||
|
||||
def test_no_context
|
||||
@params.no_context = true
|
||||
assert @params.no_context
|
||||
@params.no_context = false
|
||||
assert !@params.no_context
|
||||
end
|
||||
|
||||
def test_single_segment
|
||||
@params.single_segment = true
|
||||
assert @params.single_segment
|
||||
@params.single_segment = false
|
||||
assert !@params.single_segment
|
||||
end
|
||||
|
||||
def test_print_special
|
||||
@params.print_special = true
|
||||
assert @params.print_special
|
||||
@params.print_special = false
|
||||
assert !@params.print_special
|
||||
end
|
||||
|
||||
def test_print_progress
|
||||
@params.print_progress = true
|
||||
assert @params.print_progress
|
||||
@params.print_progress = false
|
||||
assert !@params.print_progress
|
||||
end
|
||||
|
||||
def test_print_realtime
|
||||
@params.print_realtime = true
|
||||
assert @params.print_realtime
|
||||
@params.print_realtime = false
|
||||
assert !@params.print_realtime
|
||||
end
|
||||
|
||||
def test_print_timestamps
|
||||
@params.print_timestamps = true
|
||||
assert @params.print_timestamps
|
||||
@params.print_timestamps = false
|
||||
assert !@params.print_timestamps
|
||||
end
|
||||
|
||||
def test_suppress_blank
|
||||
@params.suppress_blank = true
|
||||
assert @params.suppress_blank
|
||||
@params.suppress_blank = false
|
||||
assert !@params.suppress_blank
|
||||
end
|
||||
|
||||
def test_suppress_non_speech_tokens
|
||||
@params.suppress_non_speech_tokens = true
|
||||
assert @params.suppress_non_speech_tokens
|
||||
@params.suppress_non_speech_tokens = false
|
||||
assert !@params.suppress_non_speech_tokens
|
||||
end
|
||||
|
||||
def test_token_timestamps
|
||||
@params.token_timestamps = true
|
||||
assert @params.token_timestamps
|
||||
@params.token_timestamps = false
|
||||
assert !@params.token_timestamps
|
||||
end
|
||||
|
||||
def test_split_on_word
|
||||
@params.split_on_word = true
|
||||
assert @params.split_on_word
|
||||
@params.split_on_word = false
|
||||
assert !@params.split_on_word
|
||||
end
|
||||
|
||||
def test_speed_up
|
||||
@params.speed_up = true
|
||||
assert @params.speed_up
|
||||
@params.speed_up = false
|
||||
assert !@params.speed_up
|
||||
end
|
||||
|
||||
def test_whisper
|
||||
@whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin'))
|
||||
params = Whisper::Params.new
|
||||
params.print_timestamps = false
|
||||
|
||||
jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav')
|
||||
@whisper.transcribe(jfk, params) {|text|
|
||||
assert_match /ask not what your country can do for you, ask what you can do for your country/, text
|
||||
}
|
||||
end
|
||||
|
||||
end
|
17
cmake/DefaultTargetOptions.cmake
Normal file
17
cmake/DefaultTargetOptions.cmake
Normal file
@ -0,0 +1,17 @@
|
||||
# Set the default compile features and properties for a target.
|
||||
|
||||
if (NOT TARGET)
|
||||
message(FATAL_ERROR "TARGET not set before including DefaultTargetOptions")
|
||||
endif()
|
||||
|
||||
target_compile_features(${TARGET}
|
||||
PRIVATE
|
||||
cxx_std_11
|
||||
)
|
||||
|
||||
set_target_properties(${TARGET}
|
||||
PROPERTIES
|
||||
EXPORT_COMPILE_COMMANDS ON
|
||||
RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin"
|
||||
INSTALL_RPATH "${CMAKE_INSTALL_PREFIX}/lib"
|
||||
)
|
146
coreml/whisper-decoder-impl.h
Normal file
146
coreml/whisper-decoder-impl.h
Normal file
@ -0,0 +1,146 @@
|
||||
//
|
||||
// whisper-decoder-impl.h
|
||||
//
|
||||
// This file was automatically generated and should not be edited.
|
||||
//
|
||||
|
||||
#import <Foundation/Foundation.h>
|
||||
#import <CoreML/CoreML.h>
|
||||
#include <stdint.h>
|
||||
#include <os/log.h>
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
|
||||
/// Model Prediction Input Type
|
||||
API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
|
||||
@interface whisper_decoder_implInput : NSObject<MLFeatureProvider>
|
||||
|
||||
/// token_data as 1 by 1 matrix of 32-bit integers
|
||||
@property (readwrite, nonatomic, strong) MLMultiArray * token_data;
|
||||
|
||||
/// audio_data as 1 × 384 × 1 × 1500 4-dimensional array of floats
|
||||
@property (readwrite, nonatomic, strong) MLMultiArray * audio_data;
|
||||
- (instancetype)init NS_UNAVAILABLE;
|
||||
- (instancetype)initWithToken_data:(MLMultiArray *)token_data audio_data:(MLMultiArray *)audio_data NS_DESIGNATED_INITIALIZER;
|
||||
|
||||
@end
|
||||
|
||||
|
||||
/// Model Prediction Output Type
|
||||
API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
|
||||
@interface whisper_decoder_implOutput : NSObject<MLFeatureProvider>
|
||||
|
||||
/// var_1346 as multidimensional array of floats
|
||||
@property (readwrite, nonatomic, strong) MLMultiArray * var_1346;
|
||||
- (instancetype)init NS_UNAVAILABLE;
|
||||
- (instancetype)initWithVar_1346:(MLMultiArray *)var_1346 NS_DESIGNATED_INITIALIZER;
|
||||
|
||||
@end
|
||||
|
||||
|
||||
/// Class for model loading and prediction
|
||||
API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
|
||||
@interface whisper_decoder_impl : NSObject
|
||||
@property (readonly, nonatomic, nullable) MLModel * model;
|
||||
|
||||
/**
|
||||
URL of the underlying .mlmodelc directory.
|
||||
*/
|
||||
+ (nullable NSURL *)URLOfModelInThisBundle;
|
||||
|
||||
/**
|
||||
Initialize whisper_decoder_impl instance from an existing MLModel object.
|
||||
|
||||
Usually the application does not use this initializer unless it makes a subclass of whisper_decoder_impl.
|
||||
Such application may want to use `-[MLModel initWithContentsOfURL:configuration:error:]` and `+URLOfModelInThisBundle` to create a MLModel object to pass-in.
|
||||
*/
|
||||
- (instancetype)initWithMLModel:(MLModel *)model NS_DESIGNATED_INITIALIZER;
|
||||
|
||||
/**
|
||||
Initialize whisper_decoder_impl instance with the model in this bundle.
|
||||
*/
|
||||
- (nullable instancetype)init;
|
||||
|
||||
/**
|
||||
Initialize whisper_decoder_impl instance with the model in this bundle.
|
||||
|
||||
@param configuration The model configuration object
|
||||
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
|
||||
*/
|
||||
- (nullable instancetype)initWithConfiguration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error;
|
||||
|
||||
/**
|
||||
Initialize whisper_decoder_impl instance from the model URL.
|
||||
|
||||
@param modelURL URL to the .mlmodelc directory for whisper_decoder_impl.
|
||||
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
|
||||
*/
|
||||
- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL error:(NSError * _Nullable __autoreleasing * _Nullable)error;
|
||||
|
||||
/**
|
||||
Initialize whisper_decoder_impl instance from the model URL.
|
||||
|
||||
@param modelURL URL to the .mlmodelc directory for whisper_decoder_impl.
|
||||
@param configuration The model configuration object
|
||||
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
|
||||
*/
|
||||
- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error;
|
||||
|
||||
/**
|
||||
Construct whisper_decoder_impl instance asynchronously with configuration.
|
||||
Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread.
|
||||
|
||||
@param configuration The model configuration
|
||||
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_decoder_impl instance or NSError object.
|
||||
*/
|
||||
+ (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_decoder_impl * _Nullable model, NSError * _Nullable error))handler;
|
||||
|
||||
/**
|
||||
Construct whisper_decoder_impl instance asynchronously with URL of .mlmodelc directory and optional configuration.
|
||||
|
||||
Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread.
|
||||
|
||||
@param modelURL The model URL.
|
||||
@param configuration The model configuration
|
||||
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_decoder_impl instance or NSError object.
|
||||
*/
|
||||
+ (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_decoder_impl * _Nullable model, NSError * _Nullable error))handler;
|
||||
|
||||
/**
|
||||
Make a prediction using the standard interface
|
||||
@param input an instance of whisper_decoder_implInput to predict from
|
||||
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
|
||||
@return the prediction as whisper_decoder_implOutput
|
||||
*/
|
||||
- (nullable whisper_decoder_implOutput *)predictionFromFeatures:(whisper_decoder_implInput *)input error:(NSError * _Nullable __autoreleasing * _Nullable)error;
|
||||
|
||||
/**
|
||||
Make a prediction using the standard interface
|
||||
@param input an instance of whisper_decoder_implInput to predict from
|
||||
@param options prediction options
|
||||
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
|
||||
@return the prediction as whisper_decoder_implOutput
|
||||
*/
|
||||
- (nullable whisper_decoder_implOutput *)predictionFromFeatures:(whisper_decoder_implInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error;
|
||||
|
||||
/**
|
||||
Make a prediction using the convenience interface
|
||||
@param token_data as 1 by 1 matrix of 32-bit integers:
|
||||
@param audio_data as 1 × 384 × 1 × 1500 4-dimensional array of floats:
|
||||
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
|
||||
@return the prediction as whisper_decoder_implOutput
|
||||
*/
|
||||
- (nullable whisper_decoder_implOutput *)predictionFromToken_data:(MLMultiArray *)token_data audio_data:(MLMultiArray *)audio_data error:(NSError * _Nullable __autoreleasing * _Nullable)error;
|
||||
|
||||
/**
|
||||
Batch prediction
|
||||
@param inputArray array of whisper_decoder_implInput instances to obtain predictions from
|
||||
@param options prediction options
|
||||
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
|
||||
@return the predictions as NSArray<whisper_decoder_implOutput *>
|
||||
*/
|
||||
- (nullable NSArray<whisper_decoder_implOutput *> *)predictionsFromInputs:(NSArray<whisper_decoder_implInput*> *)inputArray options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error;
|
||||
@end
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
201
coreml/whisper-decoder-impl.m
Normal file
201
coreml/whisper-decoder-impl.m
Normal file
@ -0,0 +1,201 @@
|
||||
//
|
||||
// whisper-decoder-impl.m
|
||||
//
|
||||
// This file was automatically generated and should not be edited.
|
||||
//
|
||||
|
||||
#if !__has_feature(objc_arc)
|
||||
#error This file must be compiled with automatic reference counting enabled (-fobjc-arc)
|
||||
#endif
|
||||
|
||||
#import "whisper-decoder-impl.h"
|
||||
|
||||
@implementation whisper_decoder_implInput
|
||||
|
||||
- (instancetype)initWithToken_data:(MLMultiArray *)token_data audio_data:(MLMultiArray *)audio_data {
|
||||
self = [super init];
|
||||
if (self) {
|
||||
_token_data = token_data;
|
||||
_audio_data = audio_data;
|
||||
}
|
||||
return self;
|
||||
}
|
||||
|
||||
- (NSSet<NSString *> *)featureNames {
|
||||
return [NSSet setWithArray:@[@"token_data", @"audio_data"]];
|
||||
}
|
||||
|
||||
- (nullable MLFeatureValue *)featureValueForName:(NSString *)featureName {
|
||||
if ([featureName isEqualToString:@"token_data"]) {
|
||||
return [MLFeatureValue featureValueWithMultiArray:self.token_data];
|
||||
}
|
||||
if ([featureName isEqualToString:@"audio_data"]) {
|
||||
return [MLFeatureValue featureValueWithMultiArray:self.audio_data];
|
||||
}
|
||||
return nil;
|
||||
}
|
||||
|
||||
@end
|
||||
|
||||
@implementation whisper_decoder_implOutput
|
||||
|
||||
- (instancetype)initWithVar_1346:(MLMultiArray *)var_1346 {
|
||||
self = [super init];
|
||||
if (self) {
|
||||
_var_1346 = var_1346;
|
||||
}
|
||||
return self;
|
||||
}
|
||||
|
||||
- (NSSet<NSString *> *)featureNames {
|
||||
return [NSSet setWithArray:@[@"var_1346"]];
|
||||
}
|
||||
|
||||
- (nullable MLFeatureValue *)featureValueForName:(NSString *)featureName {
|
||||
if ([featureName isEqualToString:@"var_1346"]) {
|
||||
return [MLFeatureValue featureValueWithMultiArray:self.var_1346];
|
||||
}
|
||||
return nil;
|
||||
}
|
||||
|
||||
@end
|
||||
|
||||
@implementation whisper_decoder_impl
|
||||
|
||||
|
||||
/**
|
||||
URL of the underlying .mlmodelc directory.
|
||||
*/
|
||||
+ (nullable NSURL *)URLOfModelInThisBundle {
|
||||
NSString *assetPath = [[NSBundle bundleForClass:[self class]] pathForResource:@"whisper_decoder_impl" ofType:@"mlmodelc"];
|
||||
if (nil == assetPath) { os_log_error(OS_LOG_DEFAULT, "Could not load whisper-decoder-impl.mlmodelc in the bundle resource"); return nil; }
|
||||
return [NSURL fileURLWithPath:assetPath];
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
Initialize whisper_decoder_impl instance from an existing MLModel object.
|
||||
|
||||
Usually the application does not use this initializer unless it makes a subclass of whisper_decoder_impl.
|
||||
Such application may want to use `-[MLModel initWithContentsOfURL:configuration:error:]` and `+URLOfModelInThisBundle` to create a MLModel object to pass-in.
|
||||
*/
|
||||
- (instancetype)initWithMLModel:(MLModel *)model {
|
||||
self = [super init];
|
||||
if (!self) { return nil; }
|
||||
_model = model;
|
||||
if (_model == nil) { return nil; }
|
||||
return self;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
Initialize whisper_decoder_impl instance with the model in this bundle.
|
||||
*/
|
||||
- (nullable instancetype)init {
|
||||
return [self initWithContentsOfURL:(NSURL * _Nonnull)self.class.URLOfModelInThisBundle error:nil];
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
Initialize whisper_decoder_impl instance with the model in this bundle.
|
||||
|
||||
@param configuration The model configuration object
|
||||
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
|
||||
*/
|
||||
- (nullable instancetype)initWithConfiguration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error {
|
||||
return [self initWithContentsOfURL:(NSURL * _Nonnull)self.class.URLOfModelInThisBundle configuration:configuration error:error];
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
Initialize whisper_decoder_impl instance from the model URL.
|
||||
|
||||
@param modelURL URL to the .mlmodelc directory for whisper_decoder_impl.
|
||||
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
|
||||
*/
|
||||
- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL error:(NSError * _Nullable __autoreleasing * _Nullable)error {
|
||||
MLModel *model = [MLModel modelWithContentsOfURL:modelURL error:error];
|
||||
if (model == nil) { return nil; }
|
||||
return [self initWithMLModel:model];
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
Initialize whisper_decoder_impl instance from the model URL.
|
||||
|
||||
@param modelURL URL to the .mlmodelc directory for whisper_decoder_impl.
|
||||
@param configuration The model configuration object
|
||||
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
|
||||
*/
|
||||
- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error {
|
||||
MLModel *model = [MLModel modelWithContentsOfURL:modelURL configuration:configuration error:error];
|
||||
if (model == nil) { return nil; }
|
||||
return [self initWithMLModel:model];
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
Construct whisper_decoder_impl instance asynchronously with configuration.
|
||||
Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread.
|
||||
|
||||
@param configuration The model configuration
|
||||
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_decoder_impl instance or NSError object.
|
||||
*/
|
||||
+ (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_decoder_impl * _Nullable model, NSError * _Nullable error))handler {
|
||||
[self loadContentsOfURL:(NSURL * _Nonnull)[self URLOfModelInThisBundle]
|
||||
configuration:configuration
|
||||
completionHandler:handler];
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
Construct whisper_decoder_impl instance asynchronously with URL of .mlmodelc directory and optional configuration.
|
||||
|
||||
Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread.
|
||||
|
||||
@param modelURL The model URL.
|
||||
@param configuration The model configuration
|
||||
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_decoder_impl instance or NSError object.
|
||||
*/
|
||||
+ (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_decoder_impl * _Nullable model, NSError * _Nullable error))handler {
|
||||
[MLModel loadContentsOfURL:modelURL
|
||||
configuration:configuration
|
||||
completionHandler:^(MLModel *model, NSError *error) {
|
||||
if (model != nil) {
|
||||
whisper_decoder_impl *typedModel = [[whisper_decoder_impl alloc] initWithMLModel:model];
|
||||
handler(typedModel, nil);
|
||||
} else {
|
||||
handler(nil, error);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
- (nullable whisper_decoder_implOutput *)predictionFromFeatures:(whisper_decoder_implInput *)input error:(NSError * _Nullable __autoreleasing * _Nullable)error {
|
||||
return [self predictionFromFeatures:input options:[[MLPredictionOptions alloc] init] error:error];
|
||||
}
|
||||
|
||||
- (nullable whisper_decoder_implOutput *)predictionFromFeatures:(whisper_decoder_implInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error {
|
||||
id<MLFeatureProvider> outFeatures = [self.model predictionFromFeatures:input options:options error:error];
|
||||
if (!outFeatures) { return nil; }
|
||||
return [[whisper_decoder_implOutput alloc] initWithVar_1346:(MLMultiArray *)[outFeatures featureValueForName:@"var_1346"].multiArrayValue];
|
||||
}
|
||||
|
||||
- (nullable whisper_decoder_implOutput *)predictionFromToken_data:(MLMultiArray *)token_data audio_data:(MLMultiArray *)audio_data error:(NSError * _Nullable __autoreleasing * _Nullable)error {
|
||||
whisper_decoder_implInput *input_ = [[whisper_decoder_implInput alloc] initWithToken_data:token_data audio_data:audio_data];
|
||||
return [self predictionFromFeatures:input_ error:error];
|
||||
}
|
||||
|
||||
- (nullable NSArray<whisper_decoder_implOutput *> *)predictionsFromInputs:(NSArray<whisper_decoder_implInput*> *)inputArray options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error {
|
||||
id<MLBatchProvider> inBatch = [[MLArrayBatchProvider alloc] initWithFeatureProviderArray:inputArray];
|
||||
id<MLBatchProvider> outBatch = [self.model predictionsFromBatch:inBatch options:options error:error];
|
||||
if (!outBatch) { return nil; }
|
||||
NSMutableArray<whisper_decoder_implOutput*> *results = [NSMutableArray arrayWithCapacity:(NSUInteger)outBatch.count];
|
||||
for (NSInteger i = 0; i < outBatch.count; i++) {
|
||||
id<MLFeatureProvider> resultProvider = [outBatch featuresAtIndex:i];
|
||||
whisper_decoder_implOutput * result = [[whisper_decoder_implOutput alloc] initWithVar_1346:(MLMultiArray *)[resultProvider featureValueForName:@"var_1346"].multiArrayValue];
|
||||
[results addObject:result];
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
@end
|
142
coreml/whisper-encoder-impl.h
Normal file
142
coreml/whisper-encoder-impl.h
Normal file
@ -0,0 +1,142 @@
|
||||
//
|
||||
// whisper-encoder-impl.h
|
||||
//
|
||||
// This file was automatically generated and should not be edited.
|
||||
//
|
||||
|
||||
#import <Foundation/Foundation.h>
|
||||
#import <CoreML/CoreML.h>
|
||||
#include <stdint.h>
|
||||
#include <os/log.h>
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
|
||||
/// Model Prediction Input Type
|
||||
API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
|
||||
@interface whisper_encoder_implInput : NSObject<MLFeatureProvider>
|
||||
|
||||
/// logmel_data as 1 × 80 × 3000 3-dimensional array of floats
|
||||
@property (readwrite, nonatomic, strong) MLMultiArray * logmel_data;
|
||||
- (instancetype)init NS_UNAVAILABLE;
|
||||
- (instancetype)initWithLogmel_data:(MLMultiArray *)logmel_data NS_DESIGNATED_INITIALIZER;
|
||||
|
||||
@end
|
||||
|
||||
|
||||
/// Model Prediction Output Type
|
||||
API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
|
||||
@interface whisper_encoder_implOutput : NSObject<MLFeatureProvider>
|
||||
|
||||
/// output as multidimensional array of floats
|
||||
@property (readwrite, nonatomic, strong) MLMultiArray * output;
|
||||
- (instancetype)init NS_UNAVAILABLE;
|
||||
- (instancetype)initWithOutput:(MLMultiArray *)output NS_DESIGNATED_INITIALIZER;
|
||||
|
||||
@end
|
||||
|
||||
|
||||
/// Class for model loading and prediction
|
||||
API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
|
||||
@interface whisper_encoder_impl : NSObject
|
||||
@property (readonly, nonatomic, nullable) MLModel * model;
|
||||
|
||||
/**
|
||||
URL of the underlying .mlmodelc directory.
|
||||
*/
|
||||
+ (nullable NSURL *)URLOfModelInThisBundle;
|
||||
|
||||
/**
|
||||
Initialize whisper_encoder_impl instance from an existing MLModel object.
|
||||
|
||||
Usually the application does not use this initializer unless it makes a subclass of whisper_encoder_impl.
|
||||
Such application may want to use `-[MLModel initWithContentsOfURL:configuration:error:]` and `+URLOfModelInThisBundle` to create a MLModel object to pass-in.
|
||||
*/
|
||||
- (instancetype)initWithMLModel:(MLModel *)model NS_DESIGNATED_INITIALIZER;
|
||||
|
||||
/**
|
||||
Initialize whisper_encoder_impl instance with the model in this bundle.
|
||||
*/
|
||||
- (nullable instancetype)init;
|
||||
|
||||
/**
|
||||
Initialize whisper_encoder_impl instance with the model in this bundle.
|
||||
|
||||
@param configuration The model configuration object
|
||||
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
|
||||
*/
|
||||
- (nullable instancetype)initWithConfiguration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error;
|
||||
|
||||
/**
|
||||
Initialize whisper_encoder_impl instance from the model URL.
|
||||
|
||||
@param modelURL URL to the .mlmodelc directory for whisper_encoder_impl.
|
||||
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
|
||||
*/
|
||||
- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL error:(NSError * _Nullable __autoreleasing * _Nullable)error;
|
||||
|
||||
/**
|
||||
Initialize whisper_encoder_impl instance from the model URL.
|
||||
|
||||
@param modelURL URL to the .mlmodelc directory for whisper_encoder_impl.
|
||||
@param configuration The model configuration object
|
||||
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
|
||||
*/
|
||||
- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error;
|
||||
|
||||
/**
|
||||
Construct whisper_encoder_impl instance asynchronously with configuration.
|
||||
Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread.
|
||||
|
||||
@param configuration The model configuration
|
||||
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_encoder_impl instance or NSError object.
|
||||
*/
|
||||
+ (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_encoder_impl * _Nullable model, NSError * _Nullable error))handler;
|
||||
|
||||
/**
|
||||
Construct whisper_encoder_impl instance asynchronously with URL of .mlmodelc directory and optional configuration.
|
||||
|
||||
Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread.
|
||||
|
||||
@param modelURL The model URL.
|
||||
@param configuration The model configuration
|
||||
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_encoder_impl instance or NSError object.
|
||||
*/
|
||||
+ (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_encoder_impl * _Nullable model, NSError * _Nullable error))handler;
|
||||
|
||||
/**
|
||||
Make a prediction using the standard interface
|
||||
@param input an instance of whisper_encoder_implInput to predict from
|
||||
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
|
||||
@return the prediction as whisper_encoder_implOutput
|
||||
*/
|
||||
- (nullable whisper_encoder_implOutput *)predictionFromFeatures:(whisper_encoder_implInput *)input error:(NSError * _Nullable __autoreleasing * _Nullable)error;
|
||||
|
||||
/**
|
||||
Make a prediction using the standard interface
|
||||
@param input an instance of whisper_encoder_implInput to predict from
|
||||
@param options prediction options
|
||||
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
|
||||
@return the prediction as whisper_encoder_implOutput
|
||||
*/
|
||||
- (nullable whisper_encoder_implOutput *)predictionFromFeatures:(whisper_encoder_implInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error;
|
||||
|
||||
/**
|
||||
Make a prediction using the convenience interface
|
||||
@param logmel_data as 1 × 80 × 3000 3-dimensional array of floats:
|
||||
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
|
||||
@return the prediction as whisper_encoder_implOutput
|
||||
*/
|
||||
- (nullable whisper_encoder_implOutput *)predictionFromLogmel_data:(MLMultiArray *)logmel_data error:(NSError * _Nullable __autoreleasing * _Nullable)error;
|
||||
|
||||
/**
|
||||
Batch prediction
|
||||
@param inputArray array of whisper_encoder_implInput instances to obtain predictions from
|
||||
@param options prediction options
|
||||
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
|
||||
@return the predictions as NSArray<whisper_encoder_implOutput *>
|
||||
*/
|
||||
- (nullable NSArray<whisper_encoder_implOutput *> *)predictionsFromInputs:(NSArray<whisper_encoder_implInput*> *)inputArray options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error;
|
||||
@end
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
197
coreml/whisper-encoder-impl.m
Normal file
197
coreml/whisper-encoder-impl.m
Normal file
@ -0,0 +1,197 @@
|
||||
//
|
||||
// whisper-encoder-impl.m
|
||||
//
|
||||
// This file was automatically generated and should not be edited.
|
||||
//
|
||||
|
||||
#if !__has_feature(objc_arc)
|
||||
#error This file must be compiled with automatic reference counting enabled (-fobjc-arc)
|
||||
#endif
|
||||
|
||||
#import "whisper-encoder-impl.h"
|
||||
|
||||
@implementation whisper_encoder_implInput
|
||||
|
||||
- (instancetype)initWithLogmel_data:(MLMultiArray *)logmel_data {
|
||||
self = [super init];
|
||||
if (self) {
|
||||
_logmel_data = logmel_data;
|
||||
}
|
||||
return self;
|
||||
}
|
||||
|
||||
- (NSSet<NSString *> *)featureNames {
|
||||
return [NSSet setWithArray:@[@"logmel_data"]];
|
||||
}
|
||||
|
||||
- (nullable MLFeatureValue *)featureValueForName:(NSString *)featureName {
|
||||
if ([featureName isEqualToString:@"logmel_data"]) {
|
||||
return [MLFeatureValue featureValueWithMultiArray:self.logmel_data];
|
||||
}
|
||||
return nil;
|
||||
}
|
||||
|
||||
@end
|
||||
|
||||
@implementation whisper_encoder_implOutput
|
||||
|
||||
- (instancetype)initWithOutput:(MLMultiArray *)output {
|
||||
self = [super init];
|
||||
if (self) {
|
||||
_output = output;
|
||||
}
|
||||
return self;
|
||||
}
|
||||
|
||||
- (NSSet<NSString *> *)featureNames {
|
||||
return [NSSet setWithArray:@[@"output"]];
|
||||
}
|
||||
|
||||
- (nullable MLFeatureValue *)featureValueForName:(NSString *)featureName {
|
||||
if ([featureName isEqualToString:@"output"]) {
|
||||
return [MLFeatureValue featureValueWithMultiArray:self.output];
|
||||
}
|
||||
return nil;
|
||||
}
|
||||
|
||||
@end
|
||||
|
||||
@implementation whisper_encoder_impl
|
||||
|
||||
|
||||
/**
|
||||
URL of the underlying .mlmodelc directory.
|
||||
*/
|
||||
+ (nullable NSURL *)URLOfModelInThisBundle {
|
||||
NSString *assetPath = [[NSBundle bundleForClass:[self class]] pathForResource:@"whisper_encoder_impl" ofType:@"mlmodelc"];
|
||||
if (nil == assetPath) { os_log_error(OS_LOG_DEFAULT, "Could not load whisper-encoder-impl.mlmodelc in the bundle resource"); return nil; }
|
||||
return [NSURL fileURLWithPath:assetPath];
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
Initialize whisper_encoder_impl instance from an existing MLModel object.
|
||||
|
||||
Usually the application does not use this initializer unless it makes a subclass of whisper_encoder_impl.
|
||||
Such application may want to use `-[MLModel initWithContentsOfURL:configuration:error:]` and `+URLOfModelInThisBundle` to create a MLModel object to pass-in.
|
||||
*/
|
||||
- (instancetype)initWithMLModel:(MLModel *)model {
|
||||
self = [super init];
|
||||
if (!self) { return nil; }
|
||||
_model = model;
|
||||
if (_model == nil) { return nil; }
|
||||
return self;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
Initialize whisper_encoder_impl instance with the model in this bundle.
|
||||
*/
|
||||
- (nullable instancetype)init {
|
||||
return [self initWithContentsOfURL:(NSURL * _Nonnull)self.class.URLOfModelInThisBundle error:nil];
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
Initialize whisper_encoder_impl instance with the model in this bundle.
|
||||
|
||||
@param configuration The model configuration object
|
||||
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
|
||||
*/
|
||||
- (nullable instancetype)initWithConfiguration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error {
|
||||
return [self initWithContentsOfURL:(NSURL * _Nonnull)self.class.URLOfModelInThisBundle configuration:configuration error:error];
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
Initialize whisper_encoder_impl instance from the model URL.
|
||||
|
||||
@param modelURL URL to the .mlmodelc directory for whisper_encoder_impl.
|
||||
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
|
||||
*/
|
||||
- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL error:(NSError * _Nullable __autoreleasing * _Nullable)error {
|
||||
MLModel *model = [MLModel modelWithContentsOfURL:modelURL error:error];
|
||||
if (model == nil) { return nil; }
|
||||
return [self initWithMLModel:model];
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
Initialize whisper_encoder_impl instance from the model URL.
|
||||
|
||||
@param modelURL URL to the .mlmodelc directory for whisper_encoder_impl.
|
||||
@param configuration The model configuration object
|
||||
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
|
||||
*/
|
||||
- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error {
|
||||
MLModel *model = [MLModel modelWithContentsOfURL:modelURL configuration:configuration error:error];
|
||||
if (model == nil) { return nil; }
|
||||
return [self initWithMLModel:model];
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
Construct whisper_encoder_impl instance asynchronously with configuration.
|
||||
Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread.
|
||||
|
||||
@param configuration The model configuration
|
||||
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_encoder_impl instance or NSError object.
|
||||
*/
|
||||
+ (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_encoder_impl * _Nullable model, NSError * _Nullable error))handler {
|
||||
[self loadContentsOfURL:(NSURL * _Nonnull)[self URLOfModelInThisBundle]
|
||||
configuration:configuration
|
||||
completionHandler:handler];
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
Construct whisper_encoder_impl instance asynchronously with URL of .mlmodelc directory and optional configuration.
|
||||
|
||||
Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread.
|
||||
|
||||
@param modelURL The model URL.
|
||||
@param configuration The model configuration
|
||||
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_encoder_impl instance or NSError object.
|
||||
*/
|
||||
+ (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_encoder_impl * _Nullable model, NSError * _Nullable error))handler {
|
||||
[MLModel loadContentsOfURL:modelURL
|
||||
configuration:configuration
|
||||
completionHandler:^(MLModel *model, NSError *error) {
|
||||
if (model != nil) {
|
||||
whisper_encoder_impl *typedModel = [[whisper_encoder_impl alloc] initWithMLModel:model];
|
||||
handler(typedModel, nil);
|
||||
} else {
|
||||
handler(nil, error);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
- (nullable whisper_encoder_implOutput *)predictionFromFeatures:(whisper_encoder_implInput *)input error:(NSError * _Nullable __autoreleasing * _Nullable)error {
|
||||
return [self predictionFromFeatures:input options:[[MLPredictionOptions alloc] init] error:error];
|
||||
}
|
||||
|
||||
- (nullable whisper_encoder_implOutput *)predictionFromFeatures:(whisper_encoder_implInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error {
|
||||
id<MLFeatureProvider> outFeatures = [self.model predictionFromFeatures:input options:options error:error];
|
||||
if (!outFeatures) { return nil; }
|
||||
return [[whisper_encoder_implOutput alloc] initWithOutput:(MLMultiArray *)[outFeatures featureValueForName:@"output"].multiArrayValue];
|
||||
}
|
||||
|
||||
- (nullable whisper_encoder_implOutput *)predictionFromLogmel_data:(MLMultiArray *)logmel_data error:(NSError * _Nullable __autoreleasing * _Nullable)error {
|
||||
whisper_encoder_implInput *input_ = [[whisper_encoder_implInput alloc] initWithLogmel_data:logmel_data];
|
||||
return [self predictionFromFeatures:input_ error:error];
|
||||
}
|
||||
|
||||
- (nullable NSArray<whisper_encoder_implOutput *> *)predictionsFromInputs:(NSArray<whisper_encoder_implInput*> *)inputArray options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error {
|
||||
id<MLBatchProvider> inBatch = [[MLArrayBatchProvider alloc] initWithFeatureProviderArray:inputArray];
|
||||
id<MLBatchProvider> outBatch = [self.model predictionsFromBatch:inBatch options:options error:error];
|
||||
if (!outBatch) { return nil; }
|
||||
NSMutableArray<whisper_encoder_implOutput*> *results = [NSMutableArray arrayWithCapacity:(NSUInteger)outBatch.count];
|
||||
for (NSInteger i = 0; i < outBatch.count; i++) {
|
||||
id<MLFeatureProvider> resultProvider = [outBatch featuresAtIndex:i];
|
||||
whisper_encoder_implOutput * result = [[whisper_encoder_implOutput alloc] initWithOutput:(MLMultiArray *)[resultProvider featureValueForName:@"output"].multiArrayValue];
|
||||
[results addObject:result];
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
@end
|
22
coreml/whisper-encoder.h
Normal file
22
coreml/whisper-encoder.h
Normal file
@ -0,0 +1,22 @@
|
||||
// Wrapper of the Core ML Whisper Encoder model
|
||||
//
|
||||
// Code is derived from the work of Github user @wangchou
|
||||
// ref: https://github.com/wangchou/callCoreMLFromCpp
|
||||
|
||||
#if __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
struct whisper_coreml_context;
|
||||
|
||||
struct whisper_coreml_context * whisper_coreml_init(const char * path_model);
|
||||
void whisper_coreml_free(struct whisper_coreml_context * ctx);
|
||||
|
||||
void whisper_coreml_encode(
|
||||
const whisper_coreml_context * ctx,
|
||||
float * mel,
|
||||
float * out);
|
||||
|
||||
#if __cplusplus
|
||||
}
|
||||
#endif
|
67
coreml/whisper-encoder.mm
Normal file
67
coreml/whisper-encoder.mm
Normal file
@ -0,0 +1,67 @@
|
||||
#import "coreml/whisper-encoder.h"
|
||||
#import "coreml/whisper-encoder-impl.h"
|
||||
|
||||
#import <CoreML/CoreML.h>
|
||||
|
||||
#include <stdlib.h>
|
||||
|
||||
#if __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
struct whisper_coreml_context {
|
||||
const void * data;
|
||||
};
|
||||
|
||||
struct whisper_coreml_context * whisper_coreml_init(const char * path_model) {
|
||||
NSString * path_model_str = [[NSString alloc] initWithUTF8String:path_model];
|
||||
|
||||
NSURL * url_model = [NSURL fileURLWithPath: path_model_str];
|
||||
|
||||
const void * data = CFBridgingRetain([[whisper_encoder_impl alloc] initWithContentsOfURL:url_model error:nil]);
|
||||
|
||||
if (data == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
whisper_coreml_context * ctx = new whisper_coreml_context;
|
||||
|
||||
ctx->data = data;
|
||||
|
||||
return ctx;
|
||||
}
|
||||
|
||||
void whisper_coreml_free(struct whisper_coreml_context * ctx) {
|
||||
CFRelease(ctx->data);
|
||||
delete ctx;
|
||||
}
|
||||
|
||||
void whisper_coreml_encode(
|
||||
const whisper_coreml_context * ctx,
|
||||
float * mel,
|
||||
float * out) {
|
||||
MLMultiArray * inMultiArray = [
|
||||
[MLMultiArray alloc] initWithDataPointer: mel
|
||||
shape: @[@1, @80, @3000]
|
||||
dataType: MLMultiArrayDataTypeFloat32
|
||||
strides: @[@(240000), @(3000), @1]
|
||||
deallocator: nil
|
||||
error: nil
|
||||
];
|
||||
|
||||
whisper_encoder_implOutput * outCoreML = [(__bridge id) ctx->data predictionFromLogmel_data:inMultiArray error:nil];
|
||||
|
||||
MLMultiArray * outMA = outCoreML.output;
|
||||
|
||||
//NSArray<NSNumber *> * shape = outMA.shape;
|
||||
//NSArray<NSNumber *> * strides = outMA.strides;
|
||||
|
||||
//printf("shape: %ld %ld %ld %ld\n", [shape[0] longValue], [shape[1] longValue], [shape[2] longValue], [shape[3] longValue]);
|
||||
//printf("strides: %ld %ld %ld %ld\n", [strides[0] longValue], [strides[1] longValue], [strides[2] longValue], [strides[3] longValue]);
|
||||
|
||||
memcpy(out, outMA.dataPointer, outMA.count * sizeof(float));
|
||||
}
|
||||
|
||||
#if __cplusplus
|
||||
}
|
||||
#endif
|
@ -4,7 +4,7 @@ find_package(Threads REQUIRED)
|
||||
|
||||
# third-party
|
||||
|
||||
if (WHISPER_SUPPORT_SDL2)
|
||||
if (WHISPER_SDL2)
|
||||
# SDL2
|
||||
find_package(SDL2 REQUIRED)
|
||||
|
||||
@ -14,6 +14,41 @@ if (WHISPER_SUPPORT_SDL2)
|
||||
message(STATUS "SDL2_LIBRARIES = ${SDL2_LIBRARIES}")
|
||||
endif()
|
||||
|
||||
# common
|
||||
|
||||
set(TARGET common)
|
||||
|
||||
add_library(${TARGET} STATIC
|
||||
common.h
|
||||
common.cpp
|
||||
common-ggml.h
|
||||
common-ggml.cpp
|
||||
)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE whisper)
|
||||
|
||||
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
|
||||
if (WHISPER_SDL2)
|
||||
# common-sdl
|
||||
|
||||
set(TARGET common-sdl)
|
||||
|
||||
add_library(${TARGET} STATIC
|
||||
common-sdl.h
|
||||
common-sdl.cpp
|
||||
)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_include_directories(${TARGET} PUBLIC ${SDL2_INCLUDE_DIRS})
|
||||
target_link_libraries(${TARGET} PRIVATE ${SDL2_LIBRARIES})
|
||||
|
||||
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
endif()
|
||||
|
||||
# examples
|
||||
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
|
||||
@ -24,10 +59,14 @@ if (EMSCRIPTEN)
|
||||
add_subdirectory(command.wasm)
|
||||
add_subdirectory(talk.wasm)
|
||||
add_subdirectory(bench.wasm)
|
||||
elseif(CMAKE_JS_VERSION)
|
||||
add_subdirectory(addon.node)
|
||||
else()
|
||||
add_subdirectory(main)
|
||||
add_subdirectory(stream)
|
||||
add_subdirectory(command)
|
||||
add_subdirectory(bench)
|
||||
add_subdirectory(quantize)
|
||||
add_subdirectory(talk)
|
||||
add_subdirectory(talk-llama)
|
||||
endif()
|
||||
|
3
examples/addon.node/.gitignore
vendored
Normal file
3
examples/addon.node/.gitignore
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
.idea
|
||||
node_modules
|
||||
build
|
31
examples/addon.node/CMakeLists.txt
Normal file
31
examples/addon.node/CMakeLists.txt
Normal file
@ -0,0 +1,31 @@
|
||||
set(TARGET whisper-addon)
|
||||
|
||||
# Base settings
|
||||
#==================================================================
|
||||
# env var supported by cmake-js
|
||||
add_definitions(-DNAPI_VERSION=4)
|
||||
include_directories(${CMAKE_JS_INC})
|
||||
#==================================================================
|
||||
|
||||
add_library(${TARGET} SHARED ${CMAKE_JS_SRC} addon.cpp)
|
||||
set_target_properties(${TARGET} PROPERTIES PREFIX "" SUFFIX ".node")
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
# Include N-API wrappers
|
||||
#==================================================================
|
||||
execute_process(COMMAND node -p "require('node-addon-api').include"
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
OUTPUT_VARIABLE NODE_ADDON_API_DIR
|
||||
)
|
||||
string(REPLACE "\n" "" NODE_ADDON_API_DIR ${NODE_ADDON_API_DIR})
|
||||
string(REPLACE "\"" "" NODE_ADDON_API_DIR ${NODE_ADDON_API_DIR})
|
||||
target_include_directories(${TARGET} PRIVATE ${NODE_ADDON_API_DIR})
|
||||
#==================================================================
|
||||
|
||||
target_link_libraries(${TARGET} ${CMAKE_JS_LIB} common whisper ${CMAKE_THREAD_LIBS_INIT})
|
||||
|
||||
if(MSVC AND CMAKE_JS_NODELIB_DEF AND CMAKE_JS_NODELIB_TARGET)
|
||||
# Generate node.lib
|
||||
execute_process(COMMAND ${CMAKE_AR} /def:${CMAKE_JS_NODELIB_DEF} /out:${CMAKE_JS_NODELIB_TARGET} ${CMAKE_STATIC_LINKER_FLAGS})
|
||||
endif()
|
37
examples/addon.node/README.md
Normal file
37
examples/addon.node/README.md
Normal file
@ -0,0 +1,37 @@
|
||||
# addon
|
||||
|
||||
This is an addon demo that can **perform whisper model reasoning in `node` and `electron` environments**, based on [cmake-js](https://github.com/cmake-js/cmake-js).
|
||||
It can be used as a reference for using the whisper.cpp project in other node projects.
|
||||
|
||||
## Install
|
||||
|
||||
```shell
|
||||
npm install
|
||||
```
|
||||
|
||||
## Compile
|
||||
|
||||
Make sure it is in the project root directory and compiled with make-js.
|
||||
|
||||
```shell
|
||||
npx cmake-js compile -T whisper-addon -B Release
|
||||
```
|
||||
|
||||
For Electron addon and cmake-js options, you can see [cmake-js](https://github.com/cmake-js/cmake-js) and make very few configuration changes.
|
||||
|
||||
> Such as appointing special cmake path:
|
||||
> ```shell
|
||||
> npx cmake-js compile -c 'xxx/cmake' -T whisper-addon -B Release
|
||||
> ```
|
||||
|
||||
## Run
|
||||
|
||||
```shell
|
||||
cd examples/addon.node
|
||||
|
||||
node index.js --language='language' --model='model-path' --fname_inp='file-path'
|
||||
```
|
||||
|
||||
Because this is a simple Demo, only the above parameters are set in the node environment.
|
||||
|
||||
Other parameters can also be specified in the node environment.
|
23
examples/addon.node/__test__/whisper.spec.js
Normal file
23
examples/addon.node/__test__/whisper.spec.js
Normal file
@ -0,0 +1,23 @@
|
||||
const path = require("path");
|
||||
const { whisper } = require(path.join(
|
||||
__dirname,
|
||||
"../../../build/Release/whisper-addon"
|
||||
));
|
||||
const { promisify } = require("util");
|
||||
|
||||
const whisperAsync = promisify(whisper);
|
||||
|
||||
const whisperParamsMock = {
|
||||
language: "en",
|
||||
model: path.join(__dirname, "../../../models/ggml-base.en.bin"),
|
||||
fname_inp: path.join(__dirname, "../../../samples/jfk.wav"),
|
||||
};
|
||||
|
||||
describe("Run whisper.node", () => {
|
||||
test("it should receive a non-empty value", async () => {
|
||||
let result = await whisperAsync(whisperParamsMock);
|
||||
|
||||
expect(result.length).toBeGreaterThan(0);
|
||||
}, 10000);
|
||||
});
|
||||
|
338
examples/addon.node/addon.cpp
Normal file
338
examples/addon.node/addon.cpp
Normal file
@ -0,0 +1,338 @@
|
||||
#include "napi.h"
|
||||
#include "common.h"
|
||||
|
||||
#include "whisper.h"
|
||||
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
|
||||
struct whisper_params {
|
||||
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||
int32_t n_processors = 1;
|
||||
int32_t offset_t_ms = 0;
|
||||
int32_t offset_n = 0;
|
||||
int32_t duration_ms = 0;
|
||||
int32_t max_context = -1;
|
||||
int32_t max_len = 0;
|
||||
int32_t best_of = 5;
|
||||
int32_t beam_size = -1;
|
||||
|
||||
float word_thold = 0.01f;
|
||||
float entropy_thold = 2.4f;
|
||||
float logprob_thold = -1.0f;
|
||||
|
||||
bool speed_up = false;
|
||||
bool translate = false;
|
||||
bool diarize = false;
|
||||
bool output_txt = false;
|
||||
bool output_vtt = false;
|
||||
bool output_srt = false;
|
||||
bool output_wts = false;
|
||||
bool output_csv = false;
|
||||
bool print_special = false;
|
||||
bool print_colors = false;
|
||||
bool print_progress = false;
|
||||
bool no_timestamps = false;
|
||||
|
||||
std::string language = "en";
|
||||
std::string prompt;
|
||||
std::string model = "../../ggml-large.bin";
|
||||
|
||||
std::vector<std::string> fname_inp = {};
|
||||
std::vector<std::string> fname_out = {};
|
||||
};
|
||||
|
||||
struct whisper_print_user_data {
|
||||
const whisper_params * params;
|
||||
|
||||
const std::vector<std::vector<float>> * pcmf32s;
|
||||
};
|
||||
|
||||
// 500 -> 00:05.000
|
||||
// 6000 -> 01:00.000
|
||||
std::string to_timestamp(int64_t t, bool comma = false) {
|
||||
int64_t msec = t * 10;
|
||||
int64_t hr = msec / (1000 * 60 * 60);
|
||||
msec = msec - hr * (1000 * 60 * 60);
|
||||
int64_t min = msec / (1000 * 60);
|
||||
msec = msec - min * (1000 * 60);
|
||||
int64_t sec = msec / 1000;
|
||||
msec = msec - sec * 1000;
|
||||
|
||||
char buf[32];
|
||||
snprintf(buf, sizeof(buf), "%02d:%02d:%02d%s%03d", (int) hr, (int) min, (int) sec, comma ? "," : ".", (int) msec);
|
||||
|
||||
return std::string(buf);
|
||||
}
|
||||
|
||||
int timestamp_to_sample(int64_t t, int n_samples) {
|
||||
return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100)));
|
||||
}
|
||||
|
||||
void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) {
|
||||
const auto & params = *((whisper_print_user_data *) user_data)->params;
|
||||
const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s;
|
||||
|
||||
const int n_segments = whisper_full_n_segments(ctx);
|
||||
|
||||
std::string speaker = "";
|
||||
|
||||
int64_t t0;
|
||||
int64_t t1;
|
||||
|
||||
// print the last n_new segments
|
||||
const int s0 = n_segments - n_new;
|
||||
|
||||
if (s0 == 0) {
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
for (int i = s0; i < n_segments; i++) {
|
||||
if (!params.no_timestamps || params.diarize) {
|
||||
t0 = whisper_full_get_segment_t0(ctx, i);
|
||||
t1 = whisper_full_get_segment_t1(ctx, i);
|
||||
}
|
||||
|
||||
if (!params.no_timestamps) {
|
||||
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
|
||||
}
|
||||
|
||||
if (params.diarize && pcmf32s.size() == 2) {
|
||||
const int64_t n_samples = pcmf32s[0].size();
|
||||
|
||||
const int64_t is0 = timestamp_to_sample(t0, n_samples);
|
||||
const int64_t is1 = timestamp_to_sample(t1, n_samples);
|
||||
|
||||
double energy0 = 0.0f;
|
||||
double energy1 = 0.0f;
|
||||
|
||||
for (int64_t j = is0; j < is1; j++) {
|
||||
energy0 += fabs(pcmf32s[0][j]);
|
||||
energy1 += fabs(pcmf32s[1][j]);
|
||||
}
|
||||
|
||||
if (energy0 > 1.1*energy1) {
|
||||
speaker = "(speaker 0)";
|
||||
} else if (energy1 > 1.1*energy0) {
|
||||
speaker = "(speaker 1)";
|
||||
} else {
|
||||
speaker = "(speaker ?)";
|
||||
}
|
||||
|
||||
//printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str());
|
||||
}
|
||||
|
||||
// colorful print bug
|
||||
//
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
printf("%s%s", speaker.c_str(), text);
|
||||
|
||||
|
||||
// with timestamps or speakers: each segment on new line
|
||||
if (!params.no_timestamps || params.diarize) {
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
fflush(stdout);
|
||||
}
|
||||
}
|
||||
|
||||
int run(whisper_params ¶ms, std::vector<std::vector<std::string>> &result) {
|
||||
if (params.fname_inp.empty()) {
|
||||
fprintf(stderr, "error: no input files specified\n");
|
||||
return 2;
|
||||
}
|
||||
|
||||
if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1) {
|
||||
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
|
||||
exit(0);
|
||||
}
|
||||
|
||||
// whisper init
|
||||
|
||||
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
|
||||
|
||||
if (ctx == nullptr) {
|
||||
fprintf(stderr, "error: failed to initialize whisper context\n");
|
||||
return 3;
|
||||
}
|
||||
|
||||
for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
|
||||
const auto fname_inp = params.fname_inp[f];
|
||||
const auto fname_out = f < (int)params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
|
||||
|
||||
std::vector<float> pcmf32; // mono-channel F32 PCM
|
||||
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
|
||||
|
||||
if (!::read_wav(fname_inp, pcmf32, pcmf32s, params.diarize)) {
|
||||
fprintf(stderr, "error: failed to read WAV file '%s'\n", fname_inp.c_str());
|
||||
continue;
|
||||
}
|
||||
|
||||
// print system information
|
||||
{
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
|
||||
params.n_threads*params.n_processors, std::thread::hardware_concurrency(), whisper_print_system_info());
|
||||
}
|
||||
|
||||
// print some info about the processing
|
||||
{
|
||||
fprintf(stderr, "\n");
|
||||
if (!whisper_is_multilingual(ctx)) {
|
||||
if (params.language != "en" || params.translate) {
|
||||
params.language = "en";
|
||||
params.translate = false;
|
||||
fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
|
||||
}
|
||||
}
|
||||
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, timestamps = %d ...\n",
|
||||
__func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
|
||||
params.n_threads, params.n_processors,
|
||||
params.language.c_str(),
|
||||
params.translate ? "translate" : "transcribe",
|
||||
params.no_timestamps ? 0 : 1);
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
// run the inference
|
||||
{
|
||||
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||
|
||||
wparams.strategy = params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY;
|
||||
|
||||
wparams.print_realtime = false;
|
||||
wparams.print_progress = params.print_progress;
|
||||
wparams.print_timestamps = !params.no_timestamps;
|
||||
wparams.print_special = params.print_special;
|
||||
wparams.translate = params.translate;
|
||||
wparams.language = params.language.c_str();
|
||||
wparams.n_threads = params.n_threads;
|
||||
wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
|
||||
wparams.offset_ms = params.offset_t_ms;
|
||||
wparams.duration_ms = params.duration_ms;
|
||||
|
||||
wparams.token_timestamps = params.output_wts || params.max_len > 0;
|
||||
wparams.thold_pt = params.word_thold;
|
||||
wparams.entropy_thold = params.entropy_thold;
|
||||
wparams.logprob_thold = params.logprob_thold;
|
||||
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
|
||||
|
||||
wparams.speed_up = params.speed_up;
|
||||
|
||||
wparams.greedy.best_of = params.best_of;
|
||||
wparams.beam_search.beam_size = params.beam_size;
|
||||
|
||||
wparams.initial_prompt = params.prompt.c_str();
|
||||
|
||||
whisper_print_user_data user_data = { ¶ms, &pcmf32s };
|
||||
|
||||
// this callback is called on each new segment
|
||||
if (!wparams.print_realtime) {
|
||||
wparams.new_segment_callback = whisper_print_segment_callback;
|
||||
wparams.new_segment_callback_user_data = &user_data;
|
||||
}
|
||||
|
||||
// example for abort mechanism
|
||||
// in this example, we do not abort the processing, but we could if the flag is set to true
|
||||
// the callback is called before every encoder run - if it returns false, the processing is aborted
|
||||
{
|
||||
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
|
||||
|
||||
wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
|
||||
bool is_aborted = *(bool*)user_data;
|
||||
return !is_aborted;
|
||||
};
|
||||
wparams.encoder_begin_callback_user_data = &is_aborted;
|
||||
}
|
||||
|
||||
if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
|
||||
fprintf(stderr, "failed to process audio\n");
|
||||
return 10;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const int n_segments = whisper_full_n_segments(ctx);
|
||||
result.resize(n_segments);
|
||||
for (int i = 0; i < n_segments; ++i) {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
|
||||
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
||||
|
||||
result[i].emplace_back(to_timestamp(t0, true));
|
||||
result[i].emplace_back(to_timestamp(t1, true));
|
||||
result[i].emplace_back(text);
|
||||
}
|
||||
|
||||
whisper_print_timings(ctx);
|
||||
whisper_free(ctx);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
class Worker : public Napi::AsyncWorker {
|
||||
public:
|
||||
Worker(Napi::Function& callback, whisper_params params)
|
||||
: Napi::AsyncWorker(callback), params(params) {}
|
||||
|
||||
void Execute() override {
|
||||
run(params, result);
|
||||
}
|
||||
|
||||
void OnOK() override {
|
||||
Napi::HandleScope scope(Env());
|
||||
Napi::Object res = Napi::Array::New(Env(), result.size());
|
||||
for (uint64_t i = 0; i < result.size(); ++i) {
|
||||
Napi::Object tmp = Napi::Array::New(Env(), 3);
|
||||
for (uint64_t j = 0; j < 3; ++j) {
|
||||
tmp[j] = Napi::String::New(Env(), result[i][j]);
|
||||
}
|
||||
res[i] = tmp;
|
||||
}
|
||||
Callback().Call({Env().Null(), res});
|
||||
}
|
||||
|
||||
private:
|
||||
whisper_params params;
|
||||
std::vector<std::vector<std::string>> result;
|
||||
};
|
||||
|
||||
|
||||
|
||||
Napi::Value whisper(const Napi::CallbackInfo& info) {
|
||||
Napi::Env env = info.Env();
|
||||
if (info.Length() <= 0 || !info[0].IsObject()) {
|
||||
Napi::TypeError::New(env, "object expected").ThrowAsJavaScriptException();
|
||||
}
|
||||
whisper_params params;
|
||||
|
||||
Napi::Object whisper_params = info[0].As<Napi::Object>();
|
||||
std::string language = whisper_params.Get("language").As<Napi::String>();
|
||||
std::string model = whisper_params.Get("model").As<Napi::String>();
|
||||
std::string input = whisper_params.Get("fname_inp").As<Napi::String>();
|
||||
|
||||
params.language = language;
|
||||
params.model = model;
|
||||
params.fname_inp.emplace_back(input);
|
||||
|
||||
Napi::Function callback = info[1].As<Napi::Function>();
|
||||
Worker* worker = new Worker(callback, params);
|
||||
worker->Queue();
|
||||
return env.Undefined();
|
||||
}
|
||||
|
||||
|
||||
Napi::Object Init(Napi::Env env, Napi::Object exports) {
|
||||
exports.Set(
|
||||
Napi::String::New(env, "whisper"),
|
||||
Napi::Function::New(env, whisper)
|
||||
);
|
||||
return exports;
|
||||
}
|
||||
|
||||
NODE_API_MODULE(whisper, Init);
|
36
examples/addon.node/index.js
Normal file
36
examples/addon.node/index.js
Normal file
@ -0,0 +1,36 @@
|
||||
const path = require("path");
|
||||
const { whisper } = require(path.join(
|
||||
__dirname,
|
||||
"../../build/Release/whisper-addon"
|
||||
));
|
||||
const { promisify } = require("util");
|
||||
|
||||
const whisperAsync = promisify(whisper);
|
||||
|
||||
const whisperParams = {
|
||||
language: "en",
|
||||
model: path.join(__dirname, "../../models/ggml-base.en.bin"),
|
||||
fname_inp: "../../samples/jfk.wav",
|
||||
};
|
||||
|
||||
const arguments = process.argv.slice(2);
|
||||
const params = Object.fromEntries(
|
||||
arguments.reduce((pre, item) => {
|
||||
if (item.startsWith("--")) {
|
||||
return [...pre, item.slice(2).split("=")];
|
||||
}
|
||||
return pre;
|
||||
}, [])
|
||||
);
|
||||
|
||||
for (const key in params) {
|
||||
if (whisperParams.hasOwnProperty(key)) {
|
||||
whisperParams[key] = params[key];
|
||||
}
|
||||
}
|
||||
|
||||
console.log("whisperParams =", whisperParams);
|
||||
|
||||
whisperAsync(whisperParams).then((result) => {
|
||||
console.log(`Result from whisper: ${result}`);
|
||||
});
|
16
examples/addon.node/package.json
Normal file
16
examples/addon.node/package.json
Normal file
@ -0,0 +1,16 @@
|
||||
{
|
||||
"name": "whisper-addon",
|
||||
"version": "0.0.0",
|
||||
"description": "",
|
||||
"main": "index.js",
|
||||
"author": "Qanhe Chen",
|
||||
"license": "MIT",
|
||||
"scripts": {
|
||||
"test": "jest"
|
||||
},
|
||||
"devDependencies": {
|
||||
"cmake-js": "^7.1.1",
|
||||
"jest": "^29.4.0",
|
||||
"node-addon-api": "^5.0.0"
|
||||
}
|
||||
}
|
@ -8,6 +8,8 @@ add_executable(${TARGET}
|
||||
emscripten.cpp
|
||||
)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE
|
||||
whisper
|
||||
)
|
||||
@ -29,9 +31,9 @@ endif()
|
||||
set_target_properties(${TARGET} PROPERTIES LINK_FLAGS " \
|
||||
--bind \
|
||||
-s USE_PTHREADS=1 \
|
||||
-s PTHREAD_POOL_SIZE=8 \
|
||||
-s INITIAL_MEMORY=1024MB \
|
||||
-s TOTAL_MEMORY=1024MB \
|
||||
-s PTHREAD_POOL_SIZE_STRICT=0 \
|
||||
-s INITIAL_MEMORY=2000MB \
|
||||
-s TOTAL_MEMORY=2000MB \
|
||||
-s FORCE_FILESYSTEM=1 \
|
||||
-s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap']\" \
|
||||
${EXTRA_FLAGS} \
|
||||
|
@ -28,6 +28,11 @@ void bench_main(size_t index) {
|
||||
return;
|
||||
}
|
||||
|
||||
{
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", n_threads, std::thread::hardware_concurrency(), whisper_print_system_info());
|
||||
}
|
||||
|
||||
if (int ret = whisper_encode(ctx, 0, n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
||||
return;
|
||||
@ -52,7 +57,7 @@ EMSCRIPTEN_BINDINGS(bench) {
|
||||
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
|
||||
for (size_t i = 0; i < g_contexts.size(); ++i) {
|
||||
if (g_contexts[i] == nullptr) {
|
||||
g_contexts[i] = whisper_init(path_model.c_str());
|
||||
g_contexts[i] = whisper_init_from_file(path_model.c_str());
|
||||
if (g_contexts[i] != nullptr) {
|
||||
if (g_worker.joinable()) {
|
||||
g_worker.join();
|
||||
|
@ -35,6 +35,15 @@
|
||||
|
||||
<br><br>
|
||||
|
||||
<b>More examples:</b>
|
||||
<a href="https://whisper.ggerganov.com/">main</a> |
|
||||
<a href="https://whisper.ggerganov.com/bench">bench</a> |
|
||||
<a href="https://whisper.ggerganov.com/stream">stream</a> |
|
||||
<a href="https://whisper.ggerganov.com/command">command</a> |
|
||||
<a href="https://whisper.ggerganov.com/talk">talk</a> |
|
||||
|
||||
<br><br>
|
||||
|
||||
<hr>
|
||||
|
||||
Select the model you would like to use and click the "Bench" button.<br>
|
||||
@ -44,11 +53,18 @@
|
||||
|
||||
<div id="model-whisper">
|
||||
Whisper model: <span id="model-whisper-status"></span>
|
||||
<button id="fetch-whisper-tiny-en" onclick="loadWhisper('tiny.en')">tiny.en (75 MB)</button>
|
||||
<button id="fetch-whisper-base-en" onclick="loadWhisper('base.en')">base.en (142 MB)</button>
|
||||
<span id="fetch-whisper-progress"></span>
|
||||
|
||||
<button id="fetch-whisper-tiny-en" onclick="loadWhisper('tiny.en')">tiny.en (75 MB)</button>
|
||||
<button id="fetch-whisper-base-en" onclick="loadWhisper('base.en')">base.en (142 MB)</button>
|
||||
<button id="fetch-whisper-small-en" onclick="loadWhisper('small.en')">small.en (466 MB)</button>
|
||||
<input type="file" id="whisper-file" name="file" onchange="loadFile(event, 'whisper.bin')" />
|
||||
<br><br>
|
||||
Quantized models:<br><br>
|
||||
<button id="fetch-whisper-tiny-en-q5_1" onclick="loadWhisper('tiny-en-q5_1')">tiny.en (Q5_1, 31 MB)</button>
|
||||
<button id="fetch-whisper-base-en-q5_1" onclick="loadWhisper('base-en-q5_1')">base.en (Q5_1, 57 MB)</button>
|
||||
<button id="fetch-whisper-small-en-q5_1" onclick="loadWhisper('small-en-q5_1')">small.en (Q5_1, 182 MB)</button>
|
||||
<button id="fetch-whisper-medium-en-q5_0" onclick="loadWhisper('medium-en-q5_0')">medium.en (Q5_0, 515 MB)</button>
|
||||
<button id="fetch-whisper-large-q5_0" onclick="loadWhisper('large-q5_0')">large (Q5_0, 1030 MB)</button>
|
||||
<span id="fetch-whisper-progress"></span>
|
||||
</div>
|
||||
|
||||
<br>
|
||||
@ -160,6 +176,14 @@
|
||||
|
||||
document.getElementById('fetch-whisper-tiny-en').style.display = 'none';
|
||||
document.getElementById('fetch-whisper-base-en').style.display = 'none';
|
||||
document.getElementById('fetch-whisper-small-en').style.display = 'none';
|
||||
|
||||
document.getElementById('fetch-whisper-tiny-en-q5_1' ).style.display = 'none';
|
||||
document.getElementById('fetch-whisper-base-en-q5_1' ).style.display = 'none';
|
||||
document.getElementById('fetch-whisper-small-en-q5_1' ).style.display = 'none';
|
||||
document.getElementById('fetch-whisper-medium-en-q5_0').style.display = 'none';
|
||||
document.getElementById('fetch-whisper-large-q5_0' ).style.display = 'none';
|
||||
|
||||
document.getElementById('whisper-file' ).style.display = 'none';
|
||||
document.getElementById('model-whisper-status' ).innerHTML = 'loaded model: ' + file.name;
|
||||
}
|
||||
@ -168,19 +192,42 @@
|
||||
let urls = {
|
||||
'tiny.en': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en.bin',
|
||||
'base.en': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en.bin',
|
||||
'small.en': 'https://whisper.ggerganov.com/ggml-model-whisper-small.en.bin',
|
||||
|
||||
'tiny-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en-q5_1.bin',
|
||||
'base-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en-q5_1.bin',
|
||||
'small-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-small.en-q5_1.bin',
|
||||
'medium-en-q5_0':'https://whisper.ggerganov.com/ggml-model-whisper-medium.en-q5_0.bin',
|
||||
'large-q5_0': 'https://whisper.ggerganov.com/ggml-model-whisper-large-q5_0.bin',
|
||||
};
|
||||
|
||||
let sizes = {
|
||||
'tiny.en': 75,
|
||||
'base.en': 142,
|
||||
'small.en': 466,
|
||||
|
||||
'tiny-en-q5_1': 31,
|
||||
'base-en-q5_1': 57,
|
||||
'small-en-q5_1': 182,
|
||||
'medium-en-q5_0': 515,
|
||||
'large-q5_0': 1030,
|
||||
};
|
||||
|
||||
let url = urls[model];
|
||||
let dst = 'whisper.bin';
|
||||
let size_mb = sizes[model];
|
||||
|
||||
document.getElementById('fetch-whisper-tiny-en').style.display = 'none';
|
||||
document.getElementById('fetch-whisper-base-en').style.display = 'none';
|
||||
document.getElementById('fetch-whisper-tiny-en').style.display = 'none';
|
||||
document.getElementById('fetch-whisper-base-en').style.display = 'none';
|
||||
document.getElementById('fetch-whisper-small-en').style.display = 'none';
|
||||
|
||||
document.getElementById('fetch-whisper-tiny-en-q5_1' ).style.display = 'none';
|
||||
document.getElementById('fetch-whisper-base-en-q5_1' ).style.display = 'none';
|
||||
document.getElementById('fetch-whisper-small-en-q5_1' ).style.display = 'none';
|
||||
document.getElementById('fetch-whisper-medium-en-q5_0').style.display = 'none';
|
||||
document.getElementById('fetch-whisper-large-q5_0' ).style.display = 'none';
|
||||
|
||||
document.getElementById('whisper-file' ).style.display = 'none';
|
||||
document.getElementById('model-whisper-status').innerHTML = 'loading "' + model + '" ... ';
|
||||
|
||||
cbProgress = function(p) {
|
||||
@ -190,9 +237,18 @@
|
||||
|
||||
cbCancel = function() {
|
||||
var el;
|
||||
el = document.getElementById('fetch-whisper-tiny-en'); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('fetch-whisper-base-en'); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('model-whisper-status'); if (el) el.innerHTML = '';
|
||||
el = document.getElementById('fetch-whisper-tiny-en'); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('fetch-whisper-base-en'); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('fetch-whisper-small-en'); if (el) el.style.display = 'inline-block';
|
||||
|
||||
el = document.getElementById('fetch-whisper-tiny-en-q5_1' ); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('fetch-whisper-base-en-q5_1' ); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('fetch-whisper-small-en-q5_1' ); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('fetch-whisper-medium-en-q5_0'); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('fetch-whisper-large-q5_0' ); if (el) el.style.display = 'inline-block';
|
||||
|
||||
el = document.getElementById('whisper-file' ); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('model-whisper-status'); if (el) el.innerHTML = '';
|
||||
};
|
||||
|
||||
loadRemote(url, dst, size_mb, cbProgress, storeFS, cbCancel, printTextarea);
|
||||
|
@ -1,3 +1,6 @@
|
||||
set(TARGET bench)
|
||||
add_executable(${TARGET} bench.cpp)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE whisper ${CMAKE_THREAD_LIBS_INIT})
|
||||
|
@ -7,6 +7,7 @@
|
||||
// command-line parameters
|
||||
struct whisper_params {
|
||||
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||
int32_t what = 0; // what to benchmark: 0 - whisper ecoder, 1 - memcpy, 2 - ggml_mul_mat
|
||||
|
||||
std::string model = "models/ggml-base.en.bin";
|
||||
};
|
||||
@ -23,6 +24,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
}
|
||||
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
|
||||
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
|
||||
else if (arg == "-w" || arg == "--what") { params.what = atoi(argv[++i]); }
|
||||
else {
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
whisper_print_usage(argc, argv, params);
|
||||
@ -33,7 +35,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
return true;
|
||||
}
|
||||
|
||||
void whisper_print_usage(int argc, char ** argv, const whisper_params & params) {
|
||||
void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params) {
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "usage: %s [options]\n", argv[0]);
|
||||
fprintf(stderr, "\n");
|
||||
@ -41,19 +43,17 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
|
||||
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
|
||||
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
|
||||
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
||||
fprintf(stderr, " -w N, --what N [%-7d] what to benchmark:\n", params.what);
|
||||
fprintf(stderr, " %-7s 0 - whisper encoder\n", "");
|
||||
fprintf(stderr, " %-7s 1 - memcpy\n", "");
|
||||
fprintf(stderr, " %-7s 2 - ggml_mul_mat\n", "");
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
whisper_params params;
|
||||
|
||||
if (whisper_params_parse(argc, argv, params) == false) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
int whisper_bench_encoder(const whisper_params & params) {
|
||||
// whisper init
|
||||
|
||||
struct whisper_context * ctx = whisper_init(params.model.c_str());
|
||||
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
|
||||
|
||||
{
|
||||
fprintf(stderr, "\n");
|
||||
@ -92,3 +92,22 @@ int main(int argc, char ** argv) {
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
whisper_params params;
|
||||
|
||||
if (whisper_params_parse(argc, argv, params) == false) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ret = -1;
|
||||
|
||||
switch (params.what) {
|
||||
case 0: ret = whisper_bench_encoder(params); break;
|
||||
case 1: ret = whisper_bench_memcpy(params.n_threads); break;
|
||||
case 2: ret = whisper_bench_ggml_mul_mat(params.n_threads); break;
|
||||
default: fprintf(stderr, "error: unknown benchmark: %d\n", params.what); break;
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
@ -8,7 +8,10 @@ add_executable(${TARGET}
|
||||
emscripten.cpp
|
||||
)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE
|
||||
common
|
||||
whisper
|
||||
)
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
#include "ggml.h"
|
||||
#include "common.h"
|
||||
#include "whisper.h"
|
||||
|
||||
#include <emscripten.h>
|
||||
@ -27,92 +28,11 @@ std::string g_transcribed = "";
|
||||
|
||||
std::vector<float> g_pcmf32;
|
||||
|
||||
static std::string trim(const std::string & s) {
|
||||
std::regex e("^\\s+|\\s+$");
|
||||
return std::regex_replace(s, e, "");
|
||||
}
|
||||
|
||||
static void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate) {
|
||||
const float rc = 1.0f / (2.0f * M_PI * cutoff);
|
||||
const float dt = 1.0f / sample_rate;
|
||||
const float alpha = dt / (rc + dt);
|
||||
|
||||
float y = data[0];
|
||||
|
||||
for (size_t i = 1; i < data.size(); i++) {
|
||||
y = alpha * (y + data[i] - data[i - 1]);
|
||||
data[i] = y;
|
||||
}
|
||||
}
|
||||
|
||||
// compute similarity between two strings using Levenshtein distance
|
||||
static float similarity(const std::string & s0, const std::string & s1) {
|
||||
const size_t len0 = s0.size() + 1;
|
||||
const size_t len1 = s1.size() + 1;
|
||||
|
||||
std::vector<int> col(len1, 0);
|
||||
std::vector<int> prevCol(len1, 0);
|
||||
|
||||
for (size_t i = 0; i < len1; i++) {
|
||||
prevCol[i] = i;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < len0; i++) {
|
||||
col[0] = i;
|
||||
for (size_t j = 1; j < len1; j++) {
|
||||
col[j] = std::min(std::min(1 + col[j - 1], 1 + prevCol[j]), prevCol[j - 1] + (s0[i - 1] == s1[j - 1] ? 0 : 1));
|
||||
}
|
||||
col.swap(prevCol);
|
||||
}
|
||||
|
||||
const float dist = prevCol[len1 - 1];
|
||||
|
||||
return 1.0f - (dist / std::max(s0.size(), s1.size()));
|
||||
}
|
||||
|
||||
void command_set_status(const std::string & status) {
|
||||
std::lock_guard<std::mutex> lock(g_mutex);
|
||||
g_status = status;
|
||||
}
|
||||
|
||||
bool command_vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) {
|
||||
const int n_samples = pcmf32.size();
|
||||
const int n_samples_last = (sample_rate * last_ms) / 1000;
|
||||
|
||||
if (n_samples_last >= n_samples) {
|
||||
// not enough samples - assume no speech
|
||||
return false;
|
||||
}
|
||||
|
||||
if (freq_thold > 0.0f) {
|
||||
high_pass_filter(pcmf32, freq_thold, sample_rate);
|
||||
}
|
||||
|
||||
float energy_all = 0.0f;
|
||||
float energy_last = 0.0f;
|
||||
|
||||
for (size_t i = 0; i < n_samples; i++) {
|
||||
energy_all += fabsf(pcmf32[i]);
|
||||
|
||||
if (i >= n_samples - n_samples_last) {
|
||||
energy_last += fabsf(pcmf32[i]);
|
||||
}
|
||||
}
|
||||
|
||||
energy_all /= n_samples;
|
||||
energy_last /= n_samples_last;
|
||||
|
||||
if (verbose) {
|
||||
fprintf(stderr, "%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold);
|
||||
}
|
||||
|
||||
if (energy_last > vad_thold*energy_all) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string command_transcribe(whisper_context * ctx, const whisper_full_params & wparams, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) {
|
||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
@ -155,7 +75,7 @@ void command_get_audio(int ms, int sample_rate, std::vector<float> & audio) {
|
||||
const int64_t n_samples = (ms * sample_rate) / 1000;
|
||||
|
||||
int64_t n_take = 0;
|
||||
if (g_pcmf32.size() < n_samples) {
|
||||
if (n_samples > (int) g_pcmf32.size()) {
|
||||
n_take = g_pcmf32.size();
|
||||
} else {
|
||||
n_take = n_samples;
|
||||
@ -187,7 +107,6 @@ void command_main(size_t index) {
|
||||
|
||||
printf("command: using %d threads\n", wparams.n_threads);
|
||||
|
||||
bool is_running = true;
|
||||
bool have_prompt = false;
|
||||
bool ask_prompt = true;
|
||||
bool print_energy = false;
|
||||
@ -233,7 +152,7 @@ void command_main(size_t index) {
|
||||
{
|
||||
command_get_audio(vad_ms, WHISPER_SAMPLE_RATE, pcmf32_cur);
|
||||
|
||||
if (command_vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, vad_thold, freq_thold, print_energy)) {
|
||||
if (::vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, vad_thold, freq_thold, print_energy)) {
|
||||
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
|
||||
command_set_status("Speech detected! Processing ...");
|
||||
|
||||
@ -324,7 +243,7 @@ EMSCRIPTEN_BINDINGS(command) {
|
||||
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
|
||||
for (size_t i = 0; i < g_contexts.size(); ++i) {
|
||||
if (g_contexts[i] == nullptr) {
|
||||
g_contexts[i] = whisper_init(path_model.c_str());
|
||||
g_contexts[i] = whisper_init_from_file(path_model.c_str());
|
||||
if (g_contexts[i] != nullptr) {
|
||||
g_running = true;
|
||||
if (g_worker.joinable()) {
|
||||
|
@ -35,6 +35,15 @@
|
||||
|
||||
<br><br>
|
||||
|
||||
<b>More examples:</b>
|
||||
<a href="https://whisper.ggerganov.com/">main</a> |
|
||||
<a href="https://whisper.ggerganov.com/bench">bench</a> |
|
||||
<a href="https://whisper.ggerganov.com/stream">stream</a> |
|
||||
<a href="https://whisper.ggerganov.com/command">command</a> |
|
||||
<a href="https://whisper.ggerganov.com/talk">talk</a> |
|
||||
|
||||
<br><br>
|
||||
|
||||
<hr>
|
||||
|
||||
Select the model you would like to use, click the "Start" button and follow the instructions.
|
||||
@ -45,6 +54,10 @@
|
||||
Whisper model: <span id="model-whisper-status"></span>
|
||||
<button id="fetch-whisper-tiny-en" onclick="loadWhisper('tiny.en')">tiny.en (75 MB)</button>
|
||||
<button id="fetch-whisper-base-en" onclick="loadWhisper('base.en')">base.en (142 MB)</button>
|
||||
<br><br>
|
||||
Quantized models:<br><br>
|
||||
<button id="fetch-whisper-tiny-en-q5_1" onclick="loadWhisper('tiny-en-q5_1')">tiny.en (Q5_1, 31 MB)</button>
|
||||
<button id="fetch-whisper-base-en-q5_1" onclick="loadWhisper('base-en-q5_1')">base.en (Q5_1, 57 MB)</button>
|
||||
<span id="fetch-whisper-progress"></span>
|
||||
|
||||
<!--
|
||||
@ -162,11 +175,17 @@
|
||||
let urls = {
|
||||
'tiny.en': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en.bin',
|
||||
'base.en': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en.bin',
|
||||
|
||||
'tiny-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en-q5_1.bin',
|
||||
'base-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en-q5_1.bin',
|
||||
};
|
||||
|
||||
let sizes = {
|
||||
'tiny.en': 75,
|
||||
'base.en': 142,
|
||||
|
||||
'tiny-en-q5_1': 31,
|
||||
'base-en-q5_1': 57,
|
||||
};
|
||||
|
||||
let url = urls[model];
|
||||
@ -177,6 +196,10 @@
|
||||
|
||||
document.getElementById('fetch-whisper-tiny-en').style.display = 'none';
|
||||
document.getElementById('fetch-whisper-base-en').style.display = 'none';
|
||||
|
||||
document.getElementById('fetch-whisper-tiny-en-q5_1').style.display = 'none';
|
||||
document.getElementById('fetch-whisper-base-en-q5_1').style.display = 'none';
|
||||
|
||||
document.getElementById('model-whisper-status').innerHTML = 'loading "' + model + '" ... ';
|
||||
|
||||
cbProgress = function(p) {
|
||||
@ -188,6 +211,10 @@
|
||||
var el;
|
||||
el = document.getElementById('fetch-whisper-tiny-en'); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('fetch-whisper-base-en'); if (el) el.style.display = 'inline-block';
|
||||
|
||||
el = document.getElementById('fetch-whisper-tiny-en-q5_1'); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('fetch-whisper-base-en-q5_1'); if (el) el.style.display = 'inline-block';
|
||||
|
||||
el = document.getElementById('model-whisper-status'); if (el) el.innerHTML = '';
|
||||
};
|
||||
|
||||
|
@ -1,7 +1,9 @@
|
||||
if (WHISPER_SUPPORT_SDL2)
|
||||
if (WHISPER_SDL2)
|
||||
# command
|
||||
set(TARGET command)
|
||||
add_executable(${TARGET} command.cpp)
|
||||
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
|
||||
target_link_libraries(${TARGET} PRIVATE whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE common common-sdl whisper ${CMAKE_THREAD_LIBS_INIT})
|
||||
endif ()
|
||||
|
@ -9,7 +9,19 @@ More info is available in [issue #171](https://github.com/ggerganov/whisper.cpp/
|
||||
|
||||
# On Raspberry Pi, use tiny or base models + "-ac 768" for better performance
|
||||
./command -m ./models/ggml-tiny.en.bin -ac 768 -t 3 -c 0
|
||||
```
|
||||
|
||||
https://user-images.githubusercontent.com/1991296/204038393-2f846eae-c255-4099-a76d-5735c25c49da.mp4
|
||||
|
||||
Web version: [examples/command.wasm](/examples/command.wasm)
|
||||
|
||||
## Guided mode
|
||||
|
||||
"Guided mode" allows you to specify a list of commands (i.e. strings) and the transcription will be guided to classify your command into one from the list. This can be useful in situations where a device is listening only for a small subset of commands.
|
||||
|
||||
Initial tests show that this approach might be extremely efficient in terms of performance, since it integrates very well with the "partial Encoder" idea from #137.
|
||||
|
||||
```bash
|
||||
# Run in guided mode, the list of allowed commands is in commands.txt
|
||||
./command -m ./models/ggml-base.en.bin -cmd ./examples/command/commands.txt
|
||||
|
||||
@ -17,9 +29,8 @@ More info is available in [issue #171](https://github.com/ggerganov/whisper.cpp/
|
||||
./command -m ./models/ggml-tiny.en.bin -cmd ./examples/command/commands.txt -ac 128 -t 3 -c 0
|
||||
```
|
||||
|
||||
https://user-images.githubusercontent.com/1991296/204038393-2f846eae-c255-4099-a76d-5735c25c49da.mp4
|
||||
https://user-images.githubusercontent.com/1991296/207435352-8fc4ed3f-bde5-4555-9b8b-aeeb76bee969.mp4
|
||||
|
||||
Web version: [examples/command.wasm](/examples/command.wasm)
|
||||
|
||||
## Building
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
241
examples/common-ggml.cpp
Normal file
241
examples/common-ggml.cpp
Normal file
@ -0,0 +1,241 @@
|
||||
#include "common-ggml.h"
|
||||
|
||||
#include <regex>
|
||||
#include <map>
|
||||
|
||||
static const std::map<std::string, enum ggml_ftype> GGML_FTYPE_MAP = {
|
||||
{"q4_0", GGML_FTYPE_MOSTLY_Q4_0},
|
||||
{"q4_1", GGML_FTYPE_MOSTLY_Q4_1},
|
||||
{"q4_2", GGML_FTYPE_MOSTLY_Q4_2},
|
||||
{"q5_0", GGML_FTYPE_MOSTLY_Q5_0},
|
||||
{"q5_1", GGML_FTYPE_MOSTLY_Q5_1},
|
||||
{"q8_0", GGML_FTYPE_MOSTLY_Q8_0},
|
||||
};
|
||||
|
||||
void ggml_print_ftypes(FILE * fp) {
|
||||
for (auto it = GGML_FTYPE_MAP.begin(); it != GGML_FTYPE_MAP.end(); it++) {
|
||||
fprintf(fp, " type = \"%s\" or %d\n", it->first.c_str(), it->second);
|
||||
}
|
||||
}
|
||||
|
||||
enum ggml_ftype ggml_parse_ftype(const char * str) {
|
||||
enum ggml_ftype ftype;
|
||||
if (str[0] == 'q') {
|
||||
const auto it = GGML_FTYPE_MAP.find(str);
|
||||
if (it == GGML_FTYPE_MAP.end()) {
|
||||
fprintf(stderr, "%s: unknown ftype '%s'\n", __func__, str);
|
||||
return GGML_FTYPE_UNKNOWN;
|
||||
}
|
||||
ftype = it->second;
|
||||
} else {
|
||||
ftype = (enum ggml_ftype) atoi(str);
|
||||
}
|
||||
|
||||
return ftype;
|
||||
}
|
||||
|
||||
bool ggml_common_quantize_0(
|
||||
std::ifstream & finp,
|
||||
std::ofstream & fout,
|
||||
const ggml_ftype ftype,
|
||||
const std::vector<std::string> & to_quant,
|
||||
const std::vector<std::string> & to_skip) {
|
||||
|
||||
ggml_type qtype = GGML_TYPE_F32;
|
||||
|
||||
switch (ftype) {
|
||||
case GGML_FTYPE_MOSTLY_Q4_0: qtype = GGML_TYPE_Q4_0; break;
|
||||
case GGML_FTYPE_MOSTLY_Q4_1: qtype = GGML_TYPE_Q4_1; break;
|
||||
case GGML_FTYPE_MOSTLY_Q4_2: qtype = GGML_TYPE_Q4_2; break;
|
||||
case GGML_FTYPE_MOSTLY_Q5_0: qtype = GGML_TYPE_Q5_0; break;
|
||||
case GGML_FTYPE_MOSTLY_Q5_1: qtype = GGML_TYPE_Q5_1; break;
|
||||
case GGML_FTYPE_MOSTLY_Q8_0: qtype = GGML_TYPE_Q8_0; break;
|
||||
case GGML_FTYPE_UNKNOWN:
|
||||
case GGML_FTYPE_ALL_F32:
|
||||
case GGML_FTYPE_MOSTLY_F16:
|
||||
case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16:
|
||||
{
|
||||
fprintf(stderr, "%s: invalid model type %d\n", __func__, ftype);
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
if (!ggml_is_quantized(qtype)) {
|
||||
fprintf(stderr, "%s: invalid quantization type %d (%s)\n", __func__, qtype, ggml_type_name(qtype));
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t total_size_org = 0;
|
||||
size_t total_size_new = 0;
|
||||
|
||||
std::vector<float> work;
|
||||
|
||||
std::vector<uint8_t> data_u8;
|
||||
std::vector<ggml_fp16_t> data_f16;
|
||||
std::vector<float> data_f32;
|
||||
|
||||
std::vector<int64_t> hist_all(1 << 4, 0);
|
||||
|
||||
while (true) {
|
||||
int32_t n_dims;
|
||||
int32_t length;
|
||||
int32_t ttype;
|
||||
|
||||
finp.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
|
||||
finp.read(reinterpret_cast<char *>(&length), sizeof(length));
|
||||
finp.read(reinterpret_cast<char *>(&ttype), sizeof(ttype));
|
||||
|
||||
if (finp.eof()) {
|
||||
break;
|
||||
}
|
||||
|
||||
int32_t nelements = 1;
|
||||
int32_t ne[2] = { 1, 1 };
|
||||
for (int i = 0; i < n_dims; ++i) {
|
||||
finp.read (reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
|
||||
nelements *= ne[i];
|
||||
}
|
||||
|
||||
std::string name(length, 0);
|
||||
finp.read (&name[0], length);
|
||||
|
||||
printf("%64s - [%5d, %5d], type = %6s ", name.data(), ne[0], ne[1], ggml_type_name((ggml_type) ttype));
|
||||
|
||||
bool quantize = false;
|
||||
|
||||
// check if we should quantize this tensor
|
||||
for (const auto & s : to_quant) {
|
||||
if (std::regex_match(name, std::regex(s))) {
|
||||
quantize = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// check if we should skip this tensor
|
||||
for (const auto & s : to_skip) {
|
||||
if (std::regex_match(name, std::regex(s))) {
|
||||
quantize = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// quantize only 2D tensors
|
||||
quantize &= (n_dims == 2);
|
||||
|
||||
if (quantize) {
|
||||
if (ttype != GGML_TYPE_F32 && ttype != GGML_TYPE_F16) {
|
||||
fprintf(stderr, "%s: unsupported ttype %d (%s) for integer quantization\n", __func__, ttype, ggml_type_name((ggml_type) ttype));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (ttype == GGML_TYPE_F16) {
|
||||
data_f16.resize(nelements);
|
||||
finp.read(reinterpret_cast<char *>(data_f16.data()), nelements * sizeof(ggml_fp16_t));
|
||||
data_f32.resize(nelements);
|
||||
for (int i = 0; i < nelements; ++i) {
|
||||
data_f32[i] = ggml_fp16_to_fp32(data_f16[i]);
|
||||
}
|
||||
} else {
|
||||
data_f32.resize(nelements);
|
||||
finp.read(reinterpret_cast<char *>(data_f32.data()), nelements * sizeof(float));
|
||||
}
|
||||
|
||||
ttype = qtype;
|
||||
} else {
|
||||
const int bpe = (ttype == 0) ? sizeof(float) : sizeof(uint16_t);
|
||||
|
||||
data_u8.resize(nelements*bpe);
|
||||
finp.read(reinterpret_cast<char *>(data_u8.data()), nelements * bpe);
|
||||
}
|
||||
|
||||
fout.write(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
|
||||
fout.write(reinterpret_cast<char *>(&length), sizeof(length));
|
||||
fout.write(reinterpret_cast<char *>(&ttype), sizeof(ttype));
|
||||
for (int i = 0; i < n_dims; ++i) {
|
||||
fout.write(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
|
||||
}
|
||||
fout.write(&name[0], length);
|
||||
|
||||
if (quantize) {
|
||||
work.resize(nelements); // for quantization
|
||||
|
||||
size_t cur_size = 0;
|
||||
std::vector<int64_t> hist_cur(1 << 4, 0);
|
||||
|
||||
switch ((ggml_type) ttype) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
{
|
||||
cur_size = ggml_quantize_q4_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
|
||||
} break;
|
||||
case GGML_TYPE_Q4_1:
|
||||
{
|
||||
cur_size = ggml_quantize_q4_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
|
||||
} break;
|
||||
case GGML_TYPE_Q4_2:
|
||||
{
|
||||
cur_size = ggml_quantize_q4_2(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
|
||||
} break;
|
||||
case GGML_TYPE_Q5_0:
|
||||
{
|
||||
cur_size = ggml_quantize_q5_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
|
||||
} break;
|
||||
case GGML_TYPE_Q5_1:
|
||||
{
|
||||
cur_size = ggml_quantize_q5_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
|
||||
} break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
{
|
||||
cur_size = ggml_quantize_q8_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
|
||||
} break;
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_I8:
|
||||
case GGML_TYPE_I16:
|
||||
case GGML_TYPE_I32:
|
||||
case GGML_TYPE_Q8_1:
|
||||
case GGML_TYPE_COUNT:
|
||||
{
|
||||
fprintf(stderr, "%s: unsupported quantization type %d (%s)\n", __func__, ttype, ggml_type_name((ggml_type) ttype));
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
fout.write(reinterpret_cast<char *>(work.data()), cur_size);
|
||||
total_size_new += cur_size;
|
||||
|
||||
printf("size = %8.2f MB -> %8.2f MB | hist: ", nelements * sizeof(float)/1024.0/1024.0, cur_size/1024.0/1024.0);
|
||||
for (int i = 0; i < hist_cur.size(); ++i) {
|
||||
hist_all[i] += hist_cur[i];
|
||||
}
|
||||
|
||||
for (int i = 0; i < hist_cur.size(); ++i) {
|
||||
printf("%5.3f ", hist_cur[i] / (float)nelements);
|
||||
}
|
||||
printf("\n");
|
||||
} else {
|
||||
printf("size = %8.3f MB\n", data_u8.size()/1024.0/1024.0);
|
||||
fout.write(reinterpret_cast<char *>(data_u8.data()), data_u8.size());
|
||||
total_size_new += data_u8.size();
|
||||
}
|
||||
|
||||
total_size_org += nelements * sizeof(float);
|
||||
}
|
||||
|
||||
printf("%s: model size = %8.2f MB\n", __func__, total_size_org/1024.0/1024.0);
|
||||
printf("%s: quant size = %8.2f MB | ftype = %d (%s)\n", __func__, total_size_new/1024.0/1024.0, ftype, ggml_type_name(qtype));
|
||||
|
||||
{
|
||||
int64_t sum_all = 0;
|
||||
for (int i = 0; i < hist_all.size(); ++i) {
|
||||
sum_all += hist_all[i];
|
||||
}
|
||||
|
||||
printf("%s: hist: ", __func__);
|
||||
for (int i = 0; i < hist_all.size(); ++i) {
|
||||
printf("%5.3f ", hist_all[i] / (float)sum_all);
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
18
examples/common-ggml.h
Normal file
18
examples/common-ggml.h
Normal file
@ -0,0 +1,18 @@
|
||||
#pragma once
|
||||
|
||||
#include "ggml.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
enum ggml_ftype ggml_parse_ftype(const char * str);
|
||||
|
||||
void ggml_print_ftypes(FILE * fp = stderr);
|
||||
|
||||
bool ggml_common_quantize_0(
|
||||
std::ifstream & finp,
|
||||
std::ofstream & fout,
|
||||
const ggml_ftype ftype,
|
||||
const std::vector<std::string> & to_quant,
|
||||
const std::vector<std::string> & to_skip);
|
226
examples/common-sdl.cpp
Normal file
226
examples/common-sdl.cpp
Normal file
@ -0,0 +1,226 @@
|
||||
#include "common-sdl.h"
|
||||
|
||||
audio_async::audio_async(int len_ms) {
|
||||
m_len_ms = len_ms;
|
||||
|
||||
m_running = false;
|
||||
}
|
||||
|
||||
audio_async::~audio_async() {
|
||||
if (m_dev_id_in) {
|
||||
SDL_CloseAudioDevice(m_dev_id_in);
|
||||
}
|
||||
}
|
||||
|
||||
bool audio_async::init(int capture_id, int sample_rate) {
|
||||
SDL_LogSetPriority(SDL_LOG_CATEGORY_APPLICATION, SDL_LOG_PRIORITY_INFO);
|
||||
|
||||
if (SDL_Init(SDL_INIT_AUDIO) < 0) {
|
||||
SDL_LogError(SDL_LOG_CATEGORY_APPLICATION, "Couldn't initialize SDL: %s\n", SDL_GetError());
|
||||
return false;
|
||||
}
|
||||
|
||||
SDL_SetHintWithPriority(SDL_HINT_AUDIO_RESAMPLING_MODE, "medium", SDL_HINT_OVERRIDE);
|
||||
|
||||
{
|
||||
int nDevices = SDL_GetNumAudioDevices(SDL_TRUE);
|
||||
fprintf(stderr, "%s: found %d capture devices:\n", __func__, nDevices);
|
||||
for (int i = 0; i < nDevices; i++) {
|
||||
fprintf(stderr, "%s: - Capture device #%d: '%s'\n", __func__, i, SDL_GetAudioDeviceName(i, SDL_TRUE));
|
||||
}
|
||||
}
|
||||
|
||||
SDL_AudioSpec capture_spec_requested;
|
||||
SDL_AudioSpec capture_spec_obtained;
|
||||
|
||||
SDL_zero(capture_spec_requested);
|
||||
SDL_zero(capture_spec_obtained);
|
||||
|
||||
capture_spec_requested.freq = sample_rate;
|
||||
capture_spec_requested.format = AUDIO_F32;
|
||||
capture_spec_requested.channels = 1;
|
||||
capture_spec_requested.samples = 1024;
|
||||
capture_spec_requested.callback = [](void * userdata, uint8_t * stream, int len) {
|
||||
audio_async * audio = (audio_async *) userdata;
|
||||
audio->callback(stream, len);
|
||||
};
|
||||
capture_spec_requested.userdata = this;
|
||||
|
||||
if (capture_id >= 0) {
|
||||
fprintf(stderr, "%s: attempt to open capture device %d : '%s' ...\n", __func__, capture_id, SDL_GetAudioDeviceName(capture_id, SDL_TRUE));
|
||||
m_dev_id_in = SDL_OpenAudioDevice(SDL_GetAudioDeviceName(capture_id, SDL_TRUE), SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
|
||||
} else {
|
||||
fprintf(stderr, "%s: attempt to open default capture device ...\n", __func__);
|
||||
m_dev_id_in = SDL_OpenAudioDevice(nullptr, SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
|
||||
}
|
||||
|
||||
if (!m_dev_id_in) {
|
||||
fprintf(stderr, "%s: couldn't open an audio device for capture: %s!\n", __func__, SDL_GetError());
|
||||
m_dev_id_in = 0;
|
||||
|
||||
return false;
|
||||
} else {
|
||||
fprintf(stderr, "%s: obtained spec for input device (SDL Id = %d):\n", __func__, m_dev_id_in);
|
||||
fprintf(stderr, "%s: - sample rate: %d\n", __func__, capture_spec_obtained.freq);
|
||||
fprintf(stderr, "%s: - format: %d (required: %d)\n", __func__, capture_spec_obtained.format,
|
||||
capture_spec_requested.format);
|
||||
fprintf(stderr, "%s: - channels: %d (required: %d)\n", __func__, capture_spec_obtained.channels,
|
||||
capture_spec_requested.channels);
|
||||
fprintf(stderr, "%s: - samples per frame: %d\n", __func__, capture_spec_obtained.samples);
|
||||
}
|
||||
|
||||
m_sample_rate = capture_spec_obtained.freq;
|
||||
|
||||
m_audio.resize((m_sample_rate*m_len_ms)/1000);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool audio_async::resume() {
|
||||
if (!m_dev_id_in) {
|
||||
fprintf(stderr, "%s: no audio device to resume!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (m_running) {
|
||||
fprintf(stderr, "%s: already running!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
SDL_PauseAudioDevice(m_dev_id_in, 0);
|
||||
|
||||
m_running = true;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool audio_async::pause() {
|
||||
if (!m_dev_id_in) {
|
||||
fprintf(stderr, "%s: no audio device to pause!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!m_running) {
|
||||
fprintf(stderr, "%s: already paused!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
SDL_PauseAudioDevice(m_dev_id_in, 1);
|
||||
|
||||
m_running = false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool audio_async::clear() {
|
||||
if (!m_dev_id_in) {
|
||||
fprintf(stderr, "%s: no audio device to clear!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!m_running) {
|
||||
fprintf(stderr, "%s: not running!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(m_mutex);
|
||||
|
||||
m_audio_pos = 0;
|
||||
m_audio_len = 0;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// callback to be called by SDL
|
||||
void audio_async::callback(uint8_t * stream, int len) {
|
||||
if (!m_running) {
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t n_samples = len / sizeof(float);
|
||||
|
||||
m_audio_new.resize(n_samples);
|
||||
memcpy(m_audio_new.data(), stream, n_samples * sizeof(float));
|
||||
|
||||
//fprintf(stderr, "%s: %zu samples, pos %zu, len %zu\n", __func__, n_samples, m_audio_pos, m_audio_len);
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(m_mutex);
|
||||
|
||||
if (m_audio_pos + n_samples > m_audio.size()) {
|
||||
const size_t n0 = m_audio.size() - m_audio_pos;
|
||||
|
||||
memcpy(&m_audio[m_audio_pos], stream, n0 * sizeof(float));
|
||||
memcpy(&m_audio[0], &stream[n0], (n_samples - n0) * sizeof(float));
|
||||
|
||||
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
|
||||
m_audio_len = m_audio.size();
|
||||
} else {
|
||||
memcpy(&m_audio[m_audio_pos], stream, n_samples * sizeof(float));
|
||||
|
||||
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
|
||||
m_audio_len = std::min(m_audio_len + n_samples, m_audio.size());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void audio_async::get(int ms, std::vector<float> & result) {
|
||||
if (!m_dev_id_in) {
|
||||
fprintf(stderr, "%s: no audio device to get audio from!\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!m_running) {
|
||||
fprintf(stderr, "%s: not running!\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
result.clear();
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(m_mutex);
|
||||
|
||||
if (ms <= 0) {
|
||||
ms = m_len_ms;
|
||||
}
|
||||
|
||||
size_t n_samples = (m_sample_rate * ms) / 1000;
|
||||
if (n_samples > m_audio_len) {
|
||||
n_samples = m_audio_len;
|
||||
}
|
||||
|
||||
result.resize(n_samples);
|
||||
|
||||
int s0 = m_audio_pos - n_samples;
|
||||
if (s0 < 0) {
|
||||
s0 += m_audio.size();
|
||||
}
|
||||
|
||||
if (s0 + n_samples > m_audio.size()) {
|
||||
const size_t n0 = m_audio.size() - s0;
|
||||
|
||||
memcpy(result.data(), &m_audio[s0], n0 * sizeof(float));
|
||||
memcpy(&result[n0], &m_audio[0], (n_samples - n0) * sizeof(float));
|
||||
} else {
|
||||
memcpy(result.data(), &m_audio[s0], n_samples * sizeof(float));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool sdl_poll_events() {
|
||||
SDL_Event event;
|
||||
while (SDL_PollEvent(&event)) {
|
||||
switch (event.type) {
|
||||
case SDL_QUIT:
|
||||
{
|
||||
return false;
|
||||
} break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
50
examples/common-sdl.h
Normal file
50
examples/common-sdl.h
Normal file
@ -0,0 +1,50 @@
|
||||
#pragma once
|
||||
|
||||
#include <SDL.h>
|
||||
#include <SDL_audio.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
#include <mutex>
|
||||
|
||||
//
|
||||
// SDL Audio capture
|
||||
//
|
||||
|
||||
class audio_async {
|
||||
public:
|
||||
audio_async(int len_ms);
|
||||
~audio_async();
|
||||
|
||||
bool init(int capture_id, int sample_rate);
|
||||
|
||||
// start capturing audio via the provided SDL callback
|
||||
// keep last len_ms seconds of audio in a circular buffer
|
||||
bool resume();
|
||||
bool pause();
|
||||
bool clear();
|
||||
|
||||
// callback to be called by SDL
|
||||
void callback(uint8_t * stream, int len);
|
||||
|
||||
// get audio data from the circular buffer
|
||||
void get(int ms, std::vector<float> & audio);
|
||||
|
||||
private:
|
||||
SDL_AudioDeviceID m_dev_id_in = 0;
|
||||
|
||||
int m_len_ms = 0;
|
||||
int m_sample_rate = 0;
|
||||
|
||||
std::atomic_bool m_running;
|
||||
std::mutex m_mutex;
|
||||
|
||||
std::vector<float> m_audio;
|
||||
std::vector<float> m_audio_new;
|
||||
size_t m_audio_pos = 0;
|
||||
size_t m_audio_len = 0;
|
||||
};
|
||||
|
||||
// Return false if need to quit
|
||||
bool sdl_poll_events();
|
505
examples/common.cpp
Normal file
505
examples/common.cpp
Normal file
@ -0,0 +1,505 @@
|
||||
#include "common.h"
|
||||
|
||||
// third-party utilities
|
||||
// use your favorite implementations
|
||||
#define DR_WAV_IMPLEMENTATION
|
||||
#include "dr_wav.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <fstream>
|
||||
#include <regex>
|
||||
|
||||
#ifndef M_PI
|
||||
#define M_PI 3.14159265358979323846
|
||||
#endif
|
||||
|
||||
bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
||||
for (int i = 1; i < argc; i++) {
|
||||
std::string arg = argv[i];
|
||||
|
||||
if (arg == "-s" || arg == "--seed") {
|
||||
params.seed = std::stoi(argv[++i]);
|
||||
} else if (arg == "-t" || arg == "--threads") {
|
||||
params.n_threads = std::stoi(argv[++i]);
|
||||
} else if (arg == "-p" || arg == "--prompt") {
|
||||
params.prompt = argv[++i];
|
||||
} else if (arg == "-n" || arg == "--n_predict") {
|
||||
params.n_predict = std::stoi(argv[++i]);
|
||||
} else if (arg == "--top_k") {
|
||||
params.top_k = std::stoi(argv[++i]);
|
||||
} else if (arg == "--top_p") {
|
||||
params.top_p = std::stof(argv[++i]);
|
||||
} else if (arg == "--temp") {
|
||||
params.temp = std::stof(argv[++i]);
|
||||
} else if (arg == "-b" || arg == "--batch_size") {
|
||||
params.n_batch = std::stoi(argv[++i]);
|
||||
} else if (arg == "-m" || arg == "--model") {
|
||||
params.model = argv[++i];
|
||||
} else if (arg == "-h" || arg == "--help") {
|
||||
gpt_print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
} else {
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
gpt_print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
||||
fprintf(stderr, "usage: %s [options]\n", argv[0]);
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "options:\n");
|
||||
fprintf(stderr, " -h, --help show this help message and exit\n");
|
||||
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n");
|
||||
fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
|
||||
fprintf(stderr, " -p PROMPT, --prompt PROMPT\n");
|
||||
fprintf(stderr, " prompt to start generation with (default: random)\n");
|
||||
fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d)\n", params.n_predict);
|
||||
fprintf(stderr, " --top_k N top-k sampling (default: %d)\n", params.top_k);
|
||||
fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", params.top_p);
|
||||
fprintf(stderr, " --temp N temperature (default: %.1f)\n", params.temp);
|
||||
fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch);
|
||||
fprintf(stderr, " -m FNAME, --model FNAME\n");
|
||||
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
std::string gpt_random_prompt(std::mt19937 & rng) {
|
||||
const int r = rng() % 10;
|
||||
switch (r) {
|
||||
case 0: return "So";
|
||||
case 1: return "Once upon a time";
|
||||
case 2: return "When";
|
||||
case 3: return "The";
|
||||
case 4: return "After";
|
||||
case 5: return "If";
|
||||
case 6: return "import";
|
||||
case 7: return "He";
|
||||
case 8: return "She";
|
||||
case 9: return "They";
|
||||
default: return "To";
|
||||
}
|
||||
|
||||
return "The";
|
||||
}
|
||||
|
||||
std::string trim(const std::string & s) {
|
||||
std::regex e("^\\s+|\\s+$");
|
||||
return std::regex_replace(s, e, "");
|
||||
}
|
||||
|
||||
std::string replace(const std::string & s, const std::string & from, const std::string & to) {
|
||||
std::string result = s;
|
||||
size_t pos = 0;
|
||||
while ((pos = result.find(from, pos)) != std::string::npos) {
|
||||
result.replace(pos, from.length(), to);
|
||||
pos += to.length();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::map<std::string, int32_t> json_parse(const std::string & fname) {
|
||||
std::map<std::string, int32_t> result;
|
||||
|
||||
// read file into string
|
||||
std::string json;
|
||||
{
|
||||
std::ifstream ifs(fname);
|
||||
if (!ifs) {
|
||||
fprintf(stderr, "Failed to open %s\n", fname.c_str());
|
||||
exit(1);
|
||||
}
|
||||
|
||||
json = std::string((std::istreambuf_iterator<char>(ifs)),
|
||||
(std::istreambuf_iterator<char>()));
|
||||
}
|
||||
|
||||
if (json[0] != '{') {
|
||||
return result;
|
||||
}
|
||||
|
||||
// parse json
|
||||
{
|
||||
bool has_key = false;
|
||||
bool in_token = false;
|
||||
|
||||
std::string str_key = "";
|
||||
std::string str_val = "";
|
||||
|
||||
int n = json.size();
|
||||
for (int i = 1; i < n; ++i) {
|
||||
if (!in_token) {
|
||||
if (json[i] == ' ') continue;
|
||||
if (json[i] == '"') {
|
||||
in_token = true;
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
if (json[i] == '\\' && i+1 < n) {
|
||||
if (has_key == false) {
|
||||
str_key += json[i];
|
||||
} else {
|
||||
str_val += json[i];
|
||||
}
|
||||
++i;
|
||||
} else if (json[i] == '"') {
|
||||
if (has_key == false) {
|
||||
has_key = true;
|
||||
++i;
|
||||
while (json[i] == ' ') ++i;
|
||||
++i; // :
|
||||
while (json[i] == ' ') ++i;
|
||||
if (json[i] != '\"') {
|
||||
while (json[i] != ',' && json[i] != '}') {
|
||||
str_val += json[i++];
|
||||
}
|
||||
has_key = false;
|
||||
} else {
|
||||
in_token = true;
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
has_key = false;
|
||||
}
|
||||
|
||||
str_key = ::replace(str_key, "\\u0120", " " ); // \u0120 -> space
|
||||
str_key = ::replace(str_key, "\\u010a", "\n"); // \u010a -> new line
|
||||
str_key = ::replace(str_key, "\\\"", "\""); // \\\" -> "
|
||||
|
||||
try {
|
||||
result[str_key] = std::stoi(str_val);
|
||||
} catch (...) {
|
||||
//fprintf(stderr, "%s: ignoring key '%s' with value '%s'\n", fname.c_str(), str_key.c_str(), str_val.c_str());
|
||||
|
||||
}
|
||||
str_key = "";
|
||||
str_val = "";
|
||||
in_token = false;
|
||||
continue;
|
||||
}
|
||||
if (has_key == false) {
|
||||
str_key += json[i];
|
||||
} else {
|
||||
str_val += json[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text) {
|
||||
std::vector<std::string> words;
|
||||
|
||||
// first split the text into words
|
||||
{
|
||||
std::string str = text;
|
||||
std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
|
||||
|
||||
std::regex re(pat);
|
||||
std::smatch m;
|
||||
|
||||
while (std::regex_search(str, m, re)) {
|
||||
for (auto x : m) {
|
||||
words.push_back(x);
|
||||
}
|
||||
str = m.suffix();
|
||||
}
|
||||
}
|
||||
|
||||
// find the longest tokens that form the words:
|
||||
std::vector<gpt_vocab::id> tokens;
|
||||
for (const auto & word : words) {
|
||||
if (word.size() == 0) continue;
|
||||
|
||||
int i = 0;
|
||||
int n = word.size();
|
||||
while (i < n) {
|
||||
int j = n;
|
||||
while (j > i) {
|
||||
auto it = vocab.token_to_id.find(word.substr(i, j-i));
|
||||
if (it != vocab.token_to_id.end()) {
|
||||
tokens.push_back(it->second);
|
||||
i = j;
|
||||
break;
|
||||
}
|
||||
--j;
|
||||
}
|
||||
if (i == n) {
|
||||
break;
|
||||
}
|
||||
if (j == i) {
|
||||
auto sub = word.substr(i, 1);
|
||||
if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) {
|
||||
tokens.push_back(vocab.token_to_id.at(sub));
|
||||
} else {
|
||||
fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data());
|
||||
}
|
||||
++i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tokens;
|
||||
}
|
||||
|
||||
bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) {
|
||||
printf("%s: loading vocab from '%s'\n", __func__, fname.c_str());
|
||||
|
||||
vocab.token_to_id = ::json_parse(fname);
|
||||
|
||||
for (const auto & kv : vocab.token_to_id) {
|
||||
vocab.id_to_token[kv.second] = kv.first;
|
||||
}
|
||||
|
||||
printf("%s: vocab size = %d\n", __func__, (int) vocab.token_to_id.size());
|
||||
|
||||
// print the vocabulary
|
||||
//for (auto kv : vocab.token_to_id) {
|
||||
// printf("'%s' -> %d\n", kv.first.data(), kv.second);
|
||||
//}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
gpt_vocab::id gpt_sample_top_k_top_p(
|
||||
const gpt_vocab & vocab,
|
||||
const float * logits,
|
||||
int top_k,
|
||||
double top_p,
|
||||
double temp,
|
||||
std::mt19937 & rng) {
|
||||
int n_logits = vocab.id_to_token.size();
|
||||
|
||||
std::vector<std::pair<double, gpt_vocab::id>> logits_id;
|
||||
logits_id.reserve(n_logits);
|
||||
|
||||
{
|
||||
const double scale = 1.0/temp;
|
||||
for (int i = 0; i < n_logits; ++i) {
|
||||
logits_id.push_back(std::make_pair(logits[i]*scale, i));
|
||||
}
|
||||
}
|
||||
|
||||
// find the top K tokens
|
||||
std::partial_sort(
|
||||
logits_id.begin(),
|
||||
logits_id.begin() + top_k, logits_id.end(),
|
||||
[](const std::pair<double, gpt_vocab::id> & a, const std::pair<double, gpt_vocab::id> & b) {
|
||||
return a.first > b.first;
|
||||
});
|
||||
|
||||
logits_id.resize(top_k);
|
||||
|
||||
double maxl = -INFINITY;
|
||||
for (const auto & kv : logits_id) {
|
||||
maxl = std::max(maxl, kv.first);
|
||||
}
|
||||
|
||||
// compute probs for the top K tokens
|
||||
std::vector<double> probs;
|
||||
probs.reserve(logits_id.size());
|
||||
|
||||
double sum = 0.0;
|
||||
for (const auto & kv : logits_id) {
|
||||
double p = exp(kv.first - maxl);
|
||||
probs.push_back(p);
|
||||
sum += p;
|
||||
}
|
||||
|
||||
// normalize the probs
|
||||
for (auto & p : probs) {
|
||||
p /= sum;
|
||||
}
|
||||
|
||||
if (top_p < 1.0f) {
|
||||
double cumsum = 0.0f;
|
||||
for (int i = 0; i < top_k; i++) {
|
||||
cumsum += probs[i];
|
||||
if (cumsum >= top_p) {
|
||||
top_k = i + 1;
|
||||
probs.resize(top_k);
|
||||
logits_id.resize(top_k);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
cumsum = 1.0/cumsum;
|
||||
for (int i = 0; i < (int) probs.size(); i++) {
|
||||
probs[i] *= cumsum;
|
||||
}
|
||||
}
|
||||
|
||||
//printf("\n");
|
||||
//for (int i = 0; i < (int) probs.size(); i++) {
|
||||
// printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]);
|
||||
//}
|
||||
//exit(0);
|
||||
|
||||
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
||||
int idx = dist(rng);
|
||||
|
||||
return logits_id[idx].second;
|
||||
}
|
||||
|
||||
bool read_wav(const std::string & fname, std::vector<float>& pcmf32, std::vector<std::vector<float>>& pcmf32s, bool stereo) {
|
||||
drwav wav;
|
||||
std::vector<uint8_t> wav_data; // used for pipe input from stdin
|
||||
|
||||
if (fname == "-") {
|
||||
{
|
||||
uint8_t buf[1024];
|
||||
while (true)
|
||||
{
|
||||
const size_t n = fread(buf, 1, sizeof(buf), stdin);
|
||||
if (n == 0) {
|
||||
break;
|
||||
}
|
||||
wav_data.insert(wav_data.end(), buf, buf + n);
|
||||
}
|
||||
}
|
||||
|
||||
if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) {
|
||||
fprintf(stderr, "error: failed to open WAV file from stdin\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
|
||||
}
|
||||
else if (drwav_init_file(&wav, fname.c_str(), nullptr) == false) {
|
||||
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (wav.channels != 1 && wav.channels != 2) {
|
||||
fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", __func__, fname.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (stereo && wav.channels != 2) {
|
||||
fprintf(stderr, "%s: WAV file '%s' must be stereo for diarization\n", __func__, fname.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (wav.sampleRate != COMMON_SAMPLE_RATE) {
|
||||
fprintf(stderr, "%s: WAV file '%s' must be %i kHz\n", __func__, fname.c_str(), COMMON_SAMPLE_RATE/1000);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (wav.bitsPerSample != 16) {
|
||||
fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", __func__, fname.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8);
|
||||
|
||||
std::vector<int16_t> pcm16;
|
||||
pcm16.resize(n*wav.channels);
|
||||
drwav_read_pcm_frames_s16(&wav, n, pcm16.data());
|
||||
drwav_uninit(&wav);
|
||||
|
||||
// convert to mono, float
|
||||
pcmf32.resize(n);
|
||||
if (wav.channels == 1) {
|
||||
for (uint64_t i = 0; i < n; i++) {
|
||||
pcmf32[i] = float(pcm16[i])/32768.0f;
|
||||
}
|
||||
} else {
|
||||
for (uint64_t i = 0; i < n; i++) {
|
||||
pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
|
||||
}
|
||||
}
|
||||
|
||||
if (stereo) {
|
||||
// convert to stereo, float
|
||||
pcmf32s.resize(2);
|
||||
|
||||
pcmf32s[0].resize(n);
|
||||
pcmf32s[1].resize(n);
|
||||
for (uint64_t i = 0; i < n; i++) {
|
||||
pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
|
||||
pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate) {
|
||||
const float rc = 1.0f / (2.0f * M_PI * cutoff);
|
||||
const float dt = 1.0f / sample_rate;
|
||||
const float alpha = dt / (rc + dt);
|
||||
|
||||
float y = data[0];
|
||||
|
||||
for (size_t i = 1; i < data.size(); i++) {
|
||||
y = alpha * (y + data[i] - data[i - 1]);
|
||||
data[i] = y;
|
||||
}
|
||||
}
|
||||
|
||||
bool vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) {
|
||||
const int n_samples = pcmf32.size();
|
||||
const int n_samples_last = (sample_rate * last_ms) / 1000;
|
||||
|
||||
if (n_samples_last >= n_samples) {
|
||||
// not enough samples - assume no speech
|
||||
return false;
|
||||
}
|
||||
|
||||
if (freq_thold > 0.0f) {
|
||||
high_pass_filter(pcmf32, freq_thold, sample_rate);
|
||||
}
|
||||
|
||||
float energy_all = 0.0f;
|
||||
float energy_last = 0.0f;
|
||||
|
||||
for (int i = 0; i < n_samples; i++) {
|
||||
energy_all += fabsf(pcmf32[i]);
|
||||
|
||||
if (i >= n_samples - n_samples_last) {
|
||||
energy_last += fabsf(pcmf32[i]);
|
||||
}
|
||||
}
|
||||
|
||||
energy_all /= n_samples;
|
||||
energy_last /= n_samples_last;
|
||||
|
||||
if (verbose) {
|
||||
fprintf(stderr, "%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold);
|
||||
}
|
||||
|
||||
if (energy_last > vad_thold*energy_all) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
float similarity(const std::string & s0, const std::string & s1) {
|
||||
const size_t len0 = s0.size() + 1;
|
||||
const size_t len1 = s1.size() + 1;
|
||||
|
||||
std::vector<int> col(len1, 0);
|
||||
std::vector<int> prevCol(len1, 0);
|
||||
|
||||
for (size_t i = 0; i < len1; i++) {
|
||||
prevCol[i] = i;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < len0; i++) {
|
||||
col[0] = i;
|
||||
for (size_t j = 1; j < len1; j++) {
|
||||
col[j] = std::min(std::min(1 + col[j - 1], 1 + prevCol[j]), prevCol[j - 1] + (i > 0 && s0[i - 1] == s1[j - 1] ? 0 : 1));
|
||||
}
|
||||
col.swap(prevCol);
|
||||
}
|
||||
|
||||
const float dist = prevCol[len1 - 1];
|
||||
|
||||
return 1.0f - (dist / std::max(s0.size(), s1.size()));
|
||||
}
|
122
examples/common.h
Normal file
122
examples/common.h
Normal file
@ -0,0 +1,122 @@
|
||||
// Various helper functions and utilities
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <random>
|
||||
#include <thread>
|
||||
|
||||
#define COMMON_SAMPLE_RATE 16000
|
||||
|
||||
//
|
||||
// CLI argument parsing
|
||||
//
|
||||
|
||||
struct gpt_params {
|
||||
int32_t seed = -1; // RNG seed
|
||||
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||
int32_t n_predict = 200; // new tokens to predict
|
||||
|
||||
// sampling parameters
|
||||
int32_t top_k = 40;
|
||||
float top_p = 0.9f;
|
||||
float temp = 0.9f;
|
||||
|
||||
int32_t n_batch = 8; // batch size for prompt processing
|
||||
|
||||
std::string model = "models/gpt-2-117M/ggml-model.bin"; // model path
|
||||
std::string prompt;
|
||||
};
|
||||
|
||||
bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
|
||||
|
||||
void gpt_print_usage(int argc, char ** argv, const gpt_params & params);
|
||||
|
||||
std::string gpt_random_prompt(std::mt19937 & rng);
|
||||
|
||||
//
|
||||
// Vocab utils
|
||||
//
|
||||
|
||||
std::string trim(const std::string & s);
|
||||
|
||||
std::string replace(
|
||||
const std::string & s,
|
||||
const std::string & from,
|
||||
const std::string & to);
|
||||
|
||||
struct gpt_vocab {
|
||||
using id = int32_t;
|
||||
using token = std::string;
|
||||
|
||||
std::map<token, id> token_to_id;
|
||||
std::map<id, token> id_to_token;
|
||||
};
|
||||
|
||||
// poor-man's JSON parsing
|
||||
std::map<std::string, int32_t> json_parse(const std::string & fname);
|
||||
|
||||
// split text into tokens
|
||||
//
|
||||
// ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53
|
||||
//
|
||||
// Regex (Python):
|
||||
// r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
|
||||
//
|
||||
// Regex (C++):
|
||||
// R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"
|
||||
//
|
||||
std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text);
|
||||
|
||||
// load the tokens from encoder.json
|
||||
bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab);
|
||||
|
||||
// sample next token given probabilities for each embedding
|
||||
//
|
||||
// - consider only the top K tokens
|
||||
// - from them, consider only the top tokens with cumulative probability > P
|
||||
//
|
||||
// TODO: not sure if this implementation is correct
|
||||
// TODO: temperature is not implemented
|
||||
//
|
||||
gpt_vocab::id gpt_sample_top_k_top_p(
|
||||
const gpt_vocab & vocab,
|
||||
const float * logits,
|
||||
int top_k,
|
||||
double top_p,
|
||||
double temp,
|
||||
std::mt19937 & rng);
|
||||
|
||||
//
|
||||
// Audio utils
|
||||
//
|
||||
|
||||
// Read WAV audio file and store the PCM data into pcmf32
|
||||
// The sample rate of the audio must be equal to COMMON_SAMPLE_RATE
|
||||
// If stereo flag is set and the audio has 2 channels, the pcmf32s will contain 2 channel PCM
|
||||
bool read_wav(
|
||||
const std::string & fname,
|
||||
std::vector<float> & pcmf32,
|
||||
std::vector<std::vector<float>> & pcmf32s,
|
||||
bool stereo);
|
||||
|
||||
// Apply a high-pass frequency filter to PCM audio
|
||||
// Suppresses frequencies below cutoff Hz
|
||||
void high_pass_filter(
|
||||
std::vector<float> & data,
|
||||
float cutoff,
|
||||
float sample_rate);
|
||||
|
||||
// Basic voice activity detection (VAD) using audio energy adaptive threshold
|
||||
bool vad_simple(
|
||||
std::vector<float> & pcmf32,
|
||||
int sample_rate,
|
||||
int last_ms,
|
||||
float vad_thold,
|
||||
float freq_thold,
|
||||
bool verbose);
|
||||
|
||||
// compute similarity between two strings using Levenshtein distance
|
||||
float similarity(const std::string & s0, const std::string & s1);
|
@ -8,7 +8,7 @@ function convertTypedArray(src, type) {
|
||||
|
||||
var printTextarea = (function() {
|
||||
var element = document.getElementById('output');
|
||||
if (element) element.alue = ''; // clear browser cache
|
||||
if (element) element.value = ''; // clear browser cache
|
||||
return function(text) {
|
||||
if (arguments.length > 1) text = Array.prototype.slice.call(arguments).join(' ');
|
||||
console.log(text);
|
||||
@ -88,11 +88,15 @@ async function fetchRemote(url, cbProgress, cbPrint) {
|
||||
// - check if the data is already in the IndexedDB
|
||||
// - if not, fetch it from the remote URL and store it in the IndexedDB
|
||||
function loadRemote(url, dst, size_mb, cbProgress, cbReady, cbCancel, cbPrint) {
|
||||
// query the storage quota and print it
|
||||
navigator.storage.estimate().then(function (estimate) {
|
||||
cbPrint('loadRemote: storage quota: ' + estimate.quota + ' bytes');
|
||||
cbPrint('loadRemote: storage usage: ' + estimate.usage + ' bytes');
|
||||
});
|
||||
if (!navigator.storage || !navigator.storage.estimate) {
|
||||
cbPrint('loadRemote: navigator.storage.estimate() is not supported');
|
||||
} else {
|
||||
// query the storage quota and print it
|
||||
navigator.storage.estimate().then(function (estimate) {
|
||||
cbPrint('loadRemote: storage quota: ' + estimate.quota + ' bytes');
|
||||
cbPrint('loadRemote: storage usage: ' + estimate.usage + ' bytes');
|
||||
});
|
||||
}
|
||||
|
||||
// check if the data is already in the IndexedDB
|
||||
var rq = indexedDB.open(dbName, dbVersion);
|
||||
@ -141,7 +145,15 @@ function loadRemote(url, dst, size_mb, cbProgress, cbReady, cbCancel, cbPrint) {
|
||||
var db = event.target.result;
|
||||
var tx = db.transaction(['models'], 'readwrite');
|
||||
var os = tx.objectStore('models');
|
||||
var rq = os.put(data, url);
|
||||
|
||||
var rq = null;
|
||||
try {
|
||||
var rq = os.put(data, url);
|
||||
} catch (e) {
|
||||
cbPrint('loadRemote: failed to store "' + url + '" in the IndexedDB: \n' + e);
|
||||
cbCancel();
|
||||
return;
|
||||
}
|
||||
|
||||
rq.onsuccess = function (event) {
|
||||
cbPrint('loadRemote: "' + url + '" stored in the IndexedDB');
|
||||
@ -176,7 +188,6 @@ function loadRemote(url, dst, size_mb, cbProgress, cbReady, cbCancel, cbPrint) {
|
||||
|
||||
rq.onabort = function (event) {
|
||||
cbPrint('loadRemote: failed to open IndexedDB: abort');
|
||||
|
||||
cbCancel();
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -100,7 +100,7 @@ while [ $running -eq 1 ]; do
|
||||
err=$(cat /tmp/whisper-live.err | wc -l)
|
||||
done
|
||||
|
||||
./main -t 8 -m ./models/ggml-base.en.bin -f /tmp/whisper-live.wav --no-timestamps -otxt 2> /tmp/whispererr | tail -n 1
|
||||
./main -t 8 -m ./models/ggml-${model}.bin -f /tmp/whisper-live.wav --no-timestamps -otxt 2> /tmp/whispererr | tail -n 1
|
||||
|
||||
while [ $SECONDS -lt $((($i+1)*$step_s)) ]; do
|
||||
sleep 1
|
||||
|
@ -1,3 +1,6 @@
|
||||
set(TARGET main)
|
||||
add_executable(${TARGET} main.cpp)
|
||||
target_link_libraries(${TARGET} PRIVATE whisper ${CMAKE_THREAD_LIBS_INIT})
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE common whisper ${CMAKE_THREAD_LIBS_INIT})
|
||||
|
@ -9,25 +9,36 @@ It can be used as a reference for using the `whisper.cpp` library in other proje
|
||||
usage: ./main [options] file0.wav file1.wav ...
|
||||
|
||||
options:
|
||||
-h, --help [default] show this help message and exit
|
||||
-t N, --threads N [4 ] number of threads to use during computation
|
||||
-p N, --processors N [1 ] number of processors to use during computation
|
||||
-ot N, --offset-t N [0 ] time offset in milliseconds
|
||||
-on N, --offset-n N [0 ] segment index offset
|
||||
-d N, --duration N [0 ] duration of audio to process in milliseconds
|
||||
-mc N, --max-context N [-1 ] maximum number of text context tokens to store
|
||||
-ml N, --max-len N [0 ] maximum segment length in characters
|
||||
-wt N, --word-thold N [0.01 ] word timestamp probability threshold
|
||||
-su, --speed-up [false ] speed up audio by x2 (reduced accuracy)
|
||||
-tr, --translate [false ] translate from source language to english
|
||||
-otxt, --output-txt [false ] output result in a text file
|
||||
-ovtt, --output-vtt [false ] output result in a vtt file
|
||||
-osrt, --output-srt [false ] output result in a srt file
|
||||
-owts, --output-words [false ] output script for generating karaoke video
|
||||
-ps, --print-special [false ] print special tokens
|
||||
-pc, --print-colors [false ] print colors
|
||||
-nt, --no-timestamps [true ] do not print timestamps
|
||||
-l LANG, --language LANG [en ] spoken language
|
||||
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path
|
||||
-f FNAME, --file FNAME [ ] input WAV file path
|
||||
-h, --help [default] show this help message and exit
|
||||
-t N, --threads N [4 ] number of threads to use during computation
|
||||
-p N, --processors N [1 ] number of processors to use during computation
|
||||
-ot N, --offset-t N [0 ] time offset in milliseconds
|
||||
-on N, --offset-n N [0 ] segment index offset
|
||||
-d N, --duration N [0 ] duration of audio to process in milliseconds
|
||||
-mc N, --max-context N [-1 ] maximum number of text context tokens to store
|
||||
-ml N, --max-len N [0 ] maximum segment length in characters
|
||||
-bo N, --best-of N [5 ] number of best candidates to keep
|
||||
-bs N, --beam-size N [-1 ] beam size for beam search
|
||||
-wt N, --word-thold N [0.01 ] word timestamp probability threshold
|
||||
-et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail
|
||||
-lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail
|
||||
-su, --speed-up [false ] speed up audio by x2 (reduced accuracy)
|
||||
-tr, --translate [false ] translate from source language to english
|
||||
-di, --diarize [false ] stereo audio diarization
|
||||
-nf, --no-fallback [false ] do not use temperature fallback while decoding
|
||||
-otxt, --output-txt [false ] output result in a text file
|
||||
-ovtt, --output-vtt [false ] output result in a vtt file
|
||||
-osrt, --output-srt [false ] output result in a srt file
|
||||
-owts, --output-words [false ] output script for generating karaoke video
|
||||
-ocsv, --output-csv [false ] output result in a CSV file
|
||||
-oj, --output-json [false ] output result in a JSON file
|
||||
-of FNAME, --output-file FNAME [ ] output file path (without file extension)
|
||||
-ps, --print-special [false ] print special tokens
|
||||
-pc, --print-colors [false ] print colors
|
||||
-pp, --print-progress [false ] print progress
|
||||
-nt, --no-timestamps [true ] do not print timestamps
|
||||
-l LANG, --language LANG [en ] spoken language ('auto' for auto-detect)
|
||||
--prompt PROMPT [ ] initial prompt
|
||||
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path
|
||||
-f FNAME, --file FNAME [ ] input WAV file path
|
||||
```
|
||||
|
@ -1,9 +1,6 @@
|
||||
#include "whisper.h"
|
||||
#include "common.h"
|
||||
|
||||
// third-party utilities
|
||||
// use your favorite implementations
|
||||
#define DR_WAV_IMPLEMENTATION
|
||||
#include "dr_wav.h"
|
||||
#include "whisper.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <fstream>
|
||||
@ -11,6 +8,7 @@
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <cstring>
|
||||
|
||||
// Terminal color map. 10 colors grouped in ranges [0.0, 0.1, ..., 0.9]
|
||||
// Lowest is red, middle is yellow, highest is green.
|
||||
@ -53,32 +51,43 @@ void replace_all(std::string & s, const std::string & search, const std::string
|
||||
// command-line parameters
|
||||
struct whisper_params {
|
||||
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||
int32_t n_processors = 1;
|
||||
int32_t offset_t_ms = 0;
|
||||
int32_t offset_n = 0;
|
||||
int32_t duration_ms = 0;
|
||||
int32_t n_processors = 1;
|
||||
int32_t offset_t_ms = 0;
|
||||
int32_t offset_n = 0;
|
||||
int32_t duration_ms = 0;
|
||||
int32_t max_context = -1;
|
||||
int32_t max_len = 0;
|
||||
int32_t max_len = 0;
|
||||
int32_t best_of = 2;
|
||||
int32_t beam_size = -1;
|
||||
|
||||
float word_thold = 0.01f;
|
||||
float word_thold = 0.01f;
|
||||
float entropy_thold = 2.40f;
|
||||
float logprob_thold = -1.00f;
|
||||
|
||||
bool speed_up = false;
|
||||
bool translate = false;
|
||||
bool diarize = false;
|
||||
bool split_on_word = false;
|
||||
bool no_fallback = false;
|
||||
bool output_txt = false;
|
||||
bool output_vtt = false;
|
||||
bool output_srt = false;
|
||||
bool output_wts = false;
|
||||
bool output_csv = false;
|
||||
bool output_jsn = false;
|
||||
bool output_lrc = false;
|
||||
bool print_special = false;
|
||||
bool print_colors = false;
|
||||
bool print_progress = false;
|
||||
bool no_timestamps = false;
|
||||
|
||||
std::string language = "en";
|
||||
std::string prompt = "";
|
||||
std::string prompt;
|
||||
std::string font_path = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
|
||||
std::string model = "models/ggml-base.en.bin";
|
||||
|
||||
std::vector<std::string> fname_inp = {};
|
||||
std::vector<std::string> fname_out = {};
|
||||
};
|
||||
|
||||
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
|
||||
@ -87,6 +96,11 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
for (int i = 1; i < argc; i++) {
|
||||
std::string arg = argv[i];
|
||||
|
||||
if (arg == "-"){
|
||||
params.fname_inp.push_back(arg);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (arg[0] != '-') {
|
||||
params.fname_inp.push_back(arg);
|
||||
continue;
|
||||
@ -103,14 +117,25 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); }
|
||||
else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); }
|
||||
else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); }
|
||||
else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); }
|
||||
else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); }
|
||||
else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
|
||||
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
|
||||
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
|
||||
else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; }
|
||||
else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; }
|
||||
else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; }
|
||||
else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
|
||||
else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
|
||||
else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; }
|
||||
else if (arg == "-olrc" || arg == "--output-lrc") { params.output_lrc = true; }
|
||||
else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; }
|
||||
else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; }
|
||||
else if (arg == "-oj" || arg == "--output-json") { params.output_jsn = true; }
|
||||
else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(argv[++i]); }
|
||||
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
|
||||
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
|
||||
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
|
||||
@ -118,7 +143,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
|
||||
else if ( arg == "--prompt") { params.prompt = argv[++i]; }
|
||||
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
|
||||
else if (arg == "-f" || arg == "--file") { params.fname_inp.push_back(argv[++i]); }
|
||||
else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); }
|
||||
else {
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
whisper_print_usage(argc, argv, params);
|
||||
@ -129,35 +154,46 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
return true;
|
||||
}
|
||||
|
||||
void whisper_print_usage(int argc, char ** argv, const whisper_params & params) {
|
||||
void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params) {
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "usage: %s [options] file0.wav file1.wav ...\n", argv[0]);
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "options:\n");
|
||||
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
|
||||
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
|
||||
fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors);
|
||||
fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms);
|
||||
fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n);
|
||||
fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms);
|
||||
fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
|
||||
fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len);
|
||||
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
|
||||
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
|
||||
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
|
||||
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
|
||||
fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false");
|
||||
fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
|
||||
fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false");
|
||||
fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false");
|
||||
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
||||
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
|
||||
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
|
||||
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true");
|
||||
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
|
||||
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
|
||||
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
||||
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
|
||||
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
|
||||
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
|
||||
fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors);
|
||||
fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms);
|
||||
fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n);
|
||||
fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms);
|
||||
fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
|
||||
fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len);
|
||||
fprintf(stderr, " -sow, --split-on-word [%-7s] split on word rather than on token\n", params.split_on_word ? "true" : "false");
|
||||
fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of);
|
||||
fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size);
|
||||
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
|
||||
fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold);
|
||||
fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold);
|
||||
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
|
||||
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
|
||||
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
|
||||
fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false");
|
||||
fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false");
|
||||
fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
|
||||
fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false");
|
||||
fprintf(stderr, " -olrc, --output-lrc [%-7s] output result in a lrc file\n", params.output_lrc ? "true" : "false");
|
||||
fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false");
|
||||
fprintf(stderr, " -fp, --font-path [%-7s] path to a monospace font for karaoke video\n", params.font_path.c_str());
|
||||
fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false");
|
||||
fprintf(stderr, " -oj, --output-json [%-7s] output result in a JSON file\n", params.output_jsn ? "true" : "false");
|
||||
fprintf(stderr, " -of FNAME, --output-file FNAME [%-7s] output file path (without file extension)\n", "");
|
||||
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
||||
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
|
||||
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
|
||||
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true");
|
||||
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
|
||||
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
|
||||
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
||||
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
@ -167,96 +203,87 @@ struct whisper_print_user_data {
|
||||
const std::vector<std::vector<float>> * pcmf32s;
|
||||
};
|
||||
|
||||
void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) {
|
||||
void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper_state * /*state*/, int n_new, void * user_data) {
|
||||
const auto & params = *((whisper_print_user_data *) user_data)->params;
|
||||
const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s;
|
||||
|
||||
const int n_segments = whisper_full_n_segments(ctx);
|
||||
|
||||
std::string speaker = "";
|
||||
|
||||
int64_t t0 = 0;
|
||||
int64_t t1 = 0;
|
||||
|
||||
// print the last n_new segments
|
||||
const int s0 = n_segments - n_new;
|
||||
|
||||
if (s0 == 0) {
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
for (int i = s0; i < n_segments; i++) {
|
||||
if (params.no_timestamps) {
|
||||
if (params.print_colors) {
|
||||
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
|
||||
if (params.print_special == false) {
|
||||
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
|
||||
if (id >= whisper_token_eot(ctx)) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
const char * text = whisper_full_get_token_text(ctx, i, j);
|
||||
const float p = whisper_full_get_token_p (ctx, i, j);
|
||||
|
||||
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
|
||||
|
||||
printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
|
||||
}
|
||||
} else {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
printf("%s", text);
|
||||
}
|
||||
fflush(stdout);
|
||||
} else {
|
||||
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
|
||||
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
||||
|
||||
std::string speaker = "";
|
||||
|
||||
if (params.diarize && pcmf32s.size() == 2) {
|
||||
const int64_t n_samples = pcmf32s[0].size();
|
||||
|
||||
const int64_t is0 = timestamp_to_sample(t0, n_samples);
|
||||
const int64_t is1 = timestamp_to_sample(t1, n_samples);
|
||||
|
||||
double energy0 = 0.0f;
|
||||
double energy1 = 0.0f;
|
||||
|
||||
for (int64_t j = is0; j < is1; j++) {
|
||||
energy0 += fabs(pcmf32s[0][j]);
|
||||
energy1 += fabs(pcmf32s[1][j]);
|
||||
}
|
||||
|
||||
if (energy0 > 1.1*energy1) {
|
||||
speaker = "(speaker 0)";
|
||||
} else if (energy1 > 1.1*energy0) {
|
||||
speaker = "(speaker 1)";
|
||||
} else {
|
||||
speaker = "(speaker ?)";
|
||||
}
|
||||
|
||||
//printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str());
|
||||
}
|
||||
|
||||
if (params.print_colors) {
|
||||
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
|
||||
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
|
||||
if (params.print_special == false) {
|
||||
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
|
||||
if (id >= whisper_token_eot(ctx)) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
const char * text = whisper_full_get_token_text(ctx, i, j);
|
||||
const float p = whisper_full_get_token_p (ctx, i, j);
|
||||
|
||||
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
|
||||
|
||||
printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
|
||||
}
|
||||
printf("\n");
|
||||
} else {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
|
||||
printf("[%s --> %s] %s%s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), speaker.c_str(), text);
|
||||
}
|
||||
if (!params.no_timestamps || params.diarize) {
|
||||
t0 = whisper_full_get_segment_t0(ctx, i);
|
||||
t1 = whisper_full_get_segment_t1(ctx, i);
|
||||
}
|
||||
|
||||
if (!params.no_timestamps) {
|
||||
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
|
||||
}
|
||||
|
||||
if (params.diarize && pcmf32s.size() == 2) {
|
||||
const int64_t n_samples = pcmf32s[0].size();
|
||||
|
||||
const int64_t is0 = timestamp_to_sample(t0, n_samples);
|
||||
const int64_t is1 = timestamp_to_sample(t1, n_samples);
|
||||
|
||||
double energy0 = 0.0f;
|
||||
double energy1 = 0.0f;
|
||||
|
||||
for (int64_t j = is0; j < is1; j++) {
|
||||
energy0 += fabs(pcmf32s[0][j]);
|
||||
energy1 += fabs(pcmf32s[1][j]);
|
||||
}
|
||||
|
||||
if (energy0 > 1.1*energy1) {
|
||||
speaker = "(speaker 0)";
|
||||
} else if (energy1 > 1.1*energy0) {
|
||||
speaker = "(speaker 1)";
|
||||
} else {
|
||||
speaker = "(speaker ?)";
|
||||
}
|
||||
|
||||
//printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str());
|
||||
}
|
||||
|
||||
if (params.print_colors) {
|
||||
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
|
||||
if (params.print_special == false) {
|
||||
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
|
||||
if (id >= whisper_token_eot(ctx)) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
const char * text = whisper_full_get_token_text(ctx, i, j);
|
||||
const float p = whisper_full_get_token_p (ctx, i, j);
|
||||
|
||||
const int col = std::max(0, std::min((int) k_colors.size() - 1, (int) (std::pow(p, 3)*float(k_colors.size()))));
|
||||
|
||||
printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
|
||||
}
|
||||
} else {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
|
||||
printf("%s%s", speaker.c_str(), text);
|
||||
}
|
||||
|
||||
// with timestamps or speakers: each segment on new line
|
||||
if (!params.no_timestamps || params.diarize) {
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
fflush(stdout);
|
||||
}
|
||||
}
|
||||
|
||||
@ -325,6 +352,186 @@ bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_
|
||||
return true;
|
||||
}
|
||||
|
||||
char *escape_double_quotes_and_backslashes(const char *str) {
|
||||
if (str == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
size_t escaped_length = strlen(str) + 1;
|
||||
|
||||
for (size_t i = 0; str[i] != '\0'; i++) {
|
||||
if (str[i] == '"' || str[i] == '\\') {
|
||||
escaped_length++;
|
||||
}
|
||||
}
|
||||
|
||||
char *escaped = (char *)calloc(escaped_length, 1); // pre-zeroed
|
||||
if (escaped == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
size_t pos = 0;
|
||||
for (size_t i = 0; str[i] != '\0'; i++) {
|
||||
if (str[i] == '"' || str[i] == '\\') {
|
||||
escaped[pos++] = '\\';
|
||||
}
|
||||
escaped[pos++] = str[i];
|
||||
}
|
||||
|
||||
// no need to set zero due to calloc() being used prior
|
||||
|
||||
return escaped;
|
||||
}
|
||||
|
||||
bool output_csv(struct whisper_context * ctx, const char * fname) {
|
||||
std::ofstream fout(fname);
|
||||
if (!fout.is_open()) {
|
||||
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
|
||||
return false;
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
|
||||
|
||||
const int n_segments = whisper_full_n_segments(ctx);
|
||||
fout << "start,end,text\n";
|
||||
for (int i = 0; i < n_segments; ++i) {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
|
||||
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
||||
char * text_escaped = escape_double_quotes_and_backslashes(text);
|
||||
|
||||
//need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds.
|
||||
fout << 10 * t0 << "," << 10 * t1 << ",\"" << text_escaped << "\"\n";
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool output_json(struct whisper_context * ctx, const char * fname, const whisper_params & params) {
|
||||
std::ofstream fout(fname);
|
||||
int indent = 0;
|
||||
|
||||
auto doindent = [&]() {
|
||||
for (int i = 0; i < indent; i++) fout << "\t";
|
||||
};
|
||||
|
||||
auto start_arr = [&](const char *name) {
|
||||
doindent();
|
||||
fout << "\"" << name << "\": [\n";
|
||||
indent++;
|
||||
};
|
||||
|
||||
auto end_arr = [&](bool end = false) {
|
||||
indent--;
|
||||
doindent();
|
||||
fout << (end ? "]\n" : "},\n");
|
||||
};
|
||||
|
||||
auto start_obj = [&](const char *name = nullptr) {
|
||||
doindent();
|
||||
if (name) {
|
||||
fout << "\"" << name << "\": {\n";
|
||||
} else {
|
||||
fout << "{\n";
|
||||
}
|
||||
indent++;
|
||||
};
|
||||
|
||||
auto end_obj = [&](bool end = false) {
|
||||
indent--;
|
||||
doindent();
|
||||
fout << (end ? "}\n" : "},\n");
|
||||
};
|
||||
|
||||
auto start_value = [&](const char *name) {
|
||||
doindent();
|
||||
fout << "\"" << name << "\": ";
|
||||
};
|
||||
|
||||
auto value_s = [&](const char *name, const char *val, bool end = false) {
|
||||
start_value(name);
|
||||
char * val_escaped = escape_double_quotes_and_backslashes(val);
|
||||
fout << "\"" << val_escaped << (end ? "\"\n" : "\",\n");
|
||||
free(val_escaped);
|
||||
};
|
||||
|
||||
auto end_value = [&](bool end = false) {
|
||||
fout << (end ? "\n" : ",\n");
|
||||
};
|
||||
|
||||
auto value_i = [&](const char *name, const int64_t val, bool end = false) {
|
||||
start_value(name);
|
||||
fout << val;
|
||||
end_value(end);
|
||||
};
|
||||
|
||||
auto value_b = [&](const char *name, const bool val, bool end = false) {
|
||||
start_value(name);
|
||||
fout << (val ? "true" : "false");
|
||||
end_value(end);
|
||||
};
|
||||
|
||||
if (!fout.is_open()) {
|
||||
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
|
||||
return false;
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
|
||||
start_obj();
|
||||
value_s("systeminfo", whisper_print_system_info());
|
||||
start_obj("model");
|
||||
value_s("type", whisper_model_type_readable(ctx));
|
||||
value_b("multilingual", whisper_is_multilingual(ctx));
|
||||
value_i("vocab", whisper_model_n_vocab(ctx));
|
||||
start_obj("audio");
|
||||
value_i("ctx", whisper_model_n_audio_ctx(ctx));
|
||||
value_i("state", whisper_model_n_audio_state(ctx));
|
||||
value_i("head", whisper_model_n_audio_head(ctx));
|
||||
value_i("layer", whisper_model_n_audio_layer(ctx), true);
|
||||
end_obj();
|
||||
start_obj("text");
|
||||
value_i("ctx", whisper_model_n_text_ctx(ctx));
|
||||
value_i("state", whisper_model_n_text_state(ctx));
|
||||
value_i("head", whisper_model_n_text_head(ctx));
|
||||
value_i("layer", whisper_model_n_text_layer(ctx), true);
|
||||
end_obj();
|
||||
value_i("mels", whisper_model_n_mels(ctx));
|
||||
value_i("ftype", whisper_model_ftype(ctx), true);
|
||||
end_obj();
|
||||
start_obj("params");
|
||||
value_s("model", params.model.c_str());
|
||||
value_s("language", params.language.c_str());
|
||||
value_b("translate", params.translate, true);
|
||||
end_obj();
|
||||
start_obj("result");
|
||||
value_s("language", whisper_lang_str(whisper_full_lang_id(ctx)), true);
|
||||
end_obj();
|
||||
start_arr("transcription");
|
||||
|
||||
const int n_segments = whisper_full_n_segments(ctx);
|
||||
for (int i = 0; i < n_segments; ++i) {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
|
||||
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
||||
|
||||
start_obj();
|
||||
start_obj("timestamps");
|
||||
value_s("from", to_timestamp(t0, true).c_str());
|
||||
value_s("to", to_timestamp(t1, true).c_str(), true);
|
||||
end_obj();
|
||||
start_obj("offsets");
|
||||
value_i("from", t0 * 10);
|
||||
value_i("to", t1 * 10, true);
|
||||
end_obj();
|
||||
value_s("text", text, true);
|
||||
end_obj(i == (n_segments - 1));
|
||||
}
|
||||
|
||||
end_arr(true);
|
||||
end_obj(true);
|
||||
return true;
|
||||
}
|
||||
|
||||
// karaoke video generation
|
||||
// outputs a bash script that uses ffmpeg to generate a video with the subtitles
|
||||
// TODO: font parameter adjustments
|
||||
@ -333,8 +540,13 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
|
||||
|
||||
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
|
||||
|
||||
// TODO: become parameter
|
||||
static const char * font = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
|
||||
static const char * font = params.font_path.c_str();
|
||||
|
||||
std::ifstream fin(font);
|
||||
if (!fin.is_open()) {
|
||||
fprintf(stderr, "%s: font not found at '%s', please specify a monospace font with -fp\n", __func__, font);
|
||||
return false;
|
||||
}
|
||||
|
||||
fout << "#!/bin/bash" << "\n";
|
||||
fout << "\n";
|
||||
@ -377,7 +589,6 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
|
||||
txt_ul = "\\ \\ ";
|
||||
|
||||
{
|
||||
int ncnt = 0;
|
||||
for (int k = 0; k < n; ++k) {
|
||||
const auto & token2 = tokens[k];
|
||||
|
||||
@ -401,8 +612,6 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
|
||||
txt_ul += "\\ ";
|
||||
}
|
||||
}
|
||||
|
||||
ncnt += txt.size();
|
||||
}
|
||||
|
||||
::replace_all(txt_bg, "'", "\u2019");
|
||||
@ -440,6 +649,39 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
|
||||
return true;
|
||||
}
|
||||
|
||||
bool output_lrc(struct whisper_context * ctx, const char * fname) {
|
||||
|
||||
std::ofstream fout(fname);
|
||||
if (!fout.is_open()) {
|
||||
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
|
||||
return false;
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
|
||||
|
||||
fout << "[by:whisper.cpp]\n";
|
||||
|
||||
const int n_segments = whisper_full_n_segments(ctx);
|
||||
for (int i = 0; i < n_segments; ++i) {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
const int64_t t = whisper_full_get_segment_t0(ctx, i);
|
||||
|
||||
int64_t msec = t * 10;
|
||||
int64_t min = msec / (1000 * 60);
|
||||
msec = msec - min * (1000 * 60);
|
||||
int64_t sec = msec / 1000;
|
||||
msec = msec - sec * 1000;
|
||||
|
||||
char buf[16];
|
||||
snprintf(buf, sizeof(buf), "%02d:%02d.%02d", (int) min, (int) sec, (int) ( msec / 10));
|
||||
std::string timestamp_lrc = std::string(buf);
|
||||
|
||||
fout << '[' << timestamp_lrc << ']' << text << "\n";
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
whisper_params params;
|
||||
|
||||
@ -461,115 +703,23 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// whisper init
|
||||
|
||||
struct whisper_context * ctx = whisper_init(params.model.c_str());
|
||||
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
|
||||
|
||||
if (ctx == nullptr) {
|
||||
fprintf(stderr, "error: failed to initialize whisper context\n");
|
||||
return 3;
|
||||
}
|
||||
|
||||
// initial prompt
|
||||
std::vector<whisper_token> prompt_tokens;
|
||||
|
||||
if (params.prompt.size() > 0) {
|
||||
prompt_tokens.resize(1024);
|
||||
prompt_tokens.resize(whisper_tokenize(ctx, params.prompt.c_str(), prompt_tokens.data(), prompt_tokens.size()));
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "initial prompt: '%s'\n", params.prompt.c_str());
|
||||
fprintf(stderr, "initial tokens: [ ");
|
||||
for (int i = 0; i < (int) prompt_tokens.size(); ++i) {
|
||||
fprintf(stderr, "%d ", prompt_tokens[i]);
|
||||
}
|
||||
fprintf(stderr, "]\n");
|
||||
}
|
||||
|
||||
for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
|
||||
const auto fname_inp = params.fname_inp[f];
|
||||
const auto fname_out = f < (int) params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
|
||||
|
||||
std::vector<float> pcmf32; // mono-channel F32 PCM
|
||||
std::vector<float> pcmf32; // mono-channel F32 PCM
|
||||
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
|
||||
|
||||
// WAV input
|
||||
{
|
||||
drwav wav;
|
||||
std::vector<uint8_t> wav_data; // used for pipe input from stdin
|
||||
|
||||
if (fname_inp == "-") {
|
||||
{
|
||||
uint8_t buf[1024];
|
||||
while (true)
|
||||
{
|
||||
const size_t n = fread(buf, 1, sizeof(buf), stdin);
|
||||
if (n == 0) {
|
||||
break;
|
||||
}
|
||||
wav_data.insert(wav_data.end(), buf, buf + n);
|
||||
}
|
||||
}
|
||||
|
||||
if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), NULL) == false) {
|
||||
fprintf(stderr, "error: failed to open WAV file from stdin\n");
|
||||
return 4;
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
|
||||
}
|
||||
else if (drwav_init_file(&wav, fname_inp.c_str(), NULL) == false) {
|
||||
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
|
||||
return 5;
|
||||
}
|
||||
|
||||
if (wav.channels != 1 && wav.channels != 2) {
|
||||
fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", argv[0], fname_inp.c_str());
|
||||
return 6;
|
||||
}
|
||||
|
||||
if (params.diarize && wav.channels != 2 && params.no_timestamps == false) {
|
||||
fprintf(stderr, "%s: WAV file '%s' must be stereo for diarization and timestamps have to be enabled\n", argv[0], fname_inp.c_str());
|
||||
return 6;
|
||||
}
|
||||
|
||||
if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
|
||||
fprintf(stderr, "%s: WAV file '%s' must be 16 kHz\n", argv[0], fname_inp.c_str());
|
||||
return 8;
|
||||
}
|
||||
|
||||
if (wav.bitsPerSample != 16) {
|
||||
fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", argv[0], fname_inp.c_str());
|
||||
return 9;
|
||||
}
|
||||
|
||||
const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8);
|
||||
|
||||
std::vector<int16_t> pcm16;
|
||||
pcm16.resize(n*wav.channels);
|
||||
drwav_read_pcm_frames_s16(&wav, n, pcm16.data());
|
||||
drwav_uninit(&wav);
|
||||
|
||||
// convert to mono, float
|
||||
pcmf32.resize(n);
|
||||
if (wav.channels == 1) {
|
||||
for (int i = 0; i < n; i++) {
|
||||
pcmf32[i] = float(pcm16[i])/32768.0f;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < n; i++) {
|
||||
pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
|
||||
}
|
||||
}
|
||||
|
||||
if (params.diarize) {
|
||||
// convert to stereo, float
|
||||
pcmf32s.resize(2);
|
||||
|
||||
pcmf32s[0].resize(n);
|
||||
pcmf32s[1].resize(n);
|
||||
for (int i = 0; i < n; i++) {
|
||||
pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
|
||||
pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
|
||||
}
|
||||
}
|
||||
if (!::read_wav(fname_inp, pcmf32, pcmf32s, params.diarize)) {
|
||||
fprintf(stderr, "error: failed to read WAV file '%s'\n", fname_inp.c_str());
|
||||
continue;
|
||||
}
|
||||
|
||||
// print system information
|
||||
@ -603,6 +753,8 @@ int main(int argc, char ** argv) {
|
||||
{
|
||||
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||
|
||||
wparams.strategy = params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY;
|
||||
|
||||
wparams.print_realtime = false;
|
||||
wparams.print_progress = params.print_progress;
|
||||
wparams.print_timestamps = !params.no_timestamps;
|
||||
@ -617,11 +769,18 @@ int main(int argc, char ** argv) {
|
||||
wparams.token_timestamps = params.output_wts || params.max_len > 0;
|
||||
wparams.thold_pt = params.word_thold;
|
||||
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
|
||||
wparams.split_on_word = params.split_on_word;
|
||||
|
||||
wparams.speed_up = params.speed_up;
|
||||
|
||||
wparams.prompt_tokens = prompt_tokens.size() == 0 ? nullptr : prompt_tokens.data();
|
||||
wparams.prompt_n_tokens = prompt_tokens.size() == 0 ? 0 : prompt_tokens.size();
|
||||
wparams.initial_prompt = params.prompt.c_str();
|
||||
|
||||
wparams.greedy.best_of = params.best_of;
|
||||
wparams.beam_search.beam_size = params.beam_size;
|
||||
|
||||
wparams.temperature_inc = params.no_fallback ? 0.0f : wparams.temperature_inc;
|
||||
wparams.entropy_thold = params.entropy_thold;
|
||||
wparams.logprob_thold = params.logprob_thold;
|
||||
|
||||
whisper_print_user_data user_data = { ¶ms, &pcmf32s };
|
||||
|
||||
@ -637,7 +796,7 @@ int main(int argc, char ** argv) {
|
||||
{
|
||||
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
|
||||
|
||||
wparams.encoder_begin_callback = [](struct whisper_context * ctx, void * user_data) {
|
||||
wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
|
||||
bool is_aborted = *(bool*)user_data;
|
||||
return !is_aborted;
|
||||
};
|
||||
@ -656,27 +815,45 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// output to text file
|
||||
if (params.output_txt) {
|
||||
const auto fname_txt = fname_inp + ".txt";
|
||||
const auto fname_txt = fname_out + ".txt";
|
||||
output_txt(ctx, fname_txt.c_str());
|
||||
}
|
||||
|
||||
// output to VTT file
|
||||
if (params.output_vtt) {
|
||||
const auto fname_vtt = fname_inp + ".vtt";
|
||||
const auto fname_vtt = fname_out + ".vtt";
|
||||
output_vtt(ctx, fname_vtt.c_str());
|
||||
}
|
||||
|
||||
// output to SRT file
|
||||
if (params.output_srt) {
|
||||
const auto fname_srt = fname_inp + ".srt";
|
||||
const auto fname_srt = fname_out + ".srt";
|
||||
output_srt(ctx, fname_srt.c_str(), params);
|
||||
}
|
||||
|
||||
// output to WTS file
|
||||
if (params.output_wts) {
|
||||
const auto fname_wts = fname_inp + ".wts";
|
||||
const auto fname_wts = fname_out + ".wts";
|
||||
output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE);
|
||||
}
|
||||
|
||||
// output to CSV file
|
||||
if (params.output_csv) {
|
||||
const auto fname_csv = fname_out + ".csv";
|
||||
output_csv(ctx, fname_csv.c_str());
|
||||
}
|
||||
|
||||
// output to JSON file
|
||||
if (params.output_jsn) {
|
||||
const auto fname_jsn = fname_out + ".json";
|
||||
output_json(ctx, fname_jsn.c_str(), params);
|
||||
}
|
||||
|
||||
// output to LRC file
|
||||
if (params.output_lrc) {
|
||||
const auto fname_lrc = fname_out + ".lrc";
|
||||
output_lrc(ctx, fname_lrc.c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
6
examples/quantize/CMakeLists.txt
Normal file
6
examples/quantize/CMakeLists.txt
Normal file
@ -0,0 +1,6 @@
|
||||
set(TARGET quantize)
|
||||
add_executable(${TARGET} quantize.cpp)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE common whisper ${CMAKE_THREAD_LIBS_INIT})
|
3
examples/quantize/README.md
Normal file
3
examples/quantize/README.md
Normal file
@ -0,0 +1,3 @@
|
||||
# quantize
|
||||
|
||||
Tool for integer quantization of Whisper `ggml` model files
|
215
examples/quantize/quantize.cpp
Normal file
215
examples/quantize/quantize.cpp
Normal file
@ -0,0 +1,215 @@
|
||||
#include "ggml.h"
|
||||
|
||||
#include "common.h"
|
||||
#include "common-ggml.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <regex>
|
||||
|
||||
// default hparams (Whisper tiny)
|
||||
struct whisper_hparams {
|
||||
int32_t n_vocab = 51864;
|
||||
int32_t n_audio_ctx = 1500;
|
||||
int32_t n_audio_state = 384;
|
||||
int32_t n_audio_head = 6;
|
||||
int32_t n_audio_layer = 4;
|
||||
int32_t n_text_ctx = 448;
|
||||
int32_t n_text_state = 384;
|
||||
int32_t n_text_head = 6;
|
||||
int32_t n_text_layer = 4;
|
||||
int32_t n_mels = 80;
|
||||
int32_t f16 = 1;
|
||||
};
|
||||
|
||||
struct whisper_filters {
|
||||
int32_t n_mel;
|
||||
int32_t n_fft;
|
||||
|
||||
std::vector<float> data;
|
||||
};
|
||||
|
||||
// quantize a model
|
||||
bool whisper_model_quantize(const std::string & fname_inp, const std::string & fname_out, ggml_ftype ftype) {
|
||||
gpt_vocab vocab;
|
||||
|
||||
printf("%s: loading model from '%s'\n", __func__, fname_inp.c_str());
|
||||
|
||||
auto finp = std::ifstream(fname_inp, std::ios::binary);
|
||||
if (!finp) {
|
||||
fprintf(stderr, "%s: failed to open '%s' for reading\n", __func__, fname_inp.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
auto fout = std::ofstream(fname_out, std::ios::binary);
|
||||
if (!fout) {
|
||||
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_out.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
// verify magic
|
||||
{
|
||||
uint32_t magic;
|
||||
finp.read((char *) &magic, sizeof(magic));
|
||||
if (magic != 0x67676d6c) {
|
||||
fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname_inp.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
fout.write((char *) &magic, sizeof(magic));
|
||||
}
|
||||
|
||||
whisper_hparams hparams;
|
||||
|
||||
// load hparams
|
||||
{
|
||||
finp.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
|
||||
finp.read((char *) &hparams.n_audio_ctx, sizeof(hparams.n_audio_ctx));
|
||||
finp.read((char *) &hparams.n_audio_state, sizeof(hparams.n_audio_state));
|
||||
finp.read((char *) &hparams.n_audio_head, sizeof(hparams.n_audio_head));
|
||||
finp.read((char *) &hparams.n_audio_layer, sizeof(hparams.n_audio_layer));
|
||||
finp.read((char *) &hparams.n_text_ctx, sizeof(hparams.n_text_ctx));
|
||||
finp.read((char *) &hparams.n_text_state, sizeof(hparams.n_text_state));
|
||||
finp.read((char *) &hparams.n_text_head, sizeof(hparams.n_text_head));
|
||||
finp.read((char *) &hparams.n_text_layer, sizeof(hparams.n_text_layer));
|
||||
finp.read((char *) &hparams.n_mels, sizeof(hparams.n_mels));
|
||||
finp.read((char *) &hparams.f16, sizeof(hparams.f16));
|
||||
|
||||
fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab);
|
||||
fprintf(stderr, "%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx);
|
||||
fprintf(stderr, "%s: n_audio_state = %d\n", __func__, hparams.n_audio_state);
|
||||
fprintf(stderr, "%s: n_audio_head = %d\n", __func__, hparams.n_audio_head);
|
||||
fprintf(stderr, "%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer);
|
||||
fprintf(stderr, "%s: n_text_ctx = %d\n", __func__, hparams.n_text_ctx);
|
||||
fprintf(stderr, "%s: n_text_state = %d\n", __func__, hparams.n_text_state);
|
||||
fprintf(stderr, "%s: n_text_head = %d\n", __func__, hparams.n_text_head);
|
||||
fprintf(stderr, "%s: n_text_layer = %d\n", __func__, hparams.n_text_layer);
|
||||
fprintf(stderr, "%s: n_mels = %d\n", __func__, hparams.n_mels);
|
||||
fprintf(stderr, "%s: f16 = %d\n", __func__, hparams.f16);
|
||||
|
||||
fout.write((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
|
||||
fout.write((char *) &hparams.n_audio_ctx, sizeof(hparams.n_audio_ctx));
|
||||
fout.write((char *) &hparams.n_audio_state, sizeof(hparams.n_audio_state));
|
||||
fout.write((char *) &hparams.n_audio_head, sizeof(hparams.n_audio_head));
|
||||
fout.write((char *) &hparams.n_audio_layer, sizeof(hparams.n_audio_layer));
|
||||
fout.write((char *) &hparams.n_text_ctx, sizeof(hparams.n_text_ctx));
|
||||
fout.write((char *) &hparams.n_text_state, sizeof(hparams.n_text_state));
|
||||
fout.write((char *) &hparams.n_text_head, sizeof(hparams.n_text_head));
|
||||
fout.write((char *) &hparams.n_text_layer, sizeof(hparams.n_text_layer));
|
||||
fout.write((char *) &hparams.n_mels, sizeof(hparams.n_mels));
|
||||
fout.write((char *) &ftype, sizeof(hparams.f16));
|
||||
}
|
||||
|
||||
// load mel filters
|
||||
{
|
||||
whisper_filters filters;
|
||||
|
||||
finp.read ((char *) &filters.n_mel, sizeof(filters.n_mel));
|
||||
fout.write((char *) &filters.n_mel, sizeof(filters.n_mel));
|
||||
finp.read ((char *) &filters.n_fft, sizeof(filters.n_fft));
|
||||
fout.write((char *) &filters.n_fft, sizeof(filters.n_fft));
|
||||
|
||||
filters.data.resize(filters.n_mel * filters.n_fft);
|
||||
finp.read ((char *) filters.data.data(), filters.data.size() * sizeof(float));
|
||||
fout.write((char *) filters.data.data(), filters.data.size() * sizeof(float));
|
||||
}
|
||||
|
||||
// load vocab
|
||||
{
|
||||
int32_t n_vocab = 0;
|
||||
finp.read ((char *) &n_vocab, sizeof(n_vocab));
|
||||
fout.write((char *) &n_vocab, sizeof(n_vocab));
|
||||
|
||||
//if (n_vocab != hparams.n_vocab) {
|
||||
// fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
|
||||
// __func__, fname_inp.c_str(), n_vocab, hparams.n_vocab);
|
||||
// return false;
|
||||
//}
|
||||
|
||||
std::string word;
|
||||
for (int i = 0; i < n_vocab; i++) {
|
||||
uint32_t len;
|
||||
finp.read ((char *) &len, sizeof(len));
|
||||
fout.write((char *) &len, sizeof(len));
|
||||
|
||||
word.resize(len);
|
||||
finp.read ((char *) word.data(), len);
|
||||
fout.write((char *) word.data(), len);
|
||||
|
||||
vocab.token_to_id[word] = i;
|
||||
vocab.id_to_token[i] = word;
|
||||
}
|
||||
}
|
||||
|
||||
// regexes of tensor names to not be quantized
|
||||
const std::vector<std::string> to_skip = {
|
||||
//"encoder.*",
|
||||
"encoder.conv1.bias",
|
||||
"encoder.conv2.bias",
|
||||
"encoder.positional_embedding",
|
||||
"decoder.positional_embedding",
|
||||
};
|
||||
|
||||
if (!ggml_common_quantize_0(finp, fout, ftype, { ".*" }, to_skip)) {
|
||||
fprintf(stderr, "%s: failed to quantize model '%s'\n", __func__, fname_inp.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
finp.close();
|
||||
fout.close();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
if (argc != 4) {
|
||||
fprintf(stderr, "usage: %s model-f32.bin model-quant.bin type\n", argv[0]);
|
||||
ggml_print_ftypes(stderr);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// needed to initialize f16 tables
|
||||
{
|
||||
struct ggml_init_params params = { 0, NULL, false };
|
||||
struct ggml_context * ctx = ggml_init(params);
|
||||
ggml_free(ctx);
|
||||
}
|
||||
|
||||
const std::string fname_inp = argv[1];
|
||||
const std::string fname_out = argv[2];
|
||||
|
||||
const ggml_ftype ftype = ggml_parse_ftype(argv[3]);
|
||||
|
||||
const int64_t t_main_start_us = ggml_time_us();
|
||||
|
||||
int64_t t_quantize_us = 0;
|
||||
|
||||
// load the model
|
||||
{
|
||||
const int64_t t_start_us = ggml_time_us();
|
||||
|
||||
if (!whisper_model_quantize(fname_inp, fname_out, ggml_ftype(ftype))) {
|
||||
fprintf(stderr, "%s: failed to quantize model from '%s'\n", __func__, fname_inp.c_str());
|
||||
return 1;
|
||||
}
|
||||
|
||||
t_quantize_us = ggml_time_us() - t_start_us;
|
||||
}
|
||||
|
||||
// report timing
|
||||
{
|
||||
const int64_t t_main_end_us = ggml_time_us();
|
||||
|
||||
printf("\n");
|
||||
printf("%s: quantize time = %8.2f ms\n", __func__, t_quantize_us/1000.0f);
|
||||
printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
@ -8,6 +8,8 @@ add_executable(${TARGET}
|
||||
emscripten.cpp
|
||||
)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE
|
||||
whisper
|
||||
)
|
||||
|
@ -49,6 +49,9 @@ void stream_main(size_t index) {
|
||||
wparams.max_tokens = 32;
|
||||
wparams.audio_ctx = 768; // partial encoder context for better performance
|
||||
|
||||
// disable temperature fallback
|
||||
wparams.temperature_inc = -1.0f;
|
||||
|
||||
wparams.language = "en";
|
||||
|
||||
printf("stream: using %d threads\n", wparams.n_threads);
|
||||
@ -129,7 +132,7 @@ EMSCRIPTEN_BINDINGS(stream) {
|
||||
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
|
||||
for (size_t i = 0; i < g_contexts.size(); ++i) {
|
||||
if (g_contexts[i] == nullptr) {
|
||||
g_contexts[i] = whisper_init(path_model.c_str());
|
||||
g_contexts[i] = whisper_init_from_file(path_model.c_str());
|
||||
if (g_contexts[i] != nullptr) {
|
||||
g_running = true;
|
||||
if (g_worker.joinable()) {
|
||||
|
@ -35,6 +35,15 @@
|
||||
|
||||
<br><br>
|
||||
|
||||
<b>More examples:</b>
|
||||
<a href="https://whisper.ggerganov.com/">main</a> |
|
||||
<a href="https://whisper.ggerganov.com/bench">bench</a> |
|
||||
<a href="https://whisper.ggerganov.com/stream">stream</a> |
|
||||
<a href="https://whisper.ggerganov.com/command">command</a> |
|
||||
<a href="https://whisper.ggerganov.com/talk">talk</a> |
|
||||
|
||||
<br><br>
|
||||
|
||||
<hr>
|
||||
|
||||
Select the model you would like to use, click the "Start" button and start speaking
|
||||
@ -45,6 +54,10 @@
|
||||
Whisper model: <span id="model-whisper-status"></span>
|
||||
<button id="fetch-whisper-tiny-en" onclick="loadWhisper('tiny.en')">tiny.en (75 MB)</button>
|
||||
<button id="fetch-whisper-base-en" onclick="loadWhisper('base.en')">base.en (142 MB)</button>
|
||||
<br><br>
|
||||
Quantized models:<br><br>
|
||||
<button id="fetch-whisper-tiny-en-q5_1" onclick="loadWhisper('tiny-en-q5_1')">tiny.en (Q5_1, 31 MB)</button>
|
||||
<button id="fetch-whisper-base-en-q5_1" onclick="loadWhisper('base-en-q5_1')">base.en (Q5_1, 57 MB)</button>
|
||||
<span id="fetch-whisper-progress"></span>
|
||||
|
||||
<!--
|
||||
@ -162,11 +175,17 @@
|
||||
let urls = {
|
||||
'tiny.en': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en.bin',
|
||||
'base.en': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en.bin',
|
||||
|
||||
'tiny-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en-q5_1.bin',
|
||||
'base-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en-q5_1.bin',
|
||||
};
|
||||
|
||||
let sizes = {
|
||||
'tiny.en': 75,
|
||||
'base.en': 142,
|
||||
|
||||
'tiny-en-q5_1': 31,
|
||||
'base-en-q5_1': 57,
|
||||
};
|
||||
|
||||
let url = urls[model];
|
||||
@ -177,6 +196,10 @@
|
||||
|
||||
document.getElementById('fetch-whisper-tiny-en').style.display = 'none';
|
||||
document.getElementById('fetch-whisper-base-en').style.display = 'none';
|
||||
|
||||
document.getElementById('fetch-whisper-tiny-en-q5_1').style.display = 'none';
|
||||
document.getElementById('fetch-whisper-base-en-q5_1').style.display = 'none';
|
||||
|
||||
document.getElementById('model-whisper-status').innerHTML = 'loading "' + model + '" ... ';
|
||||
|
||||
cbProgress = function(p) {
|
||||
@ -188,6 +211,10 @@
|
||||
var el;
|
||||
el = document.getElementById('fetch-whisper-tiny-en'); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('fetch-whisper-base-en'); if (el) el.style.display = 'inline-block';
|
||||
|
||||
el = document.getElementById('fetch-whisper-tiny-en-q5_1'); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('fetch-whisper-base-en-q5_1'); if (el) el.style.display = 'inline-block';
|
||||
|
||||
el = document.getElementById('model-whisper-status'); if (el) el.innerHTML = '';
|
||||
};
|
||||
|
||||
|
@ -1,7 +1,9 @@
|
||||
if (WHISPER_SUPPORT_SDL2)
|
||||
if (WHISPER_SDL2)
|
||||
# stream
|
||||
set(TARGET stream)
|
||||
add_executable(${TARGET} stream.cpp)
|
||||
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
|
||||
target_link_libraries(${TARGET} PRIVATE whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE common common-sdl whisper ${CMAKE_THREAD_LIBS_INIT})
|
||||
endif ()
|
||||
|
@ -3,18 +3,16 @@
|
||||
// A very quick-n-dirty implementation serving mainly as a proof of concept.
|
||||
//
|
||||
|
||||
#include "common.h"
|
||||
#include "common-sdl.h"
|
||||
#include "whisper.h"
|
||||
|
||||
#include <SDL.h>
|
||||
#include <SDL_audio.h>
|
||||
|
||||
#include <cassert>
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <fstream>
|
||||
#include <mutex>
|
||||
|
||||
// 500 -> 00:05.000
|
||||
// 6000 -> 01:00.000
|
||||
@ -45,13 +43,14 @@ struct whisper_params {
|
||||
|
||||
bool speed_up = false;
|
||||
bool translate = false;
|
||||
bool no_fallback = false;
|
||||
bool print_special = false;
|
||||
bool no_context = true;
|
||||
bool no_timestamps = false;
|
||||
|
||||
std::string language = "en";
|
||||
std::string model = "models/ggml-base.en.bin";
|
||||
std::string fname_out = "";
|
||||
std::string fname_out;
|
||||
};
|
||||
|
||||
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
|
||||
@ -75,6 +74,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
|
||||
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
|
||||
else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; }
|
||||
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
|
||||
else if (arg == "-kc" || arg == "--keep-context") { params.no_context = false; }
|
||||
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
|
||||
@ -90,329 +90,32 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
return true;
|
||||
}
|
||||
|
||||
void whisper_print_usage(int argc, char ** argv, const whisper_params & params) {
|
||||
void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params) {
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "usage: %s [options]\n", argv[0]);
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "options:\n");
|
||||
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
|
||||
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
|
||||
fprintf(stderr, " --step N [%-7d] audio step size in milliseconds\n", params.step_ms);
|
||||
fprintf(stderr, " --length N [%-7d] audio length in milliseconds\n", params.length_ms);
|
||||
fprintf(stderr, " --keep N [%-7d] audio to keep from previous step in ms\n", params.keep_ms);
|
||||
fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
|
||||
fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
|
||||
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
|
||||
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
|
||||
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
|
||||
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
|
||||
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
|
||||
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
||||
fprintf(stderr, " -kc, --keep-context [%-7s] keep context between audio chunks\n", params.no_context ? "false" : "true");
|
||||
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
|
||||
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
||||
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
|
||||
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
|
||||
fprintf(stderr, " --step N [%-7d] audio step size in milliseconds\n", params.step_ms);
|
||||
fprintf(stderr, " --length N [%-7d] audio length in milliseconds\n", params.length_ms);
|
||||
fprintf(stderr, " --keep N [%-7d] audio to keep from previous step in ms\n", params.keep_ms);
|
||||
fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
|
||||
fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
|
||||
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
|
||||
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
|
||||
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
|
||||
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
|
||||
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
|
||||
fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false");
|
||||
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
||||
fprintf(stderr, " -kc, --keep-context [%-7s] keep context between audio chunks\n", params.no_context ? "false" : "true");
|
||||
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
|
||||
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
||||
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
//
|
||||
// SDL Audio capture
|
||||
//
|
||||
|
||||
class audio_async {
|
||||
public:
|
||||
audio_async(int len_ms);
|
||||
~audio_async();
|
||||
|
||||
bool init(int capture_id, int sample_rate);
|
||||
|
||||
// start capturing audio via the provided SDL callback
|
||||
// keep last len_ms seconds of audio in a circular buffer
|
||||
bool resume();
|
||||
bool pause();
|
||||
bool clear();
|
||||
|
||||
// callback to be called by SDL
|
||||
void callback(uint8_t * stream, int len);
|
||||
|
||||
// get audio data from the circular buffer
|
||||
void get(int ms, std::vector<float> & audio);
|
||||
|
||||
private:
|
||||
SDL_AudioDeviceID m_dev_id_in = 0;
|
||||
|
||||
int m_len_ms = 0;
|
||||
int m_sample_rate = 0;
|
||||
|
||||
bool m_running = false;
|
||||
std::mutex m_mutex;
|
||||
|
||||
std::vector<float> m_audio;
|
||||
std::vector<float> m_audio_new;
|
||||
size_t m_audio_pos = 0;
|
||||
size_t m_audio_len = 0;
|
||||
};
|
||||
|
||||
audio_async::audio_async(int len_ms) {
|
||||
m_len_ms = len_ms;
|
||||
}
|
||||
|
||||
audio_async::~audio_async() {
|
||||
if (m_dev_id_in) {
|
||||
SDL_CloseAudioDevice(m_dev_id_in);
|
||||
}
|
||||
}
|
||||
|
||||
bool audio_async::init(int capture_id, int sample_rate) {
|
||||
SDL_LogSetPriority(SDL_LOG_CATEGORY_APPLICATION, SDL_LOG_PRIORITY_INFO);
|
||||
|
||||
if (SDL_Init(SDL_INIT_AUDIO) < 0) {
|
||||
SDL_LogError(SDL_LOG_CATEGORY_APPLICATION, "Couldn't initialize SDL: %s\n", SDL_GetError());
|
||||
return false;
|
||||
}
|
||||
|
||||
SDL_SetHintWithPriority(SDL_HINT_AUDIO_RESAMPLING_MODE, "medium", SDL_HINT_OVERRIDE);
|
||||
|
||||
{
|
||||
int nDevices = SDL_GetNumAudioDevices(SDL_TRUE);
|
||||
fprintf(stderr, "%s: found %d capture devices:\n", __func__, nDevices);
|
||||
for (int i = 0; i < nDevices; i++) {
|
||||
fprintf(stderr, "%s: - Capture device #%d: '%s'\n", __func__, i, SDL_GetAudioDeviceName(i, SDL_TRUE));
|
||||
}
|
||||
}
|
||||
|
||||
SDL_AudioSpec capture_spec_requested;
|
||||
SDL_AudioSpec capture_spec_obtained;
|
||||
|
||||
SDL_zero(capture_spec_requested);
|
||||
SDL_zero(capture_spec_obtained);
|
||||
|
||||
capture_spec_requested.freq = sample_rate;
|
||||
capture_spec_requested.format = AUDIO_F32;
|
||||
capture_spec_requested.channels = 1;
|
||||
capture_spec_requested.samples = 1024;
|
||||
capture_spec_requested.callback = [](void * userdata, uint8_t * stream, int len) {
|
||||
audio_async * audio = (audio_async *) userdata;
|
||||
audio->callback(stream, len);
|
||||
};
|
||||
capture_spec_requested.userdata = this;
|
||||
|
||||
if (capture_id >= 0) {
|
||||
fprintf(stderr, "%s: attempt to open capture device %d : '%s' ...\n", __func__, capture_id, SDL_GetAudioDeviceName(capture_id, SDL_TRUE));
|
||||
m_dev_id_in = SDL_OpenAudioDevice(SDL_GetAudioDeviceName(capture_id, SDL_TRUE), SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
|
||||
} else {
|
||||
fprintf(stderr, "%s: attempt to open default capture device ...\n", __func__);
|
||||
m_dev_id_in = SDL_OpenAudioDevice(nullptr, SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
|
||||
}
|
||||
|
||||
if (!m_dev_id_in) {
|
||||
fprintf(stderr, "%s: couldn't open an audio device for capture: %s!\n", __func__, SDL_GetError());
|
||||
m_dev_id_in = 0;
|
||||
|
||||
return false;
|
||||
} else {
|
||||
fprintf(stderr, "%s: obtained spec for input device (SDL Id = %d):\n", __func__, m_dev_id_in);
|
||||
fprintf(stderr, "%s: - sample rate: %d\n", __func__, capture_spec_obtained.freq);
|
||||
fprintf(stderr, "%s: - format: %d (required: %d)\n", __func__, capture_spec_obtained.format,
|
||||
capture_spec_requested.format);
|
||||
fprintf(stderr, "%s: - channels: %d (required: %d)\n", __func__, capture_spec_obtained.channels,
|
||||
capture_spec_requested.channels);
|
||||
fprintf(stderr, "%s: - samples per frame: %d\n", __func__, capture_spec_obtained.samples);
|
||||
}
|
||||
|
||||
m_sample_rate = capture_spec_obtained.freq;
|
||||
|
||||
m_audio.resize((m_sample_rate*m_len_ms)/1000);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool audio_async::resume() {
|
||||
if (!m_dev_id_in) {
|
||||
fprintf(stderr, "%s: no audio device to resume!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (m_running) {
|
||||
fprintf(stderr, "%s: already running!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
SDL_PauseAudioDevice(m_dev_id_in, 0);
|
||||
|
||||
m_running = true;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool audio_async::pause() {
|
||||
if (!m_dev_id_in) {
|
||||
fprintf(stderr, "%s: no audio device to pause!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!m_running) {
|
||||
fprintf(stderr, "%s: already paused!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
SDL_PauseAudioDevice(m_dev_id_in, 1);
|
||||
|
||||
m_running = false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool audio_async::clear() {
|
||||
if (!m_dev_id_in) {
|
||||
fprintf(stderr, "%s: no audio device to clear!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!m_running) {
|
||||
fprintf(stderr, "%s: not running!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(m_mutex);
|
||||
|
||||
m_audio_pos = 0;
|
||||
m_audio_len = 0;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// callback to be called by SDL
|
||||
void audio_async::callback(uint8_t * stream, int len) {
|
||||
if (!m_running) {
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t n_samples = len / sizeof(float);
|
||||
|
||||
m_audio_new.resize(n_samples);
|
||||
memcpy(m_audio_new.data(), stream, n_samples * sizeof(float));
|
||||
|
||||
//fprintf(stderr, "%s: %zu samples, pos %zu, len %zu\n", __func__, n_samples, m_audio_pos, m_audio_len);
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(m_mutex);
|
||||
|
||||
if (m_audio_pos + n_samples > m_audio.size()) {
|
||||
const size_t n0 = m_audio.size() - m_audio_pos;
|
||||
|
||||
memcpy(&m_audio[m_audio_pos], stream, n0 * sizeof(float));
|
||||
memcpy(&m_audio[0], &stream[n0], (n_samples - n0) * sizeof(float));
|
||||
|
||||
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
|
||||
m_audio_len = m_audio.size();
|
||||
} else {
|
||||
memcpy(&m_audio[m_audio_pos], stream, n_samples * sizeof(float));
|
||||
|
||||
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
|
||||
m_audio_len = std::min(m_audio_len + n_samples, m_audio.size());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void audio_async::get(int ms, std::vector<float> & result) {
|
||||
if (!m_dev_id_in) {
|
||||
fprintf(stderr, "%s: no audio device to get audio from!\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!m_running) {
|
||||
fprintf(stderr, "%s: not running!\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
result.clear();
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(m_mutex);
|
||||
|
||||
if (ms <= 0) {
|
||||
ms = m_len_ms;
|
||||
}
|
||||
|
||||
size_t n_samples = (m_sample_rate * ms) / 1000;
|
||||
if (n_samples > m_audio_len) {
|
||||
n_samples = m_audio_len;
|
||||
}
|
||||
|
||||
result.resize(n_samples);
|
||||
|
||||
int s0 = m_audio_pos - n_samples;
|
||||
if (s0 < 0) {
|
||||
s0 += m_audio.size();
|
||||
}
|
||||
|
||||
if (s0 + n_samples > m_audio.size()) {
|
||||
const size_t n0 = m_audio.size() - s0;
|
||||
|
||||
memcpy(result.data(), &m_audio[s0], n0 * sizeof(float));
|
||||
memcpy(&result[n0], &m_audio[0], (n_samples - n0) * sizeof(float));
|
||||
} else {
|
||||
memcpy(result.data(), &m_audio[s0], n_samples * sizeof(float));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////
|
||||
|
||||
void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate) {
|
||||
const float rc = 1.0f / (2.0f * M_PI * cutoff);
|
||||
const float dt = 1.0f / sample_rate;
|
||||
const float alpha = dt / (rc + dt);
|
||||
|
||||
float y = data[0];
|
||||
|
||||
for (size_t i = 1; i < data.size(); i++) {
|
||||
y = alpha * (y + data[i] - data[i - 1]);
|
||||
data[i] = y;
|
||||
}
|
||||
}
|
||||
|
||||
bool vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) {
|
||||
const int n_samples = pcmf32.size();
|
||||
const int n_samples_last = (sample_rate * last_ms) / 1000;
|
||||
|
||||
if (n_samples_last >= n_samples) {
|
||||
// not enough samples - assume no speech
|
||||
return false;
|
||||
}
|
||||
|
||||
if (freq_thold > 0.0f) {
|
||||
high_pass_filter(pcmf32, freq_thold, sample_rate);
|
||||
}
|
||||
|
||||
float energy_all = 0.0f;
|
||||
float energy_last = 0.0f;
|
||||
|
||||
for (size_t i = 0; i < n_samples; i++) {
|
||||
energy_all += fabsf(pcmf32[i]);
|
||||
|
||||
if (i >= n_samples - n_samples_last) {
|
||||
energy_last += fabsf(pcmf32[i]);
|
||||
}
|
||||
}
|
||||
|
||||
energy_all /= n_samples;
|
||||
energy_last /= n_samples_last;
|
||||
|
||||
if (verbose) {
|
||||
fprintf(stderr, "%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold);
|
||||
}
|
||||
|
||||
if (energy_last > vad_thold*energy_all) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
whisper_params params;
|
||||
|
||||
@ -420,20 +123,21 @@ int main(int argc, char ** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
params.keep_ms = std::min(params.keep_ms, params.step_ms); // cannot be more than step_ms
|
||||
params.keep_ms = std::min(params.keep_ms, params.step_ms);
|
||||
params.length_ms = std::max(params.length_ms, params.step_ms);
|
||||
|
||||
const int n_samples_step = (params.step_ms *1e-3)*WHISPER_SAMPLE_RATE;
|
||||
const int n_samples_len = (params.length_ms*1e-3)*WHISPER_SAMPLE_RATE;
|
||||
const int n_samples_keep = (params.keep_ms *1e-3)*WHISPER_SAMPLE_RATE;
|
||||
const int n_samples_30s = (30000 *1e-3)*WHISPER_SAMPLE_RATE;
|
||||
|
||||
const int n_new_line = params.length_ms / params.step_ms - 1; // number of steps to print new line
|
||||
const int n_samples_step = (1e-3*params.step_ms )*WHISPER_SAMPLE_RATE;
|
||||
const int n_samples_len = (1e-3*params.length_ms)*WHISPER_SAMPLE_RATE;
|
||||
const int n_samples_keep = (1e-3*params.keep_ms )*WHISPER_SAMPLE_RATE;
|
||||
const int n_samples_30s = (1e-3*30000.0 )*WHISPER_SAMPLE_RATE;
|
||||
|
||||
const bool use_vad = n_samples_step <= 0; // sliding window mode uses VAD
|
||||
|
||||
params.no_timestamps = !use_vad;
|
||||
params.no_context = use_vad;
|
||||
params.max_tokens = 0;
|
||||
const int n_new_line = !use_vad ? std::max(1, params.length_ms / params.step_ms - 1) : 1; // number of steps to print new line
|
||||
|
||||
params.no_timestamps = !use_vad;
|
||||
params.no_context |= use_vad;
|
||||
params.max_tokens = 0;
|
||||
|
||||
// init audio
|
||||
|
||||
@ -447,16 +151,16 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// whisper init
|
||||
|
||||
if (whisper_lang_id(params.language.c_str()) == -1) {
|
||||
if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1){
|
||||
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
|
||||
whisper_print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
}
|
||||
|
||||
struct whisper_context * ctx = whisper_init(params.model.c_str());
|
||||
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
|
||||
|
||||
std::vector<float> pcmf32 (n_samples_30s, 0.0f);
|
||||
std::vector<float> pcmf32_old(n_samples_30s, 0.0f);
|
||||
std::vector<float> pcmf32_old;
|
||||
std::vector<float> pcmf32_new(n_samples_30s, 0.0f);
|
||||
|
||||
std::vector<whisper_token> prompt_tokens;
|
||||
@ -483,7 +187,7 @@ int main(int argc, char ** argv) {
|
||||
params.no_timestamps ? 0 : 1);
|
||||
|
||||
if (!use_vad) {
|
||||
fprintf(stderr, "%s: n_new_line = %d\n", __func__, n_new_line);
|
||||
fprintf(stderr, "%s: n_new_line = %d, no_context = %d\n", __func__, n_new_line, params.no_context);
|
||||
} else {
|
||||
fprintf(stderr, "%s: using VAD, will transcribe on speech activity\n", __func__);
|
||||
}
|
||||
@ -513,23 +217,7 @@ int main(int argc, char ** argv) {
|
||||
// main audio loop
|
||||
while (is_running) {
|
||||
// handle Ctrl + C
|
||||
{
|
||||
SDL_Event event;
|
||||
while (SDL_PollEvent(&event)) {
|
||||
switch (event.type) {
|
||||
case SDL_QUIT:
|
||||
{
|
||||
is_running = false;
|
||||
} break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!is_running) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
is_running = sdl_poll_events();
|
||||
|
||||
if (!is_running) {
|
||||
break;
|
||||
@ -552,7 +240,7 @@ int main(int argc, char ** argv) {
|
||||
break;
|
||||
}
|
||||
|
||||
SDL_Delay(1);
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(1));
|
||||
}
|
||||
|
||||
const int n_samples_new = pcmf32_new.size();
|
||||
@ -583,7 +271,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
audio.get(2000, pcmf32_new);
|
||||
|
||||
if (vad_simple(pcmf32_new, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, false)) {
|
||||
if (::vad_simple(pcmf32_new, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, false)) {
|
||||
audio.get(params.length_ms, pcmf32);
|
||||
} else {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
@ -603,7 +291,6 @@ int main(int argc, char ** argv) {
|
||||
wparams.print_realtime = false;
|
||||
wparams.print_timestamps = !params.no_timestamps;
|
||||
wparams.translate = params.translate;
|
||||
wparams.no_context = true;
|
||||
wparams.single_segment = !use_vad;
|
||||
wparams.max_tokens = params.max_tokens;
|
||||
wparams.language = params.language.c_str();
|
||||
@ -612,6 +299,10 @@ int main(int argc, char ** argv) {
|
||||
wparams.audio_ctx = params.audio_ctx;
|
||||
wparams.speed_up = params.speed_up;
|
||||
|
||||
// disable temperature fallback
|
||||
//wparams.temperature_inc = -1.0f;
|
||||
wparams.temperature_inc = params.no_fallback ? 0.0f : wparams.temperature_inc;
|
||||
|
||||
wparams.prompt_tokens = params.no_context ? nullptr : prompt_tokens.data();
|
||||
wparams.prompt_n_tokens = params.no_context ? 0 : prompt_tokens.size();
|
||||
|
||||
@ -692,6 +383,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
}
|
||||
}
|
||||
fflush(stdout);
|
||||
}
|
||||
}
|
||||
|
||||
|
1
examples/talk-llama/.gitignore
vendored
Normal file
1
examples/talk-llama/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
audio.mp3
|
16
examples/talk-llama/CMakeLists.txt
Normal file
16
examples/talk-llama/CMakeLists.txt
Normal file
@ -0,0 +1,16 @@
|
||||
if (WHISPER_SDL2)
|
||||
# talk-llama
|
||||
set(TARGET talk-llama)
|
||||
#add_executable(${TARGET} talk-llama.cpp llama.cpp)
|
||||
#target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
|
||||
#target_link_libraries(${TARGET} PRIVATE common common-sdl whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
|
||||
|
||||
# TODO: this is temporary
|
||||
# need to export ggml symbols for MSVC, but too lazy ..
|
||||
add_executable(${TARGET} talk-llama.cpp llama.cpp ../common.cpp ../common-sdl.cpp ../../ggml.c ../../whisper.cpp)
|
||||
|
||||
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS} ../../)
|
||||
target_link_libraries(${TARGET} PRIVATE ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
endif ()
|
36
examples/talk-llama/README.md
Normal file
36
examples/talk-llama/README.md
Normal file
@ -0,0 +1,36 @@
|
||||
# talk-llama
|
||||
|
||||
Talk with an LLaMA AI in your terminal
|
||||
|
||||
[Demo Talk](https://user-images.githubusercontent.com/1991296/228024237-848f998c-c334-46a6-bef8-3271590da83b.mp4)
|
||||
|
||||
## Building
|
||||
|
||||
The `talk-llama` tool depends on SDL2 library to capture audio from the microphone. You can build it like this:
|
||||
|
||||
```bash
|
||||
# Install SDL2 on Linux
|
||||
sudo apt-get install libsdl2-dev
|
||||
|
||||
# Install SDL2 on Mac OS
|
||||
brew install sdl2
|
||||
|
||||
# Build the "talk-llama" executable
|
||||
make talk-llama
|
||||
|
||||
# Run it
|
||||
./talk-llama -mw ./models/ggml-small.en.bin -ml ../llama.cpp/models/13B/ggml-model-q4_0.bin -p "Georgi" -t 8
|
||||
```
|
||||
|
||||
- The `-mw` argument specifies the Whisper model that you would like to use. Recommended `base` or `small` for real-time experience
|
||||
- The `-ml` argument specifies the LLaMA model that you would like to use. Read the instructions in https://github.com/ggerganov/llama.cpp for information about how to obtain a `ggml` compatible LLaMA model
|
||||
|
||||
## TTS
|
||||
|
||||
For best experience, this example needs a TTS tool to convert the generated text responses to voice.
|
||||
You can use any TTS engine that you would like - simply edit the [speak.sh](speak.sh) script to your needs.
|
||||
By default, it is configured to use MacOS's `say`, but you can use whatever you wish.
|
||||
|
||||
## Discussion
|
||||
|
||||
If you have any feedback, please let "us" know in the following discussion: https://github.com/ggerganov/whisper.cpp/discussions/672?converting=1
|
23
examples/talk-llama/eleven-labs.py
Normal file
23
examples/talk-llama/eleven-labs.py
Normal file
@ -0,0 +1,23 @@
|
||||
import sys
|
||||
import importlib.util
|
||||
|
||||
api_key = "" #Write your https://beta.elevenlabs.io api key here
|
||||
if not api_key:
|
||||
print("To use elevenlabs you have to register to https://beta.elevenlabs.io and add your elevenlabs api key to examples/talk-llama/eleven-labs.py")
|
||||
sys.exit()
|
||||
|
||||
if importlib.util.find_spec("elevenlabs") is None:
|
||||
print("elevenlabs library is not installed, you can install it to your enviroment using 'pip install elevenlabs'")
|
||||
sys.exit()
|
||||
|
||||
from elevenlabs import ElevenLabs
|
||||
eleven = ElevenLabs(api_key)
|
||||
|
||||
# Get a Voice object, by name or UUID
|
||||
voice = eleven.voices["Arnold"] #Possible Voices: Adam Antoni Arnold Bella Domi Elli Josh
|
||||
|
||||
# Generate the TTS
|
||||
audio = voice.generate(str(sys.argv[2:]))
|
||||
|
||||
# Save the TTS to a file
|
||||
audio.save("audio")
|
433
examples/talk-llama/llama-util.h
Normal file
433
examples/talk-llama/llama-util.h
Normal file
@ -0,0 +1,433 @@
|
||||
// Internal header to be included only by llama.cpp.
|
||||
// Contains wrappers around OS interfaces.
|
||||
|
||||
#ifndef LLAMA_UTIL_H
|
||||
#define LLAMA_UTIL_H
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstdint>
|
||||
#include <cerrno>
|
||||
#include <cstring>
|
||||
#include <cstdarg>
|
||||
#include <cstdlib>
|
||||
#include <climits>
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#ifdef __has_include
|
||||
#if __has_include(<unistd.h>)
|
||||
#include <unistd.h>
|
||||
#if defined(_POSIX_MAPPED_FILES)
|
||||
#include <sys/mman.h>
|
||||
#endif
|
||||
#if defined(_POSIX_MEMLOCK_RANGE)
|
||||
#include <sys/resource.h>
|
||||
#endif
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if defined(_WIN32)
|
||||
#define WIN32_LEAN_AND_MEAN
|
||||
#ifndef NOMINMAX
|
||||
#define NOMINMAX
|
||||
#endif
|
||||
#include <windows.h>
|
||||
#include <io.h>
|
||||
#include <stdio.h> // for _fseeki64
|
||||
#endif
|
||||
|
||||
#define LLAMA_ASSERT(x) \
|
||||
do { \
|
||||
if (!(x)) { \
|
||||
fprintf(stderr, "LLAMA_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
|
||||
abort(); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#ifdef __GNUC__
|
||||
#ifdef __MINGW32__
|
||||
__attribute__((format(gnu_printf, 1, 2)))
|
||||
#else
|
||||
__attribute__((format(printf, 1, 2)))
|
||||
#endif
|
||||
#endif
|
||||
static std::string format(const char * fmt, ...) {
|
||||
va_list ap, ap2;
|
||||
va_start(ap, fmt);
|
||||
va_copy(ap2, ap);
|
||||
int size = vsnprintf(NULL, 0, fmt, ap);
|
||||
LLAMA_ASSERT(size >= 0 && size < INT_MAX);
|
||||
std::vector<char> buf(size + 1);
|
||||
int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
|
||||
LLAMA_ASSERT(size2 == size);
|
||||
va_end(ap2);
|
||||
va_end(ap);
|
||||
return std::string(buf.data(), size);
|
||||
}
|
||||
|
||||
struct llama_file {
|
||||
// use FILE * so we don't have to re-open the file to mmap
|
||||
FILE * fp;
|
||||
size_t size;
|
||||
|
||||
llama_file(const char * fname, const char * mode) {
|
||||
fp = std::fopen(fname, mode);
|
||||
if (fp == NULL) {
|
||||
throw format("failed to open %s: %s", fname, std::strerror(errno));
|
||||
}
|
||||
seek(0, SEEK_END);
|
||||
size = tell();
|
||||
seek(0, SEEK_SET);
|
||||
}
|
||||
|
||||
size_t tell() const {
|
||||
#ifdef _WIN32
|
||||
__int64 ret = _ftelli64(fp);
|
||||
#else
|
||||
long ret = std::ftell(fp);
|
||||
#endif
|
||||
LLAMA_ASSERT(ret != -1); // this really shouldn't fail
|
||||
return (size_t) ret;
|
||||
}
|
||||
|
||||
void seek(size_t offset, int whence) {
|
||||
#ifdef _WIN32
|
||||
int ret = _fseeki64(fp, (__int64) offset, whence);
|
||||
#else
|
||||
int ret = std::fseek(fp, (long) offset, whence);
|
||||
#endif
|
||||
LLAMA_ASSERT(ret == 0); // same
|
||||
}
|
||||
|
||||
void read_raw(void * ptr, size_t size) {
|
||||
if (size == 0) {
|
||||
return;
|
||||
}
|
||||
errno = 0;
|
||||
std::size_t ret = std::fread(ptr, size, 1, fp);
|
||||
if (ferror(fp)) {
|
||||
throw format("read error: %s", strerror(errno));
|
||||
}
|
||||
if (ret != 1) {
|
||||
throw std::string("unexpectedly reached end of file");
|
||||
}
|
||||
}
|
||||
|
||||
std::uint32_t read_u32() {
|
||||
std::uint32_t ret;
|
||||
read_raw(&ret, sizeof(ret));
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::string read_string(std::uint32_t len) {
|
||||
std::vector<char> chars(len);
|
||||
read_raw(chars.data(), len);
|
||||
return std::string(chars.data(), len);
|
||||
}
|
||||
|
||||
void write_raw(const void * ptr, size_t size) {
|
||||
if (size == 0) {
|
||||
return;
|
||||
}
|
||||
errno = 0;
|
||||
size_t ret = std::fwrite(ptr, size, 1, fp);
|
||||
if (ret != 1) {
|
||||
throw format("write error: %s", strerror(errno));
|
||||
}
|
||||
}
|
||||
|
||||
void write_u32(std::uint32_t val) {
|
||||
write_raw(&val, sizeof(val));
|
||||
}
|
||||
|
||||
~llama_file() {
|
||||
if (fp) {
|
||||
std::fclose(fp);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(_WIN32)
|
||||
static std::string llama_format_win_err(DWORD err) {
|
||||
LPSTR buf;
|
||||
size_t size = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
|
||||
NULL, err, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&buf, 0, NULL);
|
||||
if (!size) {
|
||||
return "FormatMessageA failed";
|
||||
}
|
||||
std::string ret(buf, size);
|
||||
LocalFree(buf);
|
||||
return ret;
|
||||
}
|
||||
#endif
|
||||
|
||||
struct llama_mmap {
|
||||
void * addr;
|
||||
size_t size;
|
||||
|
||||
llama_mmap(const llama_mmap &) = delete;
|
||||
|
||||
#ifdef _POSIX_MAPPED_FILES
|
||||
static constexpr bool SUPPORTED = true;
|
||||
|
||||
llama_mmap(struct llama_file * file, bool prefetch = true) {
|
||||
size = file->size;
|
||||
int fd = fileno(file->fp);
|
||||
int flags = MAP_SHARED;
|
||||
#ifdef __linux__
|
||||
flags |= MAP_POPULATE;
|
||||
#endif
|
||||
addr = mmap(NULL, file->size, PROT_READ, flags, fd, 0);
|
||||
if (addr == MAP_FAILED) {
|
||||
throw format("mmap failed: %s", strerror(errno));
|
||||
}
|
||||
|
||||
if (prefetch) {
|
||||
// Advise the kernel to preload the mapped memory
|
||||
if (madvise(addr, file->size, MADV_WILLNEED)) {
|
||||
fprintf(stderr, "warning: madvise(.., MADV_WILLNEED) failed: %s\n",
|
||||
strerror(errno));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
~llama_mmap() {
|
||||
munmap(addr, size);
|
||||
}
|
||||
#elif defined(_WIN32)
|
||||
static constexpr bool SUPPORTED = true;
|
||||
|
||||
llama_mmap(struct llama_file * file, bool prefetch = true) {
|
||||
size = file->size;
|
||||
|
||||
HANDLE hFile = (HANDLE) _get_osfhandle(_fileno(file->fp));
|
||||
|
||||
HANDLE hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL);
|
||||
DWORD error = GetLastError();
|
||||
|
||||
if (hMapping == NULL) {
|
||||
throw format("CreateFileMappingA failed: %s", llama_format_win_err(error).c_str());
|
||||
}
|
||||
|
||||
addr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0);
|
||||
error = GetLastError();
|
||||
CloseHandle(hMapping);
|
||||
|
||||
if (addr == NULL) {
|
||||
throw format("MapViewOfFile failed: %s", llama_format_win_err(error).c_str());
|
||||
}
|
||||
|
||||
#if _WIN32_WINNT >= _WIN32_WINNT_WIN8
|
||||
if (prefetch) {
|
||||
// Advise the kernel to preload the mapped memory
|
||||
WIN32_MEMORY_RANGE_ENTRY range;
|
||||
range.VirtualAddress = addr;
|
||||
range.NumberOfBytes = (SIZE_T)size;
|
||||
if (!PrefetchVirtualMemory(GetCurrentProcess(), 1, &range, 0)) {
|
||||
fprintf(stderr, "warning: PrefetchVirtualMemory failed: %s\n",
|
||||
llama_format_win_err(GetLastError()).c_str());
|
||||
}
|
||||
}
|
||||
#else
|
||||
#pragma message("warning: You are building for pre-Windows 8; prefetch not supported")
|
||||
#endif // _WIN32_WINNT >= _WIN32_WINNT_WIN8
|
||||
}
|
||||
|
||||
~llama_mmap() {
|
||||
if (!UnmapViewOfFile(addr)) {
|
||||
fprintf(stderr, "warning: UnmapViewOfFile failed: %s\n",
|
||||
llama_format_win_err(GetLastError()).c_str());
|
||||
}
|
||||
}
|
||||
#else
|
||||
static constexpr bool SUPPORTED = false;
|
||||
|
||||
llama_mmap(struct llama_file *) {
|
||||
throw std::string("mmap not supported");
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
// Represents some region of memory being locked using mlock or VirtualLock;
|
||||
// will automatically unlock on destruction.
|
||||
struct llama_mlock {
|
||||
void * addr = NULL;
|
||||
size_t size = 0;
|
||||
bool failed_already = false;
|
||||
|
||||
llama_mlock() {}
|
||||
llama_mlock(const llama_mlock &) = delete;
|
||||
|
||||
~llama_mlock() {
|
||||
if (size) {
|
||||
raw_unlock(addr, size);
|
||||
}
|
||||
}
|
||||
|
||||
void init(void * addr) {
|
||||
LLAMA_ASSERT(this->addr == NULL && this->size == 0);
|
||||
this->addr = addr;
|
||||
}
|
||||
|
||||
void grow_to(size_t target_size) {
|
||||
LLAMA_ASSERT(addr);
|
||||
if (failed_already) {
|
||||
return;
|
||||
}
|
||||
size_t granularity = lock_granularity();
|
||||
target_size = (target_size + granularity - 1) & ~(granularity - 1);
|
||||
if (target_size > size) {
|
||||
if (raw_lock((uint8_t *) addr + size, target_size - size)) {
|
||||
size = target_size;
|
||||
} else {
|
||||
failed_already = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef _POSIX_MEMLOCK_RANGE
|
||||
static constexpr bool SUPPORTED = true;
|
||||
|
||||
size_t lock_granularity() {
|
||||
return (size_t) sysconf(_SC_PAGESIZE);
|
||||
}
|
||||
|
||||
#ifdef __APPLE__
|
||||
#define MLOCK_SUGGESTION \
|
||||
"Try increasing the sysctl values 'vm.user_wire_limit' and 'vm.global_user_wire_limit' and/or " \
|
||||
"decreasing 'vm.global_no_user_wire_amount'. Also try increasing RLIMIT_MLOCK (ulimit -l).\n"
|
||||
#else
|
||||
#define MLOCK_SUGGESTION \
|
||||
"Try increasing RLIMIT_MLOCK ('ulimit -l' as root).\n"
|
||||
#endif
|
||||
|
||||
bool raw_lock(const void * addr, size_t size) {
|
||||
if (!mlock(addr, size)) {
|
||||
return true;
|
||||
} else {
|
||||
char* errmsg = std::strerror(errno);
|
||||
bool suggest = (errno == ENOMEM);
|
||||
|
||||
// Check if the resource limit is fine after all
|
||||
struct rlimit lock_limit;
|
||||
if (suggest && getrlimit(RLIMIT_MEMLOCK, &lock_limit))
|
||||
suggest = false;
|
||||
if (suggest && (lock_limit.rlim_max > lock_limit.rlim_cur + size))
|
||||
suggest = false;
|
||||
|
||||
fprintf(stderr, "warning: failed to mlock %zu-byte buffer (after previously locking %zu bytes): %s\n%s",
|
||||
size, this->size, errmsg, suggest ? MLOCK_SUGGESTION : "");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
#undef MLOCK_SUGGESTION
|
||||
|
||||
void raw_unlock(void * addr, size_t size) {
|
||||
if (munlock(addr, size)) {
|
||||
fprintf(stderr, "warning: failed to munlock buffer: %s\n", std::strerror(errno));
|
||||
}
|
||||
}
|
||||
#elif defined(_WIN32)
|
||||
static constexpr bool SUPPORTED = true;
|
||||
|
||||
size_t lock_granularity() {
|
||||
SYSTEM_INFO si;
|
||||
GetSystemInfo(&si);
|
||||
return (size_t) si.dwPageSize;
|
||||
}
|
||||
|
||||
bool raw_lock(void * addr, size_t size) {
|
||||
for (int tries = 1; ; tries++) {
|
||||
if (VirtualLock(addr, size)) {
|
||||
return true;
|
||||
}
|
||||
if (tries == 2) {
|
||||
fprintf(stderr, "warning: failed to VirtualLock %zu-byte buffer (after previously locking %zu bytes): %s\n",
|
||||
size, this->size, llama_format_win_err(GetLastError()).c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
// It failed but this was only the first try; increase the working
|
||||
// set size and try again.
|
||||
SIZE_T min_ws_size, max_ws_size;
|
||||
if (!GetProcessWorkingSetSize(GetCurrentProcess(), &min_ws_size, &max_ws_size)) {
|
||||
fprintf(stderr, "warning: GetProcessWorkingSetSize failed: %s\n",
|
||||
llama_format_win_err(GetLastError()).c_str());
|
||||
return false;
|
||||
}
|
||||
// Per MSDN: "The maximum number of pages that a process can lock
|
||||
// is equal to the number of pages in its minimum working set minus
|
||||
// a small overhead."
|
||||
// Hopefully a megabyte is enough overhead:
|
||||
size_t increment = size + 1048576;
|
||||
// The minimum must be <= the maximum, so we need to increase both:
|
||||
min_ws_size += increment;
|
||||
max_ws_size += increment;
|
||||
if (!SetProcessWorkingSetSize(GetCurrentProcess(), min_ws_size, max_ws_size)) {
|
||||
fprintf(stderr, "warning: SetProcessWorkingSetSize failed: %s\n",
|
||||
llama_format_win_err(GetLastError()).c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void raw_unlock(void * addr, size_t size) {
|
||||
if (!VirtualUnlock(addr, size)) {
|
||||
fprintf(stderr, "warning: failed to VirtualUnlock buffer: %s\n",
|
||||
llama_format_win_err(GetLastError()).c_str());
|
||||
}
|
||||
}
|
||||
#else
|
||||
static constexpr bool SUPPORTED = false;
|
||||
|
||||
void raw_lock(const void * addr, size_t size) {
|
||||
fprintf(stderr, "warning: mlock not supported on this system\n");
|
||||
}
|
||||
|
||||
void raw_unlock(const void * addr, size_t size) {}
|
||||
#endif
|
||||
};
|
||||
|
||||
// Replacement for std::vector<uint8_t> that doesn't require zero-initialization.
|
||||
struct llama_buffer {
|
||||
uint8_t * addr = NULL;
|
||||
size_t size = 0;
|
||||
|
||||
void resize(size_t size) {
|
||||
delete[] addr;
|
||||
addr = new uint8_t[size];
|
||||
this->size = size;
|
||||
}
|
||||
|
||||
~llama_buffer() {
|
||||
delete[] addr;
|
||||
}
|
||||
};
|
||||
|
||||
#ifdef GGML_USE_CUBLAS
|
||||
#include "ggml-cuda.h"
|
||||
struct llama_ctx_buffer {
|
||||
uint8_t * addr = NULL;
|
||||
size_t size = 0;
|
||||
|
||||
void resize(size_t size) {
|
||||
if (addr) {
|
||||
ggml_cuda_host_free(addr);
|
||||
}
|
||||
addr = (uint8_t *) ggml_cuda_host_malloc(size);
|
||||
this->size = size;
|
||||
}
|
||||
|
||||
~llama_ctx_buffer() {
|
||||
if (addr) {
|
||||
ggml_cuda_host_free(addr);
|
||||
}
|
||||
}
|
||||
};
|
||||
#else
|
||||
typedef llama_buffer llama_ctx_buffer;
|
||||
#endif
|
||||
|
||||
#endif
|
2750
examples/talk-llama/llama.cpp
Normal file
2750
examples/talk-llama/llama.cpp
Normal file
File diff suppressed because it is too large
Load Diff
257
examples/talk-llama/llama.h
Normal file
257
examples/talk-llama/llama.h
Normal file
@ -0,0 +1,257 @@
|
||||
#ifndef LLAMA_H
|
||||
#define LLAMA_H
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <stdbool.h>
|
||||
|
||||
#ifdef LLAMA_SHARED
|
||||
# if defined(_WIN32) && !defined(__MINGW32__)
|
||||
# ifdef LLAMA_BUILD
|
||||
# define LLAMA_API __declspec(dllexport)
|
||||
# else
|
||||
# define LLAMA_API __declspec(dllimport)
|
||||
# endif
|
||||
# else
|
||||
# define LLAMA_API __attribute__ ((visibility ("default")))
|
||||
# endif
|
||||
#else
|
||||
# define LLAMA_API
|
||||
#endif
|
||||
|
||||
#define LLAMA_FILE_VERSION 1
|
||||
#define LLAMA_FILE_MAGIC 0x67676a74 // 'ggjt' in hex
|
||||
#define LLAMA_FILE_MAGIC_UNVERSIONED 0x67676d6c // pre-versioned files
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
//
|
||||
// C interface
|
||||
//
|
||||
// TODO: show sample usage
|
||||
//
|
||||
|
||||
struct llama_context;
|
||||
|
||||
typedef int llama_token;
|
||||
|
||||
typedef struct llama_token_data {
|
||||
llama_token id; // token id
|
||||
float logit; // log-odds of the token
|
||||
float p; // probability of the token
|
||||
} llama_token_data;
|
||||
|
||||
typedef struct llama_token_data_array {
|
||||
llama_token_data * data;
|
||||
size_t size;
|
||||
bool sorted;
|
||||
} llama_token_data_array;
|
||||
|
||||
typedef void (*llama_progress_callback)(float progress, void *ctx);
|
||||
|
||||
struct llama_context_params {
|
||||
int n_ctx; // text context
|
||||
int n_parts; // -1 for default
|
||||
int seed; // RNG seed, 0 for random
|
||||
|
||||
bool f16_kv; // use fp16 for KV cache
|
||||
bool logits_all; // the llama_eval() call computes all logits, not just the last one
|
||||
bool vocab_only; // only load the vocabulary, no weights
|
||||
bool use_mmap; // use mmap if possible
|
||||
bool use_mlock; // force system to keep model in RAM
|
||||
bool embedding; // embedding mode only
|
||||
|
||||
// called with a progress value between 0 and 1, pass NULL to disable
|
||||
llama_progress_callback progress_callback;
|
||||
// context pointer passed to the progress callback
|
||||
void * progress_callback_user_data;
|
||||
};
|
||||
|
||||
// model file types
|
||||
enum llama_ftype {
|
||||
LLAMA_FTYPE_ALL_F32 = 0,
|
||||
LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16
|
||||
LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // except 1d tensors
|
||||
// LLAMA_FTYPE_MOSTLY_Q4_3 (6) support has been removed
|
||||
LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors
|
||||
};
|
||||
|
||||
LLAMA_API struct llama_context_params llama_context_default_params();
|
||||
|
||||
LLAMA_API bool llama_mmap_supported();
|
||||
LLAMA_API bool llama_mlock_supported();
|
||||
|
||||
// Various functions for loading a ggml llama model.
|
||||
// Allocate (almost) all memory needed for the model.
|
||||
// Return NULL on failure
|
||||
LLAMA_API struct llama_context * llama_init_from_file(
|
||||
const char * path_model,
|
||||
struct llama_context_params params);
|
||||
|
||||
// Frees all allocated memory
|
||||
LLAMA_API void llama_free(struct llama_context * ctx);
|
||||
|
||||
// TODO: not great API - very likely to change
|
||||
// Returns 0 on success
|
||||
// nthread - how many threads to use. If <=0, will use std::thread::hardware_concurrency(), else the number given
|
||||
LLAMA_API int llama_model_quantize(
|
||||
const char * fname_inp,
|
||||
const char * fname_out,
|
||||
enum llama_ftype ftype,
|
||||
int nthread);
|
||||
|
||||
// Apply a LoRA adapter to a loaded model
|
||||
// path_base_model is the path to a higher quality model to use as a base for
|
||||
// the layers modified by the adapter. Can be NULL to use the current loaded model.
|
||||
// The model needs to be reloaded before applying a new adapter, otherwise the adapter
|
||||
// will be applied on top of the previous one
|
||||
// Returns 0 on success
|
||||
LLAMA_API int llama_apply_lora_from_file(
|
||||
struct llama_context * ctx,
|
||||
const char * path_lora,
|
||||
const char * path_base_model,
|
||||
int n_threads);
|
||||
|
||||
// Returns the number of tokens in the KV cache
|
||||
LLAMA_API int llama_get_kv_cache_token_count(struct llama_context * ctx);
|
||||
|
||||
// Sets the current rng seed.
|
||||
LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, int seed);
|
||||
|
||||
// Returns the size in bytes of the state (rng, logits, embedding and kv_cache)
|
||||
LLAMA_API size_t llama_get_state_size(struct llama_context * ctx);
|
||||
|
||||
// Copies the state to the specified destination address.
|
||||
// Destination needs to have allocated enough memory.
|
||||
// Returns the number of bytes copied
|
||||
LLAMA_API size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest);
|
||||
|
||||
// Set the state reading from the specified address
|
||||
// Returns the number of bytes read
|
||||
LLAMA_API size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src);
|
||||
|
||||
// Save/load session file
|
||||
LLAMA_API size_t llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out);
|
||||
LLAMA_API size_t llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count);
|
||||
|
||||
// Run the llama inference to obtain the logits and probabilities for the next token.
|
||||
// tokens + n_tokens is the provided batch of new tokens to process
|
||||
// n_past is the number of tokens to use from previous eval calls
|
||||
// Returns 0 on success
|
||||
LLAMA_API int llama_eval(
|
||||
struct llama_context * ctx,
|
||||
const llama_token * tokens,
|
||||
int n_tokens,
|
||||
int n_past,
|
||||
int n_threads);
|
||||
|
||||
// Convert the provided text into tokens.
|
||||
// The tokens pointer must be large enough to hold the resulting tokens.
|
||||
// Returns the number of tokens on success, no more than n_max_tokens
|
||||
// Returns a negative number on failure - the number of tokens that would have been returned
|
||||
// TODO: not sure if correct
|
||||
LLAMA_API int llama_tokenize(
|
||||
struct llama_context * ctx,
|
||||
const char * text,
|
||||
llama_token * tokens,
|
||||
int n_max_tokens,
|
||||
bool add_bos);
|
||||
|
||||
LLAMA_API int llama_n_vocab(struct llama_context * ctx);
|
||||
LLAMA_API int llama_n_ctx (struct llama_context * ctx);
|
||||
LLAMA_API int llama_n_embd (struct llama_context * ctx);
|
||||
|
||||
// Token logits obtained from the last call to llama_eval()
|
||||
// The logits for the last token are stored in the last row
|
||||
// Can be mutated in order to change the probabilities of the next token
|
||||
// Rows: n_tokens
|
||||
// Cols: n_vocab
|
||||
LLAMA_API float * llama_get_logits(struct llama_context * ctx);
|
||||
|
||||
// Get the embeddings for the input
|
||||
// shape: [n_embd] (1-dimensional)
|
||||
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
|
||||
|
||||
// Token Id -> String. Uses the vocabulary in the provided context
|
||||
LLAMA_API const char * llama_token_to_str(struct llama_context * ctx, llama_token token);
|
||||
|
||||
// Special tokens
|
||||
LLAMA_API llama_token llama_token_bos();
|
||||
LLAMA_API llama_token llama_token_eos();
|
||||
LLAMA_API llama_token llama_token_nl();
|
||||
|
||||
// Sampling functions
|
||||
|
||||
/// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
|
||||
LLAMA_API void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, llama_token * last_tokens, size_t last_tokens_size, float penalty);
|
||||
|
||||
/// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
|
||||
LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, llama_token * last_tokens, size_t last_tokens_size, float alpha_frequency, float alpha_presence);
|
||||
|
||||
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
|
||||
LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates);
|
||||
|
||||
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
||||
LLAMA_API void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int k, size_t min_keep = 1);
|
||||
|
||||
/// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
||||
LLAMA_API void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep = 1);
|
||||
|
||||
/// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
|
||||
LLAMA_API void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep = 1);
|
||||
|
||||
/// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
|
||||
LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep = 1);
|
||||
LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp);
|
||||
|
||||
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
||||
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
||||
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
||||
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
|
||||
/// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
|
||||
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
|
||||
LLAMA_API llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu);
|
||||
|
||||
/// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
||||
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
||||
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
||||
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
|
||||
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
|
||||
LLAMA_API llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu);
|
||||
|
||||
/// @details Selects the token with the highest probability.
|
||||
LLAMA_API llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates);
|
||||
|
||||
/// @details Randomly selects a token from the candidates based on their probabilities.
|
||||
LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates);
|
||||
|
||||
// Performance information
|
||||
LLAMA_API void llama_print_timings(struct llama_context * ctx);
|
||||
LLAMA_API void llama_reset_timings(struct llama_context * ctx);
|
||||
|
||||
// Print system information
|
||||
LLAMA_API const char * llama_print_system_info(void);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
// Internal API to be implemented by llama.cpp and used by tests/benchmarks only
|
||||
#ifdef LLAMA_API_INTERNAL
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
struct ggml_tensor;
|
||||
|
||||
std::vector<std::pair<std::string, struct ggml_tensor *>>& llama_internal_get_tensor_map(struct llama_context * ctx);
|
||||
|
||||
#endif
|
||||
|
||||
#endif // LLAMA_H
|
23
examples/talk-llama/prompts/talk-alpaca.txt
Normal file
23
examples/talk-llama/prompts/talk-alpaca.txt
Normal file
@ -0,0 +1,23 @@
|
||||
Below is an instruction that describes a task. Write a response that appropriately completes the request.
|
||||
|
||||
### Instruction:
|
||||
|
||||
Write a text transcript of a never ending dialog, where {0} interacts with an AI assistant named {1}.
|
||||
{1} is helpful, kind, honest, friendly, good at writing and never fails to answer {0}’s requests immediately and with details and precision.
|
||||
There are no annotations like (30 seconds passed...) or (to himself), just what {0} and {1} say aloud to each other.
|
||||
The transcript only includes text, it does not include markup like HTML and Markdown.
|
||||
{1} responds with short and concise answers.
|
||||
|
||||
### Response:
|
||||
|
||||
{0}{4} Hello, {1}!
|
||||
{1}{4} Hello {0}! How may I help you today?
|
||||
{0}{4} What time is it?
|
||||
{1}{4} It is {2} o'clock.
|
||||
{0}{4} What year is it?
|
||||
{1}{4} We are in {3}.
|
||||
{0}{4} What is a cat?
|
||||
{1}{4} A cat is a domestic species of small carnivorous mammal. It is the only domesticated species in the family Felidae.
|
||||
{0}{4} Name a color.
|
||||
{1}{4} Blue
|
||||
{0}{4}
|
21
examples/talk-llama/speak.sh
Executable file
21
examples/talk-llama/speak.sh
Executable file
@ -0,0 +1,21 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Usage:
|
||||
# speak.sh <voice_id> <text-to-speak>
|
||||
|
||||
# espeak
|
||||
# Mac OS: brew install espeak
|
||||
# Linux: apt-get install espeak
|
||||
#
|
||||
#espeak -v en-us+m$1 -s 225 -p 50 -a 200 -g 5 -k 5 "$2"
|
||||
|
||||
# for Mac
|
||||
say "$2"
|
||||
|
||||
# Eleven Labs
|
||||
# To use it, install the elevenlabs module from pip (pip install elevenlabs), register to https://beta.elevenlabs.io to get an api key and paste it in /examples/talk-llama/eleven-labs.py
|
||||
#
|
||||
#wd=$(dirname $0)
|
||||
#script=$wd/eleven-labs.py
|
||||
#python3 $script $1 "$2" >/dev/null 2>&1
|
||||
#ffplay -autoexit -nodisp -loglevel quiet -hide_banner -i ./audio.mp3 >/dev/null 2>&1
|
576
examples/talk-llama/talk-llama.cpp
Normal file
576
examples/talk-llama/talk-llama.cpp
Normal file
@ -0,0 +1,576 @@
|
||||
// Talk with AI
|
||||
//
|
||||
|
||||
#include "common.h"
|
||||
#include "common-sdl.h"
|
||||
#include "whisper.h"
|
||||
#include "llama.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <cstdio>
|
||||
#include <fstream>
|
||||
#include <regex>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <regex>
|
||||
|
||||
std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos) {
|
||||
// initialize to prompt numer of chars, since n_tokens <= n_prompt_chars
|
||||
std::vector<llama_token> res(text.size() + (int)add_bos);
|
||||
int n = llama_tokenize(ctx, text.c_str(), res.data(), res.size(), add_bos);
|
||||
assert(n >= 0);
|
||||
res.resize(n);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
// command-line parameters
|
||||
struct whisper_params {
|
||||
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||
int32_t voice_ms = 10000;
|
||||
int32_t capture_id = -1;
|
||||
int32_t max_tokens = 32;
|
||||
int32_t audio_ctx = 0;
|
||||
|
||||
int32_t n_parts_llama = -1;
|
||||
|
||||
float vad_thold = 0.6f;
|
||||
float freq_thold = 100.0f;
|
||||
|
||||
bool speed_up = false;
|
||||
bool translate = false;
|
||||
bool print_special = false;
|
||||
bool print_energy = false;
|
||||
bool no_timestamps = true;
|
||||
bool verbose_prompt = false;
|
||||
|
||||
std::string person = "Georgi";
|
||||
std::string language = "en";
|
||||
std::string model_wsp = "models/ggml-base.en.bin";
|
||||
std::string model_llama = "models/ggml-llama-7B.bin";
|
||||
std::string speak = "./examples/talk-llama/speak.sh";
|
||||
std::string prompt = "";
|
||||
std::string fname_out;
|
||||
};
|
||||
|
||||
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
|
||||
|
||||
bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
for (int i = 1; i < argc; i++) {
|
||||
std::string arg = argv[i];
|
||||
|
||||
if (arg == "-h" || arg == "--help") {
|
||||
whisper_print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
}
|
||||
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
|
||||
else if (arg == "-vms" || arg == "--voice-ms") { params.voice_ms = std::stoi(argv[++i]); }
|
||||
else if (arg == "-c" || arg == "--capture") { params.capture_id = std::stoi(argv[++i]); }
|
||||
else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); }
|
||||
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
|
||||
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "--n-parts-llama") { params.n_parts_llama = std::stoi(argv[++i]); }
|
||||
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
|
||||
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
|
||||
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
|
||||
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
|
||||
else if (arg == "--verbose-prompt") { params.verbose_prompt = true; }
|
||||
else if (arg == "-p" || arg == "--person") { params.person = argv[++i]; }
|
||||
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
|
||||
else if (arg == "-mw" || arg == "--model-whisper") { params.model_wsp = argv[++i]; }
|
||||
else if (arg == "-ml" || arg == "--model-llama") { params.model_llama = argv[++i]; }
|
||||
else if (arg == "-s" || arg == "--speak") { params.speak = argv[++i]; }
|
||||
else if (arg == "--prompt-file") {
|
||||
std::ifstream file(argv[++i]);
|
||||
std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), back_inserter(params.prompt));
|
||||
if (params.prompt.back() == '\n') {
|
||||
params.prompt.pop_back();
|
||||
}
|
||||
}
|
||||
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
|
||||
else {
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
whisper_print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params) {
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "usage: %s [options]\n", argv[0]);
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "options:\n");
|
||||
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
|
||||
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
|
||||
fprintf(stderr, " -vms N, --voice-ms N [%-7d] voice duration in milliseconds\n", params.voice_ms);
|
||||
fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
|
||||
fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
|
||||
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
|
||||
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
|
||||
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
|
||||
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
|
||||
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
|
||||
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
||||
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
|
||||
fprintf(stderr, " -p NAME, --person NAME [%-7s] person name (for prompt selection)\n", params.person.c_str());
|
||||
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
|
||||
fprintf(stderr, " -mw FILE, --model-whisper [%-7s] whisper model file\n", params.model_wsp.c_str());
|
||||
fprintf(stderr, " -ml FILE, --model-llama [%-7s] llama model file\n", params.model_llama.c_str());
|
||||
fprintf(stderr, " --n-parts-llama N [%-7d] num parts in llama model file\n", params.n_parts_llama);
|
||||
fprintf(stderr, " -s FILE, --speak TEXT [%-7s] command for TTS\n", params.speak.c_str());
|
||||
fprintf(stderr, " --prompt-file FNAME [%-7s] file with custom prompt to start dialog\n", "");
|
||||
fprintf(stderr, " --verbose-prompt [%-7s] print prompt at start\n", params.verbose_prompt ? "true" : "false");
|
||||
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
std::string transcribe(
|
||||
whisper_context * ctx,
|
||||
const whisper_params & params,
|
||||
const std::vector<float> & pcmf32,
|
||||
const std::string prompt_text,
|
||||
float & prob,
|
||||
int64_t & t_ms) {
|
||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
prob = 0.0f;
|
||||
t_ms = 0;
|
||||
|
||||
std::vector<whisper_token> prompt_tokens;
|
||||
|
||||
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||
|
||||
prompt_tokens.resize(1024);
|
||||
prompt_tokens.resize(whisper_tokenize(ctx, prompt_text.c_str(), prompt_tokens.data(), prompt_tokens.size()));
|
||||
|
||||
wparams.print_progress = false;
|
||||
wparams.print_special = params.print_special;
|
||||
wparams.print_realtime = false;
|
||||
wparams.print_timestamps = !params.no_timestamps;
|
||||
wparams.translate = params.translate;
|
||||
wparams.no_context = true;
|
||||
wparams.single_segment = true;
|
||||
wparams.max_tokens = params.max_tokens;
|
||||
wparams.language = params.language.c_str();
|
||||
wparams.n_threads = params.n_threads;
|
||||
|
||||
wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
|
||||
wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();
|
||||
|
||||
wparams.audio_ctx = params.audio_ctx;
|
||||
wparams.speed_up = params.speed_up;
|
||||
|
||||
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
|
||||
return "";
|
||||
}
|
||||
|
||||
int prob_n = 0;
|
||||
std::string result;
|
||||
|
||||
const int n_segments = whisper_full_n_segments(ctx);
|
||||
for (int i = 0; i < n_segments; ++i) {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
|
||||
result += text;
|
||||
|
||||
const int n_tokens = whisper_full_n_tokens(ctx, i);
|
||||
for (int j = 0; j < n_tokens; ++j) {
|
||||
const auto token = whisper_full_get_token_data(ctx, i, j);
|
||||
|
||||
prob += token.p;
|
||||
++prob_n;
|
||||
}
|
||||
}
|
||||
|
||||
if (prob_n > 0) {
|
||||
prob /= prob_n;
|
||||
}
|
||||
|
||||
const auto t_end = std::chrono::high_resolution_clock::now();
|
||||
t_ms = std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count();
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
const std::string k_prompt_whisper = R"(A conversation with a person called {1}.)";
|
||||
|
||||
const std::string k_prompt_llama = R"(Text transcript of a never ending dialog, where {0} interacts with an AI assistant named {1}.
|
||||
{1} is helpful, kind, honest, friendly, good at writing and never fails to answer {0}’s requests immediately and with details and precision.
|
||||
There are no annotations like (30 seconds passed...) or (to himself), just what {0} and {1} say aloud to each other.
|
||||
The transcript only includes text, it does not include markup like HTML and Markdown.
|
||||
{1} responds with short and concise answers.
|
||||
|
||||
{0}{4} Hello, {1}!
|
||||
{1}{4} Hello {0}! How may I help you today?
|
||||
{0}{4} What time is it?
|
||||
{1}{4} It is {2} o'clock.
|
||||
{0}{4} What year is it?
|
||||
{1}{4} We are in {3}.
|
||||
{0}{4} What is a cat?
|
||||
{1}{4} A cat is a domestic species of small carnivorous mammal. It is the only domesticated species in the family Felidae.
|
||||
{0}{4} Name a color.
|
||||
{1}{4} Blue
|
||||
{0}{4})";
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
whisper_params params;
|
||||
|
||||
if (whisper_params_parse(argc, argv, params) == false) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (whisper_lang_id(params.language.c_str()) == -1) {
|
||||
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
|
||||
whisper_print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
}
|
||||
|
||||
// whisper init
|
||||
|
||||
struct whisper_context * ctx_wsp = whisper_init_from_file(params.model_wsp.c_str());
|
||||
|
||||
// llama init
|
||||
|
||||
auto lparams = llama_context_default_params();
|
||||
|
||||
// tune these to your liking
|
||||
lparams.n_ctx = 2048;
|
||||
lparams.seed = 1;
|
||||
lparams.f16_kv = true;
|
||||
lparams.n_parts = params.n_parts_llama;
|
||||
|
||||
struct llama_context * ctx_llama = llama_init_from_file(params.model_llama.c_str(), lparams);
|
||||
|
||||
// print some info about the processing
|
||||
{
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
if (!whisper_is_multilingual(ctx_wsp)) {
|
||||
if (params.language != "en" || params.translate) {
|
||||
params.language = "en";
|
||||
params.translate = false;
|
||||
fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
|
||||
}
|
||||
}
|
||||
fprintf(stderr, "%s: processing, %d threads, lang = %s, task = %s, timestamps = %d ...\n",
|
||||
__func__,
|
||||
params.n_threads,
|
||||
params.language.c_str(),
|
||||
params.translate ? "translate" : "transcribe",
|
||||
params.no_timestamps ? 0 : 1);
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
|
||||
// init audio
|
||||
|
||||
audio_async audio(30*1000);
|
||||
if (!audio.init(params.capture_id, WHISPER_SAMPLE_RATE)) {
|
||||
fprintf(stderr, "%s: audio.init() failed!\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
audio.resume();
|
||||
|
||||
int n_iter = 0;
|
||||
|
||||
bool is_running = true;
|
||||
bool force_speak = false;
|
||||
|
||||
float prob0 = 0.0f;
|
||||
|
||||
const std::string chat_symb = ":";
|
||||
const std::string bot_name = "LLaMA";
|
||||
|
||||
std::vector<float> pcmf32_cur;
|
||||
std::vector<float> pcmf32_prompt;
|
||||
|
||||
const std::string prompt_whisper = ::replace(k_prompt_whisper, "{1}", bot_name);
|
||||
|
||||
// construct the initial prompt for LLaMA inference
|
||||
std::string prompt_llama = params.prompt.empty() ? k_prompt_llama : params.prompt;
|
||||
|
||||
// need to have leading ' '
|
||||
prompt_llama.insert(0, 1, ' ');
|
||||
|
||||
prompt_llama = ::replace(prompt_llama, "{0}", params.person);
|
||||
prompt_llama = ::replace(prompt_llama, "{1}", bot_name);
|
||||
|
||||
{
|
||||
// get time string
|
||||
std::string time_str;
|
||||
{
|
||||
time_t t = time(0);
|
||||
struct tm * now = localtime(&t);
|
||||
char buf[128];
|
||||
strftime(buf, sizeof(buf), "%H:%M", now);
|
||||
time_str = buf;
|
||||
}
|
||||
prompt_llama = ::replace(prompt_llama, "{2}", time_str);
|
||||
}
|
||||
|
||||
{
|
||||
// get year string
|
||||
std::string year_str;
|
||||
{
|
||||
time_t t = time(0);
|
||||
struct tm * now = localtime(&t);
|
||||
char buf[128];
|
||||
strftime(buf, sizeof(buf), "%Y", now);
|
||||
year_str = buf;
|
||||
}
|
||||
prompt_llama = ::replace(prompt_llama, "{3}", year_str);
|
||||
}
|
||||
|
||||
prompt_llama = ::replace(prompt_llama, "{4}", chat_symb);
|
||||
|
||||
// evaluate the initial prompt
|
||||
|
||||
auto embd_inp = ::llama_tokenize(ctx_llama, prompt_llama, true);
|
||||
|
||||
printf("\n");
|
||||
printf("%s : initializing - please wait ...\n", __func__);
|
||||
|
||||
if (llama_eval(ctx_llama, embd_inp.data(), embd_inp.size(), 0, params.n_threads)) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (params.verbose_prompt) {
|
||||
fprintf(stdout, "\n");
|
||||
fprintf(stdout, "%s", prompt_llama.c_str());
|
||||
fflush(stdout);
|
||||
}
|
||||
|
||||
printf("%s : done! start speaking in the microphone\n", __func__);
|
||||
printf("\n");
|
||||
printf("%s%s", params.person.c_str(), chat_symb.c_str());
|
||||
fflush(stdout);
|
||||
|
||||
// clear audio buffer
|
||||
audio.clear();
|
||||
|
||||
// text inference variables
|
||||
const int voice_id = 2;
|
||||
const int n_keep = embd_inp.size();
|
||||
const int n_ctx = llama_n_ctx(ctx_llama);
|
||||
|
||||
int n_past = n_keep;
|
||||
int n_prev = 64; // TODO arg
|
||||
|
||||
std::vector<llama_token> embd;
|
||||
|
||||
// reverse prompts for detecting when it's time to stop speaking
|
||||
std::vector<std::string> antiprompts = {
|
||||
params.person + chat_symb,
|
||||
};
|
||||
|
||||
// main loop
|
||||
while (is_running) {
|
||||
// handle Ctrl + C
|
||||
is_running = sdl_poll_events();
|
||||
|
||||
if (!is_running) {
|
||||
break;
|
||||
}
|
||||
|
||||
// delay
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
|
||||
int64_t t_ms = 0;
|
||||
|
||||
{
|
||||
audio.get(2000, pcmf32_cur);
|
||||
|
||||
if (::vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1250, params.vad_thold, params.freq_thold, params.print_energy) || force_speak) {
|
||||
//fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
|
||||
|
||||
audio.get(params.voice_ms, pcmf32_cur);
|
||||
|
||||
std::string text_heard;
|
||||
|
||||
if (!force_speak) {
|
||||
text_heard = ::trim(::transcribe(ctx_wsp, params, pcmf32_cur, prompt_whisper, prob0, t_ms));
|
||||
}
|
||||
|
||||
// remove text between brackets using regex
|
||||
{
|
||||
std::regex re("\\[.*?\\]");
|
||||
text_heard = std::regex_replace(text_heard, re, "");
|
||||
}
|
||||
|
||||
// remove text between brackets using regex
|
||||
{
|
||||
std::regex re("\\(.*?\\)");
|
||||
text_heard = std::regex_replace(text_heard, re, "");
|
||||
}
|
||||
|
||||
// remove all characters, except for letters, numbers, punctuation and ':', '\'', '-', ' '
|
||||
text_heard = std::regex_replace(text_heard, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
|
||||
|
||||
// take first line
|
||||
text_heard = text_heard.substr(0, text_heard.find_first_of('\n'));
|
||||
|
||||
// remove leading and trailing whitespace
|
||||
text_heard = std::regex_replace(text_heard, std::regex("^\\s+"), "");
|
||||
text_heard = std::regex_replace(text_heard, std::regex("\\s+$"), "");
|
||||
|
||||
const std::vector<llama_token> tokens = llama_tokenize(ctx_llama, text_heard.c_str(), false);
|
||||
|
||||
if (text_heard.empty() || tokens.empty() || force_speak) {
|
||||
//fprintf(stdout, "%s: Heard nothing, skipping ...\n", __func__);
|
||||
audio.clear();
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
force_speak = false;
|
||||
|
||||
text_heard.insert(0, 1, ' ');
|
||||
text_heard += "\n" + bot_name + chat_symb;
|
||||
fprintf(stdout, "%s%s%s", "\033[1m", text_heard.c_str(), "\033[0m");
|
||||
fflush(stdout);
|
||||
|
||||
embd = ::llama_tokenize(ctx_llama, text_heard, false);
|
||||
|
||||
// text inference
|
||||
bool done = false;
|
||||
std::string text_to_speak;
|
||||
while (true) {
|
||||
// predict
|
||||
if (embd.size() > 0) {
|
||||
if (n_past + (int) embd.size() > n_ctx) {
|
||||
n_past = n_keep;
|
||||
|
||||
// insert n_left/2 tokens at the start of embd from last_n_tokens
|
||||
embd.insert(embd.begin(), embd_inp.begin() + embd_inp.size() - n_prev, embd_inp.end());
|
||||
|
||||
//printf("\n---\n");
|
||||
//printf("resetting: '");
|
||||
//for (int i = 0; i < (int) embd.size(); i++) {
|
||||
// printf("%s", llama_token_to_str(ctx_llama, embd[i]));
|
||||
//}
|
||||
//printf("'\n");
|
||||
//printf("\n---\n");
|
||||
}
|
||||
|
||||
if (llama_eval(ctx_llama, embd.data(), embd.size(), n_past, params.n_threads)) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
//printf("n_iter = %d, n_past = %d, n_ctx = %d, n_keep = %d, n_prev = %d, embd.size() = %d\n", n_iter, n_past, n_ctx, n_keep, n_prev, (int) embd.size());
|
||||
|
||||
embd_inp.insert(embd_inp.end(), embd.begin(), embd.end());
|
||||
n_past += embd.size();
|
||||
embd.clear();
|
||||
|
||||
if (done) break;
|
||||
|
||||
{
|
||||
// out of user input, sample next token
|
||||
const float top_k = 5;
|
||||
const float top_p = 0.80f;
|
||||
const float temp = 0.30f;
|
||||
const float repeat_penalty = 1.1764f;
|
||||
|
||||
const int repeat_last_n = 256;
|
||||
|
||||
llama_token id = 0;
|
||||
|
||||
{
|
||||
auto logits = llama_get_logits(ctx_llama);
|
||||
auto n_vocab = llama_n_vocab(ctx_llama);
|
||||
|
||||
logits[llama_token_eos()] = 0;
|
||||
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
||||
}
|
||||
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
|
||||
// apply repeat penalty
|
||||
const float nl_logit = logits[llama_token_nl()];
|
||||
|
||||
llama_sample_repetition_penalty(ctx_llama, &candidates_p,
|
||||
embd_inp.data() + std::max(0, n_past - repeat_last_n),
|
||||
repeat_last_n, repeat_penalty);
|
||||
|
||||
logits[llama_token_nl()] = nl_logit;
|
||||
|
||||
if (temp <= 0) {
|
||||
// Greedy sampling
|
||||
id = llama_sample_token_greedy(ctx_llama, &candidates_p);
|
||||
} else {
|
||||
// Temperature sampling
|
||||
llama_sample_top_k(ctx_llama, &candidates_p, top_k);
|
||||
llama_sample_top_p(ctx_llama, &candidates_p, top_p);
|
||||
llama_sample_temperature(ctx_llama, &candidates_p, temp);
|
||||
id = llama_sample_token(ctx_llama, &candidates_p);
|
||||
}
|
||||
}
|
||||
|
||||
if (id != llama_token_eos()) {
|
||||
// add it to the context
|
||||
embd.push_back(id);
|
||||
|
||||
text_to_speak += llama_token_to_str(ctx_llama, id);
|
||||
|
||||
printf("%s", llama_token_to_str(ctx_llama, id));
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
std::string last_output;
|
||||
for (int i = embd_inp.size() - 16; i < (int) embd_inp.size(); i++) {
|
||||
last_output += llama_token_to_str(ctx_llama, embd_inp[i]);
|
||||
}
|
||||
last_output += llama_token_to_str(ctx_llama, embd[0]);
|
||||
|
||||
for (std::string & antiprompt : antiprompts) {
|
||||
if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
|
||||
done = true;
|
||||
text_to_speak = ::replace(text_to_speak, antiprompt, "");
|
||||
fflush(stdout);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
is_running = sdl_poll_events();
|
||||
|
||||
if (!is_running) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
text_to_speak = ::replace(text_to_speak, "\"", "");
|
||||
system((params.speak + " " + std::to_string(voice_id) + " \"" + text_to_speak + "\"").c_str());
|
||||
|
||||
audio.clear();
|
||||
|
||||
++n_iter;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
audio.pause();
|
||||
|
||||
whisper_print_timings(ctx_wsp);
|
||||
whisper_free(ctx_wsp);
|
||||
|
||||
llama_print_timings(ctx_llama);
|
||||
llama_free(ctx_llama);
|
||||
|
||||
return 0;
|
||||
}
|
@ -9,8 +9,11 @@ add_executable(${TARGET}
|
||||
gpt-2.cpp
|
||||
)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE
|
||||
whisper
|
||||
common
|
||||
)
|
||||
|
||||
unset(EXTRA_FLAGS)
|
||||
@ -31,8 +34,8 @@ set_target_properties(${TARGET} PROPERTIES LINK_FLAGS " \
|
||||
--bind \
|
||||
-s USE_PTHREADS=1 \
|
||||
-s PTHREAD_POOL_SIZE=8 \
|
||||
-s INITIAL_MEMORY=1600MB \
|
||||
-s TOTAL_MEMORY=1600MB \
|
||||
-s INITIAL_MEMORY=1800MB \
|
||||
-s TOTAL_MEMORY=1800MB \
|
||||
-s FORCE_FILESYSTEM=1 \
|
||||
-s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap']\" \
|
||||
${EXTRA_FLAGS} \
|
||||
|
@ -36,7 +36,7 @@ In order to run this demo efficiently, you need to have the following:
|
||||
- Latest Chrome or Firefox browser (Safari is not supported)
|
||||
- Run this on a desktop or laptop with modern CPU (a mobile phone will likely not be good enough)
|
||||
- Speak phrases that are no longer than 10 seconds - this is the audio context of the AI
|
||||
- The web-page uses about 1.6GB of RAM
|
||||
- The web-page uses about 1.8GB of RAM
|
||||
|
||||
Notice that this demo is using the smallest GPT-2 model, so the generated text responses are not always very good.
|
||||
Also, the prompting strategy can likely be improved to achieve better results.
|
||||
|
@ -271,7 +271,7 @@ EMSCRIPTEN_BINDINGS(talk) {
|
||||
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
|
||||
for (size_t i = 0; i < g_contexts.size(); ++i) {
|
||||
if (g_contexts[i] == nullptr) {
|
||||
g_contexts[i] = whisper_init(path_model.c_str());
|
||||
g_contexts[i] = whisper_init_from_file(path_model.c_str());
|
||||
if (g_contexts[i] != nullptr) {
|
||||
g_running = true;
|
||||
if (g_worker.joinable()) {
|
||||
|
@ -1,4 +1,6 @@
|
||||
#include "ggml.h"
|
||||
#include "common-ggml.h"
|
||||
|
||||
#include "gpt-2.h"
|
||||
|
||||
#include <cmath>
|
||||
@ -14,150 +16,6 @@
|
||||
|
||||
/////////////////////// GPT-2 BEGIN /////////////////////////
|
||||
|
||||
//
|
||||
// Vocab utils
|
||||
//
|
||||
|
||||
std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text) {
|
||||
std::vector<std::string> words;
|
||||
|
||||
// first split the text into words
|
||||
{
|
||||
std::string str = text;
|
||||
std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
|
||||
|
||||
std::regex re(pat);
|
||||
std::smatch m;
|
||||
|
||||
while (std::regex_search(str, m, re)) {
|
||||
for (auto x : m) {
|
||||
words.push_back(x);
|
||||
}
|
||||
str = m.suffix();
|
||||
}
|
||||
}
|
||||
|
||||
// find the longest tokens that form the words:
|
||||
std::vector<gpt_vocab::id> tokens;
|
||||
for (const auto & word : words) {
|
||||
if (word.size() == 0) continue;
|
||||
|
||||
int i = 0;
|
||||
int n = word.size();
|
||||
while (i < n) {
|
||||
int j = n;
|
||||
while (j > i) {
|
||||
auto it = vocab.token_to_id.find(word.substr(i, j-i));
|
||||
if (it != vocab.token_to_id.end()) {
|
||||
tokens.push_back(it->second);
|
||||
i = j;
|
||||
break;
|
||||
}
|
||||
--j;
|
||||
}
|
||||
if (i == n) {
|
||||
break;
|
||||
}
|
||||
if (j == i) {
|
||||
auto sub = word.substr(i, 1);
|
||||
if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) {
|
||||
tokens.push_back(vocab.token_to_id.at(sub));
|
||||
} else {
|
||||
fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data());
|
||||
}
|
||||
++i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tokens;
|
||||
}
|
||||
|
||||
gpt_vocab::id gpt_sample_top_k_top_p(
|
||||
const gpt_vocab & vocab,
|
||||
const float * logits,
|
||||
int top_k,
|
||||
double top_p,
|
||||
double temp,
|
||||
std::mt19937 & rng) {
|
||||
int n_logits = vocab.id_to_token.size();
|
||||
|
||||
std::vector<std::pair<double, gpt_vocab::id>> logits_id;
|
||||
logits_id.reserve(n_logits);
|
||||
|
||||
for (int i = 0; i < n_logits; i++) {
|
||||
logits_id.push_back(std::make_pair(logits[i], i));
|
||||
}
|
||||
|
||||
// find the top K tokens
|
||||
std::partial_sort(
|
||||
logits_id.begin(),
|
||||
logits_id.begin() + top_k, logits_id.end(),
|
||||
[](const std::pair<double, gpt_vocab::id> & a, const std::pair<double, gpt_vocab::id> & b) {
|
||||
return a.first > b.first;
|
||||
});
|
||||
|
||||
logits_id.resize(top_k);
|
||||
|
||||
// normalize
|
||||
{
|
||||
double sum = 0.0f;
|
||||
for (int i = 0; i < (int)logits_id.size(); i++) {
|
||||
sum += logits_id[i].first;
|
||||
}
|
||||
|
||||
sum = 1.0/sum;
|
||||
for (int i = 0; i < (int)logits_id.size(); i++) {
|
||||
logits_id[i].first *= sum;
|
||||
}
|
||||
}
|
||||
|
||||
if (top_p < 1.0f) {
|
||||
{
|
||||
double cumsum = 0.0f;
|
||||
for (int i = 0; i < top_k; i++) {
|
||||
cumsum += logits_id[i].first;
|
||||
if (cumsum >= top_p) {
|
||||
logits_id.resize(i+1);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// normalize again
|
||||
{
|
||||
double sum = 0.0f;
|
||||
for (int i = 0; i < (int)logits_id.size(); i++) {
|
||||
sum += logits_id[i].first;
|
||||
}
|
||||
|
||||
sum = 1.0/sum;
|
||||
for (int i = 0; i < (int)logits_id.size(); i++) {
|
||||
logits_id[i].first *= sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//printf("\n");
|
||||
//for (int i = 0; i < (int)logits_id.size(); i++) {
|
||||
// printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), logits_id[i].first);
|
||||
//}
|
||||
//exit(0);
|
||||
|
||||
// sample from the obtained distribution
|
||||
std::vector<double> probs;
|
||||
probs.reserve(logits_id.size());
|
||||
|
||||
for (int i = 0; i < (int) logits_id.size(); i++) {
|
||||
probs.push_back(logits_id[i].first);
|
||||
}
|
||||
|
||||
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
||||
int idx = dist(rng);
|
||||
|
||||
return logits_id[idx].second;
|
||||
}
|
||||
|
||||
// default hparams (GPT-2 117M)
|
||||
struct gpt2_hparams {
|
||||
int32_t n_vocab = 50257;
|
||||
@ -165,7 +23,7 @@ struct gpt2_hparams {
|
||||
int32_t n_embd = 768;
|
||||
int32_t n_head = 12;
|
||||
int32_t n_layer = 12;
|
||||
int32_t f16 = 1;
|
||||
int32_t ftype = 1;
|
||||
};
|
||||
|
||||
struct gpt2_layer {
|
||||
@ -187,7 +45,7 @@ struct gpt2_layer {
|
||||
struct ggml_tensor * c_mlp_fc_w;
|
||||
struct ggml_tensor * c_mlp_fc_b;
|
||||
|
||||
struct ggml_tensor * c_mlp_proj_w_trans; // transposed for efficiency
|
||||
struct ggml_tensor * c_mlp_proj_w;
|
||||
struct ggml_tensor * c_mlp_proj_b;
|
||||
};
|
||||
|
||||
@ -198,8 +56,9 @@ struct gpt2_model {
|
||||
struct ggml_tensor * ln_f_g;
|
||||
struct ggml_tensor * ln_f_b;
|
||||
|
||||
struct ggml_tensor * wte; // position embedding
|
||||
struct ggml_tensor * wpe; // token embedding
|
||||
struct ggml_tensor * wte; // position embedding
|
||||
struct ggml_tensor * wpe; // token embedding
|
||||
struct ggml_tensor * lm_head; // language model head
|
||||
|
||||
std::vector<gpt2_layer> layers;
|
||||
|
||||
@ -241,14 +100,14 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
|
||||
fin.read((char *) &hparams.n_embd, sizeof(hparams.n_embd));
|
||||
fin.read((char *) &hparams.n_head, sizeof(hparams.n_head));
|
||||
fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));
|
||||
fin.read((char *) &hparams.f16, sizeof(hparams.f16));
|
||||
fin.read((char *) &hparams.ftype, sizeof(hparams.ftype));
|
||||
|
||||
printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
|
||||
printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx);
|
||||
printf("%s: n_embd = %d\n", __func__, hparams.n_embd);
|
||||
printf("%s: n_head = %d\n", __func__, hparams.n_head);
|
||||
printf("%s: n_layer = %d\n", __func__, hparams.n_layer);
|
||||
printf("%s: f16 = %d\n", __func__, hparams.f16);
|
||||
printf("%s: ftype = %d\n", __func__, hparams.ftype);
|
||||
}
|
||||
|
||||
// load vocab
|
||||
@ -275,9 +134,14 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
|
||||
}
|
||||
}
|
||||
|
||||
// for the big tensors, we have the option to store the data in 16-bit floats
|
||||
// for the big tensors, we have the option to store the data in 16-bit floats or quantized
|
||||
// in order to save memory and also to speed up the computation
|
||||
const ggml_type wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
||||
ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype));
|
||||
if (wtype == GGML_TYPE_COUNT) {
|
||||
fprintf(stderr, "%s: invalid model file '%s' (bad ftype value %d)\n",
|
||||
__func__, fname.c_str(), model.hparams.ftype);
|
||||
return false;
|
||||
}
|
||||
|
||||
auto & ctx = model.ctx;
|
||||
|
||||
@ -291,32 +155,33 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
|
||||
const int n_ctx = hparams.n_ctx;
|
||||
const int n_vocab = hparams.n_vocab;
|
||||
|
||||
ctx_size += n_embd*ggml_type_size(GGML_TYPE_F32); // ln_f_g
|
||||
ctx_size += n_embd*ggml_type_size(GGML_TYPE_F32); // ln_f_b
|
||||
ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_g
|
||||
ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_b
|
||||
|
||||
ctx_size += n_vocab*n_embd*ggml_type_size(wtype); // wte
|
||||
ctx_size += n_ctx*n_embd*ggml_type_size(GGML_TYPE_F32); // wpe
|
||||
ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype); // wte
|
||||
ctx_size += n_ctx*n_embd*ggml_type_sizef(GGML_TYPE_F32); // wpe
|
||||
ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype); // lm_head
|
||||
|
||||
ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_1_g
|
||||
ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_1_b
|
||||
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_g
|
||||
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_b
|
||||
|
||||
ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_2_g
|
||||
ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_2_b
|
||||
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_g
|
||||
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_b
|
||||
|
||||
ctx_size += n_layer*(3*n_embd*n_embd*ggml_type_size(wtype)); // c_attn_attn_w
|
||||
ctx_size += n_layer*( 3*n_embd*ggml_type_size(GGML_TYPE_F32)); // c_attn_attn_b
|
||||
ctx_size += n_layer*(3*n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_attn_w
|
||||
ctx_size += n_layer*( 3*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_attn_b
|
||||
|
||||
ctx_size += n_layer*(n_embd*n_embd*ggml_type_size(wtype)); // c_attn_proj_w
|
||||
ctx_size += n_layer*( n_embd*ggml_type_size(GGML_TYPE_F32)); // c_attn_proj_b
|
||||
ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_proj_w
|
||||
ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_proj_b
|
||||
|
||||
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_size(wtype)); // c_mlp_fc_w
|
||||
ctx_size += n_layer*( 4*n_embd*ggml_type_size(GGML_TYPE_F32)); // c_mlp_fc_b
|
||||
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_fc_w
|
||||
ctx_size += n_layer*( 4*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_fc_b
|
||||
|
||||
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_size(wtype)); // c_mlp_proj_w
|
||||
ctx_size += n_layer*( n_embd*ggml_type_size(GGML_TYPE_F32)); // c_mlp_proj_b
|
||||
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_proj_w
|
||||
ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_proj_b
|
||||
|
||||
ctx_size += n_ctx*n_layer*n_embd*ggml_type_size(GGML_TYPE_F32); // memory_k
|
||||
ctx_size += n_ctx*n_layer*n_embd*ggml_type_size(GGML_TYPE_F32); // memory_v
|
||||
ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_k
|
||||
ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_v
|
||||
|
||||
ctx_size += (6 + 12*n_layer)*256; // object overhead
|
||||
|
||||
@ -325,9 +190,11 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
|
||||
|
||||
// create the ggml context
|
||||
{
|
||||
struct ggml_init_params params;
|
||||
params.mem_size = ctx_size;
|
||||
params.mem_buffer = NULL;
|
||||
struct ggml_init_params params = {
|
||||
.mem_size = ctx_size,
|
||||
.mem_buffer = NULL,
|
||||
.no_alloc = false,
|
||||
};
|
||||
|
||||
model.ctx = ggml_init(params);
|
||||
if (!model.ctx) {
|
||||
@ -350,36 +217,38 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
|
||||
model.ln_f_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
model.ln_f_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
|
||||
model.wte = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab);
|
||||
model.wpe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ctx);
|
||||
model.wte = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab);
|
||||
model.wpe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ctx);
|
||||
model.lm_head = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab);
|
||||
|
||||
// map by name
|
||||
model.tensors["model/ln_f/g"] = model.ln_f_g;
|
||||
model.tensors["model/ln_f/b"] = model.ln_f_b;
|
||||
|
||||
model.tensors["model/wte"] = model.wte;
|
||||
model.tensors["model/wpe"] = model.wpe;
|
||||
model.tensors["model/wte"] = model.wte;
|
||||
model.tensors["model/wpe"] = model.wpe;
|
||||
model.tensors["model/lm_head"] = model.lm_head;
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = model.layers[i];
|
||||
|
||||
layer.ln_1_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
layer.ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
layer.ln_1_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
layer.ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
|
||||
layer.ln_2_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
layer.ln_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
layer.ln_2_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
layer.ln_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
|
||||
layer.c_attn_attn_w = ggml_new_tensor_2d(ctx, wtype, 3*n_embd, n_embd);
|
||||
layer.c_attn_attn_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3*n_embd);
|
||||
layer.c_attn_attn_w = ggml_new_tensor_2d(ctx, wtype, n_embd, 3*n_embd);
|
||||
layer.c_attn_attn_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3*n_embd);
|
||||
|
||||
layer.c_attn_proj_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
|
||||
layer.c_attn_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
layer.c_attn_proj_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
|
||||
layer.c_attn_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
|
||||
layer.c_mlp_fc_w = ggml_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd);
|
||||
layer.c_mlp_fc_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_embd);
|
||||
layer.c_mlp_fc_w = ggml_new_tensor_2d(ctx, wtype, n_embd, 4*n_embd);
|
||||
layer.c_mlp_fc_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_embd);
|
||||
|
||||
layer.c_mlp_proj_w_trans = ggml_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd);
|
||||
layer.c_mlp_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
layer.c_mlp_proj_w = ggml_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd);
|
||||
layer.c_mlp_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
|
||||
// map by name
|
||||
model.tensors["model/h" + std::to_string(i) + "/ln_1/g"] = layer.ln_1_g;
|
||||
@ -397,7 +266,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
|
||||
model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/w"] = layer.c_mlp_fc_w;
|
||||
model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/b"] = layer.c_mlp_fc_b;
|
||||
|
||||
model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/w"] = layer.c_mlp_proj_w_trans;
|
||||
model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/w"] = layer.c_mlp_proj_w;
|
||||
model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/b"] = layer.c_mlp_proj_b;
|
||||
}
|
||||
}
|
||||
@ -425,14 +294,16 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
|
||||
{
|
||||
size_t total_size = 0;
|
||||
|
||||
bool has_lm_head = false;
|
||||
|
||||
while (true) {
|
||||
int32_t n_dims;
|
||||
int32_t length;
|
||||
int32_t ftype;
|
||||
int32_t ttype;
|
||||
|
||||
fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
|
||||
fin.read(reinterpret_cast<char *>(&length), sizeof(length));
|
||||
fin.read(reinterpret_cast<char *>(&ftype), sizeof(ftype));
|
||||
fin.read(reinterpret_cast<char *>(&ttype), sizeof(ttype));
|
||||
|
||||
if (fin.eof()) {
|
||||
break;
|
||||
@ -461,13 +332,18 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
|
||||
|
||||
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
|
||||
fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
|
||||
__func__, name.data(), tensor->ne[0], tensor->ne[1], ne[0], ne[1]);
|
||||
__func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], ne[0], ne[1]);
|
||||
return false;
|
||||
}
|
||||
|
||||
const size_t bpe = (ftype == 0) ? sizeof(float) : sizeof(ggml_fp16_t);
|
||||
// for debugging
|
||||
if (0) {
|
||||
printf("%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", name.data(), ne[0], ne[1], ggml_type_name(ggml_type(ttype)), ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor));
|
||||
}
|
||||
|
||||
if (nelements*bpe != ggml_nbytes(tensor)) {
|
||||
const size_t bpe = ggml_type_size(ggml_type(ttype));
|
||||
|
||||
if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
|
||||
fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
|
||||
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
|
||||
return false;
|
||||
@ -475,7 +351,15 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
|
||||
|
||||
fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
|
||||
|
||||
//printf("%24s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
|
||||
// GPT-2 models share the WTE tensor as the LM head
|
||||
if (name == "model/wte" && has_lm_head == false) {
|
||||
memcpy(model.lm_head->data, tensor->data, ggml_nbytes(tensor));
|
||||
}
|
||||
|
||||
if (name == "model/lm_head") {
|
||||
has_lm_head = true;
|
||||
}
|
||||
|
||||
total_size += ggml_nbytes(tensor);
|
||||
}
|
||||
|
||||
@ -493,7 +377,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
|
||||
// - n_threads: number of threads to use
|
||||
// - n_past: the context size so far
|
||||
// - embd_inp: the embeddings of the tokens in the context
|
||||
// - embd_w: the predicted probabilities of the next token
|
||||
// - embd_w: the predicted logits for the next token
|
||||
//
|
||||
bool gpt2_eval(
|
||||
const gpt2_model & model,
|
||||
@ -512,12 +396,12 @@ bool gpt2_eval(
|
||||
const int n_head = hparams.n_head;
|
||||
const int n_vocab = hparams.n_vocab;
|
||||
|
||||
static size_t buf_size = 640u*1024*1024;
|
||||
static size_t buf_size = 512u*1024*1024;
|
||||
static void * buf = malloc(buf_size);
|
||||
|
||||
if (mem_per_token > 0 && mem_per_token*N > buf_size) {
|
||||
const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead
|
||||
printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
|
||||
//printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
|
||||
|
||||
// reallocate
|
||||
buf_size = buf_size_new;
|
||||
@ -528,13 +412,14 @@ bool gpt2_eval(
|
||||
}
|
||||
}
|
||||
|
||||
struct ggml_init_params params;
|
||||
params.mem_size = buf_size;
|
||||
params.mem_buffer = buf;
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ buf_size,
|
||||
/*.mem_buffer =*/ buf,
|
||||
/*.no_alloc =*/ false,
|
||||
};
|
||||
|
||||
struct ggml_context * ctx0 = ggml_init(params);
|
||||
|
||||
struct ggml_cgraph gf = { };
|
||||
struct ggml_cgraph gf = {};
|
||||
gf.n_threads = n_threads;
|
||||
|
||||
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
||||
@ -578,7 +463,7 @@ bool gpt2_eval(
|
||||
// [2304, N]
|
||||
{
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
ggml_transpose(ctx0, model.layers[il].c_attn_attn_w),
|
||||
model.layers[il].c_attn_attn_w,
|
||||
cur);
|
||||
|
||||
cur = ggml_add(ctx0,
|
||||
@ -654,11 +539,13 @@ bool gpt2_eval(
|
||||
// V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
|
||||
// [n_past + N, 64, 12]
|
||||
struct ggml_tensor * V_trans =
|
||||
ggml_permute(ctx0,
|
||||
ggml_reshape_3d(ctx0,
|
||||
ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
|
||||
n_embd/n_head, n_head, n_past + N),
|
||||
1, 2, 0, 3);
|
||||
ggml_cpy(ctx0,
|
||||
ggml_permute(ctx0,
|
||||
ggml_reshape_3d(ctx0,
|
||||
ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
|
||||
n_embd/n_head, n_head, n_past + N),
|
||||
1, 2, 0, 3),
|
||||
ggml_new_tensor_3d(ctx0, model.memory_v->type, n_past + N, n_embd/n_head, n_head));
|
||||
|
||||
// KQV = transpose(V) * KQ_soft_max
|
||||
// [64, N, 12]
|
||||
@ -685,7 +572,7 @@ bool gpt2_eval(
|
||||
// [768, N]
|
||||
{
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
ggml_transpose(ctx0, model.layers[il].c_attn_proj_w),
|
||||
model.layers[il].c_attn_proj_w,
|
||||
cur);
|
||||
|
||||
cur = ggml_add(ctx0,
|
||||
@ -722,7 +609,7 @@ bool gpt2_eval(
|
||||
// cur = fc_w*cur + fc_b
|
||||
// [3072, N]
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
ggml_transpose(ctx0, model.layers[il].c_mlp_fc_w),
|
||||
model.layers[il].c_mlp_fc_w,
|
||||
cur);
|
||||
|
||||
cur = ggml_add(ctx0,
|
||||
@ -742,7 +629,7 @@ bool gpt2_eval(
|
||||
// cur = proj_w*cur + proj_b
|
||||
// [768, N]
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
model.layers[il].c_mlp_proj_w_trans,
|
||||
model.layers[il].c_mlp_proj_w,
|
||||
cur);
|
||||
|
||||
cur = ggml_add(ctx0,
|
||||
@ -769,12 +656,12 @@ bool gpt2_eval(
|
||||
}
|
||||
|
||||
// inpL = WTE * inpL
|
||||
// [ 768, 50257] - model.wte
|
||||
// [ 768, 50257] - model.lm_head
|
||||
// [ 768, N] - inpL
|
||||
inpL = ggml_mul_mat(ctx0, model.wte, inpL);
|
||||
inpL = ggml_mul_mat(ctx0, model.lm_head, inpL);
|
||||
|
||||
// logits -> probs
|
||||
inpL = ggml_soft_max(ctx0, inpL);
|
||||
//inpL = ggml_soft_max(ctx0, inpL);
|
||||
|
||||
// run the computation
|
||||
ggml_build_forward_expand(&gf, inpL);
|
||||
@ -788,7 +675,7 @@ bool gpt2_eval(
|
||||
//embd_w.resize(n_vocab*N);
|
||||
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
|
||||
|
||||
// return result for just the last token
|
||||
// return result just for the last token
|
||||
embd_w.resize(n_vocab);
|
||||
memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
|
||||
|
||||
@ -825,7 +712,7 @@ Me too.
|
||||
int32_t n_threads = std::min(N_THREAD, (int) std::thread::hardware_concurrency());
|
||||
|
||||
// sampling parameters
|
||||
int32_t top_k = 40;
|
||||
int32_t top_k = 5;
|
||||
float top_p = 0.9f;
|
||||
float temp = 1.0f;
|
||||
};
|
||||
@ -833,14 +720,15 @@ Me too.
|
||||
struct gpt2_context * gpt2_init(const char * path_model) {
|
||||
gpt2_context * ctx = new gpt2_context;
|
||||
|
||||
ctx->rng = std::mt19937(time(NULL));
|
||||
ctx->rng = std::mt19937(time(nullptr));
|
||||
|
||||
// load the model
|
||||
{
|
||||
const int64_t t_start_us = ggml_time_us();
|
||||
|
||||
if (!gpt2_model_load(path_model, ctx->model, ctx->vocab)) {
|
||||
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, "gpt-2.bin");
|
||||
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, path_model);
|
||||
delete ctx;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@ -884,9 +772,9 @@ std::string gpt2_gen_text(gpt2_context * ctx, const char * text, int max_tokens)
|
||||
|
||||
std::string result;
|
||||
|
||||
for (int i = embd.size(); i < embd_inp.size() + n_predict; i++) {
|
||||
for (int i = embd.size(); i < (int) embd_inp.size() + n_predict; i++) {
|
||||
// predict
|
||||
if (embd.size() > 0) {
|
||||
if (!embd.empty()) {
|
||||
if (!gpt2_eval(ctx->model, ctx->n_threads, n_past, embd, embd_w, mem_per_token)) {
|
||||
printf("gpt-2: failed to generate text\n");
|
||||
return "";
|
||||
@ -913,10 +801,7 @@ std::string gpt2_gen_text(gpt2_context * ctx, const char * text, int max_tokens)
|
||||
result += ctx->vocab.id_to_token[embd[0]];
|
||||
|
||||
// end of text token
|
||||
if (embd.back() == 50256 ||
|
||||
ctx->vocab.id_to_token[embd.back()] == "." ||
|
||||
ctx->vocab.id_to_token[embd.back()] == "!" ||
|
||||
ctx->vocab.id_to_token[embd.back()] == "?") {
|
||||
if (embd.back() == 50256) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user