mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-06-24 17:15:19 +00:00
Compare commits
164 Commits
Author | SHA1 | Date | |
---|---|---|---|
18d5ff8695 | |||
14bee39b29 | |||
d458fcbc15 | |||
919e58b96a | |||
05bef0f0e9 | |||
5974c8facd | |||
0bcb64b184 | |||
0bf680fea2 | |||
b806420873 | |||
be5911a9f3 | |||
d375d73b2e | |||
7765770f89 | |||
872a85ae94 | |||
9c61f5f585 | |||
c94c469592 | |||
feac80dd3f | |||
fa8dbdc888 | |||
4a7d49af95 | |||
794b162a46 | |||
5fd1bdd7fc | |||
0ccd6746c9 | |||
d9b550c0a1 | |||
e9b091c92a | |||
1f30b99208 | |||
05c3ea3bc8 | |||
6108d3cc58 | |||
bab97c83d0 | |||
3eaeb030ff | |||
acec73ab6e | |||
5cc17418c7 | |||
3efb81dec6 | |||
94a7cd2a07 | |||
3e82ff4747 | |||
b5bd2f43c5 | |||
94aa56f19e | |||
4d89ee2e59 | |||
70567eff23 | |||
02ec83c5d5 | |||
2bd4b8d577 | |||
eecf2c3d41 | |||
c23588cc4b | |||
5108b30e6d | |||
f19e23fbd1 | |||
ea1f8a50d4 | |||
3dead611bb | |||
355da83690 | |||
3e5c49e59a | |||
5e47e223bd | |||
794ff3074a | |||
7e2afa4384 | |||
1c5edc3cb3 | |||
34b772727d | |||
2c856fb9e5 | |||
7727a40dc9 | |||
b5639ed313 | |||
2c4ac2627d | |||
674a8e579b | |||
001083a769 | |||
62b51c3070 | |||
61128870b8 | |||
78548dc03f | |||
66110dafcc | |||
b73a4638ac | |||
5f16420333 | |||
ccb47e7e10 | |||
677ad754a0 | |||
514cd04452 | |||
6704a81255 | |||
463e46338c | |||
2f889132c6 | |||
ebef1e8620 | |||
114df388fe | |||
ea36831459 | |||
69b8503935 | |||
0a2d1210bc | |||
859ffc994e | |||
5e6e2187a3 | |||
a7f1f33715 | |||
86ecfc6333 | |||
18e6fb0287 | |||
0f759f125d | |||
eefed45e37 | |||
aac1710afb | |||
21c1e6afc5 | |||
a47e812a54 | |||
42c6855103 | |||
0be9cd3497 | |||
e5c197d8aa | |||
7cd1d3bc34 | |||
82637b8e9f | |||
4a0deb8b1e | |||
8e361d90d7 | |||
fc49c44426 | |||
aec01bb337 | |||
21165580a1 | |||
1d749919e3 | |||
d4fa0d92ad | |||
a5e60c019d | |||
8fcd1a3b32 | |||
992aa2cd1b | |||
4aa3bcf8a4 | |||
1beff6f66d | |||
09e9068007 | |||
fa9d43181f | |||
bb6b54a03d | |||
b597c5a779 | |||
a3fb6c507f | |||
59fdcd19c8 | |||
478289a4b3 | |||
5e94129cb2 | |||
72af0f5697 | |||
af005d573f | |||
ad1389003d | |||
f420de1322 | |||
d176160f6f | |||
ca21f7ab16 | |||
373043cabe | |||
fb4d0d470f | |||
0d229163bb | |||
f254e78737 | |||
a94897bcde | |||
2407ae8ef0 | |||
b623ca43b1 | |||
69e6e4644a | |||
09d7d2b68e | |||
0336161b7d | |||
459753342d | |||
9764782bd9 | |||
3b010f9bed | |||
113fcec513 | |||
cfc06bf8df | |||
2bfe0ebc0f | |||
4dd7119deb | |||
ab1916fc59 | |||
a1c1583cc7 | |||
d012b5c7e4 | |||
b2083c5d02 | |||
f3ee4a9673 | |||
c306a7fd89 | |||
b2fc4c7010 | |||
291980369c | |||
86ef64a855 | |||
3b1960520a | |||
2bee2650c6 | |||
beb9512be3 | |||
47737b2e82 | |||
b992f3709e | |||
60337f5306 | |||
02c7516c57 | |||
411ea9b833 | |||
11f61cecd6 | |||
b5ddb16ec7 | |||
ae16c21e9c | |||
2c3f50a021 | |||
9a65269a20 | |||
78f166174f | |||
21c569ba4a | |||
1a91c19af9 | |||
f583e2d2f5 | |||
206fc93396 | |||
a6cf6f4c4a | |||
472a473fd1 | |||
9ba66c2fad | |||
1ccb8a46a5 |
@ -1,13 +1,18 @@
|
||||
name: Bindings Tests
|
||||
name: Bindings Tests (Go)
|
||||
on:
|
||||
push:
|
||||
paths:
|
||||
- bindings/go/**
|
||||
- whisper.h
|
||||
pull_request:
|
||||
paths:
|
||||
- bindings/go/**
|
||||
- whisper.h
|
||||
|
||||
jobs:
|
||||
ubuntu-latest:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
ubuntu-latest:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: '^1.19'
|
22
.github/workflows/bindings-ruby.yml
vendored
Normal file
22
.github/workflows/bindings-ruby.yml
vendored
Normal file
@ -0,0 +1,22 @@
|
||||
name: Bindings Tests (Ruby)
|
||||
on:
|
||||
push:
|
||||
paths:
|
||||
- bindings/ruby/**
|
||||
- whisper.h
|
||||
pull_request:
|
||||
paths:
|
||||
- bindings/ruby/**
|
||||
- whisper.h
|
||||
|
||||
jobs:
|
||||
ubuntu-latest:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: ruby/setup-ruby@v1
|
||||
with:
|
||||
ruby-version: '3.0'
|
||||
- uses: actions/checkout@v1
|
||||
- run: |
|
||||
cd bindings/ruby/ext
|
||||
ruby extconf.rb && make
|
522
.github/workflows/build.yml
vendored
522
.github/workflows/build.yml
vendored
@ -1,267 +1,363 @@
|
||||
name: CI
|
||||
on: [push]
|
||||
on: [push, pull_request]
|
||||
|
||||
jobs:
|
||||
ubuntu-latest:
|
||||
runs-on: ubuntu-latest
|
||||
ubuntu-latest:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v1
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v1
|
||||
|
||||
- name: Dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install build-essential
|
||||
sudo apt-get install libsdl2-dev
|
||||
- name: Dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install build-essential
|
||||
sudo apt-get install libsdl2-dev
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
make
|
||||
make stream
|
||||
- name: Build
|
||||
run: |
|
||||
make
|
||||
make stream
|
||||
|
||||
macOS-latest:
|
||||
runs-on: macOS-latest
|
||||
macOS-latest:
|
||||
runs-on: macOS-latest
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v1
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v1
|
||||
|
||||
- name: Dependencies
|
||||
run: |
|
||||
brew update
|
||||
brew install sdl2
|
||||
- name: Dependencies
|
||||
run: |
|
||||
brew update
|
||||
brew install sdl2
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
make
|
||||
make stream
|
||||
- name: Build
|
||||
run: |
|
||||
make
|
||||
make stream
|
||||
|
||||
ubuntu-latest-gcc:
|
||||
runs-on: ubuntu-latest
|
||||
ubuntu-latest-gcc:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
build: [Debug, Release]
|
||||
strategy:
|
||||
matrix:
|
||||
build: [Debug, Release]
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v1
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v1
|
||||
|
||||
- name: Dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install build-essential
|
||||
sudo apt-get install cmake
|
||||
sudo apt-get install libsdl2-dev
|
||||
- name: Dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install build-essential
|
||||
sudo apt-get install cmake
|
||||
sudo apt-get install libsdl2-dev
|
||||
|
||||
- name: Configure
|
||||
run: cmake . -DWHISPER_SUPPORT_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }}
|
||||
- name: Configure
|
||||
run: cmake . -DWHISPER_SUPPORT_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }}
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
make
|
||||
ctest -L gh --output-on-failure
|
||||
- name: Build
|
||||
run: |
|
||||
make
|
||||
ctest -L gh --output-on-failure
|
||||
|
||||
ubuntu-latest-clang:
|
||||
runs-on: ubuntu-latest
|
||||
ubuntu-latest-clang:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
build: [Debug, Release]
|
||||
strategy:
|
||||
matrix:
|
||||
build: [Debug, Release]
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v1
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v1
|
||||
|
||||
- name: Dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install build-essential
|
||||
sudo apt-get install cmake
|
||||
sudo apt-get install libsdl2-dev
|
||||
- name: Dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install build-essential
|
||||
sudo apt-get install cmake
|
||||
sudo apt-get install libsdl2-dev
|
||||
|
||||
- name: Configure
|
||||
run: cmake . -DWHISPER_SUPPORT_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_COMPILER=clang
|
||||
- name: Configure
|
||||
run: cmake . -DWHISPER_SUPPORT_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_COMPILER=clang
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
make
|
||||
ctest -L gh --output-on-failure
|
||||
- name: Build
|
||||
run: |
|
||||
make
|
||||
ctest -L gh --output-on-failure
|
||||
|
||||
ubuntu-latest-gcc-sanitized:
|
||||
runs-on: ubuntu-latest
|
||||
ubuntu-latest-gcc-sanitized:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
sanitizer: [ADDRESS, THREAD, UNDEFINED]
|
||||
strategy:
|
||||
matrix:
|
||||
sanitizer: [ADDRESS, THREAD, UNDEFINED]
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v1
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v1
|
||||
|
||||
- name: Dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install build-essential
|
||||
sudo apt-get install cmake
|
||||
- name: Dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install build-essential
|
||||
sudo apt-get install cmake
|
||||
|
||||
- name: Configure
|
||||
run: cmake . -DCMAKE_BUILD_TYPE=Debug -DWHISPER_SANITIZE_${{ matrix.sanitizer }}=ON
|
||||
- name: Configure
|
||||
run: cmake . -DCMAKE_BUILD_TYPE=Debug -DWHISPER_SANITIZE_${{ matrix.sanitizer }}=ON
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
make
|
||||
ctest -L gh --output-on-failure
|
||||
- name: Build
|
||||
run: |
|
||||
make
|
||||
ctest -L gh --output-on-failure
|
||||
|
||||
windows:
|
||||
runs-on: windows-latest
|
||||
windows:
|
||||
runs-on: windows-latest
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
build: [Release]
|
||||
arch: [Win32, x64]
|
||||
sdl2: [ON]
|
||||
include:
|
||||
- arch: Win32
|
||||
s2arc: x86
|
||||
- arch: x64
|
||||
s2arc: x64
|
||||
- sdl2: ON
|
||||
s2ver: 2.26.0
|
||||
strategy:
|
||||
matrix:
|
||||
build: [Release]
|
||||
arch: [Win32, x64]
|
||||
sdl2: [ON]
|
||||
include:
|
||||
- arch: Win32
|
||||
s2arc: x86
|
||||
- arch: x64
|
||||
s2arc: x64
|
||||
- sdl2: ON
|
||||
s2ver: 2.26.0
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v1
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v1
|
||||
|
||||
- name: Add msbuild to PATH
|
||||
uses: microsoft/setup-msbuild@v1
|
||||
- name: Add msbuild to PATH
|
||||
uses: microsoft/setup-msbuild@v1
|
||||
|
||||
- name: Fetch SDL2 and set SDL2_DIR
|
||||
if: matrix.sdl2 == 'ON'
|
||||
run: |
|
||||
C:/msys64/usr/bin/wget.exe -qO sdl2.zip https://github.com/libsdl-org/SDL/releases/download/release-${{ matrix.s2ver }}/SDL2-devel-${{ matrix.s2ver }}-VC.zip
|
||||
7z x sdl2.zip
|
||||
echo "SDL2_DIR=$env:GITHUB_WORKSPACE/SDL2-${{ matrix.s2ver }}/cmake" >> $env:GITHUB_ENV
|
||||
- name: Fetch SDL2 and set SDL2_DIR
|
||||
if: matrix.sdl2 == 'ON'
|
||||
run: |
|
||||
C:/msys64/usr/bin/wget.exe -qO sdl2.zip https://github.com/libsdl-org/SDL/releases/download/release-${{ matrix.s2ver }}/SDL2-devel-${{ matrix.s2ver }}-VC.zip
|
||||
7z x sdl2.zip
|
||||
echo "SDL2_DIR=$env:GITHUB_WORKSPACE/SDL2-${{ matrix.s2ver }}/cmake" >> $env:GITHUB_ENV
|
||||
|
||||
- name: Configure
|
||||
run: >
|
||||
cmake -S . -B ./build -A ${{ matrix.arch }}
|
||||
-DCMAKE_BUILD_TYPE=${{ matrix.build }}
|
||||
-DWHISPER_SUPPORT_SDL2=${{ matrix.sdl2 }}
|
||||
- name: Configure
|
||||
run: >
|
||||
cmake -S . -B ./build -A ${{ matrix.arch }}
|
||||
-DCMAKE_BUILD_TYPE=${{ matrix.build }}
|
||||
-DWHISPER_SUPPORT_SDL2=${{ matrix.sdl2 }}
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
cd ./build
|
||||
msbuild ALL_BUILD.vcxproj -t:build -p:configuration=${{ matrix.build }} -p:platform=${{ matrix.arch }}
|
||||
- name: Build
|
||||
run: |
|
||||
cd ./build
|
||||
msbuild ALL_BUILD.vcxproj -t:build -p:configuration=${{ matrix.build }} -p:platform=${{ matrix.arch }}
|
||||
|
||||
- name: Copy SDL2.dll
|
||||
if: matrix.sdl2 == 'ON'
|
||||
run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }}
|
||||
- name: Copy SDL2.dll
|
||||
if: matrix.sdl2 == 'ON'
|
||||
run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }}
|
||||
|
||||
- name: Upload binaries
|
||||
if: matrix.sdl2 == 'ON'
|
||||
uses: actions/upload-artifact@v1
|
||||
with:
|
||||
name: whisper-bin-${{ matrix.arch }}
|
||||
path: build/bin/${{ matrix.build }}
|
||||
- name: Upload binaries
|
||||
if: matrix.sdl2 == 'ON'
|
||||
uses: actions/upload-artifact@v1
|
||||
with:
|
||||
name: whisper-bin-${{ matrix.arch }}
|
||||
path: build/bin/${{ matrix.build }}
|
||||
|
||||
windows-blas:
|
||||
runs-on: windows-latest
|
||||
windows-blas:
|
||||
runs-on: windows-latest
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
build: [Release]
|
||||
arch: [Win32, x64]
|
||||
blas: [ON]
|
||||
sdl2: [ON]
|
||||
include:
|
||||
- arch: Win32
|
||||
obzip: https://github.com/xianyi/OpenBLAS/releases/download/v0.3.21/OpenBLAS-0.3.21-x86.zip
|
||||
s2arc: x86
|
||||
- arch: x64
|
||||
obzip: https://github.com/xianyi/OpenBLAS/releases/download/v0.3.21/OpenBLAS-0.3.21-x64.zip
|
||||
s2arc: x64
|
||||
- sdl2: ON
|
||||
s2ver: 2.26.0
|
||||
strategy:
|
||||
matrix:
|
||||
build: [Release]
|
||||
arch: [Win32, x64]
|
||||
blas: [ON]
|
||||
sdl2: [ON]
|
||||
include:
|
||||
- arch: Win32
|
||||
obzip: https://github.com/xianyi/OpenBLAS/releases/download/v0.3.21/OpenBLAS-0.3.21-x86.zip
|
||||
s2arc: x86
|
||||
- arch: x64
|
||||
obzip: https://github.com/xianyi/OpenBLAS/releases/download/v0.3.21/OpenBLAS-0.3.21-x64.zip
|
||||
s2arc: x64
|
||||
- sdl2: ON
|
||||
s2ver: 2.26.0
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v1
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v1
|
||||
|
||||
- name: Add msbuild to PATH
|
||||
uses: microsoft/setup-msbuild@v1
|
||||
- name: Add msbuild to PATH
|
||||
uses: microsoft/setup-msbuild@v1
|
||||
|
||||
- name: Fetch OpenBLAS
|
||||
if: matrix.blas == 'ON'
|
||||
run: |
|
||||
C:/msys64/usr/bin/wget.exe -qO blas.zip ${{ matrix.obzip }}
|
||||
7z x blas.zip -oblas -y
|
||||
copy blas/include/cblas.h .
|
||||
copy blas/include/openblas_config.h .
|
||||
echo "blasdir=$env:GITHUB_WORKSPACE/blas" >> $env:GITHUB_ENV
|
||||
- name: Fetch OpenBLAS
|
||||
if: matrix.blas == 'ON'
|
||||
run: |
|
||||
C:/msys64/usr/bin/wget.exe -qO blas.zip ${{ matrix.obzip }}
|
||||
7z x blas.zip -oblas -y
|
||||
copy blas/include/cblas.h .
|
||||
copy blas/include/openblas_config.h .
|
||||
echo "blasdir=$env:GITHUB_WORKSPACE/blas" >> $env:GITHUB_ENV
|
||||
|
||||
- name: Fetch SDL2 and set SDL2_DIR
|
||||
if: matrix.sdl2 == 'ON'
|
||||
run: |
|
||||
C:/msys64/usr/bin/wget.exe -qO sdl2.zip https://github.com/libsdl-org/SDL/releases/download/release-${{ matrix.s2ver }}/SDL2-devel-${{ matrix.s2ver }}-VC.zip
|
||||
7z x sdl2.zip
|
||||
echo "SDL2_DIR=$env:GITHUB_WORKSPACE/SDL2-${{ matrix.s2ver }}/cmake" >> $env:GITHUB_ENV
|
||||
- name: Fetch SDL2 and set SDL2_DIR
|
||||
if: matrix.sdl2 == 'ON'
|
||||
run: |
|
||||
C:/msys64/usr/bin/wget.exe -qO sdl2.zip https://github.com/libsdl-org/SDL/releases/download/release-${{ matrix.s2ver }}/SDL2-devel-${{ matrix.s2ver }}-VC.zip
|
||||
7z x sdl2.zip
|
||||
echo "SDL2_DIR=$env:GITHUB_WORKSPACE/SDL2-${{ matrix.s2ver }}/cmake" >> $env:GITHUB_ENV
|
||||
|
||||
- name: Configure
|
||||
run: >
|
||||
cmake -S . -B ./build -A ${{ matrix.arch }}
|
||||
-DCMAKE_BUILD_TYPE=${{ matrix.build }}
|
||||
-DWHISPER_SUPPORT_OPENBLAS=${{ matrix.blas }}
|
||||
-DCMAKE_LIBRARY_PATH="$env:blasdir/lib"
|
||||
-DWHISPER_SUPPORT_SDL2=${{ matrix.sdl2 }}
|
||||
- name: Configure
|
||||
run: >
|
||||
cmake -S . -B ./build -A ${{ matrix.arch }}
|
||||
-DCMAKE_BUILD_TYPE=${{ matrix.build }}
|
||||
-DWHISPER_SUPPORT_OPENBLAS=${{ matrix.blas }}
|
||||
-DCMAKE_LIBRARY_PATH="$env:blasdir/lib"
|
||||
-DWHISPER_SUPPORT_SDL2=${{ matrix.sdl2 }}
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
cd ./build
|
||||
msbuild ALL_BUILD.vcxproj -t:build -p:configuration=${{ matrix.build }} -p:platform=${{ matrix.arch }}
|
||||
- name: Build
|
||||
run: |
|
||||
cd ./build
|
||||
msbuild ALL_BUILD.vcxproj -t:build -p:configuration=${{ matrix.build }} -p:platform=${{ matrix.arch }}
|
||||
|
||||
- name: Copy libopenblas.dll
|
||||
if: matrix.blas == 'ON'
|
||||
run: copy "$env:blasdir/bin/libopenblas.dll" build/bin/${{ matrix.build }}
|
||||
- name: Copy libopenblas.dll
|
||||
if: matrix.blas == 'ON'
|
||||
run: copy "$env:blasdir/bin/libopenblas.dll" build/bin/${{ matrix.build }}
|
||||
|
||||
- name: Copy SDL2.dll
|
||||
if: matrix.sdl2 == 'ON'
|
||||
run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }}
|
||||
- name: Copy SDL2.dll
|
||||
if: matrix.sdl2 == 'ON'
|
||||
run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }}
|
||||
|
||||
- name: Upload binaries
|
||||
if: matrix.blas == 'ON' && matrix.sdl2 == 'ON'
|
||||
uses: actions/upload-artifact@v1
|
||||
with:
|
||||
name: whisper-blas-bin-${{ matrix.arch }}
|
||||
path: build/bin/${{ matrix.build }}
|
||||
- name: Upload binaries
|
||||
if: matrix.blas == 'ON' && matrix.sdl2 == 'ON'
|
||||
uses: actions/upload-artifact@v1
|
||||
with:
|
||||
name: whisper-blas-bin-${{ matrix.arch }}
|
||||
path: build/bin/${{ matrix.build }}
|
||||
|
||||
emscripten:
|
||||
runs-on: ubuntu-latest
|
||||
windows-cublas:
|
||||
runs-on: windows-latest
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
build: [Release]
|
||||
strategy:
|
||||
matrix:
|
||||
build: [Release]
|
||||
arch: [x64]
|
||||
cublas: [ON]
|
||||
sdl2: [ON]
|
||||
include:
|
||||
- arch: x64
|
||||
s2arc: x64
|
||||
- sdl2: ON
|
||||
s2ver: 2.26.0
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v1
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v1
|
||||
|
||||
- name: Dependencies
|
||||
run: |
|
||||
wget -q https://github.com/emscripten-core/emsdk/archive/master.tar.gz
|
||||
tar -xvf master.tar.gz
|
||||
emsdk-master/emsdk update
|
||||
emsdk-master/emsdk install latest
|
||||
emsdk-master/emsdk activate latest
|
||||
- name: Add msbuild to PATH
|
||||
uses: microsoft/setup-msbuild@v1
|
||||
|
||||
- name: Configure
|
||||
run: echo "tmp"
|
||||
- name: Install CUDA Toolkit
|
||||
id: cuda-toolkit
|
||||
uses: Jimver/cuda-toolkit@v0.2.10
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
pushd emsdk-master
|
||||
source ./emsdk_env.sh
|
||||
popd
|
||||
emcmake cmake . -DCMAKE_BUILD_TYPE=${{ matrix.build }}
|
||||
make
|
||||
- name: Fetch SDL2 and set SDL2_DIR
|
||||
if: matrix.sdl2 == 'ON'
|
||||
run: |
|
||||
C:/msys64/usr/bin/wget.exe -qO sdl2.zip https://github.com/libsdl-org/SDL/releases/download/release-${{ matrix.s2ver }}/SDL2-devel-${{ matrix.s2ver }}-VC.zip
|
||||
7z x sdl2.zip
|
||||
echo "SDL2_DIR=$env:GITHUB_WORKSPACE/SDL2-${{ matrix.s2ver }}/cmake" >> $env:GITHUB_ENV
|
||||
|
||||
- name: Configure
|
||||
run: >
|
||||
cmake -S . -B ./build -A ${{ matrix.arch }}
|
||||
-DCMAKE_BUILD_TYPE=${{ matrix.build }}
|
||||
-DWHISPER_CUBLAS=1
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
cd ./build
|
||||
msbuild ALL_BUILD.vcxproj -t:build -p:configuration=${{ matrix.build }} -p:platform=${{ matrix.arch }}
|
||||
|
||||
- name: Copy SDL2.dll
|
||||
if: matrix.sdl2 == 'ON'
|
||||
run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }}
|
||||
|
||||
- name: Upload binaries
|
||||
if: matrix.sdl2 == 'ON'
|
||||
uses: actions/upload-artifact@v1
|
||||
with:
|
||||
name: whisper-cublas-bin-${{ matrix.arch }}
|
||||
path: build/bin/${{ matrix.build }}
|
||||
|
||||
emscripten:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
build: [Release]
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v1
|
||||
|
||||
- name: Dependencies
|
||||
run: |
|
||||
wget -q https://github.com/emscripten-core/emsdk/archive/master.tar.gz
|
||||
tar -xvf master.tar.gz
|
||||
emsdk-master/emsdk update
|
||||
emsdk-master/emsdk install latest
|
||||
emsdk-master/emsdk activate latest
|
||||
|
||||
- name: Configure
|
||||
run: echo "tmp"
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
pushd emsdk-master
|
||||
source ./emsdk_env.sh
|
||||
popd
|
||||
emcmake cmake . -DCMAKE_BUILD_TYPE=${{ matrix.build }}
|
||||
make
|
||||
|
||||
ios:
|
||||
runs-on: macos-latest
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
build: [Release]
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v1
|
||||
|
||||
- name: Configure
|
||||
run: cp models/for-tests-ggml-base.en.bin models/ggml-base.en.bin
|
||||
|
||||
- name: Build objc example
|
||||
run: xcodebuild -project examples/whisper.objc/whisper.objc.xcodeproj -scheme whisper.objc -configuration ${{ matrix.build }} -sdk iphonesimulator build
|
||||
|
||||
- name: Build swiftui example
|
||||
run: xcodebuild -project examples/whisper.swiftui/whisper.swiftui.xcodeproj -scheme WhisperCppDemo -configuration ${{ matrix.build }} -sdk iphonesimulator build
|
||||
|
||||
android:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v1
|
||||
|
||||
- name: Install Java
|
||||
uses: actions/setup-java@v3
|
||||
with:
|
||||
distribution: zulu
|
||||
java-version: 17
|
||||
|
||||
- name: Setup Android SDK
|
||||
uses: android-actions/setup-android@v2
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
cd examples/whisper.android
|
||||
./gradlew assembleRelease --no-daemon
|
||||
|
48
.github/workflows/examples.yml
vendored
Normal file
48
.github/workflows/examples.yml
vendored
Normal file
@ -0,0 +1,48 @@
|
||||
name: Examples Tests
|
||||
on:
|
||||
push:
|
||||
paths:
|
||||
- examples/addon.node/**
|
||||
- whisper.h
|
||||
pull_request:
|
||||
paths:
|
||||
- examples/addon.node/**
|
||||
- whisper.h
|
||||
|
||||
jobs:
|
||||
addon_node-ubuntu-latest:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
node-version: [ 16.x, 18.x ]
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v1
|
||||
|
||||
- name: Dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install build-essential
|
||||
sudo apt-get install cmake
|
||||
sudo apt-get install libsdl2-dev
|
||||
|
||||
- name: Use Node.js ${{ matrix.node-version }}
|
||||
uses: actions/setup-node@v1
|
||||
with:
|
||||
node-version: ${{ matrix.node-version }}
|
||||
cache: 'npm'
|
||||
|
||||
- name: Install package.json dependencies
|
||||
working-directory: ./examples/addon.node
|
||||
run: npm install
|
||||
|
||||
- name: Compile addon.node
|
||||
run: npx cmake-js compile -T whisper-addon -B Release
|
||||
|
||||
- name: Download test model
|
||||
run: |
|
||||
bash ./models/download-ggml-model.sh base.en
|
||||
- name: Test
|
||||
run: |
|
||||
cd examples/addon.node
|
||||
npm run test
|
64
.github/workflows/release-deb.yml
vendored
Normal file
64
.github/workflows/release-deb.yml
vendored
Normal file
@ -0,0 +1,64 @@
|
||||
name: release-deb
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [created]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-20.04
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
- name: Configure
|
||||
run: |
|
||||
set -x -e
|
||||
VERSION=$(echo $GITHUB_REF | cut --delimiter=/ -f 3)
|
||||
ID="whisper-cpp-small_${VERSION}_amd64"
|
||||
|
||||
echo "PKG_VERSION=$VERSION" >> $GITHUB_ENV
|
||||
echo "PKG_ID=$ID" >> $GITHUB_ENV
|
||||
|
||||
- name: Install deps
|
||||
run: |
|
||||
sudo apt install -y --no-install-recommends intel-mkl
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
cmake -S . -B build-release -D BUILD_SHARED_LIBS=OFF
|
||||
cd build-release
|
||||
make
|
||||
cd ..
|
||||
|
||||
- name: Create package tree
|
||||
env:
|
||||
GITHUB_REPO: ${{ github.repository }}
|
||||
run: |
|
||||
export ROOT=$PKG_ID/opt/project/whisper.cpp
|
||||
mkdir -p $ROOT/bin
|
||||
mkdir -p $ROOT/share
|
||||
mkdir -p $PKG_ID/DEBIAN
|
||||
|
||||
cp build-release/bin/main $ROOT/bin/whisper
|
||||
cp -r contrib/debian/control $PKG_ID/DEBIAN/
|
||||
|
||||
echo "Version: $PKG_VERSION" >> $PKG_ID/DEBIAN/control
|
||||
echo "Git-Repo: $GITHUB_REPO" >> $PKG_ID/DEBIAN/control
|
||||
echo "Git-Commit: $GITHUB_SHA" >> $PKG_ID/DEBIAN/control
|
||||
|
||||
models/download-ggml-model.sh small
|
||||
build-release/bin/quantize models/ggml-small.bin \
|
||||
$ROOT/share/ggml-small-q5_1.bin q5_1
|
||||
|
||||
- name: Create deb package
|
||||
run: |
|
||||
mkdir artifacts
|
||||
dpkg-deb --build --root-owner-group $PKG_ID
|
||||
|
||||
- name: Upload Release Asset
|
||||
uses: xresloader/upload-to-github-release@v1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
with:
|
||||
release_id: ${{ github.event.release.id }}
|
||||
file: ${{ env.PKG_ID }}.deb
|
14
.gitignore
vendored
14
.gitignore
vendored
@ -1,7 +1,11 @@
|
||||
*.o
|
||||
*.a
|
||||
.cache/
|
||||
.coreml/
|
||||
.test/
|
||||
.vs/
|
||||
.vscode/
|
||||
.idea/
|
||||
.DS_Store
|
||||
|
||||
build/
|
||||
@ -9,15 +13,21 @@ build-em/
|
||||
build-debug/
|
||||
build-release/
|
||||
build-static/
|
||||
build-cublas/
|
||||
build-no-accel/
|
||||
build-sanitize-addr/
|
||||
build-sanitize-thread/
|
||||
cmake-build-debug/
|
||||
|
||||
/main
|
||||
/stream
|
||||
/command
|
||||
/talk
|
||||
/talk-llama
|
||||
/bench
|
||||
/quantize
|
||||
|
||||
arm_neon.h
|
||||
sync.sh
|
||||
libwhisper.a
|
||||
libwhisper.so
|
||||
@ -29,3 +39,7 @@ examples/whisper.objc/whisper.objc.xcodeproj/xcuserdata/
|
||||
examples/whisper.objc/whisper.objc.xcodeproj/project.xcworkspace/xcuserdata
|
||||
|
||||
extra/bench-gg.txt
|
||||
|
||||
models/*.mlmodel
|
||||
models/*.mlmodelc
|
||||
models/*.mlpackage
|
||||
|
173
CMakeLists.txt
173
CMakeLists.txt
@ -1,6 +1,6 @@
|
||||
cmake_minimum_required (VERSION 3.0)
|
||||
|
||||
project(whisper.cpp VERSION 1.1.0)
|
||||
project(whisper.cpp VERSION 1.4.1)
|
||||
|
||||
# Add path to modules
|
||||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
|
||||
@ -35,30 +35,36 @@ endif()
|
||||
|
||||
# options
|
||||
|
||||
option(BUILD_SHARED_LIBS "whisper: build shared libs" ${BUILD_SHARED_LIBS_DEFAULT})
|
||||
option(BUILD_SHARED_LIBS "whisper: build shared libs" ${BUILD_SHARED_LIBS_DEFAULT})
|
||||
|
||||
option(WHISPER_ALL_WARNINGS "whisper: enable all compiler warnings" ON)
|
||||
option(WHISPER_ALL_WARNINGS_3RD_PARTY "whisper: enable all compiler warnings in 3rd party libs" OFF)
|
||||
option(WHISPER_ALL_WARNINGS "whisper: enable all compiler warnings" ON)
|
||||
option(WHISPER_ALL_WARNINGS_3RD_PARTY "whisper: enable all compiler warnings in 3rd party libs" OFF)
|
||||
|
||||
option(WHISPER_SANITIZE_THREAD "whisper: enable thread sanitizer" OFF)
|
||||
option(WHISPER_SANITIZE_ADDRESS "whisper: enable address sanitizer" OFF)
|
||||
option(WHISPER_SANITIZE_UNDEFINED "whisper: enable undefined sanitizer" OFF)
|
||||
option(WHISPER_SANITIZE_THREAD "whisper: enable thread sanitizer" OFF)
|
||||
option(WHISPER_SANITIZE_ADDRESS "whisper: enable address sanitizer" OFF)
|
||||
option(WHISPER_SANITIZE_UNDEFINED "whisper: enable undefined sanitizer" OFF)
|
||||
|
||||
option(WHISPER_BUILD_TESTS "whisper: build tests" ${WHISPER_STANDALONE})
|
||||
option(WHISPER_BUILD_EXAMPLES "whisper: build examples" ${WHISPER_STANDALONE})
|
||||
option(WHISPER_BUILD_TESTS "whisper: build tests" ${WHISPER_STANDALONE})
|
||||
option(WHISPER_BUILD_EXAMPLES "whisper: build examples" ${WHISPER_STANDALONE})
|
||||
|
||||
option(WHISPER_SUPPORT_SDL2 "whisper: support for libSDL2" OFF)
|
||||
option(WHISPER_SDL2 "whisper: support for libSDL2" OFF)
|
||||
|
||||
option(WHISPER_NO_AVX "whisper: disable AVX" OFF)
|
||||
option(WHISPER_NO_AVX2 "whisper: disable AVX2" OFF)
|
||||
option(WHISPER_NO_FMA "whisper: disable FMA" OFF)
|
||||
option(WHISPER_NO_F16C "whisper: disable F16c" OFF)
|
||||
|
||||
if (APPLE)
|
||||
option(WHISPER_NO_ACCELERATE "whisper: disable Accelerate framework" OFF)
|
||||
option(WHISPER_NO_AVX "whisper: disable AVX" OFF)
|
||||
option(WHISPER_NO_AVX2 "whisper: disable AVX2" OFF)
|
||||
option(WHISPER_NO_FMA "whisper: disable FMA" OFF)
|
||||
option(WHISPER_NO_ACCELERATE "whisper: disable Accelerate framework" OFF)
|
||||
option(WHISPER_COREML "whisper: enable Core ML framework" OFF)
|
||||
option(WHISPER_COREML_ALLOW_FALLBACK "whisper: allow non-CoreML fallback" OFF)
|
||||
else()
|
||||
option(WHISPER_SUPPORT_OPENBLAS "whisper: support for OpenBLAS" OFF)
|
||||
option(WHISPER_OPENBLAS "whisper: support for OpenBLAS" OFF)
|
||||
option(WHISPER_CUBLAS "whisper: support for cuBLAS" OFF)
|
||||
option(WHISPER_CLBLAST "whisper: use CLBlast" OFF)
|
||||
endif()
|
||||
|
||||
option(WHISPER_PERF "whisper: enable perf timings" OFF)
|
||||
option(WHISPER_PERF "whisper: enable perf timings" OFF)
|
||||
|
||||
# sanitizers
|
||||
|
||||
@ -86,20 +92,41 @@ endif()
|
||||
|
||||
find_package(Threads REQUIRED)
|
||||
|
||||
# on APPLE - include Accelerate framework
|
||||
if (APPLE AND NOT WHISPER_NO_ACCELERATE)
|
||||
find_library(ACCELERATE_FRAMEWORK Accelerate)
|
||||
if (ACCELERATE_FRAMEWORK)
|
||||
message(STATUS "Accelerate framework found")
|
||||
# on APPLE
|
||||
if (APPLE)
|
||||
# include Accelerate framework
|
||||
if (NOT WHISPER_NO_ACCELERATE)
|
||||
find_library(ACCELERATE_FRAMEWORK Accelerate)
|
||||
|
||||
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ${ACCELERATE_FRAMEWORK})
|
||||
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_USE_ACCELERATE)
|
||||
else()
|
||||
message(WARNING "Accelerate framework not found")
|
||||
if (ACCELERATE_FRAMEWORK)
|
||||
message(STATUS "Accelerate framework found")
|
||||
|
||||
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ${ACCELERATE_FRAMEWORK})
|
||||
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_USE_ACCELERATE)
|
||||
else()
|
||||
message(WARNING "Accelerate framework not found")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (WHISPER_COREML)
|
||||
find_library(FOUNDATION_FRAMEWORK Foundation)
|
||||
find_library(COREML_FRAMEWORK CoreML)
|
||||
|
||||
if (COREML_FRAMEWORK)
|
||||
message(STATUS "CoreML framework found")
|
||||
|
||||
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DWHISPER_USE_COREML)
|
||||
else()
|
||||
message(WARNING "CoreML framework not found")
|
||||
endif()
|
||||
|
||||
if (WHISPER_COREML_ALLOW_FALLBACK)
|
||||
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DWHISPER_USE_COREML_ALLOW_FALLBACK)
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (WHISPER_SUPPORT_OPENBLAS)
|
||||
if (WHISPER_OPENBLAS)
|
||||
find_library(OPENBLAS_LIB
|
||||
NAMES openblas libopenblas
|
||||
)
|
||||
@ -113,6 +140,46 @@ if (WHISPER_SUPPORT_OPENBLAS)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (WHISPER_CUBLAS)
|
||||
cmake_minimum_required(VERSION 3.17)
|
||||
|
||||
find_package(CUDAToolkit)
|
||||
|
||||
if (CUDAToolkit_FOUND)
|
||||
message(STATUS "cuBLAS found")
|
||||
|
||||
enable_language(CUDA)
|
||||
|
||||
set(GGML_CUDA_SOURCES ggml-cuda.cu ggml-cuda.h)
|
||||
|
||||
add_compile_definitions(GGML_USE_CUBLAS)
|
||||
|
||||
if (WHISPER_STATIC)
|
||||
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
|
||||
else()
|
||||
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
|
||||
endif()
|
||||
|
||||
else()
|
||||
message(WARNING "cuBLAS not found")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (WHISPER_CLBLAST)
|
||||
find_package(CLBlast)
|
||||
if (CLBlast_FOUND)
|
||||
message(STATUS "CLBlast found")
|
||||
|
||||
set(GGML_OPENCL_SOURCES ggml-opencl.c ggml-opencl.h)
|
||||
|
||||
add_compile_definitions(GGML_USE_CLBLAST)
|
||||
|
||||
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} clblast)
|
||||
else()
|
||||
message(WARNING "CLBlast not found")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# compiler flags
|
||||
|
||||
if (NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES)
|
||||
@ -172,7 +239,9 @@ else()
|
||||
if(NOT WHISPER_NO_FMA)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mfma")
|
||||
endif()
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mf16c")
|
||||
if(NOT WHISPER_NO_F16C)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mf16c")
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
@ -181,6 +250,33 @@ if (WHISPER_PERF)
|
||||
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_PERF)
|
||||
endif()
|
||||
|
||||
#
|
||||
# whisper.coreml - Core ML support
|
||||
#
|
||||
|
||||
if (WHISPER_COREML)
|
||||
set(TARGET whisper.coreml)
|
||||
|
||||
add_library(${TARGET}
|
||||
coreml/whisper-encoder.h
|
||||
coreml/whisper-encoder.mm
|
||||
coreml/whisper-encoder-impl.h
|
||||
coreml/whisper-encoder-impl.m
|
||||
)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_include_directories(${TARGET} PUBLIC
|
||||
.
|
||||
)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE ${FOUNDATION_FRAMEWORK} ${COREML_FRAMEWORK})
|
||||
|
||||
set_target_properties(${TARGET} PROPERTIES
|
||||
COMPILE_FLAGS "-fobjc-arc"
|
||||
)
|
||||
endif()
|
||||
|
||||
#
|
||||
# whisper - this is the main library of the project
|
||||
#
|
||||
@ -190,6 +286,8 @@ set(TARGET whisper)
|
||||
add_library(${TARGET}
|
||||
ggml.h
|
||||
ggml.c
|
||||
${GGML_CUDA_SOURCES}
|
||||
${GGML_OPENCL_SOURCES}
|
||||
whisper.h
|
||||
whisper.cpp
|
||||
)
|
||||
@ -200,6 +298,10 @@ target_include_directories(${TARGET} PUBLIC
|
||||
.
|
||||
)
|
||||
|
||||
if (WHISPER_COREML)
|
||||
target_link_libraries(${TARGET} PRIVATE whisper.coreml)
|
||||
endif()
|
||||
|
||||
if (MSVC)
|
||||
target_link_libraries(${TARGET} PRIVATE ${WHISPER_EXTRA_LIBS} ${CMAKE_THREAD_LIBS_INIT})
|
||||
|
||||
@ -215,7 +317,19 @@ if (BUILD_SHARED_LIBS)
|
||||
|
||||
target_compile_definitions(${TARGET} PUBLIC
|
||||
WHISPER_SHARED
|
||||
GGML_SHARED
|
||||
)
|
||||
|
||||
target_compile_definitions(${TARGET} PRIVATE
|
||||
WHISPER_BUILD
|
||||
GGML_BUILD
|
||||
)
|
||||
endif()
|
||||
|
||||
if (GGML_CUDA_SOURCES)
|
||||
message(STATUS "GGML CUDA sources found, configuring CUDA architecture")
|
||||
set_property(TARGET whisper PROPERTY CUDA_ARCHITECTURES OFF)
|
||||
set_property(TARGET whisper PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto")
|
||||
endif()
|
||||
|
||||
if (EMSCRIPTEN)
|
||||
@ -226,10 +340,13 @@ target_compile_definitions(${TARGET} PUBLIC
|
||||
${WHISPER_EXTRA_FLAGS}
|
||||
)
|
||||
|
||||
set_target_properties(${TARGET} PROPERTIES PUBLIC_HEADER "whisper.h")
|
||||
|
||||
install(TARGETS ${TARGET}
|
||||
LIBRARY DESTINATION lib
|
||||
ARCHIVE DESTINATION lib/static
|
||||
RUNTIME DESTINATION bin
|
||||
PUBLIC_HEADER DESTINATION include
|
||||
)
|
||||
|
||||
#
|
||||
@ -242,7 +359,7 @@ add_subdirectory(bindings)
|
||||
# programs, examples and tests
|
||||
#
|
||||
|
||||
if (WHISPER_BUILD_TESTS)
|
||||
if (WHISPER_BUILD_TESTS AND NOT CMAKE_JS_VERSION)
|
||||
enable_testing()
|
||||
add_subdirectory(tests)
|
||||
endif ()
|
||||
|
2
LICENSE
2
LICENSE
@ -1,6 +1,6 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2022 Georgi Gerganov
|
||||
Copyright (c) 2023 Georgi Gerganov
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
|
153
Makefile
153
Makefile
@ -1,3 +1,5 @@
|
||||
default: main bench quantize
|
||||
|
||||
ifndef UNAME_S
|
||||
UNAME_S := $(shell uname -s)
|
||||
endif
|
||||
@ -30,10 +32,16 @@ endif
|
||||
# Compile flags
|
||||
#
|
||||
|
||||
CFLAGS = -I. -O3 -std=c11 -fPIC
|
||||
CXXFLAGS = -I. -I./examples -O3 -std=c++11 -fPIC
|
||||
CFLAGS = -I. -O3 -DNDEBUG -std=c11 -fPIC
|
||||
CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC
|
||||
LDFLAGS =
|
||||
|
||||
# ref: https://github.com/ggerganov/whisper.cpp/issues/37
|
||||
ifneq ($(wildcard /usr/include/musl/*),)
|
||||
CFLAGS += -D_POSIX_SOURCE -D_GNU_SOURCE
|
||||
CXXFLAGS += -D_POSIX_SOURCE -D_GNU_SOURCE
|
||||
endif
|
||||
|
||||
# OS specific
|
||||
# TODO: support Windows
|
||||
ifeq ($(UNAME_S),Linux)
|
||||
@ -71,10 +79,6 @@ ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686))
|
||||
CFLAGS += -mavx2
|
||||
endif
|
||||
else ifeq ($(UNAME_S),Linux)
|
||||
AVX1_M := $(shell grep "avx " /proc/cpuinfo)
|
||||
ifneq (,$(findstring avx,$(AVX1_M)))
|
||||
CFLAGS += -mavx
|
||||
endif
|
||||
AVX2_M := $(shell grep "avx2 " /proc/cpuinfo)
|
||||
ifneq (,$(findstring avx2,$(AVX2_M)))
|
||||
CFLAGS += -mavx2
|
||||
@ -86,16 +90,17 @@ ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686))
|
||||
F16C_M := $(shell grep "f16c " /proc/cpuinfo)
|
||||
ifneq (,$(findstring f16c,$(F16C_M)))
|
||||
CFLAGS += -mf16c
|
||||
|
||||
AVX1_M := $(shell grep "avx " /proc/cpuinfo)
|
||||
ifneq (,$(findstring avx,$(AVX1_M)))
|
||||
CFLAGS += -mavx
|
||||
endif
|
||||
endif
|
||||
SSE3_M := $(shell grep "sse3 " /proc/cpuinfo)
|
||||
ifneq (,$(findstring sse3,$(SSE3_M)))
|
||||
CFLAGS += -msse3
|
||||
endif
|
||||
else ifeq ($(UNAME_S),Haiku)
|
||||
AVX1_M := $(shell sysinfo -cpu | grep "AVX ")
|
||||
ifneq (,$(findstring avx,$(AVX1_M)))
|
||||
CFLAGS += -mavx
|
||||
endif
|
||||
AVX2_M := $(shell sysinfo -cpu | grep "AVX2 ")
|
||||
ifneq (,$(findstring avx2,$(AVX2_M)))
|
||||
CFLAGS += -mavx2
|
||||
@ -107,6 +112,11 @@ ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686))
|
||||
F16C_M := $(shell sysinfo -cpu | grep "F16C ")
|
||||
ifneq (,$(findstring f16c,$(F16C_M)))
|
||||
CFLAGS += -mf16c
|
||||
|
||||
AVX1_M := $(shell sysinfo -cpu | grep "AVX ")
|
||||
ifneq (,$(findstring avx,$(AVX1_M)))
|
||||
CFLAGS += -mavx
|
||||
endif
|
||||
endif
|
||||
else
|
||||
CFLAGS += -mfma -mf16c -mavx -mavx2
|
||||
@ -115,12 +125,18 @@ endif
|
||||
ifeq ($(UNAME_M),amd64)
|
||||
CFLAGS += -mavx -mavx2 -mfma -mf16c
|
||||
endif
|
||||
ifeq ($(UNAME_M),ppc64le)
|
||||
|
||||
ifneq ($(filter ppc64%,$(UNAME_M)),)
|
||||
POWER9_M := $(shell grep "POWER9" /proc/cpuinfo)
|
||||
ifneq (,$(findstring POWER9,$(POWER9_M)))
|
||||
CFLAGS += -mpower9-vector
|
||||
endif
|
||||
# Require c++23's std::byteswap for big-endian support.
|
||||
ifeq ($(UNAME_M),ppc64)
|
||||
CXXFLAGS += -std=c++23 -DGGML_BIG_ENDIAN
|
||||
endif
|
||||
endif
|
||||
|
||||
ifndef WHISPER_NO_ACCELERATE
|
||||
# Mac M1 - include Accelerate framework
|
||||
ifeq ($(UNAME_S),Darwin)
|
||||
@ -128,27 +144,71 @@ ifndef WHISPER_NO_ACCELERATE
|
||||
LDFLAGS += -framework Accelerate
|
||||
endif
|
||||
endif
|
||||
|
||||
ifdef WHISPER_COREML
|
||||
CXXFLAGS += -DWHISPER_USE_COREML
|
||||
LDFLAGS += -framework Foundation -framework CoreML
|
||||
|
||||
ifdef WHISPER_COREML_ALLOW_FALLBACK
|
||||
CXXFLAGS += -DWHISPER_COREML_ALLOW_FALLBACK
|
||||
endif
|
||||
endif
|
||||
|
||||
ifdef WHISPER_OPENBLAS
|
||||
CFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/openblas
|
||||
LDFLAGS += -lopenblas
|
||||
endif
|
||||
|
||||
ifdef WHISPER_CUBLAS
|
||||
CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
|
||||
CXXFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
|
||||
LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib
|
||||
WHISPER_OBJ += ggml-cuda.o
|
||||
NVCC = nvcc
|
||||
NVCCFLAGS = --forward-unknown-to-host-compiler -arch=native
|
||||
|
||||
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
|
||||
$(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@
|
||||
endif
|
||||
|
||||
ifdef WHISPER_CLBLAST
|
||||
CFLAGS += -DGGML_USE_CLBLAST
|
||||
LDFLAGS += -lclblast -lOpenCL
|
||||
WHISPER_OBJ += ggml-opencl.o
|
||||
|
||||
ggml-opencl.o: ggml-opencl.c ggml-opencl.h
|
||||
$(CC) $(CFLAGS) -c $< -o $@
|
||||
endif
|
||||
|
||||
ifdef WHISPER_GPROF
|
||||
CFLAGS += -pg
|
||||
CXXFLAGS += -pg
|
||||
endif
|
||||
|
||||
ifneq ($(filter aarch64%,$(UNAME_M)),)
|
||||
CFLAGS += -mcpu=native
|
||||
CXXFLAGS += -mcpu=native
|
||||
endif
|
||||
|
||||
ifneq ($(filter armv6%,$(UNAME_M)),)
|
||||
# Raspberry Pi 1, 2, 3
|
||||
CFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access
|
||||
# 32-bit Raspberry Pi 1, 2, 3
|
||||
CFLAGS += -mfpu=neon -mfp16-format=ieee -mno-unaligned-access
|
||||
endif
|
||||
|
||||
ifneq ($(filter armv7%,$(UNAME_M)),)
|
||||
# Raspberry Pi 4
|
||||
CFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations
|
||||
# 32-bit ARM, for example on Armbian or possibly raspbian
|
||||
#CFLAGS += -mfpu=neon -mfp16-format=ieee -funsafe-math-optimizations -mno-unaligned-access
|
||||
#CXXFLAGS += -mfpu=neon -mfp16-format=ieee -funsafe-math-optimizations -mno-unaligned-access
|
||||
|
||||
# 64-bit ARM on 32-bit OS, use these (TODO: auto-detect 64-bit)
|
||||
CFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -funsafe-math-optimizations -mno-unaligned-access
|
||||
CXXFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -funsafe-math-optimizations -mno-unaligned-access
|
||||
endif
|
||||
|
||||
ifneq ($(filter armv8%,$(UNAME_M)),)
|
||||
# Raspberry Pi 4
|
||||
CFLAGS += -mfp16-format=ieee -mno-unaligned-access
|
||||
CFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -funsafe-math-optimizations -mno-unaligned-access
|
||||
CXXFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -funsafe-math-optimizations -mno-unaligned-access
|
||||
endif
|
||||
|
||||
#
|
||||
@ -166,26 +226,36 @@ $(info I CC: $(CCV))
|
||||
$(info I CXX: $(CXXV))
|
||||
$(info )
|
||||
|
||||
default: main
|
||||
|
||||
#
|
||||
# Build library
|
||||
#
|
||||
|
||||
ggml.o: ggml.c ggml.h
|
||||
$(CC) $(CFLAGS) -c ggml.c -o ggml.o
|
||||
ggml.o: ggml.c ggml.h ggml-cuda.h
|
||||
$(CC) $(CFLAGS) -c $< -o $@
|
||||
|
||||
whisper.o: whisper.cpp whisper.h
|
||||
$(CXX) $(CXXFLAGS) -c whisper.cpp -o whisper.o
|
||||
whisper.o: whisper.cpp whisper.h ggml.h ggml-cuda.h
|
||||
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||
|
||||
libwhisper.a: ggml.o whisper.o
|
||||
$(AR) rcs libwhisper.a ggml.o whisper.o
|
||||
ifndef WHISPER_COREML
|
||||
WHISPER_OBJ += whisper.o
|
||||
else
|
||||
whisper-encoder.o: coreml/whisper-encoder.mm coreml/whisper-encoder.h
|
||||
$(CXX) -O3 -I . -c coreml/whisper-encoder.mm -o whisper-encoder.o
|
||||
|
||||
libwhisper.so: ggml.o whisper.o
|
||||
$(CXX) $(CXXFLAGS) -shared -o libwhisper.so ggml.o whisper.o $(LDFLAGS)
|
||||
whisper-encoder-impl.o: coreml/whisper-encoder-impl.m coreml/whisper-encoder-impl.h
|
||||
$(CXX) -O3 -I . -fobjc-arc -c coreml/whisper-encoder-impl.m -o whisper-encoder-impl.o
|
||||
|
||||
WHISPER_OBJ += whisper.o whisper-encoder.o whisper-encoder-impl.o
|
||||
endif
|
||||
|
||||
libwhisper.a: ggml.o $(WHISPER_OBJ)
|
||||
$(AR) rcs libwhisper.a ggml.o $(WHISPER_OBJ)
|
||||
|
||||
libwhisper.so: ggml.o $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) -shared -o libwhisper.so ggml.o $(WHISPER_OBJ) $(LDFLAGS)
|
||||
|
||||
clean:
|
||||
rm -f *.o main stream command talk bench libwhisper.a libwhisper.so
|
||||
rm -f *.o main stream command talk talk-llama bench quantize libwhisper.a libwhisper.so
|
||||
|
||||
#
|
||||
# Examples
|
||||
@ -193,21 +263,30 @@ clean:
|
||||
|
||||
CC_SDL=`sdl2-config --cflags --libs`
|
||||
|
||||
main: examples/main/main.cpp ggml.o whisper.o
|
||||
$(CXX) $(CXXFLAGS) examples/main/main.cpp ggml.o whisper.o -o main $(LDFLAGS)
|
||||
SRC_COMMON = examples/common.cpp examples/common-ggml.cpp
|
||||
SRC_COMMON_SDL = examples/common-sdl.cpp
|
||||
|
||||
main: examples/main/main.cpp $(SRC_COMMON) ggml.o $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) examples/main/main.cpp $(SRC_COMMON) ggml.o $(WHISPER_OBJ) -o main $(LDFLAGS)
|
||||
./main -h
|
||||
|
||||
stream: examples/stream/stream.cpp ggml.o whisper.o
|
||||
$(CXX) $(CXXFLAGS) examples/stream/stream.cpp ggml.o whisper.o -o stream $(CC_SDL) $(LDFLAGS)
|
||||
bench: examples/bench/bench.cpp ggml.o $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) examples/bench/bench.cpp ggml.o $(WHISPER_OBJ) -o bench $(LDFLAGS)
|
||||
|
||||
command: examples/command/command.cpp ggml.o whisper.o
|
||||
$(CXX) $(CXXFLAGS) examples/command/command.cpp ggml.o whisper.o -o command $(CC_SDL) $(LDFLAGS)
|
||||
quantize: examples/quantize/quantize.cpp ggml.o $(WHISPER_OBJ) $(SRC_COMMON)
|
||||
$(CXX) $(CXXFLAGS) examples/quantize/quantize.cpp $(SRC_COMMON) ggml.o $(WHISPER_OBJ) -o quantize $(LDFLAGS)
|
||||
|
||||
talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp ggml.o whisper.o
|
||||
$(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp ggml.o whisper.o -o talk $(CC_SDL) $(LDFLAGS)
|
||||
stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o stream $(CC_SDL) $(LDFLAGS)
|
||||
|
||||
bench: examples/bench/bench.cpp ggml.o whisper.o
|
||||
$(CXX) $(CXXFLAGS) examples/bench/bench.cpp ggml.o whisper.o -o bench $(LDFLAGS)
|
||||
command: examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o command $(CC_SDL) $(LDFLAGS)
|
||||
|
||||
talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o talk $(CC_SDL) $(LDFLAGS)
|
||||
|
||||
talk-llama: examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o talk-llama $(CC_SDL) $(LDFLAGS)
|
||||
|
||||
#
|
||||
# Audio samples
|
||||
|
356
README.md
356
README.md
@ -1,21 +1,26 @@
|
||||
# whisper.cpp
|
||||
|
||||

|
||||
|
||||
[](https://github.com/ggerganov/whisper.cpp/actions)
|
||||
[](https://opensource.org/licenses/MIT)
|
||||
[](https://www.npmjs.com/package/whisper.cpp/)
|
||||
|
||||
Stable: [v1.0.4](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.0.4) / Beta: [v1.1.0](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.1.0) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126)
|
||||
Beta: [v1.4.1](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.4.1) / Stable: [v1.2.1](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.2.1) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126)
|
||||
|
||||
High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model:
|
||||
|
||||
- Plain C/C++ implementation without dependencies
|
||||
- Apple silicon first-class citizen - optimized via Arm Neon and Accelerate framework
|
||||
- Apple silicon first-class citizen - optimized via ARM NEON, Accelerate framework and [Core ML](https://github.com/ggerganov/whisper.cpp#core-ml-support)
|
||||
- AVX intrinsics support for x86 architectures
|
||||
- VSX intrinsics support for POWER architectures
|
||||
- Mixed F16 / F32 precision
|
||||
- Low memory usage (Flash Attention + Flash Forward)
|
||||
- [4-bit and 5-bit integer quantization support](https://github.com/ggerganov/whisper.cpp#quantization)
|
||||
- Low memory usage (Flash Attention)
|
||||
- Zero memory allocations at runtime
|
||||
- Runs on the CPU
|
||||
- [Partial GPU support for NVIDIA via cuBLAS](https://github.com/ggerganov/whisper.cpp#nvidia-gpu-support-via-cublas)
|
||||
- [Partial OpenCL GPU support via CLBlast](https://github.com/ggerganov/whisper.cpp#opencl-gpu-support-via-clblast)
|
||||
- [C-style API](https://github.com/ggerganov/whisper.cpp/blob/master/whisper.h)
|
||||
|
||||
Supported platforms:
|
||||
@ -58,7 +63,9 @@ the Accelerate framework utilizes the special-purpose AMX coprocessor available
|
||||
|
||||
## Quick start
|
||||
|
||||
First, download one of the Whisper models converted in [ggml format](models). For example:
|
||||
First clone the repository.
|
||||
|
||||
Then, download one of the Whisper models converted in [ggml format](models). For example:
|
||||
|
||||
```bash
|
||||
bash ./models/download-ggml-model.sh base.en
|
||||
@ -89,35 +96,37 @@ c++ -I. -I./examples -O3 -std=c++11 -pthread examples/main/main.cpp whisper.o gg
|
||||
usage: ./main [options] file0.wav file1.wav ...
|
||||
|
||||
options:
|
||||
-h, --help [default] show this help message and exit
|
||||
-t N, --threads N [4 ] number of threads to use during computation
|
||||
-p N, --processors N [1 ] number of processors to use during computation
|
||||
-ot N, --offset-t N [0 ] time offset in milliseconds
|
||||
-on N, --offset-n N [0 ] segment index offset
|
||||
-d N, --duration N [0 ] duration of audio to process in milliseconds
|
||||
-mc N, --max-context N [-1 ] maximum number of text context tokens to store
|
||||
-ml N, --max-len N [0 ] maximum segment length in characters
|
||||
-bo N, --best-of N [5 ] number of best candidates to keep
|
||||
-bs N, --beam-size N [-1 ] beam size for beam search
|
||||
-wt N, --word-thold N [0.01 ] word timestamp probability threshold
|
||||
-et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail
|
||||
-lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail
|
||||
-su, --speed-up [false ] speed up audio by x2 (reduced accuracy)
|
||||
-tr, --translate [false ] translate from source language to english
|
||||
-di, --diarize [false ] stereo audio diarization
|
||||
-otxt, --output-txt [false ] output result in a text file
|
||||
-ovtt, --output-vtt [false ] output result in a vtt file
|
||||
-osrt, --output-srt [false ] output result in a srt file
|
||||
-owts, --output-words [false ] output script for generating karaoke video
|
||||
-ocsv, --output-csv [false ] output result in a CSV file
|
||||
-ps, --print-special [false ] print special tokens
|
||||
-pc, --print-colors [false ] print colors
|
||||
-pp, --print-progress [false ] print progress
|
||||
-nt, --no-timestamps [true ] do not print timestamps
|
||||
-l LANG, --language LANG [en ] spoken language ('auto' for auto-detect)
|
||||
--prompt PROMPT [ ] initial prompt
|
||||
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path
|
||||
-f FNAME, --file FNAME [ ] input WAV file path
|
||||
-h, --help [default] show this help message and exit
|
||||
-t N, --threads N [4 ] number of threads to use during computation
|
||||
-p N, --processors N [1 ] number of processors to use during computation
|
||||
-ot N, --offset-t N [0 ] time offset in milliseconds
|
||||
-on N, --offset-n N [0 ] segment index offset
|
||||
-d N, --duration N [0 ] duration of audio to process in milliseconds
|
||||
-mc N, --max-context N [-1 ] maximum number of text context tokens to store
|
||||
-ml N, --max-len N [0 ] maximum segment length in characters
|
||||
-bo N, --best-of N [5 ] number of best candidates to keep
|
||||
-bs N, --beam-size N [-1 ] beam size for beam search
|
||||
-wt N, --word-thold N [0.01 ] word timestamp probability threshold
|
||||
-et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail
|
||||
-lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail
|
||||
-su, --speed-up [false ] speed up audio by x2 (reduced accuracy)
|
||||
-tr, --translate [false ] translate from source language to english
|
||||
-di, --diarize [false ] stereo audio diarization
|
||||
-nf, --no-fallback [false ] do not use temperature fallback while decoding
|
||||
-otxt, --output-txt [false ] output result in a text file
|
||||
-ovtt, --output-vtt [false ] output result in a vtt file
|
||||
-osrt, --output-srt [false ] output result in a srt file
|
||||
-owts, --output-words [false ] output script for generating karaoke video
|
||||
-ocsv, --output-csv [false ] output result in a CSV file
|
||||
-of FNAME, --output-file FNAME [ ] output file path (without file extension)
|
||||
-ps, --print-special [false ] print special tokens
|
||||
-pc, --print-colors [false ] print colors
|
||||
-pp, --print-progress [false ] print progress
|
||||
-nt, --no-timestamps [true ] do not print timestamps
|
||||
-l LANG, --language LANG [en ] spoken language ('auto' for auto-detect)
|
||||
--prompt PROMPT [ ] initial prompt
|
||||
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path
|
||||
-f FNAME, --file FNAME [ ] input WAV file path
|
||||
|
||||
|
||||
bash ./models/download-ggml-model.sh base.en
|
||||
@ -137,7 +146,8 @@ Running base.en on all samples in ./samples ...
|
||||
[+] Running base.en on samples/jfk.wav ... (run 'ffplay samples/jfk.wav' to listen)
|
||||
----------------------------------------------
|
||||
|
||||
whisper_model_load: loading model from 'models/ggml-base.en.bin'
|
||||
whisper_init_from_file: loading model from 'models/ggml-base.en.bin'
|
||||
whisper_model_load: loading model
|
||||
whisper_model_load: n_vocab = 51864
|
||||
whisper_model_load: n_audio_ctx = 1500
|
||||
whisper_model_load: n_audio_state = 512
|
||||
@ -150,13 +160,14 @@ whisper_model_load: n_text_layer = 6
|
||||
whisper_model_load: n_mels = 80
|
||||
whisper_model_load: f16 = 1
|
||||
whisper_model_load: type = 2
|
||||
whisper_model_load: mem required = 215.00 MB (+ 6.00 MB per decoder)
|
||||
whisper_model_load: kv self size = 5.25 MB
|
||||
whisper_model_load: kv cross size = 17.58 MB
|
||||
whisper_model_load: adding 1607 extra tokens
|
||||
whisper_model_load: mem_required = 506.00 MB
|
||||
whisper_model_load: ggml ctx size = 140.60 MB
|
||||
whisper_model_load: memory size = 22.83 MB
|
||||
whisper_model_load: model ctx = 140.60 MB
|
||||
whisper_model_load: model size = 140.54 MB
|
||||
|
||||
system_info: n_threads = 4 / 10 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | NEON = 1 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 |
|
||||
system_info: n_threads = 4 / 10 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 |
|
||||
|
||||
main: processing 'samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 processors, lang = en, task = transcribe, timestamps = 1 ...
|
||||
|
||||
@ -164,12 +175,13 @@ main: processing 'samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 proc
|
||||
[00:00:00.000 --> 00:00:11.000] And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country.
|
||||
|
||||
|
||||
whisper_print_timings: load time = 105.91 ms
|
||||
whisper_print_timings: mel time = 24.62 ms
|
||||
whisper_print_timings: sample time = 3.63 ms
|
||||
whisper_print_timings: encode time = 324.71 ms / 54.12 ms per layer
|
||||
whisper_print_timings: decode time = 83.58 ms / 13.93 ms per layer
|
||||
whisper_print_timings: total time = 542.81 ms
|
||||
whisper_print_timings: fallbacks = 0 p / 0 h
|
||||
whisper_print_timings: load time = 113.81 ms
|
||||
whisper_print_timings: mel time = 15.40 ms
|
||||
whisper_print_timings: sample time = 11.58 ms / 27 runs ( 0.43 ms per run)
|
||||
whisper_print_timings: encode time = 266.60 ms / 1 runs ( 266.60 ms per run)
|
||||
whisper_print_timings: decode time = 66.11 ms / 27 runs ( 2.45 ms per run)
|
||||
whisper_print_timings: total time = 476.31 ms
|
||||
```
|
||||
|
||||
The command downloads the `base.en` model converted to custom `ggml` format and runs the inference on all `.wav` samples in the folder `samples`.
|
||||
@ -212,16 +224,122 @@ make large
|
||||
|
||||
| Model | Disk | Mem | SHA |
|
||||
| --- | --- | --- | --- |
|
||||
| tiny | 75 MB | ~390 MB | `bd577a113a864445d4c299885e0cb97d4ba92b5f` |
|
||||
| base | 142 MB | ~500 MB | `465707469ff3a37a2b9b8d8f89f2f99de7299dac` |
|
||||
| small | 466 MB | ~1.0 GB | `55356645c2b361a969dfd0ef2c5a50d530afd8d5` |
|
||||
| medium | 1.5 GB | ~2.6 GB | `fd9727b6e1217c2f614f9b698455c4ffd82463b4` |
|
||||
| large | 2.9 GB | ~4.7 GB | `0f4c8e34f21cf1a914c59d8b3ce882345ad349d6` |
|
||||
| tiny | 75 MB | ~125 MB | `bd577a113a864445d4c299885e0cb97d4ba92b5f` |
|
||||
| base | 142 MB | ~210 MB | `465707469ff3a37a2b9b8d8f89f2f99de7299dac` |
|
||||
| small | 466 MB | ~600 MB | `55356645c2b361a969dfd0ef2c5a50d530afd8d5` |
|
||||
| medium | 1.5 GB | ~1.7 GB | `fd9727b6e1217c2f614f9b698455c4ffd82463b4` |
|
||||
| large | 2.9 GB | ~3.3 GB | `0f4c8e34f21cf1a914c59d8b3ce882345ad349d6` |
|
||||
|
||||
## Quantization
|
||||
|
||||
`whisper.cpp` supports integer quantization of the Whisper `ggml` models.
|
||||
Quantized models require less memory and disk space and depending on the hardware can be processed more efficiently.
|
||||
|
||||
Here are the steps for creating and using a quantized model:
|
||||
|
||||
```bash
|
||||
# quantize a model with Q5_0 method
|
||||
make quantize
|
||||
./quantize models/ggml-base.en.bin models/ggml-base.en-q5_0.bin q5_0
|
||||
|
||||
# run the examples as usual, specifying the quantized model file
|
||||
./main -m models/ggml-base.en-q5_0.bin ./samples/gb0.wav
|
||||
```
|
||||
|
||||
## Core ML support
|
||||
|
||||
On Apple Silicon devices, the Encoder inference can be executed on the Apple Neural Engine (ANE) via Core ML. This can result in significant
|
||||
speed-up - more than x3 faster compared with CPU-only execution. Here are the instructions for generating a Core ML model and using it with `whisper.cpp`:
|
||||
|
||||
- Install Python dependencies needed for the creation of the Core ML model:
|
||||
|
||||
```bash
|
||||
pip install ane_transformers
|
||||
pip install openai-whisper
|
||||
pip install coremltools
|
||||
```
|
||||
|
||||
- Generate a Core ML model. For example, to generate a `base.en` model, use:
|
||||
|
||||
```bash
|
||||
./models/generate-coreml-model.sh base.en
|
||||
```
|
||||
|
||||
This will generate the folder `models/ggml-base.en-encoder.mlmodelc`
|
||||
|
||||
- Build `whisper.cpp` with Core ML support:
|
||||
|
||||
```bash
|
||||
# using Makefile
|
||||
make clean
|
||||
WHISPER_COREML=1 make -j
|
||||
|
||||
# using CMake
|
||||
cd build
|
||||
cmake -DWHISPER_COREML=1 ..
|
||||
```
|
||||
|
||||
- Run the examples as usual. For example:
|
||||
|
||||
```bash
|
||||
./main -m models/ggml-base.en.bin -f samples/jfk.wav
|
||||
|
||||
...
|
||||
|
||||
whisper_init_state: loading Core ML model from 'models/ggml-base.en-encoder.mlmodelc'
|
||||
whisper_init_state: first run on a device may take a while ...
|
||||
whisper_init_state: Core ML model loaded
|
||||
|
||||
system_info: n_threads = 4 / 10 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 | COREML = 1 |
|
||||
|
||||
...
|
||||
```
|
||||
|
||||
The first run on a device is slow, since the ANE service compiles the Core ML model to some device-specific format.
|
||||
Next runs are faster.
|
||||
|
||||
For more information about the Core ML implementation please refer to PR [#566](https://github.com/ggerganov/whisper.cpp/pull/566).
|
||||
|
||||
## NVIDIA GPU support via cuBLAS
|
||||
|
||||
With NVIDIA cards, the Encoder processing can be offloaded to the GPU to a large extend through cuBLAS.
|
||||
First, make sure you have installed `cuda`: https://developer.nvidia.com/cuda-downloads
|
||||
|
||||
Now build `whisper.cpp` with cuBLAS support:
|
||||
|
||||
```
|
||||
make clean
|
||||
WHISPER_CUBLAS=1 make -j
|
||||
```
|
||||
|
||||
## OpenCL GPU support via CLBlast
|
||||
|
||||
For cards and integrated GPUs that support OpenCL, the Encoder processing can be largely offloaded to the GPU through CLBlast. This is especially useful for users with AMD APU's or low end devices for up to ~2x speedup.
|
||||
|
||||
First, make sure you have installed `CLBlast` for your OS or Distribution: https://github.com/CNugteren/CLBlast
|
||||
|
||||
Now build `whisper.cpp` with CLBlast support:
|
||||
|
||||
```
|
||||
Makefile:
|
||||
cd whisper.cpp
|
||||
make clean
|
||||
WHISPER_CLBLAST=1 make -j
|
||||
|
||||
CMake:
|
||||
cd whisper.cpp ; mkdir build ; cd build
|
||||
cmake -DWHISPER_CLBLAST=ON ..
|
||||
make clean
|
||||
make -j
|
||||
cp bin/* ../
|
||||
```
|
||||
|
||||
|
||||
Run all the examples as usual.
|
||||
|
||||
## Limitations
|
||||
|
||||
- Inference only
|
||||
- No GPU support (yet)
|
||||
|
||||
## Another example
|
||||
|
||||
@ -234,7 +352,8 @@ in about half a minute on a MacBook M1 Pro, using `medium.en` model:
|
||||
```java
|
||||
$ ./main -m models/ggml-medium.en.bin -f samples/gb1.wav -t 8
|
||||
|
||||
whisper_model_load: loading model from 'models/ggml-medium.en.bin'
|
||||
whisper_init_from_file: loading model from 'models/ggml-medium.en.bin'
|
||||
whisper_model_load: loading model
|
||||
whisper_model_load: n_vocab = 51864
|
||||
whisper_model_load: n_audio_ctx = 1500
|
||||
whisper_model_load: n_audio_state = 1024
|
||||
@ -247,62 +366,67 @@ whisper_model_load: n_text_layer = 24
|
||||
whisper_model_load: n_mels = 80
|
||||
whisper_model_load: f16 = 1
|
||||
whisper_model_load: type = 4
|
||||
whisper_model_load: mem_required = 2610.00 MB
|
||||
whisper_model_load: mem required = 1720.00 MB (+ 43.00 MB per decoder)
|
||||
whisper_model_load: kv self size = 42.00 MB
|
||||
whisper_model_load: kv cross size = 140.62 MB
|
||||
whisper_model_load: adding 1607 extra tokens
|
||||
whisper_model_load: ggml ctx size = 1644.97 MB
|
||||
whisper_model_load: memory size = 182.62 MB
|
||||
whisper_model_load: model size = 1462.12 MB
|
||||
whisper_model_load: model ctx = 1462.35 MB
|
||||
whisper_model_load: model size = 1462.12 MB
|
||||
|
||||
main: processing 'samples/gb1.wav' (3179750 samples, 198.7 sec), 8 threads, lang = en, task = transcribe, timestamps = 1 ...
|
||||
system_info: n_threads = 8 / 10 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 |
|
||||
|
||||
[00:00.000 --> 00:08.000] My fellow Americans, this day has brought terrible news and great sadness to our country.
|
||||
[00:08.000 --> 00:17.000] At nine o'clock this morning, Mission Control in Houston lost contact with our Space Shuttle Columbia.
|
||||
[00:17.000 --> 00:23.000] A short time later, debris was seen falling from the skies above Texas.
|
||||
[00:23.000 --> 00:29.000] The Columbia's lost. There are no survivors.
|
||||
[00:29.000 --> 00:32.000] On board was a crew of seven.
|
||||
[00:32.000 --> 00:39.000] Colonel Rick Husband, Lieutenant Colonel Michael Anderson, Commander Laurel Clark,
|
||||
[00:39.000 --> 00:48.000] Captain David Brown, Commander William McCool, Dr. Kultna Shavla, and Ilan Ramon,
|
||||
[00:48.000 --> 00:52.000] a colonel in the Israeli Air Force.
|
||||
[00:52.000 --> 00:58.000] These men and women assumed great risk in the service to all humanity.
|
||||
[00:58.000 --> 01:03.000] In an age when space flight has come to seem almost routine,
|
||||
[01:03.000 --> 01:07.000] it is easy to overlook the dangers of travel by rocket
|
||||
[01:07.000 --> 01:12.000] and the difficulties of navigating the fierce outer atmosphere of the Earth.
|
||||
[01:12.000 --> 01:18.000] These astronauts knew the dangers, and they faced them willingly,
|
||||
[01:18.000 --> 01:23.000] knowing they had a high and noble purpose in life.
|
||||
[01:23.000 --> 01:31.000] Because of their courage and daring and idealism, we will miss them all the more.
|
||||
[01:31.000 --> 01:36.000] All Americans today are thinking as well of the families of these men and women
|
||||
[01:36.000 --> 01:40.000] who have been given this sudden shock and grief.
|
||||
[01:40.000 --> 01:45.000] You're not alone. Our entire nation grieves with you,
|
||||
[01:45.000 --> 01:52.000] and those you love will always have the respect and gratitude of this country.
|
||||
[01:52.000 --> 01:56.000] The cause in which they died will continue.
|
||||
[01:56.000 --> 02:04.000] Mankind is led into the darkness beyond our world by the inspiration of discovery
|
||||
[02:04.000 --> 02:11.000] and the longing to understand. Our journey into space will go on.
|
||||
[02:11.000 --> 02:16.000] In the skies today, we saw destruction and tragedy.
|
||||
[02:16.000 --> 02:22.000] Yet farther than we can see, there is comfort and hope.
|
||||
[02:22.000 --> 02:29.000] In the words of the prophet Isaiah, "Lift your eyes and look to the heavens
|
||||
[02:29.000 --> 02:35.000] who created all these. He who brings out the starry hosts one by one
|
||||
[02:35.000 --> 02:39.000] and calls them each by name."
|
||||
[02:39.000 --> 02:46.000] Because of His great power and mighty strength, not one of them is missing.
|
||||
[02:46.000 --> 02:55.000] The same Creator who names the stars also knows the names of the seven souls we mourn today.
|
||||
[02:55.000 --> 03:01.000] The crew of the shuttle Columbia did not return safely to earth,
|
||||
[03:01.000 --> 03:05.000] yet we can pray that all are safely home.
|
||||
[03:05.000 --> 03:13.000] May God bless the grieving families, and may God continue to bless America.
|
||||
[03:13.000 --> 03:41.000] Audio
|
||||
main: processing 'samples/gb1.wav' (3179750 samples, 198.7 sec), 8 threads, 1 processors, lang = en, task = transcribe, timestamps = 1 ...
|
||||
|
||||
|
||||
whisper_print_timings: load time = 575.92 ms
|
||||
whisper_print_timings: mel time = 230.60 ms
|
||||
whisper_print_timings: sample time = 73.19 ms
|
||||
whisper_print_timings: encode time = 19552.61 ms / 814.69 ms per layer
|
||||
whisper_print_timings: decode time = 13249.96 ms / 552.08 ms per layer
|
||||
whisper_print_timings: total time = 33686.27 ms
|
||||
[00:00:00.000 --> 00:00:08.000] My fellow Americans, this day has brought terrible news and great sadness to our country.
|
||||
[00:00:08.000 --> 00:00:17.000] At nine o'clock this morning, Mission Control in Houston lost contact with our Space Shuttle Columbia.
|
||||
[00:00:17.000 --> 00:00:23.000] A short time later, debris was seen falling from the skies above Texas.
|
||||
[00:00:23.000 --> 00:00:29.000] The Columbia's lost. There are no survivors.
|
||||
[00:00:29.000 --> 00:00:32.000] On board was a crew of seven.
|
||||
[00:00:32.000 --> 00:00:39.000] Colonel Rick Husband, Lieutenant Colonel Michael Anderson, Commander Laurel Clark,
|
||||
[00:00:39.000 --> 00:00:48.000] Captain David Brown, Commander William McCool, Dr. Kultna Shavla, and Ilan Ramon,
|
||||
[00:00:48.000 --> 00:00:52.000] a colonel in the Israeli Air Force.
|
||||
[00:00:52.000 --> 00:00:58.000] These men and women assumed great risk in the service to all humanity.
|
||||
[00:00:58.000 --> 00:01:03.000] In an age when space flight has come to seem almost routine,
|
||||
[00:01:03.000 --> 00:01:07.000] it is easy to overlook the dangers of travel by rocket
|
||||
[00:01:07.000 --> 00:01:12.000] and the difficulties of navigating the fierce outer atmosphere of the Earth.
|
||||
[00:01:12.000 --> 00:01:18.000] These astronauts knew the dangers, and they faced them willingly,
|
||||
[00:01:18.000 --> 00:01:23.000] knowing they had a high and noble purpose in life.
|
||||
[00:01:23.000 --> 00:01:31.000] Because of their courage and daring and idealism, we will miss them all the more.
|
||||
[00:01:31.000 --> 00:01:36.000] All Americans today are thinking as well of the families of these men and women
|
||||
[00:01:36.000 --> 00:01:40.000] who have been given this sudden shock and grief.
|
||||
[00:01:40.000 --> 00:01:45.000] You're not alone. Our entire nation grieves with you,
|
||||
[00:01:45.000 --> 00:01:52.000] and those you love will always have the respect and gratitude of this country.
|
||||
[00:01:52.000 --> 00:01:56.000] The cause in which they died will continue.
|
||||
[00:01:56.000 --> 00:02:04.000] Mankind is led into the darkness beyond our world by the inspiration of discovery
|
||||
[00:02:04.000 --> 00:02:11.000] and the longing to understand. Our journey into space will go on.
|
||||
[00:02:11.000 --> 00:02:16.000] In the skies today, we saw destruction and tragedy.
|
||||
[00:02:16.000 --> 00:02:22.000] Yet farther than we can see, there is comfort and hope.
|
||||
[00:02:22.000 --> 00:02:29.000] In the words of the prophet Isaiah, "Lift your eyes and look to the heavens
|
||||
[00:02:29.000 --> 00:02:35.000] who created all these. He who brings out the starry hosts one by one
|
||||
[00:02:35.000 --> 00:02:39.000] and calls them each by name."
|
||||
[00:02:39.000 --> 00:02:46.000] Because of His great power and mighty strength, not one of them is missing.
|
||||
[00:02:46.000 --> 00:02:55.000] The same Creator who names the stars also knows the names of the seven souls we mourn today.
|
||||
[00:02:55.000 --> 00:03:01.000] The crew of the shuttle Columbia did not return safely to earth,
|
||||
[00:03:01.000 --> 00:03:05.000] yet we can pray that all are safely home.
|
||||
[00:03:05.000 --> 00:03:13.000] May God bless the grieving families, and may God continue to bless America.
|
||||
[00:03:13.000 --> 00:03:19.000] [Silence]
|
||||
|
||||
|
||||
whisper_print_timings: fallbacks = 1 p / 0 h
|
||||
whisper_print_timings: load time = 569.03 ms
|
||||
whisper_print_timings: mel time = 146.85 ms
|
||||
whisper_print_timings: sample time = 238.66 ms / 553 runs ( 0.43 ms per run)
|
||||
whisper_print_timings: encode time = 18665.10 ms / 9 runs ( 2073.90 ms per run)
|
||||
whisper_print_timings: decode time = 13090.93 ms / 549 runs ( 23.85 ms per run)
|
||||
whisper_print_timings: total time = 32733.52 ms
|
||||
```
|
||||
</details>
|
||||
|
||||
## Real-time audio input example
|
||||
|
||||
This is a naive example of performing real-time inference on audio from your microphone.
|
||||
The [stream](examples/stream) tool samples the audio every half a second and runs the transcription continously.
|
||||
The [stream](examples/stream) tool samples the audio every half a second and runs the transcription continuously.
|
||||
More info is available in [issue #10](https://github.com/ggerganov/whisper.cpp/issues/10).
|
||||
|
||||
```java
|
||||
@ -317,18 +441,22 @@ https://user-images.githubusercontent.com/1991296/194935793-76afede7-cfa8-48d8-a
|
||||
Adding the `--print-colors` argument will print the transcribed text using an experimental color coding strategy
|
||||
to highlight words with high or low confidence:
|
||||
|
||||
```java
|
||||
./main -m models/ggml-base.en.bin -f samples/gb0.wav --print-colors
|
||||
```
|
||||
|
||||
<img width="965" alt="image" src="https://user-images.githubusercontent.com/1991296/197356445-311c8643-9397-4e5e-b46e-0b4b4daa2530.png">
|
||||
|
||||
## Controlling the length of the generated text segments (experimental)
|
||||
|
||||
For example, to limit the line length to a maximum of 16 characters, simply add `-ml 16`:
|
||||
For example, to limit the line length to a maximum of 16 characters, simply add `-ml 16`:
|
||||
|
||||
```java
|
||||
./main -m ./models/ggml-base.en.bin -f ./samples/jfk.wav -ml 16
|
||||
|
||||
whisper_model_load: loading model from './models/ggml-base.en.bin'
|
||||
...
|
||||
system_info: n_threads = 4 / 10 | AVX2 = 0 | AVX512 = 0 | NEON = 1 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 |
|
||||
system_info: n_threads = 4 / 10 | AVX2 = 0 | AVX512 = 0 | NEON = 1 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 |
|
||||
|
||||
main: processing './samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 processors, lang = en, task = transcribe, timestamps = 1 ...
|
||||
|
||||
@ -352,11 +480,11 @@ The `--max-len` argument can be used to obtain word-level timestamps. Simply use
|
||||
|
||||
whisper_model_load: loading model from './models/ggml-base.en.bin'
|
||||
...
|
||||
system_info: n_threads = 4 / 10 | AVX2 = 0 | AVX512 = 0 | NEON = 1 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 |
|
||||
system_info: n_threads = 4 / 10 | AVX2 = 0 | AVX512 = 0 | NEON = 1 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 |
|
||||
|
||||
main: processing './samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 processors, lang = en, task = transcribe, timestamps = 1 ...
|
||||
|
||||
[00:00:00.000 --> 00:00:00.320]
|
||||
[00:00:00.000 --> 00:00:00.320]
|
||||
[00:00:00.320 --> 00:00:00.370] And
|
||||
[00:00:00.370 --> 00:00:00.690] so
|
||||
[00:00:00.690 --> 00:00:00.850] my
|
||||
@ -422,6 +550,19 @@ https://user-images.githubusercontent.com/1991296/199337538-b7b0c7a3-2753-4a88-a
|
||||
|
||||
---
|
||||
|
||||
## Video comparison of different models
|
||||
|
||||
Use the [extra/bench-wts.sh](https://github.com/ggerganov/whisper.cpp/blob/master/extra/bench-wts.sh) script to generate a video in the following format:
|
||||
|
||||
```java
|
||||
./extra/bench-wts.sh samples/jfk.wav
|
||||
ffplay ./samples/jfk.wav.all.mp4
|
||||
```
|
||||
|
||||
https://user-images.githubusercontent.com/1991296/223206245-2d36d903-cf8e-4f09-8c3b-eb9f9c39d6fc.mp4
|
||||
|
||||
---
|
||||
|
||||
## Benchmarks
|
||||
|
||||
In order to have an objective comparison of the performance of the inference across different system configurations,
|
||||
@ -442,7 +583,7 @@ The original models are converted to a custom binary format. This allows to pack
|
||||
You can download the converted models using the [models/download-ggml-model.sh](models/download-ggml-model.sh) script
|
||||
or manually from here:
|
||||
|
||||
- https://huggingface.co/datasets/ggerganov/whisper.cpp
|
||||
- https://huggingface.co/ggerganov/whisper.cpp
|
||||
- https://ggml.ggerganov.com
|
||||
|
||||
For more details, see the conversion script [models/convert-pt-to-ggml.py](models/convert-pt-to-ggml.py) or the README
|
||||
@ -452,9 +593,19 @@ in [models](models).
|
||||
|
||||
- [X] Rust: [tazz4843/whisper-rs](https://github.com/tazz4843/whisper-rs) | [#310](https://github.com/ggerganov/whisper.cpp/discussions/310)
|
||||
- [X] Javascript: [bindings/javascript](bindings/javascript) | [#309](https://github.com/ggerganov/whisper.cpp/discussions/309)
|
||||
- React Native (iOS / Android): [whisper.rn](https://github.com/mybigday/whisper.rn)
|
||||
- [X] Go: [bindings/go](bindings/go) | [#312](https://github.com/ggerganov/whisper.cpp/discussions/312)
|
||||
- [X] Ruby: [bindings/ruby](bindings/ruby) | [#507](https://github.com/ggerganov/whisper.cpp/discussions/507)
|
||||
- [X] Objective-C / Swift: [ggerganov/whisper.spm](https://github.com/ggerganov/whisper.spm) | [#313](https://github.com/ggerganov/whisper.cpp/discussions/313)
|
||||
- [ ] Python: soon | [WIP](https://github.com/ggerganov/whisper.cpp/issues/9)
|
||||
- [exPHAT/SwiftWhisper](https://github.com/exPHAT/SwiftWhisper)
|
||||
- [X] .NET: | [#422](https://github.com/ggerganov/whisper.cpp/discussions/422)
|
||||
- [sandrohanea/whisper.net](https://github.com/sandrohanea/whisper.net)
|
||||
- [NickDarvey/whisper](https://github.com/NickDarvey/whisper)
|
||||
- [X] Python: | [#9](https://github.com/ggerganov/whisper.cpp/issues/9)
|
||||
- [stlukey/whispercpp.py](https://github.com/stlukey/whispercpp.py) (Cython)
|
||||
- [aarnphm/whispercpp](https://github.com/aarnphm/whispercpp) (Pybind11)
|
||||
- [X] R: [bnosac/audio.whisper](https://github.com/bnosac/audio.whisper)
|
||||
- [X] Unity: [macoron/whisper.unity](https://github.com/Macoron/whisper.unity)
|
||||
|
||||
## Examples
|
||||
|
||||
@ -468,6 +619,7 @@ Some of the examples are even ported to run in the browser using WebAssembly. Ch
|
||||
| [stream](examples/stream) | [stream.wasm](examples/stream.wasm) | Real-time transcription of raw microphone capture |
|
||||
| [command](examples/command) | [command.wasm](examples/command.wasm) | Basic voice assistant example for receiving voice commands from the mic |
|
||||
| [talk](examples/talk) | [talk.wasm](examples/talk.wasm) | Talk with a GPT-2 bot |
|
||||
| [talk-llama](examples/talk-llama) | | Talk with a LLaMA bot |
|
||||
| [whisper.objc](examples/whisper.objc) | | iOS mobile application using whisper.cpp |
|
||||
| [whisper.swiftui](examples/whisper.swiftui) | | SwiftUI iOS / macOS application using whisper.cpp |
|
||||
| [whisper.android](examples/whisper.android) | | Android mobile application using whisper.cpp |
|
||||
|
@ -17,9 +17,9 @@ import (
|
||||
// CONSTANTS
|
||||
|
||||
const (
|
||||
srcUrl = "https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main" // The location of the models
|
||||
srcExt = ".bin" // Filename extension
|
||||
bufSize = 1024 * 64 // Size of the buffer used for downloading the model
|
||||
srcUrl = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main" // The location of the models
|
||||
srcExt = ".bin" // Filename extension
|
||||
bufSize = 1024 * 64 // Size of the buffer used for downloading the model
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -25,6 +25,8 @@ func Process(model whisper.Model, path string, flags *Flags) error {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Printf("\n%s\n", context.SystemInfo())
|
||||
|
||||
// Open the file
|
||||
fmt.Fprintf(flags.Output(), "Loading %q\n", path)
|
||||
fh, err := os.Open(path)
|
||||
@ -64,10 +66,13 @@ func Process(model whisper.Model, path string, flags *Flags) error {
|
||||
|
||||
// Process the data
|
||||
fmt.Fprintf(flags.Output(), " ...processing %q\n", path)
|
||||
context.ResetTimings()
|
||||
if err := context.Process(data, cb); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
context.PrintTimings()
|
||||
|
||||
// Print out the results
|
||||
switch {
|
||||
case flags.GetOut() == "srt":
|
||||
|
@ -49,6 +49,10 @@ func (p *Params) SetSpeedup(v bool) {
|
||||
|
||||
// Set language id
|
||||
func (p *Params) SetLanguage(lang int) error {
|
||||
if lang == -1 {
|
||||
p.language = nil
|
||||
return nil
|
||||
}
|
||||
str := C.whisper_lang_str(C.int(lang))
|
||||
if str == nil {
|
||||
return ErrInvalidLanguage
|
||||
@ -66,6 +70,11 @@ func (p *Params) Language() int {
|
||||
return int(C.whisper_lang_id(p.language))
|
||||
}
|
||||
|
||||
// Threads available
|
||||
func (p *Params) Threads() int {
|
||||
return int(p.n_threads)
|
||||
}
|
||||
|
||||
// Set number of threads to use
|
||||
func (p *Params) SetThreads(threads int) {
|
||||
p.n_threads = C.int(threads)
|
||||
@ -96,6 +105,10 @@ func (p *Params) SetMaxSegmentLength(n int) {
|
||||
p.max_len = C.int(n)
|
||||
}
|
||||
|
||||
func (p *Params) SetTokenTimestamps(b bool) {
|
||||
p.token_timestamps = toBool(b)
|
||||
}
|
||||
|
||||
// Set max tokens per segment (0 = no limit)
|
||||
func (p *Params) SetMaxTokensPerSegment(n int) {
|
||||
p.max_tokens = C.int(n)
|
||||
|
@ -1,7 +1,9 @@
|
||||
package whisper
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@ -44,7 +46,10 @@ func (context *context) SetLanguage(lang string) error {
|
||||
if !context.model.IsMultilingual() {
|
||||
return ErrModelNotMultilingual
|
||||
}
|
||||
if id := context.model.ctx.Whisper_lang_id(lang); id < 0 {
|
||||
|
||||
if lang == "auto" {
|
||||
context.params.SetLanguage(-1)
|
||||
} else if id := context.model.ctx.Whisper_lang_id(lang); id < 0 {
|
||||
return ErrUnsupportedLanguage
|
||||
} else if err := context.params.SetLanguage(id); err != nil {
|
||||
return err
|
||||
@ -59,6 +64,10 @@ func (context *context) IsMultilingual() bool {
|
||||
|
||||
// Get language
|
||||
func (context *context) Language() string {
|
||||
id := context.params.Language()
|
||||
if id == -1 {
|
||||
return "auto"
|
||||
}
|
||||
return whisper.Whisper_lang_str(context.params.Language())
|
||||
}
|
||||
|
||||
@ -102,11 +111,46 @@ func (context *context) SetMaxSegmentLength(n uint) {
|
||||
context.params.SetMaxSegmentLength(int(n))
|
||||
}
|
||||
|
||||
// Set token timestamps flag
|
||||
func (context *context) SetTokenTimestamps(b bool) {
|
||||
context.params.SetTokenTimestamps(b)
|
||||
}
|
||||
|
||||
// Set max tokens per segment (0 = no limit)
|
||||
func (context *context) SetMaxTokensPerSegment(n uint) {
|
||||
context.params.SetMaxTokensPerSegment(int(n))
|
||||
}
|
||||
|
||||
// ResetTimings resets the mode timings. Should be called before processing
|
||||
func (context *context) ResetTimings() {
|
||||
context.model.ctx.Whisper_reset_timings()
|
||||
}
|
||||
|
||||
// PrintTimings prints the model timings to stdout.
|
||||
func (context *context) PrintTimings() {
|
||||
context.model.ctx.Whisper_print_timings()
|
||||
}
|
||||
|
||||
// SystemInfo returns the system information
|
||||
func (context *context) SystemInfo() string {
|
||||
return fmt.Sprintf("system_info: n_threads = %d / %d | %s\n",
|
||||
context.params.Threads(),
|
||||
runtime.NumCPU(),
|
||||
whisper.Whisper_print_system_info(),
|
||||
)
|
||||
}
|
||||
|
||||
// Use mel data at offset_ms to try and auto-detect the spoken language
|
||||
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
|
||||
// Returns the probabilities of all languages.
|
||||
func (context *context) WhisperLangAutoDetect(offset_ms int, n_threads int) ([]float32, error) {
|
||||
langProbs, err := context.model.ctx.Whisper_lang_auto_detect(offset_ms, n_threads)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return langProbs, nil
|
||||
}
|
||||
|
||||
// Process new sample data and return any errors
|
||||
func (context *context) Process(data []float32, cb SegmentCallback) error {
|
||||
if context.model.ctx == nil {
|
||||
@ -241,10 +285,14 @@ func toSegment(ctx *whisper.Context, n int) Segment {
|
||||
func toTokens(ctx *whisper.Context, n int) []Token {
|
||||
result := make([]Token, ctx.Whisper_full_n_tokens(n))
|
||||
for i := 0; i < len(result); i++ {
|
||||
data := ctx.Whisper_full_get_token_data(n, i)
|
||||
|
||||
result[i] = Token{
|
||||
Id: int(ctx.Whisper_full_get_token_id(n, i)),
|
||||
Text: strings.TrimSpace(ctx.Whisper_full_get_token_text(n, i)),
|
||||
P: ctx.Whisper_full_get_token_p(n, i),
|
||||
Id: int(ctx.Whisper_full_get_token_id(n, i)),
|
||||
Text: ctx.Whisper_full_get_token_text(n, i),
|
||||
P: ctx.Whisper_full_get_token_p(n, i),
|
||||
Start: time.Duration(data.T0()) * time.Millisecond * 10,
|
||||
End: time.Duration(data.T1()) * time.Millisecond * 10,
|
||||
}
|
||||
}
|
||||
return result
|
||||
|
@ -29,7 +29,7 @@ type Model interface {
|
||||
|
||||
// Context is the speach recognition context.
|
||||
type Context interface {
|
||||
SetLanguage(string) error // Set the language to use for speech recognition.
|
||||
SetLanguage(string) error // Set the language to use for speech recognition, use "auto" for auto detect language.
|
||||
SetTranslate(bool) // Set translate flag
|
||||
IsMultilingual() bool // Return true if the model is multilingual.
|
||||
Language() string // Get language
|
||||
@ -41,6 +41,7 @@ type Context interface {
|
||||
SetTokenThreshold(float32) // Set timestamp token probability threshold
|
||||
SetTokenSumThreshold(float32) // Set timestamp token sum probability threshold
|
||||
SetMaxSegmentLength(uint) // Set max segment length in characters
|
||||
SetTokenTimestamps(bool) // Set token timestamps flag
|
||||
SetMaxTokensPerSegment(uint) // Set max tokens per segment (0 = no limit)
|
||||
|
||||
// Process mono audio data and return any errors.
|
||||
@ -60,6 +61,12 @@ type Context interface {
|
||||
IsNOT(Token) bool // Test for "No timestamps" token
|
||||
IsLANG(Token, string) bool // Test for token associated with a specific language
|
||||
IsText(Token) bool // Test for text token
|
||||
|
||||
// Timings
|
||||
PrintTimings()
|
||||
ResetTimings()
|
||||
|
||||
SystemInfo() string
|
||||
}
|
||||
|
||||
// Segment is the text result of a speech recognition.
|
||||
@ -79,7 +86,8 @@ type Segment struct {
|
||||
|
||||
// Token is a text or special token
|
||||
type Token struct {
|
||||
Id int
|
||||
Text string
|
||||
P float32
|
||||
Id int
|
||||
Text string
|
||||
P float32
|
||||
Start, End time.Duration
|
||||
}
|
||||
|
@ -94,6 +94,7 @@ func (model *model) NewContext() (Context, error) {
|
||||
params.SetPrintRealtime(false)
|
||||
params.SetPrintTimestamps(false)
|
||||
params.SetThreads(runtime.NumCPU())
|
||||
params.SetNoContext(true)
|
||||
|
||||
// Return new context
|
||||
return newContext(model, params)
|
||||
|
@ -20,7 +20,7 @@ extern bool callEncoderBegin(void* user_data);
|
||||
// Text segment callback
|
||||
// Called on every newly generated text segment
|
||||
// Use the whisper_full_...() functions to obtain the text segments
|
||||
static void whisper_new_segment_cb(struct whisper_context* ctx, int n_new, void* user_data) {
|
||||
static void whisper_new_segment_cb(struct whisper_context* ctx, struct whisper_state* state, int n_new, void* user_data) {
|
||||
if(user_data != NULL && ctx != NULL) {
|
||||
callNewSegment(user_data, n_new);
|
||||
}
|
||||
@ -29,7 +29,7 @@ static void whisper_new_segment_cb(struct whisper_context* ctx, int n_new, void*
|
||||
// Encoder begin callback
|
||||
// If not NULL, called before the encoder starts
|
||||
// If it returns false, the computation is aborted
|
||||
static bool whisper_encoder_begin_cb(struct whisper_context* ctx, void* user_data) {
|
||||
static bool whisper_encoder_begin_cb(struct whisper_context* ctx, struct whisper_state* state, void* user_data) {
|
||||
if(user_data != NULL && ctx != NULL) {
|
||||
return callEncoderBegin(user_data);
|
||||
}
|
||||
@ -356,7 +356,7 @@ func (ctx *Context) Whisper_full_get_token_id(segment int, token int) Token {
|
||||
|
||||
// Get token data for the specified token in the specified segment.
|
||||
// This contains probabilities, timestamps, etc.
|
||||
func (ctx *Context) whisper_full_get_token_data(segment int, token int) TokenData {
|
||||
func (ctx *Context) Whisper_full_get_token_data(segment int, token int) TokenData {
|
||||
return TokenData(C.whisper_full_get_token_data((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
|
||||
}
|
||||
|
||||
@ -407,3 +407,11 @@ func callEncoderBegin(user_data unsafe.Pointer) C.bool {
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (t TokenData) T0() int64 {
|
||||
return int64(t.t0)
|
||||
}
|
||||
|
||||
func (t TokenData) T1() int64 {
|
||||
return int64(t.t1)
|
||||
}
|
||||
|
Submodule bindings/ios updated: f6334b026f...af745e4f2f
@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "whisper.cpp",
|
||||
"version": "1.1.0",
|
||||
"version": "1.4.1",
|
||||
"description": "Whisper speech recognition",
|
||||
"main": "whisper.js",
|
||||
"scripts": {
|
||||
|
File diff suppressed because one or more lines are too long
7
bindings/ruby/ext/.gitignore
vendored
Normal file
7
bindings/ruby/ext/.gitignore
vendored
Normal file
@ -0,0 +1,7 @@
|
||||
Makefile
|
||||
ggml.c
|
||||
ggml.h
|
||||
whisper.bundle
|
||||
whisper.cpp
|
||||
whisper.h
|
||||
dr_wav.h
|
21
bindings/ruby/ext/extconf.rb
Normal file
21
bindings/ruby/ext/extconf.rb
Normal file
@ -0,0 +1,21 @@
|
||||
require 'mkmf'
|
||||
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.cpp')} .")
|
||||
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.h')} .")
|
||||
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.h')} .")
|
||||
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.c')} .")
|
||||
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','examples','dr_wav.h')} .")
|
||||
|
||||
|
||||
# need to use c++ compiler flags
|
||||
$CXXFLAGS << ' -std=c++11'
|
||||
# Set to true when building binary gems
|
||||
if enable_config('static-stdlib', false)
|
||||
$LDFLAGS << ' -static-libgcc -static-libstdc++'
|
||||
end
|
||||
|
||||
if enable_config('march-tune-native', false)
|
||||
$CFLAGS << ' -march=native -mtune=native'
|
||||
$CXXFLAGS << ' -march=native -mtune=native'
|
||||
end
|
||||
|
||||
create_makefile('whisper')
|
426
bindings/ruby/ext/ruby_whisper.cpp
Normal file
426
bindings/ruby/ext/ruby_whisper.cpp
Normal file
@ -0,0 +1,426 @@
|
||||
#include <ruby.h>
|
||||
#include "ruby_whisper.h"
|
||||
#define DR_WAV_IMPLEMENTATION
|
||||
#include "dr_wav.h"
|
||||
#include <cmath>
|
||||
#include <fstream>
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#define BOOL_PARAMS_SETTER(self, prop, value) \
|
||||
ruby_whisper_params *rwp; \
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp); \
|
||||
if (value == Qfalse || value == Qnil) { \
|
||||
rwp->params.prop = false; \
|
||||
} else { \
|
||||
rwp->params.prop = true; \
|
||||
} \
|
||||
return value; \
|
||||
|
||||
#define BOOL_PARAMS_GETTER(self, prop) \
|
||||
ruby_whisper_params *rwp; \
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp); \
|
||||
if (rwp->params.prop) { \
|
||||
return Qtrue; \
|
||||
} else { \
|
||||
return Qfalse; \
|
||||
}
|
||||
|
||||
VALUE mWhisper;
|
||||
VALUE cContext;
|
||||
VALUE cParams;
|
||||
|
||||
static void ruby_whisper_free(ruby_whisper *rw) {
|
||||
if (rw->context) {
|
||||
whisper_free(rw->context);
|
||||
rw->context = NULL;
|
||||
}
|
||||
}
|
||||
static void ruby_whisper_params_free(ruby_whisper_params *rwp) {
|
||||
}
|
||||
|
||||
void rb_whisper_mark(ruby_whisper *rw) {
|
||||
// call rb_gc_mark on any ruby references in rw
|
||||
}
|
||||
|
||||
void rb_whisper_free(ruby_whisper *rw) {
|
||||
ruby_whisper_free(rw);
|
||||
free(rw);
|
||||
}
|
||||
|
||||
void rb_whisper_params_mark(ruby_whisper_params *rwp) {
|
||||
}
|
||||
|
||||
void rb_whisper_params_free(ruby_whisper_params *rwp) {
|
||||
ruby_whisper_params_free(rwp);
|
||||
free(rwp);
|
||||
}
|
||||
|
||||
static VALUE ruby_whisper_allocate(VALUE klass) {
|
||||
ruby_whisper *rw;
|
||||
rw = ALLOC(ruby_whisper);
|
||||
rw->context = NULL;
|
||||
return Data_Wrap_Struct(klass, rb_whisper_mark, rb_whisper_free, rw);
|
||||
}
|
||||
|
||||
static VALUE ruby_whisper_params_allocate(VALUE klass) {
|
||||
ruby_whisper_params *rwp;
|
||||
rwp = ALLOC(ruby_whisper_params);
|
||||
rwp->params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||
return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp);
|
||||
}
|
||||
|
||||
static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) {
|
||||
ruby_whisper *rw;
|
||||
VALUE whisper_model_file_path;
|
||||
|
||||
// TODO: we can support init from buffer here too maybe another ruby object to expose
|
||||
rb_scan_args(argc, argv, "01", &whisper_model_file_path);
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
|
||||
if (!rb_respond_to(whisper_model_file_path, rb_intern("to_s"))) {
|
||||
rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context");
|
||||
}
|
||||
rw->context = whisper_init_from_file(StringValueCStr(whisper_model_file_path));
|
||||
if (rw->context == nullptr) {
|
||||
rb_raise(rb_eRuntimeError, "error: failed to initialize whisper context");
|
||||
}
|
||||
return self;
|
||||
}
|
||||
|
||||
/*
|
||||
* transcribe a single file
|
||||
* can emit to a block results
|
||||
*
|
||||
**/
|
||||
static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
|
||||
ruby_whisper *rw;
|
||||
ruby_whisper_params *rwp;
|
||||
VALUE wave_file_path, blk, params;
|
||||
|
||||
rb_scan_args(argc, argv, "02&", &wave_file_path, ¶ms, &blk);
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
Data_Get_Struct(params, ruby_whisper_params, rwp);
|
||||
|
||||
if (!rb_respond_to(wave_file_path, rb_intern("to_s"))) {
|
||||
rb_raise(rb_eRuntimeError, "Expected file path to wave file");
|
||||
}
|
||||
|
||||
std::string fname_inp = StringValueCStr(wave_file_path);
|
||||
|
||||
std::vector<float> pcmf32; // mono-channel F32 PCM
|
||||
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
|
||||
|
||||
// WAV input - this is directly from main.cpp example
|
||||
{
|
||||
drwav wav;
|
||||
std::vector<uint8_t> wav_data; // used for pipe input from stdin
|
||||
|
||||
if (fname_inp == "-") {
|
||||
{
|
||||
uint8_t buf[1024];
|
||||
while (true) {
|
||||
const size_t n = fread(buf, 1, sizeof(buf), stdin);
|
||||
if (n == 0) {
|
||||
break;
|
||||
}
|
||||
wav_data.insert(wav_data.end(), buf, buf + n);
|
||||
}
|
||||
}
|
||||
|
||||
if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) {
|
||||
fprintf(stderr, "error: failed to open WAV file from stdin\n");
|
||||
return self;
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
|
||||
} else if (drwav_init_file(&wav, fname_inp.c_str(), nullptr) == false) {
|
||||
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
|
||||
return self;
|
||||
}
|
||||
|
||||
if (wav.channels != 1 && wav.channels != 2) {
|
||||
fprintf(stderr, "WAV file '%s' must be mono or stereo\n", fname_inp.c_str());
|
||||
return self;
|
||||
}
|
||||
|
||||
if (rwp->diarize && wav.channels != 2 && rwp->params.print_timestamps == false) {
|
||||
fprintf(stderr, "WAV file '%s' must be stereo for diarization and timestamps have to be enabled\n", fname_inp.c_str());
|
||||
return self;
|
||||
}
|
||||
|
||||
if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
|
||||
fprintf(stderr, "WAV file '%s' must be %i kHz\n", fname_inp.c_str(), WHISPER_SAMPLE_RATE/1000);
|
||||
return self;
|
||||
}
|
||||
|
||||
if (wav.bitsPerSample != 16) {
|
||||
fprintf(stderr, "WAV file '%s' must be 16-bit\n", fname_inp.c_str());
|
||||
return self;
|
||||
}
|
||||
|
||||
const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8);
|
||||
|
||||
std::vector<int16_t> pcm16;
|
||||
pcm16.resize(n*wav.channels);
|
||||
drwav_read_pcm_frames_s16(&wav, n, pcm16.data());
|
||||
drwav_uninit(&wav);
|
||||
|
||||
// convert to mono, float
|
||||
pcmf32.resize(n);
|
||||
if (wav.channels == 1) {
|
||||
for (uint64_t i = 0; i < n; i++) {
|
||||
pcmf32[i] = float(pcm16[i])/32768.0f;
|
||||
}
|
||||
} else {
|
||||
for (uint64_t i = 0; i < n; i++) {
|
||||
pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
|
||||
}
|
||||
}
|
||||
|
||||
if (rwp->diarize) {
|
||||
// convert to stereo, float
|
||||
pcmf32s.resize(2);
|
||||
|
||||
pcmf32s[0].resize(n);
|
||||
pcmf32s[1].resize(n);
|
||||
for (uint64_t i = 0; i < n; i++) {
|
||||
pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
|
||||
pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
{
|
||||
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
|
||||
|
||||
rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
|
||||
bool is_aborted = *(bool*)user_data;
|
||||
return !is_aborted;
|
||||
};
|
||||
rwp->params.encoder_begin_callback_user_data = &is_aborted;
|
||||
}
|
||||
|
||||
if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) {
|
||||
fprintf(stderr, "failed to process audio\n");
|
||||
return self;
|
||||
}
|
||||
const int n_segments = whisper_full_n_segments(rw->context);
|
||||
VALUE output = rb_str_new2("");
|
||||
for (int i = 0; i < n_segments; ++i) {
|
||||
const char * text = whisper_full_get_segment_text(rw->context, i);
|
||||
output = rb_str_concat(output, rb_str_new2(text));
|
||||
}
|
||||
VALUE idCall = rb_intern("call");
|
||||
if (blk != Qnil) {
|
||||
rb_funcall(blk, idCall, 1, output);
|
||||
}
|
||||
return self;
|
||||
}
|
||||
|
||||
/*
|
||||
* params.language = "auto" | "en", etc...
|
||||
*/
|
||||
static VALUE ruby_whisper_params_set_language(VALUE self, VALUE value) {
|
||||
ruby_whisper_params *rwp;
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
if (value == Qfalse || value == Qnil) {
|
||||
rwp->params.language = "auto";
|
||||
} else {
|
||||
rwp->params.language = StringValueCStr(value);
|
||||
}
|
||||
return value;
|
||||
}
|
||||
static VALUE ruby_whisper_params_get_language(VALUE self) {
|
||||
ruby_whisper_params *rwp;
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
if (rwp->params.language) {
|
||||
return rb_str_new2(rwp->params.language);
|
||||
} else {
|
||||
return rb_str_new2("auto");
|
||||
}
|
||||
}
|
||||
static VALUE ruby_whisper_params_set_translate(VALUE self, VALUE value) {
|
||||
BOOL_PARAMS_SETTER(self, translate, value)
|
||||
}
|
||||
static VALUE ruby_whisper_params_get_translate(VALUE self) {
|
||||
BOOL_PARAMS_GETTER(self, translate)
|
||||
}
|
||||
static VALUE ruby_whisper_params_set_no_context(VALUE self, VALUE value) {
|
||||
BOOL_PARAMS_SETTER(self, no_context, value)
|
||||
}
|
||||
static VALUE ruby_whisper_params_get_no_context(VALUE self) {
|
||||
BOOL_PARAMS_GETTER(self, no_context)
|
||||
}
|
||||
static VALUE ruby_whisper_params_set_single_segment(VALUE self, VALUE value) {
|
||||
BOOL_PARAMS_SETTER(self, single_segment, value)
|
||||
}
|
||||
static VALUE ruby_whisper_params_get_single_segment(VALUE self) {
|
||||
BOOL_PARAMS_GETTER(self, single_segment)
|
||||
}
|
||||
static VALUE ruby_whisper_params_set_print_special(VALUE self, VALUE value) {
|
||||
BOOL_PARAMS_SETTER(self, print_special, value)
|
||||
}
|
||||
static VALUE ruby_whisper_params_get_print_special(VALUE self) {
|
||||
BOOL_PARAMS_GETTER(self, print_special)
|
||||
}
|
||||
static VALUE ruby_whisper_params_set_print_progress(VALUE self, VALUE value) {
|
||||
BOOL_PARAMS_SETTER(self, print_progress, value)
|
||||
}
|
||||
static VALUE ruby_whisper_params_get_print_progress(VALUE self) {
|
||||
BOOL_PARAMS_GETTER(self, print_progress)
|
||||
}
|
||||
static VALUE ruby_whisper_params_set_print_realtime(VALUE self, VALUE value) {
|
||||
BOOL_PARAMS_SETTER(self, print_realtime, value)
|
||||
}
|
||||
static VALUE ruby_whisper_params_get_print_realtime(VALUE self) {
|
||||
BOOL_PARAMS_GETTER(self, print_realtime)
|
||||
}
|
||||
static VALUE ruby_whisper_params_set_print_timestamps(VALUE self, VALUE value) {
|
||||
BOOL_PARAMS_SETTER(self, print_timestamps, value)
|
||||
}
|
||||
static VALUE ruby_whisper_params_get_print_timestamps(VALUE self) {
|
||||
BOOL_PARAMS_GETTER(self, print_timestamps)
|
||||
}
|
||||
static VALUE ruby_whisper_params_set_suppress_blank(VALUE self, VALUE value) {
|
||||
BOOL_PARAMS_SETTER(self, suppress_blank, value)
|
||||
}
|
||||
static VALUE ruby_whisper_params_get_suppress_blank(VALUE self) {
|
||||
BOOL_PARAMS_GETTER(self, suppress_blank)
|
||||
}
|
||||
static VALUE ruby_whisper_params_set_suppress_non_speech_tokens(VALUE self, VALUE value) {
|
||||
BOOL_PARAMS_SETTER(self, suppress_non_speech_tokens, value)
|
||||
}
|
||||
static VALUE ruby_whisper_params_get_suppress_non_speech_tokens(VALUE self) {
|
||||
BOOL_PARAMS_GETTER(self, suppress_non_speech_tokens)
|
||||
}
|
||||
static VALUE ruby_whisper_params_get_token_timestamps(VALUE self) {
|
||||
BOOL_PARAMS_GETTER(self, token_timestamps)
|
||||
}
|
||||
static VALUE ruby_whisper_params_set_token_timestamps(VALUE self, VALUE value) {
|
||||
BOOL_PARAMS_SETTER(self, token_timestamps, value)
|
||||
}
|
||||
static VALUE ruby_whisper_params_get_split_on_word(VALUE self) {
|
||||
BOOL_PARAMS_GETTER(self, split_on_word)
|
||||
}
|
||||
static VALUE ruby_whisper_params_set_split_on_word(VALUE self, VALUE value) {
|
||||
BOOL_PARAMS_SETTER(self, split_on_word, value)
|
||||
}
|
||||
static VALUE ruby_whisper_params_get_speed_up(VALUE self) {
|
||||
BOOL_PARAMS_GETTER(self, speed_up)
|
||||
}
|
||||
static VALUE ruby_whisper_params_set_speed_up(VALUE self, VALUE value) {
|
||||
BOOL_PARAMS_SETTER(self, speed_up, value)
|
||||
}
|
||||
static VALUE ruby_whisper_params_get_diarize(VALUE self) {
|
||||
ruby_whisper_params *rwp;
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
if (rwp->diarize) {
|
||||
return Qtrue;
|
||||
} else {
|
||||
return Qfalse;
|
||||
}
|
||||
}
|
||||
static VALUE ruby_whisper_params_set_diarize(VALUE self, VALUE value) {
|
||||
ruby_whisper_params *rwp;
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
if (value == Qfalse || value == Qnil) {
|
||||
rwp->diarize = false;
|
||||
} else {
|
||||
rwp->diarize = true;
|
||||
} \
|
||||
return value;
|
||||
}
|
||||
|
||||
static VALUE ruby_whisper_params_get_offset(VALUE self) {
|
||||
ruby_whisper_params *rwp;
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
return INT2NUM(rwp->params.offset_ms);
|
||||
}
|
||||
static VALUE ruby_whisper_params_set_offset(VALUE self, VALUE value) {
|
||||
ruby_whisper_params *rwp;
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
rwp->params.offset_ms = NUM2INT(value);
|
||||
return value;
|
||||
}
|
||||
static VALUE ruby_whisper_params_get_duration(VALUE self) {
|
||||
ruby_whisper_params *rwp;
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
return INT2NUM(rwp->params.duration_ms);
|
||||
}
|
||||
static VALUE ruby_whisper_params_set_duration(VALUE self, VALUE value) {
|
||||
ruby_whisper_params *rwp;
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
rwp->params.duration_ms = NUM2INT(value);
|
||||
return value;
|
||||
}
|
||||
|
||||
static VALUE ruby_whisper_params_get_max_text_tokens(VALUE self) {
|
||||
ruby_whisper_params *rwp;
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
return INT2NUM(rwp->params.n_max_text_ctx);
|
||||
}
|
||||
static VALUE ruby_whisper_params_set_max_text_tokens(VALUE self, VALUE value) {
|
||||
ruby_whisper_params *rwp;
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
rwp->params.n_max_text_ctx = NUM2INT(value);
|
||||
return value;
|
||||
}
|
||||
|
||||
void Init_whisper() {
|
||||
mWhisper = rb_define_module("Whisper");
|
||||
cContext = rb_define_class_under(mWhisper, "Context", rb_cObject);
|
||||
cParams = rb_define_class_under(mWhisper, "Params", rb_cObject);
|
||||
|
||||
rb_define_alloc_func(cContext, ruby_whisper_allocate);
|
||||
rb_define_method(cContext, "initialize", ruby_whisper_initialize, -1);
|
||||
|
||||
rb_define_method(cContext, "transcribe", ruby_whisper_transcribe, -1);
|
||||
|
||||
rb_define_alloc_func(cParams, ruby_whisper_params_allocate);
|
||||
|
||||
rb_define_method(cParams, "language=", ruby_whisper_params_set_language, 1);
|
||||
rb_define_method(cParams, "language", ruby_whisper_params_get_language, 0);
|
||||
rb_define_method(cParams, "translate=", ruby_whisper_params_set_translate, 1);
|
||||
rb_define_method(cParams, "translate", ruby_whisper_params_get_translate, 0);
|
||||
rb_define_method(cParams, "no_context=", ruby_whisper_params_set_no_context, 1);
|
||||
rb_define_method(cParams, "no_context", ruby_whisper_params_get_no_context, 0);
|
||||
rb_define_method(cParams, "single_segment=", ruby_whisper_params_set_single_segment, 1);
|
||||
rb_define_method(cParams, "single_segment", ruby_whisper_params_get_single_segment, 0);
|
||||
rb_define_method(cParams, "print_special", ruby_whisper_params_get_print_special, 0);
|
||||
rb_define_method(cParams, "print_special=", ruby_whisper_params_set_print_special, 1);
|
||||
rb_define_method(cParams, "print_progress", ruby_whisper_params_get_print_progress, 0);
|
||||
rb_define_method(cParams, "print_progress=", ruby_whisper_params_set_print_progress, 1);
|
||||
rb_define_method(cParams, "print_realtime", ruby_whisper_params_get_print_realtime, 0);
|
||||
rb_define_method(cParams, "print_realtime=", ruby_whisper_params_set_print_realtime, 1);
|
||||
rb_define_method(cParams, "print_timestamps", ruby_whisper_params_get_print_timestamps, 0);
|
||||
rb_define_method(cParams, "print_timestamps=", ruby_whisper_params_set_print_timestamps, 1);
|
||||
rb_define_method(cParams, "suppress_blank", ruby_whisper_params_get_suppress_blank, 0);
|
||||
rb_define_method(cParams, "suppress_blank=", ruby_whisper_params_set_suppress_blank, 1);
|
||||
rb_define_method(cParams, "suppress_non_speech_tokens", ruby_whisper_params_get_suppress_non_speech_tokens, 0);
|
||||
rb_define_method(cParams, "suppress_non_speech_tokens=", ruby_whisper_params_set_suppress_non_speech_tokens, 1);
|
||||
rb_define_method(cParams, "token_timestamps", ruby_whisper_params_get_token_timestamps, 0);
|
||||
rb_define_method(cParams, "token_timestamps=", ruby_whisper_params_set_token_timestamps, 1);
|
||||
rb_define_method(cParams, "split_on_word", ruby_whisper_params_get_split_on_word, 0);
|
||||
rb_define_method(cParams, "split_on_word=", ruby_whisper_params_set_split_on_word, 1);
|
||||
rb_define_method(cParams, "speed_up", ruby_whisper_params_get_speed_up, 0);
|
||||
rb_define_method(cParams, "speed_up=", ruby_whisper_params_set_speed_up, 1);
|
||||
rb_define_method(cParams, "diarize", ruby_whisper_params_get_diarize, 0);
|
||||
rb_define_method(cParams, "diarize=", ruby_whisper_params_set_diarize, 1);
|
||||
|
||||
rb_define_method(cParams, "offset", ruby_whisper_params_get_offset, 0);
|
||||
rb_define_method(cParams, "offset=", ruby_whisper_params_set_offset, 1);
|
||||
rb_define_method(cParams, "duration", ruby_whisper_params_get_duration, 0);
|
||||
rb_define_method(cParams, "duration=", ruby_whisper_params_set_duration, 1);
|
||||
|
||||
rb_define_method(cParams, "max_text_tokens", ruby_whisper_params_get_max_text_tokens, 0);
|
||||
rb_define_method(cParams, "max_text_tokens=", ruby_whisper_params_set_max_text_tokens, 1);
|
||||
}
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
15
bindings/ruby/ext/ruby_whisper.h
Normal file
15
bindings/ruby/ext/ruby_whisper.h
Normal file
@ -0,0 +1,15 @@
|
||||
#ifndef __RUBY_WHISPER_H
|
||||
#define __RUBY_WHISPER_H
|
||||
|
||||
#include "whisper.h"
|
||||
|
||||
typedef struct {
|
||||
struct whisper_context *context;
|
||||
} ruby_whisper;
|
||||
|
||||
typedef struct {
|
||||
struct whisper_full_params params;
|
||||
bool diarize;
|
||||
} ruby_whisper_params;
|
||||
|
||||
#endif
|
138
bindings/ruby/tests/test_whisper.rb
Normal file
138
bindings/ruby/tests/test_whisper.rb
Normal file
@ -0,0 +1,138 @@
|
||||
TOPDIR = File.expand_path(File.join(File.dirname(__FILE__), '..'))
|
||||
EXTDIR = File.join(TOPDIR, 'ext')
|
||||
#$LIBDIR = File.join(TOPDIR, 'lib')
|
||||
#$:.unshift(LIBDIR)
|
||||
$:.unshift(EXTDIR)
|
||||
|
||||
require 'whisper'
|
||||
require 'test/unit'
|
||||
|
||||
class TestWhisper < Test::Unit::TestCase
|
||||
def setup
|
||||
@params = Whisper::Params.new
|
||||
end
|
||||
|
||||
def test_language
|
||||
@params.language = "en"
|
||||
assert_equal @params.language, "en"
|
||||
@params.language = "auto"
|
||||
assert_equal @params.language, "auto"
|
||||
end
|
||||
|
||||
def test_offset
|
||||
@params.offset = 10_000
|
||||
assert_equal @params.offset, 10_000
|
||||
@params.offset = 0
|
||||
assert_equal @params.offset, 0
|
||||
end
|
||||
|
||||
def test_duration
|
||||
@params.duration = 60_000
|
||||
assert_equal @params.duration, 60_000
|
||||
@params.duration = 0
|
||||
assert_equal @params.duration, 0
|
||||
end
|
||||
|
||||
def test_max_text_tokens
|
||||
@params.max_text_tokens = 300
|
||||
assert_equal @params.max_text_tokens, 300
|
||||
@params.max_text_tokens = 0
|
||||
assert_equal @params.max_text_tokens, 0
|
||||
end
|
||||
|
||||
def test_translate
|
||||
@params.translate = true
|
||||
assert @params.translate
|
||||
@params.translate = false
|
||||
assert !@params.translate
|
||||
end
|
||||
|
||||
def test_no_context
|
||||
@params.no_context = true
|
||||
assert @params.no_context
|
||||
@params.no_context = false
|
||||
assert !@params.no_context
|
||||
end
|
||||
|
||||
def test_single_segment
|
||||
@params.single_segment = true
|
||||
assert @params.single_segment
|
||||
@params.single_segment = false
|
||||
assert !@params.single_segment
|
||||
end
|
||||
|
||||
def test_print_special
|
||||
@params.print_special = true
|
||||
assert @params.print_special
|
||||
@params.print_special = false
|
||||
assert !@params.print_special
|
||||
end
|
||||
|
||||
def test_print_progress
|
||||
@params.print_progress = true
|
||||
assert @params.print_progress
|
||||
@params.print_progress = false
|
||||
assert !@params.print_progress
|
||||
end
|
||||
|
||||
def test_print_realtime
|
||||
@params.print_realtime = true
|
||||
assert @params.print_realtime
|
||||
@params.print_realtime = false
|
||||
assert !@params.print_realtime
|
||||
end
|
||||
|
||||
def test_print_timestamps
|
||||
@params.print_timestamps = true
|
||||
assert @params.print_timestamps
|
||||
@params.print_timestamps = false
|
||||
assert !@params.print_timestamps
|
||||
end
|
||||
|
||||
def test_suppress_blank
|
||||
@params.suppress_blank = true
|
||||
assert @params.suppress_blank
|
||||
@params.suppress_blank = false
|
||||
assert !@params.suppress_blank
|
||||
end
|
||||
|
||||
def test_suppress_non_speech_tokens
|
||||
@params.suppress_non_speech_tokens = true
|
||||
assert @params.suppress_non_speech_tokens
|
||||
@params.suppress_non_speech_tokens = false
|
||||
assert !@params.suppress_non_speech_tokens
|
||||
end
|
||||
|
||||
def test_token_timestamps
|
||||
@params.token_timestamps = true
|
||||
assert @params.token_timestamps
|
||||
@params.token_timestamps = false
|
||||
assert !@params.token_timestamps
|
||||
end
|
||||
|
||||
def test_split_on_word
|
||||
@params.split_on_word = true
|
||||
assert @params.split_on_word
|
||||
@params.split_on_word = false
|
||||
assert !@params.split_on_word
|
||||
end
|
||||
|
||||
def test_speed_up
|
||||
@params.speed_up = true
|
||||
assert @params.speed_up
|
||||
@params.speed_up = false
|
||||
assert !@params.speed_up
|
||||
end
|
||||
|
||||
def test_whisper
|
||||
@whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin'))
|
||||
params = Whisper::Params.new
|
||||
params.print_timestamps = false
|
||||
|
||||
jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav')
|
||||
@whisper.transcribe(jfk, params) {|text|
|
||||
assert_match /ask not what your country can do for you, ask what you can do for your country/, text
|
||||
}
|
||||
end
|
||||
|
||||
end
|
5
contrib/debian/control
Normal file
5
contrib/debian/control
Normal file
@ -0,0 +1,5 @@
|
||||
Package: whisper-small-cpp
|
||||
Architecture: amd64
|
||||
Maintainer: Alexey Kharlamov <alexey@kharlamov.biz>
|
||||
Description: Whisper Speech to Text Converter
|
||||
Depends: libc6 (>= 2.2.1), intel-mkl
|
146
coreml/whisper-decoder-impl.h
Normal file
146
coreml/whisper-decoder-impl.h
Normal file
@ -0,0 +1,146 @@
|
||||
//
|
||||
// whisper-decoder-impl.h
|
||||
//
|
||||
// This file was automatically generated and should not be edited.
|
||||
//
|
||||
|
||||
#import <Foundation/Foundation.h>
|
||||
#import <CoreML/CoreML.h>
|
||||
#include <stdint.h>
|
||||
#include <os/log.h>
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
|
||||
/// Model Prediction Input Type
|
||||
API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
|
||||
@interface whisper_decoder_implInput : NSObject<MLFeatureProvider>
|
||||
|
||||
/// token_data as 1 by 1 matrix of 32-bit integers
|
||||
@property (readwrite, nonatomic, strong) MLMultiArray * token_data;
|
||||
|
||||
/// audio_data as 1 × 384 × 1 × 1500 4-dimensional array of floats
|
||||
@property (readwrite, nonatomic, strong) MLMultiArray * audio_data;
|
||||
- (instancetype)init NS_UNAVAILABLE;
|
||||
- (instancetype)initWithToken_data:(MLMultiArray *)token_data audio_data:(MLMultiArray *)audio_data NS_DESIGNATED_INITIALIZER;
|
||||
|
||||
@end
|
||||
|
||||
|
||||
/// Model Prediction Output Type
|
||||
API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
|
||||
@interface whisper_decoder_implOutput : NSObject<MLFeatureProvider>
|
||||
|
||||
/// var_1346 as multidimensional array of floats
|
||||
@property (readwrite, nonatomic, strong) MLMultiArray * var_1346;
|
||||
- (instancetype)init NS_UNAVAILABLE;
|
||||
- (instancetype)initWithVar_1346:(MLMultiArray *)var_1346 NS_DESIGNATED_INITIALIZER;
|
||||
|
||||
@end
|
||||
|
||||
|
||||
/// Class for model loading and prediction
|
||||
API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
|
||||
@interface whisper_decoder_impl : NSObject
|
||||
@property (readonly, nonatomic, nullable) MLModel * model;
|
||||
|
||||
/**
|
||||
URL of the underlying .mlmodelc directory.
|
||||
*/
|
||||
+ (nullable NSURL *)URLOfModelInThisBundle;
|
||||
|
||||
/**
|
||||
Initialize whisper_decoder_impl instance from an existing MLModel object.
|
||||
|
||||
Usually the application does not use this initializer unless it makes a subclass of whisper_decoder_impl.
|
||||
Such application may want to use `-[MLModel initWithContentsOfURL:configuration:error:]` and `+URLOfModelInThisBundle` to create a MLModel object to pass-in.
|
||||
*/
|
||||
- (instancetype)initWithMLModel:(MLModel *)model NS_DESIGNATED_INITIALIZER;
|
||||
|
||||
/**
|
||||
Initialize whisper_decoder_impl instance with the model in this bundle.
|
||||
*/
|
||||
- (nullable instancetype)init;
|
||||
|
||||
/**
|
||||
Initialize whisper_decoder_impl instance with the model in this bundle.
|
||||
|
||||
@param configuration The model configuration object
|
||||
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
|
||||
*/
|
||||
- (nullable instancetype)initWithConfiguration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error;
|
||||
|
||||
/**
|
||||
Initialize whisper_decoder_impl instance from the model URL.
|
||||
|
||||
@param modelURL URL to the .mlmodelc directory for whisper_decoder_impl.
|
||||
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
|
||||
*/
|
||||
- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL error:(NSError * _Nullable __autoreleasing * _Nullable)error;
|
||||
|
||||
/**
|
||||
Initialize whisper_decoder_impl instance from the model URL.
|
||||
|
||||
@param modelURL URL to the .mlmodelc directory for whisper_decoder_impl.
|
||||
@param configuration The model configuration object
|
||||
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
|
||||
*/
|
||||
- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error;
|
||||
|
||||
/**
|
||||
Construct whisper_decoder_impl instance asynchronously with configuration.
|
||||
Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread.
|
||||
|
||||
@param configuration The model configuration
|
||||
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_decoder_impl instance or NSError object.
|
||||
*/
|
||||
+ (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_decoder_impl * _Nullable model, NSError * _Nullable error))handler;
|
||||
|
||||
/**
|
||||
Construct whisper_decoder_impl instance asynchronously with URL of .mlmodelc directory and optional configuration.
|
||||
|
||||
Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread.
|
||||
|
||||
@param modelURL The model URL.
|
||||
@param configuration The model configuration
|
||||
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_decoder_impl instance or NSError object.
|
||||
*/
|
||||
+ (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_decoder_impl * _Nullable model, NSError * _Nullable error))handler;
|
||||
|
||||
/**
|
||||
Make a prediction using the standard interface
|
||||
@param input an instance of whisper_decoder_implInput to predict from
|
||||
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
|
||||
@return the prediction as whisper_decoder_implOutput
|
||||
*/
|
||||
- (nullable whisper_decoder_implOutput *)predictionFromFeatures:(whisper_decoder_implInput *)input error:(NSError * _Nullable __autoreleasing * _Nullable)error;
|
||||
|
||||
/**
|
||||
Make a prediction using the standard interface
|
||||
@param input an instance of whisper_decoder_implInput to predict from
|
||||
@param options prediction options
|
||||
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
|
||||
@return the prediction as whisper_decoder_implOutput
|
||||
*/
|
||||
- (nullable whisper_decoder_implOutput *)predictionFromFeatures:(whisper_decoder_implInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error;
|
||||
|
||||
/**
|
||||
Make a prediction using the convenience interface
|
||||
@param token_data as 1 by 1 matrix of 32-bit integers:
|
||||
@param audio_data as 1 × 384 × 1 × 1500 4-dimensional array of floats:
|
||||
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
|
||||
@return the prediction as whisper_decoder_implOutput
|
||||
*/
|
||||
- (nullable whisper_decoder_implOutput *)predictionFromToken_data:(MLMultiArray *)token_data audio_data:(MLMultiArray *)audio_data error:(NSError * _Nullable __autoreleasing * _Nullable)error;
|
||||
|
||||
/**
|
||||
Batch prediction
|
||||
@param inputArray array of whisper_decoder_implInput instances to obtain predictions from
|
||||
@param options prediction options
|
||||
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
|
||||
@return the predictions as NSArray<whisper_decoder_implOutput *>
|
||||
*/
|
||||
- (nullable NSArray<whisper_decoder_implOutput *> *)predictionsFromInputs:(NSArray<whisper_decoder_implInput*> *)inputArray options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error;
|
||||
@end
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
201
coreml/whisper-decoder-impl.m
Normal file
201
coreml/whisper-decoder-impl.m
Normal file
@ -0,0 +1,201 @@
|
||||
//
|
||||
// whisper-decoder-impl.m
|
||||
//
|
||||
// This file was automatically generated and should not be edited.
|
||||
//
|
||||
|
||||
#if !__has_feature(objc_arc)
|
||||
#error This file must be compiled with automatic reference counting enabled (-fobjc-arc)
|
||||
#endif
|
||||
|
||||
#import "whisper-decoder-impl.h"
|
||||
|
||||
@implementation whisper_decoder_implInput
|
||||
|
||||
- (instancetype)initWithToken_data:(MLMultiArray *)token_data audio_data:(MLMultiArray *)audio_data {
|
||||
self = [super init];
|
||||
if (self) {
|
||||
_token_data = token_data;
|
||||
_audio_data = audio_data;
|
||||
}
|
||||
return self;
|
||||
}
|
||||
|
||||
- (NSSet<NSString *> *)featureNames {
|
||||
return [NSSet setWithArray:@[@"token_data", @"audio_data"]];
|
||||
}
|
||||
|
||||
- (nullable MLFeatureValue *)featureValueForName:(NSString *)featureName {
|
||||
if ([featureName isEqualToString:@"token_data"]) {
|
||||
return [MLFeatureValue featureValueWithMultiArray:self.token_data];
|
||||
}
|
||||
if ([featureName isEqualToString:@"audio_data"]) {
|
||||
return [MLFeatureValue featureValueWithMultiArray:self.audio_data];
|
||||
}
|
||||
return nil;
|
||||
}
|
||||
|
||||
@end
|
||||
|
||||
@implementation whisper_decoder_implOutput
|
||||
|
||||
- (instancetype)initWithVar_1346:(MLMultiArray *)var_1346 {
|
||||
self = [super init];
|
||||
if (self) {
|
||||
_var_1346 = var_1346;
|
||||
}
|
||||
return self;
|
||||
}
|
||||
|
||||
- (NSSet<NSString *> *)featureNames {
|
||||
return [NSSet setWithArray:@[@"var_1346"]];
|
||||
}
|
||||
|
||||
- (nullable MLFeatureValue *)featureValueForName:(NSString *)featureName {
|
||||
if ([featureName isEqualToString:@"var_1346"]) {
|
||||
return [MLFeatureValue featureValueWithMultiArray:self.var_1346];
|
||||
}
|
||||
return nil;
|
||||
}
|
||||
|
||||
@end
|
||||
|
||||
@implementation whisper_decoder_impl
|
||||
|
||||
|
||||
/**
|
||||
URL of the underlying .mlmodelc directory.
|
||||
*/
|
||||
+ (nullable NSURL *)URLOfModelInThisBundle {
|
||||
NSString *assetPath = [[NSBundle bundleForClass:[self class]] pathForResource:@"whisper_decoder_impl" ofType:@"mlmodelc"];
|
||||
if (nil == assetPath) { os_log_error(OS_LOG_DEFAULT, "Could not load whisper-decoder-impl.mlmodelc in the bundle resource"); return nil; }
|
||||
return [NSURL fileURLWithPath:assetPath];
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
Initialize whisper_decoder_impl instance from an existing MLModel object.
|
||||
|
||||
Usually the application does not use this initializer unless it makes a subclass of whisper_decoder_impl.
|
||||
Such application may want to use `-[MLModel initWithContentsOfURL:configuration:error:]` and `+URLOfModelInThisBundle` to create a MLModel object to pass-in.
|
||||
*/
|
||||
- (instancetype)initWithMLModel:(MLModel *)model {
|
||||
self = [super init];
|
||||
if (!self) { return nil; }
|
||||
_model = model;
|
||||
if (_model == nil) { return nil; }
|
||||
return self;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
Initialize whisper_decoder_impl instance with the model in this bundle.
|
||||
*/
|
||||
- (nullable instancetype)init {
|
||||
return [self initWithContentsOfURL:(NSURL * _Nonnull)self.class.URLOfModelInThisBundle error:nil];
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
Initialize whisper_decoder_impl instance with the model in this bundle.
|
||||
|
||||
@param configuration The model configuration object
|
||||
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
|
||||
*/
|
||||
- (nullable instancetype)initWithConfiguration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error {
|
||||
return [self initWithContentsOfURL:(NSURL * _Nonnull)self.class.URLOfModelInThisBundle configuration:configuration error:error];
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
Initialize whisper_decoder_impl instance from the model URL.
|
||||
|
||||
@param modelURL URL to the .mlmodelc directory for whisper_decoder_impl.
|
||||
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
|
||||
*/
|
||||
- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL error:(NSError * _Nullable __autoreleasing * _Nullable)error {
|
||||
MLModel *model = [MLModel modelWithContentsOfURL:modelURL error:error];
|
||||
if (model == nil) { return nil; }
|
||||
return [self initWithMLModel:model];
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
Initialize whisper_decoder_impl instance from the model URL.
|
||||
|
||||
@param modelURL URL to the .mlmodelc directory for whisper_decoder_impl.
|
||||
@param configuration The model configuration object
|
||||
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
|
||||
*/
|
||||
- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error {
|
||||
MLModel *model = [MLModel modelWithContentsOfURL:modelURL configuration:configuration error:error];
|
||||
if (model == nil) { return nil; }
|
||||
return [self initWithMLModel:model];
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
Construct whisper_decoder_impl instance asynchronously with configuration.
|
||||
Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread.
|
||||
|
||||
@param configuration The model configuration
|
||||
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_decoder_impl instance or NSError object.
|
||||
*/
|
||||
+ (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_decoder_impl * _Nullable model, NSError * _Nullable error))handler {
|
||||
[self loadContentsOfURL:(NSURL * _Nonnull)[self URLOfModelInThisBundle]
|
||||
configuration:configuration
|
||||
completionHandler:handler];
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
Construct whisper_decoder_impl instance asynchronously with URL of .mlmodelc directory and optional configuration.
|
||||
|
||||
Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread.
|
||||
|
||||
@param modelURL The model URL.
|
||||
@param configuration The model configuration
|
||||
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_decoder_impl instance or NSError object.
|
||||
*/
|
||||
+ (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_decoder_impl * _Nullable model, NSError * _Nullable error))handler {
|
||||
[MLModel loadContentsOfURL:modelURL
|
||||
configuration:configuration
|
||||
completionHandler:^(MLModel *model, NSError *error) {
|
||||
if (model != nil) {
|
||||
whisper_decoder_impl *typedModel = [[whisper_decoder_impl alloc] initWithMLModel:model];
|
||||
handler(typedModel, nil);
|
||||
} else {
|
||||
handler(nil, error);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
- (nullable whisper_decoder_implOutput *)predictionFromFeatures:(whisper_decoder_implInput *)input error:(NSError * _Nullable __autoreleasing * _Nullable)error {
|
||||
return [self predictionFromFeatures:input options:[[MLPredictionOptions alloc] init] error:error];
|
||||
}
|
||||
|
||||
- (nullable whisper_decoder_implOutput *)predictionFromFeatures:(whisper_decoder_implInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error {
|
||||
id<MLFeatureProvider> outFeatures = [self.model predictionFromFeatures:input options:options error:error];
|
||||
if (!outFeatures) { return nil; }
|
||||
return [[whisper_decoder_implOutput alloc] initWithVar_1346:(MLMultiArray *)[outFeatures featureValueForName:@"var_1346"].multiArrayValue];
|
||||
}
|
||||
|
||||
- (nullable whisper_decoder_implOutput *)predictionFromToken_data:(MLMultiArray *)token_data audio_data:(MLMultiArray *)audio_data error:(NSError * _Nullable __autoreleasing * _Nullable)error {
|
||||
whisper_decoder_implInput *input_ = [[whisper_decoder_implInput alloc] initWithToken_data:token_data audio_data:audio_data];
|
||||
return [self predictionFromFeatures:input_ error:error];
|
||||
}
|
||||
|
||||
- (nullable NSArray<whisper_decoder_implOutput *> *)predictionsFromInputs:(NSArray<whisper_decoder_implInput*> *)inputArray options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error {
|
||||
id<MLBatchProvider> inBatch = [[MLArrayBatchProvider alloc] initWithFeatureProviderArray:inputArray];
|
||||
id<MLBatchProvider> outBatch = [self.model predictionsFromBatch:inBatch options:options error:error];
|
||||
if (!outBatch) { return nil; }
|
||||
NSMutableArray<whisper_decoder_implOutput*> *results = [NSMutableArray arrayWithCapacity:(NSUInteger)outBatch.count];
|
||||
for (NSInteger i = 0; i < outBatch.count; i++) {
|
||||
id<MLFeatureProvider> resultProvider = [outBatch featuresAtIndex:i];
|
||||
whisper_decoder_implOutput * result = [[whisper_decoder_implOutput alloc] initWithVar_1346:(MLMultiArray *)[resultProvider featureValueForName:@"var_1346"].multiArrayValue];
|
||||
[results addObject:result];
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
@end
|
142
coreml/whisper-encoder-impl.h
Normal file
142
coreml/whisper-encoder-impl.h
Normal file
@ -0,0 +1,142 @@
|
||||
//
|
||||
// whisper-encoder-impl.h
|
||||
//
|
||||
// This file was automatically generated and should not be edited.
|
||||
//
|
||||
|
||||
#import <Foundation/Foundation.h>
|
||||
#import <CoreML/CoreML.h>
|
||||
#include <stdint.h>
|
||||
#include <os/log.h>
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
|
||||
/// Model Prediction Input Type
|
||||
API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
|
||||
@interface whisper_encoder_implInput : NSObject<MLFeatureProvider>
|
||||
|
||||
/// logmel_data as 1 × 80 × 3000 3-dimensional array of floats
|
||||
@property (readwrite, nonatomic, strong) MLMultiArray * logmel_data;
|
||||
- (instancetype)init NS_UNAVAILABLE;
|
||||
- (instancetype)initWithLogmel_data:(MLMultiArray *)logmel_data NS_DESIGNATED_INITIALIZER;
|
||||
|
||||
@end
|
||||
|
||||
|
||||
/// Model Prediction Output Type
|
||||
API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
|
||||
@interface whisper_encoder_implOutput : NSObject<MLFeatureProvider>
|
||||
|
||||
/// output as multidimensional array of floats
|
||||
@property (readwrite, nonatomic, strong) MLMultiArray * output;
|
||||
- (instancetype)init NS_UNAVAILABLE;
|
||||
- (instancetype)initWithOutput:(MLMultiArray *)output NS_DESIGNATED_INITIALIZER;
|
||||
|
||||
@end
|
||||
|
||||
|
||||
/// Class for model loading and prediction
|
||||
API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
|
||||
@interface whisper_encoder_impl : NSObject
|
||||
@property (readonly, nonatomic, nullable) MLModel * model;
|
||||
|
||||
/**
|
||||
URL of the underlying .mlmodelc directory.
|
||||
*/
|
||||
+ (nullable NSURL *)URLOfModelInThisBundle;
|
||||
|
||||
/**
|
||||
Initialize whisper_encoder_impl instance from an existing MLModel object.
|
||||
|
||||
Usually the application does not use this initializer unless it makes a subclass of whisper_encoder_impl.
|
||||
Such application may want to use `-[MLModel initWithContentsOfURL:configuration:error:]` and `+URLOfModelInThisBundle` to create a MLModel object to pass-in.
|
||||
*/
|
||||
- (instancetype)initWithMLModel:(MLModel *)model NS_DESIGNATED_INITIALIZER;
|
||||
|
||||
/**
|
||||
Initialize whisper_encoder_impl instance with the model in this bundle.
|
||||
*/
|
||||
- (nullable instancetype)init;
|
||||
|
||||
/**
|
||||
Initialize whisper_encoder_impl instance with the model in this bundle.
|
||||
|
||||
@param configuration The model configuration object
|
||||
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
|
||||
*/
|
||||
- (nullable instancetype)initWithConfiguration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error;
|
||||
|
||||
/**
|
||||
Initialize whisper_encoder_impl instance from the model URL.
|
||||
|
||||
@param modelURL URL to the .mlmodelc directory for whisper_encoder_impl.
|
||||
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
|
||||
*/
|
||||
- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL error:(NSError * _Nullable __autoreleasing * _Nullable)error;
|
||||
|
||||
/**
|
||||
Initialize whisper_encoder_impl instance from the model URL.
|
||||
|
||||
@param modelURL URL to the .mlmodelc directory for whisper_encoder_impl.
|
||||
@param configuration The model configuration object
|
||||
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
|
||||
*/
|
||||
- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error;
|
||||
|
||||
/**
|
||||
Construct whisper_encoder_impl instance asynchronously with configuration.
|
||||
Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread.
|
||||
|
||||
@param configuration The model configuration
|
||||
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_encoder_impl instance or NSError object.
|
||||
*/
|
||||
+ (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_encoder_impl * _Nullable model, NSError * _Nullable error))handler;
|
||||
|
||||
/**
|
||||
Construct whisper_encoder_impl instance asynchronously with URL of .mlmodelc directory and optional configuration.
|
||||
|
||||
Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread.
|
||||
|
||||
@param modelURL The model URL.
|
||||
@param configuration The model configuration
|
||||
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_encoder_impl instance or NSError object.
|
||||
*/
|
||||
+ (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_encoder_impl * _Nullable model, NSError * _Nullable error))handler;
|
||||
|
||||
/**
|
||||
Make a prediction using the standard interface
|
||||
@param input an instance of whisper_encoder_implInput to predict from
|
||||
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
|
||||
@return the prediction as whisper_encoder_implOutput
|
||||
*/
|
||||
- (nullable whisper_encoder_implOutput *)predictionFromFeatures:(whisper_encoder_implInput *)input error:(NSError * _Nullable __autoreleasing * _Nullable)error;
|
||||
|
||||
/**
|
||||
Make a prediction using the standard interface
|
||||
@param input an instance of whisper_encoder_implInput to predict from
|
||||
@param options prediction options
|
||||
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
|
||||
@return the prediction as whisper_encoder_implOutput
|
||||
*/
|
||||
- (nullable whisper_encoder_implOutput *)predictionFromFeatures:(whisper_encoder_implInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error;
|
||||
|
||||
/**
|
||||
Make a prediction using the convenience interface
|
||||
@param logmel_data as 1 × 80 × 3000 3-dimensional array of floats:
|
||||
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
|
||||
@return the prediction as whisper_encoder_implOutput
|
||||
*/
|
||||
- (nullable whisper_encoder_implOutput *)predictionFromLogmel_data:(MLMultiArray *)logmel_data error:(NSError * _Nullable __autoreleasing * _Nullable)error;
|
||||
|
||||
/**
|
||||
Batch prediction
|
||||
@param inputArray array of whisper_encoder_implInput instances to obtain predictions from
|
||||
@param options prediction options
|
||||
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
|
||||
@return the predictions as NSArray<whisper_encoder_implOutput *>
|
||||
*/
|
||||
- (nullable NSArray<whisper_encoder_implOutput *> *)predictionsFromInputs:(NSArray<whisper_encoder_implInput*> *)inputArray options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error;
|
||||
@end
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
197
coreml/whisper-encoder-impl.m
Normal file
197
coreml/whisper-encoder-impl.m
Normal file
@ -0,0 +1,197 @@
|
||||
//
|
||||
// whisper-encoder-impl.m
|
||||
//
|
||||
// This file was automatically generated and should not be edited.
|
||||
//
|
||||
|
||||
#if !__has_feature(objc_arc)
|
||||
#error This file must be compiled with automatic reference counting enabled (-fobjc-arc)
|
||||
#endif
|
||||
|
||||
#import "whisper-encoder-impl.h"
|
||||
|
||||
@implementation whisper_encoder_implInput
|
||||
|
||||
- (instancetype)initWithLogmel_data:(MLMultiArray *)logmel_data {
|
||||
self = [super init];
|
||||
if (self) {
|
||||
_logmel_data = logmel_data;
|
||||
}
|
||||
return self;
|
||||
}
|
||||
|
||||
- (NSSet<NSString *> *)featureNames {
|
||||
return [NSSet setWithArray:@[@"logmel_data"]];
|
||||
}
|
||||
|
||||
- (nullable MLFeatureValue *)featureValueForName:(NSString *)featureName {
|
||||
if ([featureName isEqualToString:@"logmel_data"]) {
|
||||
return [MLFeatureValue featureValueWithMultiArray:self.logmel_data];
|
||||
}
|
||||
return nil;
|
||||
}
|
||||
|
||||
@end
|
||||
|
||||
@implementation whisper_encoder_implOutput
|
||||
|
||||
- (instancetype)initWithOutput:(MLMultiArray *)output {
|
||||
self = [super init];
|
||||
if (self) {
|
||||
_output = output;
|
||||
}
|
||||
return self;
|
||||
}
|
||||
|
||||
- (NSSet<NSString *> *)featureNames {
|
||||
return [NSSet setWithArray:@[@"output"]];
|
||||
}
|
||||
|
||||
- (nullable MLFeatureValue *)featureValueForName:(NSString *)featureName {
|
||||
if ([featureName isEqualToString:@"output"]) {
|
||||
return [MLFeatureValue featureValueWithMultiArray:self.output];
|
||||
}
|
||||
return nil;
|
||||
}
|
||||
|
||||
@end
|
||||
|
||||
@implementation whisper_encoder_impl
|
||||
|
||||
|
||||
/**
|
||||
URL of the underlying .mlmodelc directory.
|
||||
*/
|
||||
+ (nullable NSURL *)URLOfModelInThisBundle {
|
||||
NSString *assetPath = [[NSBundle bundleForClass:[self class]] pathForResource:@"whisper_encoder_impl" ofType:@"mlmodelc"];
|
||||
if (nil == assetPath) { os_log_error(OS_LOG_DEFAULT, "Could not load whisper-encoder-impl.mlmodelc in the bundle resource"); return nil; }
|
||||
return [NSURL fileURLWithPath:assetPath];
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
Initialize whisper_encoder_impl instance from an existing MLModel object.
|
||||
|
||||
Usually the application does not use this initializer unless it makes a subclass of whisper_encoder_impl.
|
||||
Such application may want to use `-[MLModel initWithContentsOfURL:configuration:error:]` and `+URLOfModelInThisBundle` to create a MLModel object to pass-in.
|
||||
*/
|
||||
- (instancetype)initWithMLModel:(MLModel *)model {
|
||||
self = [super init];
|
||||
if (!self) { return nil; }
|
||||
_model = model;
|
||||
if (_model == nil) { return nil; }
|
||||
return self;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
Initialize whisper_encoder_impl instance with the model in this bundle.
|
||||
*/
|
||||
- (nullable instancetype)init {
|
||||
return [self initWithContentsOfURL:(NSURL * _Nonnull)self.class.URLOfModelInThisBundle error:nil];
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
Initialize whisper_encoder_impl instance with the model in this bundle.
|
||||
|
||||
@param configuration The model configuration object
|
||||
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
|
||||
*/
|
||||
- (nullable instancetype)initWithConfiguration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error {
|
||||
return [self initWithContentsOfURL:(NSURL * _Nonnull)self.class.URLOfModelInThisBundle configuration:configuration error:error];
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
Initialize whisper_encoder_impl instance from the model URL.
|
||||
|
||||
@param modelURL URL to the .mlmodelc directory for whisper_encoder_impl.
|
||||
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
|
||||
*/
|
||||
- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL error:(NSError * _Nullable __autoreleasing * _Nullable)error {
|
||||
MLModel *model = [MLModel modelWithContentsOfURL:modelURL error:error];
|
||||
if (model == nil) { return nil; }
|
||||
return [self initWithMLModel:model];
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
Initialize whisper_encoder_impl instance from the model URL.
|
||||
|
||||
@param modelURL URL to the .mlmodelc directory for whisper_encoder_impl.
|
||||
@param configuration The model configuration object
|
||||
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
|
||||
*/
|
||||
- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error {
|
||||
MLModel *model = [MLModel modelWithContentsOfURL:modelURL configuration:configuration error:error];
|
||||
if (model == nil) { return nil; }
|
||||
return [self initWithMLModel:model];
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
Construct whisper_encoder_impl instance asynchronously with configuration.
|
||||
Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread.
|
||||
|
||||
@param configuration The model configuration
|
||||
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_encoder_impl instance or NSError object.
|
||||
*/
|
||||
+ (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_encoder_impl * _Nullable model, NSError * _Nullable error))handler {
|
||||
[self loadContentsOfURL:(NSURL * _Nonnull)[self URLOfModelInThisBundle]
|
||||
configuration:configuration
|
||||
completionHandler:handler];
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
Construct whisper_encoder_impl instance asynchronously with URL of .mlmodelc directory and optional configuration.
|
||||
|
||||
Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread.
|
||||
|
||||
@param modelURL The model URL.
|
||||
@param configuration The model configuration
|
||||
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_encoder_impl instance or NSError object.
|
||||
*/
|
||||
+ (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_encoder_impl * _Nullable model, NSError * _Nullable error))handler {
|
||||
[MLModel loadContentsOfURL:modelURL
|
||||
configuration:configuration
|
||||
completionHandler:^(MLModel *model, NSError *error) {
|
||||
if (model != nil) {
|
||||
whisper_encoder_impl *typedModel = [[whisper_encoder_impl alloc] initWithMLModel:model];
|
||||
handler(typedModel, nil);
|
||||
} else {
|
||||
handler(nil, error);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
- (nullable whisper_encoder_implOutput *)predictionFromFeatures:(whisper_encoder_implInput *)input error:(NSError * _Nullable __autoreleasing * _Nullable)error {
|
||||
return [self predictionFromFeatures:input options:[[MLPredictionOptions alloc] init] error:error];
|
||||
}
|
||||
|
||||
- (nullable whisper_encoder_implOutput *)predictionFromFeatures:(whisper_encoder_implInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error {
|
||||
id<MLFeatureProvider> outFeatures = [self.model predictionFromFeatures:input options:options error:error];
|
||||
if (!outFeatures) { return nil; }
|
||||
return [[whisper_encoder_implOutput alloc] initWithOutput:(MLMultiArray *)[outFeatures featureValueForName:@"output"].multiArrayValue];
|
||||
}
|
||||
|
||||
- (nullable whisper_encoder_implOutput *)predictionFromLogmel_data:(MLMultiArray *)logmel_data error:(NSError * _Nullable __autoreleasing * _Nullable)error {
|
||||
whisper_encoder_implInput *input_ = [[whisper_encoder_implInput alloc] initWithLogmel_data:logmel_data];
|
||||
return [self predictionFromFeatures:input_ error:error];
|
||||
}
|
||||
|
||||
- (nullable NSArray<whisper_encoder_implOutput *> *)predictionsFromInputs:(NSArray<whisper_encoder_implInput*> *)inputArray options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error {
|
||||
id<MLBatchProvider> inBatch = [[MLArrayBatchProvider alloc] initWithFeatureProviderArray:inputArray];
|
||||
id<MLBatchProvider> outBatch = [self.model predictionsFromBatch:inBatch options:options error:error];
|
||||
if (!outBatch) { return nil; }
|
||||
NSMutableArray<whisper_encoder_implOutput*> *results = [NSMutableArray arrayWithCapacity:(NSUInteger)outBatch.count];
|
||||
for (NSInteger i = 0; i < outBatch.count; i++) {
|
||||
id<MLFeatureProvider> resultProvider = [outBatch featuresAtIndex:i];
|
||||
whisper_encoder_implOutput * result = [[whisper_encoder_implOutput alloc] initWithOutput:(MLMultiArray *)[resultProvider featureValueForName:@"output"].multiArrayValue];
|
||||
[results addObject:result];
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
@end
|
22
coreml/whisper-encoder.h
Normal file
22
coreml/whisper-encoder.h
Normal file
@ -0,0 +1,22 @@
|
||||
// Wrapper of the Core ML Whisper Encoder model
|
||||
//
|
||||
// Code is derived from the work of Github user @wangchou
|
||||
// ref: https://github.com/wangchou/callCoreMLFromCpp
|
||||
|
||||
#if __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
struct whisper_coreml_context;
|
||||
|
||||
struct whisper_coreml_context * whisper_coreml_init(const char * path_model);
|
||||
void whisper_coreml_free(struct whisper_coreml_context * ctx);
|
||||
|
||||
void whisper_coreml_encode(
|
||||
const whisper_coreml_context * ctx,
|
||||
float * mel,
|
||||
float * out);
|
||||
|
||||
#if __cplusplus
|
||||
}
|
||||
#endif
|
67
coreml/whisper-encoder.mm
Normal file
67
coreml/whisper-encoder.mm
Normal file
@ -0,0 +1,67 @@
|
||||
#import "coreml/whisper-encoder.h"
|
||||
#import "coreml/whisper-encoder-impl.h"
|
||||
|
||||
#import <CoreML/CoreML.h>
|
||||
|
||||
#include <stdlib.h>
|
||||
|
||||
#if __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
struct whisper_coreml_context {
|
||||
const void * data;
|
||||
};
|
||||
|
||||
struct whisper_coreml_context * whisper_coreml_init(const char * path_model) {
|
||||
NSString * path_model_str = [[NSString alloc] initWithUTF8String:path_model];
|
||||
|
||||
NSURL * url_model = [NSURL fileURLWithPath: path_model_str];
|
||||
|
||||
const void * data = CFBridgingRetain([[whisper_encoder_impl alloc] initWithContentsOfURL:url_model error:nil]);
|
||||
|
||||
if (data == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
whisper_coreml_context * ctx = new whisper_coreml_context;
|
||||
|
||||
ctx->data = data;
|
||||
|
||||
return ctx;
|
||||
}
|
||||
|
||||
void whisper_coreml_free(struct whisper_coreml_context * ctx) {
|
||||
CFRelease(ctx->data);
|
||||
delete ctx;
|
||||
}
|
||||
|
||||
void whisper_coreml_encode(
|
||||
const whisper_coreml_context * ctx,
|
||||
float * mel,
|
||||
float * out) {
|
||||
MLMultiArray * inMultiArray = [
|
||||
[MLMultiArray alloc] initWithDataPointer: mel
|
||||
shape: @[@1, @80, @3000]
|
||||
dataType: MLMultiArrayDataTypeFloat32
|
||||
strides: @[@(240000), @(3000), @1]
|
||||
deallocator: nil
|
||||
error: nil
|
||||
];
|
||||
|
||||
whisper_encoder_implOutput * outCoreML = [(__bridge id) ctx->data predictionFromLogmel_data:inMultiArray error:nil];
|
||||
|
||||
MLMultiArray * outMA = outCoreML.output;
|
||||
|
||||
//NSArray<NSNumber *> * shape = outMA.shape;
|
||||
//NSArray<NSNumber *> * strides = outMA.strides;
|
||||
|
||||
//printf("shape: %ld %ld %ld %ld\n", [shape[0] longValue], [shape[1] longValue], [shape[2] longValue], [shape[3] longValue]);
|
||||
//printf("strides: %ld %ld %ld %ld\n", [strides[0] longValue], [strides[1] longValue], [strides[2] longValue], [strides[3] longValue]);
|
||||
|
||||
memcpy(out, outMA.dataPointer, outMA.count * sizeof(float));
|
||||
}
|
||||
|
||||
#if __cplusplus
|
||||
}
|
||||
#endif
|
@ -4,7 +4,7 @@ find_package(Threads REQUIRED)
|
||||
|
||||
# third-party
|
||||
|
||||
if (WHISPER_SUPPORT_SDL2)
|
||||
if (WHISPER_SDL2)
|
||||
# SDL2
|
||||
find_package(SDL2 REQUIRED)
|
||||
|
||||
@ -14,6 +14,41 @@ if (WHISPER_SUPPORT_SDL2)
|
||||
message(STATUS "SDL2_LIBRARIES = ${SDL2_LIBRARIES}")
|
||||
endif()
|
||||
|
||||
# common
|
||||
|
||||
set(TARGET common)
|
||||
|
||||
add_library(${TARGET} STATIC
|
||||
common.h
|
||||
common.cpp
|
||||
common-ggml.h
|
||||
common-ggml.cpp
|
||||
)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE whisper)
|
||||
|
||||
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
|
||||
if (WHISPER_SDL2)
|
||||
# common-sdl
|
||||
|
||||
set(TARGET common-sdl)
|
||||
|
||||
add_library(${TARGET} STATIC
|
||||
common-sdl.h
|
||||
common-sdl.cpp
|
||||
)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_include_directories(${TARGET} PUBLIC ${SDL2_INCLUDE_DIRS})
|
||||
target_link_libraries(${TARGET} PRIVATE ${SDL2_LIBRARIES})
|
||||
|
||||
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
endif()
|
||||
|
||||
# examples
|
||||
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
|
||||
@ -24,10 +59,14 @@ if (EMSCRIPTEN)
|
||||
add_subdirectory(command.wasm)
|
||||
add_subdirectory(talk.wasm)
|
||||
add_subdirectory(bench.wasm)
|
||||
elseif(CMAKE_JS_VERSION)
|
||||
add_subdirectory(addon.node)
|
||||
else()
|
||||
add_subdirectory(main)
|
||||
add_subdirectory(stream)
|
||||
add_subdirectory(command)
|
||||
add_subdirectory(bench)
|
||||
add_subdirectory(quantize)
|
||||
add_subdirectory(talk)
|
||||
add_subdirectory(talk-llama)
|
||||
endif()
|
||||
|
3
examples/addon.node/.gitignore
vendored
Normal file
3
examples/addon.node/.gitignore
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
.idea
|
||||
node_modules
|
||||
build
|
31
examples/addon.node/CMakeLists.txt
Normal file
31
examples/addon.node/CMakeLists.txt
Normal file
@ -0,0 +1,31 @@
|
||||
set(TARGET whisper-addon)
|
||||
|
||||
# Base settings
|
||||
#==================================================================
|
||||
# env var supported by cmake-js
|
||||
add_definitions(-DNAPI_VERSION=4)
|
||||
include_directories(${CMAKE_JS_INC})
|
||||
#==================================================================
|
||||
|
||||
add_library(${TARGET} SHARED ${CMAKE_JS_SRC} addon.cpp)
|
||||
set_target_properties(${TARGET} PROPERTIES PREFIX "" SUFFIX ".node")
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
# Include N-API wrappers
|
||||
#==================================================================
|
||||
execute_process(COMMAND node -p "require('node-addon-api').include"
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
OUTPUT_VARIABLE NODE_ADDON_API_DIR
|
||||
)
|
||||
string(REPLACE "\n" "" NODE_ADDON_API_DIR ${NODE_ADDON_API_DIR})
|
||||
string(REPLACE "\"" "" NODE_ADDON_API_DIR ${NODE_ADDON_API_DIR})
|
||||
target_include_directories(${TARGET} PRIVATE ${NODE_ADDON_API_DIR})
|
||||
#==================================================================
|
||||
|
||||
target_link_libraries(${TARGET} ${CMAKE_JS_LIB} common whisper ${CMAKE_THREAD_LIBS_INIT})
|
||||
|
||||
if(MSVC AND CMAKE_JS_NODELIB_DEF AND CMAKE_JS_NODELIB_TARGET)
|
||||
# Generate node.lib
|
||||
execute_process(COMMAND ${CMAKE_AR} /def:${CMAKE_JS_NODELIB_DEF} /out:${CMAKE_JS_NODELIB_TARGET} ${CMAKE_STATIC_LINKER_FLAGS})
|
||||
endif()
|
37
examples/addon.node/README.md
Normal file
37
examples/addon.node/README.md
Normal file
@ -0,0 +1,37 @@
|
||||
# addon
|
||||
|
||||
This is an addon demo that can **perform whisper model reasoning in `node` and `electron` environments**, based on [cmake-js](https://github.com/cmake-js/cmake-js).
|
||||
It can be used as a reference for using the whisper.cpp project in other node projects.
|
||||
|
||||
## Install
|
||||
|
||||
```shell
|
||||
npm install
|
||||
```
|
||||
|
||||
## Compile
|
||||
|
||||
Make sure it is in the project root directory and compiled with make-js.
|
||||
|
||||
```shell
|
||||
npx cmake-js compile -T whisper-addon -B Release
|
||||
```
|
||||
|
||||
For Electron addon and cmake-js options, you can see [cmake-js](https://github.com/cmake-js/cmake-js) and make very few configuration changes.
|
||||
|
||||
> Such as appointing special cmake path:
|
||||
> ```shell
|
||||
> npx cmake-js compile -c 'xxx/cmake' -T whisper-addon -B Release
|
||||
> ```
|
||||
|
||||
## Run
|
||||
|
||||
```shell
|
||||
cd examples/addon.node
|
||||
|
||||
node index.js --language='language' --model='model-path' --fname_inp='file-path'
|
||||
```
|
||||
|
||||
Because this is a simple Demo, only the above parameters are set in the node environment.
|
||||
|
||||
Other parameters can also be specified in the node environment.
|
23
examples/addon.node/__test__/whisper.spec.js
Normal file
23
examples/addon.node/__test__/whisper.spec.js
Normal file
@ -0,0 +1,23 @@
|
||||
const path = require("path");
|
||||
const { whisper } = require(path.join(
|
||||
__dirname,
|
||||
"../../../build/Release/whisper-addon"
|
||||
));
|
||||
const { promisify } = require("util");
|
||||
|
||||
const whisperAsync = promisify(whisper);
|
||||
|
||||
const whisperParamsMock = {
|
||||
language: "en",
|
||||
model: path.join(__dirname, "../../../models/ggml-base.en.bin"),
|
||||
fname_inp: path.join(__dirname, "../../../samples/jfk.wav"),
|
||||
};
|
||||
|
||||
describe("Run whisper.node", () => {
|
||||
test("it should receive a non-empty value", async () => {
|
||||
let result = await whisperAsync(whisperParamsMock);
|
||||
|
||||
expect(result.length).toBeGreaterThan(0);
|
||||
}, 10000);
|
||||
});
|
||||
|
338
examples/addon.node/addon.cpp
Normal file
338
examples/addon.node/addon.cpp
Normal file
@ -0,0 +1,338 @@
|
||||
#include "napi.h"
|
||||
#include "common.h"
|
||||
|
||||
#include "whisper.h"
|
||||
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
|
||||
struct whisper_params {
|
||||
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||
int32_t n_processors = 1;
|
||||
int32_t offset_t_ms = 0;
|
||||
int32_t offset_n = 0;
|
||||
int32_t duration_ms = 0;
|
||||
int32_t max_context = -1;
|
||||
int32_t max_len = 0;
|
||||
int32_t best_of = 5;
|
||||
int32_t beam_size = -1;
|
||||
|
||||
float word_thold = 0.01f;
|
||||
float entropy_thold = 2.4f;
|
||||
float logprob_thold = -1.0f;
|
||||
|
||||
bool speed_up = false;
|
||||
bool translate = false;
|
||||
bool diarize = false;
|
||||
bool output_txt = false;
|
||||
bool output_vtt = false;
|
||||
bool output_srt = false;
|
||||
bool output_wts = false;
|
||||
bool output_csv = false;
|
||||
bool print_special = false;
|
||||
bool print_colors = false;
|
||||
bool print_progress = false;
|
||||
bool no_timestamps = false;
|
||||
|
||||
std::string language = "en";
|
||||
std::string prompt;
|
||||
std::string model = "../../ggml-large.bin";
|
||||
|
||||
std::vector<std::string> fname_inp = {};
|
||||
std::vector<std::string> fname_out = {};
|
||||
};
|
||||
|
||||
struct whisper_print_user_data {
|
||||
const whisper_params * params;
|
||||
|
||||
const std::vector<std::vector<float>> * pcmf32s;
|
||||
};
|
||||
|
||||
// 500 -> 00:05.000
|
||||
// 6000 -> 01:00.000
|
||||
std::string to_timestamp(int64_t t, bool comma = false) {
|
||||
int64_t msec = t * 10;
|
||||
int64_t hr = msec / (1000 * 60 * 60);
|
||||
msec = msec - hr * (1000 * 60 * 60);
|
||||
int64_t min = msec / (1000 * 60);
|
||||
msec = msec - min * (1000 * 60);
|
||||
int64_t sec = msec / 1000;
|
||||
msec = msec - sec * 1000;
|
||||
|
||||
char buf[32];
|
||||
snprintf(buf, sizeof(buf), "%02d:%02d:%02d%s%03d", (int) hr, (int) min, (int) sec, comma ? "," : ".", (int) msec);
|
||||
|
||||
return std::string(buf);
|
||||
}
|
||||
|
||||
int timestamp_to_sample(int64_t t, int n_samples) {
|
||||
return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100)));
|
||||
}
|
||||
|
||||
void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) {
|
||||
const auto & params = *((whisper_print_user_data *) user_data)->params;
|
||||
const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s;
|
||||
|
||||
const int n_segments = whisper_full_n_segments(ctx);
|
||||
|
||||
std::string speaker = "";
|
||||
|
||||
int64_t t0;
|
||||
int64_t t1;
|
||||
|
||||
// print the last n_new segments
|
||||
const int s0 = n_segments - n_new;
|
||||
|
||||
if (s0 == 0) {
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
for (int i = s0; i < n_segments; i++) {
|
||||
if (!params.no_timestamps || params.diarize) {
|
||||
t0 = whisper_full_get_segment_t0(ctx, i);
|
||||
t1 = whisper_full_get_segment_t1(ctx, i);
|
||||
}
|
||||
|
||||
if (!params.no_timestamps) {
|
||||
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
|
||||
}
|
||||
|
||||
if (params.diarize && pcmf32s.size() == 2) {
|
||||
const int64_t n_samples = pcmf32s[0].size();
|
||||
|
||||
const int64_t is0 = timestamp_to_sample(t0, n_samples);
|
||||
const int64_t is1 = timestamp_to_sample(t1, n_samples);
|
||||
|
||||
double energy0 = 0.0f;
|
||||
double energy1 = 0.0f;
|
||||
|
||||
for (int64_t j = is0; j < is1; j++) {
|
||||
energy0 += fabs(pcmf32s[0][j]);
|
||||
energy1 += fabs(pcmf32s[1][j]);
|
||||
}
|
||||
|
||||
if (energy0 > 1.1*energy1) {
|
||||
speaker = "(speaker 0)";
|
||||
} else if (energy1 > 1.1*energy0) {
|
||||
speaker = "(speaker 1)";
|
||||
} else {
|
||||
speaker = "(speaker ?)";
|
||||
}
|
||||
|
||||
//printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str());
|
||||
}
|
||||
|
||||
// colorful print bug
|
||||
//
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
printf("%s%s", speaker.c_str(), text);
|
||||
|
||||
|
||||
// with timestamps or speakers: each segment on new line
|
||||
if (!params.no_timestamps || params.diarize) {
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
fflush(stdout);
|
||||
}
|
||||
}
|
||||
|
||||
int run(whisper_params ¶ms, std::vector<std::vector<std::string>> &result) {
|
||||
if (params.fname_inp.empty()) {
|
||||
fprintf(stderr, "error: no input files specified\n");
|
||||
return 2;
|
||||
}
|
||||
|
||||
if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1) {
|
||||
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
|
||||
exit(0);
|
||||
}
|
||||
|
||||
// whisper init
|
||||
|
||||
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
|
||||
|
||||
if (ctx == nullptr) {
|
||||
fprintf(stderr, "error: failed to initialize whisper context\n");
|
||||
return 3;
|
||||
}
|
||||
|
||||
for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
|
||||
const auto fname_inp = params.fname_inp[f];
|
||||
const auto fname_out = f < (int)params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
|
||||
|
||||
std::vector<float> pcmf32; // mono-channel F32 PCM
|
||||
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
|
||||
|
||||
if (!::read_wav(fname_inp, pcmf32, pcmf32s, params.diarize)) {
|
||||
fprintf(stderr, "error: failed to read WAV file '%s'\n", fname_inp.c_str());
|
||||
continue;
|
||||
}
|
||||
|
||||
// print system information
|
||||
{
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
|
||||
params.n_threads*params.n_processors, std::thread::hardware_concurrency(), whisper_print_system_info());
|
||||
}
|
||||
|
||||
// print some info about the processing
|
||||
{
|
||||
fprintf(stderr, "\n");
|
||||
if (!whisper_is_multilingual(ctx)) {
|
||||
if (params.language != "en" || params.translate) {
|
||||
params.language = "en";
|
||||
params.translate = false;
|
||||
fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
|
||||
}
|
||||
}
|
||||
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, timestamps = %d ...\n",
|
||||
__func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
|
||||
params.n_threads, params.n_processors,
|
||||
params.language.c_str(),
|
||||
params.translate ? "translate" : "transcribe",
|
||||
params.no_timestamps ? 0 : 1);
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
// run the inference
|
||||
{
|
||||
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||
|
||||
wparams.strategy = params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY;
|
||||
|
||||
wparams.print_realtime = false;
|
||||
wparams.print_progress = params.print_progress;
|
||||
wparams.print_timestamps = !params.no_timestamps;
|
||||
wparams.print_special = params.print_special;
|
||||
wparams.translate = params.translate;
|
||||
wparams.language = params.language.c_str();
|
||||
wparams.n_threads = params.n_threads;
|
||||
wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
|
||||
wparams.offset_ms = params.offset_t_ms;
|
||||
wparams.duration_ms = params.duration_ms;
|
||||
|
||||
wparams.token_timestamps = params.output_wts || params.max_len > 0;
|
||||
wparams.thold_pt = params.word_thold;
|
||||
wparams.entropy_thold = params.entropy_thold;
|
||||
wparams.logprob_thold = params.logprob_thold;
|
||||
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
|
||||
|
||||
wparams.speed_up = params.speed_up;
|
||||
|
||||
wparams.greedy.best_of = params.best_of;
|
||||
wparams.beam_search.beam_size = params.beam_size;
|
||||
|
||||
wparams.initial_prompt = params.prompt.c_str();
|
||||
|
||||
whisper_print_user_data user_data = { ¶ms, &pcmf32s };
|
||||
|
||||
// this callback is called on each new segment
|
||||
if (!wparams.print_realtime) {
|
||||
wparams.new_segment_callback = whisper_print_segment_callback;
|
||||
wparams.new_segment_callback_user_data = &user_data;
|
||||
}
|
||||
|
||||
// example for abort mechanism
|
||||
// in this example, we do not abort the processing, but we could if the flag is set to true
|
||||
// the callback is called before every encoder run - if it returns false, the processing is aborted
|
||||
{
|
||||
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
|
||||
|
||||
wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
|
||||
bool is_aborted = *(bool*)user_data;
|
||||
return !is_aborted;
|
||||
};
|
||||
wparams.encoder_begin_callback_user_data = &is_aborted;
|
||||
}
|
||||
|
||||
if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
|
||||
fprintf(stderr, "failed to process audio\n");
|
||||
return 10;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const int n_segments = whisper_full_n_segments(ctx);
|
||||
result.resize(n_segments);
|
||||
for (int i = 0; i < n_segments; ++i) {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
|
||||
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
||||
|
||||
result[i].emplace_back(to_timestamp(t0, true));
|
||||
result[i].emplace_back(to_timestamp(t1, true));
|
||||
result[i].emplace_back(text);
|
||||
}
|
||||
|
||||
whisper_print_timings(ctx);
|
||||
whisper_free(ctx);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
class Worker : public Napi::AsyncWorker {
|
||||
public:
|
||||
Worker(Napi::Function& callback, whisper_params params)
|
||||
: Napi::AsyncWorker(callback), params(params) {}
|
||||
|
||||
void Execute() override {
|
||||
run(params, result);
|
||||
}
|
||||
|
||||
void OnOK() override {
|
||||
Napi::HandleScope scope(Env());
|
||||
Napi::Object res = Napi::Array::New(Env(), result.size());
|
||||
for (uint64_t i = 0; i < result.size(); ++i) {
|
||||
Napi::Object tmp = Napi::Array::New(Env(), 3);
|
||||
for (uint64_t j = 0; j < 3; ++j) {
|
||||
tmp[j] = Napi::String::New(Env(), result[i][j]);
|
||||
}
|
||||
res[i] = tmp;
|
||||
}
|
||||
Callback().Call({Env().Null(), res});
|
||||
}
|
||||
|
||||
private:
|
||||
whisper_params params;
|
||||
std::vector<std::vector<std::string>> result;
|
||||
};
|
||||
|
||||
|
||||
|
||||
Napi::Value whisper(const Napi::CallbackInfo& info) {
|
||||
Napi::Env env = info.Env();
|
||||
if (info.Length() <= 0 || !info[0].IsObject()) {
|
||||
Napi::TypeError::New(env, "object expected").ThrowAsJavaScriptException();
|
||||
}
|
||||
whisper_params params;
|
||||
|
||||
Napi::Object whisper_params = info[0].As<Napi::Object>();
|
||||
std::string language = whisper_params.Get("language").As<Napi::String>();
|
||||
std::string model = whisper_params.Get("model").As<Napi::String>();
|
||||
std::string input = whisper_params.Get("fname_inp").As<Napi::String>();
|
||||
|
||||
params.language = language;
|
||||
params.model = model;
|
||||
params.fname_inp.emplace_back(input);
|
||||
|
||||
Napi::Function callback = info[1].As<Napi::Function>();
|
||||
Worker* worker = new Worker(callback, params);
|
||||
worker->Queue();
|
||||
return env.Undefined();
|
||||
}
|
||||
|
||||
|
||||
Napi::Object Init(Napi::Env env, Napi::Object exports) {
|
||||
exports.Set(
|
||||
Napi::String::New(env, "whisper"),
|
||||
Napi::Function::New(env, whisper)
|
||||
);
|
||||
return exports;
|
||||
}
|
||||
|
||||
NODE_API_MODULE(whisper, Init);
|
36
examples/addon.node/index.js
Normal file
36
examples/addon.node/index.js
Normal file
@ -0,0 +1,36 @@
|
||||
const path = require("path");
|
||||
const { whisper } = require(path.join(
|
||||
__dirname,
|
||||
"../../build/Release/whisper-addon"
|
||||
));
|
||||
const { promisify } = require("util");
|
||||
|
||||
const whisperAsync = promisify(whisper);
|
||||
|
||||
const whisperParams = {
|
||||
language: "en",
|
||||
model: path.join(__dirname, "../../models/ggml-base.en.bin"),
|
||||
fname_inp: "../../samples/jfk.wav",
|
||||
};
|
||||
|
||||
const arguments = process.argv.slice(2);
|
||||
const params = Object.fromEntries(
|
||||
arguments.reduce((pre, item) => {
|
||||
if (item.startsWith("--")) {
|
||||
return [...pre, item.slice(2).split("=")];
|
||||
}
|
||||
return pre;
|
||||
}, [])
|
||||
);
|
||||
|
||||
for (const key in params) {
|
||||
if (whisperParams.hasOwnProperty(key)) {
|
||||
whisperParams[key] = params[key];
|
||||
}
|
||||
}
|
||||
|
||||
console.log("whisperParams =", whisperParams);
|
||||
|
||||
whisperAsync(whisperParams).then((result) => {
|
||||
console.log(`Result from whisper: ${result}`);
|
||||
});
|
16
examples/addon.node/package.json
Normal file
16
examples/addon.node/package.json
Normal file
@ -0,0 +1,16 @@
|
||||
{
|
||||
"name": "whisper-addon",
|
||||
"version": "0.0.0",
|
||||
"description": "",
|
||||
"main": "index.js",
|
||||
"author": "Qanhe Chen",
|
||||
"license": "MIT",
|
||||
"scripts": {
|
||||
"test": "jest"
|
||||
},
|
||||
"devDependencies": {
|
||||
"cmake-js": "^7.1.1",
|
||||
"jest": "^29.4.0",
|
||||
"node-addon-api": "^5.0.0"
|
||||
}
|
||||
}
|
@ -31,9 +31,9 @@ endif()
|
||||
set_target_properties(${TARGET} PROPERTIES LINK_FLAGS " \
|
||||
--bind \
|
||||
-s USE_PTHREADS=1 \
|
||||
-s PTHREAD_POOL_SIZE=8 \
|
||||
-s INITIAL_MEMORY=1024MB \
|
||||
-s TOTAL_MEMORY=1024MB \
|
||||
-s PTHREAD_POOL_SIZE_STRICT=0 \
|
||||
-s INITIAL_MEMORY=2000MB \
|
||||
-s TOTAL_MEMORY=2000MB \
|
||||
-s FORCE_FILESYSTEM=1 \
|
||||
-s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap']\" \
|
||||
${EXTRA_FLAGS} \
|
||||
|
@ -35,6 +35,15 @@
|
||||
|
||||
<br><br>
|
||||
|
||||
<b>More examples:</b>
|
||||
<a href="https://whisper.ggerganov.com/">main</a> |
|
||||
<a href="https://whisper.ggerganov.com/bench">bench</a> |
|
||||
<a href="https://whisper.ggerganov.com/stream">stream</a> |
|
||||
<a href="https://whisper.ggerganov.com/command">command</a> |
|
||||
<a href="https://whisper.ggerganov.com/talk">talk</a> |
|
||||
|
||||
<br><br>
|
||||
|
||||
<hr>
|
||||
|
||||
Select the model you would like to use and click the "Bench" button.<br>
|
||||
@ -44,11 +53,18 @@
|
||||
|
||||
<div id="model-whisper">
|
||||
Whisper model: <span id="model-whisper-status"></span>
|
||||
<button id="fetch-whisper-tiny-en" onclick="loadWhisper('tiny.en')">tiny.en (75 MB)</button>
|
||||
<button id="fetch-whisper-base-en" onclick="loadWhisper('base.en')">base.en (142 MB)</button>
|
||||
<span id="fetch-whisper-progress"></span>
|
||||
|
||||
<button id="fetch-whisper-tiny-en" onclick="loadWhisper('tiny.en')">tiny.en (75 MB)</button>
|
||||
<button id="fetch-whisper-base-en" onclick="loadWhisper('base.en')">base.en (142 MB)</button>
|
||||
<button id="fetch-whisper-small-en" onclick="loadWhisper('small.en')">small.en (466 MB)</button>
|
||||
<input type="file" id="whisper-file" name="file" onchange="loadFile(event, 'whisper.bin')" />
|
||||
<br><br>
|
||||
Quantized models:<br><br>
|
||||
<button id="fetch-whisper-tiny-en-q5_1" onclick="loadWhisper('tiny-en-q5_1')">tiny.en (Q5_1, 31 MB)</button>
|
||||
<button id="fetch-whisper-base-en-q5_1" onclick="loadWhisper('base-en-q5_1')">base.en (Q5_1, 57 MB)</button>
|
||||
<button id="fetch-whisper-small-en-q5_1" onclick="loadWhisper('small-en-q5_1')">small.en (Q5_1, 182 MB)</button>
|
||||
<button id="fetch-whisper-medium-en-q5_0" onclick="loadWhisper('medium-en-q5_0')">medium.en (Q5_0, 515 MB)</button>
|
||||
<button id="fetch-whisper-large-q5_0" onclick="loadWhisper('large-q5_0')">large (Q5_0, 1030 MB)</button>
|
||||
<span id="fetch-whisper-progress"></span>
|
||||
</div>
|
||||
|
||||
<br>
|
||||
@ -160,6 +176,14 @@
|
||||
|
||||
document.getElementById('fetch-whisper-tiny-en').style.display = 'none';
|
||||
document.getElementById('fetch-whisper-base-en').style.display = 'none';
|
||||
document.getElementById('fetch-whisper-small-en').style.display = 'none';
|
||||
|
||||
document.getElementById('fetch-whisper-tiny-en-q5_1' ).style.display = 'none';
|
||||
document.getElementById('fetch-whisper-base-en-q5_1' ).style.display = 'none';
|
||||
document.getElementById('fetch-whisper-small-en-q5_1' ).style.display = 'none';
|
||||
document.getElementById('fetch-whisper-medium-en-q5_0').style.display = 'none';
|
||||
document.getElementById('fetch-whisper-large-q5_0' ).style.display = 'none';
|
||||
|
||||
document.getElementById('whisper-file' ).style.display = 'none';
|
||||
document.getElementById('model-whisper-status' ).innerHTML = 'loaded model: ' + file.name;
|
||||
}
|
||||
@ -168,19 +192,42 @@
|
||||
let urls = {
|
||||
'tiny.en': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en.bin',
|
||||
'base.en': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en.bin',
|
||||
'small.en': 'https://whisper.ggerganov.com/ggml-model-whisper-small.en.bin',
|
||||
|
||||
'tiny-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en-q5_1.bin',
|
||||
'base-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en-q5_1.bin',
|
||||
'small-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-small.en-q5_1.bin',
|
||||
'medium-en-q5_0':'https://whisper.ggerganov.com/ggml-model-whisper-medium.en-q5_0.bin',
|
||||
'large-q5_0': 'https://whisper.ggerganov.com/ggml-model-whisper-large-q5_0.bin',
|
||||
};
|
||||
|
||||
let sizes = {
|
||||
'tiny.en': 75,
|
||||
'base.en': 142,
|
||||
'small.en': 466,
|
||||
|
||||
'tiny-en-q5_1': 31,
|
||||
'base-en-q5_1': 57,
|
||||
'small-en-q5_1': 182,
|
||||
'medium-en-q5_0': 515,
|
||||
'large-q5_0': 1030,
|
||||
};
|
||||
|
||||
let url = urls[model];
|
||||
let dst = 'whisper.bin';
|
||||
let size_mb = sizes[model];
|
||||
|
||||
document.getElementById('fetch-whisper-tiny-en').style.display = 'none';
|
||||
document.getElementById('fetch-whisper-base-en').style.display = 'none';
|
||||
document.getElementById('fetch-whisper-tiny-en').style.display = 'none';
|
||||
document.getElementById('fetch-whisper-base-en').style.display = 'none';
|
||||
document.getElementById('fetch-whisper-small-en').style.display = 'none';
|
||||
|
||||
document.getElementById('fetch-whisper-tiny-en-q5_1' ).style.display = 'none';
|
||||
document.getElementById('fetch-whisper-base-en-q5_1' ).style.display = 'none';
|
||||
document.getElementById('fetch-whisper-small-en-q5_1' ).style.display = 'none';
|
||||
document.getElementById('fetch-whisper-medium-en-q5_0').style.display = 'none';
|
||||
document.getElementById('fetch-whisper-large-q5_0' ).style.display = 'none';
|
||||
|
||||
document.getElementById('whisper-file' ).style.display = 'none';
|
||||
document.getElementById('model-whisper-status').innerHTML = 'loading "' + model + '" ... ';
|
||||
|
||||
cbProgress = function(p) {
|
||||
@ -190,9 +237,18 @@
|
||||
|
||||
cbCancel = function() {
|
||||
var el;
|
||||
el = document.getElementById('fetch-whisper-tiny-en'); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('fetch-whisper-base-en'); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('model-whisper-status'); if (el) el.innerHTML = '';
|
||||
el = document.getElementById('fetch-whisper-tiny-en'); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('fetch-whisper-base-en'); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('fetch-whisper-small-en'); if (el) el.style.display = 'inline-block';
|
||||
|
||||
el = document.getElementById('fetch-whisper-tiny-en-q5_1' ); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('fetch-whisper-base-en-q5_1' ); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('fetch-whisper-small-en-q5_1' ); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('fetch-whisper-medium-en-q5_0'); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('fetch-whisper-large-q5_0' ); if (el) el.style.display = 'inline-block';
|
||||
|
||||
el = document.getElementById('whisper-file' ); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('model-whisper-status'); if (el) el.innerHTML = '';
|
||||
};
|
||||
|
||||
loadRemote(url, dst, size_mb, cbProgress, storeFS, cbCancel, printTextarea);
|
||||
|
@ -11,6 +11,7 @@ add_executable(${TARGET}
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE
|
||||
common
|
||||
whisper
|
||||
)
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
#include "ggml.h"
|
||||
#include "common.h"
|
||||
#include "whisper.h"
|
||||
|
||||
#include <emscripten.h>
|
||||
@ -27,92 +28,11 @@ std::string g_transcribed = "";
|
||||
|
||||
std::vector<float> g_pcmf32;
|
||||
|
||||
static std::string trim(const std::string & s) {
|
||||
std::regex e("^\\s+|\\s+$");
|
||||
return std::regex_replace(s, e, "");
|
||||
}
|
||||
|
||||
static void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate) {
|
||||
const float rc = 1.0f / (2.0f * M_PI * cutoff);
|
||||
const float dt = 1.0f / sample_rate;
|
||||
const float alpha = dt / (rc + dt);
|
||||
|
||||
float y = data[0];
|
||||
|
||||
for (size_t i = 1; i < data.size(); i++) {
|
||||
y = alpha * (y + data[i] - data[i - 1]);
|
||||
data[i] = y;
|
||||
}
|
||||
}
|
||||
|
||||
// compute similarity between two strings using Levenshtein distance
|
||||
static float similarity(const std::string & s0, const std::string & s1) {
|
||||
const size_t len0 = s0.size() + 1;
|
||||
const size_t len1 = s1.size() + 1;
|
||||
|
||||
std::vector<int> col(len1, 0);
|
||||
std::vector<int> prevCol(len1, 0);
|
||||
|
||||
for (size_t i = 0; i < len1; i++) {
|
||||
prevCol[i] = i;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < len0; i++) {
|
||||
col[0] = i;
|
||||
for (size_t j = 1; j < len1; j++) {
|
||||
col[j] = std::min(std::min(1 + col[j - 1], 1 + prevCol[j]), prevCol[j - 1] + (s0[i - 1] == s1[j - 1] ? 0 : 1));
|
||||
}
|
||||
col.swap(prevCol);
|
||||
}
|
||||
|
||||
const float dist = prevCol[len1 - 1];
|
||||
|
||||
return 1.0f - (dist / std::max(s0.size(), s1.size()));
|
||||
}
|
||||
|
||||
void command_set_status(const std::string & status) {
|
||||
std::lock_guard<std::mutex> lock(g_mutex);
|
||||
g_status = status;
|
||||
}
|
||||
|
||||
bool command_vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) {
|
||||
const int n_samples = pcmf32.size();
|
||||
const int n_samples_last = (sample_rate * last_ms) / 1000;
|
||||
|
||||
if (n_samples_last >= n_samples) {
|
||||
// not enough samples - assume no speech
|
||||
return false;
|
||||
}
|
||||
|
||||
if (freq_thold > 0.0f) {
|
||||
high_pass_filter(pcmf32, freq_thold, sample_rate);
|
||||
}
|
||||
|
||||
float energy_all = 0.0f;
|
||||
float energy_last = 0.0f;
|
||||
|
||||
for (size_t i = 0; i < n_samples; i++) {
|
||||
energy_all += fabsf(pcmf32[i]);
|
||||
|
||||
if (i >= n_samples - n_samples_last) {
|
||||
energy_last += fabsf(pcmf32[i]);
|
||||
}
|
||||
}
|
||||
|
||||
energy_all /= n_samples;
|
||||
energy_last /= n_samples_last;
|
||||
|
||||
if (verbose) {
|
||||
fprintf(stderr, "%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold);
|
||||
}
|
||||
|
||||
if (energy_last > vad_thold*energy_all) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string command_transcribe(whisper_context * ctx, const whisper_full_params & wparams, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) {
|
||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
@ -155,7 +75,7 @@ void command_get_audio(int ms, int sample_rate, std::vector<float> & audio) {
|
||||
const int64_t n_samples = (ms * sample_rate) / 1000;
|
||||
|
||||
int64_t n_take = 0;
|
||||
if (g_pcmf32.size() < n_samples) {
|
||||
if (n_samples > (int) g_pcmf32.size()) {
|
||||
n_take = g_pcmf32.size();
|
||||
} else {
|
||||
n_take = n_samples;
|
||||
@ -187,7 +107,6 @@ void command_main(size_t index) {
|
||||
|
||||
printf("command: using %d threads\n", wparams.n_threads);
|
||||
|
||||
bool is_running = true;
|
||||
bool have_prompt = false;
|
||||
bool ask_prompt = true;
|
||||
bool print_energy = false;
|
||||
@ -233,7 +152,7 @@ void command_main(size_t index) {
|
||||
{
|
||||
command_get_audio(vad_ms, WHISPER_SAMPLE_RATE, pcmf32_cur);
|
||||
|
||||
if (command_vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, vad_thold, freq_thold, print_energy)) {
|
||||
if (::vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, vad_thold, freq_thold, print_energy)) {
|
||||
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
|
||||
command_set_status("Speech detected! Processing ...");
|
||||
|
||||
|
@ -35,6 +35,15 @@
|
||||
|
||||
<br><br>
|
||||
|
||||
<b>More examples:</b>
|
||||
<a href="https://whisper.ggerganov.com/">main</a> |
|
||||
<a href="https://whisper.ggerganov.com/bench">bench</a> |
|
||||
<a href="https://whisper.ggerganov.com/stream">stream</a> |
|
||||
<a href="https://whisper.ggerganov.com/command">command</a> |
|
||||
<a href="https://whisper.ggerganov.com/talk">talk</a> |
|
||||
|
||||
<br><br>
|
||||
|
||||
<hr>
|
||||
|
||||
Select the model you would like to use, click the "Start" button and follow the instructions.
|
||||
@ -45,6 +54,10 @@
|
||||
Whisper model: <span id="model-whisper-status"></span>
|
||||
<button id="fetch-whisper-tiny-en" onclick="loadWhisper('tiny.en')">tiny.en (75 MB)</button>
|
||||
<button id="fetch-whisper-base-en" onclick="loadWhisper('base.en')">base.en (142 MB)</button>
|
||||
<br><br>
|
||||
Quantized models:<br><br>
|
||||
<button id="fetch-whisper-tiny-en-q5_1" onclick="loadWhisper('tiny-en-q5_1')">tiny.en (Q5_1, 31 MB)</button>
|
||||
<button id="fetch-whisper-base-en-q5_1" onclick="loadWhisper('base-en-q5_1')">base.en (Q5_1, 57 MB)</button>
|
||||
<span id="fetch-whisper-progress"></span>
|
||||
|
||||
<!--
|
||||
@ -162,11 +175,17 @@
|
||||
let urls = {
|
||||
'tiny.en': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en.bin',
|
||||
'base.en': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en.bin',
|
||||
|
||||
'tiny-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en-q5_1.bin',
|
||||
'base-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en-q5_1.bin',
|
||||
};
|
||||
|
||||
let sizes = {
|
||||
'tiny.en': 75,
|
||||
'base.en': 142,
|
||||
|
||||
'tiny-en-q5_1': 31,
|
||||
'base-en-q5_1': 57,
|
||||
};
|
||||
|
||||
let url = urls[model];
|
||||
@ -177,6 +196,10 @@
|
||||
|
||||
document.getElementById('fetch-whisper-tiny-en').style.display = 'none';
|
||||
document.getElementById('fetch-whisper-base-en').style.display = 'none';
|
||||
|
||||
document.getElementById('fetch-whisper-tiny-en-q5_1').style.display = 'none';
|
||||
document.getElementById('fetch-whisper-base-en-q5_1').style.display = 'none';
|
||||
|
||||
document.getElementById('model-whisper-status').innerHTML = 'loading "' + model + '" ... ';
|
||||
|
||||
cbProgress = function(p) {
|
||||
@ -188,6 +211,10 @@
|
||||
var el;
|
||||
el = document.getElementById('fetch-whisper-tiny-en'); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('fetch-whisper-base-en'); if (el) el.style.display = 'inline-block';
|
||||
|
||||
el = document.getElementById('fetch-whisper-tiny-en-q5_1'); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('fetch-whisper-base-en-q5_1'); if (el) el.style.display = 'inline-block';
|
||||
|
||||
el = document.getElementById('model-whisper-status'); if (el) el.innerHTML = '';
|
||||
};
|
||||
|
||||
|
@ -1,10 +1,9 @@
|
||||
if (WHISPER_SUPPORT_SDL2)
|
||||
if (WHISPER_SDL2)
|
||||
# command
|
||||
set(TARGET command)
|
||||
add_executable(${TARGET} command.cpp)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
|
||||
target_link_libraries(${TARGET} PRIVATE whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_link_libraries(${TARGET} PRIVATE common common-sdl whisper ${CMAKE_THREAD_LIBS_INIT})
|
||||
endif ()
|
||||
|
@ -6,11 +6,10 @@
|
||||
// ref: https://github.com/ggerganov/whisper.cpp/issues/171
|
||||
//
|
||||
|
||||
#include "common.h"
|
||||
#include "common-sdl.h"
|
||||
#include "whisper.h"
|
||||
|
||||
#include <SDL.h>
|
||||
#include <SDL_audio.h>
|
||||
|
||||
#include <sstream>
|
||||
#include <cassert>
|
||||
#include <cstdio>
|
||||
@ -110,309 +109,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
//
|
||||
// SDL Audio capture
|
||||
//
|
||||
|
||||
class audio_async {
|
||||
public:
|
||||
audio_async(int len_ms);
|
||||
~audio_async();
|
||||
|
||||
bool init(int capture_id, int sample_rate);
|
||||
|
||||
// start capturing audio via the provided SDL callback
|
||||
// keep last len_ms seconds of audio in a circular buffer
|
||||
bool resume();
|
||||
bool pause();
|
||||
bool clear();
|
||||
|
||||
// callback to be called by SDL
|
||||
void callback(uint8_t * stream, int len);
|
||||
|
||||
// get audio data from the circular buffer
|
||||
void get(int ms, std::vector<float> & audio);
|
||||
|
||||
private:
|
||||
SDL_AudioDeviceID m_dev_id_in = 0;
|
||||
|
||||
int m_len_ms = 0;
|
||||
int m_sample_rate = 0;
|
||||
|
||||
bool m_running = false;
|
||||
std::mutex m_mutex;
|
||||
|
||||
std::vector<float> m_audio;
|
||||
std::vector<float> m_audio_new;
|
||||
size_t m_audio_pos = 0;
|
||||
size_t m_audio_len = 0;
|
||||
};
|
||||
|
||||
audio_async::audio_async(int len_ms) {
|
||||
m_len_ms = len_ms;
|
||||
}
|
||||
|
||||
audio_async::~audio_async() {
|
||||
if (m_dev_id_in) {
|
||||
SDL_CloseAudioDevice(m_dev_id_in);
|
||||
}
|
||||
}
|
||||
|
||||
bool audio_async::init(int capture_id, int sample_rate) {
|
||||
SDL_LogSetPriority(SDL_LOG_CATEGORY_APPLICATION, SDL_LOG_PRIORITY_INFO);
|
||||
|
||||
if (SDL_Init(SDL_INIT_AUDIO) < 0) {
|
||||
SDL_LogError(SDL_LOG_CATEGORY_APPLICATION, "Couldn't initialize SDL: %s\n", SDL_GetError());
|
||||
return false;
|
||||
}
|
||||
|
||||
SDL_SetHintWithPriority(SDL_HINT_AUDIO_RESAMPLING_MODE, "medium", SDL_HINT_OVERRIDE);
|
||||
|
||||
{
|
||||
int nDevices = SDL_GetNumAudioDevices(SDL_TRUE);
|
||||
fprintf(stderr, "%s: found %d capture devices:\n", __func__, nDevices);
|
||||
for (int i = 0; i < nDevices; i++) {
|
||||
fprintf(stderr, "%s: - Capture device #%d: '%s'\n", __func__, i, SDL_GetAudioDeviceName(i, SDL_TRUE));
|
||||
}
|
||||
}
|
||||
|
||||
SDL_AudioSpec capture_spec_requested;
|
||||
SDL_AudioSpec capture_spec_obtained;
|
||||
|
||||
SDL_zero(capture_spec_requested);
|
||||
SDL_zero(capture_spec_obtained);
|
||||
|
||||
capture_spec_requested.freq = sample_rate;
|
||||
capture_spec_requested.format = AUDIO_F32;
|
||||
capture_spec_requested.channels = 1;
|
||||
capture_spec_requested.samples = 1024;
|
||||
capture_spec_requested.callback = [](void * userdata, uint8_t * stream, int len) {
|
||||
audio_async * audio = (audio_async *) userdata;
|
||||
audio->callback(stream, len);
|
||||
};
|
||||
capture_spec_requested.userdata = this;
|
||||
|
||||
if (capture_id >= 0) {
|
||||
fprintf(stderr, "%s: attempt to open capture device %d : '%s' ...\n", __func__, capture_id, SDL_GetAudioDeviceName(capture_id, SDL_TRUE));
|
||||
m_dev_id_in = SDL_OpenAudioDevice(SDL_GetAudioDeviceName(capture_id, SDL_TRUE), SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
|
||||
} else {
|
||||
fprintf(stderr, "%s: attempt to open default capture device ...\n", __func__);
|
||||
m_dev_id_in = SDL_OpenAudioDevice(nullptr, SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
|
||||
}
|
||||
|
||||
if (!m_dev_id_in) {
|
||||
fprintf(stderr, "%s: couldn't open an audio device for capture: %s!\n", __func__, SDL_GetError());
|
||||
m_dev_id_in = 0;
|
||||
|
||||
return false;
|
||||
} else {
|
||||
fprintf(stderr, "%s: obtained spec for input device (SDL Id = %d):\n", __func__, m_dev_id_in);
|
||||
fprintf(stderr, "%s: - sample rate: %d\n", __func__, capture_spec_obtained.freq);
|
||||
fprintf(stderr, "%s: - format: %d (required: %d)\n", __func__, capture_spec_obtained.format,
|
||||
capture_spec_requested.format);
|
||||
fprintf(stderr, "%s: - channels: %d (required: %d)\n", __func__, capture_spec_obtained.channels,
|
||||
capture_spec_requested.channels);
|
||||
fprintf(stderr, "%s: - samples per frame: %d\n", __func__, capture_spec_obtained.samples);
|
||||
}
|
||||
|
||||
m_sample_rate = capture_spec_obtained.freq;
|
||||
|
||||
m_audio.resize((m_sample_rate*m_len_ms)/1000);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool audio_async::resume() {
|
||||
if (!m_dev_id_in) {
|
||||
fprintf(stderr, "%s: no audio device to resume!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (m_running) {
|
||||
fprintf(stderr, "%s: already running!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
SDL_PauseAudioDevice(m_dev_id_in, 0);
|
||||
|
||||
m_running = true;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool audio_async::pause() {
|
||||
if (!m_dev_id_in) {
|
||||
fprintf(stderr, "%s: no audio device to pause!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!m_running) {
|
||||
fprintf(stderr, "%s: already paused!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
SDL_PauseAudioDevice(m_dev_id_in, 1);
|
||||
|
||||
m_running = false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool audio_async::clear() {
|
||||
if (!m_dev_id_in) {
|
||||
fprintf(stderr, "%s: no audio device to clear!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!m_running) {
|
||||
fprintf(stderr, "%s: not running!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(m_mutex);
|
||||
|
||||
m_audio_pos = 0;
|
||||
m_audio_len = 0;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// callback to be called by SDL
|
||||
void audio_async::callback(uint8_t * stream, int len) {
|
||||
if (!m_running) {
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t n_samples = len / sizeof(float);
|
||||
|
||||
m_audio_new.resize(n_samples);
|
||||
memcpy(m_audio_new.data(), stream, n_samples * sizeof(float));
|
||||
|
||||
//fprintf(stderr, "%s: %zu samples, pos %zu, len %zu\n", __func__, n_samples, m_audio_pos, m_audio_len);
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(m_mutex);
|
||||
|
||||
if (m_audio_pos + n_samples > m_audio.size()) {
|
||||
const size_t n0 = m_audio.size() - m_audio_pos;
|
||||
|
||||
memcpy(&m_audio[m_audio_pos], stream, n0 * sizeof(float));
|
||||
memcpy(&m_audio[0], &stream[n0], (n_samples - n0) * sizeof(float));
|
||||
|
||||
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
|
||||
m_audio_len = m_audio.size();
|
||||
} else {
|
||||
memcpy(&m_audio[m_audio_pos], stream, n_samples * sizeof(float));
|
||||
|
||||
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
|
||||
m_audio_len = std::min(m_audio_len + n_samples, m_audio.size());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void audio_async::get(int ms, std::vector<float> & result) {
|
||||
if (!m_dev_id_in) {
|
||||
fprintf(stderr, "%s: no audio device to get audio from!\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!m_running) {
|
||||
fprintf(stderr, "%s: not running!\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
result.clear();
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(m_mutex);
|
||||
|
||||
if (ms <= 0) {
|
||||
ms = m_len_ms;
|
||||
}
|
||||
|
||||
size_t n_samples = (m_sample_rate * ms) / 1000;
|
||||
if (n_samples > m_audio_len) {
|
||||
n_samples = m_audio_len;
|
||||
}
|
||||
|
||||
result.resize(n_samples);
|
||||
|
||||
int s0 = m_audio_pos - n_samples;
|
||||
if (s0 < 0) {
|
||||
s0 += m_audio.size();
|
||||
}
|
||||
|
||||
if (s0 + n_samples > m_audio.size()) {
|
||||
const size_t n0 = m_audio.size() - s0;
|
||||
|
||||
memcpy(result.data(), &m_audio[s0], n0 * sizeof(float));
|
||||
memcpy(&result[n0], &m_audio[0], (n_samples - n0) * sizeof(float));
|
||||
} else {
|
||||
memcpy(result.data(), &m_audio[s0], n_samples * sizeof(float));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////
|
||||
|
||||
std::string trim(const std::string & s) {
|
||||
std::regex e("^\\s+|\\s+$");
|
||||
return std::regex_replace(s, e, "");
|
||||
}
|
||||
|
||||
void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate) {
|
||||
const float rc = 1.0f / (2.0f * M_PI * cutoff);
|
||||
const float dt = 1.0f / sample_rate;
|
||||
const float alpha = dt / (rc + dt);
|
||||
|
||||
float y = data[0];
|
||||
|
||||
for (size_t i = 1; i < data.size(); i++) {
|
||||
y = alpha * (y + data[i] - data[i - 1]);
|
||||
data[i] = y;
|
||||
}
|
||||
}
|
||||
|
||||
bool vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) {
|
||||
const int n_samples = pcmf32.size();
|
||||
const int n_samples_last = (sample_rate * last_ms) / 1000;
|
||||
|
||||
if (n_samples_last >= n_samples) {
|
||||
// not enough samples - assume no speech
|
||||
return false;
|
||||
}
|
||||
|
||||
if (freq_thold > 0.0f) {
|
||||
high_pass_filter(pcmf32, freq_thold, sample_rate);
|
||||
}
|
||||
|
||||
float energy_all = 0.0f;
|
||||
float energy_last = 0.0f;
|
||||
|
||||
for (int i = 0; i < n_samples; i++) {
|
||||
energy_all += fabsf(pcmf32[i]);
|
||||
|
||||
if (i >= n_samples - n_samples_last) {
|
||||
energy_last += fabsf(pcmf32[i]);
|
||||
}
|
||||
}
|
||||
|
||||
energy_all /= n_samples;
|
||||
energy_last /= n_samples_last;
|
||||
|
||||
if (verbose) {
|
||||
fprintf(stderr, "%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold);
|
||||
}
|
||||
|
||||
if (energy_last > vad_thold*energy_all) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string transcribe(whisper_context * ctx, const whisper_params & params, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) {
|
||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
@ -467,31 +163,6 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con
|
||||
return result;
|
||||
}
|
||||
|
||||
// compute similarity between two strings using Levenshtein distance
|
||||
float similarity(const std::string & s0, const std::string & s1) {
|
||||
const size_t len0 = s0.size() + 1;
|
||||
const size_t len1 = s1.size() + 1;
|
||||
|
||||
std::vector<int> col(len1, 0);
|
||||
std::vector<int> prevCol(len1, 0);
|
||||
|
||||
for (size_t i = 0; i < len1; i++) {
|
||||
prevCol[i] = i;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < len0; i++) {
|
||||
col[0] = i;
|
||||
for (size_t j = 1; j < len1; j++) {
|
||||
col[j] = std::min(std::min(1 + col[j - 1], 1 + prevCol[j]), prevCol[j - 1] + (s0[i - 1] == s1[j - 1] ? 0 : 1));
|
||||
}
|
||||
col.swap(prevCol);
|
||||
}
|
||||
|
||||
const float dist = prevCol[len1 - 1];
|
||||
|
||||
return 1.0f - (dist / std::max(s0.size(), s1.size()));
|
||||
}
|
||||
|
||||
std::vector<std::string> read_allowed_commands(const std::string & fname) {
|
||||
std::vector<std::string> allowed_commands;
|
||||
|
||||
@ -502,7 +173,7 @@ std::vector<std::string> read_allowed_commands(const std::string & fname) {
|
||||
|
||||
std::string line;
|
||||
while (std::getline(ifs, line)) {
|
||||
line = trim(line);
|
||||
line = ::trim(line);
|
||||
if (line.empty()) {
|
||||
continue;
|
||||
}
|
||||
@ -526,23 +197,6 @@ std::vector<std::string> get_words(const std::string &txt) {
|
||||
return words;
|
||||
}
|
||||
|
||||
// returns true if no exit event was received
|
||||
bool process_sdl_events() {
|
||||
SDL_Event event;
|
||||
while (SDL_PollEvent(&event)) {
|
||||
switch (event.type) {
|
||||
case SDL_QUIT:
|
||||
{
|
||||
return false;
|
||||
} break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// command-list mode
|
||||
// guide the transcription to match the most likely command from a provided list
|
||||
int process_command_list(struct whisper_context * ctx, audio_async &audio, const whisper_params ¶ms) {
|
||||
@ -634,14 +288,14 @@ int process_command_list(struct whisper_context * ctx, audio_async &audio, const
|
||||
// main loop
|
||||
while (is_running) {
|
||||
// handle Ctrl + C
|
||||
is_running = process_sdl_events();
|
||||
is_running = sdl_poll_events();
|
||||
|
||||
// delay
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
|
||||
audio.get(2000, pcmf32_cur);
|
||||
|
||||
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
|
||||
if (::vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
|
||||
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
|
||||
|
||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||
@ -775,7 +429,7 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
|
||||
// main loop
|
||||
while (is_running) {
|
||||
// handle Ctrl + C
|
||||
is_running = process_sdl_events();
|
||||
is_running = sdl_poll_events();
|
||||
|
||||
// delay
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
@ -791,7 +445,7 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
|
||||
{
|
||||
audio.get(2000, pcmf32_cur);
|
||||
|
||||
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
|
||||
if (::vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
|
||||
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
|
||||
|
||||
int64_t t_ms = 0;
|
||||
@ -854,7 +508,7 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud
|
||||
// main loop
|
||||
while (is_running) {
|
||||
// handle Ctrl + C
|
||||
is_running = process_sdl_events();
|
||||
is_running = sdl_poll_events();
|
||||
|
||||
// delay
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
@ -870,7 +524,7 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud
|
||||
{
|
||||
audio.get(2000, pcmf32_cur);
|
||||
|
||||
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
|
||||
if (::vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
|
||||
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
|
||||
|
||||
int64_t t_ms = 0;
|
||||
|
241
examples/common-ggml.cpp
Normal file
241
examples/common-ggml.cpp
Normal file
@ -0,0 +1,241 @@
|
||||
#include "common-ggml.h"
|
||||
|
||||
#include <regex>
|
||||
#include <map>
|
||||
|
||||
static const std::map<std::string, enum ggml_ftype> GGML_FTYPE_MAP = {
|
||||
{"q4_0", GGML_FTYPE_MOSTLY_Q4_0},
|
||||
{"q4_1", GGML_FTYPE_MOSTLY_Q4_1},
|
||||
{"q4_2", GGML_FTYPE_MOSTLY_Q4_2},
|
||||
{"q5_0", GGML_FTYPE_MOSTLY_Q5_0},
|
||||
{"q5_1", GGML_FTYPE_MOSTLY_Q5_1},
|
||||
{"q8_0", GGML_FTYPE_MOSTLY_Q8_0},
|
||||
};
|
||||
|
||||
void ggml_print_ftypes(FILE * fp) {
|
||||
for (auto it = GGML_FTYPE_MAP.begin(); it != GGML_FTYPE_MAP.end(); it++) {
|
||||
fprintf(fp, " type = \"%s\" or %d\n", it->first.c_str(), it->second);
|
||||
}
|
||||
}
|
||||
|
||||
enum ggml_ftype ggml_parse_ftype(const char * str) {
|
||||
enum ggml_ftype ftype;
|
||||
if (str[0] == 'q') {
|
||||
const auto it = GGML_FTYPE_MAP.find(str);
|
||||
if (it == GGML_FTYPE_MAP.end()) {
|
||||
fprintf(stderr, "%s: unknown ftype '%s'\n", __func__, str);
|
||||
return GGML_FTYPE_UNKNOWN;
|
||||
}
|
||||
ftype = it->second;
|
||||
} else {
|
||||
ftype = (enum ggml_ftype) atoi(str);
|
||||
}
|
||||
|
||||
return ftype;
|
||||
}
|
||||
|
||||
bool ggml_common_quantize_0(
|
||||
std::ifstream & finp,
|
||||
std::ofstream & fout,
|
||||
const ggml_ftype ftype,
|
||||
const std::vector<std::string> & to_quant,
|
||||
const std::vector<std::string> & to_skip) {
|
||||
|
||||
ggml_type qtype = GGML_TYPE_F32;
|
||||
|
||||
switch (ftype) {
|
||||
case GGML_FTYPE_MOSTLY_Q4_0: qtype = GGML_TYPE_Q4_0; break;
|
||||
case GGML_FTYPE_MOSTLY_Q4_1: qtype = GGML_TYPE_Q4_1; break;
|
||||
case GGML_FTYPE_MOSTLY_Q4_2: qtype = GGML_TYPE_Q4_2; break;
|
||||
case GGML_FTYPE_MOSTLY_Q5_0: qtype = GGML_TYPE_Q5_0; break;
|
||||
case GGML_FTYPE_MOSTLY_Q5_1: qtype = GGML_TYPE_Q5_1; break;
|
||||
case GGML_FTYPE_MOSTLY_Q8_0: qtype = GGML_TYPE_Q8_0; break;
|
||||
case GGML_FTYPE_UNKNOWN:
|
||||
case GGML_FTYPE_ALL_F32:
|
||||
case GGML_FTYPE_MOSTLY_F16:
|
||||
case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16:
|
||||
{
|
||||
fprintf(stderr, "%s: invalid model type %d\n", __func__, ftype);
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
if (!ggml_is_quantized(qtype)) {
|
||||
fprintf(stderr, "%s: invalid quantization type %d (%s)\n", __func__, qtype, ggml_type_name(qtype));
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t total_size_org = 0;
|
||||
size_t total_size_new = 0;
|
||||
|
||||
std::vector<float> work;
|
||||
|
||||
std::vector<uint8_t> data_u8;
|
||||
std::vector<ggml_fp16_t> data_f16;
|
||||
std::vector<float> data_f32;
|
||||
|
||||
std::vector<int64_t> hist_all(1 << 4, 0);
|
||||
|
||||
while (true) {
|
||||
int32_t n_dims;
|
||||
int32_t length;
|
||||
int32_t ttype;
|
||||
|
||||
finp.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
|
||||
finp.read(reinterpret_cast<char *>(&length), sizeof(length));
|
||||
finp.read(reinterpret_cast<char *>(&ttype), sizeof(ttype));
|
||||
|
||||
if (finp.eof()) {
|
||||
break;
|
||||
}
|
||||
|
||||
int32_t nelements = 1;
|
||||
int32_t ne[4] = { 1, 1, 1, 1 };
|
||||
for (int i = 0; i < n_dims; ++i) {
|
||||
finp.read (reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
|
||||
nelements *= ne[i];
|
||||
}
|
||||
|
||||
std::string name(length, 0);
|
||||
finp.read (&name[0], length);
|
||||
|
||||
printf("%64s - [%5d, %5d, %5d], type = %6s ", name.data(), ne[0], ne[1], ne[2], ggml_type_name((ggml_type) ttype));
|
||||
|
||||
bool quantize = false;
|
||||
|
||||
// check if we should quantize this tensor
|
||||
for (const auto & s : to_quant) {
|
||||
if (std::regex_match(name, std::regex(s))) {
|
||||
quantize = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// check if we should skip this tensor
|
||||
for (const auto & s : to_skip) {
|
||||
if (std::regex_match(name, std::regex(s))) {
|
||||
quantize = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// quantize only 2D tensors
|
||||
quantize &= (n_dims == 2);
|
||||
|
||||
if (quantize) {
|
||||
if (ttype != GGML_TYPE_F32 && ttype != GGML_TYPE_F16) {
|
||||
fprintf(stderr, "%s: unsupported ttype %d (%s) for integer quantization\n", __func__, ttype, ggml_type_name((ggml_type) ttype));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (ttype == GGML_TYPE_F16) {
|
||||
data_f16.resize(nelements);
|
||||
finp.read(reinterpret_cast<char *>(data_f16.data()), nelements * sizeof(ggml_fp16_t));
|
||||
data_f32.resize(nelements);
|
||||
for (int i = 0; i < nelements; ++i) {
|
||||
data_f32[i] = ggml_fp16_to_fp32(data_f16[i]);
|
||||
}
|
||||
} else {
|
||||
data_f32.resize(nelements);
|
||||
finp.read(reinterpret_cast<char *>(data_f32.data()), nelements * sizeof(float));
|
||||
}
|
||||
|
||||
ttype = qtype;
|
||||
} else {
|
||||
const int bpe = (ttype == 0) ? sizeof(float) : sizeof(uint16_t);
|
||||
|
||||
data_u8.resize(nelements*bpe);
|
||||
finp.read(reinterpret_cast<char *>(data_u8.data()), nelements * bpe);
|
||||
}
|
||||
|
||||
fout.write(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
|
||||
fout.write(reinterpret_cast<char *>(&length), sizeof(length));
|
||||
fout.write(reinterpret_cast<char *>(&ttype), sizeof(ttype));
|
||||
for (int i = 0; i < n_dims; ++i) {
|
||||
fout.write(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
|
||||
}
|
||||
fout.write(&name[0], length);
|
||||
|
||||
if (quantize) {
|
||||
work.resize(nelements); // for quantization
|
||||
|
||||
size_t cur_size = 0;
|
||||
std::vector<int64_t> hist_cur(1 << 4, 0);
|
||||
|
||||
switch ((ggml_type) ttype) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
{
|
||||
cur_size = ggml_quantize_q4_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
|
||||
} break;
|
||||
case GGML_TYPE_Q4_1:
|
||||
{
|
||||
cur_size = ggml_quantize_q4_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
|
||||
} break;
|
||||
case GGML_TYPE_Q4_2:
|
||||
{
|
||||
cur_size = ggml_quantize_q4_2(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
|
||||
} break;
|
||||
case GGML_TYPE_Q5_0:
|
||||
{
|
||||
cur_size = ggml_quantize_q5_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
|
||||
} break;
|
||||
case GGML_TYPE_Q5_1:
|
||||
{
|
||||
cur_size = ggml_quantize_q5_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
|
||||
} break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
{
|
||||
cur_size = ggml_quantize_q8_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
|
||||
} break;
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_I8:
|
||||
case GGML_TYPE_I16:
|
||||
case GGML_TYPE_I32:
|
||||
case GGML_TYPE_Q8_1:
|
||||
case GGML_TYPE_COUNT:
|
||||
{
|
||||
fprintf(stderr, "%s: unsupported quantization type %d (%s)\n", __func__, ttype, ggml_type_name((ggml_type) ttype));
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
fout.write(reinterpret_cast<char *>(work.data()), cur_size);
|
||||
total_size_new += cur_size;
|
||||
|
||||
printf("size = %8.2f MB -> %8.2f MB | hist: ", nelements * sizeof(float)/1024.0/1024.0, cur_size/1024.0/1024.0);
|
||||
for (int i = 0; i < (int) hist_cur.size(); ++i) {
|
||||
hist_all[i] += hist_cur[i];
|
||||
}
|
||||
|
||||
for (int i = 0; i < (int) hist_cur.size(); ++i) {
|
||||
printf("%5.3f ", hist_cur[i] / (float)nelements);
|
||||
}
|
||||
printf("\n");
|
||||
} else {
|
||||
printf("size = %8.3f MB\n", data_u8.size()/1024.0/1024.0);
|
||||
fout.write(reinterpret_cast<char *>(data_u8.data()), data_u8.size());
|
||||
total_size_new += data_u8.size();
|
||||
}
|
||||
|
||||
total_size_org += nelements * sizeof(float);
|
||||
}
|
||||
|
||||
printf("%s: model size = %8.2f MB\n", __func__, total_size_org/1024.0/1024.0);
|
||||
printf("%s: quant size = %8.2f MB | ftype = %d (%s)\n", __func__, total_size_new/1024.0/1024.0, ftype, ggml_type_name(qtype));
|
||||
|
||||
{
|
||||
int64_t sum_all = 0;
|
||||
for (int i = 0; i < (int) hist_all.size(); ++i) {
|
||||
sum_all += hist_all[i];
|
||||
}
|
||||
|
||||
printf("%s: hist: ", __func__);
|
||||
for (int i = 0; i < (int) hist_all.size(); ++i) {
|
||||
printf("%5.3f ", hist_all[i] / (float)sum_all);
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
18
examples/common-ggml.h
Normal file
18
examples/common-ggml.h
Normal file
@ -0,0 +1,18 @@
|
||||
#pragma once
|
||||
|
||||
#include "ggml.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
enum ggml_ftype ggml_parse_ftype(const char * str);
|
||||
|
||||
void ggml_print_ftypes(FILE * fp = stderr);
|
||||
|
||||
bool ggml_common_quantize_0(
|
||||
std::ifstream & finp,
|
||||
std::ofstream & fout,
|
||||
const ggml_ftype ftype,
|
||||
const std::vector<std::string> & to_quant,
|
||||
const std::vector<std::string> & to_skip);
|
226
examples/common-sdl.cpp
Normal file
226
examples/common-sdl.cpp
Normal file
@ -0,0 +1,226 @@
|
||||
#include "common-sdl.h"
|
||||
|
||||
audio_async::audio_async(int len_ms) {
|
||||
m_len_ms = len_ms;
|
||||
|
||||
m_running = false;
|
||||
}
|
||||
|
||||
audio_async::~audio_async() {
|
||||
if (m_dev_id_in) {
|
||||
SDL_CloseAudioDevice(m_dev_id_in);
|
||||
}
|
||||
}
|
||||
|
||||
bool audio_async::init(int capture_id, int sample_rate) {
|
||||
SDL_LogSetPriority(SDL_LOG_CATEGORY_APPLICATION, SDL_LOG_PRIORITY_INFO);
|
||||
|
||||
if (SDL_Init(SDL_INIT_AUDIO) < 0) {
|
||||
SDL_LogError(SDL_LOG_CATEGORY_APPLICATION, "Couldn't initialize SDL: %s\n", SDL_GetError());
|
||||
return false;
|
||||
}
|
||||
|
||||
SDL_SetHintWithPriority(SDL_HINT_AUDIO_RESAMPLING_MODE, "medium", SDL_HINT_OVERRIDE);
|
||||
|
||||
{
|
||||
int nDevices = SDL_GetNumAudioDevices(SDL_TRUE);
|
||||
fprintf(stderr, "%s: found %d capture devices:\n", __func__, nDevices);
|
||||
for (int i = 0; i < nDevices; i++) {
|
||||
fprintf(stderr, "%s: - Capture device #%d: '%s'\n", __func__, i, SDL_GetAudioDeviceName(i, SDL_TRUE));
|
||||
}
|
||||
}
|
||||
|
||||
SDL_AudioSpec capture_spec_requested;
|
||||
SDL_AudioSpec capture_spec_obtained;
|
||||
|
||||
SDL_zero(capture_spec_requested);
|
||||
SDL_zero(capture_spec_obtained);
|
||||
|
||||
capture_spec_requested.freq = sample_rate;
|
||||
capture_spec_requested.format = AUDIO_F32;
|
||||
capture_spec_requested.channels = 1;
|
||||
capture_spec_requested.samples = 1024;
|
||||
capture_spec_requested.callback = [](void * userdata, uint8_t * stream, int len) {
|
||||
audio_async * audio = (audio_async *) userdata;
|
||||
audio->callback(stream, len);
|
||||
};
|
||||
capture_spec_requested.userdata = this;
|
||||
|
||||
if (capture_id >= 0) {
|
||||
fprintf(stderr, "%s: attempt to open capture device %d : '%s' ...\n", __func__, capture_id, SDL_GetAudioDeviceName(capture_id, SDL_TRUE));
|
||||
m_dev_id_in = SDL_OpenAudioDevice(SDL_GetAudioDeviceName(capture_id, SDL_TRUE), SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
|
||||
} else {
|
||||
fprintf(stderr, "%s: attempt to open default capture device ...\n", __func__);
|
||||
m_dev_id_in = SDL_OpenAudioDevice(nullptr, SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
|
||||
}
|
||||
|
||||
if (!m_dev_id_in) {
|
||||
fprintf(stderr, "%s: couldn't open an audio device for capture: %s!\n", __func__, SDL_GetError());
|
||||
m_dev_id_in = 0;
|
||||
|
||||
return false;
|
||||
} else {
|
||||
fprintf(stderr, "%s: obtained spec for input device (SDL Id = %d):\n", __func__, m_dev_id_in);
|
||||
fprintf(stderr, "%s: - sample rate: %d\n", __func__, capture_spec_obtained.freq);
|
||||
fprintf(stderr, "%s: - format: %d (required: %d)\n", __func__, capture_spec_obtained.format,
|
||||
capture_spec_requested.format);
|
||||
fprintf(stderr, "%s: - channels: %d (required: %d)\n", __func__, capture_spec_obtained.channels,
|
||||
capture_spec_requested.channels);
|
||||
fprintf(stderr, "%s: - samples per frame: %d\n", __func__, capture_spec_obtained.samples);
|
||||
}
|
||||
|
||||
m_sample_rate = capture_spec_obtained.freq;
|
||||
|
||||
m_audio.resize((m_sample_rate*m_len_ms)/1000);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool audio_async::resume() {
|
||||
if (!m_dev_id_in) {
|
||||
fprintf(stderr, "%s: no audio device to resume!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (m_running) {
|
||||
fprintf(stderr, "%s: already running!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
SDL_PauseAudioDevice(m_dev_id_in, 0);
|
||||
|
||||
m_running = true;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool audio_async::pause() {
|
||||
if (!m_dev_id_in) {
|
||||
fprintf(stderr, "%s: no audio device to pause!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!m_running) {
|
||||
fprintf(stderr, "%s: already paused!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
SDL_PauseAudioDevice(m_dev_id_in, 1);
|
||||
|
||||
m_running = false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool audio_async::clear() {
|
||||
if (!m_dev_id_in) {
|
||||
fprintf(stderr, "%s: no audio device to clear!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!m_running) {
|
||||
fprintf(stderr, "%s: not running!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(m_mutex);
|
||||
|
||||
m_audio_pos = 0;
|
||||
m_audio_len = 0;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// callback to be called by SDL
|
||||
void audio_async::callback(uint8_t * stream, int len) {
|
||||
if (!m_running) {
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t n_samples = len / sizeof(float);
|
||||
|
||||
m_audio_new.resize(n_samples);
|
||||
memcpy(m_audio_new.data(), stream, n_samples * sizeof(float));
|
||||
|
||||
//fprintf(stderr, "%s: %zu samples, pos %zu, len %zu\n", __func__, n_samples, m_audio_pos, m_audio_len);
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(m_mutex);
|
||||
|
||||
if (m_audio_pos + n_samples > m_audio.size()) {
|
||||
const size_t n0 = m_audio.size() - m_audio_pos;
|
||||
|
||||
memcpy(&m_audio[m_audio_pos], stream, n0 * sizeof(float));
|
||||
memcpy(&m_audio[0], &stream[n0], (n_samples - n0) * sizeof(float));
|
||||
|
||||
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
|
||||
m_audio_len = m_audio.size();
|
||||
} else {
|
||||
memcpy(&m_audio[m_audio_pos], stream, n_samples * sizeof(float));
|
||||
|
||||
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
|
||||
m_audio_len = std::min(m_audio_len + n_samples, m_audio.size());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void audio_async::get(int ms, std::vector<float> & result) {
|
||||
if (!m_dev_id_in) {
|
||||
fprintf(stderr, "%s: no audio device to get audio from!\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!m_running) {
|
||||
fprintf(stderr, "%s: not running!\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
result.clear();
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(m_mutex);
|
||||
|
||||
if (ms <= 0) {
|
||||
ms = m_len_ms;
|
||||
}
|
||||
|
||||
size_t n_samples = (m_sample_rate * ms) / 1000;
|
||||
if (n_samples > m_audio_len) {
|
||||
n_samples = m_audio_len;
|
||||
}
|
||||
|
||||
result.resize(n_samples);
|
||||
|
||||
int s0 = m_audio_pos - n_samples;
|
||||
if (s0 < 0) {
|
||||
s0 += m_audio.size();
|
||||
}
|
||||
|
||||
if (s0 + n_samples > m_audio.size()) {
|
||||
const size_t n0 = m_audio.size() - s0;
|
||||
|
||||
memcpy(result.data(), &m_audio[s0], n0 * sizeof(float));
|
||||
memcpy(&result[n0], &m_audio[0], (n_samples - n0) * sizeof(float));
|
||||
} else {
|
||||
memcpy(result.data(), &m_audio[s0], n_samples * sizeof(float));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool sdl_poll_events() {
|
||||
SDL_Event event;
|
||||
while (SDL_PollEvent(&event)) {
|
||||
switch (event.type) {
|
||||
case SDL_QUIT:
|
||||
{
|
||||
return false;
|
||||
} break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
50
examples/common-sdl.h
Normal file
50
examples/common-sdl.h
Normal file
@ -0,0 +1,50 @@
|
||||
#pragma once
|
||||
|
||||
#include <SDL.h>
|
||||
#include <SDL_audio.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
#include <mutex>
|
||||
|
||||
//
|
||||
// SDL Audio capture
|
||||
//
|
||||
|
||||
class audio_async {
|
||||
public:
|
||||
audio_async(int len_ms);
|
||||
~audio_async();
|
||||
|
||||
bool init(int capture_id, int sample_rate);
|
||||
|
||||
// start capturing audio via the provided SDL callback
|
||||
// keep last len_ms seconds of audio in a circular buffer
|
||||
bool resume();
|
||||
bool pause();
|
||||
bool clear();
|
||||
|
||||
// callback to be called by SDL
|
||||
void callback(uint8_t * stream, int len);
|
||||
|
||||
// get audio data from the circular buffer
|
||||
void get(int ms, std::vector<float> & audio);
|
||||
|
||||
private:
|
||||
SDL_AudioDeviceID m_dev_id_in = 0;
|
||||
|
||||
int m_len_ms = 0;
|
||||
int m_sample_rate = 0;
|
||||
|
||||
std::atomic_bool m_running;
|
||||
std::mutex m_mutex;
|
||||
|
||||
std::vector<float> m_audio;
|
||||
std::vector<float> m_audio_new;
|
||||
size_t m_audio_pos = 0;
|
||||
size_t m_audio_len = 0;
|
||||
};
|
||||
|
||||
// Return false if need to quit
|
||||
bool sdl_poll_events();
|
505
examples/common.cpp
Normal file
505
examples/common.cpp
Normal file
@ -0,0 +1,505 @@
|
||||
#include "common.h"
|
||||
|
||||
// third-party utilities
|
||||
// use your favorite implementations
|
||||
#define DR_WAV_IMPLEMENTATION
|
||||
#include "dr_wav.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <fstream>
|
||||
#include <regex>
|
||||
|
||||
#ifndef M_PI
|
||||
#define M_PI 3.14159265358979323846
|
||||
#endif
|
||||
|
||||
bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
||||
for (int i = 1; i < argc; i++) {
|
||||
std::string arg = argv[i];
|
||||
|
||||
if (arg == "-s" || arg == "--seed") {
|
||||
params.seed = std::stoi(argv[++i]);
|
||||
} else if (arg == "-t" || arg == "--threads") {
|
||||
params.n_threads = std::stoi(argv[++i]);
|
||||
} else if (arg == "-p" || arg == "--prompt") {
|
||||
params.prompt = argv[++i];
|
||||
} else if (arg == "-n" || arg == "--n_predict") {
|
||||
params.n_predict = std::stoi(argv[++i]);
|
||||
} else if (arg == "--top_k") {
|
||||
params.top_k = std::stoi(argv[++i]);
|
||||
} else if (arg == "--top_p") {
|
||||
params.top_p = std::stof(argv[++i]);
|
||||
} else if (arg == "--temp") {
|
||||
params.temp = std::stof(argv[++i]);
|
||||
} else if (arg == "-b" || arg == "--batch_size") {
|
||||
params.n_batch = std::stoi(argv[++i]);
|
||||
} else if (arg == "-m" || arg == "--model") {
|
||||
params.model = argv[++i];
|
||||
} else if (arg == "-h" || arg == "--help") {
|
||||
gpt_print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
} else {
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
gpt_print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
||||
fprintf(stderr, "usage: %s [options]\n", argv[0]);
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "options:\n");
|
||||
fprintf(stderr, " -h, --help show this help message and exit\n");
|
||||
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n");
|
||||
fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
|
||||
fprintf(stderr, " -p PROMPT, --prompt PROMPT\n");
|
||||
fprintf(stderr, " prompt to start generation with (default: random)\n");
|
||||
fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d)\n", params.n_predict);
|
||||
fprintf(stderr, " --top_k N top-k sampling (default: %d)\n", params.top_k);
|
||||
fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", params.top_p);
|
||||
fprintf(stderr, " --temp N temperature (default: %.1f)\n", params.temp);
|
||||
fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch);
|
||||
fprintf(stderr, " -m FNAME, --model FNAME\n");
|
||||
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
std::string gpt_random_prompt(std::mt19937 & rng) {
|
||||
const int r = rng() % 10;
|
||||
switch (r) {
|
||||
case 0: return "So";
|
||||
case 1: return "Once upon a time";
|
||||
case 2: return "When";
|
||||
case 3: return "The";
|
||||
case 4: return "After";
|
||||
case 5: return "If";
|
||||
case 6: return "import";
|
||||
case 7: return "He";
|
||||
case 8: return "She";
|
||||
case 9: return "They";
|
||||
default: return "To";
|
||||
}
|
||||
|
||||
return "The";
|
||||
}
|
||||
|
||||
std::string trim(const std::string & s) {
|
||||
std::regex e("^\\s+|\\s+$");
|
||||
return std::regex_replace(s, e, "");
|
||||
}
|
||||
|
||||
std::string replace(const std::string & s, const std::string & from, const std::string & to) {
|
||||
std::string result = s;
|
||||
size_t pos = 0;
|
||||
while ((pos = result.find(from, pos)) != std::string::npos) {
|
||||
result.replace(pos, from.length(), to);
|
||||
pos += to.length();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::map<std::string, int32_t> json_parse(const std::string & fname) {
|
||||
std::map<std::string, int32_t> result;
|
||||
|
||||
// read file into string
|
||||
std::string json;
|
||||
{
|
||||
std::ifstream ifs(fname);
|
||||
if (!ifs) {
|
||||
fprintf(stderr, "Failed to open %s\n", fname.c_str());
|
||||
exit(1);
|
||||
}
|
||||
|
||||
json = std::string((std::istreambuf_iterator<char>(ifs)),
|
||||
(std::istreambuf_iterator<char>()));
|
||||
}
|
||||
|
||||
if (json[0] != '{') {
|
||||
return result;
|
||||
}
|
||||
|
||||
// parse json
|
||||
{
|
||||
bool has_key = false;
|
||||
bool in_token = false;
|
||||
|
||||
std::string str_key = "";
|
||||
std::string str_val = "";
|
||||
|
||||
int n = json.size();
|
||||
for (int i = 1; i < n; ++i) {
|
||||
if (!in_token) {
|
||||
if (json[i] == ' ') continue;
|
||||
if (json[i] == '"') {
|
||||
in_token = true;
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
if (json[i] == '\\' && i+1 < n) {
|
||||
if (has_key == false) {
|
||||
str_key += json[i];
|
||||
} else {
|
||||
str_val += json[i];
|
||||
}
|
||||
++i;
|
||||
} else if (json[i] == '"') {
|
||||
if (has_key == false) {
|
||||
has_key = true;
|
||||
++i;
|
||||
while (json[i] == ' ') ++i;
|
||||
++i; // :
|
||||
while (json[i] == ' ') ++i;
|
||||
if (json[i] != '\"') {
|
||||
while (json[i] != ',' && json[i] != '}') {
|
||||
str_val += json[i++];
|
||||
}
|
||||
has_key = false;
|
||||
} else {
|
||||
in_token = true;
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
has_key = false;
|
||||
}
|
||||
|
||||
str_key = ::replace(str_key, "\\u0120", " " ); // \u0120 -> space
|
||||
str_key = ::replace(str_key, "\\u010a", "\n"); // \u010a -> new line
|
||||
str_key = ::replace(str_key, "\\\"", "\""); // \\\" -> "
|
||||
|
||||
try {
|
||||
result[str_key] = std::stoi(str_val);
|
||||
} catch (...) {
|
||||
//fprintf(stderr, "%s: ignoring key '%s' with value '%s'\n", fname.c_str(), str_key.c_str(), str_val.c_str());
|
||||
|
||||
}
|
||||
str_key = "";
|
||||
str_val = "";
|
||||
in_token = false;
|
||||
continue;
|
||||
}
|
||||
if (has_key == false) {
|
||||
str_key += json[i];
|
||||
} else {
|
||||
str_val += json[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text) {
|
||||
std::vector<std::string> words;
|
||||
|
||||
// first split the text into words
|
||||
{
|
||||
std::string str = text;
|
||||
std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
|
||||
|
||||
std::regex re(pat);
|
||||
std::smatch m;
|
||||
|
||||
while (std::regex_search(str, m, re)) {
|
||||
for (auto x : m) {
|
||||
words.push_back(x);
|
||||
}
|
||||
str = m.suffix();
|
||||
}
|
||||
}
|
||||
|
||||
// find the longest tokens that form the words:
|
||||
std::vector<gpt_vocab::id> tokens;
|
||||
for (const auto & word : words) {
|
||||
if (word.size() == 0) continue;
|
||||
|
||||
int i = 0;
|
||||
int n = word.size();
|
||||
while (i < n) {
|
||||
int j = n;
|
||||
while (j > i) {
|
||||
auto it = vocab.token_to_id.find(word.substr(i, j-i));
|
||||
if (it != vocab.token_to_id.end()) {
|
||||
tokens.push_back(it->second);
|
||||
i = j;
|
||||
break;
|
||||
}
|
||||
--j;
|
||||
}
|
||||
if (i == n) {
|
||||
break;
|
||||
}
|
||||
if (j == i) {
|
||||
auto sub = word.substr(i, 1);
|
||||
if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) {
|
||||
tokens.push_back(vocab.token_to_id.at(sub));
|
||||
} else {
|
||||
fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data());
|
||||
}
|
||||
++i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tokens;
|
||||
}
|
||||
|
||||
bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) {
|
||||
printf("%s: loading vocab from '%s'\n", __func__, fname.c_str());
|
||||
|
||||
vocab.token_to_id = ::json_parse(fname);
|
||||
|
||||
for (const auto & kv : vocab.token_to_id) {
|
||||
vocab.id_to_token[kv.second] = kv.first;
|
||||
}
|
||||
|
||||
printf("%s: vocab size = %d\n", __func__, (int) vocab.token_to_id.size());
|
||||
|
||||
// print the vocabulary
|
||||
//for (auto kv : vocab.token_to_id) {
|
||||
// printf("'%s' -> %d\n", kv.first.data(), kv.second);
|
||||
//}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
gpt_vocab::id gpt_sample_top_k_top_p(
|
||||
const gpt_vocab & vocab,
|
||||
const float * logits,
|
||||
int top_k,
|
||||
double top_p,
|
||||
double temp,
|
||||
std::mt19937 & rng) {
|
||||
int n_logits = vocab.id_to_token.size();
|
||||
|
||||
std::vector<std::pair<double, gpt_vocab::id>> logits_id;
|
||||
logits_id.reserve(n_logits);
|
||||
|
||||
{
|
||||
const double scale = 1.0/temp;
|
||||
for (int i = 0; i < n_logits; ++i) {
|
||||
logits_id.push_back(std::make_pair(logits[i]*scale, i));
|
||||
}
|
||||
}
|
||||
|
||||
// find the top K tokens
|
||||
std::partial_sort(
|
||||
logits_id.begin(),
|
||||
logits_id.begin() + top_k, logits_id.end(),
|
||||
[](const std::pair<double, gpt_vocab::id> & a, const std::pair<double, gpt_vocab::id> & b) {
|
||||
return a.first > b.first;
|
||||
});
|
||||
|
||||
logits_id.resize(top_k);
|
||||
|
||||
double maxl = -INFINITY;
|
||||
for (const auto & kv : logits_id) {
|
||||
maxl = std::max(maxl, kv.first);
|
||||
}
|
||||
|
||||
// compute probs for the top K tokens
|
||||
std::vector<double> probs;
|
||||
probs.reserve(logits_id.size());
|
||||
|
||||
double sum = 0.0;
|
||||
for (const auto & kv : logits_id) {
|
||||
double p = exp(kv.first - maxl);
|
||||
probs.push_back(p);
|
||||
sum += p;
|
||||
}
|
||||
|
||||
// normalize the probs
|
||||
for (auto & p : probs) {
|
||||
p /= sum;
|
||||
}
|
||||
|
||||
if (top_p < 1.0f) {
|
||||
double cumsum = 0.0f;
|
||||
for (int i = 0; i < top_k; i++) {
|
||||
cumsum += probs[i];
|
||||
if (cumsum >= top_p) {
|
||||
top_k = i + 1;
|
||||
probs.resize(top_k);
|
||||
logits_id.resize(top_k);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
cumsum = 1.0/cumsum;
|
||||
for (int i = 0; i < (int) probs.size(); i++) {
|
||||
probs[i] *= cumsum;
|
||||
}
|
||||
}
|
||||
|
||||
//printf("\n");
|
||||
//for (int i = 0; i < (int) probs.size(); i++) {
|
||||
// printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]);
|
||||
//}
|
||||
//exit(0);
|
||||
|
||||
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
||||
int idx = dist(rng);
|
||||
|
||||
return logits_id[idx].second;
|
||||
}
|
||||
|
||||
bool read_wav(const std::string & fname, std::vector<float>& pcmf32, std::vector<std::vector<float>>& pcmf32s, bool stereo) {
|
||||
drwav wav;
|
||||
std::vector<uint8_t> wav_data; // used for pipe input from stdin
|
||||
|
||||
if (fname == "-") {
|
||||
{
|
||||
uint8_t buf[1024];
|
||||
while (true)
|
||||
{
|
||||
const size_t n = fread(buf, 1, sizeof(buf), stdin);
|
||||
if (n == 0) {
|
||||
break;
|
||||
}
|
||||
wav_data.insert(wav_data.end(), buf, buf + n);
|
||||
}
|
||||
}
|
||||
|
||||
if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) {
|
||||
fprintf(stderr, "error: failed to open WAV file from stdin\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
|
||||
}
|
||||
else if (drwav_init_file(&wav, fname.c_str(), nullptr) == false) {
|
||||
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (wav.channels != 1 && wav.channels != 2) {
|
||||
fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", __func__, fname.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (stereo && wav.channels != 2) {
|
||||
fprintf(stderr, "%s: WAV file '%s' must be stereo for diarization\n", __func__, fname.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (wav.sampleRate != COMMON_SAMPLE_RATE) {
|
||||
fprintf(stderr, "%s: WAV file '%s' must be %i kHz\n", __func__, fname.c_str(), COMMON_SAMPLE_RATE/1000);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (wav.bitsPerSample != 16) {
|
||||
fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", __func__, fname.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8);
|
||||
|
||||
std::vector<int16_t> pcm16;
|
||||
pcm16.resize(n*wav.channels);
|
||||
drwav_read_pcm_frames_s16(&wav, n, pcm16.data());
|
||||
drwav_uninit(&wav);
|
||||
|
||||
// convert to mono, float
|
||||
pcmf32.resize(n);
|
||||
if (wav.channels == 1) {
|
||||
for (uint64_t i = 0; i < n; i++) {
|
||||
pcmf32[i] = float(pcm16[i])/32768.0f;
|
||||
}
|
||||
} else {
|
||||
for (uint64_t i = 0; i < n; i++) {
|
||||
pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
|
||||
}
|
||||
}
|
||||
|
||||
if (stereo) {
|
||||
// convert to stereo, float
|
||||
pcmf32s.resize(2);
|
||||
|
||||
pcmf32s[0].resize(n);
|
||||
pcmf32s[1].resize(n);
|
||||
for (uint64_t i = 0; i < n; i++) {
|
||||
pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
|
||||
pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate) {
|
||||
const float rc = 1.0f / (2.0f * M_PI * cutoff);
|
||||
const float dt = 1.0f / sample_rate;
|
||||
const float alpha = dt / (rc + dt);
|
||||
|
||||
float y = data[0];
|
||||
|
||||
for (size_t i = 1; i < data.size(); i++) {
|
||||
y = alpha * (y + data[i] - data[i - 1]);
|
||||
data[i] = y;
|
||||
}
|
||||
}
|
||||
|
||||
bool vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) {
|
||||
const int n_samples = pcmf32.size();
|
||||
const int n_samples_last = (sample_rate * last_ms) / 1000;
|
||||
|
||||
if (n_samples_last >= n_samples) {
|
||||
// not enough samples - assume no speech
|
||||
return false;
|
||||
}
|
||||
|
||||
if (freq_thold > 0.0f) {
|
||||
high_pass_filter(pcmf32, freq_thold, sample_rate);
|
||||
}
|
||||
|
||||
float energy_all = 0.0f;
|
||||
float energy_last = 0.0f;
|
||||
|
||||
for (int i = 0; i < n_samples; i++) {
|
||||
energy_all += fabsf(pcmf32[i]);
|
||||
|
||||
if (i >= n_samples - n_samples_last) {
|
||||
energy_last += fabsf(pcmf32[i]);
|
||||
}
|
||||
}
|
||||
|
||||
energy_all /= n_samples;
|
||||
energy_last /= n_samples_last;
|
||||
|
||||
if (verbose) {
|
||||
fprintf(stderr, "%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold);
|
||||
}
|
||||
|
||||
if (energy_last > vad_thold*energy_all) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
float similarity(const std::string & s0, const std::string & s1) {
|
||||
const size_t len0 = s0.size() + 1;
|
||||
const size_t len1 = s1.size() + 1;
|
||||
|
||||
std::vector<int> col(len1, 0);
|
||||
std::vector<int> prevCol(len1, 0);
|
||||
|
||||
for (size_t i = 0; i < len1; i++) {
|
||||
prevCol[i] = i;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < len0; i++) {
|
||||
col[0] = i;
|
||||
for (size_t j = 1; j < len1; j++) {
|
||||
col[j] = std::min(std::min(1 + col[j - 1], 1 + prevCol[j]), prevCol[j - 1] + (i > 0 && s0[i - 1] == s1[j - 1] ? 0 : 1));
|
||||
}
|
||||
col.swap(prevCol);
|
||||
}
|
||||
|
||||
const float dist = prevCol[len1 - 1];
|
||||
|
||||
return 1.0f - (dist / std::max(s0.size(), s1.size()));
|
||||
}
|
122
examples/common.h
Normal file
122
examples/common.h
Normal file
@ -0,0 +1,122 @@
|
||||
// Various helper functions and utilities
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <random>
|
||||
#include <thread>
|
||||
|
||||
#define COMMON_SAMPLE_RATE 16000
|
||||
|
||||
//
|
||||
// CLI argument parsing
|
||||
//
|
||||
|
||||
struct gpt_params {
|
||||
int32_t seed = -1; // RNG seed
|
||||
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||
int32_t n_predict = 200; // new tokens to predict
|
||||
|
||||
// sampling parameters
|
||||
int32_t top_k = 40;
|
||||
float top_p = 0.9f;
|
||||
float temp = 0.9f;
|
||||
|
||||
int32_t n_batch = 8; // batch size for prompt processing
|
||||
|
||||
std::string model = "models/gpt-2-117M/ggml-model.bin"; // model path
|
||||
std::string prompt;
|
||||
};
|
||||
|
||||
bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
|
||||
|
||||
void gpt_print_usage(int argc, char ** argv, const gpt_params & params);
|
||||
|
||||
std::string gpt_random_prompt(std::mt19937 & rng);
|
||||
|
||||
//
|
||||
// Vocab utils
|
||||
//
|
||||
|
||||
std::string trim(const std::string & s);
|
||||
|
||||
std::string replace(
|
||||
const std::string & s,
|
||||
const std::string & from,
|
||||
const std::string & to);
|
||||
|
||||
struct gpt_vocab {
|
||||
using id = int32_t;
|
||||
using token = std::string;
|
||||
|
||||
std::map<token, id> token_to_id;
|
||||
std::map<id, token> id_to_token;
|
||||
};
|
||||
|
||||
// poor-man's JSON parsing
|
||||
std::map<std::string, int32_t> json_parse(const std::string & fname);
|
||||
|
||||
// split text into tokens
|
||||
//
|
||||
// ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53
|
||||
//
|
||||
// Regex (Python):
|
||||
// r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
|
||||
//
|
||||
// Regex (C++):
|
||||
// R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"
|
||||
//
|
||||
std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text);
|
||||
|
||||
// load the tokens from encoder.json
|
||||
bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab);
|
||||
|
||||
// sample next token given probabilities for each embedding
|
||||
//
|
||||
// - consider only the top K tokens
|
||||
// - from them, consider only the top tokens with cumulative probability > P
|
||||
//
|
||||
// TODO: not sure if this implementation is correct
|
||||
// TODO: temperature is not implemented
|
||||
//
|
||||
gpt_vocab::id gpt_sample_top_k_top_p(
|
||||
const gpt_vocab & vocab,
|
||||
const float * logits,
|
||||
int top_k,
|
||||
double top_p,
|
||||
double temp,
|
||||
std::mt19937 & rng);
|
||||
|
||||
//
|
||||
// Audio utils
|
||||
//
|
||||
|
||||
// Read WAV audio file and store the PCM data into pcmf32
|
||||
// The sample rate of the audio must be equal to COMMON_SAMPLE_RATE
|
||||
// If stereo flag is set and the audio has 2 channels, the pcmf32s will contain 2 channel PCM
|
||||
bool read_wav(
|
||||
const std::string & fname,
|
||||
std::vector<float> & pcmf32,
|
||||
std::vector<std::vector<float>> & pcmf32s,
|
||||
bool stereo);
|
||||
|
||||
// Apply a high-pass frequency filter to PCM audio
|
||||
// Suppresses frequencies below cutoff Hz
|
||||
void high_pass_filter(
|
||||
std::vector<float> & data,
|
||||
float cutoff,
|
||||
float sample_rate);
|
||||
|
||||
// Basic voice activity detection (VAD) using audio energy adaptive threshold
|
||||
bool vad_simple(
|
||||
std::vector<float> & pcmf32,
|
||||
int sample_rate,
|
||||
int last_ms,
|
||||
float vad_thold,
|
||||
float freq_thold,
|
||||
bool verbose);
|
||||
|
||||
// compute similarity between two strings using Levenshtein distance
|
||||
float similarity(const std::string & s0, const std::string & s1);
|
@ -8,7 +8,7 @@ function convertTypedArray(src, type) {
|
||||
|
||||
var printTextarea = (function() {
|
||||
var element = document.getElementById('output');
|
||||
if (element) element.alue = ''; // clear browser cache
|
||||
if (element) element.value = ''; // clear browser cache
|
||||
return function(text) {
|
||||
if (arguments.length > 1) text = Array.prototype.slice.call(arguments).join(' ');
|
||||
console.log(text);
|
||||
@ -88,11 +88,15 @@ async function fetchRemote(url, cbProgress, cbPrint) {
|
||||
// - check if the data is already in the IndexedDB
|
||||
// - if not, fetch it from the remote URL and store it in the IndexedDB
|
||||
function loadRemote(url, dst, size_mb, cbProgress, cbReady, cbCancel, cbPrint) {
|
||||
// query the storage quota and print it
|
||||
navigator.storage.estimate().then(function (estimate) {
|
||||
cbPrint('loadRemote: storage quota: ' + estimate.quota + ' bytes');
|
||||
cbPrint('loadRemote: storage usage: ' + estimate.usage + ' bytes');
|
||||
});
|
||||
if (!navigator.storage || !navigator.storage.estimate) {
|
||||
cbPrint('loadRemote: navigator.storage.estimate() is not supported');
|
||||
} else {
|
||||
// query the storage quota and print it
|
||||
navigator.storage.estimate().then(function (estimate) {
|
||||
cbPrint('loadRemote: storage quota: ' + estimate.quota + ' bytes');
|
||||
cbPrint('loadRemote: storage usage: ' + estimate.usage + ' bytes');
|
||||
});
|
||||
}
|
||||
|
||||
// check if the data is already in the IndexedDB
|
||||
var rq = indexedDB.open(dbName, dbVersion);
|
||||
@ -141,7 +145,15 @@ function loadRemote(url, dst, size_mb, cbProgress, cbReady, cbCancel, cbPrint) {
|
||||
var db = event.target.result;
|
||||
var tx = db.transaction(['models'], 'readwrite');
|
||||
var os = tx.objectStore('models');
|
||||
var rq = os.put(data, url);
|
||||
|
||||
var rq = null;
|
||||
try {
|
||||
var rq = os.put(data, url);
|
||||
} catch (e) {
|
||||
cbPrint('loadRemote: failed to store "' + url + '" in the IndexedDB: \n' + e);
|
||||
cbCancel();
|
||||
return;
|
||||
}
|
||||
|
||||
rq.onsuccess = function (event) {
|
||||
cbPrint('loadRemote: "' + url + '" stored in the IndexedDB');
|
||||
@ -176,7 +188,6 @@ function loadRemote(url, dst, size_mb, cbProgress, cbReady, cbCancel, cbPrint) {
|
||||
|
||||
rq.onabort = function (event) {
|
||||
cbPrint('loadRemote: failed to open IndexedDB: abort');
|
||||
|
||||
cbCancel();
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -100,7 +100,7 @@ while [ $running -eq 1 ]; do
|
||||
err=$(cat /tmp/whisper-live.err | wc -l)
|
||||
done
|
||||
|
||||
./main -t 8 -m ./models/ggml-base.en.bin -f /tmp/whisper-live.wav --no-timestamps -otxt 2> /tmp/whispererr | tail -n 1
|
||||
./main -t 8 -m ./models/ggml-${model}.bin -f /tmp/whisper-live.wav --no-timestamps -otxt 2> /tmp/whispererr | tail -n 1
|
||||
|
||||
while [ $SECONDS -lt $((($i+1)*$step_s)) ]; do
|
||||
sleep 1
|
||||
|
@ -3,4 +3,4 @@ add_executable(${TARGET} main.cpp)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE whisper ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_link_libraries(${TARGET} PRIVATE common whisper ${CMAKE_THREAD_LIBS_INIT})
|
||||
|
@ -9,25 +9,36 @@ It can be used as a reference for using the `whisper.cpp` library in other proje
|
||||
usage: ./main [options] file0.wav file1.wav ...
|
||||
|
||||
options:
|
||||
-h, --help [default] show this help message and exit
|
||||
-t N, --threads N [4 ] number of threads to use during computation
|
||||
-p N, --processors N [1 ] number of processors to use during computation
|
||||
-ot N, --offset-t N [0 ] time offset in milliseconds
|
||||
-on N, --offset-n N [0 ] segment index offset
|
||||
-d N, --duration N [0 ] duration of audio to process in milliseconds
|
||||
-mc N, --max-context N [-1 ] maximum number of text context tokens to store
|
||||
-ml N, --max-len N [0 ] maximum segment length in characters
|
||||
-wt N, --word-thold N [0.01 ] word timestamp probability threshold
|
||||
-su, --speed-up [false ] speed up audio by x2 (reduced accuracy)
|
||||
-tr, --translate [false ] translate from source language to english
|
||||
-otxt, --output-txt [false ] output result in a text file
|
||||
-ovtt, --output-vtt [false ] output result in a vtt file
|
||||
-osrt, --output-srt [false ] output result in a srt file
|
||||
-owts, --output-words [false ] output script for generating karaoke video
|
||||
-ps, --print-special [false ] print special tokens
|
||||
-pc, --print-colors [false ] print colors
|
||||
-nt, --no-timestamps [true ] do not print timestamps
|
||||
-l LANG, --language LANG [en ] spoken language
|
||||
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path
|
||||
-f FNAME, --file FNAME [ ] input WAV file path
|
||||
-h, --help [default] show this help message and exit
|
||||
-t N, --threads N [4 ] number of threads to use during computation
|
||||
-p N, --processors N [1 ] number of processors to use during computation
|
||||
-ot N, --offset-t N [0 ] time offset in milliseconds
|
||||
-on N, --offset-n N [0 ] segment index offset
|
||||
-d N, --duration N [0 ] duration of audio to process in milliseconds
|
||||
-mc N, --max-context N [-1 ] maximum number of text context tokens to store
|
||||
-ml N, --max-len N [0 ] maximum segment length in characters
|
||||
-bo N, --best-of N [5 ] number of best candidates to keep
|
||||
-bs N, --beam-size N [-1 ] beam size for beam search
|
||||
-wt N, --word-thold N [0.01 ] word timestamp probability threshold
|
||||
-et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail
|
||||
-lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail
|
||||
-su, --speed-up [false ] speed up audio by x2 (reduced accuracy)
|
||||
-tr, --translate [false ] translate from source language to english
|
||||
-di, --diarize [false ] stereo audio diarization
|
||||
-nf, --no-fallback [false ] do not use temperature fallback while decoding
|
||||
-otxt, --output-txt [false ] output result in a text file
|
||||
-ovtt, --output-vtt [false ] output result in a vtt file
|
||||
-osrt, --output-srt [false ] output result in a srt file
|
||||
-owts, --output-words [false ] output script for generating karaoke video
|
||||
-ocsv, --output-csv [false ] output result in a CSV file
|
||||
-oj, --output-json [false ] output result in a JSON file
|
||||
-of FNAME, --output-file FNAME [ ] output file path (without file extension)
|
||||
-ps, --print-special [false ] print special tokens
|
||||
-pc, --print-colors [false ] print colors
|
||||
-pp, --print-progress [false ] print progress
|
||||
-nt, --no-timestamps [true ] do not print timestamps
|
||||
-l LANG, --language LANG [en ] spoken language ('auto' for auto-detect)
|
||||
--prompt PROMPT [ ] initial prompt
|
||||
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path
|
||||
-f FNAME, --file FNAME [ ] input WAV file path
|
||||
```
|
||||
|
@ -1,9 +1,6 @@
|
||||
#include "whisper.h"
|
||||
#include "common.h"
|
||||
|
||||
// third-party utilities
|
||||
// use your favorite implementations
|
||||
#define DR_WAV_IMPLEMENTATION
|
||||
#include "dr_wav.h"
|
||||
#include "whisper.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <fstream>
|
||||
@ -11,6 +8,7 @@
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <cstring>
|
||||
|
||||
// Terminal color map. 10 colors grouped in ranges [0.0, 0.1, ..., 0.9]
|
||||
// Lowest is red, middle is yellow, highest is green.
|
||||
@ -53,27 +51,32 @@ void replace_all(std::string & s, const std::string & search, const std::string
|
||||
// command-line parameters
|
||||
struct whisper_params {
|
||||
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||
int32_t n_processors = 1;
|
||||
int32_t offset_t_ms = 0;
|
||||
int32_t offset_n = 0;
|
||||
int32_t duration_ms = 0;
|
||||
int32_t n_processors = 1;
|
||||
int32_t offset_t_ms = 0;
|
||||
int32_t offset_n = 0;
|
||||
int32_t duration_ms = 0;
|
||||
int32_t max_context = -1;
|
||||
int32_t max_len = 0;
|
||||
int32_t best_of = 5;
|
||||
int32_t max_len = 0;
|
||||
int32_t best_of = 2;
|
||||
int32_t beam_size = -1;
|
||||
|
||||
float word_thold = 0.01f;
|
||||
float entropy_thold = 2.4f;
|
||||
float logprob_thold = -1.0f;
|
||||
float word_thold = 0.01f;
|
||||
float entropy_thold = 2.40f;
|
||||
float logprob_thold = -1.00f;
|
||||
|
||||
bool speed_up = false;
|
||||
bool translate = false;
|
||||
bool detect_language= false;
|
||||
bool diarize = false;
|
||||
bool split_on_word = false;
|
||||
bool no_fallback = false;
|
||||
bool output_txt = false;
|
||||
bool output_vtt = false;
|
||||
bool output_srt = false;
|
||||
bool output_wts = false;
|
||||
bool output_csv = false;
|
||||
bool output_jsn = false;
|
||||
bool output_lrc = false;
|
||||
bool print_special = false;
|
||||
bool print_colors = false;
|
||||
bool print_progress = false;
|
||||
@ -81,9 +84,11 @@ struct whisper_params {
|
||||
|
||||
std::string language = "en";
|
||||
std::string prompt;
|
||||
std::string font_path = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
|
||||
std::string model = "models/ggml-base.en.bin";
|
||||
|
||||
std::vector<std::string> fname_inp = {};
|
||||
std::vector<std::string> fname_out = {};
|
||||
};
|
||||
|
||||
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
|
||||
@ -92,6 +97,11 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
for (int i = 1; i < argc; i++) {
|
||||
std::string arg = argv[i];
|
||||
|
||||
if (arg == "-"){
|
||||
params.fname_inp.push_back(arg);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (arg[0] != '-') {
|
||||
params.fname_inp.push_back(arg);
|
||||
continue;
|
||||
@ -116,16 +126,23 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
|
||||
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
|
||||
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
|
||||
else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; }
|
||||
else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; }
|
||||
else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; }
|
||||
else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
|
||||
else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
|
||||
else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; }
|
||||
else if (arg == "-olrc" || arg == "--output-lrc") { params.output_lrc = true; }
|
||||
else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; }
|
||||
else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; }
|
||||
else if (arg == "-oj" || arg == "--output-json") { params.output_jsn = true; }
|
||||
else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(argv[++i]); }
|
||||
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
|
||||
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
|
||||
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
|
||||
else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
|
||||
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
|
||||
else if (arg == "-dl" || arg == "--detect-language"){ params.detect_language= true; }
|
||||
else if ( arg == "--prompt") { params.prompt = argv[++i]; }
|
||||
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
|
||||
else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); }
|
||||
@ -144,35 +161,42 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
||||
fprintf(stderr, "usage: %s [options] file0.wav file1.wav ...\n", argv[0]);
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "options:\n");
|
||||
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
|
||||
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
|
||||
fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors);
|
||||
fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms);
|
||||
fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n);
|
||||
fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms);
|
||||
fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
|
||||
fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len);
|
||||
fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of);
|
||||
fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size);
|
||||
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
|
||||
fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold);
|
||||
fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold);
|
||||
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
|
||||
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
|
||||
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
|
||||
fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false");
|
||||
fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
|
||||
fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false");
|
||||
fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false");
|
||||
fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false");
|
||||
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
||||
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
|
||||
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
|
||||
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true");
|
||||
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
|
||||
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
|
||||
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
||||
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
|
||||
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
|
||||
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
|
||||
fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors);
|
||||
fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms);
|
||||
fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n);
|
||||
fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms);
|
||||
fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
|
||||
fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len);
|
||||
fprintf(stderr, " -sow, --split-on-word [%-7s] split on word rather than on token\n", params.split_on_word ? "true" : "false");
|
||||
fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of);
|
||||
fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size);
|
||||
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
|
||||
fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold);
|
||||
fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold);
|
||||
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
|
||||
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
|
||||
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
|
||||
fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false");
|
||||
fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false");
|
||||
fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
|
||||
fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false");
|
||||
fprintf(stderr, " -olrc, --output-lrc [%-7s] output result in a lrc file\n", params.output_lrc ? "true" : "false");
|
||||
fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false");
|
||||
fprintf(stderr, " -fp, --font-path [%-7s] path to a monospace font for karaoke video\n", params.font_path.c_str());
|
||||
fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false");
|
||||
fprintf(stderr, " -oj, --output-json [%-7s] output result in a JSON file\n", params.output_jsn ? "true" : "false");
|
||||
fprintf(stderr, " -of FNAME, --output-file FNAME [%-7s] output file path (without file extension)\n", "");
|
||||
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
||||
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
|
||||
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
|
||||
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true");
|
||||
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
|
||||
fprintf(stderr, " -dl, --detect-language [%-7s] exit after automatically detecting language\n", params.detect_language ? "true" : "false");
|
||||
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
|
||||
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
||||
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
@ -182,7 +206,7 @@ struct whisper_print_user_data {
|
||||
const std::vector<std::vector<float>> * pcmf32s;
|
||||
};
|
||||
|
||||
void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) {
|
||||
void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper_state * /*state*/, int n_new, void * user_data) {
|
||||
const auto & params = *((whisper_print_user_data *) user_data)->params;
|
||||
const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s;
|
||||
|
||||
@ -190,8 +214,8 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi
|
||||
|
||||
std::string speaker = "";
|
||||
|
||||
int64_t t0;
|
||||
int64_t t1;
|
||||
int64_t t0 = 0;
|
||||
int64_t t1 = 0;
|
||||
|
||||
// print the last n_new segments
|
||||
const int s0 = n_segments - n_new;
|
||||
@ -331,6 +355,37 @@ bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_
|
||||
return true;
|
||||
}
|
||||
|
||||
char *escape_double_quotes_and_backslashes(const char *str) {
|
||||
if (str == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
size_t escaped_length = strlen(str) + 1;
|
||||
|
||||
for (size_t i = 0; str[i] != '\0'; i++) {
|
||||
if (str[i] == '"' || str[i] == '\\') {
|
||||
escaped_length++;
|
||||
}
|
||||
}
|
||||
|
||||
char *escaped = (char *)calloc(escaped_length, 1); // pre-zeroed
|
||||
if (escaped == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
size_t pos = 0;
|
||||
for (size_t i = 0; str[i] != '\0'; i++) {
|
||||
if (str[i] == '"' || str[i] == '\\') {
|
||||
escaped[pos++] = '\\';
|
||||
}
|
||||
escaped[pos++] = str[i];
|
||||
}
|
||||
|
||||
// no need to set zero due to calloc() being used prior
|
||||
|
||||
return escaped;
|
||||
}
|
||||
|
||||
bool output_csv(struct whisper_context * ctx, const char * fname) {
|
||||
std::ofstream fout(fname);
|
||||
if (!fout.is_open()) {
|
||||
@ -341,31 +396,160 @@ bool output_csv(struct whisper_context * ctx, const char * fname) {
|
||||
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
|
||||
|
||||
const int n_segments = whisper_full_n_segments(ctx);
|
||||
fout << "start,end,text\n";
|
||||
for (int i = 0; i < n_segments; ++i) {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
if (text[0] == ' ') {
|
||||
text = text + sizeof(char); //whisper_full_get_segment_text() returns a string with leading space, point to the next character.
|
||||
}
|
||||
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
|
||||
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
||||
char * text_escaped = escape_double_quotes_and_backslashes(text);
|
||||
|
||||
//need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds.
|
||||
fout << 10 * t0 << ", " << 10 * t1 << ", \"" << text << "\"\n";
|
||||
fout << 10 * t0 << "," << 10 * t1 << ",\"" << text_escaped << "\"\n";
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool output_json(struct whisper_context * ctx, const char * fname, const whisper_params & params) {
|
||||
std::ofstream fout(fname);
|
||||
int indent = 0;
|
||||
|
||||
auto doindent = [&]() {
|
||||
for (int i = 0; i < indent; i++) fout << "\t";
|
||||
};
|
||||
|
||||
auto start_arr = [&](const char *name) {
|
||||
doindent();
|
||||
fout << "\"" << name << "\": [\n";
|
||||
indent++;
|
||||
};
|
||||
|
||||
auto end_arr = [&](bool end = false) {
|
||||
indent--;
|
||||
doindent();
|
||||
fout << (end ? "]\n" : "},\n");
|
||||
};
|
||||
|
||||
auto start_obj = [&](const char *name = nullptr) {
|
||||
doindent();
|
||||
if (name) {
|
||||
fout << "\"" << name << "\": {\n";
|
||||
} else {
|
||||
fout << "{\n";
|
||||
}
|
||||
indent++;
|
||||
};
|
||||
|
||||
auto end_obj = [&](bool end = false) {
|
||||
indent--;
|
||||
doindent();
|
||||
fout << (end ? "}\n" : "},\n");
|
||||
};
|
||||
|
||||
auto start_value = [&](const char *name) {
|
||||
doindent();
|
||||
fout << "\"" << name << "\": ";
|
||||
};
|
||||
|
||||
auto value_s = [&](const char *name, const char *val, bool end = false) {
|
||||
start_value(name);
|
||||
char * val_escaped = escape_double_quotes_and_backslashes(val);
|
||||
fout << "\"" << val_escaped << (end ? "\"\n" : "\",\n");
|
||||
free(val_escaped);
|
||||
};
|
||||
|
||||
auto end_value = [&](bool end = false) {
|
||||
fout << (end ? "\n" : ",\n");
|
||||
};
|
||||
|
||||
auto value_i = [&](const char *name, const int64_t val, bool end = false) {
|
||||
start_value(name);
|
||||
fout << val;
|
||||
end_value(end);
|
||||
};
|
||||
|
||||
auto value_b = [&](const char *name, const bool val, bool end = false) {
|
||||
start_value(name);
|
||||
fout << (val ? "true" : "false");
|
||||
end_value(end);
|
||||
};
|
||||
|
||||
if (!fout.is_open()) {
|
||||
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
|
||||
return false;
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
|
||||
start_obj();
|
||||
value_s("systeminfo", whisper_print_system_info());
|
||||
start_obj("model");
|
||||
value_s("type", whisper_model_type_readable(ctx));
|
||||
value_b("multilingual", whisper_is_multilingual(ctx));
|
||||
value_i("vocab", whisper_model_n_vocab(ctx));
|
||||
start_obj("audio");
|
||||
value_i("ctx", whisper_model_n_audio_ctx(ctx));
|
||||
value_i("state", whisper_model_n_audio_state(ctx));
|
||||
value_i("head", whisper_model_n_audio_head(ctx));
|
||||
value_i("layer", whisper_model_n_audio_layer(ctx), true);
|
||||
end_obj();
|
||||
start_obj("text");
|
||||
value_i("ctx", whisper_model_n_text_ctx(ctx));
|
||||
value_i("state", whisper_model_n_text_state(ctx));
|
||||
value_i("head", whisper_model_n_text_head(ctx));
|
||||
value_i("layer", whisper_model_n_text_layer(ctx), true);
|
||||
end_obj();
|
||||
value_i("mels", whisper_model_n_mels(ctx));
|
||||
value_i("ftype", whisper_model_ftype(ctx), true);
|
||||
end_obj();
|
||||
start_obj("params");
|
||||
value_s("model", params.model.c_str());
|
||||
value_s("language", params.language.c_str());
|
||||
value_b("translate", params.translate, true);
|
||||
end_obj();
|
||||
start_obj("result");
|
||||
value_s("language", whisper_lang_str(whisper_full_lang_id(ctx)), true);
|
||||
end_obj();
|
||||
start_arr("transcription");
|
||||
|
||||
const int n_segments = whisper_full_n_segments(ctx);
|
||||
for (int i = 0; i < n_segments; ++i) {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
|
||||
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
||||
|
||||
start_obj();
|
||||
start_obj("timestamps");
|
||||
value_s("from", to_timestamp(t0, true).c_str());
|
||||
value_s("to", to_timestamp(t1, true).c_str(), true);
|
||||
end_obj();
|
||||
start_obj("offsets");
|
||||
value_i("from", t0 * 10);
|
||||
value_i("to", t1 * 10, true);
|
||||
end_obj();
|
||||
value_s("text", text, true);
|
||||
end_obj(i == (n_segments - 1));
|
||||
}
|
||||
|
||||
end_arr(true);
|
||||
end_obj(true);
|
||||
return true;
|
||||
}
|
||||
|
||||
// karaoke video generation
|
||||
// outputs a bash script that uses ffmpeg to generate a video with the subtitles
|
||||
// TODO: font parameter adjustments
|
||||
bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & /*params*/, float t_sec) {
|
||||
bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, float t_sec) {
|
||||
std::ofstream fout(fname);
|
||||
|
||||
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
|
||||
|
||||
// TODO: become parameter
|
||||
static const char * font = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
|
||||
static const char * font = params.font_path.c_str();
|
||||
|
||||
std::ifstream fin(font);
|
||||
if (!fin.is_open()) {
|
||||
fprintf(stderr, "%s: font not found at '%s', please specify a monospace font with -fp\n", __func__, font);
|
||||
return false;
|
||||
}
|
||||
|
||||
fout << "#!/bin/bash" << "\n";
|
||||
fout << "\n";
|
||||
@ -468,6 +652,39 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
|
||||
return true;
|
||||
}
|
||||
|
||||
bool output_lrc(struct whisper_context * ctx, const char * fname) {
|
||||
|
||||
std::ofstream fout(fname);
|
||||
if (!fout.is_open()) {
|
||||
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
|
||||
return false;
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
|
||||
|
||||
fout << "[by:whisper.cpp]\n";
|
||||
|
||||
const int n_segments = whisper_full_n_segments(ctx);
|
||||
for (int i = 0; i < n_segments; ++i) {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
const int64_t t = whisper_full_get_segment_t0(ctx, i);
|
||||
|
||||
int64_t msec = t * 10;
|
||||
int64_t min = msec / (1000 * 60);
|
||||
msec = msec - min * (1000 * 60);
|
||||
int64_t sec = msec / 1000;
|
||||
msec = msec - sec * 1000;
|
||||
|
||||
char buf[16];
|
||||
snprintf(buf, sizeof(buf), "%02d:%02d.%02d", (int) min, (int) sec, (int) ( msec / 10));
|
||||
std::string timestamp_lrc = std::string(buf);
|
||||
|
||||
fout << '[' << timestamp_lrc << ']' << text << "\n";
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
whisper_params params;
|
||||
|
||||
@ -496,108 +713,16 @@ int main(int argc, char ** argv) {
|
||||
return 3;
|
||||
}
|
||||
|
||||
// initial prompt
|
||||
std::vector<whisper_token> prompt_tokens;
|
||||
|
||||
if (!params.prompt.empty()) {
|
||||
prompt_tokens.resize(1024);
|
||||
prompt_tokens.resize(whisper_tokenize(ctx, params.prompt.c_str(), prompt_tokens.data(), prompt_tokens.size()));
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "initial prompt: '%s'\n", params.prompt.c_str());
|
||||
fprintf(stderr, "initial tokens: [ ");
|
||||
for (int i = 0; i < (int) prompt_tokens.size(); ++i) {
|
||||
fprintf(stderr, "%d ", prompt_tokens[i]);
|
||||
}
|
||||
fprintf(stderr, "]\n");
|
||||
}
|
||||
|
||||
for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
|
||||
const auto fname_inp = params.fname_inp[f];
|
||||
const auto fname_out = f < (int) params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
|
||||
|
||||
std::vector<float> pcmf32; // mono-channel F32 PCM
|
||||
std::vector<float> pcmf32; // mono-channel F32 PCM
|
||||
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
|
||||
|
||||
// WAV input
|
||||
{
|
||||
drwav wav;
|
||||
std::vector<uint8_t> wav_data; // used for pipe input from stdin
|
||||
|
||||
if (fname_inp == "-") {
|
||||
{
|
||||
uint8_t buf[1024];
|
||||
while (true)
|
||||
{
|
||||
const size_t n = fread(buf, 1, sizeof(buf), stdin);
|
||||
if (n == 0) {
|
||||
break;
|
||||
}
|
||||
wav_data.insert(wav_data.end(), buf, buf + n);
|
||||
}
|
||||
}
|
||||
|
||||
if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) {
|
||||
fprintf(stderr, "error: failed to open WAV file from stdin\n");
|
||||
return 4;
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
|
||||
}
|
||||
else if (drwav_init_file(&wav, fname_inp.c_str(), nullptr) == false) {
|
||||
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
|
||||
return 5;
|
||||
}
|
||||
|
||||
if (wav.channels != 1 && wav.channels != 2) {
|
||||
fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", argv[0], fname_inp.c_str());
|
||||
return 6;
|
||||
}
|
||||
|
||||
if (params.diarize && wav.channels != 2 && params.no_timestamps == false) {
|
||||
fprintf(stderr, "%s: WAV file '%s' must be stereo for diarization and timestamps have to be enabled\n", argv[0], fname_inp.c_str());
|
||||
return 6;
|
||||
}
|
||||
|
||||
if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
|
||||
fprintf(stderr, "%s: WAV file '%s' must be %i kHz\n", argv[0], fname_inp.c_str(), WHISPER_SAMPLE_RATE/1000);
|
||||
return 8;
|
||||
}
|
||||
|
||||
if (wav.bitsPerSample != 16) {
|
||||
fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", argv[0], fname_inp.c_str());
|
||||
return 9;
|
||||
}
|
||||
|
||||
const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8);
|
||||
|
||||
std::vector<int16_t> pcm16;
|
||||
pcm16.resize(n*wav.channels);
|
||||
drwav_read_pcm_frames_s16(&wav, n, pcm16.data());
|
||||
drwav_uninit(&wav);
|
||||
|
||||
// convert to mono, float
|
||||
pcmf32.resize(n);
|
||||
if (wav.channels == 1) {
|
||||
for (uint64_t i = 0; i < n; i++) {
|
||||
pcmf32[i] = float(pcm16[i])/32768.0f;
|
||||
}
|
||||
} else {
|
||||
for (uint64_t i = 0; i < n; i++) {
|
||||
pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
|
||||
}
|
||||
}
|
||||
|
||||
if (params.diarize) {
|
||||
// convert to stereo, float
|
||||
pcmf32s.resize(2);
|
||||
|
||||
pcmf32s[0].resize(n);
|
||||
pcmf32s[1].resize(n);
|
||||
for (uint64_t i = 0; i < n; i++) {
|
||||
pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
|
||||
pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
|
||||
}
|
||||
}
|
||||
if (!::read_wav(fname_inp, pcmf32, pcmf32s, params.diarize)) {
|
||||
fprintf(stderr, "error: failed to read WAV file '%s'\n", fname_inp.c_str());
|
||||
continue;
|
||||
}
|
||||
|
||||
// print system information
|
||||
@ -617,6 +742,9 @@ int main(int argc, char ** argv) {
|
||||
fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
|
||||
}
|
||||
}
|
||||
if (params.detect_language) {
|
||||
params.language = "auto";
|
||||
}
|
||||
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, timestamps = %d ...\n",
|
||||
__func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
|
||||
params.n_threads, params.n_processors,
|
||||
@ -639,6 +767,7 @@ int main(int argc, char ** argv) {
|
||||
wparams.print_special = params.print_special;
|
||||
wparams.translate = params.translate;
|
||||
wparams.language = params.language.c_str();
|
||||
wparams.detect_language = params.detect_language;
|
||||
wparams.n_threads = params.n_threads;
|
||||
wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
|
||||
wparams.offset_ms = params.offset_t_ms;
|
||||
@ -646,18 +775,19 @@ int main(int argc, char ** argv) {
|
||||
|
||||
wparams.token_timestamps = params.output_wts || params.max_len > 0;
|
||||
wparams.thold_pt = params.word_thold;
|
||||
wparams.entropy_thold = params.entropy_thold;
|
||||
wparams.logprob_thold = params.logprob_thold;
|
||||
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
|
||||
wparams.split_on_word = params.split_on_word;
|
||||
|
||||
wparams.speed_up = params.speed_up;
|
||||
|
||||
wparams.initial_prompt = params.prompt.c_str();
|
||||
|
||||
wparams.greedy.best_of = params.best_of;
|
||||
wparams.beam_search.beam_size = params.beam_size;
|
||||
wparams.temperature_inc = -1;
|
||||
|
||||
wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
|
||||
wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();
|
||||
wparams.temperature_inc = params.no_fallback ? 0.0f : wparams.temperature_inc;
|
||||
wparams.entropy_thold = params.entropy_thold;
|
||||
wparams.logprob_thold = params.logprob_thold;
|
||||
|
||||
whisper_print_user_data user_data = { ¶ms, &pcmf32s };
|
||||
|
||||
@ -673,7 +803,7 @@ int main(int argc, char ** argv) {
|
||||
{
|
||||
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
|
||||
|
||||
wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, void * user_data) {
|
||||
wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
|
||||
bool is_aborted = *(bool*)user_data;
|
||||
return !is_aborted;
|
||||
};
|
||||
@ -692,34 +822,45 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// output to text file
|
||||
if (params.output_txt) {
|
||||
const auto fname_txt = fname_inp + ".txt";
|
||||
const auto fname_txt = fname_out + ".txt";
|
||||
output_txt(ctx, fname_txt.c_str());
|
||||
}
|
||||
|
||||
// output to VTT file
|
||||
if (params.output_vtt) {
|
||||
const auto fname_vtt = fname_inp + ".vtt";
|
||||
const auto fname_vtt = fname_out + ".vtt";
|
||||
output_vtt(ctx, fname_vtt.c_str());
|
||||
}
|
||||
|
||||
// output to SRT file
|
||||
if (params.output_srt) {
|
||||
const auto fname_srt = fname_inp + ".srt";
|
||||
const auto fname_srt = fname_out + ".srt";
|
||||
output_srt(ctx, fname_srt.c_str(), params);
|
||||
}
|
||||
|
||||
// output to WTS file
|
||||
if (params.output_wts) {
|
||||
const auto fname_wts = fname_inp + ".wts";
|
||||
const auto fname_wts = fname_out + ".wts";
|
||||
output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE);
|
||||
}
|
||||
|
||||
// output to CSV file
|
||||
// output to CSV file
|
||||
if (params.output_csv) {
|
||||
const auto fname_csv = fname_inp + ".csv";
|
||||
const auto fname_csv = fname_out + ".csv";
|
||||
output_csv(ctx, fname_csv.c_str());
|
||||
}
|
||||
|
||||
// output to JSON file
|
||||
if (params.output_jsn) {
|
||||
const auto fname_jsn = fname_out + ".json";
|
||||
output_json(ctx, fname_jsn.c_str(), params);
|
||||
}
|
||||
|
||||
// output to LRC file
|
||||
if (params.output_lrc) {
|
||||
const auto fname_lrc = fname_out + ".lrc";
|
||||
output_lrc(ctx, fname_lrc.c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
6
examples/quantize/CMakeLists.txt
Normal file
6
examples/quantize/CMakeLists.txt
Normal file
@ -0,0 +1,6 @@
|
||||
set(TARGET quantize)
|
||||
add_executable(${TARGET} quantize.cpp)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE common whisper ${CMAKE_THREAD_LIBS_INIT})
|
3
examples/quantize/README.md
Normal file
3
examples/quantize/README.md
Normal file
@ -0,0 +1,3 @@
|
||||
# quantize
|
||||
|
||||
Tool for integer quantization of Whisper `ggml` model files
|
215
examples/quantize/quantize.cpp
Normal file
215
examples/quantize/quantize.cpp
Normal file
@ -0,0 +1,215 @@
|
||||
#include "ggml.h"
|
||||
|
||||
#include "common.h"
|
||||
#include "common-ggml.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <regex>
|
||||
|
||||
// default hparams (Whisper tiny)
|
||||
struct whisper_hparams {
|
||||
int32_t n_vocab = 51864;
|
||||
int32_t n_audio_ctx = 1500;
|
||||
int32_t n_audio_state = 384;
|
||||
int32_t n_audio_head = 6;
|
||||
int32_t n_audio_layer = 4;
|
||||
int32_t n_text_ctx = 448;
|
||||
int32_t n_text_state = 384;
|
||||
int32_t n_text_head = 6;
|
||||
int32_t n_text_layer = 4;
|
||||
int32_t n_mels = 80;
|
||||
int32_t f16 = 1;
|
||||
};
|
||||
|
||||
struct whisper_filters {
|
||||
int32_t n_mel;
|
||||
int32_t n_fft;
|
||||
|
||||
std::vector<float> data;
|
||||
};
|
||||
|
||||
// quantize a model
|
||||
bool whisper_model_quantize(const std::string & fname_inp, const std::string & fname_out, ggml_ftype ftype) {
|
||||
gpt_vocab vocab;
|
||||
|
||||
printf("%s: loading model from '%s'\n", __func__, fname_inp.c_str());
|
||||
|
||||
auto finp = std::ifstream(fname_inp, std::ios::binary);
|
||||
if (!finp) {
|
||||
fprintf(stderr, "%s: failed to open '%s' for reading\n", __func__, fname_inp.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
auto fout = std::ofstream(fname_out, std::ios::binary);
|
||||
if (!fout) {
|
||||
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_out.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
// verify magic
|
||||
{
|
||||
uint32_t magic;
|
||||
finp.read((char *) &magic, sizeof(magic));
|
||||
if (magic != 0x67676d6c) {
|
||||
fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname_inp.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
fout.write((char *) &magic, sizeof(magic));
|
||||
}
|
||||
|
||||
whisper_hparams hparams;
|
||||
|
||||
// load hparams
|
||||
{
|
||||
finp.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
|
||||
finp.read((char *) &hparams.n_audio_ctx, sizeof(hparams.n_audio_ctx));
|
||||
finp.read((char *) &hparams.n_audio_state, sizeof(hparams.n_audio_state));
|
||||
finp.read((char *) &hparams.n_audio_head, sizeof(hparams.n_audio_head));
|
||||
finp.read((char *) &hparams.n_audio_layer, sizeof(hparams.n_audio_layer));
|
||||
finp.read((char *) &hparams.n_text_ctx, sizeof(hparams.n_text_ctx));
|
||||
finp.read((char *) &hparams.n_text_state, sizeof(hparams.n_text_state));
|
||||
finp.read((char *) &hparams.n_text_head, sizeof(hparams.n_text_head));
|
||||
finp.read((char *) &hparams.n_text_layer, sizeof(hparams.n_text_layer));
|
||||
finp.read((char *) &hparams.n_mels, sizeof(hparams.n_mels));
|
||||
finp.read((char *) &hparams.f16, sizeof(hparams.f16));
|
||||
|
||||
fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab);
|
||||
fprintf(stderr, "%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx);
|
||||
fprintf(stderr, "%s: n_audio_state = %d\n", __func__, hparams.n_audio_state);
|
||||
fprintf(stderr, "%s: n_audio_head = %d\n", __func__, hparams.n_audio_head);
|
||||
fprintf(stderr, "%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer);
|
||||
fprintf(stderr, "%s: n_text_ctx = %d\n", __func__, hparams.n_text_ctx);
|
||||
fprintf(stderr, "%s: n_text_state = %d\n", __func__, hparams.n_text_state);
|
||||
fprintf(stderr, "%s: n_text_head = %d\n", __func__, hparams.n_text_head);
|
||||
fprintf(stderr, "%s: n_text_layer = %d\n", __func__, hparams.n_text_layer);
|
||||
fprintf(stderr, "%s: n_mels = %d\n", __func__, hparams.n_mels);
|
||||
fprintf(stderr, "%s: f16 = %d\n", __func__, hparams.f16);
|
||||
|
||||
fout.write((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
|
||||
fout.write((char *) &hparams.n_audio_ctx, sizeof(hparams.n_audio_ctx));
|
||||
fout.write((char *) &hparams.n_audio_state, sizeof(hparams.n_audio_state));
|
||||
fout.write((char *) &hparams.n_audio_head, sizeof(hparams.n_audio_head));
|
||||
fout.write((char *) &hparams.n_audio_layer, sizeof(hparams.n_audio_layer));
|
||||
fout.write((char *) &hparams.n_text_ctx, sizeof(hparams.n_text_ctx));
|
||||
fout.write((char *) &hparams.n_text_state, sizeof(hparams.n_text_state));
|
||||
fout.write((char *) &hparams.n_text_head, sizeof(hparams.n_text_head));
|
||||
fout.write((char *) &hparams.n_text_layer, sizeof(hparams.n_text_layer));
|
||||
fout.write((char *) &hparams.n_mels, sizeof(hparams.n_mels));
|
||||
fout.write((char *) &ftype, sizeof(hparams.f16));
|
||||
}
|
||||
|
||||
// load mel filters
|
||||
{
|
||||
whisper_filters filters;
|
||||
|
||||
finp.read ((char *) &filters.n_mel, sizeof(filters.n_mel));
|
||||
fout.write((char *) &filters.n_mel, sizeof(filters.n_mel));
|
||||
finp.read ((char *) &filters.n_fft, sizeof(filters.n_fft));
|
||||
fout.write((char *) &filters.n_fft, sizeof(filters.n_fft));
|
||||
|
||||
filters.data.resize(filters.n_mel * filters.n_fft);
|
||||
finp.read ((char *) filters.data.data(), filters.data.size() * sizeof(float));
|
||||
fout.write((char *) filters.data.data(), filters.data.size() * sizeof(float));
|
||||
}
|
||||
|
||||
// load vocab
|
||||
{
|
||||
int32_t n_vocab = 0;
|
||||
finp.read ((char *) &n_vocab, sizeof(n_vocab));
|
||||
fout.write((char *) &n_vocab, sizeof(n_vocab));
|
||||
|
||||
//if (n_vocab != hparams.n_vocab) {
|
||||
// fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
|
||||
// __func__, fname_inp.c_str(), n_vocab, hparams.n_vocab);
|
||||
// return false;
|
||||
//}
|
||||
|
||||
std::string word;
|
||||
for (int i = 0; i < n_vocab; i++) {
|
||||
uint32_t len;
|
||||
finp.read ((char *) &len, sizeof(len));
|
||||
fout.write((char *) &len, sizeof(len));
|
||||
|
||||
word.resize(len);
|
||||
finp.read ((char *) word.data(), len);
|
||||
fout.write((char *) word.data(), len);
|
||||
|
||||
vocab.token_to_id[word] = i;
|
||||
vocab.id_to_token[i] = word;
|
||||
}
|
||||
}
|
||||
|
||||
// regexes of tensor names to not be quantized
|
||||
const std::vector<std::string> to_skip = {
|
||||
//"encoder.*",
|
||||
"encoder.conv1.bias",
|
||||
"encoder.conv2.bias",
|
||||
"encoder.positional_embedding",
|
||||
"decoder.positional_embedding",
|
||||
};
|
||||
|
||||
if (!ggml_common_quantize_0(finp, fout, ftype, { ".*" }, to_skip)) {
|
||||
fprintf(stderr, "%s: failed to quantize model '%s'\n", __func__, fname_inp.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
finp.close();
|
||||
fout.close();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
if (argc != 4) {
|
||||
fprintf(stderr, "usage: %s model-f32.bin model-quant.bin type\n", argv[0]);
|
||||
ggml_print_ftypes(stderr);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// needed to initialize f16 tables
|
||||
{
|
||||
struct ggml_init_params params = { 0, NULL, false };
|
||||
struct ggml_context * ctx = ggml_init(params);
|
||||
ggml_free(ctx);
|
||||
}
|
||||
|
||||
const std::string fname_inp = argv[1];
|
||||
const std::string fname_out = argv[2];
|
||||
|
||||
const ggml_ftype ftype = ggml_parse_ftype(argv[3]);
|
||||
|
||||
const int64_t t_main_start_us = ggml_time_us();
|
||||
|
||||
int64_t t_quantize_us = 0;
|
||||
|
||||
// load the model
|
||||
{
|
||||
const int64_t t_start_us = ggml_time_us();
|
||||
|
||||
if (!whisper_model_quantize(fname_inp, fname_out, ggml_ftype(ftype))) {
|
||||
fprintf(stderr, "%s: failed to quantize model from '%s'\n", __func__, fname_inp.c_str());
|
||||
return 1;
|
||||
}
|
||||
|
||||
t_quantize_us = ggml_time_us() - t_start_us;
|
||||
}
|
||||
|
||||
// report timing
|
||||
{
|
||||
const int64_t t_main_end_us = ggml_time_us();
|
||||
|
||||
printf("\n");
|
||||
printf("%s: quantize time = %8.2f ms\n", __func__, t_quantize_us/1000.0f);
|
||||
printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
@ -35,6 +35,15 @@
|
||||
|
||||
<br><br>
|
||||
|
||||
<b>More examples:</b>
|
||||
<a href="https://whisper.ggerganov.com/">main</a> |
|
||||
<a href="https://whisper.ggerganov.com/bench">bench</a> |
|
||||
<a href="https://whisper.ggerganov.com/stream">stream</a> |
|
||||
<a href="https://whisper.ggerganov.com/command">command</a> |
|
||||
<a href="https://whisper.ggerganov.com/talk">talk</a> |
|
||||
|
||||
<br><br>
|
||||
|
||||
<hr>
|
||||
|
||||
Select the model you would like to use, click the "Start" button and start speaking
|
||||
@ -45,6 +54,10 @@
|
||||
Whisper model: <span id="model-whisper-status"></span>
|
||||
<button id="fetch-whisper-tiny-en" onclick="loadWhisper('tiny.en')">tiny.en (75 MB)</button>
|
||||
<button id="fetch-whisper-base-en" onclick="loadWhisper('base.en')">base.en (142 MB)</button>
|
||||
<br><br>
|
||||
Quantized models:<br><br>
|
||||
<button id="fetch-whisper-tiny-en-q5_1" onclick="loadWhisper('tiny-en-q5_1')">tiny.en (Q5_1, 31 MB)</button>
|
||||
<button id="fetch-whisper-base-en-q5_1" onclick="loadWhisper('base-en-q5_1')">base.en (Q5_1, 57 MB)</button>
|
||||
<span id="fetch-whisper-progress"></span>
|
||||
|
||||
<!--
|
||||
@ -162,11 +175,17 @@
|
||||
let urls = {
|
||||
'tiny.en': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en.bin',
|
||||
'base.en': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en.bin',
|
||||
|
||||
'tiny-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en-q5_1.bin',
|
||||
'base-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en-q5_1.bin',
|
||||
};
|
||||
|
||||
let sizes = {
|
||||
'tiny.en': 75,
|
||||
'base.en': 142,
|
||||
|
||||
'tiny-en-q5_1': 31,
|
||||
'base-en-q5_1': 57,
|
||||
};
|
||||
|
||||
let url = urls[model];
|
||||
@ -177,6 +196,10 @@
|
||||
|
||||
document.getElementById('fetch-whisper-tiny-en').style.display = 'none';
|
||||
document.getElementById('fetch-whisper-base-en').style.display = 'none';
|
||||
|
||||
document.getElementById('fetch-whisper-tiny-en-q5_1').style.display = 'none';
|
||||
document.getElementById('fetch-whisper-base-en-q5_1').style.display = 'none';
|
||||
|
||||
document.getElementById('model-whisper-status').innerHTML = 'loading "' + model + '" ... ';
|
||||
|
||||
cbProgress = function(p) {
|
||||
@ -188,6 +211,10 @@
|
||||
var el;
|
||||
el = document.getElementById('fetch-whisper-tiny-en'); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('fetch-whisper-base-en'); if (el) el.style.display = 'inline-block';
|
||||
|
||||
el = document.getElementById('fetch-whisper-tiny-en-q5_1'); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('fetch-whisper-base-en-q5_1'); if (el) el.style.display = 'inline-block';
|
||||
|
||||
el = document.getElementById('model-whisper-status'); if (el) el.innerHTML = '';
|
||||
};
|
||||
|
||||
|
@ -1,10 +1,9 @@
|
||||
if (WHISPER_SUPPORT_SDL2)
|
||||
if (WHISPER_SDL2)
|
||||
# stream
|
||||
set(TARGET stream)
|
||||
add_executable(${TARGET} stream.cpp)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
|
||||
target_link_libraries(${TARGET} PRIVATE whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_link_libraries(${TARGET} PRIVATE common common-sdl whisper ${CMAKE_THREAD_LIBS_INIT})
|
||||
endif ()
|
||||
|
@ -3,19 +3,16 @@
|
||||
// A very quick-n-dirty implementation serving mainly as a proof of concept.
|
||||
//
|
||||
|
||||
#include "common.h"
|
||||
#include "common-sdl.h"
|
||||
#include "whisper.h"
|
||||
|
||||
#include <SDL.h>
|
||||
#include <SDL_audio.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <cassert>
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <fstream>
|
||||
#include <mutex>
|
||||
|
||||
// 500 -> 00:05.000
|
||||
// 6000 -> 01:00.000
|
||||
@ -46,6 +43,7 @@ struct whisper_params {
|
||||
|
||||
bool speed_up = false;
|
||||
bool translate = false;
|
||||
bool no_fallback = false;
|
||||
bool print_special = false;
|
||||
bool no_context = true;
|
||||
bool no_timestamps = false;
|
||||
@ -76,6 +74,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
|
||||
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
|
||||
else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; }
|
||||
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
|
||||
else if (arg == "-kc" || arg == "--keep-context") { params.no_context = false; }
|
||||
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
|
||||
@ -97,325 +96,26 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "options:\n");
|
||||
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
|
||||
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
|
||||
fprintf(stderr, " --step N [%-7d] audio step size in milliseconds\n", params.step_ms);
|
||||
fprintf(stderr, " --length N [%-7d] audio length in milliseconds\n", params.length_ms);
|
||||
fprintf(stderr, " --keep N [%-7d] audio to keep from previous step in ms\n", params.keep_ms);
|
||||
fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
|
||||
fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
|
||||
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
|
||||
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
|
||||
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
|
||||
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
|
||||
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
|
||||
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
||||
fprintf(stderr, " -kc, --keep-context [%-7s] keep context between audio chunks\n", params.no_context ? "false" : "true");
|
||||
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
|
||||
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
||||
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
|
||||
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
|
||||
fprintf(stderr, " --step N [%-7d] audio step size in milliseconds\n", params.step_ms);
|
||||
fprintf(stderr, " --length N [%-7d] audio length in milliseconds\n", params.length_ms);
|
||||
fprintf(stderr, " --keep N [%-7d] audio to keep from previous step in ms\n", params.keep_ms);
|
||||
fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
|
||||
fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
|
||||
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
|
||||
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
|
||||
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
|
||||
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
|
||||
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
|
||||
fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false");
|
||||
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
||||
fprintf(stderr, " -kc, --keep-context [%-7s] keep context between audio chunks\n", params.no_context ? "false" : "true");
|
||||
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
|
||||
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
||||
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
//
|
||||
// SDL Audio capture
|
||||
//
|
||||
|
||||
class audio_async {
|
||||
public:
|
||||
audio_async(int len_ms);
|
||||
~audio_async();
|
||||
|
||||
bool init(int capture_id, int sample_rate);
|
||||
|
||||
// start capturing audio via the provided SDL callback
|
||||
// keep last len_ms seconds of audio in a circular buffer
|
||||
bool resume();
|
||||
bool pause();
|
||||
bool clear();
|
||||
|
||||
// callback to be called by SDL
|
||||
void callback(uint8_t * stream, int len);
|
||||
|
||||
// get audio data from the circular buffer
|
||||
void get(int ms, std::vector<float> & audio);
|
||||
|
||||
private:
|
||||
SDL_AudioDeviceID m_dev_id_in = 0;
|
||||
|
||||
int m_len_ms = 0;
|
||||
int m_sample_rate = 0;
|
||||
|
||||
std::atomic_bool m_running;
|
||||
std::mutex m_mutex;
|
||||
|
||||
std::vector<float> m_audio;
|
||||
std::vector<float> m_audio_new;
|
||||
size_t m_audio_pos = 0;
|
||||
size_t m_audio_len = 0;
|
||||
};
|
||||
|
||||
audio_async::audio_async(int len_ms) {
|
||||
m_len_ms = len_ms;
|
||||
|
||||
m_running = false;
|
||||
}
|
||||
|
||||
audio_async::~audio_async() {
|
||||
if (m_dev_id_in) {
|
||||
SDL_CloseAudioDevice(m_dev_id_in);
|
||||
}
|
||||
}
|
||||
|
||||
bool audio_async::init(int capture_id, int sample_rate) {
|
||||
SDL_LogSetPriority(SDL_LOG_CATEGORY_APPLICATION, SDL_LOG_PRIORITY_INFO);
|
||||
|
||||
if (SDL_Init(SDL_INIT_AUDIO) < 0) {
|
||||
SDL_LogError(SDL_LOG_CATEGORY_APPLICATION, "Couldn't initialize SDL: %s\n", SDL_GetError());
|
||||
return false;
|
||||
}
|
||||
|
||||
SDL_SetHintWithPriority(SDL_HINT_AUDIO_RESAMPLING_MODE, "medium", SDL_HINT_OVERRIDE);
|
||||
|
||||
{
|
||||
int nDevices = SDL_GetNumAudioDevices(SDL_TRUE);
|
||||
fprintf(stderr, "%s: found %d capture devices:\n", __func__, nDevices);
|
||||
for (int i = 0; i < nDevices; i++) {
|
||||
fprintf(stderr, "%s: - Capture device #%d: '%s'\n", __func__, i, SDL_GetAudioDeviceName(i, SDL_TRUE));
|
||||
}
|
||||
}
|
||||
|
||||
SDL_AudioSpec capture_spec_requested;
|
||||
SDL_AudioSpec capture_spec_obtained;
|
||||
|
||||
SDL_zero(capture_spec_requested);
|
||||
SDL_zero(capture_spec_obtained);
|
||||
|
||||
capture_spec_requested.freq = sample_rate;
|
||||
capture_spec_requested.format = AUDIO_F32;
|
||||
capture_spec_requested.channels = 1;
|
||||
capture_spec_requested.samples = 1024;
|
||||
capture_spec_requested.callback = [](void * userdata, uint8_t * stream, int len) {
|
||||
audio_async * audio = (audio_async *) userdata;
|
||||
audio->callback(stream, len);
|
||||
};
|
||||
capture_spec_requested.userdata = this;
|
||||
|
||||
if (capture_id >= 0) {
|
||||
fprintf(stderr, "%s: attempt to open capture device %d : '%s' ...\n", __func__, capture_id, SDL_GetAudioDeviceName(capture_id, SDL_TRUE));
|
||||
m_dev_id_in = SDL_OpenAudioDevice(SDL_GetAudioDeviceName(capture_id, SDL_TRUE), SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
|
||||
} else {
|
||||
fprintf(stderr, "%s: attempt to open default capture device ...\n", __func__);
|
||||
m_dev_id_in = SDL_OpenAudioDevice(nullptr, SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
|
||||
}
|
||||
|
||||
if (!m_dev_id_in) {
|
||||
fprintf(stderr, "%s: couldn't open an audio device for capture: %s!\n", __func__, SDL_GetError());
|
||||
m_dev_id_in = 0;
|
||||
|
||||
return false;
|
||||
} else {
|
||||
fprintf(stderr, "%s: obtained spec for input device (SDL Id = %d):\n", __func__, m_dev_id_in);
|
||||
fprintf(stderr, "%s: - sample rate: %d\n", __func__, capture_spec_obtained.freq);
|
||||
fprintf(stderr, "%s: - format: %d (required: %d)\n", __func__, capture_spec_obtained.format,
|
||||
capture_spec_requested.format);
|
||||
fprintf(stderr, "%s: - channels: %d (required: %d)\n", __func__, capture_spec_obtained.channels,
|
||||
capture_spec_requested.channels);
|
||||
fprintf(stderr, "%s: - samples per frame: %d\n", __func__, capture_spec_obtained.samples);
|
||||
}
|
||||
|
||||
m_sample_rate = capture_spec_obtained.freq;
|
||||
|
||||
m_audio.resize((m_sample_rate*m_len_ms)/1000);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool audio_async::resume() {
|
||||
if (!m_dev_id_in) {
|
||||
fprintf(stderr, "%s: no audio device to resume!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (m_running) {
|
||||
fprintf(stderr, "%s: already running!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
SDL_PauseAudioDevice(m_dev_id_in, 0);
|
||||
|
||||
m_running = true;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool audio_async::pause() {
|
||||
if (!m_dev_id_in) {
|
||||
fprintf(stderr, "%s: no audio device to pause!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!m_running) {
|
||||
fprintf(stderr, "%s: already paused!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
SDL_PauseAudioDevice(m_dev_id_in, 1);
|
||||
|
||||
m_running = false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool audio_async::clear() {
|
||||
if (!m_dev_id_in) {
|
||||
fprintf(stderr, "%s: no audio device to clear!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!m_running) {
|
||||
fprintf(stderr, "%s: not running!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(m_mutex);
|
||||
|
||||
m_audio_pos = 0;
|
||||
m_audio_len = 0;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// callback to be called by SDL
|
||||
void audio_async::callback(uint8_t * stream, int len) {
|
||||
if (!m_running) {
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t n_samples = len / sizeof(float);
|
||||
|
||||
m_audio_new.resize(n_samples);
|
||||
memcpy(m_audio_new.data(), stream, n_samples * sizeof(float));
|
||||
|
||||
//fprintf(stderr, "%s: %zu samples, pos %zu, len %zu\n", __func__, n_samples, m_audio_pos, m_audio_len);
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(m_mutex);
|
||||
|
||||
if (m_audio_pos + n_samples > m_audio.size()) {
|
||||
const size_t n0 = m_audio.size() - m_audio_pos;
|
||||
|
||||
memcpy(&m_audio[m_audio_pos], stream, n0 * sizeof(float));
|
||||
memcpy(&m_audio[0], &stream[n0], (n_samples - n0) * sizeof(float));
|
||||
|
||||
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
|
||||
m_audio_len = m_audio.size();
|
||||
} else {
|
||||
memcpy(&m_audio[m_audio_pos], stream, n_samples * sizeof(float));
|
||||
|
||||
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
|
||||
m_audio_len = std::min(m_audio_len + n_samples, m_audio.size());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void audio_async::get(int ms, std::vector<float> & result) {
|
||||
if (!m_dev_id_in) {
|
||||
fprintf(stderr, "%s: no audio device to get audio from!\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!m_running) {
|
||||
fprintf(stderr, "%s: not running!\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
result.clear();
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(m_mutex);
|
||||
|
||||
if (ms <= 0) {
|
||||
ms = m_len_ms;
|
||||
}
|
||||
|
||||
size_t n_samples = (m_sample_rate * ms) / 1000;
|
||||
if (n_samples > m_audio_len) {
|
||||
n_samples = m_audio_len;
|
||||
}
|
||||
|
||||
result.resize(n_samples);
|
||||
|
||||
int s0 = m_audio_pos - n_samples;
|
||||
if (s0 < 0) {
|
||||
s0 += m_audio.size();
|
||||
}
|
||||
|
||||
if (s0 + n_samples > m_audio.size()) {
|
||||
const size_t n0 = m_audio.size() - s0;
|
||||
|
||||
memcpy(result.data(), &m_audio[s0], n0 * sizeof(float));
|
||||
memcpy(&result[n0], &m_audio[0], (n_samples - n0) * sizeof(float));
|
||||
} else {
|
||||
memcpy(result.data(), &m_audio[s0], n_samples * sizeof(float));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////
|
||||
|
||||
void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate) {
|
||||
const float rc = 1.0f / (2.0f * M_PI * cutoff);
|
||||
const float dt = 1.0f / sample_rate;
|
||||
const float alpha = dt / (rc + dt);
|
||||
|
||||
float y = data[0];
|
||||
|
||||
for (size_t i = 1; i < data.size(); i++) {
|
||||
y = alpha * (y + data[i] - data[i - 1]);
|
||||
data[i] = y;
|
||||
}
|
||||
}
|
||||
|
||||
bool vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) {
|
||||
const int n_samples = pcmf32.size();
|
||||
const int n_samples_last = (sample_rate * last_ms) / 1000;
|
||||
|
||||
if (n_samples_last >= n_samples) {
|
||||
// not enough samples - assume no speech
|
||||
return false;
|
||||
}
|
||||
|
||||
if (freq_thold > 0.0f) {
|
||||
high_pass_filter(pcmf32, freq_thold, sample_rate);
|
||||
}
|
||||
|
||||
float energy_all = 0.0f;
|
||||
float energy_last = 0.0f;
|
||||
|
||||
for (int i = 0; i < n_samples; i++) {
|
||||
energy_all += fabsf(pcmf32[i]);
|
||||
|
||||
if (i >= n_samples - n_samples_last) {
|
||||
energy_last += fabsf(pcmf32[i]);
|
||||
}
|
||||
}
|
||||
|
||||
energy_all /= n_samples;
|
||||
energy_last /= n_samples_last;
|
||||
|
||||
if (verbose) {
|
||||
fprintf(stderr, "%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold);
|
||||
}
|
||||
|
||||
if (energy_last > vad_thold*energy_all) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
whisper_params params;
|
||||
|
||||
@ -423,16 +123,17 @@ int main(int argc, char ** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
params.keep_ms = std::min(params.keep_ms, params.step_ms); // cannot be more than step_ms
|
||||
params.keep_ms = std::min(params.keep_ms, params.step_ms);
|
||||
params.length_ms = std::max(params.length_ms, params.step_ms);
|
||||
|
||||
const int n_samples_step = (params.step_ms *1e-3)*WHISPER_SAMPLE_RATE;
|
||||
const int n_samples_len = (params.length_ms*1e-3)*WHISPER_SAMPLE_RATE;
|
||||
const int n_samples_keep = (params.keep_ms *1e-3)*WHISPER_SAMPLE_RATE;
|
||||
const int n_samples_30s = (30000 *1e-3)*WHISPER_SAMPLE_RATE;
|
||||
const int n_samples_step = (1e-3*params.step_ms )*WHISPER_SAMPLE_RATE;
|
||||
const int n_samples_len = (1e-3*params.length_ms)*WHISPER_SAMPLE_RATE;
|
||||
const int n_samples_keep = (1e-3*params.keep_ms )*WHISPER_SAMPLE_RATE;
|
||||
const int n_samples_30s = (1e-3*30000.0 )*WHISPER_SAMPLE_RATE;
|
||||
|
||||
const bool use_vad = n_samples_step <= 0; // sliding window mode uses VAD
|
||||
|
||||
const int n_new_line = !use_vad ? params.length_ms / params.step_ms - 1 : 1; // number of steps to print new line
|
||||
const int n_new_line = !use_vad ? std::max(1, params.length_ms / params.step_ms - 1) : 1; // number of steps to print new line
|
||||
|
||||
params.no_timestamps = !use_vad;
|
||||
params.no_context |= use_vad;
|
||||
@ -450,7 +151,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// whisper init
|
||||
|
||||
if (whisper_lang_id(params.language.c_str()) == -1) {
|
||||
if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1){
|
||||
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
|
||||
whisper_print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
@ -516,23 +217,7 @@ int main(int argc, char ** argv) {
|
||||
// main audio loop
|
||||
while (is_running) {
|
||||
// handle Ctrl + C
|
||||
{
|
||||
SDL_Event event;
|
||||
while (SDL_PollEvent(&event)) {
|
||||
switch (event.type) {
|
||||
case SDL_QUIT:
|
||||
{
|
||||
is_running = false;
|
||||
} break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!is_running) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
is_running = sdl_poll_events();
|
||||
|
||||
if (!is_running) {
|
||||
break;
|
||||
@ -555,7 +240,7 @@ int main(int argc, char ** argv) {
|
||||
break;
|
||||
}
|
||||
|
||||
SDL_Delay(1);
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(1));
|
||||
}
|
||||
|
||||
const int n_samples_new = pcmf32_new.size();
|
||||
@ -586,7 +271,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
audio.get(2000, pcmf32_new);
|
||||
|
||||
if (vad_simple(pcmf32_new, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, false)) {
|
||||
if (::vad_simple(pcmf32_new, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, false)) {
|
||||
audio.get(params.length_ms, pcmf32);
|
||||
} else {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
@ -606,7 +291,6 @@ int main(int argc, char ** argv) {
|
||||
wparams.print_realtime = false;
|
||||
wparams.print_timestamps = !params.no_timestamps;
|
||||
wparams.translate = params.translate;
|
||||
wparams.no_context = true;
|
||||
wparams.single_segment = !use_vad;
|
||||
wparams.max_tokens = params.max_tokens;
|
||||
wparams.language = params.language.c_str();
|
||||
@ -616,7 +300,8 @@ int main(int argc, char ** argv) {
|
||||
wparams.speed_up = params.speed_up;
|
||||
|
||||
// disable temperature fallback
|
||||
wparams.temperature_inc = -1.0f;
|
||||
//wparams.temperature_inc = -1.0f;
|
||||
wparams.temperature_inc = params.no_fallback ? 0.0f : wparams.temperature_inc;
|
||||
|
||||
wparams.prompt_tokens = params.no_context ? nullptr : prompt_tokens.data();
|
||||
wparams.prompt_n_tokens = params.no_context ? 0 : prompt_tokens.size();
|
||||
@ -698,6 +383,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
}
|
||||
}
|
||||
fflush(stdout);
|
||||
}
|
||||
}
|
||||
|
||||
|
1
examples/talk-llama/.gitignore
vendored
Normal file
1
examples/talk-llama/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
audio.mp3
|
16
examples/talk-llama/CMakeLists.txt
Normal file
16
examples/talk-llama/CMakeLists.txt
Normal file
@ -0,0 +1,16 @@
|
||||
if (WHISPER_SDL2)
|
||||
# talk-llama
|
||||
set(TARGET talk-llama)
|
||||
#add_executable(${TARGET} talk-llama.cpp llama.cpp)
|
||||
#target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
|
||||
#target_link_libraries(${TARGET} PRIVATE common common-sdl whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
|
||||
|
||||
# TODO: this is temporary
|
||||
# need to export ggml symbols for MSVC, but too lazy ..
|
||||
add_executable(${TARGET} talk-llama.cpp llama.cpp ../common.cpp ../common-sdl.cpp ../../ggml.c ../../whisper.cpp)
|
||||
|
||||
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS} ../../)
|
||||
target_link_libraries(${TARGET} PRIVATE ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
endif ()
|
50
examples/talk-llama/README.md
Normal file
50
examples/talk-llama/README.md
Normal file
@ -0,0 +1,50 @@
|
||||
# talk-llama
|
||||
|
||||
Talk with an LLaMA AI in your terminal
|
||||
|
||||
[Demo Talk](https://user-images.githubusercontent.com/1991296/228024237-848f998c-c334-46a6-bef8-3271590da83b.mp4)
|
||||
|
||||
## Building
|
||||
|
||||
The `talk-llama` tool depends on SDL2 library to capture audio from the microphone. You can build it like this:
|
||||
|
||||
```bash
|
||||
# Install SDL2 on Linux
|
||||
sudo apt-get install libsdl2-dev
|
||||
|
||||
# Install SDL2 on Mac OS
|
||||
brew install sdl2
|
||||
|
||||
# Build the "talk-llama" executable
|
||||
make talk-llama
|
||||
|
||||
# Run it
|
||||
./talk-llama -mw ./models/ggml-small.en.bin -ml ../llama.cpp/models/13B/ggml-model-q4_0.bin -p "Georgi" -t 8
|
||||
```
|
||||
|
||||
- The `-mw` argument specifies the Whisper model that you would like to use. Recommended `base` or `small` for real-time experience
|
||||
- The `-ml` argument specifies the LLaMA model that you would like to use. Read the instructions in https://github.com/ggerganov/llama.cpp for information about how to obtain a `ggml` compatible LLaMA model
|
||||
|
||||
## Session
|
||||
|
||||
The `talk-llama` tool supports session management to enable more coherent and continuous conversations. By maintaining context from previous interactions, it can better understand and respond to user requests in a more natural way.
|
||||
|
||||
To enable session support, use the `--session FILE` command line option when running the program. The `talk-llama` model state will be saved to the specified file after each interaction. If the file does not exist, it will be created. If the file exists, the model state will be loaded from it, allowing you to resume a previous session.
|
||||
|
||||
This feature is especially helpful for maintaining context in long conversations or when interacting with the AI assistant across multiple sessions. It ensures that the assistant remembers the previous interactions and can provide more relevant and contextual responses.
|
||||
|
||||
Example usage:
|
||||
|
||||
```bash
|
||||
./talk-llama --session ./my-session-file -mw ./models/ggml-small.en.bin -ml ../llama.cpp/models/13B/ggml-model-q4_0.bin -p "Georgi" -t 8
|
||||
```
|
||||
|
||||
## TTS
|
||||
|
||||
For best experience, this example needs a TTS tool to convert the generated text responses to voice.
|
||||
You can use any TTS engine that you would like - simply edit the [speak.sh](speak.sh) script to your needs.
|
||||
By default, it is configured to use MacOS's `say`, but you can use whatever you wish.
|
||||
|
||||
## Discussion
|
||||
|
||||
If you have any feedback, please let "us" know in the following discussion: https://github.com/ggerganov/whisper.cpp/discussions/672?converting=1
|
23
examples/talk-llama/eleven-labs.py
Normal file
23
examples/talk-llama/eleven-labs.py
Normal file
@ -0,0 +1,23 @@
|
||||
import sys
|
||||
import importlib.util
|
||||
|
||||
api_key = "" #Write your https://beta.elevenlabs.io api key here
|
||||
if not api_key:
|
||||
print("To use elevenlabs you have to register to https://beta.elevenlabs.io and add your elevenlabs api key to examples/talk-llama/eleven-labs.py")
|
||||
sys.exit()
|
||||
|
||||
if importlib.util.find_spec("elevenlabs") is None:
|
||||
print("elevenlabs library is not installed, you can install it to your enviroment using 'pip install elevenlabs'")
|
||||
sys.exit()
|
||||
|
||||
from elevenlabs import ElevenLabs
|
||||
eleven = ElevenLabs(api_key)
|
||||
|
||||
# Get a Voice object, by name or UUID
|
||||
voice = eleven.voices["Arnold"] #Possible Voices: Adam Antoni Arnold Bella Domi Elli Josh
|
||||
|
||||
# Generate the TTS
|
||||
audio = voice.generate(str(sys.argv[2:]))
|
||||
|
||||
# Save the TTS to a file
|
||||
audio.save("audio")
|
433
examples/talk-llama/llama-util.h
Normal file
433
examples/talk-llama/llama-util.h
Normal file
@ -0,0 +1,433 @@
|
||||
// Internal header to be included only by llama.cpp.
|
||||
// Contains wrappers around OS interfaces.
|
||||
|
||||
#ifndef LLAMA_UTIL_H
|
||||
#define LLAMA_UTIL_H
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstdint>
|
||||
#include <cerrno>
|
||||
#include <cstring>
|
||||
#include <cstdarg>
|
||||
#include <cstdlib>
|
||||
#include <climits>
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#ifdef __has_include
|
||||
#if __has_include(<unistd.h>)
|
||||
#include <unistd.h>
|
||||
#if defined(_POSIX_MAPPED_FILES)
|
||||
#include <sys/mman.h>
|
||||
#endif
|
||||
#if defined(_POSIX_MEMLOCK_RANGE)
|
||||
#include <sys/resource.h>
|
||||
#endif
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if defined(_WIN32)
|
||||
#define WIN32_LEAN_AND_MEAN
|
||||
#ifndef NOMINMAX
|
||||
#define NOMINMAX
|
||||
#endif
|
||||
#include <windows.h>
|
||||
#include <io.h>
|
||||
#include <stdio.h> // for _fseeki64
|
||||
#endif
|
||||
|
||||
#define LLAMA_ASSERT(x) \
|
||||
do { \
|
||||
if (!(x)) { \
|
||||
fprintf(stderr, "LLAMA_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
|
||||
abort(); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#ifdef __GNUC__
|
||||
#ifdef __MINGW32__
|
||||
__attribute__((format(gnu_printf, 1, 2)))
|
||||
#else
|
||||
__attribute__((format(printf, 1, 2)))
|
||||
#endif
|
||||
#endif
|
||||
static std::string format(const char * fmt, ...) {
|
||||
va_list ap, ap2;
|
||||
va_start(ap, fmt);
|
||||
va_copy(ap2, ap);
|
||||
int size = vsnprintf(NULL, 0, fmt, ap);
|
||||
LLAMA_ASSERT(size >= 0 && size < INT_MAX);
|
||||
std::vector<char> buf(size + 1);
|
||||
int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
|
||||
LLAMA_ASSERT(size2 == size);
|
||||
va_end(ap2);
|
||||
va_end(ap);
|
||||
return std::string(buf.data(), size);
|
||||
}
|
||||
|
||||
struct llama_file {
|
||||
// use FILE * so we don't have to re-open the file to mmap
|
||||
FILE * fp;
|
||||
size_t size;
|
||||
|
||||
llama_file(const char * fname, const char * mode) {
|
||||
fp = std::fopen(fname, mode);
|
||||
if (fp == NULL) {
|
||||
throw format("failed to open %s: %s", fname, std::strerror(errno));
|
||||
}
|
||||
seek(0, SEEK_END);
|
||||
size = tell();
|
||||
seek(0, SEEK_SET);
|
||||
}
|
||||
|
||||
size_t tell() const {
|
||||
#ifdef _WIN32
|
||||
__int64 ret = _ftelli64(fp);
|
||||
#else
|
||||
long ret = std::ftell(fp);
|
||||
#endif
|
||||
LLAMA_ASSERT(ret != -1); // this really shouldn't fail
|
||||
return (size_t) ret;
|
||||
}
|
||||
|
||||
void seek(size_t offset, int whence) {
|
||||
#ifdef _WIN32
|
||||
int ret = _fseeki64(fp, (__int64) offset, whence);
|
||||
#else
|
||||
int ret = std::fseek(fp, (long) offset, whence);
|
||||
#endif
|
||||
LLAMA_ASSERT(ret == 0); // same
|
||||
}
|
||||
|
||||
void read_raw(void * ptr, size_t size) {
|
||||
if (size == 0) {
|
||||
return;
|
||||
}
|
||||
errno = 0;
|
||||
std::size_t ret = std::fread(ptr, size, 1, fp);
|
||||
if (ferror(fp)) {
|
||||
throw format("read error: %s", strerror(errno));
|
||||
}
|
||||
if (ret != 1) {
|
||||
throw std::string("unexpectedly reached end of file");
|
||||
}
|
||||
}
|
||||
|
||||
std::uint32_t read_u32() {
|
||||
std::uint32_t ret;
|
||||
read_raw(&ret, sizeof(ret));
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::string read_string(std::uint32_t len) {
|
||||
std::vector<char> chars(len);
|
||||
read_raw(chars.data(), len);
|
||||
return std::string(chars.data(), len);
|
||||
}
|
||||
|
||||
void write_raw(const void * ptr, size_t size) {
|
||||
if (size == 0) {
|
||||
return;
|
||||
}
|
||||
errno = 0;
|
||||
size_t ret = std::fwrite(ptr, size, 1, fp);
|
||||
if (ret != 1) {
|
||||
throw format("write error: %s", strerror(errno));
|
||||
}
|
||||
}
|
||||
|
||||
void write_u32(std::uint32_t val) {
|
||||
write_raw(&val, sizeof(val));
|
||||
}
|
||||
|
||||
~llama_file() {
|
||||
if (fp) {
|
||||
std::fclose(fp);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(_WIN32)
|
||||
static std::string llama_format_win_err(DWORD err) {
|
||||
LPSTR buf;
|
||||
size_t size = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
|
||||
NULL, err, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&buf, 0, NULL);
|
||||
if (!size) {
|
||||
return "FormatMessageA failed";
|
||||
}
|
||||
std::string ret(buf, size);
|
||||
LocalFree(buf);
|
||||
return ret;
|
||||
}
|
||||
#endif
|
||||
|
||||
struct llama_mmap {
|
||||
void * addr;
|
||||
size_t size;
|
||||
|
||||
llama_mmap(const llama_mmap &) = delete;
|
||||
|
||||
#ifdef _POSIX_MAPPED_FILES
|
||||
static constexpr bool SUPPORTED = true;
|
||||
|
||||
llama_mmap(struct llama_file * file, bool prefetch = true) {
|
||||
size = file->size;
|
||||
int fd = fileno(file->fp);
|
||||
int flags = MAP_SHARED;
|
||||
#ifdef __linux__
|
||||
flags |= MAP_POPULATE;
|
||||
#endif
|
||||
addr = mmap(NULL, file->size, PROT_READ, flags, fd, 0);
|
||||
if (addr == MAP_FAILED) {
|
||||
throw format("mmap failed: %s", strerror(errno));
|
||||
}
|
||||
|
||||
if (prefetch) {
|
||||
// Advise the kernel to preload the mapped memory
|
||||
if (madvise(addr, file->size, MADV_WILLNEED)) {
|
||||
fprintf(stderr, "warning: madvise(.., MADV_WILLNEED) failed: %s\n",
|
||||
strerror(errno));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
~llama_mmap() {
|
||||
munmap(addr, size);
|
||||
}
|
||||
#elif defined(_WIN32)
|
||||
static constexpr bool SUPPORTED = true;
|
||||
|
||||
llama_mmap(struct llama_file * file, bool prefetch = true) {
|
||||
size = file->size;
|
||||
|
||||
HANDLE hFile = (HANDLE) _get_osfhandle(_fileno(file->fp));
|
||||
|
||||
HANDLE hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL);
|
||||
DWORD error = GetLastError();
|
||||
|
||||
if (hMapping == NULL) {
|
||||
throw format("CreateFileMappingA failed: %s", llama_format_win_err(error).c_str());
|
||||
}
|
||||
|
||||
addr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0);
|
||||
error = GetLastError();
|
||||
CloseHandle(hMapping);
|
||||
|
||||
if (addr == NULL) {
|
||||
throw format("MapViewOfFile failed: %s", llama_format_win_err(error).c_str());
|
||||
}
|
||||
|
||||
#if _WIN32_WINNT >= _WIN32_WINNT_WIN8
|
||||
if (prefetch) {
|
||||
// Advise the kernel to preload the mapped memory
|
||||
WIN32_MEMORY_RANGE_ENTRY range;
|
||||
range.VirtualAddress = addr;
|
||||
range.NumberOfBytes = (SIZE_T)size;
|
||||
if (!PrefetchVirtualMemory(GetCurrentProcess(), 1, &range, 0)) {
|
||||
fprintf(stderr, "warning: PrefetchVirtualMemory failed: %s\n",
|
||||
llama_format_win_err(GetLastError()).c_str());
|
||||
}
|
||||
}
|
||||
#else
|
||||
#pragma message("warning: You are building for pre-Windows 8; prefetch not supported")
|
||||
#endif // _WIN32_WINNT >= _WIN32_WINNT_WIN8
|
||||
}
|
||||
|
||||
~llama_mmap() {
|
||||
if (!UnmapViewOfFile(addr)) {
|
||||
fprintf(stderr, "warning: UnmapViewOfFile failed: %s\n",
|
||||
llama_format_win_err(GetLastError()).c_str());
|
||||
}
|
||||
}
|
||||
#else
|
||||
static constexpr bool SUPPORTED = false;
|
||||
|
||||
llama_mmap(struct llama_file *) {
|
||||
throw std::string("mmap not supported");
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
// Represents some region of memory being locked using mlock or VirtualLock;
|
||||
// will automatically unlock on destruction.
|
||||
struct llama_mlock {
|
||||
void * addr = NULL;
|
||||
size_t size = 0;
|
||||
bool failed_already = false;
|
||||
|
||||
llama_mlock() {}
|
||||
llama_mlock(const llama_mlock &) = delete;
|
||||
|
||||
~llama_mlock() {
|
||||
if (size) {
|
||||
raw_unlock(addr, size);
|
||||
}
|
||||
}
|
||||
|
||||
void init(void * addr) {
|
||||
LLAMA_ASSERT(this->addr == NULL && this->size == 0);
|
||||
this->addr = addr;
|
||||
}
|
||||
|
||||
void grow_to(size_t target_size) {
|
||||
LLAMA_ASSERT(addr);
|
||||
if (failed_already) {
|
||||
return;
|
||||
}
|
||||
size_t granularity = lock_granularity();
|
||||
target_size = (target_size + granularity - 1) & ~(granularity - 1);
|
||||
if (target_size > size) {
|
||||
if (raw_lock((uint8_t *) addr + size, target_size - size)) {
|
||||
size = target_size;
|
||||
} else {
|
||||
failed_already = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef _POSIX_MEMLOCK_RANGE
|
||||
static constexpr bool SUPPORTED = true;
|
||||
|
||||
size_t lock_granularity() {
|
||||
return (size_t) sysconf(_SC_PAGESIZE);
|
||||
}
|
||||
|
||||
#ifdef __APPLE__
|
||||
#define MLOCK_SUGGESTION \
|
||||
"Try increasing the sysctl values 'vm.user_wire_limit' and 'vm.global_user_wire_limit' and/or " \
|
||||
"decreasing 'vm.global_no_user_wire_amount'. Also try increasing RLIMIT_MLOCK (ulimit -l).\n"
|
||||
#else
|
||||
#define MLOCK_SUGGESTION \
|
||||
"Try increasing RLIMIT_MLOCK ('ulimit -l' as root).\n"
|
||||
#endif
|
||||
|
||||
bool raw_lock(const void * addr, size_t size) {
|
||||
if (!mlock(addr, size)) {
|
||||
return true;
|
||||
} else {
|
||||
char* errmsg = std::strerror(errno);
|
||||
bool suggest = (errno == ENOMEM);
|
||||
|
||||
// Check if the resource limit is fine after all
|
||||
struct rlimit lock_limit;
|
||||
if (suggest && getrlimit(RLIMIT_MEMLOCK, &lock_limit))
|
||||
suggest = false;
|
||||
if (suggest && (lock_limit.rlim_max > lock_limit.rlim_cur + size))
|
||||
suggest = false;
|
||||
|
||||
fprintf(stderr, "warning: failed to mlock %zu-byte buffer (after previously locking %zu bytes): %s\n%s",
|
||||
size, this->size, errmsg, suggest ? MLOCK_SUGGESTION : "");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
#undef MLOCK_SUGGESTION
|
||||
|
||||
void raw_unlock(void * addr, size_t size) {
|
||||
if (munlock(addr, size)) {
|
||||
fprintf(stderr, "warning: failed to munlock buffer: %s\n", std::strerror(errno));
|
||||
}
|
||||
}
|
||||
#elif defined(_WIN32)
|
||||
static constexpr bool SUPPORTED = true;
|
||||
|
||||
size_t lock_granularity() {
|
||||
SYSTEM_INFO si;
|
||||
GetSystemInfo(&si);
|
||||
return (size_t) si.dwPageSize;
|
||||
}
|
||||
|
||||
bool raw_lock(void * addr, size_t size) {
|
||||
for (int tries = 1; ; tries++) {
|
||||
if (VirtualLock(addr, size)) {
|
||||
return true;
|
||||
}
|
||||
if (tries == 2) {
|
||||
fprintf(stderr, "warning: failed to VirtualLock %zu-byte buffer (after previously locking %zu bytes): %s\n",
|
||||
size, this->size, llama_format_win_err(GetLastError()).c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
// It failed but this was only the first try; increase the working
|
||||
// set size and try again.
|
||||
SIZE_T min_ws_size, max_ws_size;
|
||||
if (!GetProcessWorkingSetSize(GetCurrentProcess(), &min_ws_size, &max_ws_size)) {
|
||||
fprintf(stderr, "warning: GetProcessWorkingSetSize failed: %s\n",
|
||||
llama_format_win_err(GetLastError()).c_str());
|
||||
return false;
|
||||
}
|
||||
// Per MSDN: "The maximum number of pages that a process can lock
|
||||
// is equal to the number of pages in its minimum working set minus
|
||||
// a small overhead."
|
||||
// Hopefully a megabyte is enough overhead:
|
||||
size_t increment = size + 1048576;
|
||||
// The minimum must be <= the maximum, so we need to increase both:
|
||||
min_ws_size += increment;
|
||||
max_ws_size += increment;
|
||||
if (!SetProcessWorkingSetSize(GetCurrentProcess(), min_ws_size, max_ws_size)) {
|
||||
fprintf(stderr, "warning: SetProcessWorkingSetSize failed: %s\n",
|
||||
llama_format_win_err(GetLastError()).c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void raw_unlock(void * addr, size_t size) {
|
||||
if (!VirtualUnlock(addr, size)) {
|
||||
fprintf(stderr, "warning: failed to VirtualUnlock buffer: %s\n",
|
||||
llama_format_win_err(GetLastError()).c_str());
|
||||
}
|
||||
}
|
||||
#else
|
||||
static constexpr bool SUPPORTED = false;
|
||||
|
||||
void raw_lock(const void * addr, size_t size) {
|
||||
fprintf(stderr, "warning: mlock not supported on this system\n");
|
||||
}
|
||||
|
||||
void raw_unlock(const void * addr, size_t size) {}
|
||||
#endif
|
||||
};
|
||||
|
||||
// Replacement for std::vector<uint8_t> that doesn't require zero-initialization.
|
||||
struct llama_buffer {
|
||||
uint8_t * addr = NULL;
|
||||
size_t size = 0;
|
||||
|
||||
void resize(size_t size) {
|
||||
delete[] addr;
|
||||
addr = new uint8_t[size];
|
||||
this->size = size;
|
||||
}
|
||||
|
||||
~llama_buffer() {
|
||||
delete[] addr;
|
||||
}
|
||||
};
|
||||
|
||||
#ifdef GGML_USE_CUBLAS
|
||||
#include "ggml-cuda.h"
|
||||
struct llama_ctx_buffer {
|
||||
uint8_t * addr = NULL;
|
||||
size_t size = 0;
|
||||
|
||||
void resize(size_t size) {
|
||||
if (addr) {
|
||||
ggml_cuda_host_free(addr);
|
||||
}
|
||||
addr = (uint8_t *) ggml_cuda_host_malloc(size);
|
||||
this->size = size;
|
||||
}
|
||||
|
||||
~llama_ctx_buffer() {
|
||||
if (addr) {
|
||||
ggml_cuda_host_free(addr);
|
||||
}
|
||||
}
|
||||
};
|
||||
#else
|
||||
typedef llama_buffer llama_ctx_buffer;
|
||||
#endif
|
||||
|
||||
#endif
|
2775
examples/talk-llama/llama.cpp
Normal file
2775
examples/talk-llama/llama.cpp
Normal file
File diff suppressed because it is too large
Load Diff
258
examples/talk-llama/llama.h
Normal file
258
examples/talk-llama/llama.h
Normal file
@ -0,0 +1,258 @@
|
||||
#ifndef LLAMA_H
|
||||
#define LLAMA_H
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <stdbool.h>
|
||||
|
||||
#ifdef LLAMA_SHARED
|
||||
# if defined(_WIN32) && !defined(__MINGW32__)
|
||||
# ifdef LLAMA_BUILD
|
||||
# define LLAMA_API __declspec(dllexport)
|
||||
# else
|
||||
# define LLAMA_API __declspec(dllimport)
|
||||
# endif
|
||||
# else
|
||||
# define LLAMA_API __attribute__ ((visibility ("default")))
|
||||
# endif
|
||||
#else
|
||||
# define LLAMA_API
|
||||
#endif
|
||||
|
||||
#define LLAMA_FILE_VERSION 1
|
||||
#define LLAMA_FILE_MAGIC 'ggjt'
|
||||
#define LLAMA_FILE_MAGIC_UNVERSIONED 'ggml'
|
||||
#define LLAMA_SESSION_MAGIC 'ggsn'
|
||||
#define LLAMA_SESSION_VERSION 0
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
//
|
||||
// C interface
|
||||
//
|
||||
// TODO: show sample usage
|
||||
//
|
||||
|
||||
struct llama_context;
|
||||
|
||||
typedef int llama_token;
|
||||
|
||||
typedef struct llama_token_data {
|
||||
llama_token id; // token id
|
||||
float logit; // log-odds of the token
|
||||
float p; // probability of the token
|
||||
} llama_token_data;
|
||||
|
||||
typedef struct llama_token_data_array {
|
||||
llama_token_data * data;
|
||||
size_t size;
|
||||
bool sorted;
|
||||
} llama_token_data_array;
|
||||
|
||||
typedef void (*llama_progress_callback)(float progress, void *ctx);
|
||||
|
||||
struct llama_context_params {
|
||||
int n_ctx; // text context
|
||||
int n_parts; // -1 for default
|
||||
int seed; // RNG seed, 0 for random
|
||||
|
||||
bool f16_kv; // use fp16 for KV cache
|
||||
bool logits_all; // the llama_eval() call computes all logits, not just the last one
|
||||
bool vocab_only; // only load the vocabulary, no weights
|
||||
bool use_mmap; // use mmap if possible
|
||||
bool use_mlock; // force system to keep model in RAM
|
||||
bool embedding; // embedding mode only
|
||||
|
||||
// called with a progress value between 0 and 1, pass NULL to disable
|
||||
llama_progress_callback progress_callback;
|
||||
// context pointer passed to the progress callback
|
||||
void * progress_callback_user_data;
|
||||
};
|
||||
|
||||
// model file types
|
||||
enum llama_ftype {
|
||||
LLAMA_FTYPE_ALL_F32 = 0,
|
||||
LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16
|
||||
LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // except 1d tensors
|
||||
// LLAMA_FTYPE_MOSTLY_Q4_3 (6) support has been removed
|
||||
LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors
|
||||
};
|
||||
|
||||
LLAMA_API struct llama_context_params llama_context_default_params();
|
||||
|
||||
LLAMA_API bool llama_mmap_supported();
|
||||
LLAMA_API bool llama_mlock_supported();
|
||||
|
||||
// Various functions for loading a ggml llama model.
|
||||
// Allocate (almost) all memory needed for the model.
|
||||
// Return NULL on failure
|
||||
LLAMA_API struct llama_context * llama_init_from_file(
|
||||
const char * path_model,
|
||||
struct llama_context_params params);
|
||||
|
||||
// Frees all allocated memory
|
||||
LLAMA_API void llama_free(struct llama_context * ctx);
|
||||
|
||||
// TODO: not great API - very likely to change
|
||||
// Returns 0 on success
|
||||
// nthread - how many threads to use. If <=0, will use std::thread::hardware_concurrency(), else the number given
|
||||
LLAMA_API int llama_model_quantize(
|
||||
const char * fname_inp,
|
||||
const char * fname_out,
|
||||
enum llama_ftype ftype,
|
||||
int nthread);
|
||||
|
||||
// Apply a LoRA adapter to a loaded model
|
||||
// path_base_model is the path to a higher quality model to use as a base for
|
||||
// the layers modified by the adapter. Can be NULL to use the current loaded model.
|
||||
// The model needs to be reloaded before applying a new adapter, otherwise the adapter
|
||||
// will be applied on top of the previous one
|
||||
// Returns 0 on success
|
||||
LLAMA_API int llama_apply_lora_from_file(
|
||||
struct llama_context * ctx,
|
||||
const char * path_lora,
|
||||
const char * path_base_model,
|
||||
int n_threads);
|
||||
|
||||
// Returns the number of tokens in the KV cache
|
||||
LLAMA_API int llama_get_kv_cache_token_count(struct llama_context * ctx);
|
||||
|
||||
// Sets the current rng seed.
|
||||
LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, int seed);
|
||||
|
||||
// Returns the size in bytes of the state (rng, logits, embedding and kv_cache)
|
||||
LLAMA_API size_t llama_get_state_size(struct llama_context * ctx);
|
||||
|
||||
// Copies the state to the specified destination address.
|
||||
// Destination needs to have allocated enough memory.
|
||||
// Returns the number of bytes copied
|
||||
LLAMA_API size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest);
|
||||
|
||||
// Set the state reading from the specified address
|
||||
// Returns the number of bytes read
|
||||
LLAMA_API size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src);
|
||||
|
||||
// Save/load session file
|
||||
LLAMA_API bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out);
|
||||
LLAMA_API bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count);
|
||||
// Run the llama inference to obtain the logits and probabilities for the next token.
|
||||
// tokens + n_tokens is the provided batch of new tokens to process
|
||||
// n_past is the number of tokens to use from previous eval calls
|
||||
// Returns 0 on success
|
||||
LLAMA_API int llama_eval(
|
||||
struct llama_context * ctx,
|
||||
const llama_token * tokens,
|
||||
int n_tokens,
|
||||
int n_past,
|
||||
int n_threads);
|
||||
|
||||
// Convert the provided text into tokens.
|
||||
// The tokens pointer must be large enough to hold the resulting tokens.
|
||||
// Returns the number of tokens on success, no more than n_max_tokens
|
||||
// Returns a negative number on failure - the number of tokens that would have been returned
|
||||
// TODO: not sure if correct
|
||||
LLAMA_API int llama_tokenize(
|
||||
struct llama_context * ctx,
|
||||
const char * text,
|
||||
llama_token * tokens,
|
||||
int n_max_tokens,
|
||||
bool add_bos);
|
||||
|
||||
LLAMA_API int llama_n_vocab(struct llama_context * ctx);
|
||||
LLAMA_API int llama_n_ctx (struct llama_context * ctx);
|
||||
LLAMA_API int llama_n_embd (struct llama_context * ctx);
|
||||
|
||||
// Token logits obtained from the last call to llama_eval()
|
||||
// The logits for the last token are stored in the last row
|
||||
// Can be mutated in order to change the probabilities of the next token
|
||||
// Rows: n_tokens
|
||||
// Cols: n_vocab
|
||||
LLAMA_API float * llama_get_logits(struct llama_context * ctx);
|
||||
|
||||
// Get the embeddings for the input
|
||||
// shape: [n_embd] (1-dimensional)
|
||||
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
|
||||
|
||||
// Token Id -> String. Uses the vocabulary in the provided context
|
||||
LLAMA_API const char * llama_token_to_str(struct llama_context * ctx, llama_token token);
|
||||
|
||||
// Special tokens
|
||||
LLAMA_API llama_token llama_token_bos();
|
||||
LLAMA_API llama_token llama_token_eos();
|
||||
LLAMA_API llama_token llama_token_nl();
|
||||
|
||||
// Sampling functions
|
||||
|
||||
/// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
|
||||
LLAMA_API void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, llama_token * last_tokens, size_t last_tokens_size, float penalty);
|
||||
|
||||
/// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
|
||||
LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, llama_token * last_tokens, size_t last_tokens_size, float alpha_frequency, float alpha_presence);
|
||||
|
||||
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
|
||||
LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates);
|
||||
|
||||
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
||||
LLAMA_API void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int k, size_t min_keep = 1);
|
||||
|
||||
/// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
||||
LLAMA_API void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep = 1);
|
||||
|
||||
/// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
|
||||
LLAMA_API void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep = 1);
|
||||
|
||||
/// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
|
||||
LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep = 1);
|
||||
LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp);
|
||||
|
||||
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
||||
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
||||
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
||||
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
|
||||
/// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
|
||||
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
|
||||
LLAMA_API llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu);
|
||||
|
||||
/// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
||||
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
||||
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
||||
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
|
||||
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
|
||||
LLAMA_API llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu);
|
||||
|
||||
/// @details Selects the token with the highest probability.
|
||||
LLAMA_API llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates);
|
||||
|
||||
/// @details Randomly selects a token from the candidates based on their probabilities.
|
||||
LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates);
|
||||
|
||||
// Performance information
|
||||
LLAMA_API void llama_print_timings(struct llama_context * ctx);
|
||||
LLAMA_API void llama_reset_timings(struct llama_context * ctx);
|
||||
|
||||
// Print system information
|
||||
LLAMA_API const char * llama_print_system_info(void);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
// Internal API to be implemented by llama.cpp and used by tests/benchmarks only
|
||||
#ifdef LLAMA_API_INTERNAL
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
struct ggml_tensor;
|
||||
|
||||
std::vector<std::pair<std::string, struct ggml_tensor *>>& llama_internal_get_tensor_map(struct llama_context * ctx);
|
||||
|
||||
#endif
|
||||
|
||||
#endif // LLAMA_H
|
23
examples/talk-llama/prompts/talk-alpaca.txt
Normal file
23
examples/talk-llama/prompts/talk-alpaca.txt
Normal file
@ -0,0 +1,23 @@
|
||||
Below is an instruction that describes a task. Write a response that appropriately completes the request.
|
||||
|
||||
### Instruction:
|
||||
|
||||
Write a text transcript of a never ending dialog, where {0} interacts with an AI assistant named {1}.
|
||||
{1} is helpful, kind, honest, friendly, good at writing and never fails to answer {0}’s requests immediately and with details and precision.
|
||||
There are no annotations like (30 seconds passed...) or (to himself), just what {0} and {1} say aloud to each other.
|
||||
The transcript only includes text, it does not include markup like HTML and Markdown.
|
||||
{1} responds with short and concise answers.
|
||||
|
||||
### Response:
|
||||
|
||||
{0}{4} Hello, {1}!
|
||||
{1}{4} Hello {0}! How may I help you today?
|
||||
{0}{4} What time is it?
|
||||
{1}{4} It is {2} o'clock.
|
||||
{0}{4} What year is it?
|
||||
{1}{4} We are in {3}.
|
||||
{0}{4} What is a cat?
|
||||
{1}{4} A cat is a domestic species of small carnivorous mammal. It is the only domesticated species in the family Felidae.
|
||||
{0}{4} Name a color.
|
||||
{1}{4} Blue
|
||||
{0}{4}
|
21
examples/talk-llama/speak.sh
Executable file
21
examples/talk-llama/speak.sh
Executable file
@ -0,0 +1,21 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Usage:
|
||||
# speak.sh <voice_id> <text-to-speak>
|
||||
|
||||
# espeak
|
||||
# Mac OS: brew install espeak
|
||||
# Linux: apt-get install espeak
|
||||
#
|
||||
#espeak -v en-us+m$1 -s 225 -p 50 -a 200 -g 5 -k 5 "$2"
|
||||
|
||||
# for Mac
|
||||
say "$2"
|
||||
|
||||
# Eleven Labs
|
||||
# To use it, install the elevenlabs module from pip (pip install elevenlabs), register to https://beta.elevenlabs.io to get an api key and paste it in /examples/talk-llama/eleven-labs.py
|
||||
#
|
||||
#wd=$(dirname $0)
|
||||
#script=$wd/eleven-labs.py
|
||||
#python3 $script $1 "$2" >/dev/null 2>&1
|
||||
#ffplay -autoexit -nodisp -loglevel quiet -hide_banner -i ./audio.mp3 >/dev/null 2>&1
|
673
examples/talk-llama/talk-llama.cpp
Normal file
673
examples/talk-llama/talk-llama.cpp
Normal file
@ -0,0 +1,673 @@
|
||||
// Talk with AI
|
||||
//
|
||||
|
||||
#include "common.h"
|
||||
#include "common-sdl.h"
|
||||
#include "whisper.h"
|
||||
#include "llama.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <cstdio>
|
||||
#include <fstream>
|
||||
#include <regex>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <regex>
|
||||
|
||||
std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos) {
|
||||
// initialize to prompt numer of chars, since n_tokens <= n_prompt_chars
|
||||
std::vector<llama_token> res(text.size() + (int)add_bos);
|
||||
int n = llama_tokenize(ctx, text.c_str(), res.data(), res.size(), add_bos);
|
||||
assert(n >= 0);
|
||||
res.resize(n);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
// command-line parameters
|
||||
struct whisper_params {
|
||||
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||
int32_t voice_ms = 10000;
|
||||
int32_t capture_id = -1;
|
||||
int32_t max_tokens = 32;
|
||||
int32_t audio_ctx = 0;
|
||||
|
||||
int32_t n_parts_llama = -1;
|
||||
|
||||
float vad_thold = 0.6f;
|
||||
float freq_thold = 100.0f;
|
||||
|
||||
bool speed_up = false;
|
||||
bool translate = false;
|
||||
bool print_special = false;
|
||||
bool print_energy = false;
|
||||
bool no_timestamps = true;
|
||||
bool verbose_prompt = false;
|
||||
|
||||
std::string person = "Georgi";
|
||||
std::string language = "en";
|
||||
std::string model_wsp = "models/ggml-base.en.bin";
|
||||
std::string model_llama = "models/ggml-llama-7B.bin";
|
||||
std::string speak = "./examples/talk-llama/speak.sh";
|
||||
std::string prompt = "";
|
||||
std::string fname_out;
|
||||
std::string path_session = ""; // path to file for saving/loading model eval state
|
||||
};
|
||||
|
||||
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
|
||||
|
||||
bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
for (int i = 1; i < argc; i++) {
|
||||
std::string arg = argv[i];
|
||||
|
||||
if (arg == "-h" || arg == "--help") {
|
||||
whisper_print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
}
|
||||
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
|
||||
else if (arg == "-vms" || arg == "--voice-ms") { params.voice_ms = std::stoi(argv[++i]); }
|
||||
else if (arg == "-c" || arg == "--capture") { params.capture_id = std::stoi(argv[++i]); }
|
||||
else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); }
|
||||
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
|
||||
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "--n-parts-llama") { params.n_parts_llama = std::stoi(argv[++i]); }
|
||||
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
|
||||
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
|
||||
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
|
||||
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
|
||||
else if (arg == "--verbose-prompt") { params.verbose_prompt = true; }
|
||||
else if (arg == "-p" || arg == "--person") { params.person = argv[++i]; }
|
||||
else if (arg == "--session") { params.path_session = argv[++i];}
|
||||
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
|
||||
else if (arg == "-mw" || arg == "--model-whisper") { params.model_wsp = argv[++i]; }
|
||||
else if (arg == "-ml" || arg == "--model-llama") { params.model_llama = argv[++i]; }
|
||||
else if (arg == "-s" || arg == "--speak") { params.speak = argv[++i]; }
|
||||
else if (arg == "--prompt-file") {
|
||||
std::ifstream file(argv[++i]);
|
||||
std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), back_inserter(params.prompt));
|
||||
if (params.prompt.back() == '\n') {
|
||||
params.prompt.pop_back();
|
||||
}
|
||||
}
|
||||
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
|
||||
else {
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
whisper_print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params) {
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "usage: %s [options]\n", argv[0]);
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "options:\n");
|
||||
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
|
||||
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
|
||||
fprintf(stderr, " -vms N, --voice-ms N [%-7d] voice duration in milliseconds\n", params.voice_ms);
|
||||
fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
|
||||
fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
|
||||
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
|
||||
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
|
||||
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
|
||||
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
|
||||
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
|
||||
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
||||
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
|
||||
fprintf(stderr, " -p NAME, --person NAME [%-7s] person name (for prompt selection)\n", params.person.c_str());
|
||||
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
|
||||
fprintf(stderr, " -mw FILE, --model-whisper [%-7s] whisper model file\n", params.model_wsp.c_str());
|
||||
fprintf(stderr, " -ml FILE, --model-llama [%-7s] llama model file\n", params.model_llama.c_str());
|
||||
fprintf(stderr, " --n-parts-llama N [%-7d] num parts in llama model file\n", params.n_parts_llama);
|
||||
fprintf(stderr, " -s FILE, --speak TEXT [%-7s] command for TTS\n", params.speak.c_str());
|
||||
fprintf(stderr, " --prompt-file FNAME [%-7s] file with custom prompt to start dialog\n", "");
|
||||
fprintf(stderr, " --session FNAME file to cache model state in (may be large!) (default: none)\n");
|
||||
fprintf(stderr, " --verbose-prompt [%-7s] print prompt at start\n", params.verbose_prompt ? "true" : "false");
|
||||
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
std::string transcribe(
|
||||
whisper_context * ctx,
|
||||
const whisper_params & params,
|
||||
const std::vector<float> & pcmf32,
|
||||
const std::string prompt_text,
|
||||
float & prob,
|
||||
int64_t & t_ms) {
|
||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
prob = 0.0f;
|
||||
t_ms = 0;
|
||||
|
||||
std::vector<whisper_token> prompt_tokens;
|
||||
|
||||
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||
|
||||
prompt_tokens.resize(1024);
|
||||
prompt_tokens.resize(whisper_tokenize(ctx, prompt_text.c_str(), prompt_tokens.data(), prompt_tokens.size()));
|
||||
|
||||
wparams.print_progress = false;
|
||||
wparams.print_special = params.print_special;
|
||||
wparams.print_realtime = false;
|
||||
wparams.print_timestamps = !params.no_timestamps;
|
||||
wparams.translate = params.translate;
|
||||
wparams.no_context = true;
|
||||
wparams.single_segment = true;
|
||||
wparams.max_tokens = params.max_tokens;
|
||||
wparams.language = params.language.c_str();
|
||||
wparams.n_threads = params.n_threads;
|
||||
|
||||
wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
|
||||
wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();
|
||||
|
||||
wparams.audio_ctx = params.audio_ctx;
|
||||
wparams.speed_up = params.speed_up;
|
||||
|
||||
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
|
||||
return "";
|
||||
}
|
||||
|
||||
int prob_n = 0;
|
||||
std::string result;
|
||||
|
||||
const int n_segments = whisper_full_n_segments(ctx);
|
||||
for (int i = 0; i < n_segments; ++i) {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
|
||||
result += text;
|
||||
|
||||
const int n_tokens = whisper_full_n_tokens(ctx, i);
|
||||
for (int j = 0; j < n_tokens; ++j) {
|
||||
const auto token = whisper_full_get_token_data(ctx, i, j);
|
||||
|
||||
prob += token.p;
|
||||
++prob_n;
|
||||
}
|
||||
}
|
||||
|
||||
if (prob_n > 0) {
|
||||
prob /= prob_n;
|
||||
}
|
||||
|
||||
const auto t_end = std::chrono::high_resolution_clock::now();
|
||||
t_ms = std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count();
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
const std::string k_prompt_whisper = R"(A conversation with a person called {1}.)";
|
||||
|
||||
const std::string k_prompt_llama = R"(Text transcript of a never ending dialog, where {0} interacts with an AI assistant named {1}.
|
||||
{1} is helpful, kind, honest, friendly, good at writing and never fails to answer {0}’s requests immediately and with details and precision.
|
||||
There are no annotations like (30 seconds passed...) or (to himself), just what {0} and {1} say aloud to each other.
|
||||
The transcript only includes text, it does not include markup like HTML and Markdown.
|
||||
{1} responds with short and concise answers.
|
||||
|
||||
{0}{4} Hello, {1}!
|
||||
{1}{4} Hello {0}! How may I help you today?
|
||||
{0}{4} What time is it?
|
||||
{1}{4} It is {2} o'clock.
|
||||
{0}{4} What year is it?
|
||||
{1}{4} We are in {3}.
|
||||
{0}{4} What is a cat?
|
||||
{1}{4} A cat is a domestic species of small carnivorous mammal. It is the only domesticated species in the family Felidae.
|
||||
{0}{4} Name a color.
|
||||
{1}{4} Blue
|
||||
{0}{4})";
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
whisper_params params;
|
||||
|
||||
if (whisper_params_parse(argc, argv, params) == false) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (whisper_lang_id(params.language.c_str()) == -1) {
|
||||
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
|
||||
whisper_print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
}
|
||||
|
||||
// whisper init
|
||||
|
||||
struct whisper_context * ctx_wsp = whisper_init_from_file(params.model_wsp.c_str());
|
||||
|
||||
// llama init
|
||||
|
||||
auto lparams = llama_context_default_params();
|
||||
|
||||
// tune these to your liking
|
||||
lparams.n_ctx = 2048;
|
||||
lparams.seed = 1;
|
||||
lparams.f16_kv = true;
|
||||
lparams.n_parts = params.n_parts_llama;
|
||||
|
||||
struct llama_context * ctx_llama = llama_init_from_file(params.model_llama.c_str(), lparams);
|
||||
|
||||
// print some info about the processing
|
||||
{
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
if (!whisper_is_multilingual(ctx_wsp)) {
|
||||
if (params.language != "en" || params.translate) {
|
||||
params.language = "en";
|
||||
params.translate = false;
|
||||
fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
|
||||
}
|
||||
}
|
||||
fprintf(stderr, "%s: processing, %d threads, lang = %s, task = %s, timestamps = %d ...\n",
|
||||
__func__,
|
||||
params.n_threads,
|
||||
params.language.c_str(),
|
||||
params.translate ? "translate" : "transcribe",
|
||||
params.no_timestamps ? 0 : 1);
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
|
||||
// init audio
|
||||
|
||||
audio_async audio(30*1000);
|
||||
if (!audio.init(params.capture_id, WHISPER_SAMPLE_RATE)) {
|
||||
fprintf(stderr, "%s: audio.init() failed!\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
audio.resume();
|
||||
|
||||
int n_iter = 0;
|
||||
|
||||
bool is_running = true;
|
||||
bool force_speak = false;
|
||||
|
||||
float prob0 = 0.0f;
|
||||
|
||||
const std::string chat_symb = ":";
|
||||
const std::string bot_name = "LLaMA";
|
||||
|
||||
std::vector<float> pcmf32_cur;
|
||||
std::vector<float> pcmf32_prompt;
|
||||
|
||||
const std::string prompt_whisper = ::replace(k_prompt_whisper, "{1}", bot_name);
|
||||
|
||||
// construct the initial prompt for LLaMA inference
|
||||
std::string prompt_llama = params.prompt.empty() ? k_prompt_llama : params.prompt;
|
||||
|
||||
// need to have leading ' '
|
||||
prompt_llama.insert(0, 1, ' ');
|
||||
|
||||
prompt_llama = ::replace(prompt_llama, "{0}", params.person);
|
||||
prompt_llama = ::replace(prompt_llama, "{1}", bot_name);
|
||||
|
||||
{
|
||||
// get time string
|
||||
std::string time_str;
|
||||
{
|
||||
time_t t = time(0);
|
||||
struct tm * now = localtime(&t);
|
||||
char buf[128];
|
||||
strftime(buf, sizeof(buf), "%H:%M", now);
|
||||
time_str = buf;
|
||||
}
|
||||
prompt_llama = ::replace(prompt_llama, "{2}", time_str);
|
||||
}
|
||||
|
||||
{
|
||||
// get year string
|
||||
std::string year_str;
|
||||
{
|
||||
time_t t = time(0);
|
||||
struct tm * now = localtime(&t);
|
||||
char buf[128];
|
||||
strftime(buf, sizeof(buf), "%Y", now);
|
||||
year_str = buf;
|
||||
}
|
||||
prompt_llama = ::replace(prompt_llama, "{3}", year_str);
|
||||
}
|
||||
|
||||
prompt_llama = ::replace(prompt_llama, "{4}", chat_symb);
|
||||
|
||||
// init session
|
||||
std::string path_session = params.path_session;
|
||||
std::vector<llama_token> session_tokens;
|
||||
auto embd_inp = ::llama_tokenize(ctx_llama, prompt_llama, true);
|
||||
|
||||
if (!path_session.empty()) {
|
||||
fprintf(stderr, "%s: attempting to load saved session from %s\n", __func__, path_session.c_str());
|
||||
|
||||
// fopen to check for existing session
|
||||
FILE * fp = std::fopen(path_session.c_str(), "rb");
|
||||
if (fp != NULL) {
|
||||
std::fclose(fp);
|
||||
|
||||
session_tokens.resize(lparams.n_ctx);
|
||||
size_t n_token_count_out = 0;
|
||||
if (!llama_load_session_file(ctx_llama, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out)) {
|
||||
fprintf(stderr, "%s: error: failed to load session file '%s'\n", __func__, path_session.c_str());
|
||||
return 1;
|
||||
}
|
||||
session_tokens.resize(n_token_count_out);
|
||||
for (size_t i = 0; i < session_tokens.size(); i++) {
|
||||
embd_inp[i] = session_tokens[i];
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s: loaded a session with prompt size of %d tokens\n", __func__, (int) session_tokens.size());
|
||||
} else {
|
||||
fprintf(stderr, "%s: session file does not exist, will create\n", __func__);
|
||||
}
|
||||
}
|
||||
|
||||
// evaluate the initial prompt
|
||||
|
||||
printf("\n");
|
||||
printf("%s : initializing - please wait ...\n", __func__);
|
||||
|
||||
if (llama_eval(ctx_llama, embd_inp.data(), embd_inp.size(), 0, params.n_threads)) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (params.verbose_prompt) {
|
||||
fprintf(stdout, "\n");
|
||||
fprintf(stdout, "%s", prompt_llama.c_str());
|
||||
fflush(stdout);
|
||||
}
|
||||
|
||||
// debug message about similarity of saved session, if applicable
|
||||
size_t n_matching_session_tokens = 0;
|
||||
if (session_tokens.size()) {
|
||||
for (llama_token id : session_tokens) {
|
||||
if (n_matching_session_tokens >= embd_inp.size() || id != embd_inp[n_matching_session_tokens]) {
|
||||
break;
|
||||
}
|
||||
n_matching_session_tokens++;
|
||||
}
|
||||
if (n_matching_session_tokens >= embd_inp.size()) {
|
||||
fprintf(stderr, "%s: session file has exact match for prompt!\n", __func__);
|
||||
} else if (n_matching_session_tokens < (embd_inp.size() / 2)) {
|
||||
fprintf(stderr, "%s: warning: session file has low similarity to prompt (%zu / %zu tokens); will mostly be reevaluated\n",
|
||||
__func__, n_matching_session_tokens, embd_inp.size());
|
||||
} else {
|
||||
fprintf(stderr, "%s: session file matches %zu / %zu tokens of prompt\n",
|
||||
__func__, n_matching_session_tokens, embd_inp.size());
|
||||
}
|
||||
}
|
||||
|
||||
// HACK - because session saving incurs a non-negligible delay, for now skip re-saving session
|
||||
// if we loaded a session with at least 75% similarity. It's currently just used to speed up the
|
||||
// initial prompt so it doesn't need to be an exact match.
|
||||
bool need_to_save_session = !path_session.empty() && n_matching_session_tokens < (embd_inp.size() * 3 / 4);
|
||||
|
||||
printf("%s : done! start speaking in the microphone\n", __func__);
|
||||
printf("\n");
|
||||
printf("%s%s", params.person.c_str(), chat_symb.c_str());
|
||||
fflush(stdout);
|
||||
|
||||
// clear audio buffer
|
||||
audio.clear();
|
||||
|
||||
// text inference variables
|
||||
const int voice_id = 2;
|
||||
const int n_keep = embd_inp.size();
|
||||
const int n_ctx = llama_n_ctx(ctx_llama);
|
||||
|
||||
int n_past = n_keep;
|
||||
int n_prev = 64; // TODO arg
|
||||
int n_session_consumed = !path_session.empty() && session_tokens.size() > 0 ? session_tokens.size() : 0;
|
||||
|
||||
std::vector<llama_token> embd;
|
||||
|
||||
// reverse prompts for detecting when it's time to stop speaking
|
||||
std::vector<std::string> antiprompts = {
|
||||
params.person + chat_symb,
|
||||
};
|
||||
|
||||
// main loop
|
||||
while (is_running) {
|
||||
// handle Ctrl + C
|
||||
is_running = sdl_poll_events();
|
||||
|
||||
if (!is_running) {
|
||||
break;
|
||||
}
|
||||
|
||||
// delay
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
|
||||
int64_t t_ms = 0;
|
||||
|
||||
{
|
||||
audio.get(2000, pcmf32_cur);
|
||||
|
||||
if (::vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1250, params.vad_thold, params.freq_thold, params.print_energy) || force_speak) {
|
||||
//fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
|
||||
|
||||
audio.get(params.voice_ms, pcmf32_cur);
|
||||
|
||||
std::string text_heard;
|
||||
|
||||
if (!force_speak) {
|
||||
text_heard = ::trim(::transcribe(ctx_wsp, params, pcmf32_cur, prompt_whisper, prob0, t_ms));
|
||||
}
|
||||
|
||||
// remove text between brackets using regex
|
||||
{
|
||||
std::regex re("\\[.*?\\]");
|
||||
text_heard = std::regex_replace(text_heard, re, "");
|
||||
}
|
||||
|
||||
// remove text between brackets using regex
|
||||
{
|
||||
std::regex re("\\(.*?\\)");
|
||||
text_heard = std::regex_replace(text_heard, re, "");
|
||||
}
|
||||
|
||||
// remove all characters, except for letters, numbers, punctuation and ':', '\'', '-', ' '
|
||||
text_heard = std::regex_replace(text_heard, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
|
||||
|
||||
// take first line
|
||||
text_heard = text_heard.substr(0, text_heard.find_first_of('\n'));
|
||||
|
||||
// remove leading and trailing whitespace
|
||||
text_heard = std::regex_replace(text_heard, std::regex("^\\s+"), "");
|
||||
text_heard = std::regex_replace(text_heard, std::regex("\\s+$"), "");
|
||||
|
||||
const std::vector<llama_token> tokens = llama_tokenize(ctx_llama, text_heard.c_str(), false);
|
||||
|
||||
if (text_heard.empty() || tokens.empty() || force_speak) {
|
||||
//fprintf(stdout, "%s: Heard nothing, skipping ...\n", __func__);
|
||||
audio.clear();
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
force_speak = false;
|
||||
|
||||
text_heard.insert(0, 1, ' ');
|
||||
text_heard += "\n" + bot_name + chat_symb;
|
||||
fprintf(stdout, "%s%s%s", "\033[1m", text_heard.c_str(), "\033[0m");
|
||||
fflush(stdout);
|
||||
|
||||
embd = ::llama_tokenize(ctx_llama, text_heard, false);
|
||||
|
||||
// Append the new input tokens to the session_tokens vector
|
||||
if (!path_session.empty()) {
|
||||
session_tokens.insert(session_tokens.end(), tokens.begin(), tokens.end());
|
||||
}
|
||||
|
||||
// text inference
|
||||
bool done = false;
|
||||
std::string text_to_speak;
|
||||
while (true) {
|
||||
// predict
|
||||
if (embd.size() > 0) {
|
||||
if (n_past + (int) embd.size() > n_ctx) {
|
||||
n_past = n_keep;
|
||||
|
||||
// insert n_left/2 tokens at the start of embd from last_n_tokens
|
||||
embd.insert(embd.begin(), embd_inp.begin() + embd_inp.size() - n_prev, embd_inp.end());
|
||||
// stop saving session if we run out of context
|
||||
path_session = "";
|
||||
//printf("\n---\n");
|
||||
//printf("resetting: '");
|
||||
//for (int i = 0; i < (int) embd.size(); i++) {
|
||||
// printf("%s", llama_token_to_str(ctx_llama, embd[i]));
|
||||
//}
|
||||
//printf("'\n");
|
||||
//printf("\n---\n");
|
||||
}
|
||||
|
||||
// try to reuse a matching prefix from the loaded session instead of re-eval (via n_past)
|
||||
// REVIEW
|
||||
if (n_session_consumed < (int) session_tokens.size()) {
|
||||
size_t i = 0;
|
||||
for ( ; i < embd.size(); i++) {
|
||||
if (embd[i] != session_tokens[n_session_consumed]) {
|
||||
session_tokens.resize(n_session_consumed);
|
||||
break;
|
||||
}
|
||||
|
||||
n_past++;
|
||||
n_session_consumed++;
|
||||
|
||||
if (n_session_consumed >= (int) session_tokens.size()) {
|
||||
i++;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (i > 0) {
|
||||
embd.erase(embd.begin(), embd.begin() + i);
|
||||
}
|
||||
}
|
||||
|
||||
if (embd.size() > 0 && !path_session.empty()) {
|
||||
session_tokens.insert(session_tokens.end(), embd.begin(), embd.end());
|
||||
n_session_consumed = session_tokens.size();
|
||||
}
|
||||
|
||||
if (llama_eval(ctx_llama, embd.data(), embd.size(), n_past, params.n_threads)) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
embd_inp.insert(embd_inp.end(), embd.begin(), embd.end());
|
||||
n_past += embd.size();
|
||||
|
||||
embd.clear();
|
||||
|
||||
if (done) break;
|
||||
|
||||
{
|
||||
// out of user input, sample next token
|
||||
const float top_k = 5;
|
||||
const float top_p = 0.80f;
|
||||
const float temp = 0.30f;
|
||||
const float repeat_penalty = 1.1764f;
|
||||
|
||||
const int repeat_last_n = 256;
|
||||
|
||||
if (!path_session.empty() && need_to_save_session) {
|
||||
need_to_save_session = false;
|
||||
llama_save_session_file(ctx_llama, path_session.c_str(), session_tokens.data(), session_tokens.size());
|
||||
}
|
||||
|
||||
llama_token id = 0;
|
||||
|
||||
{
|
||||
auto logits = llama_get_logits(ctx_llama);
|
||||
auto n_vocab = llama_n_vocab(ctx_llama);
|
||||
|
||||
logits[llama_token_eos()] = 0;
|
||||
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
||||
}
|
||||
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
|
||||
// apply repeat penalty
|
||||
const float nl_logit = logits[llama_token_nl()];
|
||||
|
||||
llama_sample_repetition_penalty(ctx_llama, &candidates_p,
|
||||
embd_inp.data() + std::max(0, n_past - repeat_last_n),
|
||||
repeat_last_n, repeat_penalty);
|
||||
|
||||
logits[llama_token_nl()] = nl_logit;
|
||||
|
||||
if (temp <= 0) {
|
||||
// Greedy sampling
|
||||
id = llama_sample_token_greedy(ctx_llama, &candidates_p);
|
||||
} else {
|
||||
// Temperature sampling
|
||||
llama_sample_top_k(ctx_llama, &candidates_p, top_k);
|
||||
llama_sample_top_p(ctx_llama, &candidates_p, top_p);
|
||||
llama_sample_temperature(ctx_llama, &candidates_p, temp);
|
||||
id = llama_sample_token(ctx_llama, &candidates_p);
|
||||
}
|
||||
}
|
||||
|
||||
if (id != llama_token_eos()) {
|
||||
// add it to the context
|
||||
embd.push_back(id);
|
||||
|
||||
text_to_speak += llama_token_to_str(ctx_llama, id);
|
||||
|
||||
printf("%s", llama_token_to_str(ctx_llama, id));
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
std::string last_output;
|
||||
for (int i = embd_inp.size() - 16; i < (int) embd_inp.size(); i++) {
|
||||
last_output += llama_token_to_str(ctx_llama, embd_inp[i]);
|
||||
}
|
||||
last_output += llama_token_to_str(ctx_llama, embd[0]);
|
||||
|
||||
for (std::string & antiprompt : antiprompts) {
|
||||
if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
|
||||
done = true;
|
||||
text_to_speak = ::replace(text_to_speak, antiprompt, "");
|
||||
fflush(stdout);
|
||||
need_to_save_session = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
is_running = sdl_poll_events();
|
||||
|
||||
if (!is_running) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
text_to_speak = ::replace(text_to_speak, "\"", "");
|
||||
system((params.speak + " " + std::to_string(voice_id) + " \"" + text_to_speak + "\"").c_str());
|
||||
|
||||
audio.clear();
|
||||
|
||||
++n_iter;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
audio.pause();
|
||||
|
||||
whisper_print_timings(ctx_wsp);
|
||||
whisper_free(ctx_wsp);
|
||||
|
||||
llama_print_timings(ctx_llama);
|
||||
llama_free(ctx_llama);
|
||||
|
||||
return 0;
|
||||
}
|
@ -13,6 +13,7 @@ include(DefaultTargetOptions)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE
|
||||
whisper
|
||||
common
|
||||
)
|
||||
|
||||
unset(EXTRA_FLAGS)
|
||||
|
@ -1,4 +1,6 @@
|
||||
#include "ggml.h"
|
||||
#include "common-ggml.h"
|
||||
|
||||
#include "gpt-2.h"
|
||||
|
||||
#include <cmath>
|
||||
@ -14,150 +16,6 @@
|
||||
|
||||
/////////////////////// GPT-2 BEGIN /////////////////////////
|
||||
|
||||
//
|
||||
// Vocab utils
|
||||
//
|
||||
|
||||
std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text) {
|
||||
std::vector<std::string> words;
|
||||
|
||||
// first split the text into words
|
||||
{
|
||||
std::string str = text;
|
||||
std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
|
||||
|
||||
std::regex re(pat);
|
||||
std::smatch m;
|
||||
|
||||
while (std::regex_search(str, m, re)) {
|
||||
for (auto x : m) {
|
||||
words.push_back(x);
|
||||
}
|
||||
str = m.suffix();
|
||||
}
|
||||
}
|
||||
|
||||
// find the longest tokens that form the words:
|
||||
std::vector<gpt_vocab::id> tokens;
|
||||
for (const auto & word : words) {
|
||||
if (word.size() == 0) continue;
|
||||
|
||||
int i = 0;
|
||||
int n = word.size();
|
||||
while (i < n) {
|
||||
int j = n;
|
||||
while (j > i) {
|
||||
auto it = vocab.token_to_id.find(word.substr(i, j-i));
|
||||
if (it != vocab.token_to_id.end()) {
|
||||
tokens.push_back(it->second);
|
||||
i = j;
|
||||
break;
|
||||
}
|
||||
--j;
|
||||
}
|
||||
if (i == n) {
|
||||
break;
|
||||
}
|
||||
if (j == i) {
|
||||
auto sub = word.substr(i, 1);
|
||||
if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) {
|
||||
tokens.push_back(vocab.token_to_id.at(sub));
|
||||
} else {
|
||||
fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data());
|
||||
}
|
||||
++i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tokens;
|
||||
}
|
||||
|
||||
gpt_vocab::id gpt_sample_top_k_top_p(
|
||||
const gpt_vocab & vocab,
|
||||
const float * logits,
|
||||
int top_k,
|
||||
double top_p,
|
||||
double temp,
|
||||
std::mt19937 & rng) {
|
||||
int n_logits = vocab.id_to_token.size();
|
||||
|
||||
std::vector<std::pair<double, gpt_vocab::id>> logits_id;
|
||||
logits_id.reserve(n_logits);
|
||||
|
||||
for (int i = 0; i < n_logits; i++) {
|
||||
logits_id.push_back(std::make_pair(logits[i], i));
|
||||
}
|
||||
|
||||
// find the top K tokens
|
||||
std::partial_sort(
|
||||
logits_id.begin(),
|
||||
logits_id.begin() + top_k, logits_id.end(),
|
||||
[](const std::pair<double, gpt_vocab::id> & a, const std::pair<double, gpt_vocab::id> & b) {
|
||||
return a.first > b.first;
|
||||
});
|
||||
|
||||
logits_id.resize(top_k);
|
||||
|
||||
// normalize
|
||||
{
|
||||
double sum = 0.0f;
|
||||
for (int i = 0; i < (int)logits_id.size(); i++) {
|
||||
sum += logits_id[i].first;
|
||||
}
|
||||
|
||||
sum = 1.0/sum;
|
||||
for (int i = 0; i < (int)logits_id.size(); i++) {
|
||||
logits_id[i].first *= sum;
|
||||
}
|
||||
}
|
||||
|
||||
if (top_p < 1.0f) {
|
||||
{
|
||||
double cumsum = 0.0f;
|
||||
for (int i = 0; i < top_k; i++) {
|
||||
cumsum += logits_id[i].first;
|
||||
if (cumsum >= top_p) {
|
||||
logits_id.resize(i+1);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// normalize again
|
||||
{
|
||||
double sum = 0.0f;
|
||||
for (int i = 0; i < (int)logits_id.size(); i++) {
|
||||
sum += logits_id[i].first;
|
||||
}
|
||||
|
||||
sum = 1.0/sum;
|
||||
for (int i = 0; i < (int)logits_id.size(); i++) {
|
||||
logits_id[i].first *= sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//printf("\n");
|
||||
//for (int i = 0; i < (int)logits_id.size(); i++) {
|
||||
// printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), logits_id[i].first);
|
||||
//}
|
||||
//exit(0);
|
||||
|
||||
// sample from the obtained distribution
|
||||
std::vector<double> probs;
|
||||
probs.reserve(logits_id.size());
|
||||
|
||||
for (int i = 0; i < (int) logits_id.size(); i++) {
|
||||
probs.push_back(logits_id[i].first);
|
||||
}
|
||||
|
||||
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
||||
int idx = dist(rng);
|
||||
|
||||
return logits_id[idx].second;
|
||||
}
|
||||
|
||||
// default hparams (GPT-2 117M)
|
||||
struct gpt2_hparams {
|
||||
int32_t n_vocab = 50257;
|
||||
@ -165,7 +23,7 @@ struct gpt2_hparams {
|
||||
int32_t n_embd = 768;
|
||||
int32_t n_head = 12;
|
||||
int32_t n_layer = 12;
|
||||
int32_t f16 = 1;
|
||||
int32_t ftype = 1;
|
||||
};
|
||||
|
||||
struct gpt2_layer {
|
||||
@ -187,7 +45,7 @@ struct gpt2_layer {
|
||||
struct ggml_tensor * c_mlp_fc_w;
|
||||
struct ggml_tensor * c_mlp_fc_b;
|
||||
|
||||
struct ggml_tensor * c_mlp_proj_w_trans; // transposed for efficiency
|
||||
struct ggml_tensor * c_mlp_proj_w;
|
||||
struct ggml_tensor * c_mlp_proj_b;
|
||||
};
|
||||
|
||||
@ -198,8 +56,9 @@ struct gpt2_model {
|
||||
struct ggml_tensor * ln_f_g;
|
||||
struct ggml_tensor * ln_f_b;
|
||||
|
||||
struct ggml_tensor * wte; // position embedding
|
||||
struct ggml_tensor * wpe; // token embedding
|
||||
struct ggml_tensor * wte; // position embedding
|
||||
struct ggml_tensor * wpe; // token embedding
|
||||
struct ggml_tensor * lm_head; // language model head
|
||||
|
||||
std::vector<gpt2_layer> layers;
|
||||
|
||||
@ -241,14 +100,14 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
|
||||
fin.read((char *) &hparams.n_embd, sizeof(hparams.n_embd));
|
||||
fin.read((char *) &hparams.n_head, sizeof(hparams.n_head));
|
||||
fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));
|
||||
fin.read((char *) &hparams.f16, sizeof(hparams.f16));
|
||||
fin.read((char *) &hparams.ftype, sizeof(hparams.ftype));
|
||||
|
||||
printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
|
||||
printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx);
|
||||
printf("%s: n_embd = %d\n", __func__, hparams.n_embd);
|
||||
printf("%s: n_head = %d\n", __func__, hparams.n_head);
|
||||
printf("%s: n_layer = %d\n", __func__, hparams.n_layer);
|
||||
printf("%s: f16 = %d\n", __func__, hparams.f16);
|
||||
printf("%s: ftype = %d\n", __func__, hparams.ftype);
|
||||
}
|
||||
|
||||
// load vocab
|
||||
@ -275,9 +134,14 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
|
||||
}
|
||||
}
|
||||
|
||||
// for the big tensors, we have the option to store the data in 16-bit floats
|
||||
// for the big tensors, we have the option to store the data in 16-bit floats or quantized
|
||||
// in order to save memory and also to speed up the computation
|
||||
const ggml_type wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
||||
ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype));
|
||||
if (wtype == GGML_TYPE_COUNT) {
|
||||
fprintf(stderr, "%s: invalid model file '%s' (bad ftype value %d)\n",
|
||||
__func__, fname.c_str(), model.hparams.ftype);
|
||||
return false;
|
||||
}
|
||||
|
||||
auto & ctx = model.ctx;
|
||||
|
||||
@ -291,32 +155,33 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
|
||||
const int n_ctx = hparams.n_ctx;
|
||||
const int n_vocab = hparams.n_vocab;
|
||||
|
||||
ctx_size += n_embd*ggml_type_size(GGML_TYPE_F32); // ln_f_g
|
||||
ctx_size += n_embd*ggml_type_size(GGML_TYPE_F32); // ln_f_b
|
||||
ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_g
|
||||
ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_b
|
||||
|
||||
ctx_size += n_vocab*n_embd*ggml_type_size(wtype); // wte
|
||||
ctx_size += n_ctx*n_embd*ggml_type_size(GGML_TYPE_F32); // wpe
|
||||
ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype); // wte
|
||||
ctx_size += n_ctx*n_embd*ggml_type_sizef(GGML_TYPE_F32); // wpe
|
||||
ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype); // lm_head
|
||||
|
||||
ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_1_g
|
||||
ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_1_b
|
||||
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_g
|
||||
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_b
|
||||
|
||||
ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_2_g
|
||||
ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_2_b
|
||||
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_g
|
||||
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_b
|
||||
|
||||
ctx_size += n_layer*(3*n_embd*n_embd*ggml_type_size(wtype)); // c_attn_attn_w
|
||||
ctx_size += n_layer*( 3*n_embd*ggml_type_size(GGML_TYPE_F32)); // c_attn_attn_b
|
||||
ctx_size += n_layer*(3*n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_attn_w
|
||||
ctx_size += n_layer*( 3*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_attn_b
|
||||
|
||||
ctx_size += n_layer*(n_embd*n_embd*ggml_type_size(wtype)); // c_attn_proj_w
|
||||
ctx_size += n_layer*( n_embd*ggml_type_size(GGML_TYPE_F32)); // c_attn_proj_b
|
||||
ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_proj_w
|
||||
ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_proj_b
|
||||
|
||||
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_size(wtype)); // c_mlp_fc_w
|
||||
ctx_size += n_layer*( 4*n_embd*ggml_type_size(GGML_TYPE_F32)); // c_mlp_fc_b
|
||||
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_fc_w
|
||||
ctx_size += n_layer*( 4*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_fc_b
|
||||
|
||||
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_size(wtype)); // c_mlp_proj_w
|
||||
ctx_size += n_layer*( n_embd*ggml_type_size(GGML_TYPE_F32)); // c_mlp_proj_b
|
||||
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_proj_w
|
||||
ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_proj_b
|
||||
|
||||
ctx_size += n_ctx*n_layer*n_embd*ggml_type_size(GGML_TYPE_F32); // memory_k
|
||||
ctx_size += n_ctx*n_layer*n_embd*ggml_type_size(GGML_TYPE_F32); // memory_v
|
||||
ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_k
|
||||
ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_v
|
||||
|
||||
ctx_size += (6 + 12*n_layer)*256; // object overhead
|
||||
|
||||
@ -325,9 +190,11 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
|
||||
|
||||
// create the ggml context
|
||||
{
|
||||
struct ggml_init_params params;
|
||||
params.mem_size = ctx_size;
|
||||
params.mem_buffer = NULL;
|
||||
struct ggml_init_params params = {
|
||||
.mem_size = ctx_size,
|
||||
.mem_buffer = NULL,
|
||||
.no_alloc = false,
|
||||
};
|
||||
|
||||
model.ctx = ggml_init(params);
|
||||
if (!model.ctx) {
|
||||
@ -350,36 +217,38 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
|
||||
model.ln_f_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
model.ln_f_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
|
||||
model.wte = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab);
|
||||
model.wpe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ctx);
|
||||
model.wte = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab);
|
||||
model.wpe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ctx);
|
||||
model.lm_head = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab);
|
||||
|
||||
// map by name
|
||||
model.tensors["model/ln_f/g"] = model.ln_f_g;
|
||||
model.tensors["model/ln_f/b"] = model.ln_f_b;
|
||||
|
||||
model.tensors["model/wte"] = model.wte;
|
||||
model.tensors["model/wpe"] = model.wpe;
|
||||
model.tensors["model/wte"] = model.wte;
|
||||
model.tensors["model/wpe"] = model.wpe;
|
||||
model.tensors["model/lm_head"] = model.lm_head;
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = model.layers[i];
|
||||
|
||||
layer.ln_1_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
layer.ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
layer.ln_1_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
layer.ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
|
||||
layer.ln_2_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
layer.ln_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
layer.ln_2_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
layer.ln_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
|
||||
layer.c_attn_attn_w = ggml_new_tensor_2d(ctx, wtype, 3*n_embd, n_embd);
|
||||
layer.c_attn_attn_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3*n_embd);
|
||||
layer.c_attn_attn_w = ggml_new_tensor_2d(ctx, wtype, n_embd, 3*n_embd);
|
||||
layer.c_attn_attn_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3*n_embd);
|
||||
|
||||
layer.c_attn_proj_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
|
||||
layer.c_attn_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
layer.c_attn_proj_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
|
||||
layer.c_attn_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
|
||||
layer.c_mlp_fc_w = ggml_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd);
|
||||
layer.c_mlp_fc_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_embd);
|
||||
layer.c_mlp_fc_w = ggml_new_tensor_2d(ctx, wtype, n_embd, 4*n_embd);
|
||||
layer.c_mlp_fc_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_embd);
|
||||
|
||||
layer.c_mlp_proj_w_trans = ggml_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd);
|
||||
layer.c_mlp_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
layer.c_mlp_proj_w = ggml_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd);
|
||||
layer.c_mlp_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
|
||||
// map by name
|
||||
model.tensors["model/h" + std::to_string(i) + "/ln_1/g"] = layer.ln_1_g;
|
||||
@ -397,7 +266,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
|
||||
model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/w"] = layer.c_mlp_fc_w;
|
||||
model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/b"] = layer.c_mlp_fc_b;
|
||||
|
||||
model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/w"] = layer.c_mlp_proj_w_trans;
|
||||
model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/w"] = layer.c_mlp_proj_w;
|
||||
model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/b"] = layer.c_mlp_proj_b;
|
||||
}
|
||||
}
|
||||
@ -425,14 +294,16 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
|
||||
{
|
||||
size_t total_size = 0;
|
||||
|
||||
bool has_lm_head = false;
|
||||
|
||||
while (true) {
|
||||
int32_t n_dims;
|
||||
int32_t length;
|
||||
int32_t ftype;
|
||||
int32_t ttype;
|
||||
|
||||
fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
|
||||
fin.read(reinterpret_cast<char *>(&length), sizeof(length));
|
||||
fin.read(reinterpret_cast<char *>(&ftype), sizeof(ftype));
|
||||
fin.read(reinterpret_cast<char *>(&ttype), sizeof(ttype));
|
||||
|
||||
if (fin.eof()) {
|
||||
break;
|
||||
@ -461,13 +332,18 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
|
||||
|
||||
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
|
||||
fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
|
||||
__func__, name.data(), tensor->ne[0], tensor->ne[1], ne[0], ne[1]);
|
||||
__func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], ne[0], ne[1]);
|
||||
return false;
|
||||
}
|
||||
|
||||
const size_t bpe = (ftype == 0) ? sizeof(float) : sizeof(ggml_fp16_t);
|
||||
// for debugging
|
||||
if (0) {
|
||||
printf("%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", name.data(), ne[0], ne[1], ggml_type_name(ggml_type(ttype)), ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor));
|
||||
}
|
||||
|
||||
if (nelements*bpe != ggml_nbytes(tensor)) {
|
||||
const size_t bpe = ggml_type_size(ggml_type(ttype));
|
||||
|
||||
if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
|
||||
fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
|
||||
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
|
||||
return false;
|
||||
@ -475,7 +351,15 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
|
||||
|
||||
fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
|
||||
|
||||
//printf("%24s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
|
||||
// GPT-2 models share the WTE tensor as the LM head
|
||||
if (name == "model/wte" && has_lm_head == false) {
|
||||
memcpy(model.lm_head->data, tensor->data, ggml_nbytes(tensor));
|
||||
}
|
||||
|
||||
if (name == "model/lm_head") {
|
||||
has_lm_head = true;
|
||||
}
|
||||
|
||||
total_size += ggml_nbytes(tensor);
|
||||
}
|
||||
|
||||
@ -493,7 +377,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
|
||||
// - n_threads: number of threads to use
|
||||
// - n_past: the context size so far
|
||||
// - embd_inp: the embeddings of the tokens in the context
|
||||
// - embd_w: the predicted probabilities of the next token
|
||||
// - embd_w: the predicted logits for the next token
|
||||
//
|
||||
bool gpt2_eval(
|
||||
const gpt2_model & model,
|
||||
@ -512,12 +396,12 @@ bool gpt2_eval(
|
||||
const int n_head = hparams.n_head;
|
||||
const int n_vocab = hparams.n_vocab;
|
||||
|
||||
static size_t buf_size = 640u*1024*1024;
|
||||
static size_t buf_size = 512u*1024*1024;
|
||||
static void * buf = malloc(buf_size);
|
||||
|
||||
if (mem_per_token > 0 && mem_per_token*N > buf_size) {
|
||||
const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead
|
||||
printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
|
||||
//printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
|
||||
|
||||
// reallocate
|
||||
buf_size = buf_size_new;
|
||||
@ -528,13 +412,14 @@ bool gpt2_eval(
|
||||
}
|
||||
}
|
||||
|
||||
struct ggml_init_params params;
|
||||
params.mem_size = buf_size;
|
||||
params.mem_buffer = buf;
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ buf_size,
|
||||
/*.mem_buffer =*/ buf,
|
||||
/*.no_alloc =*/ false,
|
||||
};
|
||||
|
||||
struct ggml_context * ctx0 = ggml_init(params);
|
||||
|
||||
struct ggml_cgraph gf = { };
|
||||
struct ggml_cgraph gf = {};
|
||||
gf.n_threads = n_threads;
|
||||
|
||||
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
||||
@ -578,7 +463,7 @@ bool gpt2_eval(
|
||||
// [2304, N]
|
||||
{
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
ggml_transpose(ctx0, model.layers[il].c_attn_attn_w),
|
||||
model.layers[il].c_attn_attn_w,
|
||||
cur);
|
||||
|
||||
cur = ggml_add(ctx0,
|
||||
@ -654,11 +539,13 @@ bool gpt2_eval(
|
||||
// V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
|
||||
// [n_past + N, 64, 12]
|
||||
struct ggml_tensor * V_trans =
|
||||
ggml_permute(ctx0,
|
||||
ggml_reshape_3d(ctx0,
|
||||
ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
|
||||
n_embd/n_head, n_head, n_past + N),
|
||||
1, 2, 0, 3);
|
||||
ggml_cpy(ctx0,
|
||||
ggml_permute(ctx0,
|
||||
ggml_reshape_3d(ctx0,
|
||||
ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
|
||||
n_embd/n_head, n_head, n_past + N),
|
||||
1, 2, 0, 3),
|
||||
ggml_new_tensor_3d(ctx0, model.memory_v->type, n_past + N, n_embd/n_head, n_head));
|
||||
|
||||
// KQV = transpose(V) * KQ_soft_max
|
||||
// [64, N, 12]
|
||||
@ -685,7 +572,7 @@ bool gpt2_eval(
|
||||
// [768, N]
|
||||
{
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
ggml_transpose(ctx0, model.layers[il].c_attn_proj_w),
|
||||
model.layers[il].c_attn_proj_w,
|
||||
cur);
|
||||
|
||||
cur = ggml_add(ctx0,
|
||||
@ -722,7 +609,7 @@ bool gpt2_eval(
|
||||
// cur = fc_w*cur + fc_b
|
||||
// [3072, N]
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
ggml_transpose(ctx0, model.layers[il].c_mlp_fc_w),
|
||||
model.layers[il].c_mlp_fc_w,
|
||||
cur);
|
||||
|
||||
cur = ggml_add(ctx0,
|
||||
@ -742,7 +629,7 @@ bool gpt2_eval(
|
||||
// cur = proj_w*cur + proj_b
|
||||
// [768, N]
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
model.layers[il].c_mlp_proj_w_trans,
|
||||
model.layers[il].c_mlp_proj_w,
|
||||
cur);
|
||||
|
||||
cur = ggml_add(ctx0,
|
||||
@ -769,12 +656,12 @@ bool gpt2_eval(
|
||||
}
|
||||
|
||||
// inpL = WTE * inpL
|
||||
// [ 768, 50257] - model.wte
|
||||
// [ 768, 50257] - model.lm_head
|
||||
// [ 768, N] - inpL
|
||||
inpL = ggml_mul_mat(ctx0, model.wte, inpL);
|
||||
inpL = ggml_mul_mat(ctx0, model.lm_head, inpL);
|
||||
|
||||
// logits -> probs
|
||||
inpL = ggml_soft_max(ctx0, inpL);
|
||||
//inpL = ggml_soft_max(ctx0, inpL);
|
||||
|
||||
// run the computation
|
||||
ggml_build_forward_expand(&gf, inpL);
|
||||
@ -788,7 +675,7 @@ bool gpt2_eval(
|
||||
//embd_w.resize(n_vocab*N);
|
||||
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
|
||||
|
||||
// return result for just the last token
|
||||
// return result just for the last token
|
||||
embd_w.resize(n_vocab);
|
||||
memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
|
||||
|
||||
@ -825,7 +712,7 @@ Me too.
|
||||
int32_t n_threads = std::min(N_THREAD, (int) std::thread::hardware_concurrency());
|
||||
|
||||
// sampling parameters
|
||||
int32_t top_k = 40;
|
||||
int32_t top_k = 5;
|
||||
float top_p = 0.9f;
|
||||
float temp = 1.0f;
|
||||
};
|
||||
@ -833,14 +720,15 @@ Me too.
|
||||
struct gpt2_context * gpt2_init(const char * path_model) {
|
||||
gpt2_context * ctx = new gpt2_context;
|
||||
|
||||
ctx->rng = std::mt19937(time(NULL));
|
||||
ctx->rng = std::mt19937(time(nullptr));
|
||||
|
||||
// load the model
|
||||
{
|
||||
const int64_t t_start_us = ggml_time_us();
|
||||
|
||||
if (!gpt2_model_load(path_model, ctx->model, ctx->vocab)) {
|
||||
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, "gpt-2.bin");
|
||||
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, path_model);
|
||||
delete ctx;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@ -884,9 +772,9 @@ std::string gpt2_gen_text(gpt2_context * ctx, const char * text, int max_tokens)
|
||||
|
||||
std::string result;
|
||||
|
||||
for (int i = embd.size(); i < embd_inp.size() + n_predict; i++) {
|
||||
for (int i = embd.size(); i < (int) embd_inp.size() + n_predict; i++) {
|
||||
// predict
|
||||
if (embd.size() > 0) {
|
||||
if (!embd.empty()) {
|
||||
if (!gpt2_eval(ctx->model, ctx->n_threads, n_past, embd, embd_w, mem_per_token)) {
|
||||
printf("gpt-2: failed to generate text\n");
|
||||
return "";
|
||||
@ -913,10 +801,7 @@ std::string gpt2_gen_text(gpt2_context * ctx, const char * text, int max_tokens)
|
||||
result += ctx->vocab.id_to_token[embd[0]];
|
||||
|
||||
// end of text token
|
||||
if (embd.back() == 50256 ||
|
||||
ctx->vocab.id_to_token[embd.back()] == "." ||
|
||||
ctx->vocab.id_to_token[embd.back()] == "!" ||
|
||||
ctx->vocab.id_to_token[embd.back()] == "?") {
|
||||
if (embd.back() == 50256) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -2,18 +2,12 @@
|
||||
|
||||
// TODO: Change to C-style API and move to ./examples for easy reuse.
|
||||
|
||||
#include "common.h"
|
||||
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <string>
|
||||
|
||||
struct gpt_vocab {
|
||||
using id = int32_t;
|
||||
using token = std::string;
|
||||
|
||||
std::map<token, id> token_to_id;
|
||||
std::map<id, token> id_to_token;
|
||||
};
|
||||
|
||||
struct gpt2_context;
|
||||
|
||||
struct gpt2_context * gpt2_init(const char * path_model);
|
||||
|
@ -44,6 +44,15 @@
|
||||
|
||||
<br><br>
|
||||
|
||||
<b>More examples:</b>
|
||||
<a href="https://whisper.ggerganov.com/">main</a> |
|
||||
<a href="https://whisper.ggerganov.com/bench">bench</a> |
|
||||
<a href="https://whisper.ggerganov.com/stream">stream</a> |
|
||||
<a href="https://whisper.ggerganov.com/command">command</a> |
|
||||
<a href="https://whisper.ggerganov.com/talk">talk</a> |
|
||||
|
||||
<br><br>
|
||||
|
||||
<hr>
|
||||
|
||||
Select the models you would like to use and click the "Start" button to begin the conversation
|
||||
@ -54,6 +63,10 @@
|
||||
Whisper model: <span id="model-whisper-status"></span>
|
||||
<button id="fetch-whisper-tiny-en" onclick="loadWhisper('tiny.en')">tiny.en (75 MB)</button>
|
||||
<button id="fetch-whisper-base-en" onclick="loadWhisper('base.en')">base.en (142 MB)</button>
|
||||
<br><br>
|
||||
Quantized models:<br><br>
|
||||
<button id="fetch-whisper-tiny-en-q5_1" onclick="loadWhisper('tiny-en-q5_1')">tiny.en (Q5_1, 31 MB)</button>
|
||||
<button id="fetch-whisper-base-en-q5_1" onclick="loadWhisper('base-en-q5_1')">base.en (Q5_1, 57 MB)</button>
|
||||
<span id="fetch-whisper-progress"></span>
|
||||
|
||||
<!--
|
||||
@ -266,11 +279,17 @@
|
||||
let urls = {
|
||||
'tiny.en': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en.bin',
|
||||
'base.en': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en.bin',
|
||||
|
||||
'tiny-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en-q5_1.bin',
|
||||
'base-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en-q5_1.bin',
|
||||
};
|
||||
|
||||
let sizes = {
|
||||
'tiny.en': 75,
|
||||
'base.en': 142,
|
||||
|
||||
'tiny-en-q5_1': 31,
|
||||
'base-en-q5_1': 57,
|
||||
};
|
||||
|
||||
let url = urls[model];
|
||||
@ -281,6 +300,10 @@
|
||||
|
||||
document.getElementById('fetch-whisper-tiny-en').style.display = 'none';
|
||||
document.getElementById('fetch-whisper-base-en').style.display = 'none';
|
||||
|
||||
document.getElementById('fetch-whisper-tiny-en-q5_1').style.display = 'none';
|
||||
document.getElementById('fetch-whisper-base-en-q5_1').style.display = 'none';
|
||||
|
||||
document.getElementById('model-whisper-status').innerHTML = 'loading "' + model + '" ... ';
|
||||
|
||||
cbProgress = function(p) {
|
||||
@ -292,6 +315,10 @@
|
||||
var el;
|
||||
el = document.getElementById('fetch-whisper-tiny-en'); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('fetch-whisper-base-en'); if (el) el.style.display = 'inline-block';
|
||||
|
||||
el = document.getElementById('fetch-whisper-tiny-en-q5_1'); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('fetch-whisper-base-en-q5_1'); if (el) el.style.display = 'inline-block';
|
||||
|
||||
el = document.getElementById('model-whisper-status'); if (el) el.innerHTML = '';
|
||||
};
|
||||
|
||||
|
2
examples/talk/.gitignore
vendored
2
examples/talk/.gitignore
vendored
@ -1 +1 @@
|
||||
eleven-labs.py
|
||||
audio.mp3
|
||||
|
@ -1,16 +1,8 @@
|
||||
if (WHISPER_SUPPORT_SDL2)
|
||||
if (WHISPER_SDL2)
|
||||
# talk
|
||||
set(TARGET talk)
|
||||
#add_executable(${TARGET} talk.cpp gpt-2.cpp)
|
||||
#target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
|
||||
#target_link_libraries(${TARGET} PRIVATE whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
|
||||
|
||||
# TODO: this is temporary
|
||||
# need to export ggml symbols for MSVC, but too lazy ..
|
||||
add_executable(${TARGET} talk.cpp gpt-2.cpp ../../ggml.c ../../whisper.cpp)
|
||||
add_executable(${TARGET} talk.cpp gpt-2.cpp)
|
||||
target_link_libraries(${TARGET} PRIVATE common common-sdl whisper ${CMAKE_THREAD_LIBS_INIT})
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS} ../../)
|
||||
target_link_libraries(${TARGET} PRIVATE ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
|
||||
endif ()
|
||||
|
@ -31,7 +31,7 @@ To run this, you will need a ggml GPT-2 model: [instructions](https://github.com
|
||||
Alternatively, you can simply download the smallest ggml GPT-2 117M model (240 MB) like this:
|
||||
|
||||
```
|
||||
wget --quiet --show-progress -O models/ggml-gpt-2-117M.bin https://huggingface.co/datasets/ggerganov/ggml/raw/main/ggml-model-gpt-2-117M.bin
|
||||
wget --quiet --show-progress -O models/ggml-gpt-2-117M.bin https://huggingface.co/ggerganov/ggml/raw/main/ggml-model-gpt-2-117M.bin
|
||||
```
|
||||
|
||||
## TTS
|
||||
|
23
examples/talk/eleven-labs.py
Normal file
23
examples/talk/eleven-labs.py
Normal file
@ -0,0 +1,23 @@
|
||||
import sys
|
||||
import importlib.util
|
||||
|
||||
api_key = "" #Write your https://beta.elevenlabs.io api key here
|
||||
if not api_key:
|
||||
print("To use elevenlabs you have to register to https://beta.elevenlabs.io and add your elevenlabs api key to examples/talk/eleven-labs.py")
|
||||
sys.exit()
|
||||
|
||||
if importlib.util.find_spec("elevenlabs") is None:
|
||||
print("elevenlabs library is not installed, you can install it to your enviroment using 'pip install elevenlabs'")
|
||||
sys.exit()
|
||||
|
||||
from elevenlabs import ElevenLabs
|
||||
eleven = ElevenLabs(api_key)
|
||||
|
||||
# Get a Voice object, by name or UUID
|
||||
voice = eleven.voices["Arnold"] #Possible Voices: Adam Antoni Arnold Bella Domi Elli Josh
|
||||
|
||||
# Generate the TTS
|
||||
audio = voice.generate(str(sys.argv[2:]))
|
||||
|
||||
# Save the TTS to a file
|
||||
audio.save("audio")
|
@ -1,4 +1,6 @@
|
||||
#include "ggml.h"
|
||||
#include "common-ggml.h"
|
||||
|
||||
#include "gpt-2.h"
|
||||
|
||||
#include <cmath>
|
||||
@ -14,150 +16,6 @@
|
||||
|
||||
/////////////////////// GPT-2 BEGIN /////////////////////////
|
||||
|
||||
//
|
||||
// Vocab utils
|
||||
//
|
||||
|
||||
std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text) {
|
||||
std::vector<std::string> words;
|
||||
|
||||
// first split the text into words
|
||||
{
|
||||
std::string str = text;
|
||||
std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
|
||||
|
||||
std::regex re(pat);
|
||||
std::smatch m;
|
||||
|
||||
while (std::regex_search(str, m, re)) {
|
||||
for (auto x : m) {
|
||||
words.push_back(x);
|
||||
}
|
||||
str = m.suffix();
|
||||
}
|
||||
}
|
||||
|
||||
// find the longest tokens that form the words:
|
||||
std::vector<gpt_vocab::id> tokens;
|
||||
for (const auto & word : words) {
|
||||
if (word.empty()) continue;
|
||||
|
||||
int i = 0;
|
||||
int n = word.size();
|
||||
while (i < n) {
|
||||
int j = n;
|
||||
while (j > i) {
|
||||
auto it = vocab.token_to_id.find(word.substr(i, j-i));
|
||||
if (it != vocab.token_to_id.end()) {
|
||||
tokens.push_back(it->second);
|
||||
i = j;
|
||||
break;
|
||||
}
|
||||
--j;
|
||||
}
|
||||
if (i == n) {
|
||||
break;
|
||||
}
|
||||
if (j == i) {
|
||||
auto sub = word.substr(i, 1);
|
||||
if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) {
|
||||
tokens.push_back(vocab.token_to_id.at(sub));
|
||||
} else {
|
||||
fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data());
|
||||
}
|
||||
++i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tokens;
|
||||
}
|
||||
|
||||
gpt_vocab::id gpt_sample_top_k_top_p(
|
||||
const gpt_vocab & vocab,
|
||||
const float * logits,
|
||||
int top_k,
|
||||
double top_p,
|
||||
double /*temp*/,
|
||||
std::mt19937 & rng) {
|
||||
int n_logits = vocab.id_to_token.size();
|
||||
|
||||
std::vector<std::pair<double, gpt_vocab::id>> logits_id;
|
||||
logits_id.reserve(n_logits);
|
||||
|
||||
for (int i = 0; i < n_logits; i++) {
|
||||
logits_id.emplace_back(logits[i], i);
|
||||
}
|
||||
|
||||
// find the top K tokens
|
||||
std::partial_sort(
|
||||
logits_id.begin(),
|
||||
logits_id.begin() + top_k, logits_id.end(),
|
||||
[](const std::pair<double, gpt_vocab::id> & a, const std::pair<double, gpt_vocab::id> & b) {
|
||||
return a.first > b.first;
|
||||
});
|
||||
|
||||
logits_id.resize(top_k);
|
||||
|
||||
// normalize
|
||||
{
|
||||
double sum = 0.0f;
|
||||
for (int i = 0; i < (int)logits_id.size(); i++) {
|
||||
sum += logits_id[i].first;
|
||||
}
|
||||
|
||||
sum = 1.0/sum;
|
||||
for (int i = 0; i < (int)logits_id.size(); i++) {
|
||||
logits_id[i].first *= sum;
|
||||
}
|
||||
}
|
||||
|
||||
if (top_p < 1.0f) {
|
||||
{
|
||||
double cumsum = 0.0f;
|
||||
for (int i = 0; i < top_k; i++) {
|
||||
cumsum += logits_id[i].first;
|
||||
if (cumsum >= top_p) {
|
||||
logits_id.resize(i+1);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// normalize again
|
||||
{
|
||||
double sum = 0.0f;
|
||||
for (int i = 0; i < (int)logits_id.size(); i++) {
|
||||
sum += logits_id[i].first;
|
||||
}
|
||||
|
||||
sum = 1.0/sum;
|
||||
for (int i = 0; i < (int)logits_id.size(); i++) {
|
||||
logits_id[i].first *= sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//printf("\n");
|
||||
//for (int i = 0; i < (int) logits_id.size(); i++) {
|
||||
// printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), logits_id[i].first);
|
||||
//}
|
||||
//exit(0);
|
||||
|
||||
// sample from the obtained distribution
|
||||
std::vector<double> probs;
|
||||
probs.reserve(logits_id.size());
|
||||
|
||||
for (int i = 0; i < (int) logits_id.size(); i++) {
|
||||
probs.push_back(logits_id[i].first);
|
||||
}
|
||||
|
||||
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
||||
int idx = dist(rng);
|
||||
|
||||
return logits_id[idx].second;
|
||||
}
|
||||
|
||||
// default hparams (GPT-2 117M)
|
||||
struct gpt2_hparams {
|
||||
int32_t n_vocab = 50257;
|
||||
@ -165,7 +23,7 @@ struct gpt2_hparams {
|
||||
int32_t n_embd = 768;
|
||||
int32_t n_head = 12;
|
||||
int32_t n_layer = 12;
|
||||
int32_t f16 = 1;
|
||||
int32_t ftype = 1;
|
||||
};
|
||||
|
||||
struct gpt2_layer {
|
||||
@ -187,7 +45,7 @@ struct gpt2_layer {
|
||||
struct ggml_tensor * c_mlp_fc_w;
|
||||
struct ggml_tensor * c_mlp_fc_b;
|
||||
|
||||
struct ggml_tensor * c_mlp_proj_w_trans; // transposed for efficiency
|
||||
struct ggml_tensor * c_mlp_proj_w;
|
||||
struct ggml_tensor * c_mlp_proj_b;
|
||||
};
|
||||
|
||||
@ -198,8 +56,9 @@ struct gpt2_model {
|
||||
struct ggml_tensor * ln_f_g;
|
||||
struct ggml_tensor * ln_f_b;
|
||||
|
||||
struct ggml_tensor * wte; // position embedding
|
||||
struct ggml_tensor * wpe; // token embedding
|
||||
struct ggml_tensor * wte; // position embedding
|
||||
struct ggml_tensor * wpe; // token embedding
|
||||
struct ggml_tensor * lm_head; // language model head
|
||||
|
||||
std::vector<gpt2_layer> layers;
|
||||
|
||||
@ -241,14 +100,14 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
|
||||
fin.read((char *) &hparams.n_embd, sizeof(hparams.n_embd));
|
||||
fin.read((char *) &hparams.n_head, sizeof(hparams.n_head));
|
||||
fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));
|
||||
fin.read((char *) &hparams.f16, sizeof(hparams.f16));
|
||||
fin.read((char *) &hparams.ftype, sizeof(hparams.ftype));
|
||||
|
||||
printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
|
||||
printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx);
|
||||
printf("%s: n_embd = %d\n", __func__, hparams.n_embd);
|
||||
printf("%s: n_head = %d\n", __func__, hparams.n_head);
|
||||
printf("%s: n_layer = %d\n", __func__, hparams.n_layer);
|
||||
printf("%s: f16 = %d\n", __func__, hparams.f16);
|
||||
printf("%s: ftype = %d\n", __func__, hparams.ftype);
|
||||
}
|
||||
|
||||
// load vocab
|
||||
@ -268,16 +127,21 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
|
||||
fin.read((char *) &len, sizeof(len));
|
||||
|
||||
word.resize(len);
|
||||
fin.read((char *) &word[0], len);
|
||||
fin.read((char *) word.data(), len);
|
||||
|
||||
vocab.token_to_id[word] = i;
|
||||
vocab.id_to_token[i] = word;
|
||||
}
|
||||
}
|
||||
|
||||
// for the big tensors, we have the option to store the data in 16-bit floats
|
||||
// for the big tensors, we have the option to store the data in 16-bit floats or quantized
|
||||
// in order to save memory and also to speed up the computation
|
||||
const ggml_type wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
||||
ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype));
|
||||
if (wtype == GGML_TYPE_COUNT) {
|
||||
fprintf(stderr, "%s: invalid model file '%s' (bad ftype value %d)\n",
|
||||
__func__, fname.c_str(), model.hparams.ftype);
|
||||
return false;
|
||||
}
|
||||
|
||||
auto & ctx = model.ctx;
|
||||
|
||||
@ -291,32 +155,33 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
|
||||
const int n_ctx = hparams.n_ctx;
|
||||
const int n_vocab = hparams.n_vocab;
|
||||
|
||||
ctx_size += n_embd*ggml_type_size(GGML_TYPE_F32); // ln_f_g
|
||||
ctx_size += n_embd*ggml_type_size(GGML_TYPE_F32); // ln_f_b
|
||||
ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_g
|
||||
ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_b
|
||||
|
||||
ctx_size += n_vocab*n_embd*ggml_type_size(wtype); // wte
|
||||
ctx_size += n_ctx*n_embd*ggml_type_size(GGML_TYPE_F32); // wpe
|
||||
ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype); // wte
|
||||
ctx_size += n_ctx*n_embd*ggml_type_sizef(GGML_TYPE_F32); // wpe
|
||||
ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype); // lm_head
|
||||
|
||||
ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_1_g
|
||||
ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_1_b
|
||||
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_g
|
||||
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_b
|
||||
|
||||
ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_2_g
|
||||
ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_2_b
|
||||
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_g
|
||||
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_b
|
||||
|
||||
ctx_size += n_layer*(3*n_embd*n_embd*ggml_type_size(wtype)); // c_attn_attn_w
|
||||
ctx_size += n_layer*( 3*n_embd*ggml_type_size(GGML_TYPE_F32)); // c_attn_attn_b
|
||||
ctx_size += n_layer*(3*n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_attn_w
|
||||
ctx_size += n_layer*( 3*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_attn_b
|
||||
|
||||
ctx_size += n_layer*(n_embd*n_embd*ggml_type_size(wtype)); // c_attn_proj_w
|
||||
ctx_size += n_layer*( n_embd*ggml_type_size(GGML_TYPE_F32)); // c_attn_proj_b
|
||||
ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_proj_w
|
||||
ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_proj_b
|
||||
|
||||
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_size(wtype)); // c_mlp_fc_w
|
||||
ctx_size += n_layer*( 4*n_embd*ggml_type_size(GGML_TYPE_F32)); // c_mlp_fc_b
|
||||
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_fc_w
|
||||
ctx_size += n_layer*( 4*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_fc_b
|
||||
|
||||
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_size(wtype)); // c_mlp_proj_w
|
||||
ctx_size += n_layer*( n_embd*ggml_type_size(GGML_TYPE_F32)); // c_mlp_proj_b
|
||||
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_proj_w
|
||||
ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_proj_b
|
||||
|
||||
ctx_size += n_ctx*n_layer*n_embd*ggml_type_size(GGML_TYPE_F32); // memory_k
|
||||
ctx_size += n_ctx*n_layer*n_embd*ggml_type_size(GGML_TYPE_F32); // memory_v
|
||||
ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_k
|
||||
ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_v
|
||||
|
||||
ctx_size += (6 + 12*n_layer)*256; // object overhead
|
||||
|
||||
@ -325,9 +190,11 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
|
||||
|
||||
// create the ggml context
|
||||
{
|
||||
struct ggml_init_params params;
|
||||
params.mem_size = ctx_size;
|
||||
params.mem_buffer = nullptr;
|
||||
struct ggml_init_params params = {
|
||||
.mem_size = ctx_size,
|
||||
.mem_buffer = NULL,
|
||||
.no_alloc = false,
|
||||
};
|
||||
|
||||
model.ctx = ggml_init(params);
|
||||
if (!model.ctx) {
|
||||
@ -350,36 +217,38 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
|
||||
model.ln_f_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
model.ln_f_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
|
||||
model.wte = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab);
|
||||
model.wpe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ctx);
|
||||
model.wte = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab);
|
||||
model.wpe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ctx);
|
||||
model.lm_head = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab);
|
||||
|
||||
// map by name
|
||||
model.tensors["model/ln_f/g"] = model.ln_f_g;
|
||||
model.tensors["model/ln_f/b"] = model.ln_f_b;
|
||||
|
||||
model.tensors["model/wte"] = model.wte;
|
||||
model.tensors["model/wpe"] = model.wpe;
|
||||
model.tensors["model/wte"] = model.wte;
|
||||
model.tensors["model/wpe"] = model.wpe;
|
||||
model.tensors["model/lm_head"] = model.lm_head;
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = model.layers[i];
|
||||
|
||||
layer.ln_1_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
layer.ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
layer.ln_1_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
layer.ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
|
||||
layer.ln_2_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
layer.ln_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
layer.ln_2_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
layer.ln_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
|
||||
layer.c_attn_attn_w = ggml_new_tensor_2d(ctx, wtype, 3*n_embd, n_embd);
|
||||
layer.c_attn_attn_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3*n_embd);
|
||||
layer.c_attn_attn_w = ggml_new_tensor_2d(ctx, wtype, n_embd, 3*n_embd);
|
||||
layer.c_attn_attn_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3*n_embd);
|
||||
|
||||
layer.c_attn_proj_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
|
||||
layer.c_attn_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
layer.c_attn_proj_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
|
||||
layer.c_attn_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
|
||||
layer.c_mlp_fc_w = ggml_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd);
|
||||
layer.c_mlp_fc_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_embd);
|
||||
layer.c_mlp_fc_w = ggml_new_tensor_2d(ctx, wtype, n_embd, 4*n_embd);
|
||||
layer.c_mlp_fc_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_embd);
|
||||
|
||||
layer.c_mlp_proj_w_trans = ggml_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd);
|
||||
layer.c_mlp_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
layer.c_mlp_proj_w = ggml_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd);
|
||||
layer.c_mlp_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
|
||||
// map by name
|
||||
model.tensors["model/h" + std::to_string(i) + "/ln_1/g"] = layer.ln_1_g;
|
||||
@ -397,7 +266,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
|
||||
model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/w"] = layer.c_mlp_fc_w;
|
||||
model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/b"] = layer.c_mlp_fc_b;
|
||||
|
||||
model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/w"] = layer.c_mlp_proj_w_trans;
|
||||
model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/w"] = layer.c_mlp_proj_w;
|
||||
model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/b"] = layer.c_mlp_proj_b;
|
||||
}
|
||||
}
|
||||
@ -425,14 +294,16 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
|
||||
{
|
||||
size_t total_size = 0;
|
||||
|
||||
bool has_lm_head = false;
|
||||
|
||||
while (true) {
|
||||
int32_t n_dims;
|
||||
int32_t length;
|
||||
int32_t ftype;
|
||||
int32_t ttype;
|
||||
|
||||
fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
|
||||
fin.read(reinterpret_cast<char *>(&length), sizeof(length));
|
||||
fin.read(reinterpret_cast<char *>(&ftype), sizeof(ftype));
|
||||
fin.read(reinterpret_cast<char *>(&ttype), sizeof(ttype));
|
||||
|
||||
if (fin.eof()) {
|
||||
break;
|
||||
@ -448,7 +319,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
|
||||
std::string name(length, 0);
|
||||
fin.read(&name[0], length);
|
||||
|
||||
if (model.tensors.find(name) == model.tensors.end()) {
|
||||
if (model.tensors.find(name.data()) == model.tensors.end()) {
|
||||
fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
|
||||
return false;
|
||||
}
|
||||
@ -461,13 +332,18 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
|
||||
|
||||
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
|
||||
fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
|
||||
__func__, name.data(), tensor->ne[0], tensor->ne[1], ne[0], ne[1]);
|
||||
__func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], ne[0], ne[1]);
|
||||
return false;
|
||||
}
|
||||
|
||||
const size_t bpe = (ftype == 0) ? sizeof(float) : sizeof(ggml_fp16_t);
|
||||
// for debugging
|
||||
if (0) {
|
||||
printf("%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", name.data(), ne[0], ne[1], ggml_type_name(ggml_type(ttype)), ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor));
|
||||
}
|
||||
|
||||
if (nelements*bpe != ggml_nbytes(tensor)) {
|
||||
const size_t bpe = ggml_type_size(ggml_type(ttype));
|
||||
|
||||
if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
|
||||
fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
|
||||
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
|
||||
return false;
|
||||
@ -475,7 +351,15 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
|
||||
|
||||
fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
|
||||
|
||||
//printf("%24s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
|
||||
// GPT-2 models share the WTE tensor as the LM head
|
||||
if (name == "model/wte" && has_lm_head == false) {
|
||||
memcpy(model.lm_head->data, tensor->data, ggml_nbytes(tensor));
|
||||
}
|
||||
|
||||
if (name == "model/lm_head") {
|
||||
has_lm_head = true;
|
||||
}
|
||||
|
||||
total_size += ggml_nbytes(tensor);
|
||||
}
|
||||
|
||||
@ -493,7 +377,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
|
||||
// - n_threads: number of threads to use
|
||||
// - n_past: the context size so far
|
||||
// - embd_inp: the embeddings of the tokens in the context
|
||||
// - embd_w: the predicted probabilities of the next token
|
||||
// - embd_w: the predicted logits for the next token
|
||||
//
|
||||
bool gpt2_eval(
|
||||
const gpt2_model & model,
|
||||
@ -512,12 +396,12 @@ bool gpt2_eval(
|
||||
const int n_head = hparams.n_head;
|
||||
const int n_vocab = hparams.n_vocab;
|
||||
|
||||
static size_t buf_size = 5640ull*1024*1024;
|
||||
static size_t buf_size = 512u*1024*1024;
|
||||
static void * buf = malloc(buf_size);
|
||||
|
||||
if (mem_per_token > 0 && mem_per_token*N > buf_size) {
|
||||
const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead
|
||||
printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
|
||||
//printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
|
||||
|
||||
// reallocate
|
||||
buf_size = buf_size_new;
|
||||
@ -528,13 +412,14 @@ bool gpt2_eval(
|
||||
}
|
||||
}
|
||||
|
||||
struct ggml_init_params params;
|
||||
params.mem_size = buf_size;
|
||||
params.mem_buffer = buf;
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ buf_size,
|
||||
/*.mem_buffer =*/ buf,
|
||||
/*.no_alloc =*/ false,
|
||||
};
|
||||
|
||||
struct ggml_context * ctx0 = ggml_init(params);
|
||||
|
||||
struct ggml_cgraph gf = { };
|
||||
struct ggml_cgraph gf = {};
|
||||
gf.n_threads = n_threads;
|
||||
|
||||
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
||||
@ -578,7 +463,7 @@ bool gpt2_eval(
|
||||
// [2304, N]
|
||||
{
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
ggml_transpose(ctx0, model.layers[il].c_attn_attn_w),
|
||||
model.layers[il].c_attn_attn_w,
|
||||
cur);
|
||||
|
||||
cur = ggml_add(ctx0,
|
||||
@ -654,11 +539,13 @@ bool gpt2_eval(
|
||||
// V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
|
||||
// [n_past + N, 64, 12]
|
||||
struct ggml_tensor * V_trans =
|
||||
ggml_permute(ctx0,
|
||||
ggml_reshape_3d(ctx0,
|
||||
ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
|
||||
n_embd/n_head, n_head, n_past + N),
|
||||
1, 2, 0, 3);
|
||||
ggml_cpy(ctx0,
|
||||
ggml_permute(ctx0,
|
||||
ggml_reshape_3d(ctx0,
|
||||
ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
|
||||
n_embd/n_head, n_head, n_past + N),
|
||||
1, 2, 0, 3),
|
||||
ggml_new_tensor_3d(ctx0, model.memory_v->type, n_past + N, n_embd/n_head, n_head));
|
||||
|
||||
// KQV = transpose(V) * KQ_soft_max
|
||||
// [64, N, 12]
|
||||
@ -685,7 +572,7 @@ bool gpt2_eval(
|
||||
// [768, N]
|
||||
{
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
ggml_transpose(ctx0, model.layers[il].c_attn_proj_w),
|
||||
model.layers[il].c_attn_proj_w,
|
||||
cur);
|
||||
|
||||
cur = ggml_add(ctx0,
|
||||
@ -722,7 +609,7 @@ bool gpt2_eval(
|
||||
// cur = fc_w*cur + fc_b
|
||||
// [3072, N]
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
ggml_transpose(ctx0, model.layers[il].c_mlp_fc_w),
|
||||
model.layers[il].c_mlp_fc_w,
|
||||
cur);
|
||||
|
||||
cur = ggml_add(ctx0,
|
||||
@ -742,7 +629,7 @@ bool gpt2_eval(
|
||||
// cur = proj_w*cur + proj_b
|
||||
// [768, N]
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
model.layers[il].c_mlp_proj_w_trans,
|
||||
model.layers[il].c_mlp_proj_w,
|
||||
cur);
|
||||
|
||||
cur = ggml_add(ctx0,
|
||||
@ -769,12 +656,12 @@ bool gpt2_eval(
|
||||
}
|
||||
|
||||
// inpL = WTE * inpL
|
||||
// [ 768, 50257] - model.wte
|
||||
// [ 768, 50257] - model.lm_head
|
||||
// [ 768, N] - inpL
|
||||
inpL = ggml_mul_mat(ctx0, model.wte, inpL);
|
||||
inpL = ggml_mul_mat(ctx0, model.lm_head, inpL);
|
||||
|
||||
// logits -> probs
|
||||
inpL = ggml_soft_max(ctx0, inpL);
|
||||
//inpL = ggml_soft_max(ctx0, inpL);
|
||||
|
||||
// run the computation
|
||||
ggml_build_forward_expand(&gf, inpL);
|
||||
@ -788,7 +675,7 @@ bool gpt2_eval(
|
||||
//embd_w.resize(n_vocab*N);
|
||||
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
|
||||
|
||||
// return result for just the last token
|
||||
// return result just for the last token
|
||||
embd_w.resize(n_vocab);
|
||||
memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
|
||||
|
||||
|
@ -2,18 +2,12 @@
|
||||
|
||||
// TODO: Change to C-style API and move to ./examples for easy reuse.
|
||||
|
||||
#include "common.h"
|
||||
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <string>
|
||||
|
||||
struct gpt_vocab {
|
||||
using id = int32_t;
|
||||
using token = std::string;
|
||||
|
||||
std::map<token, id> token_to_id;
|
||||
std::map<id, token> id_to_token;
|
||||
};
|
||||
|
||||
struct gpt2_context;
|
||||
|
||||
struct gpt2_context * gpt2_init(const char * path_model);
|
||||
|
@ -7,9 +7,13 @@
|
||||
# Mac OS: brew install espeak
|
||||
# Linux: apt-get install espeak
|
||||
#
|
||||
espeak -v en-us+m$1 -s 175 -p 50 -a 200 -g 5 -k 5 "$2"
|
||||
#espeak -v en-us+m$1 -s 175 -p 50 -a 200 -g 5 -k 5 "$2"
|
||||
|
||||
# Mac OS "say" command
|
||||
say "$2"
|
||||
|
||||
# Eleven Labs
|
||||
# To use it, install the elevenlabs module from pip (pip install elevenlabs), register to https://beta.elevenlabs.io to get an api key and paste it in /examples/talk/eleven-labs.py
|
||||
#
|
||||
#wd=$(dirname $0)
|
||||
#script=$wd/eleven-labs.py
|
||||
|
@ -1,16 +1,14 @@
|
||||
// Talk with AI
|
||||
//
|
||||
|
||||
#include "common.h"
|
||||
#include "common-sdl.h"
|
||||
#include "whisper.h"
|
||||
#include "gpt-2.h"
|
||||
|
||||
#include <SDL.h>
|
||||
#include <SDL_audio.h>
|
||||
|
||||
#include <cassert>
|
||||
#include <cstdio>
|
||||
#include <fstream>
|
||||
#include <mutex>
|
||||
#include <regex>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
@ -105,320 +103,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
//
|
||||
// SDL Audio capture
|
||||
//
|
||||
|
||||
class audio_async {
|
||||
public:
|
||||
audio_async(int len_ms);
|
||||
~audio_async();
|
||||
|
||||
bool init(int capture_id, int sample_rate);
|
||||
|
||||
// start capturing audio via the provided SDL callback
|
||||
// keep last len_ms seconds of audio in a circular buffer
|
||||
bool resume();
|
||||
bool pause();
|
||||
bool clear();
|
||||
|
||||
// callback to be called by SDL
|
||||
void callback(uint8_t * stream, int len);
|
||||
|
||||
// get audio data from the circular buffer
|
||||
void get(int ms, std::vector<float> & audio);
|
||||
|
||||
private:
|
||||
SDL_AudioDeviceID m_dev_id_in = 0;
|
||||
|
||||
int m_len_ms = 0;
|
||||
int m_sample_rate = 0;
|
||||
|
||||
bool m_running = false;
|
||||
std::mutex m_mutex;
|
||||
|
||||
std::vector<float> m_audio;
|
||||
std::vector<float> m_audio_new;
|
||||
size_t m_audio_pos = 0;
|
||||
size_t m_audio_len = 0;
|
||||
};
|
||||
|
||||
audio_async::audio_async(int len_ms) {
|
||||
m_len_ms = len_ms;
|
||||
}
|
||||
|
||||
audio_async::~audio_async() {
|
||||
if (m_dev_id_in) {
|
||||
SDL_CloseAudioDevice(m_dev_id_in);
|
||||
}
|
||||
}
|
||||
|
||||
bool audio_async::init(int capture_id, int sample_rate) {
|
||||
SDL_LogSetPriority(SDL_LOG_CATEGORY_APPLICATION, SDL_LOG_PRIORITY_INFO);
|
||||
|
||||
if (SDL_Init(SDL_INIT_AUDIO) < 0) {
|
||||
SDL_LogError(SDL_LOG_CATEGORY_APPLICATION, "Couldn't initialize SDL: %s\n", SDL_GetError());
|
||||
return false;
|
||||
}
|
||||
|
||||
SDL_SetHintWithPriority(SDL_HINT_AUDIO_RESAMPLING_MODE, "medium", SDL_HINT_OVERRIDE);
|
||||
|
||||
{
|
||||
int nDevices = SDL_GetNumAudioDevices(SDL_TRUE);
|
||||
fprintf(stderr, "%s: found %d capture devices:\n", __func__, nDevices);
|
||||
for (int i = 0; i < nDevices; i++) {
|
||||
fprintf(stderr, "%s: - Capture device #%d: '%s'\n", __func__, i, SDL_GetAudioDeviceName(i, SDL_TRUE));
|
||||
}
|
||||
}
|
||||
|
||||
SDL_AudioSpec capture_spec_requested;
|
||||
SDL_AudioSpec capture_spec_obtained;
|
||||
|
||||
SDL_zero(capture_spec_requested);
|
||||
SDL_zero(capture_spec_obtained);
|
||||
|
||||
capture_spec_requested.freq = sample_rate;
|
||||
capture_spec_requested.format = AUDIO_F32;
|
||||
capture_spec_requested.channels = 1;
|
||||
capture_spec_requested.samples = 1024;
|
||||
capture_spec_requested.callback = [](void * userdata, uint8_t * stream, int len) {
|
||||
audio_async * audio = (audio_async *) userdata;
|
||||
audio->callback(stream, len);
|
||||
};
|
||||
capture_spec_requested.userdata = this;
|
||||
|
||||
if (capture_id >= 0) {
|
||||
fprintf(stderr, "%s: attempt to open capture device %d : '%s' ...\n", __func__, capture_id, SDL_GetAudioDeviceName(capture_id, SDL_TRUE));
|
||||
m_dev_id_in = SDL_OpenAudioDevice(SDL_GetAudioDeviceName(capture_id, SDL_TRUE), SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
|
||||
} else {
|
||||
fprintf(stderr, "%s: attempt to open default capture device ...\n", __func__);
|
||||
m_dev_id_in = SDL_OpenAudioDevice(nullptr, SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
|
||||
}
|
||||
|
||||
if (!m_dev_id_in) {
|
||||
fprintf(stderr, "%s: couldn't open an audio device for capture: %s!\n", __func__, SDL_GetError());
|
||||
m_dev_id_in = 0;
|
||||
|
||||
return false;
|
||||
} else {
|
||||
fprintf(stderr, "%s: obtained spec for input device (SDL Id = %d):\n", __func__, m_dev_id_in);
|
||||
fprintf(stderr, "%s: - sample rate: %d\n", __func__, capture_spec_obtained.freq);
|
||||
fprintf(stderr, "%s: - format: %d (required: %d)\n", __func__, capture_spec_obtained.format,
|
||||
capture_spec_requested.format);
|
||||
fprintf(stderr, "%s: - channels: %d (required: %d)\n", __func__, capture_spec_obtained.channels,
|
||||
capture_spec_requested.channels);
|
||||
fprintf(stderr, "%s: - samples per frame: %d\n", __func__, capture_spec_obtained.samples);
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
m_sample_rate = capture_spec_obtained.freq;
|
||||
|
||||
m_audio.resize((m_sample_rate*m_len_ms)/1000);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool audio_async::resume() {
|
||||
if (!m_dev_id_in) {
|
||||
fprintf(stderr, "%s: no audio device to resume!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (m_running) {
|
||||
fprintf(stderr, "%s: already running!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
SDL_PauseAudioDevice(m_dev_id_in, 0);
|
||||
|
||||
m_running = true;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool audio_async::pause() {
|
||||
if (!m_dev_id_in) {
|
||||
fprintf(stderr, "%s: no audio device to pause!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!m_running) {
|
||||
fprintf(stderr, "%s: already paused!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
SDL_PauseAudioDevice(m_dev_id_in, 1);
|
||||
|
||||
m_running = false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool audio_async::clear() {
|
||||
if (!m_dev_id_in) {
|
||||
fprintf(stderr, "%s: no audio device to clear!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!m_running) {
|
||||
fprintf(stderr, "%s: not running!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(m_mutex);
|
||||
|
||||
m_audio_pos = 0;
|
||||
m_audio_len = 0;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// callback to be called by SDL
|
||||
void audio_async::callback(uint8_t * stream, int len) {
|
||||
if (!m_running) {
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t n_samples = len / sizeof(float);
|
||||
|
||||
m_audio_new.resize(n_samples);
|
||||
memcpy(m_audio_new.data(), stream, n_samples * sizeof(float));
|
||||
|
||||
//fprintf(stderr, "%s: %zu samples, pos %zu, len %zu\n", __func__, n_samples, m_audio_pos, m_audio_len);
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(m_mutex);
|
||||
|
||||
if (m_audio_pos + n_samples > m_audio.size()) {
|
||||
const size_t n0 = m_audio.size() - m_audio_pos;
|
||||
|
||||
memcpy(&m_audio[m_audio_pos], stream, n0 * sizeof(float));
|
||||
memcpy(&m_audio[0], &stream[n0], (n_samples - n0) * sizeof(float));
|
||||
|
||||
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
|
||||
m_audio_len = m_audio.size();
|
||||
} else {
|
||||
memcpy(&m_audio[m_audio_pos], stream, n_samples * sizeof(float));
|
||||
|
||||
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
|
||||
m_audio_len = std::min(m_audio_len + n_samples, m_audio.size());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void audio_async::get(int ms, std::vector<float> & result) {
|
||||
if (!m_dev_id_in) {
|
||||
fprintf(stderr, "%s: no audio device to get audio from!\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!m_running) {
|
||||
fprintf(stderr, "%s: not running!\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
result.clear();
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(m_mutex);
|
||||
|
||||
if (ms <= 0) {
|
||||
ms = m_len_ms;
|
||||
}
|
||||
|
||||
size_t n_samples = (m_sample_rate * ms) / 1000;
|
||||
if (n_samples > m_audio_len) {
|
||||
n_samples = m_audio_len;
|
||||
}
|
||||
|
||||
result.resize(n_samples);
|
||||
|
||||
int s0 = m_audio_pos - n_samples;
|
||||
if (s0 < 0) {
|
||||
s0 += m_audio.size();
|
||||
}
|
||||
|
||||
if (s0 + n_samples > m_audio.size()) {
|
||||
const size_t n0 = m_audio.size() - s0;
|
||||
|
||||
memcpy(result.data(), &m_audio[s0], n0 * sizeof(float));
|
||||
memcpy(&result[n0], &m_audio[0], (n_samples - n0) * sizeof(float));
|
||||
} else {
|
||||
memcpy(result.data(), &m_audio[s0], n_samples * sizeof(float));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////
|
||||
|
||||
std::string trim(const std::string & s) {
|
||||
std::regex e("^\\s+|\\s+$");
|
||||
return std::regex_replace(s, e, "");
|
||||
}
|
||||
|
||||
std::string replace(const std::string & s, const std::string & from, const std::string & to) {
|
||||
std::string result = s;
|
||||
size_t pos = 0;
|
||||
while ((pos = result.find(from, pos)) != std::string::npos) {
|
||||
result.replace(pos, from.length(), to);
|
||||
pos += to.length();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate) {
|
||||
const float rc = 1.0f / (2.0f * M_PI * cutoff);
|
||||
const float dt = 1.0f / sample_rate;
|
||||
const float alpha = dt / (rc + dt);
|
||||
|
||||
float y = data[0];
|
||||
|
||||
for (size_t i = 1; i < data.size(); i++) {
|
||||
y = alpha * (y + data[i] - data[i - 1]);
|
||||
data[i] = y;
|
||||
}
|
||||
}
|
||||
|
||||
bool vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) {
|
||||
const int n_samples = pcmf32.size();
|
||||
const int n_samples_last = (sample_rate * last_ms) / 1000;
|
||||
|
||||
if (n_samples_last >= n_samples) {
|
||||
// not enough samples - assume no speech
|
||||
return false;
|
||||
}
|
||||
|
||||
if (freq_thold > 0.0f) {
|
||||
high_pass_filter(pcmf32, freq_thold, sample_rate);
|
||||
}
|
||||
|
||||
float energy_all = 0.0f;
|
||||
float energy_last = 0.0f;
|
||||
|
||||
for (int i = 0; i < n_samples; i++) {
|
||||
energy_all += fabsf(pcmf32[i]);
|
||||
|
||||
if (i >= n_samples - n_samples_last) {
|
||||
energy_last += fabsf(pcmf32[i]);
|
||||
}
|
||||
}
|
||||
|
||||
energy_all /= n_samples;
|
||||
energy_last /= n_samples_last;
|
||||
|
||||
if (verbose) {
|
||||
fprintf(stderr, "%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold);
|
||||
}
|
||||
|
||||
if (energy_last > vad_thold*energy_all) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string transcribe(whisper_context * ctx, const whisper_params & params, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) {
|
||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
@ -557,22 +241,10 @@ int main(int argc, char ** argv) {
|
||||
// main loop
|
||||
while (is_running) {
|
||||
// handle Ctrl + C
|
||||
{
|
||||
SDL_Event event;
|
||||
while (SDL_PollEvent(&event)) {
|
||||
switch (event.type) {
|
||||
case SDL_QUIT:
|
||||
{
|
||||
is_running = false;
|
||||
} break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
is_running = sdl_poll_events();
|
||||
|
||||
if (!is_running) {
|
||||
break;
|
||||
}
|
||||
if (!is_running) {
|
||||
break;
|
||||
}
|
||||
|
||||
// delay
|
||||
@ -583,7 +255,7 @@ int main(int argc, char ** argv) {
|
||||
{
|
||||
audio.get(2000, pcmf32_cur);
|
||||
|
||||
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1250, params.vad_thold, params.freq_thold, params.print_energy) || force_speak) {
|
||||
if (::vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1250, params.vad_thold, params.freq_thold, params.print_energy) || force_speak) {
|
||||
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
|
||||
|
||||
audio.get(params.voice_ms, pcmf32_cur);
|
||||
|
@ -9,4 +9,6 @@ To use:
|
||||
5. Select the "release" active build variant, and use Android Studio to run and deploy to your device.
|
||||
[^1]: I recommend the tiny or base models for running on an Android device.
|
||||
|
||||
<img width="300" alt="image" src="https://user-images.githubusercontent.com/1991296/208154256-82d972dc-221b-48c4-bfcb-36ce68602f93.png">
|
||||
(PS: Do not move this android project folder individually to other folders, because this android project folder depends on the files of the whole project.)
|
||||
|
||||
<img width="300" alt="image" src="https://user-images.githubusercontent.com/1670775/221613663-a17bf770-27ef-45ab-9a46-a5f99ba65d2a.jpg">
|
||||
|
@ -2,6 +2,7 @@ package com.whispercppdemo.ui.main
|
||||
|
||||
import androidx.compose.foundation.layout.*
|
||||
import androidx.compose.foundation.rememberScrollState
|
||||
import androidx.compose.foundation.text.selection.SelectionContainer
|
||||
import androidx.compose.foundation.verticalScroll
|
||||
import androidx.compose.material3.*
|
||||
import androidx.compose.runtime.Composable
|
||||
@ -19,6 +20,7 @@ fun MainScreen(viewModel: MainScreenViewModel) {
|
||||
canTranscribe = viewModel.canTranscribe,
|
||||
isRecording = viewModel.isRecording,
|
||||
messageLog = viewModel.dataLog,
|
||||
onBenchmarkTapped = viewModel::benchmark,
|
||||
onTranscribeSampleTapped = viewModel::transcribeSample,
|
||||
onRecordTapped = viewModel::toggleRecord
|
||||
)
|
||||
@ -30,6 +32,7 @@ private fun MainScreen(
|
||||
canTranscribe: Boolean,
|
||||
isRecording: Boolean,
|
||||
messageLog: String,
|
||||
onBenchmarkTapped: () -> Unit,
|
||||
onTranscribeSampleTapped: () -> Unit,
|
||||
onRecordTapped: () -> Unit
|
||||
) {
|
||||
@ -45,8 +48,11 @@ private fun MainScreen(
|
||||
.padding(innerPadding)
|
||||
.padding(16.dp)
|
||||
) {
|
||||
Row(horizontalArrangement = Arrangement.SpaceBetween) {
|
||||
TranscribeSampleButton(enabled = canTranscribe, onClick = onTranscribeSampleTapped)
|
||||
Column(verticalArrangement = Arrangement.SpaceBetween) {
|
||||
Row(horizontalArrangement = Arrangement.SpaceBetween, modifier = Modifier.fillMaxWidth()) {
|
||||
BenchmarkButton(enabled = canTranscribe, onClick = onBenchmarkTapped)
|
||||
TranscribeSampleButton(enabled = canTranscribe, onClick = onTranscribeSampleTapped)
|
||||
}
|
||||
RecordButton(
|
||||
enabled = canTranscribe,
|
||||
isRecording = isRecording,
|
||||
@ -60,7 +66,16 @@ private fun MainScreen(
|
||||
|
||||
@Composable
|
||||
private fun MessageLog(log: String) {
|
||||
Text(modifier = Modifier.verticalScroll(rememberScrollState()), text = log)
|
||||
SelectionContainer() {
|
||||
Text(modifier = Modifier.verticalScroll(rememberScrollState()), text = log)
|
||||
}
|
||||
}
|
||||
|
||||
@Composable
|
||||
private fun BenchmarkButton(enabled: Boolean, onClick: () -> Unit) {
|
||||
Button(onClick = onClick, enabled = enabled) {
|
||||
Text("Benchmark")
|
||||
}
|
||||
}
|
||||
|
||||
@Composable
|
||||
|
@ -41,10 +41,15 @@ class MainScreenViewModel(private val application: Application) : ViewModel() {
|
||||
|
||||
init {
|
||||
viewModelScope.launch {
|
||||
printSystemInfo()
|
||||
loadData()
|
||||
}
|
||||
}
|
||||
|
||||
private suspend fun printSystemInfo() {
|
||||
printMessage(String.format("System Info: %s\n", WhisperContext.getSystemInfo()));
|
||||
}
|
||||
|
||||
private suspend fun loadData() {
|
||||
printMessage("Loading data...\n")
|
||||
try {
|
||||
@ -81,10 +86,29 @@ class MainScreenViewModel(private val application: Application) : ViewModel() {
|
||||
//whisperContext = WhisperContext.createContextFromFile(firstModel.absolutePath)
|
||||
}
|
||||
|
||||
fun benchmark() = viewModelScope.launch {
|
||||
runBenchmark(6)
|
||||
}
|
||||
|
||||
fun transcribeSample() = viewModelScope.launch {
|
||||
transcribeAudio(getFirstSample())
|
||||
}
|
||||
|
||||
private suspend fun runBenchmark(nthreads: Int) {
|
||||
if (!canTranscribe) {
|
||||
return
|
||||
}
|
||||
|
||||
canTranscribe = false
|
||||
|
||||
printMessage("Running benchmark. This will take minutes...\n")
|
||||
whisperContext?.benchMemory(nthreads)?.let{ printMessage(it) }
|
||||
printMessage("\n")
|
||||
whisperContext?.benchGgmlMulMat(nthreads)?.let{ printMessage(it) }
|
||||
|
||||
canTranscribe = true
|
||||
}
|
||||
|
||||
private suspend fun getFirstSample(): File = withContext(Dispatchers.IO) {
|
||||
samplesPath.listFiles()!!.first()
|
||||
}
|
||||
@ -114,11 +138,14 @@ class MainScreenViewModel(private val application: Application) : ViewModel() {
|
||||
canTranscribe = false
|
||||
|
||||
try {
|
||||
printMessage("Reading wave samples...\n")
|
||||
printMessage("Reading wave samples... ")
|
||||
val data = readAudioSamples(file)
|
||||
printMessage("${data.size / (16000 / 1000)} ms\n")
|
||||
printMessage("Transcribing data...\n")
|
||||
val start = System.currentTimeMillis()
|
||||
val text = whisperContext?.transcribeData(data)
|
||||
printMessage("Done: $text\n")
|
||||
val elapsed = System.currentTimeMillis() - start
|
||||
printMessage("Done ($elapsed ms): $text\n")
|
||||
} catch (e: Exception) {
|
||||
Log.w(LOG_TAG, e)
|
||||
printMessage("${e.localizedMessage}\n")
|
||||
|
@ -27,6 +27,14 @@ class WhisperContext private constructor(private var ptr: Long) {
|
||||
}
|
||||
}
|
||||
|
||||
suspend fun benchMemory(nthreads: Int): String = withContext(scope.coroutineContext) {
|
||||
return@withContext WhisperLib.benchMemcpy(nthreads)
|
||||
}
|
||||
|
||||
suspend fun benchGgmlMulMat(nthreads: Int): String = withContext(scope.coroutineContext) {
|
||||
return@withContext WhisperLib.benchGgmlMulMat(nthreads)
|
||||
}
|
||||
|
||||
suspend fun release() = withContext(scope.coroutineContext) {
|
||||
if (ptr != 0L) {
|
||||
WhisperLib.freeContext(ptr)
|
||||
@ -66,6 +74,10 @@ class WhisperContext private constructor(private var ptr: Long) {
|
||||
}
|
||||
return WhisperContext(ptr)
|
||||
}
|
||||
|
||||
fun getSystemInfo(): String {
|
||||
return WhisperLib.getSystemInfo()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -74,6 +86,7 @@ private class WhisperLib {
|
||||
init {
|
||||
Log.d(LOG_TAG, "Primary ABI: ${Build.SUPPORTED_ABIS[0]}")
|
||||
var loadVfpv4 = false
|
||||
var loadV8fp16 = false
|
||||
if (isArmEabiV7a()) {
|
||||
// armeabi-v7a needs runtime detection support
|
||||
val cpuInfo = cpuInfo()
|
||||
@ -84,11 +97,24 @@ private class WhisperLib {
|
||||
loadVfpv4 = true
|
||||
}
|
||||
}
|
||||
} else if (isArmEabiV8a()) {
|
||||
// ARMv8.2a needs runtime detection support
|
||||
val cpuInfo = cpuInfo()
|
||||
cpuInfo?.let {
|
||||
Log.d(LOG_TAG, "CPU info: $cpuInfo")
|
||||
if (cpuInfo.contains("fphp")) {
|
||||
Log.d(LOG_TAG, "CPU supports fp16 arithmetic")
|
||||
loadV8fp16 = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (loadVfpv4) {
|
||||
Log.d(LOG_TAG, "Loading libwhisper_vfpv4.so")
|
||||
System.loadLibrary("whisper_vfpv4")
|
||||
} else if (loadV8fp16) {
|
||||
Log.d(LOG_TAG, "Loading libwhisper_v8fp16_va.so")
|
||||
System.loadLibrary("whisper_v8fp16_va")
|
||||
} else {
|
||||
Log.d(LOG_TAG, "Loading libwhisper.so")
|
||||
System.loadLibrary("whisper")
|
||||
@ -103,6 +129,9 @@ private class WhisperLib {
|
||||
external fun fullTranscribe(contextPtr: Long, audioData: FloatArray)
|
||||
external fun getTextSegmentCount(contextPtr: Long): Int
|
||||
external fun getTextSegment(contextPtr: Long, index: Int): String
|
||||
external fun getSystemInfo(): String
|
||||
external fun benchMemcpy(nthread: Int): String
|
||||
external fun benchGgmlMulMat(nthread: Int): String
|
||||
}
|
||||
}
|
||||
|
||||
@ -110,6 +139,10 @@ private fun isArmEabiV7a(): Boolean {
|
||||
return Build.SUPPORTED_ABIS[0].equals("armeabi-v7a")
|
||||
}
|
||||
|
||||
private fun isArmEabiV8a(): Boolean {
|
||||
return Build.SUPPORTED_ABIS[0].equals("arm64-v8a")
|
||||
}
|
||||
|
||||
private fun cpuInfo(): String? {
|
||||
return try {
|
||||
File("/proc/cpuinfo").inputStream().bufferedReader().use {
|
||||
|
@ -12,4 +12,15 @@ ifeq ($(TARGET_ARCH_ABI),armeabi-v7a)
|
||||
# https://android.googlesource.com/platform/ndk/+/master/sources/android/cpufeatures/cpu-features.h
|
||||
LOCAL_CFLAGS += -mfpu=neon-vfpv4
|
||||
include $(BUILD_SHARED_LIBRARY)
|
||||
endif
|
||||
endif
|
||||
|
||||
ifeq ($(TARGET_ARCH_ABI),arm64-v8a)
|
||||
include $(CLEAR_VARS)
|
||||
LOCAL_MODULE := libwhisper_v8fp16_va
|
||||
include $(LOCAL_PATH)/Whisper.mk
|
||||
# Allow building NEON FMA code.
|
||||
# https://android.googlesource.com/platform/ndk/+/master/sources/android/cpufeatures/cpu-features.h
|
||||
LOCAL_CFLAGS += -march=armv8.2-a+fp16
|
||||
include $(BUILD_SHARED_LIBRARY)
|
||||
endif
|
||||
|
||||
|
@ -6,6 +6,7 @@
|
||||
#include <sys/sysinfo.h>
|
||||
#include <string.h>
|
||||
#include "whisper.h"
|
||||
#include "ggml.h"
|
||||
|
||||
#define UNUSED(x) (void)(x)
|
||||
#define TAG "JNI"
|
||||
@ -213,4 +214,30 @@ Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_getTextSegment(
|
||||
const char *text = whisper_full_get_segment_text(context, index);
|
||||
jstring string = (*env)->NewStringUTF(env, text);
|
||||
return string;
|
||||
}
|
||||
}
|
||||
|
||||
JNIEXPORT jstring JNICALL
|
||||
Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_getSystemInfo(
|
||||
JNIEnv *env, jobject thiz
|
||||
) {
|
||||
UNUSED(thiz);
|
||||
const char *sysinfo = whisper_print_system_info();
|
||||
jstring string = (*env)->NewStringUTF(env, sysinfo);
|
||||
return string;
|
||||
}
|
||||
|
||||
JNIEXPORT jstring JNICALL
|
||||
Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_benchMemcpy(JNIEnv *env, jobject thiz,
|
||||
jint n_threads) {
|
||||
UNUSED(thiz);
|
||||
const char *bench_ggml_memcpy = whisper_bench_memcpy_str(n_threads);
|
||||
jstring string = (*env)->NewStringUTF(env, bench_ggml_memcpy);
|
||||
}
|
||||
|
||||
JNIEXPORT jstring JNICALL
|
||||
Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_benchGgmlMulMat(JNIEnv *env, jobject thiz,
|
||||
jint n_threads) {
|
||||
UNUSED(thiz);
|
||||
const char *bench_ggml_mul_mat = whisper_bench_ggml_mul_mat_str(n_threads);
|
||||
jstring string = (*env)->NewStringUTF(env, bench_ggml_mul_mat);
|
||||
}
|
||||
|
@ -24,3 +24,5 @@ Also, don't forget to add the `-DGGML_USE_ACCELERATE` compiler flag in Build Pha
|
||||
This can significantly improve the performance of the transcription:
|
||||
|
||||
<img width="1072" alt="image" src="https://user-images.githubusercontent.com/1991296/208511239-8d7cdbd1-aa48-41b5-becd-ca288d53cc07.png">
|
||||
|
||||
In this project, it also added `-O3 -DNDEBUG` to `Other C Flags`, but adding flags to app proj is not ideal in real world (applies to all C/C++ files), consider splitting xcodeproj in workspace in your own project.
|
||||
|
@ -296,6 +296,10 @@
|
||||
IPHONEOS_DEPLOYMENT_TARGET = 16.0;
|
||||
MTL_ENABLE_DEBUG_INFO = NO;
|
||||
MTL_FAST_MATH = YES;
|
||||
OTHER_CFLAGS = (
|
||||
"-O3",
|
||||
"-DNDEBUG",
|
||||
);
|
||||
SDKROOT = iphoneos;
|
||||
VALIDATE_PRODUCT = YES;
|
||||
};
|
||||
|
@ -1,14 +1,18 @@
|
||||
A sample SwiftUI app using [whisper.cpp](https://github.com/ggerganov/whisper.cpp/) to do voice-to-text transcriptions.
|
||||
See also: [whisper.objc](https://github.com/ggerganov/whisper.cpp/tree/master/examples/whisper.objc).
|
||||
|
||||
To use:
|
||||
**Usage**:
|
||||
|
||||
1. Select a model from the [whisper.cpp repository](https://github.com/ggerganov/whisper.cpp/tree/master/models).[^1]
|
||||
2. Add the model to "whisper.swiftui.demo/Resources/models" via Xcode.
|
||||
2. Add the model to `whisper.swiftui.demo/Resources/models` **via Xcode**.
|
||||
3. Select a sample audio file (for example, [jfk.wav](https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav)).
|
||||
4. Add the model to "whisper.swiftui.demo/Resources/samples" via Xcode.
|
||||
5. Select the "release" build configuration under "Run", then deploy and run to your device.
|
||||
4. Add the sample audio file to `whisper.swiftui.demo/Resources/samples` **via Xcode**.
|
||||
5. Select the "Release" [^2] build configuration under "Run", then deploy and run to your device.
|
||||
|
||||
**Note:** Pay attention to the folder path: `whisper.swiftui.demo/Resources/models` is the appropriate directory to place resources whilst `whisper.swiftui.demo/Models` is related to actual code.
|
||||
|
||||
[^1]: I recommend the tiny, base or small models for running on an iOS device.
|
||||
|
||||
[^2]: The `Release` build can boost performance of transcription. In this project, it also added `-O3 -DNDEBUG` to `Other C Flags`, but adding flags to app proj is not ideal in real world (applies to all C/C++ files), consider splitting xcodeproj in workspace in your own project.
|
||||
|
||||

|
||||
|
@ -430,6 +430,10 @@
|
||||
LLVM_LTO = YES;
|
||||
MACOSX_DEPLOYMENT_TARGET = 13.0;
|
||||
MARKETING_VERSION = 1.0;
|
||||
OTHER_CFLAGS = (
|
||||
"-O3",
|
||||
"-DNDEBUG",
|
||||
);
|
||||
PRODUCT_BUNDLE_IDENTIFIER = com.whispercppdemo.WhisperCppDemo;
|
||||
PRODUCT_NAME = "$(TARGET_NAME)";
|
||||
SDKROOT = auto;
|
||||
|
@ -31,9 +31,9 @@ endif()
|
||||
set_target_properties(${TARGET} PROPERTIES LINK_FLAGS " \
|
||||
--bind \
|
||||
-s USE_PTHREADS=1 \
|
||||
-s PTHREAD_POOL_SIZE=8 \
|
||||
-s INITIAL_MEMORY=1024MB \
|
||||
-s TOTAL_MEMORY=1024MB \
|
||||
-s PTHREAD_POOL_SIZE_STRICT=0 \
|
||||
-s INITIAL_MEMORY=2000MB \
|
||||
-s TOTAL_MEMORY=2000MB \
|
||||
-s FORCE_FILESYSTEM=1 \
|
||||
-s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap']\" \
|
||||
${EXTRA_FLAGS} \
|
||||
|
@ -37,6 +37,6 @@ emcmake cmake ..
|
||||
make -j
|
||||
|
||||
# copy the produced page to your HTTP path
|
||||
cp bin/whisper.wasm/* /path/to/html/
|
||||
cp bin/libwhisper.worker.js /path/to/html/
|
||||
cp bin/whisper.wasm/* /path/to/html/
|
||||
cp bin/libmain.worker.js /path/to/html/
|
||||
```
|
||||
|
@ -10,6 +10,12 @@ std::thread g_worker;
|
||||
|
||||
std::vector<struct whisper_context *> g_contexts(4, nullptr);
|
||||
|
||||
static inline int mpow2(int n) {
|
||||
int p = 1;
|
||||
while (p <= n) p *= 2;
|
||||
return p/2;
|
||||
}
|
||||
|
||||
EMSCRIPTEN_BINDINGS(whisper) {
|
||||
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
|
||||
if (g_worker.joinable()) {
|
||||
@ -43,7 +49,7 @@ EMSCRIPTEN_BINDINGS(whisper) {
|
||||
}
|
||||
}));
|
||||
|
||||
emscripten::function("full_default", emscripten::optional_override([](size_t index, const emscripten::val & audio, const std::string & lang, bool translate) {
|
||||
emscripten::function("full_default", emscripten::optional_override([](size_t index, const emscripten::val & audio, const std::string & lang, int nthreads, bool translate) {
|
||||
if (g_worker.joinable()) {
|
||||
g_worker.join();
|
||||
}
|
||||
@ -66,7 +72,7 @@ EMSCRIPTEN_BINDINGS(whisper) {
|
||||
params.print_special = false;
|
||||
params.translate = translate;
|
||||
params.language = whisper_is_multilingual(g_contexts[index]) ? lang.c_str() : "en";
|
||||
params.n_threads = std::min(8, (int) std::thread::hardware_concurrency());
|
||||
params.n_threads = std::min(nthreads, std::min(16, mpow2(std::thread::hardware_concurrency())));
|
||||
params.offset_ms = 0;
|
||||
|
||||
std::vector<float> pcmf32;
|
||||
|
@ -40,19 +40,42 @@
|
||||
|
||||
Note that the computation is quite heavy and may take a few seconds to complete.<br>
|
||||
The transcription results will be displayed in the text area below.<br><br>
|
||||
<b>Important: your browser must support WASM SIMD instructions for this to work.</b>
|
||||
<b>Important:</b>
|
||||
<ul>
|
||||
<li>your browser must support WASM SIMD instructions for this to work</li>
|
||||
<li>Firefox cannot load files larger than 256 MB - use Chrome instead</li>
|
||||
</ul>
|
||||
|
||||
<br><br><hr>
|
||||
<b>More examples:</b>
|
||||
<a href="https://whisper.ggerganov.com/">main</a> |
|
||||
<a href="https://whisper.ggerganov.com/bench">bench</a> |
|
||||
<a href="https://whisper.ggerganov.com/stream">stream</a> |
|
||||
<a href="https://whisper.ggerganov.com/command">command</a> |
|
||||
<a href="https://whisper.ggerganov.com/talk">talk</a> |
|
||||
|
||||
<hr>
|
||||
|
||||
<div id="model">
|
||||
Whisper model: <span id="model-whisper-status"></span>
|
||||
<button id="fetch-whisper-tiny-en" onclick="loadWhisper('tiny.en')">tiny.en (75 MB)</button>
|
||||
<button id="fetch-whisper-tiny" onclick="loadWhisper('tiny')">tiny (75 MB)</button>
|
||||
<button id="fetch-whisper-base-en" onclick="loadWhisper('base.en')">base.en (142 MB)</button>
|
||||
<button id="fetch-whisper-base" onclick="loadWhisper('base')">base (142 MB)</button>
|
||||
<span id="fetch-whisper-progress"></span>
|
||||
|
||||
Whisper models: <span id="model-whisper-status"></span><br><br>
|
||||
<button id="fetch-whisper-tiny-en" onclick="loadWhisper('tiny.en')">tiny.en (75 MB)</button>
|
||||
<button id="fetch-whisper-tiny" onclick="loadWhisper('tiny')">tiny (75 MB)</button>
|
||||
<button id="fetch-whisper-base-en" onclick="loadWhisper('base.en')">base.en (142 MB)</button>
|
||||
<button id="fetch-whisper-base" onclick="loadWhisper('base')">base (142 MB)</button>
|
||||
<button id="fetch-whisper-small-en" onclick="loadWhisper('small.en')">small.en (466 MB)</button>
|
||||
<button id="fetch-whisper-small" onclick="loadWhisper('small')">small (466 MB)</button>
|
||||
<input type="file" id="whisper-file" name="file" onchange="loadFile(event, 'whisper.bin')" />
|
||||
<br><br>
|
||||
Quantized models:<br><br>
|
||||
<button id="fetch-whisper-tiny-en-q5_1" onclick="loadWhisper('tiny-en-q5_1')">tiny.en (Q5_1, 31 MB)</button>
|
||||
<button id="fetch-whisper-tiny-q5_1" onclick="loadWhisper('tiny-q5_1')">tiny (Q5_1, 31 MB)</button>
|
||||
<button id="fetch-whisper-base-en-q5_1" onclick="loadWhisper('base-en-q5_1')">base.en (Q5_1, 57 MB)</button>
|
||||
<button id="fetch-whisper-base-q5_1" onclick="loadWhisper('base-q5_1')">base (Q5_1, 57 MB)</button>
|
||||
<button id="fetch-whisper-small-en-q5_1" onclick="loadWhisper('small-en-q5_1')">small.en (Q5_1, 182 MB)</button>
|
||||
<button id="fetch-whisper-small-q5_1" onclick="loadWhisper('small-q5_1')">small (Q5_1, 182 MB)</button><br>
|
||||
<button id="fetch-whisper-medium-en-q5_0" onclick="loadWhisper('medium-en-q5_0')">medium.en (Q5_0, 515 MB)</button>
|
||||
<button id="fetch-whisper-medium-q5_0" onclick="loadWhisper('medium-q5_0')">medium (Q5_0, 515 MB)</button>
|
||||
<button id="fetch-whisper-large-q5_0" onclick="loadWhisper('large-q5_0')">large (Q5_0, 1030 MB)</button>
|
||||
<span id="fetch-whisper-progress"></span>
|
||||
</div>
|
||||
|
||||
<br>
|
||||
@ -60,8 +83,8 @@
|
||||
<!-- radio button to select between file upload or microphone -->
|
||||
<div id="input">
|
||||
Input:
|
||||
<input type="radio" id="file" name="input" value="file" checked="checked" onchange="changeInput('file')" /> File
|
||||
<input type="radio" id="mic" name="input" value="mic" onchange="changeInput('mic')" /> Microphone
|
||||
<input type="radio" id="file" name="input" value="file" checked="checked" onchange="changeInput('file')" /> <label for="file">File</label>
|
||||
<input type="radio" id="mic" name="input" value="mic" onchange="changeInput('mic')" /> <label for="mic">Microphone</label>
|
||||
</div>
|
||||
|
||||
<br>
|
||||
@ -159,6 +182,12 @@
|
||||
<option value="yi">Yiddish</option>
|
||||
</select>
|
||||
</td>
|
||||
<!-- Slider to select number of threads between 1 and 16 -->
|
||||
<td>
|
||||
Threads:
|
||||
<input type="range" id="threads" name="threads" min="1" max="16" value="8" onchange="changeThreads(this.value)" />
|
||||
<span id="threads-value">8</span>
|
||||
</td>
|
||||
<td>
|
||||
<button onclick="onProcess(false);">Transcribe</button>
|
||||
</td>
|
||||
@ -261,11 +290,13 @@
|
||||
|
||||
Module.FS_createDataFile("/", fname, buf, true, true);
|
||||
|
||||
model_whisper = fname;
|
||||
//model_whisper = fname;
|
||||
|
||||
document.getElementById('model-whisper-status').innerHTML = 'loaded "' + model_whisper + '"!';
|
||||
|
||||
printTextarea('storeFS: stored model: ' + fname + ' size: ' + buf.length);
|
||||
|
||||
document.getElementById('model').innerHTML = 'Model fetched: ' + model_whisper;
|
||||
}
|
||||
|
||||
function loadFile(event, fname) {
|
||||
@ -284,27 +315,64 @@
|
||||
}
|
||||
reader.readAsArrayBuffer(file);
|
||||
|
||||
document.getElementById('fetch-whisper-tiny-en').style.display = 'none';
|
||||
document.getElementById('fetch-whisper-base-en').style.display = 'none';
|
||||
document.getElementById('fetch-whisper-tiny' ).style.display = 'none';
|
||||
document.getElementById('fetch-whisper-base' ).style.display = 'none';
|
||||
document.getElementById('whisper-file' ).style.display = 'none';
|
||||
document.getElementById('model-whisper-status' ).innerHTML = 'loaded model: ' + file.name;
|
||||
document.getElementById('fetch-whisper-tiny-en' ).style.display = 'none';
|
||||
document.getElementById('fetch-whisper-base-en' ).style.display = 'none';
|
||||
document.getElementById('fetch-whisper-small-en').style.display = 'none';
|
||||
document.getElementById('fetch-whisper-tiny' ).style.display = 'none';
|
||||
document.getElementById('fetch-whisper-base' ).style.display = 'none';
|
||||
document.getElementById('fetch-whisper-small' ).style.display = 'none';
|
||||
|
||||
document.getElementById('fetch-whisper-tiny-en-q5_1' ).style.display = 'none';
|
||||
document.getElementById('fetch-whisper-tiny-q5_1' ).style.display = 'none';
|
||||
document.getElementById('fetch-whisper-base-en-q5_1' ).style.display = 'none';
|
||||
document.getElementById('fetch-whisper-base-q5_1' ).style.display = 'none';
|
||||
document.getElementById('fetch-whisper-small-en-q5_1' ).style.display = 'none';
|
||||
document.getElementById('fetch-whisper-small-q5_1' ).style.display = 'none';
|
||||
document.getElementById('fetch-whisper-medium-en-q5_0').style.display = 'none';
|
||||
document.getElementById('fetch-whisper-medium-q5_0' ).style.display = 'none';
|
||||
document.getElementById('fetch-whisper-large-q5_0' ).style.display = 'none';
|
||||
|
||||
document.getElementById('whisper-file' ).style.display = 'none';
|
||||
document.getElementById('model-whisper-status' ).innerHTML = 'loaded model: ' + file.name;
|
||||
}
|
||||
|
||||
function loadWhisper(model) {
|
||||
let urls = {
|
||||
'tiny.en': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en.bin',
|
||||
'tiny': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.bin',
|
||||
'base.en': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en.bin',
|
||||
'base': 'https://whisper.ggerganov.com/ggml-model-whisper-base.bin',
|
||||
'tiny.en': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en.bin',
|
||||
'tiny': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.bin',
|
||||
'base.en': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en.bin',
|
||||
'base': 'https://whisper.ggerganov.com/ggml-model-whisper-base.bin',
|
||||
'small.en': 'https://whisper.ggerganov.com/ggml-model-whisper-small.en.bin',
|
||||
'small': 'https://whisper.ggerganov.com/ggml-model-whisper-small.bin',
|
||||
|
||||
'tiny-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en-q5_1.bin',
|
||||
'tiny-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny-q5_1.bin',
|
||||
'base-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en-q5_1.bin',
|
||||
'base-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-base-q5_1.bin',
|
||||
'small-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-small.en-q5_1.bin',
|
||||
'small-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-small-q5_1.bin',
|
||||
'medium-en-q5_0':'https://whisper.ggerganov.com/ggml-model-whisper-medium.en-q5_0.bin',
|
||||
'medium-q5_0': 'https://whisper.ggerganov.com/ggml-model-whisper-medium-q5_0.bin',
|
||||
'large-q5_0': 'https://whisper.ggerganov.com/ggml-model-whisper-large-q5_0.bin',
|
||||
};
|
||||
|
||||
let sizes = {
|
||||
'tiny.en': 75,
|
||||
'tiny': 75,
|
||||
'base.en': 142,
|
||||
'base': 142,
|
||||
'tiny.en': 75,
|
||||
'tiny': 75,
|
||||
'base.en': 142,
|
||||
'base': 142,
|
||||
'small.en': 466,
|
||||
'small': 466,
|
||||
|
||||
'tiny-en-q5_1': 31,
|
||||
'tiny-q5_1': 31,
|
||||
'base-en-q5_1': 57,
|
||||
'base-q5_1': 57,
|
||||
'small-en-q5_1': 182,
|
||||
'small-q5_1': 182,
|
||||
'medium-en-q5_0': 515,
|
||||
'medium-q5_0': 515,
|
||||
'large-q5_0': 1030,
|
||||
};
|
||||
|
||||
let url = urls[model];
|
||||
@ -313,12 +381,25 @@
|
||||
|
||||
model_whisper = model;
|
||||
|
||||
document.getElementById('fetch-whisper-tiny-en').style.display = 'none';
|
||||
document.getElementById('fetch-whisper-base-en').style.display = 'none';
|
||||
document.getElementById('fetch-whisper-tiny' ).style.display = 'none';
|
||||
document.getElementById('fetch-whisper-base' ).style.display = 'none';
|
||||
document.getElementById('whisper-file' ).style.display = 'none';
|
||||
document.getElementById('model-whisper-status' ).innerHTML = 'loading model: ' + model;
|
||||
document.getElementById('fetch-whisper-tiny-en' ).style.display = 'none';
|
||||
document.getElementById('fetch-whisper-base-en' ).style.display = 'none';
|
||||
document.getElementById('fetch-whisper-small-en').style.display = 'none';
|
||||
document.getElementById('fetch-whisper-tiny' ).style.display = 'none';
|
||||
document.getElementById('fetch-whisper-base' ).style.display = 'none';
|
||||
document.getElementById('fetch-whisper-small' ).style.display = 'none';
|
||||
|
||||
document.getElementById('fetch-whisper-tiny-en-q5_1' ).style.display = 'none';
|
||||
document.getElementById('fetch-whisper-tiny-q5_1' ).style.display = 'none';
|
||||
document.getElementById('fetch-whisper-base-en-q5_1' ).style.display = 'none';
|
||||
document.getElementById('fetch-whisper-base-q5_1' ).style.display = 'none';
|
||||
document.getElementById('fetch-whisper-small-en-q5_1' ).style.display = 'none';
|
||||
document.getElementById('fetch-whisper-small-q5_1' ).style.display = 'none';
|
||||
document.getElementById('fetch-whisper-medium-en-q5_0').style.display = 'none';
|
||||
document.getElementById('fetch-whisper-medium-q5_0' ).style.display = 'none';
|
||||
document.getElementById('fetch-whisper-large-q5_0' ).style.display = 'none';
|
||||
|
||||
document.getElementById('whisper-file' ).style.display = 'none';
|
||||
document.getElementById('model-whisper-status').innerHTML = 'loading model: ' + model;
|
||||
|
||||
cbProgress = function(p) {
|
||||
let el = document.getElementById('fetch-whisper-progress');
|
||||
@ -327,12 +408,26 @@
|
||||
|
||||
cbCancel = function() {
|
||||
var el;
|
||||
el = document.getElementById('fetch-whisper-tiny-en'); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('fetch-whisper-base-en'); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('fetch-whisper-tiny' ); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('fetch-whisper-base' ); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('whisper-file' ); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('model-whisper-status' ); if (el) el.innerHTML = '';
|
||||
|
||||
el = document.getElementById('fetch-whisper-tiny-en' ); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('fetch-whisper-base-en' ); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('fetch-whisper-small-en'); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('fetch-whisper-tiny' ); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('fetch-whisper-base' ); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('fetch-whisper-small' ); if (el) el.style.display = 'inline-block';
|
||||
|
||||
el = document.getElementById('fetch-whisper-tiny-en-q5_1' ); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('fetch-whisper-tiny-q5_1' ); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('fetch-whisper-base-en-q5_1' ); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('fetch-whisper-base-q5_1' ); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('fetch-whisper-small-en-q5_1' ); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('fetch-whisper-small-q5_1' ); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('fetch-whisper-medium-en-q5_0'); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('fetch-whisper-medium-q5_0' ); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('fetch-whisper-large-q5_0' ); if (el) el.style.display = 'inline-block';
|
||||
|
||||
el = document.getElementById('whisper-file' ); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('model-whisper-status'); if (el) el.innerHTML = '';
|
||||
};
|
||||
|
||||
loadRemote(url, dst, size_mb, cbProgress, storeFS, cbCancel, printTextarea);
|
||||
@ -342,7 +437,8 @@
|
||||
// audio file
|
||||
//
|
||||
|
||||
const kMaxAudio_s = 120;
|
||||
const kMaxAudio_s = 30*60;
|
||||
const kMaxRecording_s = 2*60;
|
||||
const kSampleRate = 16000;
|
||||
|
||||
window.AudioContext = window.AudioContext || window.webkitAudioContext;
|
||||
@ -411,7 +507,7 @@
|
||||
doRecording = false;
|
||||
}
|
||||
|
||||
// record up to kMaxAudio_s seconds of audio from the microphone
|
||||
// record up to kMaxRecording_s seconds of audio from the microphone
|
||||
// check if doRecording is false every 1000 ms and stop recording if so
|
||||
// update progress information
|
||||
function startRecording() {
|
||||
@ -467,9 +563,9 @@
|
||||
printTextarea('js: audio recorded, size: ' + audio.length);
|
||||
|
||||
// truncate to first 30 seconds
|
||||
if (audio.length > kMaxAudio_s*kSampleRate) {
|
||||
audio = audio.slice(0, kMaxAudio_s*kSampleRate);
|
||||
printTextarea('js: truncated audio to first ' + kMaxAudio_s + ' seconds');
|
||||
if (audio.length > kMaxRecording_s*kSampleRate) {
|
||||
audio = audio.slice(0, kMaxRecording_s*kSampleRate);
|
||||
printTextarea('js: truncated audio to first ' + kMaxRecording_s + ' seconds');
|
||||
}
|
||||
setAudio(audio);
|
||||
});
|
||||
@ -497,24 +593,31 @@
|
||||
});
|
||||
}
|
||||
|
||||
document.getElementById('progress-bar').style.width = (100*(Date.now() - startTime)/1000/kMaxAudio_s) + '%';
|
||||
document.getElementById('progress-text').innerHTML = (100*(Date.now() - startTime)/1000/kMaxAudio_s).toFixed(0) + '%';
|
||||
document.getElementById('progress-bar').style.width = (100*(Date.now() - startTime)/1000/kMaxRecording_s) + '%';
|
||||
document.getElementById('progress-text').innerHTML = (100*(Date.now() - startTime)/1000/kMaxRecording_s).toFixed(0) + '%';
|
||||
}, 1000);
|
||||
|
||||
printTextarea('js: recording ...');
|
||||
|
||||
setTimeout(function() {
|
||||
if (doRecording) {
|
||||
printTextarea('js: recording stopped after ' + kMaxAudio_s + ' seconds');
|
||||
printTextarea('js: recording stopped after ' + kMaxRecording_s + ' seconds');
|
||||
stopRecording();
|
||||
}
|
||||
}, kMaxAudio_s*1000);
|
||||
}, kMaxRecording_s*1000);
|
||||
}
|
||||
|
||||
//
|
||||
// transcribe
|
||||
//
|
||||
|
||||
var nthreads = 8;
|
||||
|
||||
function changeThreads(value) {
|
||||
nthreads = value;
|
||||
document.getElementById('threads-value').innerHTML = nthreads;
|
||||
}
|
||||
|
||||
function onProcess(translate) {
|
||||
if (!instance) {
|
||||
instance = Module.init('whisper.bin');
|
||||
@ -541,7 +644,7 @@
|
||||
printTextarea('');
|
||||
|
||||
setTimeout(function() {
|
||||
var ret = Module.full_default(instance, audio, document.getElementById('language').value, translate);
|
||||
var ret = Module.full_default(instance, audio, document.getElementById('language').value, nthreads, translate);
|
||||
console.log('js: full_default returned: ' + ret);
|
||||
if (ret) {
|
||||
printTextarea("js: whisper returned: " + ret);
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user