Compare commits
262 Commits
parallel-s
...
gg/fix-ext
Author | SHA1 | Date | |
---|---|---|---|
f25edade2b | |||
74c260fe34 | |||
551529290d | |||
25a90ffa38 | |||
866b67ca93 | |||
d7e9f58f7f | |||
04839bae22 | |||
3cc6e04a52 | |||
b7ef178b9c | |||
47dfe9d4db | |||
1d3270cc8f | |||
a6fb6ab597 | |||
163e74b6c3 | |||
f273e66dc6 | |||
02b4c52c12 | |||
518199c09e | |||
8b17a2f776 | |||
b6d2827914 | |||
9711bae0b3 | |||
eec38f63bd | |||
ef5e6b746f | |||
77bf6b5f56 | |||
b562fff9d0 | |||
b5dec374f4 | |||
fa0dc6167c | |||
55bcd62a4b | |||
0ed762d691 | |||
1b5bb7792e | |||
9b735cea77 | |||
12c462d656 | |||
fc7b0e2c28 | |||
f850a067ed | |||
f75e1197f1 | |||
aa8a75e287 | |||
80e8a2ea39 | |||
19f8048139 | |||
0f80e5a80a | |||
b6559333ff | |||
434b8f3b96 | |||
7a74e929c8 | |||
361ecebe90 | |||
807cbc672e | |||
98ae5276b7 | |||
6adb969b09 | |||
8a7d6ff51a | |||
25f650a8e8 | |||
44e517f074 | |||
cb9de61659 | |||
a2ef80d66f | |||
baa190446a | |||
8f5220d81f | |||
8e391fcf3a | |||
593657054e | |||
ae5c4f7340 | |||
baa30bacdb | |||
3e6fad07aa | |||
e72e4158de | |||
bd41733db2 | |||
23c648e98d | |||
75ab2d06f5 | |||
adc099edee | |||
52cce82493 | |||
ef3c9ed9eb | |||
7fe3ed5e00 | |||
6061241292 | |||
0878ab7c15 | |||
c65edd5b64 | |||
3c8d14e9c5 | |||
c3977cb2ce | |||
6da1661bc2 | |||
cc56540661 | |||
94c1ae8668 | |||
55d54359e0 | |||
d33c2ad354 | |||
9afa7ff624 | |||
0649289f02 | |||
aaeaa43878 | |||
078b8e23bf | |||
74da3e1757 | |||
2d2c93a798 | |||
4bbb60efce | |||
1cf679dec4 | |||
41026c1e4b | |||
d6b9be21d7 | |||
c0329acde8 | |||
fb466b3417 | |||
1f50a7d29f | |||
1de21b913d | |||
4aea058e5a | |||
fd10234363 | |||
8fb5c6a409 | |||
2fe5fbfcc2 | |||
01637e1a4c | |||
1b349eb1f9 | |||
138eaebead | |||
61b9192f27 | |||
161b51d91a | |||
f904b31a7d | |||
f6614155e4 | |||
f5f159c320 | |||
6ebba525f1 | |||
2a5874441d | |||
d08445c9ad | |||
4a945696cb | |||
dabc964d83 | |||
654baf693d | |||
f001a3b7b6 | |||
c615f2c335 | |||
d839dd0242 | |||
435847891c | |||
182f290808 | |||
447dfc11fc | |||
9aa9f3b84e | |||
396ebd1e80 | |||
12490f4398 | |||
db078a9ba8 | |||
a13a7da5ad | |||
519f8e8684 | |||
40ae0962f4 | |||
1560288048 | |||
1ad6fafd91 | |||
70840aed5f | |||
b24d18feb9 | |||
3fa98f4395 | |||
d05b7ee90e | |||
6dcee35129 | |||
5cb345f5e9 | |||
fbcb52d3cd | |||
6b01e3fedd | |||
f7908f9bb8 | |||
00b7a4be02 | |||
04b0a768b8 | |||
87670425f2 | |||
32e71a1861 | |||
9c857cf280 | |||
97b12212dd | |||
9fa34d79ec | |||
a0a64a19dd | |||
bbc23611fa | |||
e9783a1fb4 | |||
9e0cc28792 | |||
73072a7c73 | |||
a8ba1262ff | |||
e66a9a7806 | |||
338442d773 | |||
10651bddf6 | |||
53d4d0b30d | |||
2865e4710b | |||
c46a74a19d | |||
46dc49a6a1 | |||
cc7f872131 | |||
bcc1658cd0 | |||
c46886f599 | |||
29f78392c1 | |||
022756a872 | |||
3b8c2dff57 | |||
0b9af32a8b | |||
11b1b63b14 | |||
0e26a6c92e | |||
66d8f0b7f1 | |||
ba5bcde874 | |||
ab0a8593c5 | |||
668ffc9b23 | |||
9962371f71 | |||
993acb5d41 | |||
a3d0aa73d1 | |||
14c57952f7 | |||
6c369d6788 | |||
4cdd9aad9b | |||
f38c057503 | |||
1e5544b39b | |||
d5673af79f | |||
a28dacec65 | |||
dbe29d4e33 | |||
fe3a67c546 | |||
b138ff2be3 | |||
cf6f1e4181 | |||
620a223814 | |||
f39f9690ec | |||
f9ca90256b | |||
2623640cd6 | |||
d87de61ae6 | |||
f5f485f899 | |||
e77b27c331 | |||
a5cc3dc8a2 | |||
37a709f655 | |||
3a5302108d | |||
d2ee117a0a | |||
db8ccdb850 | |||
d2419030b0 | |||
8986690c2a | |||
9286d3f584 | |||
940de9dbe9 | |||
88112c8afb | |||
375585c07c | |||
fd99ece8e3 | |||
8171e621fc | |||
ec03661b20 | |||
6335933a5b | |||
885b5563d0 | |||
9521ba6801 | |||
29511d33c7 | |||
7bc4d22337 | |||
afce6fa113 | |||
3163090d89 | |||
f0efd0202d | |||
3c28d1a571 | |||
e369243ebd | |||
a0ec3fac54 | |||
6559b538e5 | |||
73d5005880 | |||
6b094b6dfe | |||
641f2f4282 | |||
bfacd9f8ce | |||
f52e74d4dc | |||
23c21e92eb | |||
447d49530c | |||
9d6ebd877c | |||
0ba365f958 | |||
010c8ec3ab | |||
ffdb5c4735 | |||
a5881d619c | |||
34f70b3a56 | |||
8328d1900f | |||
d2bd5f0bdc | |||
34209a37a2 | |||
180e062eda | |||
5c7be85fdc | |||
146169ec38 | |||
9befab5ab9 | |||
9ac88f2b57 | |||
46f5b6cb08 | |||
eff3570f78 | |||
fa19bc4195 | |||
a01b2e0971 | |||
8159a9ab99 | |||
7516d9c16d | |||
46cc26d1b9 | |||
f784f9fa12 | |||
ca23f8ee6d | |||
e2f0eba2d4 | |||
d4353e48f7 | |||
bebf0da983 | |||
848e54f3ad | |||
7883d1cae4 | |||
ccc85b4ff8 | |||
c7606b47df | |||
d38af151a1 | |||
94267df08e | |||
8713c67133 | |||
57a60639bb | |||
bfbaa4dce5 | |||
1d79e78402 | |||
b6c5f49b78 | |||
d4231649e6 | |||
3e5c7feeff | |||
c23598e4ca | |||
54a08bde29 | |||
9f8bbd3fee | |||
3172006a24 | |||
684bc8bd70 | |||
b0502836b8 |
38
.devops/main-cuda.Dockerfile
Normal file
@ -0,0 +1,38 @@
|
||||
ARG UBUNTU_VERSION=22.04
|
||||
# This needs to generally match the container host's environment.
|
||||
ARG CUDA_VERSION=12.3.1
|
||||
# Target the CUDA build image
|
||||
ARG BASE_CUDA_DEV_CONTAINER=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION}
|
||||
# Target the CUDA runtime image
|
||||
ARG BASE_CUDA_RUN_CONTAINER=nvidia/cuda:${CUDA_VERSION}-runtime-ubuntu${UBUNTU_VERSION}
|
||||
|
||||
FROM ${BASE_CUDA_DEV_CONTAINER} AS build
|
||||
WORKDIR /app
|
||||
|
||||
# Unless otherwise specified, we make a fat build.
|
||||
ARG CUDA_DOCKER_ARCH=all
|
||||
# Set nvcc architecture
|
||||
ENV CUDA_DOCKER_ARCH=${CUDA_DOCKER_ARCH}
|
||||
# Enable cuBLAS
|
||||
ENV WHISPER_CUBLAS=1
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y build-essential \
|
||||
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
|
||||
|
||||
# Ref: https://stackoverflow.com/a/53464012
|
||||
ENV CUDA_MAIN_VERSION=12.3
|
||||
ENV LD_LIBRARY_PATH /usr/local/cuda-${CUDA_MAIN_VERSION}/compat:$LD_LIBRARY_PATH
|
||||
|
||||
COPY .. .
|
||||
RUN make
|
||||
|
||||
FROM ${BASE_CUDA_RUN_CONTAINER} AS runtime
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y curl ffmpeg \
|
||||
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
|
||||
|
||||
COPY --from=build /app /app
|
||||
ENTRYPOINT [ "bash", "-c" ]
|
19
.devops/main.Dockerfile
Normal file
@ -0,0 +1,19 @@
|
||||
FROM ubuntu:22.04 AS build
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y build-essential \
|
||||
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
|
||||
|
||||
COPY .. .
|
||||
RUN make
|
||||
|
||||
FROM ubuntu:22.04 AS runtime
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y curl ffmpeg \
|
||||
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
|
||||
|
||||
COPY --from=build /app /app
|
||||
ENTRYPOINT [ "bash", "-c" ]
|
101
.github/workflows/build.yml
vendored
@ -25,6 +25,7 @@ jobs:
|
||||
docker run --platform ${{ matrix.arch }} --rm \
|
||||
-v ${{ github.workspace }}:/workspace \
|
||||
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
|
||||
set -e
|
||||
apt update
|
||||
apt install -y build-essential libsdl2-dev
|
||||
make
|
||||
@ -86,6 +87,7 @@ jobs:
|
||||
docker run --platform ${{ matrix.arch }} --rm \
|
||||
-v ${{ github.workspace }}:/workspace \
|
||||
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
|
||||
set -e
|
||||
apt update
|
||||
apt install -y build-essential cmake libsdl2-dev
|
||||
cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }}
|
||||
@ -113,8 +115,9 @@ jobs:
|
||||
docker run --platform ${{ matrix.arch }} --rm \
|
||||
-v ${{ github.workspace }}:/workspace \
|
||||
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
|
||||
set -e
|
||||
apt update
|
||||
apt install -y build-essential cmake libsdl2-dev
|
||||
apt install -y clang build-essential cmake libsdl2-dev
|
||||
cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_COMPILER=clang
|
||||
make
|
||||
ctest -L gh --output-on-failure'
|
||||
@ -140,6 +143,7 @@ jobs:
|
||||
docker run --platform ${{ matrix.arch }} --rm \
|
||||
-v ${{ github.workspace }}:/workspace \
|
||||
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
|
||||
set -e
|
||||
apt update
|
||||
apt install -y build-essential cmake
|
||||
cmake . -DCMAKE_BUILD_TYPE=Debug -DWHISPER_SANITIZE_${{ matrix.sanitizer }}=ON
|
||||
@ -162,7 +166,7 @@ jobs:
|
||||
s2arc: x64
|
||||
jnaPath: win32-x86-64
|
||||
- sdl2: ON
|
||||
s2ver: 2.26.0
|
||||
s2ver: 2.28.5
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
@ -217,13 +221,16 @@ jobs:
|
||||
sdl2: [ON]
|
||||
include:
|
||||
- arch: Win32
|
||||
obzip: https://github.com/OpenMathLib/OpenBLAS/releases/download/v0.3.24/OpenBLAS-0.3.24-x86.zip
|
||||
obzip: https://github.com/OpenMathLib/OpenBLAS/releases/download/v0.3.25/OpenBLAS-0.3.25-x86.zip
|
||||
s2arc: x86
|
||||
clblast: OFF
|
||||
- arch: x64
|
||||
obzip: https://github.com/OpenMathLib/OpenBLAS/releases/download/v0.3.24/OpenBLAS-0.3.24-x64.zip
|
||||
obzip: https://github.com/OpenMathLib/OpenBLAS/releases/download/v0.3.25/OpenBLAS-0.3.25-x64.zip
|
||||
s2arc: x64
|
||||
clblast: ON
|
||||
clver: 1.6.1
|
||||
- sdl2: ON
|
||||
s2ver: 2.26.0
|
||||
s2ver: 2.28.5
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
@ -248,6 +255,18 @@ jobs:
|
||||
7z x sdl2.zip
|
||||
echo "SDL2_DIR=$env:GITHUB_WORKSPACE/SDL2-${{ matrix.s2ver }}/cmake" >> $env:GITHUB_ENV
|
||||
|
||||
- name: Install OpenCL
|
||||
if: matrix.clblast == 'ON'
|
||||
run: vcpkg.exe --triplet=${{ matrix.arch }}-windows install opencl
|
||||
|
||||
- name: Fetch CLBlast and set CLBlast_DIR
|
||||
if: matrix.clblast == 'ON'
|
||||
run: |
|
||||
C:/msys64/usr/bin/wget.exe -qO clblast.zip https://github.com/CNugteren/CLBlast/releases/download/${{ matrix.clver }}/CLBlast-${{ matrix.clver }}-windows-x64.zip
|
||||
7z x clblast.zip
|
||||
7z x CLBlast-${{ matrix.clver }}-windows-x64.7z
|
||||
echo "CLBlast_DIR=$env:GITHUB_WORKSPACE/CLBlast-${{ matrix.clver }}-windows-x64/lib/cmake/CLBlast" >> $env:GITHUB_ENV
|
||||
|
||||
- name: Configure
|
||||
run: >
|
||||
cmake -S . -B ./build -A ${{ matrix.arch }}
|
||||
@ -255,6 +274,7 @@ jobs:
|
||||
-DWHISPER_OPENBLAS=${{ matrix.blas }}
|
||||
-DCMAKE_LIBRARY_PATH="$env:OPENBLAS_PATH/lib"
|
||||
-DWHISPER_SDL2=${{ matrix.sdl2 }}
|
||||
-DWHISPER_CLBLAST=${{ matrix.clblast }}
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
@ -269,11 +289,15 @@ jobs:
|
||||
if: matrix.sdl2 == 'ON'
|
||||
run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }}
|
||||
|
||||
- name: Copy clblast.dll
|
||||
if: matrix.clblast == 'ON'
|
||||
run: copy "$env:CLBlast_DIR/../../clblast.dll" build/bin/${{ matrix.build }}
|
||||
|
||||
- name: Upload binaries
|
||||
if: matrix.blas == 'ON' && matrix.sdl2 == 'ON'
|
||||
uses: actions/upload-artifact@v1
|
||||
with:
|
||||
name: whisper-blas-bin-${{ matrix.arch }}
|
||||
name: whisper-blas${{ matrix.clblast == 'ON' && '-clblast' || ''}}-bin-${{ matrix.arch }}
|
||||
path: build/bin/${{ matrix.build }}
|
||||
|
||||
windows-cublas:
|
||||
@ -285,11 +309,12 @@ jobs:
|
||||
arch: [x64]
|
||||
cublas: [ON]
|
||||
sdl2: [ON]
|
||||
cuda-toolkit: [12.2.0, 11.8.0]
|
||||
include:
|
||||
- arch: x64
|
||||
s2arc: x64
|
||||
- sdl2: ON
|
||||
s2ver: 2.26.0
|
||||
s2ver: 2.28.5
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
@ -300,7 +325,9 @@ jobs:
|
||||
|
||||
- name: Install CUDA Toolkit
|
||||
id: cuda-toolkit
|
||||
uses: Jimver/cuda-toolkit@v0.2.10
|
||||
uses: Jimver/cuda-toolkit@v0.2.11
|
||||
with:
|
||||
cuda: '${{ matrix.cuda-toolkit }}'
|
||||
|
||||
- name: Fetch SDL2 and set SDL2_DIR
|
||||
if: matrix.sdl2 == 'ON'
|
||||
@ -313,12 +340,20 @@ jobs:
|
||||
run: >
|
||||
cmake -S . -B ./build -A ${{ matrix.arch }}
|
||||
-DCMAKE_BUILD_TYPE=${{ matrix.build }}
|
||||
-DWHISPER_CUBLAS=1
|
||||
-DWHISPER_CUBLAS=${{ matrix.cublas }}
|
||||
-DWHISPER_SDL2=${{ matrix.sdl2 }}
|
||||
|
||||
- name: Build
|
||||
- name: Build ${{ matrix.cuda-toolkit }}
|
||||
run: |
|
||||
cd ./build
|
||||
msbuild ALL_BUILD.vcxproj -t:build -p:configuration=${{ matrix.build }} -p:platform=${{ matrix.arch }}
|
||||
cmake --build . --config ${{ matrix.build }}
|
||||
|
||||
- name: Copy CUDA DLLs
|
||||
run: >
|
||||
Copy-Item -PassThru
|
||||
-Path "${{ steps.cuda-toolkit.outputs.CUDA_PATH }}/bin/*.dll"
|
||||
-Include cudart64_*,cublas64_*,cublasLt64_*
|
||||
-Destination build/bin/${{ matrix.build }}
|
||||
|
||||
- name: Copy SDL2.dll
|
||||
if: matrix.sdl2 == 'ON'
|
||||
@ -328,7 +363,7 @@ jobs:
|
||||
if: matrix.sdl2 == 'ON'
|
||||
uses: actions/upload-artifact@v1
|
||||
with:
|
||||
name: whisper-cublas-bin-${{ matrix.arch }}
|
||||
name: whisper-cublas-${{ matrix.cuda-toolkit }}-bin-${{ matrix.arch }}
|
||||
path: build/bin/${{ matrix.build }}
|
||||
|
||||
emscripten:
|
||||
@ -381,6 +416,14 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
path: whisper
|
||||
|
||||
- name: Clone
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
repository: ggerganov/ggml
|
||||
path: ggml
|
||||
|
||||
- name: Install Java
|
||||
uses: actions/setup-java@v3
|
||||
@ -393,9 +436,41 @@ jobs:
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
cd examples/whisper.android
|
||||
cd whisper/examples/whisper.android
|
||||
./gradlew assembleRelease --no-daemon
|
||||
|
||||
- name: Build with external ggml
|
||||
run: |
|
||||
export PATH_TO_GGML=$PWD/ggml
|
||||
cd whisper/examples/whisper.android
|
||||
./gradlew assembleRelease --no-daemon -PGGML_HOME=$PATH_TO_GGML
|
||||
|
||||
android_java:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: set up JDK 11
|
||||
uses: actions/setup-java@v3
|
||||
with:
|
||||
java-version: '11'
|
||||
distribution: 'temurin'
|
||||
cache: gradle
|
||||
|
||||
- name: Setup Android SDK
|
||||
uses: android-actions/setup-android@v2
|
||||
with:
|
||||
api-level: 30
|
||||
build-tools-version: 30.0.3
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
cd examples/whisper.android.java
|
||||
chmod +x ./gradlew
|
||||
./gradlew assembleRelease
|
||||
|
||||
java:
|
||||
needs: [ 'windows' ]
|
||||
runs-on: windows-latest
|
||||
|
57
.github/workflows/docker.yml
vendored
Normal file
@ -0,0 +1,57 @@
|
||||
name: Publish Docker image
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
|
||||
jobs:
|
||||
push_to_registry:
|
||||
name: Push Docker image to Docker Hub
|
||||
if: github.event.pull_request.draft == false
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
COMMIT_SHA: ${{ github.sha }}
|
||||
strategy:
|
||||
matrix:
|
||||
config:
|
||||
- { tag: "main", dockerfile: ".devops/main.Dockerfile", platform: "linux/amd64,linux/arm64" }
|
||||
- { tag: "main-cuda", dockerfile: ".devops/main-cuda.Dockerfile", platform: "linux/amd64" }
|
||||
|
||||
steps:
|
||||
- name: Check out the repo
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.repository_owner }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Build and push Docker image (versioned)
|
||||
if: github.event_name == 'push'
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
push: true
|
||||
platforms: ${{ matrix.config.platforms }}
|
||||
tags: "ghcr.io/${{ github.repository }}:${{ matrix.config.tag }}-${{ env.COMMIT_SHA }}"
|
||||
file: ${{ matrix.config.dockerfile }}
|
||||
|
||||
- name: Build and push Docker image (tagged)
|
||||
uses: docker/build-push-action@v4
|
||||
with:
|
||||
context: .
|
||||
push: ${{ github.event_name == 'push' }}
|
||||
platforms: ${{ matrix.config.platforms }}
|
||||
tags: "ghcr.io/${{ github.repository }}:${{ matrix.config.tag }}"
|
||||
file: ${{ matrix.config.dockerfile }}
|
5
.gitignore
vendored
@ -31,6 +31,7 @@ build-sanitize-thread/
|
||||
/talk-llama
|
||||
/bench
|
||||
/quantize
|
||||
/server
|
||||
/lsp
|
||||
|
||||
arm_neon.h
|
||||
@ -54,3 +55,7 @@ bindings/java/.idea/
|
||||
.idea/
|
||||
|
||||
benchmark_results.csv
|
||||
cmake-build-debug/
|
||||
.cxx/
|
||||
.gradle/
|
||||
local.properties
|
@ -1,6 +1,6 @@
|
||||
cmake_minimum_required (VERSION 3.5)
|
||||
|
||||
project(whisper.cpp VERSION 1.4.3)
|
||||
project(whisper.cpp VERSION 1.5.4)
|
||||
|
||||
# Add path to modules
|
||||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
|
||||
@ -68,6 +68,7 @@ if (APPLE)
|
||||
option(WHISPER_METAL_NDEBUG "whisper: disable Metal debugging" OFF)
|
||||
option(WHISPER_COREML "whisper: enable Core ML framework" OFF)
|
||||
option(WHISPER_COREML_ALLOW_FALLBACK "whisper: allow non-CoreML fallback" OFF)
|
||||
option(WHISPER_METAL_EMBED_LIBRARY "whisper: embed Metal library" OFF)
|
||||
else()
|
||||
option(WHISPER_BLAS "whisper: use BLAS libraries" OFF)
|
||||
option(WHISPER_BLAS_VENDOR "whisper: BLAS library vendor" Generic)
|
||||
@ -147,6 +148,30 @@ if (APPLE)
|
||||
|
||||
# copy ggml-metal.metal to bin directory
|
||||
configure_file(ggml-metal.metal bin/ggml-metal.metal COPYONLY)
|
||||
|
||||
if (WHISPER_METAL_EMBED_LIBRARY)
|
||||
enable_language(ASM)
|
||||
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_METAL_EMBED_LIBRARY)
|
||||
|
||||
set(METALLIB_SOURCE "${CMAKE_SOURCE_DIR}/ggml-metal.metal")
|
||||
|
||||
file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/autogenerated")
|
||||
set(EMBED_METALLIB_ASSEMBLY "${CMAKE_BINARY_DIR}/autogenerated/ggml-embed-metallib.s")
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${EMBED_METALLIB_ASSEMBLY}
|
||||
COMMAND echo ".section __DATA,__ggml_metallib" > ${EMBED_METALLIB_ASSEMBLY}
|
||||
COMMAND echo ".globl _ggml_metallib_start" >> ${EMBED_METALLIB_ASSEMBLY}
|
||||
COMMAND echo "_ggml_metallib_start:" >> ${EMBED_METALLIB_ASSEMBLY}
|
||||
COMMAND echo ".incbin \\\"${METALLIB_SOURCE}\\\"" >> ${EMBED_METALLIB_ASSEMBLY}
|
||||
COMMAND echo ".globl _ggml_metallib_end" >> ${EMBED_METALLIB_ASSEMBLY}
|
||||
COMMAND echo "_ggml_metallib_end:" >> ${EMBED_METALLIB_ASSEMBLY}
|
||||
DEPENDS ${METALLIB_SOURCE}
|
||||
COMMENT "Generate assembly for embedded Metal library"
|
||||
)
|
||||
|
||||
set(GGML_SOURCES_METAL ${GGML_SOURCES_METAL} ${EMBED_METALLIB_ASSEMBLY})
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (WHISPER_COREML)
|
||||
@ -218,11 +243,17 @@ if (WHISPER_CUBLAS)
|
||||
add_compile_definitions(GGML_USE_CUBLAS)
|
||||
|
||||
if (WHISPER_STATIC)
|
||||
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
|
||||
if (WIN32)
|
||||
# As of 12.3.1 CUDA Tookit for Windows does not offer a static cublas library
|
||||
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas CUDA::cublasLt)
|
||||
else ()
|
||||
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
|
||||
endif()
|
||||
else()
|
||||
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
|
||||
endif()
|
||||
|
||||
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cuda_driver)
|
||||
else()
|
||||
message(FATAL_ERROR "cuBLAS not found")
|
||||
endif()
|
||||
@ -309,7 +340,8 @@ if (WHISPER_ALL_WARNINGS)
|
||||
endif()
|
||||
|
||||
if (NOT MSVC)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Werror=vla")
|
||||
# TODO: temporary disabled until we figure out ggml-metal.m
|
||||
#set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Werror=vla")
|
||||
#set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fno-math-errno -ffinite-math-only -funsafe-math-optimizations")
|
||||
endif()
|
||||
|
||||
@ -338,8 +370,8 @@ else()
|
||||
endif()
|
||||
else()
|
||||
if (EMSCRIPTEN)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -pthread")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread")
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -pthread -s TOTAL_STACK=5242880")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread -s TOTAL_STACK=5242880")
|
||||
else()
|
||||
if(NOT WHISPER_NO_AVX)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx")
|
||||
@ -498,6 +530,7 @@ else()
|
||||
endif()
|
||||
|
||||
if (BUILD_SHARED_LIBS)
|
||||
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
target_link_libraries(${TARGET} PUBLIC
|
||||
${CMAKE_DL_LIBS}
|
||||
)
|
||||
@ -521,7 +554,13 @@ endif()
|
||||
|
||||
if (GGML_SOURCES_CUDA)
|
||||
message(STATUS "GGML CUDA sources found, configuring CUDA architecture")
|
||||
set_property(TARGET whisper PROPERTY CUDA_ARCHITECTURES OFF)
|
||||
# Only configure gmml CUDA architectures is not globally set
|
||||
if (NOT DEFINED GGML_CUDA_ARCHITECTURES)
|
||||
# Not overriden by user, so set defaults
|
||||
set(GGML_CUDA_ARCHITECTURES 52 61 70)
|
||||
endif()
|
||||
message(STATUS "GGML Configuring CUDA architectures ${GGML_CUDA_ARCHITECTURES}")
|
||||
set_property(TARGET whisper PROPERTY CUDA_ARCHITECTURES ${GGML_CUDA_ARCHITECTURES})
|
||||
set_property(TARGET whisper PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto")
|
||||
endif()
|
||||
|
||||
@ -533,7 +572,7 @@ target_compile_definitions(${TARGET} PUBLIC
|
||||
${WHISPER_EXTRA_FLAGS}
|
||||
)
|
||||
|
||||
set_target_properties(${TARGET} PROPERTIES PUBLIC_HEADER "whisper.h")
|
||||
set_target_properties(${TARGET} PROPERTIES PUBLIC_HEADER "ggml.h;whisper.h")
|
||||
|
||||
include(GNUInstallDirs)
|
||||
|
||||
|
57
Makefile
@ -1,4 +1,4 @@
|
||||
default: main bench quantize
|
||||
default: main bench quantize server
|
||||
|
||||
ifndef UNAME_S
|
||||
UNAME_S := $(shell uname -s)
|
||||
@ -42,6 +42,12 @@ CFLAGS = -I. -O3 -DNDEBUG -std=c11 -fPIC
|
||||
CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC
|
||||
LDFLAGS =
|
||||
|
||||
ifdef MACOSX_DEPLOYMENT_TARGET
|
||||
CFLAGS += -mmacosx-version-min=$(MACOSX_DEPLOYMENT_TARGET)
|
||||
CXXFLAGS += -mmacosx-version-min=$(MACOSX_DEPLOYMENT_TARGET)
|
||||
LDFLAGS += -mmacosx-version-min=$(MACOSX_DEPLOYMENT_TARGET)
|
||||
endif
|
||||
|
||||
# clock_gettime came in POSIX.1b (1993)
|
||||
# CLOCK_MONOTONIC came in POSIX.1-2001 / SUSv3 as optional
|
||||
# posix_memalign came in POSIX.1-2001 / SUSv3
|
||||
@ -99,6 +105,16 @@ ifeq ($(filter $(UNAME_S),Linux Darwin DragonFly FreeBSD NetBSD OpenBSD Haiku),$
|
||||
CXXFLAGS += -pthread
|
||||
endif
|
||||
|
||||
# detect Windows
|
||||
ifneq ($(findstring _NT,$(UNAME_S)),)
|
||||
_WIN32 := 1
|
||||
endif
|
||||
|
||||
# Windows Sockets 2 (Winsock) for network-capable apps
|
||||
ifeq ($(_WIN32),1)
|
||||
LWINSOCK2 := -lws2_32
|
||||
endif
|
||||
|
||||
# Architecture specific
|
||||
# TODO: probably these flags need to be tweaked on some architectures
|
||||
# feel free to update the Makefile for your architecture and send a pull request or issue
|
||||
@ -107,7 +123,7 @@ ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686 amd64))
|
||||
CPUINFO_CMD := sysctl machdep.cpu.features machdep.cpu.leaf7_features
|
||||
else ifeq ($(UNAME_S),Linux)
|
||||
CPUINFO_CMD := cat /proc/cpuinfo
|
||||
else ifneq (,$(filter MINGW32_NT% MINGW64_NT%,$(UNAME_S)))
|
||||
else ifneq (,$(filter MINGW32_NT% MINGW64_NT% MSYS_NT%,$(UNAME_S)))
|
||||
CPUINFO_CMD := cat /proc/cpuinfo
|
||||
else ifneq (,$(filter DragonFly FreeBSD,$(UNAME_S)))
|
||||
CPUINFO_CMD := grep Features /var/run/dmesg.boot
|
||||
@ -199,14 +215,14 @@ endif
|
||||
|
||||
ifdef WHISPER_CUBLAS
|
||||
ifeq ($(shell expr $(NVCC_VERSION) \>= 11.6), 1)
|
||||
CUDA_ARCH_FLAG=native
|
||||
CUDA_ARCH_FLAG ?= native
|
||||
else
|
||||
CUDA_ARCH_FLAG=all
|
||||
CUDA_ARCH_FLAG ?= all
|
||||
endif
|
||||
|
||||
CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include
|
||||
CXXFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include
|
||||
LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib
|
||||
LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib
|
||||
WHISPER_OBJ += ggml-cuda.o
|
||||
NVCC = nvcc
|
||||
NVCCFLAGS = --forward-unknown-to-host-compiler -arch=$(CUDA_ARCH_FLAG)
|
||||
@ -329,6 +345,24 @@ ggml-metal.o: ggml-metal.m ggml-metal.h
|
||||
$(CC) $(CFLAGS) -c $< -o $@
|
||||
|
||||
WHISPER_OBJ += ggml-metal.o
|
||||
|
||||
ifdef WHISPER_METAL_EMBED_LIBRARY
|
||||
CFLAGS += -DGGML_METAL_EMBED_LIBRARY
|
||||
|
||||
ggml-metal-embed.o: ggml-metal.metal
|
||||
@echo "Embedding Metal library"
|
||||
$(eval TEMP_ASSEMBLY=$(shell mktemp))
|
||||
@echo ".section __DATA, __ggml_metallib" > $(TEMP_ASSEMBLY)
|
||||
@echo ".globl _ggml_metallib_start" >> $(TEMP_ASSEMBLY)
|
||||
@echo "_ggml_metallib_start:" >> $(TEMP_ASSEMBLY)
|
||||
@echo ".incbin \"$<\"" >> $(TEMP_ASSEMBLY)
|
||||
@echo ".globl _ggml_metallib_end" >> $(TEMP_ASSEMBLY)
|
||||
@echo "_ggml_metallib_end:" >> $(TEMP_ASSEMBLY)
|
||||
@$(AS) $(TEMP_ASSEMBLY) -o $@
|
||||
@rm -f ${TEMP_ASSEMBLY}
|
||||
|
||||
WHISPER_OBJ += ggml-metal-embed.o
|
||||
endif
|
||||
endif
|
||||
|
||||
libwhisper.a: $(WHISPER_OBJ)
|
||||
@ -338,7 +372,7 @@ libwhisper.so: $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) -shared -o libwhisper.so $(WHISPER_OBJ) $(LDFLAGS)
|
||||
|
||||
clean:
|
||||
rm -f *.o main stream command talk talk-llama bench quantize lsp libwhisper.a libwhisper.so
|
||||
rm -f *.o main stream command talk talk-llama bench quantize server lsp libwhisper.a libwhisper.so
|
||||
|
||||
#
|
||||
# Examples
|
||||
@ -359,11 +393,14 @@ bench: examples/bench/bench.cpp $(WHISPER_OBJ)
|
||||
quantize: examples/quantize/quantize.cpp $(WHISPER_OBJ) $(SRC_COMMON)
|
||||
$(CXX) $(CXXFLAGS) examples/quantize/quantize.cpp $(SRC_COMMON) $(WHISPER_OBJ) -o quantize $(LDFLAGS)
|
||||
|
||||
server: examples/server/server.cpp $(SRC_COMMON) $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) examples/server/server.cpp $(SRC_COMMON) $(WHISPER_OBJ) -o server $(LDFLAGS) $(LWINSOCK2)
|
||||
|
||||
stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o stream $(CC_SDL) $(LDFLAGS)
|
||||
|
||||
command: examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o command $(CC_SDL) $(LDFLAGS)
|
||||
command: examples/command/command.cpp examples/grammar-parser.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) examples/command/command.cpp examples/grammar-parser.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o command $(CC_SDL) $(LDFLAGS)
|
||||
|
||||
lsp: examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o lsp $(CC_SDL) $(LDFLAGS)
|
||||
@ -418,9 +455,9 @@ samples:
|
||||
.PHONY: medium
|
||||
.PHONY: large-v1
|
||||
.PHONY: large-v2
|
||||
.PHONY: large
|
||||
.PHONY: large-v3
|
||||
|
||||
tiny.en tiny base.en base small.en small medium.en medium large-v1 large-v2 large: main
|
||||
tiny.en tiny base.en base small.en small medium.en medium large-v1 large-v2 large-v3: main
|
||||
bash ./models/download-ggml-model.sh $@
|
||||
@echo ""
|
||||
@echo "==============================================="
|
||||
|
@ -2,41 +2,26 @@
|
||||
|
||||
import PackageDescription
|
||||
|
||||
#if arch(arm) || arch(arm64)
|
||||
let platforms: [SupportedPlatform]? = [
|
||||
.macOS(.v12),
|
||||
.iOS(.v14),
|
||||
.watchOS(.v4),
|
||||
.tvOS(.v14)
|
||||
]
|
||||
let exclude: [String] = []
|
||||
let resources: [Resource] = [
|
||||
.process("ggml-metal.metal")
|
||||
]
|
||||
let additionalSources: [String] = ["ggml-metal.m"]
|
||||
let additionalSettings: [CSetting] = [
|
||||
.unsafeFlags(["-fno-objc-arc"]),
|
||||
.define("GGML_USE_METAL")
|
||||
]
|
||||
#else
|
||||
let platforms: [SupportedPlatform]? = nil
|
||||
let exclude: [String] = ["ggml-metal.metal"]
|
||||
let resources: [Resource] = []
|
||||
let additionalSources: [String] = []
|
||||
let additionalSettings: [CSetting] = []
|
||||
#endif
|
||||
|
||||
let package = Package(
|
||||
name: "whisper",
|
||||
platforms: platforms,
|
||||
platforms: [
|
||||
.macOS(.v12),
|
||||
.iOS(.v14),
|
||||
.watchOS(.v4),
|
||||
.tvOS(.v14)
|
||||
],
|
||||
products: [
|
||||
.library(name: "whisper", targets: ["whisper"]),
|
||||
],
|
||||
dependencies: [
|
||||
.package(url: "https://github.com/ggerganov/ggml.git", .branch("release"))
|
||||
],
|
||||
targets: [
|
||||
.target(
|
||||
name: "whisper",
|
||||
dependencies: ["ggml"],
|
||||
path: ".",
|
||||
exclude: exclude + [
|
||||
exclude: [
|
||||
"bindings",
|
||||
"cmake",
|
||||
"coreml",
|
||||
@ -51,23 +36,20 @@ let package = Package(
|
||||
"Makefile"
|
||||
],
|
||||
sources: [
|
||||
"ggml.c",
|
||||
"whisper.cpp",
|
||||
"ggml-alloc.c",
|
||||
"ggml-backend.c",
|
||||
"ggml-quants.c"
|
||||
] + additionalSources,
|
||||
resources: resources,
|
||||
],
|
||||
publicHeadersPath: "spm-headers",
|
||||
cSettings: [
|
||||
.unsafeFlags(["-Wno-shorten-64-to-32", "-O3", "-DNDEBUG"]),
|
||||
.define("GGML_USE_ACCELERATE")
|
||||
.define("GGML_USE_ACCELERATE"),
|
||||
.unsafeFlags(["-fno-objc-arc"]),
|
||||
.define("GGML_USE_METAL")
|
||||
// NOTE: NEW_LAPACK will required iOS version 16.4+
|
||||
// We should consider add this in the future when we drop support for iOS 14
|
||||
// (ref: ref: https://developer.apple.com/documentation/accelerate/1513264-cblas_sgemm?language=objc)
|
||||
// .define("ACCELERATE_NEW_LAPACK"),
|
||||
// .define("ACCELERATE_LAPACK_ILP64")
|
||||
] + additionalSettings,
|
||||
],
|
||||
linkerSettings: [
|
||||
.linkedFramework("Accelerate")
|
||||
]
|
||||
|
199
README.md
@ -6,7 +6,7 @@
|
||||
[](https://opensource.org/licenses/MIT)
|
||||
[](https://www.npmjs.com/package/whisper.cpp/)
|
||||
|
||||
Beta: [v1.4.3](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.4.3) / Stable: [v1.2.1](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.2.1) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126)
|
||||
Stable: [v1.5.4](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.5.4) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126)
|
||||
|
||||
High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model:
|
||||
|
||||
@ -16,12 +16,10 @@ High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisp
|
||||
- VSX intrinsics support for POWER architectures
|
||||
- Mixed F16 / F32 precision
|
||||
- [4-bit and 5-bit integer quantization support](https://github.com/ggerganov/whisper.cpp#quantization)
|
||||
- Low memory usage (Flash Attention)
|
||||
- Zero memory allocations at runtime
|
||||
- Support for CPU-only inference
|
||||
- [Partial GPU support for NVIDIA via cuBLAS](https://github.com/ggerganov/whisper.cpp#nvidia-gpu-support-via-cublas)
|
||||
- [Efficient GPU support for NVIDIA](https://github.com/ggerganov/whisper.cpp#nvidia-gpu-support-via-cublas)
|
||||
- [Partial OpenCL GPU support via CLBlast](https://github.com/ggerganov/whisper.cpp#opencl-gpu-support-via-clblast)
|
||||
- [BLAS CPU support via OpenBLAS](https://github.com/ggerganov/whisper.cpp#blas-cpu-support-via-openblas)
|
||||
- [OpenVINO Support](https://github.com/ggerganov/whisper.cpp#openvino-support)
|
||||
- [C-style API](https://github.com/ggerganov/whisper.cpp/blob/master/whisper.h)
|
||||
|
||||
@ -35,11 +33,10 @@ Supported platforms:
|
||||
- [x] [WebAssembly](examples/whisper.wasm)
|
||||
- [x] Windows ([MSVC](https://github.com/ggerganov/whisper.cpp/blob/master/.github/workflows/build.yml#L117-L144) and [MinGW](https://github.com/ggerganov/whisper.cpp/issues/168)]
|
||||
- [x] [Raspberry Pi](https://github.com/ggerganov/whisper.cpp/discussions/166)
|
||||
- [x] [docker](https://github.com/ggerganov/whisper.cpp/pkgs/container/whisper.cpp)
|
||||
|
||||
The entire implementation of the model is contained in 2 source files:
|
||||
|
||||
- Tensor operations: [ggml.h](ggml.h) / [ggml.c](ggml.c)
|
||||
- Transformer inference: [whisper.h](whisper.h) / [whisper.cpp](whisper.cpp)
|
||||
The entire high-level implementation of the model is contained in [whisper.h](whisper.h) and [whisper.cpp](whisper.cpp).
|
||||
The rest of the code is part of the [`ggml`](https://github.com/ggerganov/ggml) machine learning library.
|
||||
|
||||
Having such a lightweight implementation of the model allows to easily integrate it in different platforms and applications.
|
||||
As an example, here is a video of running the model on an iPhone 13 device - fully offline, on-device: [whisper.objc](examples/whisper.objc)
|
||||
@ -64,22 +61,22 @@ Or you can even run it straight in the browser: [talk.wasm](examples/talk.wasm)
|
||||
- Sample real-time audio transcription from the microphone is demonstrated in [stream.cpp](examples/stream)
|
||||
- Various other examples are available in the [examples](examples) folder
|
||||
|
||||
The tensor operators are optimized heavily for Apple silicon CPUs. Depending on the computation size, Arm Neon SIMD
|
||||
intrinsics or CBLAS Accelerate framework routines are used. The latter are especially effective for bigger sizes since
|
||||
the Accelerate framework utilizes the special-purpose AMX coprocessor available in modern Apple products.
|
||||
The tensor operators are optimized heavily for Apple silicon CPUs. Depending on the computation size, Arm Neon SIMD intrinsics or CBLAS Accelerate framework routines are used. The latter are especially effective for bigger sizes since the Accelerate framework utilizes the special-purpose AMX coprocessor available in modern Apple products.
|
||||
|
||||
## Quick start
|
||||
|
||||
First clone the repository.
|
||||
First clone the repository:
|
||||
|
||||
Then, download one of the Whisper models converted in [ggml format](models). For example:
|
||||
```bash
|
||||
git clone https://github.com/ggerganov/whisper.cpp.git
|
||||
```
|
||||
|
||||
Then, download one of the Whisper [models](models/README.md) converted in [`ggml` format](#ggml-format). For example:
|
||||
|
||||
```bash
|
||||
bash ./models/download-ggml-model.sh base.en
|
||||
```
|
||||
|
||||
If you wish to convert the Whisper models to ggml format yourself, instructions are in [models/README.md](models/README.md).
|
||||
|
||||
Now build the [main](examples/main) example and transcribe an audio file like this:
|
||||
|
||||
```bash
|
||||
@ -94,7 +91,7 @@ make
|
||||
|
||||
For a quick demo, simply run `make base.en`:
|
||||
|
||||
```java
|
||||
```text
|
||||
$ make base.en
|
||||
|
||||
cc -I. -O3 -std=c11 -pthread -DGGML_USE_ACCELERATE -c ggml.c -o ggml.o
|
||||
@ -114,8 +111,8 @@ options:
|
||||
-mc N, --max-context N [-1 ] maximum number of text context tokens to store
|
||||
-ml N, --max-len N [0 ] maximum segment length in characters
|
||||
-sow, --split-on-word [false ] split on word rather than on token
|
||||
-bo N, --best-of N [2 ] number of best candidates to keep
|
||||
-bs N, --beam-size N [-1 ] beam size for beam search
|
||||
-bo N, --best-of N [5 ] number of best candidates to keep
|
||||
-bs N, --beam-size N [5 ] beam size for beam search
|
||||
-wt N, --word-thold N [0.01 ] word timestamp probability threshold
|
||||
-et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail
|
||||
-lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail
|
||||
@ -132,6 +129,7 @@ options:
|
||||
-fp, --font-path [/System/Library/Fonts/Supplemental/Courier New Bold.ttf] path to a monospace font for karaoke video
|
||||
-ocsv, --output-csv [false ] output result in a CSV file
|
||||
-oj, --output-json [false ] output result in a JSON file
|
||||
-ojf, --output-json-full [false ] include more information in the JSON file
|
||||
-of FNAME, --output-file FNAME [ ] output file path (without file extension)
|
||||
-ps, --print-special [false ] print special tokens
|
||||
-pc, --print-colors [false ] print colors
|
||||
@ -143,7 +141,8 @@ options:
|
||||
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path
|
||||
-f FNAME, --file FNAME [ ] input WAV file path
|
||||
-oved D, --ov-e-device DNAME [CPU ] the OpenVINO device used for encode inference
|
||||
-ls, --log-score [false ] log best decoder scores of token
|
||||
-ls, --log-score [false ] log best decoder scores of tokens
|
||||
-ng, --no-gpu [false ] disable GPU
|
||||
|
||||
|
||||
bash ./models/download-ggml-model.sh base.en
|
||||
@ -208,7 +207,7 @@ For detailed usage instructions, run: `./main -h`
|
||||
Note that the [main](examples/main) example currently runs only with 16-bit WAV files, so make sure to convert your input before running the tool.
|
||||
For example, you can use `ffmpeg` like this:
|
||||
|
||||
```java
|
||||
```bash
|
||||
ffmpeg -i input.mp3 -ar 16000 -ac 1 -c:a pcm_s16le output.wav
|
||||
```
|
||||
|
||||
@ -235,18 +234,18 @@ make medium.en
|
||||
make medium
|
||||
make large-v1
|
||||
make large-v2
|
||||
make large
|
||||
make large-v3
|
||||
```
|
||||
|
||||
## Memory usage
|
||||
|
||||
| Model | Disk | Mem | SHA |
|
||||
| --- | --- | --- | --- |
|
||||
| tiny | 75 MB | ~125 MB | `bd577a113a864445d4c299885e0cb97d4ba92b5f` |
|
||||
| base | 142 MB | ~210 MB | `465707469ff3a37a2b9b8d8f89f2f99de7299dac` |
|
||||
| small | 466 MB | ~600 MB | `55356645c2b361a969dfd0ef2c5a50d530afd8d5` |
|
||||
| medium | 1.5 GB | ~1.7 GB | `fd9727b6e1217c2f614f9b698455c4ffd82463b4` |
|
||||
| large | 2.9 GB | ~3.3 GB | `ad82bf6a9043ceed055076d0fd39f5f186ff8062` |
|
||||
| Model | Disk | Mem |
|
||||
| ------ | ------- | ------- |
|
||||
| tiny | 75 MiB | ~273 MB |
|
||||
| base | 142 MiB | ~388 MB |
|
||||
| small | 466 MiB | ~852 MB |
|
||||
| medium | 1.5 GiB | ~2.1 GB |
|
||||
| large | 2.9 GiB | ~3.9 GB |
|
||||
|
||||
## Quantization
|
||||
|
||||
@ -279,7 +278,7 @@ speed-up - more than x3 faster compared with CPU-only execution. Here are the in
|
||||
|
||||
- To ensure `coremltools` operates correctly, please confirm that [Xcode](https://developer.apple.com/xcode/) is installed and execute `xcode-select --install` to install the command-line tools.
|
||||
- Python 3.10 is recommended.
|
||||
- [OPTIONAL] It is recommended to utilize a Python version management system, such as [Miniconda](https://docs.conda.io/en/latest/miniconda.html) for this step:
|
||||
- [OPTIONAL] It is recommended to utilize a Python version management system, such as [Miniconda](https://docs.conda.io/en/latest/miniconda.html) for this step:
|
||||
- To create an environment, use: `conda create -n py310-whisper python=3.10 -y`
|
||||
- To activate the environment, use: `conda activate py310-whisper`
|
||||
|
||||
@ -305,8 +304,8 @@ speed-up - more than x3 faster compared with CPU-only execution. Here are the in
|
||||
|
||||
- Run the examples as usual. For example:
|
||||
|
||||
```bash
|
||||
./main -m models/ggml-base.en.bin -f samples/jfk.wav
|
||||
```text
|
||||
$ ./main -m models/ggml-base.en.bin -f samples/jfk.wav
|
||||
|
||||
...
|
||||
|
||||
@ -334,7 +333,8 @@ This can result in significant speedup in encoder performance. Here are the inst
|
||||
- First, setup python virtual env. and install python dependencies. Python 3.10 is recommended.
|
||||
|
||||
Windows:
|
||||
```
|
||||
|
||||
```powershell
|
||||
cd models
|
||||
python -m venv openvino_conv_env
|
||||
openvino_conv_env\Scripts\activate
|
||||
@ -343,7 +343,8 @@ This can result in significant speedup in encoder performance. Here are the inst
|
||||
```
|
||||
|
||||
Linux and macOS:
|
||||
```
|
||||
|
||||
```bash
|
||||
cd models
|
||||
python3 -m venv openvino_conv_env
|
||||
source openvino_conv_env/bin/activate
|
||||
@ -357,7 +358,7 @@ This can result in significant speedup in encoder performance. Here are the inst
|
||||
python convert-whisper-to-openvino.py --model base.en
|
||||
```
|
||||
|
||||
This will produce ggml-base.en-encoder-openvino.xml/.bin IR model files. It's recommended to relocate these to the same folder as ggml models, as that
|
||||
This will produce ggml-base.en-encoder-openvino.xml/.bin IR model files. It's recommended to relocate these to the same folder as `ggml` models, as that
|
||||
is the default location that the OpenVINO extension will search at runtime.
|
||||
|
||||
- Build `whisper.cpp` with OpenVINO support:
|
||||
@ -367,24 +368,28 @@ This can result in significant speedup in encoder performance. Here are the inst
|
||||
After downloading & extracting package onto your development system, set up required environment by sourcing setupvars script. For example:
|
||||
|
||||
Linux:
|
||||
|
||||
```bash
|
||||
source /path/to/l_openvino_toolkit_ubuntu22_2023.0.0.10926.b4452d56304_x86_64/setupvars.sh
|
||||
```
|
||||
|
||||
Windows (cmd):
|
||||
```
|
||||
|
||||
```powershell
|
||||
C:\Path\To\w_openvino_toolkit_windows_2023.0.0.10926.b4452d56304_x86_64\setupvars.bat
|
||||
```
|
||||
|
||||
And then build the project using cmake:
|
||||
|
||||
```bash
|
||||
cmake -B build -DWHISPER_OPENVINO=1
|
||||
cmake --build build -j --config Release
|
||||
```
|
||||
|
||||
- Run the examples as usual. For example:
|
||||
```bash
|
||||
./main -m models/ggml-base.en.bin -f samples/jfk.wav
|
||||
|
||||
```text
|
||||
$ ./main -m models/ggml-base.en.bin -f samples/jfk.wav
|
||||
|
||||
...
|
||||
|
||||
@ -400,12 +405,12 @@ This can result in significant speedup in encoder performance. Here are the inst
|
||||
|
||||
The first time run on an OpenVINO device is slow, since the OpenVINO framework will compile the IR (Intermediate Representation) model to a device-specific 'blob'. This device-specific blob will get
|
||||
cached for the next run.
|
||||
|
||||
|
||||
For more information about the Core ML implementation please refer to PR [#1037](https://github.com/ggerganov/whisper.cpp/pull/1037).
|
||||
|
||||
## NVIDIA GPU support via cuBLAS
|
||||
## NVIDIA GPU support
|
||||
|
||||
With NVIDIA cards the Encoder processing can to a large extent be offloaded to the GPU through cuBLAS.
|
||||
With NVIDIA cards the processing of the models is done efficiently on the GPU via cuBLAS and custom CUDA kernels.
|
||||
First, make sure you have installed `cuda`: https://developer.nvidia.com/cuda-downloads
|
||||
|
||||
Now build `whisper.cpp` with cuBLAS support:
|
||||
@ -435,7 +440,6 @@ cmake -B build -DWHISPER_CLBLAST=ON
|
||||
cmake --build build -j --config Release
|
||||
```
|
||||
|
||||
|
||||
Run all the examples as usual.
|
||||
|
||||
## BLAS CPU support via OpenBLAS
|
||||
@ -450,6 +454,38 @@ make clean
|
||||
WHISPER_OPENBLAS=1 make -j
|
||||
```
|
||||
|
||||
## Docker
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Docker must be installed and running on your system.
|
||||
- Create a folder to store big models & intermediate files (ex. /whisper/models)
|
||||
|
||||
### Images
|
||||
|
||||
We have two Docker images available for this project:
|
||||
|
||||
1. `ghcr.io/ggerganov/whisper.cpp:main`: This image includes the main executable file as well as `curl` and `ffmpeg`. (platforms: `linux/amd64`, `linux/arm64`)
|
||||
2. `ghcr.io/ggerganov/whisper.cpp:main-cuda`: Same as `main` but compiled with CUDA support. (platforms: `linux/amd64`)
|
||||
|
||||
### Usage
|
||||
|
||||
```shell
|
||||
# download model and persist it in a local folder
|
||||
docker run -it --rm \
|
||||
-v path/to/models:/models \
|
||||
whisper.cpp:main "./models/download-ggml-model.sh base /models"
|
||||
# transcribe an audio file
|
||||
docker run -it --rm \
|
||||
-v path/to/models:/models \
|
||||
-v path/to/audios:/audios \
|
||||
whisper.cpp:main "./main -m /models/ggml-base.bin -f /audios/jfk.wav"
|
||||
# transcribe an audio file in samples folder
|
||||
docker run -it --rm \
|
||||
-v path/to/models:/models \
|
||||
whisper.cpp:main "./main -m /models/ggml-base.bin -f ./samples/jfk.wav"
|
||||
```
|
||||
|
||||
## Limitations
|
||||
|
||||
- Inference only
|
||||
@ -462,7 +498,7 @@ in about half a minute on a MacBook M1 Pro, using `medium.en` model:
|
||||
<details>
|
||||
<summary>Expand to see the result</summary>
|
||||
|
||||
```java
|
||||
```text
|
||||
$ ./main -m models/ggml-medium.en.bin -f samples/gb1.wav -t 8
|
||||
|
||||
whisper_init_from_file: loading model from 'models/ggml-medium.en.bin'
|
||||
@ -534,6 +570,7 @@ whisper_print_timings: encode time = 18665.10 ms / 9 runs ( 2073.90 ms per
|
||||
whisper_print_timings: decode time = 13090.93 ms / 549 runs ( 23.85 ms per run)
|
||||
whisper_print_timings: total time = 32733.52 ms
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## Real-time audio input example
|
||||
@ -542,7 +579,7 @@ This is a naive example of performing real-time inference on audio from your mic
|
||||
The [stream](examples/stream) tool samples the audio every half a second and runs the transcription continuously.
|
||||
More info is available in [issue #10](https://github.com/ggerganov/whisper.cpp/issues/10).
|
||||
|
||||
```java
|
||||
```bash
|
||||
make stream
|
||||
./stream -m ./models/ggml-base.en.bin -t 8 --step 500 --length 5000
|
||||
```
|
||||
@ -554,7 +591,7 @@ https://user-images.githubusercontent.com/1991296/194935793-76afede7-cfa8-48d8-a
|
||||
Adding the `--print-colors` argument will print the transcribed text using an experimental color coding strategy
|
||||
to highlight words with high or low confidence:
|
||||
|
||||
```java
|
||||
```bash
|
||||
./main -m models/ggml-base.en.bin -f samples/gb0.wav --print-colors
|
||||
```
|
||||
|
||||
@ -564,8 +601,8 @@ to highlight words with high or low confidence:
|
||||
|
||||
For example, to limit the line length to a maximum of 16 characters, simply add `-ml 16`:
|
||||
|
||||
```java
|
||||
./main -m ./models/ggml-base.en.bin -f ./samples/jfk.wav -ml 16
|
||||
```text
|
||||
$ ./main -m ./models/ggml-base.en.bin -f ./samples/jfk.wav -ml 16
|
||||
|
||||
whisper_model_load: loading model from './models/ggml-base.en.bin'
|
||||
...
|
||||
@ -588,8 +625,8 @@ main: processing './samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 pr
|
||||
|
||||
The `--max-len` argument can be used to obtain word-level timestamps. Simply use `-ml 1`:
|
||||
|
||||
```java
|
||||
./main -m ./models/ggml-base.en.bin -f ./samples/jfk.wav -ml 1
|
||||
```text
|
||||
$ ./main -m ./models/ggml-base.en.bin -f ./samples/jfk.wav -ml 1
|
||||
|
||||
whisper_model_load: loading model from './models/ggml-base.en.bin'
|
||||
...
|
||||
@ -659,7 +696,7 @@ This requires to have `ffmpeg` installed.
|
||||
|
||||
Here are a few *"typical"* examples:
|
||||
|
||||
```java
|
||||
```bash
|
||||
./main -m ./models/ggml-base.en.bin -f ./samples/jfk.wav -owts
|
||||
source ./samples/jfk.wav.wts
|
||||
ffplay ./samples/jfk.wav.mp4
|
||||
@ -669,7 +706,7 @@ https://user-images.githubusercontent.com/1991296/199337465-dbee4b5e-9aeb-48a3-b
|
||||
|
||||
---
|
||||
|
||||
```java
|
||||
```bash
|
||||
./main -m ./models/ggml-base.en.bin -f ./samples/mm0.wav -owts
|
||||
source ./samples/mm0.wav.wts
|
||||
ffplay ./samples/mm0.wav.mp4
|
||||
@ -679,7 +716,7 @@ https://user-images.githubusercontent.com/1991296/199337504-cc8fd233-0cb7-4920-9
|
||||
|
||||
---
|
||||
|
||||
```java
|
||||
```bash
|
||||
./main -m ./models/ggml-base.en.bin -f ./samples/gb0.wav -owts
|
||||
source ./samples/gb0.wav.wts
|
||||
ffplay ./samples/gb0.wav.mp4
|
||||
@ -693,7 +730,7 @@ https://user-images.githubusercontent.com/1991296/199337538-b7b0c7a3-2753-4a88-a
|
||||
|
||||
Use the [extra/bench-wts.sh](https://github.com/ggerganov/whisper.cpp/blob/master/extra/bench-wts.sh) script to generate a video in the following format:
|
||||
|
||||
```java
|
||||
```bash
|
||||
./extra/bench-wts.sh samples/jfk.wav
|
||||
ffplay ./samples/jfk.wav.all.mp4
|
||||
```
|
||||
@ -722,8 +759,7 @@ It is written in python with the intention of being easy to modify and extend fo
|
||||
|
||||
It outputs a csv file with the results of the benchmarking.
|
||||
|
||||
|
||||
## ggml format
|
||||
## `ggml` format
|
||||
|
||||
The original models are converted to a custom binary format. This allows to pack everything needed into a single file:
|
||||
|
||||
@ -738,49 +774,50 @@ or manually from here:
|
||||
- https://huggingface.co/ggerganov/whisper.cpp
|
||||
- https://ggml.ggerganov.com
|
||||
|
||||
For more details, see the conversion script [models/convert-pt-to-ggml.py](models/convert-pt-to-ggml.py) or the README
|
||||
in [models](models).
|
||||
For more details, see the conversion script [models/convert-pt-to-ggml.py](models/convert-pt-to-ggml.py) or [models/README.md](models/README.md).
|
||||
|
||||
## [Bindings](https://github.com/ggerganov/whisper.cpp/discussions/categories/bindings)
|
||||
|
||||
- [X] Rust: [tazz4843/whisper-rs](https://github.com/tazz4843/whisper-rs) | [#310](https://github.com/ggerganov/whisper.cpp/discussions/310)
|
||||
- [X] JavaScript: [bindings/javascript](bindings/javascript) | [#309](https://github.com/ggerganov/whisper.cpp/discussions/309)
|
||||
- [x] Rust: [tazz4843/whisper-rs](https://github.com/tazz4843/whisper-rs) | [#310](https://github.com/ggerganov/whisper.cpp/discussions/310)
|
||||
- [x] JavaScript: [bindings/javascript](bindings/javascript) | [#309](https://github.com/ggerganov/whisper.cpp/discussions/309)
|
||||
- React Native (iOS / Android): [whisper.rn](https://github.com/mybigday/whisper.rn)
|
||||
- [X] Go: [bindings/go](bindings/go) | [#312](https://github.com/ggerganov/whisper.cpp/discussions/312)
|
||||
- [X] Java:
|
||||
- [x] Go: [bindings/go](bindings/go) | [#312](https://github.com/ggerganov/whisper.cpp/discussions/312)
|
||||
- [x] Java:
|
||||
- [GiviMAD/whisper-jni](https://github.com/GiviMAD/whisper-jni)
|
||||
- [X] Ruby: [bindings/ruby](bindings/ruby) | [#507](https://github.com/ggerganov/whisper.cpp/discussions/507)
|
||||
- [X] Objective-C / Swift: [ggerganov/whisper.spm](https://github.com/ggerganov/whisper.spm) | [#313](https://github.com/ggerganov/whisper.cpp/discussions/313)
|
||||
- [x] Ruby: [bindings/ruby](bindings/ruby) | [#507](https://github.com/ggerganov/whisper.cpp/discussions/507)
|
||||
- [x] Objective-C / Swift: [ggerganov/whisper.spm](https://github.com/ggerganov/whisper.spm) | [#313](https://github.com/ggerganov/whisper.cpp/discussions/313)
|
||||
- [exPHAT/SwiftWhisper](https://github.com/exPHAT/SwiftWhisper)
|
||||
- [X] .NET: | [#422](https://github.com/ggerganov/whisper.cpp/discussions/422)
|
||||
- [x] .NET: | [#422](https://github.com/ggerganov/whisper.cpp/discussions/422)
|
||||
- [sandrohanea/whisper.net](https://github.com/sandrohanea/whisper.net)
|
||||
- [NickDarvey/whisper](https://github.com/NickDarvey/whisper)
|
||||
- [X] Python: | [#9](https://github.com/ggerganov/whisper.cpp/issues/9)
|
||||
- [x] Python: | [#9](https://github.com/ggerganov/whisper.cpp/issues/9)
|
||||
- [stlukey/whispercpp.py](https://github.com/stlukey/whispercpp.py) (Cython)
|
||||
- [aarnphm/whispercpp](https://github.com/aarnphm/whispercpp) (Pybind11)
|
||||
- [X] R: [bnosac/audio.whisper](https://github.com/bnosac/audio.whisper)
|
||||
- [X] Unity: [macoron/whisper.unity](https://github.com/Macoron/whisper.unity)
|
||||
- [x] R: [bnosac/audio.whisper](https://github.com/bnosac/audio.whisper)
|
||||
- [x] Unity: [macoron/whisper.unity](https://github.com/Macoron/whisper.unity)
|
||||
|
||||
## Examples
|
||||
|
||||
There are various examples of using the library for different projects in the [examples](examples) folder.
|
||||
Some of the examples are even ported to run in the browser using WebAssembly. Check them out!
|
||||
|
||||
| Example | Web | Description |
|
||||
| --- | --- | --- |
|
||||
| [main](examples/main) | [whisper.wasm](examples/whisper.wasm) | Tool for translating and transcribing audio using Whisper |
|
||||
| [bench](examples/bench) | [bench.wasm](examples/bench.wasm) | Benchmark the performance of Whisper on your machine |
|
||||
| [stream](examples/stream) | [stream.wasm](examples/stream.wasm) | Real-time transcription of raw microphone capture |
|
||||
| [command](examples/command) | [command.wasm](examples/command.wasm) | Basic voice assistant example for receiving voice commands from the mic |
|
||||
| [talk](examples/talk) | [talk.wasm](examples/talk.wasm) | Talk with a GPT-2 bot |
|
||||
| [talk-llama](examples/talk-llama) | | Talk with a LLaMA bot |
|
||||
| [whisper.objc](examples/whisper.objc) | | iOS mobile application using whisper.cpp |
|
||||
| [whisper.swiftui](examples/whisper.swiftui) | | SwiftUI iOS / macOS application using whisper.cpp |
|
||||
| [whisper.android](examples/whisper.android) | | Android mobile application using whisper.cpp |
|
||||
| [whisper.nvim](examples/whisper.nvim) | | Speech-to-text plugin for Neovim |
|
||||
| [generate-karaoke.sh](examples/generate-karaoke.sh) | | Helper script to easily [generate a karaoke video](https://youtu.be/uj7hVta4blM) of raw audio capture |
|
||||
| [livestream.sh](examples/livestream.sh) | | [Livestream audio transcription](https://github.com/ggerganov/whisper.cpp/issues/185) |
|
||||
| [yt-wsp.sh](examples/yt-wsp.sh) | | Download + transcribe and/or translate any VOD [(original)](https://gist.github.com/DaniruKun/96f763ec1a037cc92fe1a059b643b818) |
|
||||
| Example | Web | Description |
|
||||
| --------------------------------------------------- | ------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| [main](examples/main) | [whisper.wasm](examples/whisper.wasm) | Tool for translating and transcribing audio using Whisper |
|
||||
| [bench](examples/bench) | [bench.wasm](examples/bench.wasm) | Benchmark the performance of Whisper on your machine |
|
||||
| [stream](examples/stream) | [stream.wasm](examples/stream.wasm) | Real-time transcription of raw microphone capture |
|
||||
| [command](examples/command) | [command.wasm](examples/command.wasm) | Basic voice assistant example for receiving voice commands from the mic |
|
||||
| [wchess](examples/wchess) | [wchess.wasm](examples/wchess) | Voice-controlled chess |
|
||||
| [talk](examples/talk) | [talk.wasm](examples/talk.wasm) | Talk with a GPT-2 bot |
|
||||
| [talk-llama](examples/talk-llama) | | Talk with a LLaMA bot |
|
||||
| [whisper.objc](examples/whisper.objc) | | iOS mobile application using whisper.cpp |
|
||||
| [whisper.swiftui](examples/whisper.swiftui) | | SwiftUI iOS / macOS application using whisper.cpp |
|
||||
| [whisper.android](examples/whisper.android) | | Android mobile application using whisper.cpp |
|
||||
| [whisper.nvim](examples/whisper.nvim) | | Speech-to-text plugin for Neovim |
|
||||
| [generate-karaoke.sh](examples/generate-karaoke.sh) | | Helper script to easily [generate a karaoke video](https://youtu.be/uj7hVta4blM) of raw audio capture |
|
||||
| [livestream.sh](examples/livestream.sh) | | [Livestream audio transcription](https://github.com/ggerganov/whisper.cpp/issues/185) |
|
||||
| [yt-wsp.sh](examples/yt-wsp.sh) | | Download + transcribe and/or translate any VOD [(original)](https://gist.github.com/DaniruKun/96f763ec1a037cc92fe1a059b643b818) |
|
||||
| [server](examples/server) | | HTTP transcription server with OAI-like API |
|
||||
|
||||
## [Discussions](https://github.com/ggerganov/whisper.cpp/discussions)
|
||||
|
||||
|
@ -1,9 +1,26 @@
|
||||
ifndef UNAME_S
|
||||
UNAME_S := $(shell uname -s)
|
||||
endif
|
||||
|
||||
ifndef UNAME_P
|
||||
UNAME_P := $(shell uname -p)
|
||||
endif
|
||||
|
||||
ifndef UNAME_M
|
||||
UNAME_M := $(shell uname -m)
|
||||
endif
|
||||
|
||||
GGML_METAL_PATH_RESOURCES := $(abspath ../..)
|
||||
BUILD_DIR := build
|
||||
MODELS_DIR := models
|
||||
EXAMPLES_DIR := $(wildcard examples/*)
|
||||
INCLUDE_PATH := $(abspath ../..)
|
||||
LIBRARY_PATH := $(abspath ../..)
|
||||
|
||||
ifeq ($(UNAME_S),Darwin)
|
||||
EXT_LDFLAGS := -framework Foundation -framework Metal -framework MetalKit
|
||||
endif
|
||||
|
||||
all: clean whisper examples
|
||||
|
||||
whisper: mkdir
|
||||
@ -11,8 +28,13 @@ whisper: mkdir
|
||||
@${MAKE} -C ../.. libwhisper.a
|
||||
|
||||
test: model-small whisper modtidy
|
||||
ifeq ($(UNAME_S),Darwin)
|
||||
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} GGML_METAL_PATH_RESOURCES=${GGML_METAL_PATH_RESOURCES} go test -ldflags "-extldflags '$(EXT_LDFLAGS)'" -v .
|
||||
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} GGML_METAL_PATH_RESOURCES=${GGML_METAL_PATH_RESOURCES} go test -ldflags "-extldflags '$(EXT_LDFLAGS)'" -v ./pkg/whisper/...
|
||||
else
|
||||
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go test -v .
|
||||
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go test -v ./pkg/whisper/...
|
||||
endif
|
||||
|
||||
examples: $(EXAMPLES_DIR)
|
||||
|
||||
@ -21,7 +43,11 @@ model-small: mkdir examples/go-model-download
|
||||
|
||||
$(EXAMPLES_DIR): mkdir whisper modtidy
|
||||
@echo Build example $(notdir $@)
|
||||
ifeq ($(UNAME_S),Darwin)
|
||||
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} GGML_METAL_PATH_RESOURCES=${GGML_METAL_PATH_RESOURCES} go build ${BUILD_FLAGS} -ldflags "-extldflags '$(EXT_LDFLAGS)'" -o ${BUILD_DIR}/$(notdir $@) ./$@
|
||||
else
|
||||
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go build ${BUILD_FLAGS} -o ${BUILD_DIR}/$(notdir $@) ./$@
|
||||
endif
|
||||
|
||||
mkdir:
|
||||
@echo Mkdir ${BUILD_DIR}
|
||||
|
@ -24,7 +24,7 @@ const (
|
||||
|
||||
var (
|
||||
// The models which will be downloaded, if no model is specified as an argument
|
||||
modelNames = []string{"ggml-tiny.en", "ggml-tiny", "ggml-base.en", "ggml-base", "ggml-small.en", "ggml-small", "ggml-medium.en", "ggml-medium", "ggml-large-v1", "ggml-large-v2", "ggml-large"}
|
||||
modelNames = []string{"ggml-tiny.en", "ggml-tiny", "ggml-base.en", "ggml-base", "ggml-small.en", "ggml-small", "ggml-medium.en", "ggml-medium", "ggml-large-v1", "ggml-large-v2", "ggml-large-v3"}
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -123,6 +123,11 @@ func (p *Params) SetAudioCtx(n int) {
|
||||
p.audio_ctx = C.int(n)
|
||||
}
|
||||
|
||||
// Set initial prompt
|
||||
func (p *Params) SetInitialPrompt(prompt string) {
|
||||
p.initial_prompt = C.CString(prompt)
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// PRIVATE METHODS
|
||||
|
||||
@ -147,6 +152,7 @@ func (p *Params) String() string {
|
||||
str += fmt.Sprintf(" offset_ms=%d", p.offset_ms)
|
||||
str += fmt.Sprintf(" duration_ms=%d", p.duration_ms)
|
||||
str += fmt.Sprintf(" audio_ctx=%d", p.audio_ctx)
|
||||
str += fmt.Sprintf(" initial_prompt=%s", C.GoString(p.initial_prompt))
|
||||
if p.translate {
|
||||
str += " translate"
|
||||
}
|
||||
|
@ -130,6 +130,11 @@ func (context *context) SetAudioCtx(n uint) {
|
||||
context.params.SetAudioCtx(int(n))
|
||||
}
|
||||
|
||||
// Set initial prompt
|
||||
func (context *context) SetInitialPrompt(prompt string) {
|
||||
context.params.SetInitialPrompt(prompt)
|
||||
}
|
||||
|
||||
// ResetTimings resets the mode timings. Should be called before processing
|
||||
func (context *context) ResetTimings() {
|
||||
context.model.ctx.Whisper_reset_timings()
|
||||
|
@ -38,17 +38,18 @@ type Context interface {
|
||||
IsMultilingual() bool // Return true if the model is multilingual.
|
||||
Language() string // Get language
|
||||
|
||||
SetOffset(time.Duration) // Set offset
|
||||
SetDuration(time.Duration) // Set duration
|
||||
SetThreads(uint) // Set number of threads to use
|
||||
SetSpeedup(bool) // Set speedup flag
|
||||
SetSplitOnWord(bool) // Set split on word flag
|
||||
SetTokenThreshold(float32) // Set timestamp token probability threshold
|
||||
SetTokenSumThreshold(float32) // Set timestamp token sum probability threshold
|
||||
SetMaxSegmentLength(uint) // Set max segment length in characters
|
||||
SetTokenTimestamps(bool) // Set token timestamps flag
|
||||
SetMaxTokensPerSegment(uint) // Set max tokens per segment (0 = no limit)
|
||||
SetAudioCtx(uint) // Set audio encoder context
|
||||
SetOffset(time.Duration) // Set offset
|
||||
SetDuration(time.Duration) // Set duration
|
||||
SetThreads(uint) // Set number of threads to use
|
||||
SetSpeedup(bool) // Set speedup flag
|
||||
SetSplitOnWord(bool) // Set split on word flag
|
||||
SetTokenThreshold(float32) // Set timestamp token probability threshold
|
||||
SetTokenSumThreshold(float32) // Set timestamp token sum probability threshold
|
||||
SetMaxSegmentLength(uint) // Set max segment length in characters
|
||||
SetTokenTimestamps(bool) // Set token timestamps flag
|
||||
SetMaxTokensPerSegment(uint) // Set max tokens per segment (0 = no limit)
|
||||
SetAudioCtx(uint) // Set audio encoder context
|
||||
SetInitialPrompt(prompt string) // Set initial prompt
|
||||
|
||||
// Process mono audio data and return any errors.
|
||||
// If defined, newly generated segments are passed to the
|
||||
|
@ -9,6 +9,7 @@ archivesBaseName = 'whispercpp'
|
||||
group = 'io.github.ggerganov'
|
||||
version = '1.4.0'
|
||||
|
||||
|
||||
sourceCompatibility = 1.8
|
||||
targetCompatibility = 1.8
|
||||
|
||||
|
@ -2,6 +2,7 @@ package io.github.ggerganov.whispercpp;
|
||||
|
||||
import com.sun.jna.Native;
|
||||
import com.sun.jna.Pointer;
|
||||
import io.github.ggerganov.whispercpp.bean.WhisperSegment;
|
||||
import io.github.ggerganov.whispercpp.params.WhisperContextParams;
|
||||
import io.github.ggerganov.whispercpp.params.WhisperFullParams;
|
||||
import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy;
|
||||
@ -9,6 +10,8 @@ import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy;
|
||||
import java.io.File;
|
||||
import java.io.FileNotFoundException;
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Before calling most methods, you must call `initContext(modelPath)` to initialise the `ctx` Pointer.
|
||||
@ -160,6 +163,28 @@ public class WhisperCpp implements AutoCloseable {
|
||||
|
||||
return str.toString().trim();
|
||||
}
|
||||
public List<WhisperSegment> fullTranscribeWithTime(WhisperFullParams whisperParams, float[] audioData) throws IOException {
|
||||
if (ctx == null) {
|
||||
throw new IllegalStateException("Model not initialised");
|
||||
}
|
||||
|
||||
if (lib.whisper_full(ctx, whisperParams, audioData, audioData.length) != 0) {
|
||||
throw new IOException("Failed to process audio");
|
||||
}
|
||||
|
||||
int nSegments = lib.whisper_full_n_segments(ctx);
|
||||
List<WhisperSegment> segments= new ArrayList<>(nSegments);
|
||||
|
||||
|
||||
for (int i = 0; i < nSegments; i++) {
|
||||
long t0 = lib.whisper_full_get_segment_t0(ctx, i);
|
||||
String text = lib.whisper_full_get_segment_text(ctx, i);
|
||||
long t1 = lib.whisper_full_get_segment_t1(ctx, i);
|
||||
segments.add(new WhisperSegment(t0,t1,text));
|
||||
}
|
||||
|
||||
return segments;
|
||||
}
|
||||
|
||||
// public int getTextSegmentCount(Pointer ctx) {
|
||||
// return lib.whisper_full_n_segments(ctx);
|
||||
|
@ -0,0 +1,47 @@
|
||||
package io.github.ggerganov.whispercpp.bean;
|
||||
|
||||
/**
|
||||
* Created by litonglinux@qq.com on 10/21/2023_7:48 AM
|
||||
*/
|
||||
public class WhisperSegment {
|
||||
private long start, end;
|
||||
private String sentence;
|
||||
|
||||
public WhisperSegment() {
|
||||
}
|
||||
|
||||
public WhisperSegment(long start, long end, String sentence) {
|
||||
this.start = start;
|
||||
this.end = end;
|
||||
this.sentence = sentence;
|
||||
}
|
||||
|
||||
public long getStart() {
|
||||
return start;
|
||||
}
|
||||
|
||||
public long getEnd() {
|
||||
return end;
|
||||
}
|
||||
|
||||
public String getSentence() {
|
||||
return sentence;
|
||||
}
|
||||
|
||||
public void setStart(long start) {
|
||||
this.start = start;
|
||||
}
|
||||
|
||||
public void setEnd(long end) {
|
||||
this.end = end;
|
||||
}
|
||||
|
||||
public void setSentence(String sentence) {
|
||||
this.sentence = sentence;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "[" + start + " --> " + end + "]:" + sentence;
|
||||
}
|
||||
}
|
@ -58,6 +58,9 @@ public class WhisperFullParams extends Structure {
|
||||
no_context = enable ? CBool.FALSE : CBool.TRUE;
|
||||
}
|
||||
|
||||
/** Generate timestamps or not? */
|
||||
public CBool no_timestamps;
|
||||
|
||||
/** Flag to force single segment output (useful for streaming). (default = false) */
|
||||
public CBool single_segment;
|
||||
|
||||
@ -304,10 +307,16 @@ public class WhisperFullParams extends Structure {
|
||||
logits_filter_callback = CallbackReference.getFunctionPointer(callback);
|
||||
}
|
||||
|
||||
/** Grammar stuff */
|
||||
public Pointer grammar_rules;
|
||||
public long n_grammar_rules;
|
||||
public long i_start_rule;
|
||||
public float grammar_penalty;
|
||||
|
||||
@Override
|
||||
protected List<String> getFieldOrder() {
|
||||
return Arrays.asList("strategy", "n_threads", "n_max_text_ctx", "offset_ms", "duration_ms", "translate",
|
||||
"no_context", "single_segment",
|
||||
"no_context", "single_segment", "no_timestamps",
|
||||
"print_special", "print_progress", "print_realtime", "print_timestamps", "token_timestamps",
|
||||
"thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "speed_up", "audio_ctx",
|
||||
"tdrz_enable", "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language",
|
||||
@ -316,6 +325,7 @@ public class WhisperFullParams extends Structure {
|
||||
"new_segment_callback", "new_segment_callback_user_data",
|
||||
"progress_callback", "progress_callback_user_data",
|
||||
"encoder_begin_callback", "encoder_begin_callback_user_data",
|
||||
"logits_filter_callback", "logits_filter_callback_user_data");
|
||||
"logits_filter_callback", "logits_filter_callback_user_data",
|
||||
"grammar_rules", "n_grammar_rules", "i_start_rule", "grammar_penalty");
|
||||
}
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ package io.github.ggerganov.whispercpp;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
import io.github.ggerganov.whispercpp.bean.WhisperSegment;
|
||||
import io.github.ggerganov.whispercpp.params.CBool;
|
||||
import io.github.ggerganov.whispercpp.params.WhisperFullParams;
|
||||
import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy;
|
||||
@ -11,6 +12,7 @@ import javax.sound.sampled.AudioInputStream;
|
||||
import javax.sound.sampled.AudioSystem;
|
||||
import java.io.File;
|
||||
import java.io.FileNotFoundException;
|
||||
import java.util.List;
|
||||
|
||||
class WhisperCppTest {
|
||||
private static WhisperCpp whisper = new WhisperCpp();
|
||||
@ -20,11 +22,12 @@ class WhisperCppTest {
|
||||
static void init() throws FileNotFoundException {
|
||||
// By default, models are loaded from ~/.cache/whisper/ and are usually named "ggml-${name}.bin"
|
||||
// or you can provide the absolute path to the model file.
|
||||
//String modelName = "../../models/ggml-tiny.bin";
|
||||
String modelName = "../../models/ggml-tiny.en.bin";
|
||||
try {
|
||||
whisper.initContext(modelName);
|
||||
// whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
|
||||
// whisper.getJavaDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
|
||||
//whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
|
||||
//whisper.getJavaDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
|
||||
modelInitialised = true;
|
||||
} catch (FileNotFoundException ex) {
|
||||
System.out.println("Model " + modelName + " not found");
|
||||
@ -42,7 +45,7 @@ class WhisperCppTest {
|
||||
assertEquals(16384, params.n_max_text_ctx);
|
||||
assertFalse(params.translate);
|
||||
assertEquals(0.01f, params.thold_pt);
|
||||
assertEquals(2, params.beam_search.beam_size);
|
||||
assertEquals(5, params.beam_search.beam_size);
|
||||
assertEquals(-1.0f, params.beam_search.patience);
|
||||
}
|
||||
|
||||
@ -55,7 +58,7 @@ class WhisperCppTest {
|
||||
assertEquals(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY.ordinal(), params.strategy);
|
||||
assertNotEquals(0, params.n_threads);
|
||||
assertEquals(16384, params.n_max_text_ctx);
|
||||
assertEquals(2, params.greedy.best_of);
|
||||
assertEquals(5, params.greedy.best_of);
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -72,11 +75,11 @@ class WhisperCppTest {
|
||||
byte[] b = new byte[audioInputStream.available()];
|
||||
float[] floats = new float[b.length / 2];
|
||||
|
||||
// WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
|
||||
//WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
|
||||
WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
|
||||
params.setProgressCallback((ctx, state, progress, user_data) -> System.out.println("progress: " + progress));
|
||||
params.print_progress = CBool.FALSE;
|
||||
// params.initial_prompt = "and so my fellow Americans um, like";
|
||||
//params.initial_prompt = "and so my fellow Americans um, like";
|
||||
|
||||
|
||||
try {
|
||||
@ -99,4 +102,43 @@ class WhisperCppTest {
|
||||
audioInputStream.close();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void testFullTranscribeWithTime() throws Exception {
|
||||
if (!modelInitialised) {
|
||||
System.out.println("Model not initialised, skipping test");
|
||||
return;
|
||||
}
|
||||
|
||||
// Given
|
||||
File file = new File(System.getProperty("user.dir"), "../../samples/jfk.wav");
|
||||
AudioInputStream audioInputStream = AudioSystem.getAudioInputStream(file);
|
||||
|
||||
byte[] b = new byte[audioInputStream.available()];
|
||||
float[] floats = new float[b.length / 2];
|
||||
|
||||
//WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
|
||||
WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
|
||||
params.setProgressCallback((ctx, state, progress, user_data) -> System.out.println("progress: " + progress));
|
||||
params.print_progress = CBool.FALSE;
|
||||
//params.initial_prompt = "and so my fellow Americans um, like";
|
||||
|
||||
try {
|
||||
audioInputStream.read(b);
|
||||
|
||||
for (int i = 0, j = 0; i < b.length; i += 2, j++) {
|
||||
int intSample = (int) (b[i + 1]) << 8 | (int) (b[i]) & 0xFF;
|
||||
floats[j] = intSample / 32767.0f;
|
||||
}
|
||||
|
||||
List<WhisperSegment> segments = whisper.fullTranscribeWithTime(params, floats);
|
||||
assertTrue(segments.size() > 0, "The size of segments should be greater than 0");
|
||||
for (WhisperSegment segment : segments) {
|
||||
System.out.println(segment);
|
||||
}
|
||||
} finally {
|
||||
audioInputStream.close();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -41,7 +41,7 @@ make publish-npm
|
||||
|
||||
## Sample run
|
||||
|
||||
```java
|
||||
```text
|
||||
$ node --experimental-wasm-threads --experimental-wasm-simd ../tests/test-whisper.js
|
||||
|
||||
whisper_model_load: loading model from 'whisper.bin'
|
||||
@ -63,7 +63,7 @@ whisper_model_load: ggml ctx size = 140.60 MB
|
||||
whisper_model_load: memory size = 22.83 MB
|
||||
whisper_model_load: model size = 140.54 MB
|
||||
|
||||
system_info: n_threads = 8 / 10 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | NEON = 0 | F16C = 0 | FP16_VA = 0 | WASM_SIMD = 1 | BLAS = 0 |
|
||||
system_info: n_threads = 8 / 10 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | NEON = 0 | F16C = 0 | FP16_VA = 0 | WASM_SIMD = 1 | BLAS = 0 |
|
||||
|
||||
operator(): processing 176000 samples, 11.0 sec, 8 threads, 1 processors, lang = en, task = transcribe ...
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "whisper.cpp",
|
||||
"version": "1.4.3",
|
||||
"version": "1.5.4",
|
||||
"description": "Whisper speech recognition",
|
||||
"main": "whisper.js",
|
||||
"scripts": {
|
||||
|
@ -70,7 +70,7 @@ extern "C" {
|
||||
void (*graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
|
||||
|
||||
// compute graph without a plan
|
||||
void (*graph_compute)(ggml_backend_t backend, struct ggml_cgraph * cgraph);
|
||||
bool (*graph_compute)(ggml_backend_t backend, struct ggml_cgraph * cgraph);
|
||||
|
||||
// check if the backend supports an operation
|
||||
bool (*supports_op)(ggml_backend_t backend, const struct ggml_tensor * op);
|
||||
|
@ -156,8 +156,8 @@ void ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_
|
||||
backend->iface.graph_plan_compute(backend, plan);
|
||||
}
|
||||
|
||||
void ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
|
||||
backend->iface.graph_compute(backend, cgraph);
|
||||
bool ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
|
||||
return backend->iface.graph_compute(backend, cgraph);
|
||||
}
|
||||
|
||||
bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
|
||||
|
@ -52,7 +52,7 @@ extern "C" {
|
||||
|
||||
GGML_API void ggml_backend_graph_plan_free (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
|
||||
GGML_API void ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
|
||||
GGML_API void ggml_backend_graph_compute (ggml_backend_t backend, struct ggml_cgraph * cgraph);
|
||||
GGML_API bool ggml_backend_graph_compute (ggml_backend_t backend, struct ggml_cgraph * cgraph);
|
||||
GGML_API bool ggml_backend_supports_op (ggml_backend_t backend, const struct ggml_tensor * op);
|
||||
|
||||
// tensor copy between different backends
|
||||
|
@ -24,9 +24,9 @@ struct whisper_coreml_context * whisper_coreml_init(const char * path_model) {
|
||||
|
||||
// select which device to run the Core ML model on
|
||||
MLModelConfiguration *config = [[MLModelConfiguration alloc] init];
|
||||
config.computeUnits = MLComputeUnitsCPUAndGPU;
|
||||
// config.computeUnits = MLComputeUnitsCPUAndGPU;
|
||||
//config.computeUnits = MLComputeUnitsCPUAndNeuralEngine;
|
||||
//config.computeUnits = MLComputeUnitsAll;
|
||||
config.computeUnits = MLComputeUnitsAll;
|
||||
|
||||
const void * data = CFBridgingRetain([[whisper_encoder_impl alloc] initWithContentsOfURL:url_model configuration:config error:nil]);
|
||||
|
||||
|
@ -14,6 +14,10 @@ if (WHISPER_SDL2)
|
||||
message(STATUS "SDL2_LIBRARIES = ${SDL2_LIBRARIES}")
|
||||
endif()
|
||||
|
||||
if (WHISPER_CLBLAST)
|
||||
find_package(CLBlast REQUIRED)
|
||||
endif()
|
||||
|
||||
# common
|
||||
|
||||
set(TARGET common)
|
||||
@ -23,6 +27,7 @@ add_library(${TARGET} STATIC
|
||||
common.cpp
|
||||
common-ggml.h
|
||||
common-ggml.cpp
|
||||
grammar-parser.cpp
|
||||
)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
@ -64,6 +69,7 @@ elseif(CMAKE_JS_VERSION)
|
||||
else()
|
||||
add_subdirectory(main)
|
||||
add_subdirectory(stream)
|
||||
add_subdirectory(server)
|
||||
add_subdirectory(command)
|
||||
add_subdirectory(bench)
|
||||
add_subdirectory(quantize)
|
||||
@ -71,3 +77,5 @@ else()
|
||||
add_subdirectory(talk-llama)
|
||||
add_subdirectory(lsp)
|
||||
endif()
|
||||
|
||||
add_subdirectory(wchess)
|
||||
|
@ -154,7 +154,7 @@ int run(whisper_params ¶ms, std::vector<std::vector<std::string>> &result) {
|
||||
|
||||
// whisper init
|
||||
|
||||
struct whisper_context_params cparams;
|
||||
struct whisper_context_params cparams = whisper_context_default_params();
|
||||
cparams.use_gpu = params.use_gpu;
|
||||
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
|
||||
|
||||
|
@ -58,7 +58,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
||||
int whisper_bench_full(const whisper_params & params) {
|
||||
// whisper init
|
||||
|
||||
struct whisper_context_params cparams;
|
||||
struct whisper_context_params cparams = whisper_context_default_params();
|
||||
cparams.use_gpu = params.use_gpu;
|
||||
|
||||
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
|
||||
@ -81,7 +81,7 @@ int whisper_bench_full(const whisper_params & params) {
|
||||
}
|
||||
// heat encoder
|
||||
if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
||||
fprintf(stderr, "error: failed to encode: %d\n", ret);
|
||||
return 4;
|
||||
}
|
||||
|
||||
@ -90,13 +90,13 @@ int whisper_bench_full(const whisper_params & params) {
|
||||
|
||||
// prompt heat
|
||||
if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
||||
fprintf(stderr, "error: failed to decode: %d\n", ret);
|
||||
return 4;
|
||||
}
|
||||
|
||||
// text-generation heat
|
||||
if (int ret = whisper_decode(ctx, tokens, 1, 256, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
||||
fprintf(stderr, "error: failed to decode: %d\n", ret);
|
||||
return 4;
|
||||
}
|
||||
|
||||
@ -104,20 +104,30 @@ int whisper_bench_full(const whisper_params & params) {
|
||||
|
||||
// actual run
|
||||
if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
||||
fprintf(stderr, "error: failed to encode: %d\n", ret);
|
||||
return 4;
|
||||
}
|
||||
|
||||
for (int i = 0; i < 16; i++) {
|
||||
if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
||||
// text-generation
|
||||
for (int i = 0; i < 256; i++) {
|
||||
if (int ret = whisper_decode(ctx, tokens, 1, i, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to decode: %d\n", ret);
|
||||
return 4;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < 256; i++) {
|
||||
if (int ret = whisper_decode(ctx, tokens, 1, i, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
||||
// batched decoding
|
||||
for (int i = 0; i < 64; i++) {
|
||||
if (int ret = whisper_decode(ctx, tokens, 5, 0, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to decode: %d\n", ret);
|
||||
return 4;
|
||||
}
|
||||
}
|
||||
|
||||
// prompt processing
|
||||
for (int i = 0; i < 16; i++) {
|
||||
if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to decode: %d\n", ret);
|
||||
return 4;
|
||||
}
|
||||
}
|
||||
|
@ -9,6 +9,7 @@
|
||||
#include "common-sdl.h"
|
||||
#include "common.h"
|
||||
#include "whisper.h"
|
||||
#include "grammar-parser.h"
|
||||
|
||||
#include <sstream>
|
||||
#include <cassert>
|
||||
@ -21,6 +22,11 @@
|
||||
#include <vector>
|
||||
#include <map>
|
||||
|
||||
bool file_exists(const std::string & fname) {
|
||||
std::ifstream f(fname.c_str());
|
||||
return f.good();
|
||||
}
|
||||
|
||||
// command-line parameters
|
||||
struct whisper_params {
|
||||
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||
@ -30,8 +36,12 @@ struct whisper_params {
|
||||
int32_t max_tokens = 32;
|
||||
int32_t audio_ctx = 0;
|
||||
|
||||
float vad_thold = 0.6f;
|
||||
float freq_thold = 100.0f;
|
||||
float vad_thold = 0.6f;
|
||||
float freq_thold = 100.0f;
|
||||
|
||||
float grammar_penalty = 100.0f;
|
||||
|
||||
grammar_parser::parse_state grammar_parsed;
|
||||
|
||||
bool speed_up = false;
|
||||
bool translate = false;
|
||||
@ -45,6 +55,8 @@ struct whisper_params {
|
||||
std::string fname_out;
|
||||
std::string commands;
|
||||
std::string prompt;
|
||||
std::string context;
|
||||
std::string grammar;
|
||||
};
|
||||
|
||||
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
|
||||
@ -75,6 +87,9 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
|
||||
else if (arg == "-cmd" || arg == "--commands") { params.commands = argv[++i]; }
|
||||
else if (arg == "-p" || arg == "--prompt") { params.prompt = argv[++i]; }
|
||||
else if (arg == "-ctx" || arg == "--context") { params.context = argv[++i]; }
|
||||
else if ( arg == "--grammar") { params.grammar = argv[++i]; }
|
||||
else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
|
||||
else {
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
whisper_print_usage(argc, argv, params);
|
||||
@ -109,16 +124,30 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
||||
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
|
||||
fprintf(stderr, " -cmd FNAME, --commands FNAME [%-7s] text file with allowed commands\n", params.commands.c_str());
|
||||
fprintf(stderr, " -p, --prompt [%-7s] the required activation prompt\n", params.prompt.c_str());
|
||||
fprintf(stderr, " -ctx, --context [%-7s] sample text to help the transcription\n", params.context.c_str());
|
||||
fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
|
||||
fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty);
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
std::string transcribe(whisper_context * ctx, const whisper_params & params, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) {
|
||||
std::string transcribe(
|
||||
whisper_context * ctx,
|
||||
const whisper_params & params,
|
||||
const std::vector<float> & pcmf32,
|
||||
const std::string & grammar_rule,
|
||||
float & logprob_min,
|
||||
float & logprob_sum,
|
||||
int & n_tokens,
|
||||
int64_t & t_ms) {
|
||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
prob = 0.0f;
|
||||
logprob_min = 0.0f;
|
||||
logprob_sum = 0.0f;
|
||||
n_tokens = 0;
|
||||
t_ms = 0;
|
||||
|
||||
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||
//whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH);
|
||||
|
||||
wparams.print_progress = false;
|
||||
wparams.print_special = params.print_special;
|
||||
@ -126,19 +155,41 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con
|
||||
wparams.print_timestamps = !params.no_timestamps;
|
||||
wparams.translate = params.translate;
|
||||
wparams.no_context = true;
|
||||
wparams.no_timestamps = params.no_timestamps;
|
||||
wparams.single_segment = true;
|
||||
wparams.max_tokens = params.max_tokens;
|
||||
wparams.language = params.language.c_str();
|
||||
wparams.n_threads = params.n_threads;
|
||||
|
||||
wparams.audio_ctx = params.audio_ctx;
|
||||
wparams.speed_up = params.speed_up;
|
||||
wparams.audio_ctx = params.audio_ctx;
|
||||
wparams.speed_up = params.speed_up;
|
||||
|
||||
wparams.temperature = 0.4f;
|
||||
wparams.temperature_inc = 1.0f;
|
||||
wparams.greedy.best_of = 5;
|
||||
|
||||
wparams.beam_search.beam_size = 5;
|
||||
|
||||
wparams.initial_prompt = params.context.data();
|
||||
|
||||
const auto & grammar_parsed = params.grammar_parsed;
|
||||
auto grammar_rules = grammar_parsed.c_rules();
|
||||
|
||||
if (!params.grammar_parsed.rules.empty() && !grammar_rule.empty()) {
|
||||
if (grammar_parsed.symbol_ids.find(grammar_rule) == grammar_parsed.symbol_ids.end()) {
|
||||
fprintf(stderr, "%s: warning: grammar rule '%s' not found - skipping grammar sampling\n", __func__, grammar_rule.c_str());
|
||||
} else {
|
||||
wparams.grammar_rules = grammar_rules.data();
|
||||
wparams.n_grammar_rules = grammar_rules.size();
|
||||
wparams.i_start_rule = grammar_parsed.symbol_ids.at(grammar_rule);
|
||||
wparams.grammar_penalty = params.grammar_penalty;
|
||||
}
|
||||
}
|
||||
|
||||
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
|
||||
return "";
|
||||
}
|
||||
|
||||
int prob_n = 0;
|
||||
std::string result;
|
||||
|
||||
const int n_segments = whisper_full_n_segments(ctx);
|
||||
@ -147,19 +198,17 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con
|
||||
|
||||
result += text;
|
||||
|
||||
const int n_tokens = whisper_full_n_tokens(ctx, i);
|
||||
for (int j = 0; j < n_tokens; ++j) {
|
||||
const int n = whisper_full_n_tokens(ctx, i);
|
||||
for (int j = 0; j < n; ++j) {
|
||||
const auto token = whisper_full_get_token_data(ctx, i, j);
|
||||
|
||||
prob += token.p;
|
||||
++prob_n;
|
||||
if(token.plog > 0.0f) exit(0);
|
||||
logprob_min = std::min(logprob_min, token.plog);
|
||||
logprob_sum += token.plog;
|
||||
++n_tokens;
|
||||
}
|
||||
}
|
||||
|
||||
if (prob_n > 0) {
|
||||
prob /= prob_n;
|
||||
}
|
||||
|
||||
const auto t_end = std::chrono::high_resolution_clock::now();
|
||||
t_ms = std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count();
|
||||
|
||||
@ -250,7 +299,7 @@ int process_command_list(struct whisper_context * ctx, audio_async &audio, const
|
||||
fprintf(stderr, " ]\n");
|
||||
}
|
||||
|
||||
std::string k_prompt = "select one from the available words: ";
|
||||
std::string k_prompt = "select one from the available words: ";
|
||||
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
|
||||
if (i > 0) {
|
||||
k_prompt += ", ";
|
||||
@ -418,7 +467,9 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
|
||||
bool is_running = true;
|
||||
bool ask_prompt = true;
|
||||
|
||||
float prob = 0.0f;
|
||||
float logprob_min = 0.0f;
|
||||
float logprob_sum = 0.0f;
|
||||
int n_tokens = 0;
|
||||
|
||||
std::vector<float> pcmf32_cur;
|
||||
|
||||
@ -456,7 +507,7 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
|
||||
// detect the commands
|
||||
audio.get(params.command_ms, pcmf32_cur);
|
||||
|
||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
|
||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "", logprob_min, logprob_sum, n_tokens, t_ms));
|
||||
|
||||
const auto words = get_words(txt);
|
||||
|
||||
@ -492,18 +543,27 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
|
||||
|
||||
// general-purpose mode
|
||||
// freely transcribe the voice into text
|
||||
int process_general_transcription(struct whisper_context * ctx, audio_async &audio, const whisper_params ¶ms) {
|
||||
int process_general_transcription(struct whisper_context * ctx, audio_async & audio, const whisper_params & params) {
|
||||
bool is_running = true;
|
||||
bool have_prompt = false;
|
||||
bool ask_prompt = true;
|
||||
|
||||
float prob0 = 0.0f;
|
||||
float prob = 0.0f;
|
||||
float logprob_min0 = 0.0f;
|
||||
float logprob_min = 0.0f;
|
||||
|
||||
float logprob_sum0 = 0.0f;
|
||||
float logprob_sum = 0.0f;
|
||||
|
||||
int n_tokens0 = 0;
|
||||
int n_tokens = 0;
|
||||
|
||||
std::vector<float> pcmf32_cur;
|
||||
std::vector<float> pcmf32_prompt;
|
||||
|
||||
const std::string k_prompt = "Ok Whisper, start listening for commands.";
|
||||
std::string k_prompt = "Ok Whisper, start listening for commands.";
|
||||
if (!params.prompt.empty()) {
|
||||
k_prompt = params.prompt;
|
||||
}
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "%s: general-purpose mode\n", __func__);
|
||||
@ -536,9 +596,11 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud
|
||||
// wait for activation phrase
|
||||
audio.get(params.prompt_ms, pcmf32_cur);
|
||||
|
||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob0, t_ms));
|
||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "prompt", logprob_min0, logprob_sum0, n_tokens0, t_ms));
|
||||
|
||||
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms);
|
||||
const float p = 100.0f * std::exp(logprob_min0);
|
||||
|
||||
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms, p = %.2f%%)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms, p);
|
||||
|
||||
const float sim = similarity(txt, k_prompt);
|
||||
|
||||
@ -559,19 +621,30 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud
|
||||
// we have heard the activation phrase, now detect the commands
|
||||
audio.get(params.command_ms, pcmf32_cur);
|
||||
|
||||
//printf("len prompt: %.4f\n", pcmf32_prompt.size() / (float) WHISPER_SAMPLE_RATE);
|
||||
//printf("len command: %.4f\n", pcmf32_cur.size() / (float) WHISPER_SAMPLE_RATE);
|
||||
|
||||
// prepend 3 second of silence
|
||||
pcmf32_cur.insert(pcmf32_cur.begin(), 3.0f*WHISPER_SAMPLE_RATE, 0.0f);
|
||||
|
||||
// prepend the prompt audio
|
||||
pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
|
||||
|
||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
|
||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "root", logprob_min, logprob_sum, n_tokens, t_ms));
|
||||
|
||||
prob = 100.0f*(prob - prob0);
|
||||
//const float p = 100.0f * std::exp((logprob - logprob0) / (n_tokens - n_tokens0));
|
||||
const float p = 100.0f * std::exp(logprob_min);
|
||||
|
||||
//fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str());
|
||||
|
||||
// find the prompt in the text
|
||||
float best_sim = 0.0f;
|
||||
size_t best_len = 0;
|
||||
for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
|
||||
for (size_t n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
|
||||
if (n >= txt.size()) {
|
||||
break;
|
||||
}
|
||||
|
||||
const auto prompt = txt.substr(0, n);
|
||||
|
||||
const float sim = similarity(prompt, k_prompt);
|
||||
@ -584,9 +657,16 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud
|
||||
}
|
||||
}
|
||||
|
||||
const std::string command = ::trim(txt.substr(best_len));
|
||||
fprintf(stdout, "%s: DEBUG: txt = '%s', prob = %.2f%%\n", __func__, txt.c_str(), p);
|
||||
if (best_len == 0) {
|
||||
fprintf(stdout, "%s: WARNING: command not recognized, try again\n", __func__);
|
||||
} else {
|
||||
// cut the prompt from the decoded text
|
||||
const std::string command = ::trim(txt.substr(best_len));
|
||||
|
||||
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
|
||||
}
|
||||
|
||||
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
|
||||
fprintf(stdout, "\n");
|
||||
}
|
||||
|
||||
@ -613,7 +693,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// whisper init
|
||||
|
||||
struct whisper_context_params cparams;
|
||||
struct whisper_context_params cparams = whisper_context_default_params();
|
||||
cparams.use_gpu = params.use_gpu;
|
||||
|
||||
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
|
||||
@ -654,12 +734,36 @@ int main(int argc, char ** argv) {
|
||||
|
||||
int ret_val = 0;
|
||||
|
||||
if (!params.commands.empty()) {
|
||||
ret_val = process_command_list(ctx, audio, params);
|
||||
} else if (!params.prompt.empty()) {
|
||||
ret_val = always_prompt_transcription(ctx, audio, params);
|
||||
} else {
|
||||
ret_val = process_general_transcription(ctx, audio, params);
|
||||
if (!params.grammar.empty()) {
|
||||
auto & grammar = params.grammar_parsed;
|
||||
if (file_exists(params.grammar.c_str())) {
|
||||
// read grammar from file
|
||||
std::ifstream ifs(params.grammar.c_str());
|
||||
const std::string txt = std::string((std::istreambuf_iterator<char>(ifs)), std::istreambuf_iterator<char>());
|
||||
grammar = grammar_parser::parse(txt.c_str());
|
||||
} else {
|
||||
// read grammar from string
|
||||
grammar = grammar_parser::parse(params.grammar.c_str());
|
||||
}
|
||||
|
||||
// will be empty (default) if there are parse errors
|
||||
if (grammar.rules.empty()) {
|
||||
ret_val = 1;
|
||||
} else {
|
||||
fprintf(stderr, "%s: grammar:\n", __func__);
|
||||
grammar_parser::print_grammar(stderr, grammar);
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
}
|
||||
|
||||
if (ret_val == 0) {
|
||||
if (!params.commands.empty()) {
|
||||
ret_val = process_command_list(ctx, audio, params);
|
||||
} else if (!params.prompt.empty() && params.grammar_parsed.rules.empty()) {
|
||||
ret_val = always_prompt_transcription(ctx, audio, params);
|
||||
} else {
|
||||
ret_val = process_general_transcription(ctx, audio, params);
|
||||
}
|
||||
}
|
||||
|
||||
audio.pause();
|
||||
|
@ -9,6 +9,11 @@ static const std::map<std::string, enum ggml_ftype> GGML_FTYPE_MAP = {
|
||||
{"q5_0", GGML_FTYPE_MOSTLY_Q5_0},
|
||||
{"q5_1", GGML_FTYPE_MOSTLY_Q5_1},
|
||||
{"q8_0", GGML_FTYPE_MOSTLY_Q8_0},
|
||||
{"q2_k", GGML_FTYPE_MOSTLY_Q2_K},
|
||||
{"q3_k", GGML_FTYPE_MOSTLY_Q3_K},
|
||||
{"q4_k", GGML_FTYPE_MOSTLY_Q4_K},
|
||||
{"q5_k", GGML_FTYPE_MOSTLY_Q5_K},
|
||||
{"q6_k", GGML_FTYPE_MOSTLY_Q6_K},
|
||||
};
|
||||
|
||||
void ggml_print_ftypes(FILE * fp) {
|
||||
@ -48,15 +53,18 @@ bool ggml_common_quantize_0(
|
||||
case GGML_FTYPE_MOSTLY_Q5_0: qtype = GGML_TYPE_Q5_0; break;
|
||||
case GGML_FTYPE_MOSTLY_Q5_1: qtype = GGML_TYPE_Q5_1; break;
|
||||
case GGML_FTYPE_MOSTLY_Q8_0: qtype = GGML_TYPE_Q8_0; break;
|
||||
case GGML_FTYPE_MOSTLY_Q2_K: qtype = GGML_TYPE_Q2_K; break;
|
||||
case GGML_FTYPE_MOSTLY_Q3_K: qtype = GGML_TYPE_Q3_K; break;
|
||||
case GGML_FTYPE_MOSTLY_Q4_K: qtype = GGML_TYPE_Q4_K; break;
|
||||
case GGML_FTYPE_MOSTLY_Q5_K: qtype = GGML_TYPE_Q5_K; break;
|
||||
case GGML_FTYPE_MOSTLY_Q6_K: qtype = GGML_TYPE_Q6_K; break;
|
||||
case GGML_FTYPE_UNKNOWN:
|
||||
case GGML_FTYPE_ALL_F32:
|
||||
case GGML_FTYPE_MOSTLY_F16:
|
||||
case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16:
|
||||
case GGML_FTYPE_MOSTLY_Q2_K:
|
||||
case GGML_FTYPE_MOSTLY_Q3_K:
|
||||
case GGML_FTYPE_MOSTLY_Q4_K:
|
||||
case GGML_FTYPE_MOSTLY_Q5_K:
|
||||
case GGML_FTYPE_MOSTLY_Q6_K:
|
||||
case GGML_FTYPE_MOSTLY_IQ2_XXS:
|
||||
case GGML_FTYPE_MOSTLY_IQ2_XS:
|
||||
case GGML_FTYPE_MOSTLY_IQ3_XXS:
|
||||
{
|
||||
fprintf(stderr, "%s: invalid model type %d\n", __func__, ftype);
|
||||
return false;
|
||||
@ -167,24 +175,17 @@ bool ggml_common_quantize_0(
|
||||
|
||||
switch ((ggml_type) ttype) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
{
|
||||
cur_size = ggml_quantize_q4_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
|
||||
} break;
|
||||
case GGML_TYPE_Q4_1:
|
||||
{
|
||||
cur_size = ggml_quantize_q4_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
|
||||
} break;
|
||||
case GGML_TYPE_Q5_0:
|
||||
{
|
||||
cur_size = ggml_quantize_q5_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
|
||||
} break;
|
||||
case GGML_TYPE_Q5_1:
|
||||
{
|
||||
cur_size = ggml_quantize_q5_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
|
||||
} break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q2_K:
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q5_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
{
|
||||
cur_size = ggml_quantize_q8_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
|
||||
cur_size = ggml_quantize_chunk((ggml_type) ttype, data_f32.data(), work.data(), 0, nelements/ne[0], ne[0], hist_cur.data(), nullptr);
|
||||
} break;
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_F16:
|
||||
@ -192,12 +193,10 @@ bool ggml_common_quantize_0(
|
||||
case GGML_TYPE_I16:
|
||||
case GGML_TYPE_I32:
|
||||
case GGML_TYPE_Q8_1:
|
||||
case GGML_TYPE_Q2_K:
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q5_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
case GGML_TYPE_Q8_K:
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
case GGML_TYPE_COUNT:
|
||||
{
|
||||
fprintf(stderr, "%s: unsupported quantization type %d (%s)\n", __func__, ttype, ggml_type_name((ggml_type) ttype));
|
||||
|
@ -139,10 +139,13 @@ void audio_async::callback(uint8_t * stream, int len) {
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t n_samples = len / sizeof(float);
|
||||
size_t n_samples = len / sizeof(float);
|
||||
|
||||
m_audio_new.resize(n_samples);
|
||||
memcpy(m_audio_new.data(), stream, n_samples * sizeof(float));
|
||||
if (n_samples > m_audio.size()) {
|
||||
n_samples = m_audio.size();
|
||||
|
||||
stream += (len - (n_samples * sizeof(float)));
|
||||
}
|
||||
|
||||
//fprintf(stderr, "%s: %zu samples, pos %zu, len %zu\n", __func__, n_samples, m_audio_pos, m_audio_len);
|
||||
|
||||
@ -153,7 +156,7 @@ void audio_async::callback(uint8_t * stream, int len) {
|
||||
const size_t n0 = m_audio.size() - m_audio_pos;
|
||||
|
||||
memcpy(&m_audio[m_audio_pos], stream, n0 * sizeof(float));
|
||||
memcpy(&m_audio[0], &stream[n0], (n_samples - n0) * sizeof(float));
|
||||
memcpy(&m_audio[0], stream + n0 * sizeof(float), (n_samples - n0) * sizeof(float));
|
||||
|
||||
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
|
||||
m_audio_len = m_audio.size();
|
||||
|
@ -41,7 +41,6 @@ private:
|
||||
std::mutex m_mutex;
|
||||
|
||||
std::vector<float> m_audio;
|
||||
std::vector<float> m_audio_new;
|
||||
size_t m_audio_pos = 0;
|
||||
size_t m_audio_len = 0;
|
||||
};
|
||||
|
@ -615,6 +615,21 @@ gpt_vocab::id gpt_sample_top_k_top_p_repeat(
|
||||
|
||||
}
|
||||
|
||||
bool is_wav_buffer(const std::string buf) {
|
||||
// RIFF ref: https://en.wikipedia.org/wiki/Resource_Interchange_File_Format
|
||||
// WAV ref: https://www.mmsp.ece.mcgill.ca/Documents/AudioFormats/WAVE/WAVE.html
|
||||
if (buf.size() < 12 || buf.substr(0, 4) != "RIFF" || buf.substr(8, 4) != "WAVE") {
|
||||
return false;
|
||||
}
|
||||
|
||||
uint32_t chunk_size = *reinterpret_cast<const uint32_t*>(buf.data() + 4);
|
||||
if (chunk_size + 8 != buf.size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool read_wav(const std::string & fname, std::vector<float>& pcmf32, std::vector<std::vector<float>>& pcmf32s, bool stereo) {
|
||||
drwav wav;
|
||||
std::vector<uint8_t> wav_data; // used for pipe input from stdin
|
||||
@ -639,6 +654,12 @@ bool read_wav(const std::string & fname, std::vector<float>& pcmf32, std::vector
|
||||
|
||||
fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
|
||||
}
|
||||
else if (is_wav_buffer(fname)) {
|
||||
if (drwav_init_memory(&wav, fname.c_str(), fname.size(), nullptr) == false) {
|
||||
fprintf(stderr, "error: failed to open WAV file from fname buffer\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if (drwav_init_file(&wav, fname.c_str(), nullptr) == false) {
|
||||
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname.c_str());
|
||||
return false;
|
||||
|
@ -135,7 +135,11 @@ gpt_vocab::id gpt_sample_top_k_top_p_repeat(
|
||||
// Audio utils
|
||||
//
|
||||
|
||||
// Check if a buffer is a WAV audio file
|
||||
bool is_wav_buffer(const std::string buf);
|
||||
|
||||
// Read WAV audio file and store the PCM data into pcmf32
|
||||
// fname can be a buffer of WAV data instead of a filename
|
||||
// The sample rate of the audio must be equal to COMMON_SAMPLE_RATE
|
||||
// If stereo flag is set and the audio has 2 channels, the pcmf32s will contain 2 channel PCM
|
||||
bool read_wav(
|
||||
|
423
examples/grammar-parser.cpp
Normal file
@ -0,0 +1,423 @@
|
||||
#include "grammar-parser.h"
|
||||
#include <cstdint>
|
||||
#include <cwchar>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <stdexcept>
|
||||
#include <exception>
|
||||
|
||||
namespace grammar_parser {
|
||||
// NOTE: assumes valid utf8 (but checks for overrun)
|
||||
// copied from whisper.cpp
|
||||
std::pair<uint32_t, const char *> decode_utf8(const char * src) {
|
||||
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
|
||||
uint8_t first_byte = static_cast<uint8_t>(*src);
|
||||
uint8_t highbits = first_byte >> 4;
|
||||
int len = lookup[highbits];
|
||||
uint8_t mask = (1 << (8 - len)) - 1;
|
||||
uint32_t value = first_byte & mask;
|
||||
const char * end = src + len; // may overrun!
|
||||
const char * pos = src + 1;
|
||||
for ( ; pos < end && *pos; pos++) {
|
||||
value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
|
||||
}
|
||||
return std::make_pair(value, pos);
|
||||
}
|
||||
|
||||
uint32_t get_symbol_id(parse_state & state, const char * src, size_t len) {
|
||||
uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size());
|
||||
auto result = state.symbol_ids.insert(std::make_pair(std::string(src, len), next_id));
|
||||
return result.first->second;
|
||||
}
|
||||
|
||||
uint32_t generate_symbol_id(parse_state & state, const std::string & base_name) {
|
||||
uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size());
|
||||
state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id;
|
||||
return next_id;
|
||||
}
|
||||
|
||||
void add_rule(
|
||||
parse_state & state,
|
||||
uint32_t rule_id,
|
||||
const std::vector<whisper_grammar_element> & rule) {
|
||||
if (state.rules.size() <= rule_id) {
|
||||
state.rules.resize(rule_id + 1);
|
||||
}
|
||||
state.rules[rule_id] = rule;
|
||||
}
|
||||
|
||||
bool is_word_char(char c) {
|
||||
return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9');
|
||||
}
|
||||
|
||||
std::pair<uint32_t, const char *> parse_hex(const char * src, int size) {
|
||||
const char * pos = src;
|
||||
const char * end = src + size;
|
||||
uint32_t value = 0;
|
||||
for ( ; pos < end && *pos; pos++) {
|
||||
value <<= 4;
|
||||
char c = *pos;
|
||||
if ('a' <= c && c <= 'f') {
|
||||
value += c - 'a' + 10;
|
||||
} else if ('A' <= c && c <= 'F') {
|
||||
value += c - 'A' + 10;
|
||||
} else if ('0' <= c && c <= '9') {
|
||||
value += c - '0';
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (pos != end) {
|
||||
throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src);
|
||||
}
|
||||
return std::make_pair(value, pos);
|
||||
}
|
||||
|
||||
const char * parse_space(const char * src, bool newline_ok) {
|
||||
const char * pos = src;
|
||||
while (*pos == ' ' || *pos == '\t' || *pos == '#' ||
|
||||
(newline_ok && (*pos == '\r' || *pos == '\n'))) {
|
||||
if (*pos == '#') {
|
||||
while (*pos && *pos != '\r' && *pos != '\n') {
|
||||
pos++;
|
||||
}
|
||||
} else {
|
||||
pos++;
|
||||
}
|
||||
}
|
||||
return pos;
|
||||
}
|
||||
|
||||
const char * parse_name(const char * src) {
|
||||
const char * pos = src;
|
||||
while (is_word_char(*pos)) {
|
||||
pos++;
|
||||
}
|
||||
if (pos == src) {
|
||||
throw std::runtime_error(std::string("expecting name at ") + src);
|
||||
}
|
||||
return pos;
|
||||
}
|
||||
|
||||
std::pair<uint32_t, const char *> parse_char(const char * src) {
|
||||
if (*src == '\\') {
|
||||
switch (src[1]) {
|
||||
case 'x': return parse_hex(src + 2, 2);
|
||||
case 'u': return parse_hex(src + 2, 4);
|
||||
case 'U': return parse_hex(src + 2, 8);
|
||||
case 't': return std::make_pair('\t', src + 2);
|
||||
case 'r': return std::make_pair('\r', src + 2);
|
||||
case 'n': return std::make_pair('\n', src + 2);
|
||||
case '\\':
|
||||
case '"':
|
||||
case '[':
|
||||
case ']':
|
||||
return std::make_pair(src[1], src + 2);
|
||||
default:
|
||||
throw std::runtime_error(std::string("unknown escape at ") + src);
|
||||
}
|
||||
} else if (*src) {
|
||||
return decode_utf8(src);
|
||||
}
|
||||
throw std::runtime_error("unexpected end of input");
|
||||
}
|
||||
|
||||
const char * parse_alternates(
|
||||
parse_state & state,
|
||||
const char * src,
|
||||
const std::string & rule_name,
|
||||
uint32_t rule_id,
|
||||
bool is_nested);
|
||||
|
||||
const char * parse_sequence(
|
||||
parse_state & state,
|
||||
const char * src,
|
||||
const std::string & rule_name,
|
||||
std::vector<whisper_grammar_element> & out_elements,
|
||||
bool is_nested) {
|
||||
size_t last_sym_start = out_elements.size();
|
||||
const char * pos = src;
|
||||
while (*pos) {
|
||||
if (*pos == '"') { // literal string
|
||||
pos++;
|
||||
last_sym_start = out_elements.size();
|
||||
while (*pos != '"') {
|
||||
auto char_pair = parse_char(pos);
|
||||
pos = char_pair.second;
|
||||
out_elements.push_back({WHISPER_GRETYPE_CHAR, char_pair.first});
|
||||
}
|
||||
pos = parse_space(pos + 1, is_nested);
|
||||
} else if (*pos == '[') { // char range(s)
|
||||
pos++;
|
||||
enum whisper_gretype start_type = WHISPER_GRETYPE_CHAR;
|
||||
if (*pos == '^') {
|
||||
pos++;
|
||||
start_type = WHISPER_GRETYPE_CHAR_NOT;
|
||||
}
|
||||
last_sym_start = out_elements.size();
|
||||
while (*pos != ']') {
|
||||
auto char_pair = parse_char(pos);
|
||||
pos = char_pair.second;
|
||||
enum whisper_gretype type = last_sym_start < out_elements.size()
|
||||
? WHISPER_GRETYPE_CHAR_ALT
|
||||
: start_type;
|
||||
|
||||
out_elements.push_back({type, char_pair.first});
|
||||
if (pos[0] == '-' && pos[1] != ']') {
|
||||
auto endchar_pair = parse_char(pos + 1);
|
||||
pos = endchar_pair.second;
|
||||
out_elements.push_back({WHISPER_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
|
||||
}
|
||||
}
|
||||
pos = parse_space(pos + 1, is_nested);
|
||||
} else if (is_word_char(*pos)) { // rule reference
|
||||
const char * name_end = parse_name(pos);
|
||||
uint32_t ref_rule_id = get_symbol_id(state, pos, name_end - pos);
|
||||
pos = parse_space(name_end, is_nested);
|
||||
last_sym_start = out_elements.size();
|
||||
out_elements.push_back({WHISPER_GRETYPE_RULE_REF, ref_rule_id});
|
||||
} else if (*pos == '(') { // grouping
|
||||
// parse nested alternates into synthesized rule
|
||||
pos = parse_space(pos + 1, true);
|
||||
uint32_t sub_rule_id = generate_symbol_id(state, rule_name);
|
||||
pos = parse_alternates(state, pos, rule_name, sub_rule_id, true);
|
||||
last_sym_start = out_elements.size();
|
||||
// output reference to synthesized rule
|
||||
out_elements.push_back({WHISPER_GRETYPE_RULE_REF, sub_rule_id});
|
||||
if (*pos != ')') {
|
||||
throw std::runtime_error(std::string("expecting ')' at ") + pos);
|
||||
}
|
||||
pos = parse_space(pos + 1, is_nested);
|
||||
} else if (*pos == '*' || *pos == '+' || *pos == '?') { // repetition operator
|
||||
if (last_sym_start == out_elements.size()) {
|
||||
throw std::runtime_error(std::string("expecting preceeding item to */+/? at ") + pos);
|
||||
}
|
||||
|
||||
// apply transformation to previous symbol (last_sym_start to end) according to
|
||||
// rewrite rules:
|
||||
// S* --> S' ::= S S' |
|
||||
// S+ --> S' ::= S S' | S
|
||||
// S? --> S' ::= S |
|
||||
uint32_t sub_rule_id = generate_symbol_id(state, rule_name);
|
||||
std::vector<whisper_grammar_element> sub_rule;
|
||||
// add preceding symbol to generated rule
|
||||
sub_rule.insert(
|
||||
sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end());
|
||||
if (*pos == '*' || *pos == '+') {
|
||||
// cause generated rule to recurse
|
||||
sub_rule.push_back({WHISPER_GRETYPE_RULE_REF, sub_rule_id});
|
||||
}
|
||||
// mark start of alternate def
|
||||
sub_rule.push_back({WHISPER_GRETYPE_ALT, 0});
|
||||
if (*pos == '+') {
|
||||
// add preceding symbol as alternate only for '+' (otherwise empty)
|
||||
sub_rule.insert(
|
||||
sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end());
|
||||
}
|
||||
sub_rule.push_back({WHISPER_GRETYPE_END, 0});
|
||||
add_rule(state, sub_rule_id, sub_rule);
|
||||
|
||||
// in original rule, replace previous symbol with reference to generated rule
|
||||
out_elements.resize(last_sym_start);
|
||||
out_elements.push_back({WHISPER_GRETYPE_RULE_REF, sub_rule_id});
|
||||
|
||||
pos = parse_space(pos + 1, is_nested);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return pos;
|
||||
}
|
||||
|
||||
const char * parse_alternates(
|
||||
parse_state & state,
|
||||
const char * src,
|
||||
const std::string & rule_name,
|
||||
uint32_t rule_id,
|
||||
bool is_nested) {
|
||||
std::vector<whisper_grammar_element> rule;
|
||||
const char * pos = parse_sequence(state, src, rule_name, rule, is_nested);
|
||||
while (*pos == '|') {
|
||||
rule.push_back({WHISPER_GRETYPE_ALT, 0});
|
||||
pos = parse_space(pos + 1, true);
|
||||
pos = parse_sequence(state, pos, rule_name, rule, is_nested);
|
||||
}
|
||||
rule.push_back({WHISPER_GRETYPE_END, 0});
|
||||
add_rule(state, rule_id, rule);
|
||||
return pos;
|
||||
}
|
||||
|
||||
const char * parse_rule(parse_state & state, const char * src) {
|
||||
const char * name_end = parse_name(src);
|
||||
const char * pos = parse_space(name_end, false);
|
||||
size_t name_len = name_end - src;
|
||||
uint32_t rule_id = get_symbol_id(state, src, name_len);
|
||||
const std::string name(src, name_len);
|
||||
|
||||
if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
|
||||
throw std::runtime_error(std::string("expecting ::= at ") + pos);
|
||||
}
|
||||
pos = parse_space(pos + 3, true);
|
||||
|
||||
pos = parse_alternates(state, pos, name, rule_id, false);
|
||||
|
||||
if (*pos == '\r') {
|
||||
pos += pos[1] == '\n' ? 2 : 1;
|
||||
} else if (*pos == '\n') {
|
||||
pos++;
|
||||
} else if (*pos) {
|
||||
throw std::runtime_error(std::string("expecting newline or end at ") + pos);
|
||||
}
|
||||
return parse_space(pos, true);
|
||||
}
|
||||
|
||||
parse_state parse(const char * src) {
|
||||
try {
|
||||
parse_state state;
|
||||
const char * pos = parse_space(src, true);
|
||||
while (*pos) {
|
||||
pos = parse_rule(state, pos);
|
||||
}
|
||||
return state;
|
||||
} catch (const std::exception & err) {
|
||||
fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what());
|
||||
return parse_state();
|
||||
}
|
||||
}
|
||||
|
||||
void print_grammar_char(FILE * file, uint32_t c) {
|
||||
if (0x20 <= c && c <= 0x7f) {
|
||||
fprintf(file, "%c", static_cast<char>(c));
|
||||
} else {
|
||||
// cop out of encoding UTF-8
|
||||
fprintf(file, "<U+%04X>", c);
|
||||
}
|
||||
}
|
||||
|
||||
bool is_char_element(whisper_grammar_element elem) {
|
||||
switch (elem.type) {
|
||||
case WHISPER_GRETYPE_CHAR: return true;
|
||||
case WHISPER_GRETYPE_CHAR_NOT: return true;
|
||||
case WHISPER_GRETYPE_CHAR_ALT: return true;
|
||||
case WHISPER_GRETYPE_CHAR_RNG_UPPER: return true;
|
||||
default: return false;
|
||||
}
|
||||
}
|
||||
|
||||
void print_rule_binary(FILE * file, const std::vector<whisper_grammar_element> & rule) {
|
||||
for (auto elem : rule) {
|
||||
switch (elem.type) {
|
||||
case WHISPER_GRETYPE_END: fprintf(file, "END"); break;
|
||||
case WHISPER_GRETYPE_ALT: fprintf(file, "ALT"); break;
|
||||
case WHISPER_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break;
|
||||
case WHISPER_GRETYPE_CHAR: fprintf(file, "CHAR"); break;
|
||||
case WHISPER_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break;
|
||||
case WHISPER_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
|
||||
case WHISPER_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break;
|
||||
}
|
||||
switch (elem.type) {
|
||||
case WHISPER_GRETYPE_END:
|
||||
case WHISPER_GRETYPE_ALT:
|
||||
case WHISPER_GRETYPE_RULE_REF:
|
||||
fprintf(file, "(%u) ", elem.value);
|
||||
break;
|
||||
case WHISPER_GRETYPE_CHAR:
|
||||
case WHISPER_GRETYPE_CHAR_NOT:
|
||||
case WHISPER_GRETYPE_CHAR_RNG_UPPER:
|
||||
case WHISPER_GRETYPE_CHAR_ALT:
|
||||
fprintf(file, "(\"");
|
||||
print_grammar_char(file, elem.value);
|
||||
fprintf(file, "\") ");
|
||||
break;
|
||||
}
|
||||
}
|
||||
fprintf(file, "\n");
|
||||
}
|
||||
|
||||
void print_rule(
|
||||
FILE * file,
|
||||
uint32_t rule_id,
|
||||
const std::vector<whisper_grammar_element> & rule,
|
||||
const std::map<uint32_t, std::string> & symbol_id_names) {
|
||||
if (rule.empty() || rule.back().type != WHISPER_GRETYPE_END) {
|
||||
throw std::runtime_error(
|
||||
"malformed rule, does not end with WHISPER_GRETYPE_END: " + std::to_string(rule_id));
|
||||
}
|
||||
fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str());
|
||||
for (size_t i = 0, end = rule.size() - 1; i < end; i++) {
|
||||
whisper_grammar_element elem = rule[i];
|
||||
switch (elem.type) {
|
||||
case WHISPER_GRETYPE_END:
|
||||
throw std::runtime_error(
|
||||
"unexpected end of rule: " + std::to_string(rule_id) + "," +
|
||||
std::to_string(i));
|
||||
case WHISPER_GRETYPE_ALT:
|
||||
fprintf(file, "| ");
|
||||
break;
|
||||
case WHISPER_GRETYPE_RULE_REF:
|
||||
fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str());
|
||||
break;
|
||||
case WHISPER_GRETYPE_CHAR:
|
||||
fprintf(file, "[");
|
||||
print_grammar_char(file, elem.value);
|
||||
break;
|
||||
case WHISPER_GRETYPE_CHAR_NOT:
|
||||
fprintf(file, "[^");
|
||||
print_grammar_char(file, elem.value);
|
||||
break;
|
||||
case WHISPER_GRETYPE_CHAR_RNG_UPPER:
|
||||
if (i == 0 || !is_char_element(rule[i - 1])) {
|
||||
throw std::runtime_error(
|
||||
"WHISPER_GRETYPE_CHAR_RNG_UPPER without preceding char: " +
|
||||
std::to_string(rule_id) + "," + std::to_string(i));
|
||||
}
|
||||
fprintf(file, "-");
|
||||
print_grammar_char(file, elem.value);
|
||||
break;
|
||||
case WHISPER_GRETYPE_CHAR_ALT:
|
||||
if (i == 0 || !is_char_element(rule[i - 1])) {
|
||||
throw std::runtime_error(
|
||||
"WHISPER_GRETYPE_CHAR_ALT without preceding char: " +
|
||||
std::to_string(rule_id) + "," + std::to_string(i));
|
||||
}
|
||||
print_grammar_char(file, elem.value);
|
||||
break;
|
||||
}
|
||||
if (is_char_element(elem)) {
|
||||
switch (rule[i + 1].type) {
|
||||
case WHISPER_GRETYPE_CHAR_ALT:
|
||||
case WHISPER_GRETYPE_CHAR_RNG_UPPER:
|
||||
break;
|
||||
default:
|
||||
fprintf(file, "] ");
|
||||
}
|
||||
}
|
||||
}
|
||||
fprintf(file, "\n");
|
||||
}
|
||||
|
||||
void print_grammar(FILE * file, const parse_state & state) {
|
||||
try {
|
||||
std::map<uint32_t, std::string> symbol_id_names;
|
||||
for (auto kv : state.symbol_ids) {
|
||||
symbol_id_names[kv.second] = kv.first;
|
||||
}
|
||||
for (size_t i = 0, end = state.rules.size(); i < end; i++) {
|
||||
// fprintf(file, "%zu: ", i);
|
||||
// print_rule_binary(file, state.rules[i]);
|
||||
print_rule(file, uint32_t(i), state.rules[i], symbol_id_names);
|
||||
// fprintf(file, "\n");
|
||||
}
|
||||
} catch (const std::exception & err) {
|
||||
fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what());
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<const whisper_grammar_element *> parse_state::c_rules() const{
|
||||
std::vector<const whisper_grammar_element *> ret;
|
||||
for (const auto & rule : rules) {
|
||||
ret.push_back(rule.data());
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
}
|
29
examples/grammar-parser.h
Normal file
@ -0,0 +1,29 @@
|
||||
// Implements a parser for an extended Backus-Naur form (BNF), producing the
|
||||
// binary context-free grammar format specified by whisper.h. Supports character
|
||||
// ranges, grouping, and repetition operators. As an example, a grammar for
|
||||
// arithmetic might look like:
|
||||
//
|
||||
// root ::= expr
|
||||
// expr ::= term ([-+*/] term)*
|
||||
// term ::= num | "(" space expr ")" space
|
||||
// num ::= [0-9]+ space
|
||||
// space ::= [ \t\n]*
|
||||
|
||||
#pragma once
|
||||
#include "whisper.h"
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
|
||||
namespace grammar_parser {
|
||||
struct parse_state {
|
||||
std::map<std::string, uint32_t> symbol_ids;
|
||||
std::vector<std::vector<whisper_grammar_element>> rules;
|
||||
|
||||
std::vector<const whisper_grammar_element *> c_rules() const;
|
||||
};
|
||||
|
||||
parse_state parse(const char * src);
|
||||
void print_grammar(FILE * file, const parse_state & state);
|
||||
}
|
@ -22,6 +22,7 @@ var printTextarea = (function() {
|
||||
async function clearCache() {
|
||||
if (confirm('Are you sure you want to clear the cache?\nAll the models will be downloaded again.')) {
|
||||
indexedDB.deleteDatabase(dbName);
|
||||
location.reload();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -48,7 +48,7 @@ if [ -n "$3" ]; then
|
||||
fi
|
||||
|
||||
# Whisper models
|
||||
models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small" "medium.en" "medium" "large-v1" "large-v2" "large" )
|
||||
models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small" "medium.en" "medium" "large-v1" "large-v2" "large-v3" )
|
||||
|
||||
# list available models
|
||||
function list_models {
|
||||
|
@ -435,7 +435,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
// whisper init
|
||||
struct whisper_context_params cparams;
|
||||
struct whisper_context_params cparams = whisper_context_default_params();
|
||||
cparams.use_gpu = params.use_gpu;
|
||||
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
|
||||
// init audio
|
||||
|
@ -17,28 +17,37 @@ options:
|
||||
-d N, --duration N [0 ] duration of audio to process in milliseconds
|
||||
-mc N, --max-context N [-1 ] maximum number of text context tokens to store
|
||||
-ml N, --max-len N [0 ] maximum segment length in characters
|
||||
-sow, --split-on-word [false ] split on word rather than on token
|
||||
-bo N, --best-of N [5 ] number of best candidates to keep
|
||||
-bs N, --beam-size N [-1 ] beam size for beam search
|
||||
-bs N, --beam-size N [5 ] beam size for beam search
|
||||
-wt N, --word-thold N [0.01 ] word timestamp probability threshold
|
||||
-et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail
|
||||
-lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail
|
||||
-su, --speed-up [false ] speed up audio by x2 (reduced accuracy)
|
||||
-debug, --debug-mode [false ] enable debug mode (eg. dump log_mel)
|
||||
-tr, --translate [false ] translate from source language to english
|
||||
-di, --diarize [false ] stereo audio diarization
|
||||
-tdrz, --tinydiarize [false ] enable tinydiarize (requires a tdrz model)
|
||||
-nf, --no-fallback [false ] do not use temperature fallback while decoding
|
||||
-otxt, --output-txt [false ] output result in a text file
|
||||
-ovtt, --output-vtt [false ] output result in a vtt file
|
||||
-osrt, --output-srt [false ] output result in a srt file
|
||||
-olrc, --output-lrc [false ] output result in a lrc file
|
||||
-owts, --output-words [false ] output script for generating karaoke video
|
||||
-fp, --font-path [/System/Library/Fonts/Supplemental/Courier New Bold.ttf] path to a monospace font for karaoke video
|
||||
-ocsv, --output-csv [false ] output result in a CSV file
|
||||
-oj, --output-json [false ] output result in a JSON file
|
||||
-ojf, --output-json-full [false ] include more information in the JSON file
|
||||
-of FNAME, --output-file FNAME [ ] output file path (without file extension)
|
||||
-ps, --print-special [false ] print special tokens
|
||||
-pc, --print-colors [false ] print colors
|
||||
-pp, --print-progress [false ] print progress
|
||||
-nt, --no-timestamps [true ] do not print timestamps
|
||||
-nt, --no-timestamps [false ] do not print timestamps
|
||||
-l LANG, --language LANG [en ] spoken language ('auto' for auto-detect)
|
||||
-dl, --detect-language [false ] exit after automatically detecting language
|
||||
--prompt PROMPT [ ] initial prompt
|
||||
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path
|
||||
-f FNAME, --file FNAME [ ] input WAV file path
|
||||
-oved D, --ov-e-device DNAME [CPU ] the OpenVINO device used for encode inference
|
||||
-ls, --log-score [false ] log best decoder scores of tokens
|
||||
-ng, --no-gpu [false ] disable GPU
|
||||
```
|
||||
|
@ -62,8 +62,9 @@ struct whisper_params {
|
||||
int32_t progress_step = 5;
|
||||
int32_t max_context = -1;
|
||||
int32_t max_len = 0;
|
||||
int32_t best_of = 2;
|
||||
int32_t beam_size = -1;
|
||||
int32_t best_of = whisper_full_default_params(WHISPER_SAMPLING_GREEDY).greedy.best_of;
|
||||
int32_t beam_size = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH).beam_search.beam_size;
|
||||
int32_t audio_ctx = 0;
|
||||
|
||||
float word_thold = 0.01f;
|
||||
float entropy_thold = 2.40f;
|
||||
@ -85,6 +86,7 @@ struct whisper_params {
|
||||
bool output_jsn = false;
|
||||
bool output_jsn_full = false;
|
||||
bool output_lrc = false;
|
||||
bool no_prints = false;
|
||||
bool print_special = false;
|
||||
bool print_colors = false;
|
||||
bool print_progress = false;
|
||||
@ -135,6 +137,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); }
|
||||
else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); }
|
||||
else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); }
|
||||
else if (arg == "-ac" || arg == "--audio-context") { params.audio_ctx = std::stoi(argv[++i]); }
|
||||
else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); }
|
||||
@ -155,6 +158,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
else if (arg == "-oj" || arg == "--output-json") { params.output_jsn = true; }
|
||||
else if (arg == "-ojf" || arg == "--output-json-full"){ params.output_jsn_full = params.output_jsn = true; }
|
||||
else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(argv[++i]); }
|
||||
else if (arg == "-np" || arg == "--no-prints") { params.no_prints = true; }
|
||||
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
|
||||
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
|
||||
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
|
||||
@ -165,8 +169,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
|
||||
else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); }
|
||||
else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; }
|
||||
else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; }
|
||||
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
|
||||
else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; }
|
||||
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
|
||||
else {
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
whisper_print_usage(argc, argv, params);
|
||||
@ -193,6 +197,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
||||
fprintf(stderr, " -sow, --split-on-word [%-7s] split on word rather than on token\n", params.split_on_word ? "true" : "false");
|
||||
fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of);
|
||||
fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size);
|
||||
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
|
||||
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
|
||||
fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold);
|
||||
fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold);
|
||||
@ -212,6 +217,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
||||
fprintf(stderr, " -oj, --output-json [%-7s] output result in a JSON file\n", params.output_jsn ? "true" : "false");
|
||||
fprintf(stderr, " -ojf, --output-json-full [%-7s] include more information in the JSON file\n", params.output_jsn_full ? "true" : "false");
|
||||
fprintf(stderr, " -of FNAME, --output-file FNAME [%-7s] output file path (without file extension)\n", "");
|
||||
fprintf(stderr, " -np, --no-prints [%-7s] do not print anything other than the results\n", params.no_prints ? "true" : "false");
|
||||
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
||||
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
|
||||
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
|
||||
@ -852,6 +858,9 @@ bool output_lrc(struct whisper_context * ctx, const char * fname, const whisper_
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
void cb_log_disable(enum ggml_log_level , const char * , void * ) { }
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
whisper_params params;
|
||||
|
||||
@ -878,9 +887,13 @@ int main(int argc, char ** argv) {
|
||||
exit(0);
|
||||
}
|
||||
|
||||
if (params.no_prints) {
|
||||
whisper_log_set(cb_log_disable, NULL);
|
||||
}
|
||||
|
||||
// whisper init
|
||||
|
||||
struct whisper_context_params cparams;
|
||||
struct whisper_context_params cparams = whisper_context_default_params();
|
||||
cparams.use_gpu = params.use_gpu;
|
||||
|
||||
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
|
||||
@ -905,29 +918,28 @@ int main(int argc, char ** argv) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// print system information
|
||||
{
|
||||
if (!whisper_is_multilingual(ctx)) {
|
||||
if (params.language != "en" || params.translate) {
|
||||
params.language = "en";
|
||||
params.translate = false;
|
||||
fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
|
||||
}
|
||||
}
|
||||
if (params.detect_language) {
|
||||
params.language = "auto";
|
||||
}
|
||||
|
||||
if (!params.no_prints) {
|
||||
// print system information
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
|
||||
params.n_threads*params.n_processors, std::thread::hardware_concurrency(), whisper_print_system_info());
|
||||
}
|
||||
|
||||
// print some info about the processing
|
||||
{
|
||||
// print some info about the processing
|
||||
fprintf(stderr, "\n");
|
||||
if (!whisper_is_multilingual(ctx)) {
|
||||
if (params.language != "en" || params.translate) {
|
||||
params.language = "en";
|
||||
params.translate = false;
|
||||
fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
|
||||
}
|
||||
}
|
||||
if (params.detect_language) {
|
||||
params.language = "auto";
|
||||
}
|
||||
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, %stimestamps = %d ...\n",
|
||||
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, %d beams + best of %d, lang = %s, task = %s, %stimestamps = %d ...\n",
|
||||
__func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
|
||||
params.n_threads, params.n_processors,
|
||||
params.n_threads, params.n_processors, params.beam_size, params.best_of,
|
||||
params.language.c_str(),
|
||||
params.translate ? "translate" : "transcribe",
|
||||
params.tinydiarize ? "tdrz = 1, " : "",
|
||||
@ -958,6 +970,7 @@ int main(int argc, char ** argv) {
|
||||
wparams.thold_pt = params.word_thold;
|
||||
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
|
||||
wparams.split_on_word = params.split_on_word;
|
||||
wparams.audio_ctx = params.audio_ctx;
|
||||
|
||||
wparams.speed_up = params.speed_up;
|
||||
wparams.debug_mode = params.debug_mode;
|
||||
@ -973,6 +986,8 @@ int main(int argc, char ** argv) {
|
||||
wparams.entropy_thold = params.entropy_thold;
|
||||
wparams.logprob_thold = params.logprob_thold;
|
||||
|
||||
wparams.no_timestamps = params.no_timestamps;
|
||||
|
||||
whisper_print_user_data user_data = { ¶ms, &pcmf32s, 0 };
|
||||
|
||||
// this callback is called on each new segment
|
||||
|
7
examples/python/test_whisper_processor.py
Normal file
@ -0,0 +1,7 @@
|
||||
import whisper_processor
|
||||
|
||||
try:
|
||||
result = whisper_processor.process_audio("./audio/wake_word_detected16k.wav", "base.en")
|
||||
print(result)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
54
examples/python/whisper_processor.py
Normal file
@ -0,0 +1,54 @@
|
||||
import subprocess
|
||||
import sys
|
||||
import os
|
||||
|
||||
def process_audio(wav_file, model_name="base.en"):
|
||||
"""
|
||||
Processes an audio file using a specified model and returns the processed string.
|
||||
|
||||
:param wav_file: Path to the WAV file
|
||||
:param model_name: Name of the model to use
|
||||
:return: Processed string output from the audio processing
|
||||
:raises: Exception if an error occurs during processing
|
||||
"""
|
||||
|
||||
model = f"./models/ggml-{model_name}.bin"
|
||||
|
||||
# Check if the file exists
|
||||
if not os.path.exists(model):
|
||||
raise FileNotFoundError(f"Model file not found: {model} \n\nDownload a model with this command:\n\n> bash ./models/download-ggml-model.sh {model_name}\n\n")
|
||||
|
||||
if not os.path.exists(wav_file):
|
||||
raise FileNotFoundError(f"WAV file not found: {wav_file}")
|
||||
|
||||
full_command = f"./main -m {model} -f {wav_file} -np -nt"
|
||||
|
||||
# Execute the command
|
||||
process = subprocess.Popen(full_command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
|
||||
# Get the output and error (if any)
|
||||
output, error = process.communicate()
|
||||
|
||||
if error:
|
||||
raise Exception(f"Error processing audio: {error.decode('utf-8')}")
|
||||
|
||||
# Process and return the output string
|
||||
decoded_str = output.decode('utf-8').strip()
|
||||
processed_str = decoded_str.replace('[BLANK_AUDIO]', '').strip()
|
||||
|
||||
return processed_str
|
||||
|
||||
def main():
|
||||
if len(sys.argv) >= 2:
|
||||
wav_file = sys.argv[1]
|
||||
model_name = sys.argv[2] if len(sys.argv) == 3 else "base.en"
|
||||
try:
|
||||
result = process_audio(wav_file, model_name)
|
||||
print(result)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
else:
|
||||
print("Usage: python whisper_processor.py <wav_file> [<model_name>]")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
10
examples/server/CMakeLists.txt
Normal file
@ -0,0 +1,10 @@
|
||||
set(TARGET server)
|
||||
add_executable(${TARGET} server.cpp httplib.h json.hpp)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE common whisper ${CMAKE_THREAD_LIBS_INIT})
|
||||
|
||||
if (WIN32)
|
||||
target_link_libraries(${TARGET} PRIVATE ws2_32)
|
||||
endif()
|
69
examples/server/README.md
Normal file
@ -0,0 +1,69 @@
|
||||
# whisper.cpp http server
|
||||
|
||||
Simple http server. WAV Files are passed to the inference model via http requests.
|
||||
|
||||
https://github.com/ggerganov/whisper.cpp/assets/1991296/e983ee53-8741-4eb5-9048-afe5e4594b8f
|
||||
|
||||
## Usage
|
||||
|
||||
```
|
||||
./server -h
|
||||
|
||||
usage: ./bin/server [options]
|
||||
|
||||
options:
|
||||
-h, --help [default] show this help message and exit
|
||||
-t N, --threads N [4 ] number of threads to use during computation
|
||||
-p N, --processors N [1 ] number of processors to use during computation
|
||||
-ot N, --offset-t N [0 ] time offset in milliseconds
|
||||
-on N, --offset-n N [0 ] segment index offset
|
||||
-d N, --duration N [0 ] duration of audio to process in milliseconds
|
||||
-mc N, --max-context N [-1 ] maximum number of text context tokens to store
|
||||
-ml N, --max-len N [0 ] maximum segment length in characters
|
||||
-sow, --split-on-word [false ] split on word rather than on token
|
||||
-bo N, --best-of N [2 ] number of best candidates to keep
|
||||
-bs N, --beam-size N [-1 ] beam size for beam search
|
||||
-wt N, --word-thold N [0.01 ] word timestamp probability threshold
|
||||
-et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail
|
||||
-lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail
|
||||
-debug, --debug-mode [false ] enable debug mode (eg. dump log_mel)
|
||||
-tr, --translate [false ] translate from source language to english
|
||||
-di, --diarize [false ] stereo audio diarization
|
||||
-tdrz, --tinydiarize [false ] enable tinydiarize (requires a tdrz model)
|
||||
-nf, --no-fallback [false ] do not use temperature fallback while decoding
|
||||
-ps, --print-special [false ] print special tokens
|
||||
-pc, --print-colors [false ] print colors
|
||||
-pr, --print-realtime [false ] print output in realtime
|
||||
-pp, --print-progress [false ] print progress
|
||||
-nt, --no-timestamps [false ] do not print timestamps
|
||||
-l LANG, --language LANG [en ] spoken language ('auto' for auto-detect)
|
||||
-dl, --detect-language [false ] exit after automatically detecting language
|
||||
--prompt PROMPT [ ] initial prompt
|
||||
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path
|
||||
-oved D, --ov-e-device DNAME [CPU ] the OpenVINO device used for encode inference
|
||||
--host HOST, [127.0.0.1] Hostname/ip-adress for the server
|
||||
--port PORT, [8080 ] Port number for the server
|
||||
--convert, [false ] Convert audio to WAV, requires ffmpeg on the server
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> **Do not run the server example with administrative privileges and ensure it's operated in a sandbox environment, especially since it involves risky operations like accepting user file uploads and using ffmpeg for format conversions. Always validate and sanitize inputs to guard against potential security threats.**
|
||||
|
||||
## request examples
|
||||
|
||||
**/inference**
|
||||
```
|
||||
curl 127.0.0.1:8080/inference \
|
||||
-H "Content-Type: multipart/form-data" \
|
||||
-F file="@<file-path>" \
|
||||
-F temperature="0.0" \
|
||||
-F temperature_inc="0.2" \
|
||||
-F response_format="json"
|
||||
```
|
||||
|
||||
**/load**
|
||||
```
|
||||
curl 127.0.0.1:8080/load \
|
||||
-H "Content-Type: multipart/form-data" \
|
||||
-F model="<path-to-model-file>"
|
||||
```
|
9262
examples/server/httplib.h
Normal file
24596
examples/server/json.hpp
Normal file
1022
examples/server/server.cpp
Normal file
@ -4,7 +4,7 @@ This is a naive example of performing real-time inference on audio from your mic
|
||||
The `stream` tool samples the audio every half a second and runs the transcription continously.
|
||||
More info is available in [issue #10](https://github.com/ggerganov/whisper.cpp/issues/10).
|
||||
|
||||
```java
|
||||
```bash
|
||||
./stream -m ./models/ggml-base.en.bin -t 8 --step 500 --length 5000
|
||||
```
|
||||
|
||||
@ -14,7 +14,7 @@ https://user-images.githubusercontent.com/1991296/194935793-76afede7-cfa8-48d8-a
|
||||
|
||||
Setting the `--step` argument to `0` enables the sliding window mode:
|
||||
|
||||
```java
|
||||
```bash
|
||||
./stream -m ./models/ggml-small.en.bin -t 6 --step 0 --length 30000 -vth 0.6
|
||||
```
|
||||
|
||||
@ -39,8 +39,8 @@ brew install sdl2
|
||||
make stream
|
||||
```
|
||||
|
||||
Ensure you are at the root of the repo when running `make stream`. Not within the `examples/stream` dir
|
||||
as the libraries needed like `common-sdl.h` are located within `examples`. Attempting to compile within
|
||||
Ensure you are at the root of the repo when running `make stream`. Not within the `examples/stream` dir
|
||||
as the libraries needed like `common-sdl.h` are located within `examples`. Attempting to compile within
|
||||
`examples/steam` means your compiler cannot find them and it gives an error it cannot find the file.
|
||||
|
||||
```bash
|
||||
|
@ -166,7 +166,7 @@ int main(int argc, char ** argv) {
|
||||
exit(0);
|
||||
}
|
||||
|
||||
struct whisper_context_params cparams;
|
||||
struct whisper_context_params cparams = whisper_context_default_params();
|
||||
cparams.use_gpu = params.use_gpu;
|
||||
|
||||
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
|
||||
|
@ -1,25 +1,18 @@
|
||||
if (WHISPER_SDL2)
|
||||
# talk-llama
|
||||
set(TARGET talk-llama)
|
||||
#add_executable(${TARGET} talk-llama.cpp llama.cpp)
|
||||
#target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
|
||||
#target_link_libraries(${TARGET} PRIVATE common common-sdl whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
|
||||
add_executable(${TARGET} talk-llama.cpp llama.cpp)
|
||||
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
|
||||
|
||||
# TODO: this is temporary
|
||||
# need to export ggml symbols for MSVC, but too lazy ..
|
||||
add_executable(${TARGET}
|
||||
talk-llama.cpp
|
||||
llama.cpp
|
||||
../common.cpp
|
||||
../common-sdl.cpp
|
||||
../../ggml.c
|
||||
../../ggml-alloc.c
|
||||
../../ggml-backend.c
|
||||
../../ggml-quants.c
|
||||
../../whisper.cpp)
|
||||
if (WHISPER_CLBLAST)
|
||||
set(CLBLAST_LIBNAME clblast)
|
||||
endif ()
|
||||
target_link_libraries(${TARGET} PRIVATE common common-sdl whisper ${SDL2_LIBRARIES} ${CLBLAST_LIBNAME} ${CMAKE_THREAD_LIBS_INIT})
|
||||
|
||||
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS} ../../)
|
||||
target_link_libraries(${TARGET} PRIVATE ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
|
||||
if(WIN32)
|
||||
# It requires Windows 8.1 or later for PrefetchVirtualMemory
|
||||
target_compile_definitions(${TARGET} PRIVATE -D_WIN32_WINNT=0x0602)
|
||||
endif()
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
endif ()
|
||||
|
@ -2,12 +2,8 @@
|
||||
#define LLAMA_H
|
||||
|
||||
#include "ggml.h"
|
||||
#ifdef GGML_USE_CUBLAS
|
||||
#include "ggml-cuda.h"
|
||||
#define LLAMA_MAX_DEVICES GGML_CUDA_MAX_DEVICES
|
||||
#else
|
||||
#define LLAMA_MAX_DEVICES 1
|
||||
#endif // GGML_USE_CUBLAS
|
||||
#include "ggml-backend.h"
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
@ -39,15 +35,11 @@
|
||||
|
||||
#define LLAMA_MAX_RNG_STATE (64*1024)
|
||||
|
||||
#define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
|
||||
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
|
||||
|
||||
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
||||
#define LLAMA_SESSION_VERSION 2
|
||||
|
||||
#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL)
|
||||
// Defined when llama.cpp is compiled with support for offloading model layers to GPU.
|
||||
#define LLAMA_SUPPORTS_GPU_OFFLOAD
|
||||
#endif
|
||||
#define LLAMA_SESSION_VERSION 4
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
@ -69,6 +61,7 @@ extern "C" {
|
||||
enum llama_vocab_type {
|
||||
LLAMA_VOCAB_TYPE_SPM = 0, // SentencePiece
|
||||
LLAMA_VOCAB_TYPE_BPE = 1, // Byte Pair Encoding
|
||||
LLAMA_VOCAB_TYPE_WPM = 2, // WordPiece
|
||||
};
|
||||
|
||||
enum llama_token_type {
|
||||
@ -102,6 +95,11 @@ extern "C" {
|
||||
LLAMA_FTYPE_MOSTLY_Q5_K_S = 16, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q5_K_M = 17, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q6_K = 18, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_IQ2_XXS = 19, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_IQ2_XS = 20, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q2_K_S = 21, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q3_K_XS = 22, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_IQ3_XXS = 23, // except 1d tensors
|
||||
|
||||
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
|
||||
};
|
||||
@ -114,6 +112,12 @@ extern "C" {
|
||||
LLAMA_ROPE_SCALING_MAX_VALUE = LLAMA_ROPE_SCALING_YARN,
|
||||
};
|
||||
|
||||
enum llama_split_mode {
|
||||
LLAMA_SPLIT_NONE = 0, // single GPU
|
||||
LLAMA_SPLIT_LAYER = 1, // split layers and KV across GPUs
|
||||
LLAMA_SPLIT_ROW = 2, // split rows across GPUs
|
||||
};
|
||||
|
||||
typedef struct llama_token_data {
|
||||
llama_token id; // token id
|
||||
float logit; // log-odds of the token
|
||||
@ -126,7 +130,7 @@ extern "C" {
|
||||
bool sorted;
|
||||
} llama_token_data_array;
|
||||
|
||||
typedef void (*llama_progress_callback)(float progress, void *ctx);
|
||||
typedef bool (*llama_progress_callback)(float progress, void *ctx);
|
||||
|
||||
// Input data for llama_decode
|
||||
// A llama_batch object can contain input about one or many sequences
|
||||
@ -158,16 +162,46 @@ extern "C" {
|
||||
llama_seq_id all_seq_id; // used if seq_id == NULL
|
||||
} llama_batch;
|
||||
|
||||
enum llama_model_kv_override_type {
|
||||
LLAMA_KV_OVERRIDE_INT,
|
||||
LLAMA_KV_OVERRIDE_FLOAT,
|
||||
LLAMA_KV_OVERRIDE_BOOL,
|
||||
};
|
||||
|
||||
struct llama_model_kv_override {
|
||||
char key[128];
|
||||
enum llama_model_kv_override_type tag;
|
||||
union {
|
||||
int64_t int_value;
|
||||
double float_value;
|
||||
bool bool_value;
|
||||
};
|
||||
};
|
||||
|
||||
struct llama_model_params {
|
||||
int32_t n_gpu_layers; // number of layers to store in VRAM
|
||||
int32_t main_gpu; // the GPU that is used for scratch and small tensors
|
||||
const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES)
|
||||
enum llama_split_mode split_mode; // how to split the model across multiple GPUs
|
||||
|
||||
// called with a progress value between 0 and 1, pass NULL to disable
|
||||
// main_gpu interpretation depends on split_mode:
|
||||
// LLAMA_SPLIT_NONE: the GPU that is used for the entire model
|
||||
// LLAMA_SPLIT_ROW: the GPU that is used for small tensors and intermediate results
|
||||
// LLAMA_SPLIT_LAYER: ignored
|
||||
int32_t main_gpu;
|
||||
|
||||
// proportion of the model (layers or rows) to offload to each GPU, size: llama_max_devices()
|
||||
const float * tensor_split;
|
||||
|
||||
// Called with a progress value between 0.0 and 1.0. Pass NULL to disable.
|
||||
// If the provided progress_callback returns true, model loading continues.
|
||||
// If it returns false, model loading is immediately aborted.
|
||||
llama_progress_callback progress_callback;
|
||||
|
||||
// context pointer passed to the progress callback
|
||||
void * progress_callback_user_data;
|
||||
|
||||
// override key-value pairs of the model meta data
|
||||
const struct llama_model_kv_override * kv_overrides;
|
||||
|
||||
// Keep the booleans together to avoid misalignment during copy-by-value.
|
||||
bool vocab_only; // only load the vocabulary, no weights
|
||||
bool use_mmap; // use mmap if possible
|
||||
@ -180,32 +214,39 @@ extern "C" {
|
||||
uint32_t n_batch; // prompt processing maximum batch size
|
||||
uint32_t n_threads; // number of threads to use for generation
|
||||
uint32_t n_threads_batch; // number of threads to use for batch processing
|
||||
int8_t rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
|
||||
int32_t rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
|
||||
|
||||
// ref: https://github.com/ggerganov/llama.cpp/pull/2054
|
||||
float rope_freq_base; // RoPE base frequency, 0 = from model
|
||||
float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model
|
||||
float yarn_ext_factor; // YaRN extrapolation mix factor, NaN = from model
|
||||
float yarn_ext_factor; // YaRN extrapolation mix factor, negative = from model
|
||||
float yarn_attn_factor; // YaRN magnitude scaling factor
|
||||
float yarn_beta_fast; // YaRN low correction dim
|
||||
float yarn_beta_slow; // YaRN high correction dim
|
||||
uint32_t yarn_orig_ctx; // YaRN original context size
|
||||
|
||||
ggml_backend_sched_eval_callback cb_eval;
|
||||
void * cb_eval_user_data;
|
||||
|
||||
enum ggml_type type_k; // data type for K cache
|
||||
enum ggml_type type_v; // data type for V cache
|
||||
|
||||
// Keep the booleans together to avoid misalignment during copy-by-value.
|
||||
bool mul_mat_q; // if true, use experimental mul_mat_q kernels (DEPRECATED - always true)
|
||||
bool f16_kv; // use fp16 for KV cache, fp32 otherwise
|
||||
bool logits_all; // the llama_eval() call computes all logits, not just the last one
|
||||
bool embedding; // embedding mode only
|
||||
bool mul_mat_q; // if true, use experimental mul_mat_q kernels (DEPRECATED - always true)
|
||||
bool logits_all; // the llama_eval() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
|
||||
bool embedding; // embedding mode only
|
||||
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
|
||||
};
|
||||
|
||||
// model quantization parameters
|
||||
typedef struct llama_model_quantize_params {
|
||||
int nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency()
|
||||
int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency()
|
||||
enum llama_ftype ftype; // quantize to this llama_ftype
|
||||
bool allow_requantize; // allow quantizing non-f32/f16 tensors
|
||||
bool quantize_output_tensor; // quantize output.weight
|
||||
bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
|
||||
bool pure; // disable k-quant mixtures and quantize all tensors to the same type
|
||||
void * imatrix; // pointer to importance matrix data
|
||||
} llama_model_quantize_params;
|
||||
|
||||
// grammar types
|
||||
@ -284,25 +325,48 @@ extern "C" {
|
||||
|
||||
LLAMA_API int64_t llama_time_us(void);
|
||||
|
||||
LLAMA_API int llama_max_devices (void);
|
||||
LLAMA_API bool llama_mmap_supported (void);
|
||||
LLAMA_API bool llama_mlock_supported(void);
|
||||
LLAMA_API size_t llama_max_devices(void);
|
||||
|
||||
LLAMA_API bool llama_supports_mmap (void);
|
||||
LLAMA_API bool llama_supports_mlock (void);
|
||||
LLAMA_API bool llama_supports_gpu_offload(void);
|
||||
|
||||
LLAMA_API DEPRECATED(bool llama_mmap_supported (void), "use llama_supports_mmap() instead");
|
||||
LLAMA_API DEPRECATED(bool llama_mlock_supported(void), "use llama_supports_mlock() instead");
|
||||
|
||||
LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);
|
||||
|
||||
LLAMA_API int llama_n_ctx (const struct llama_context * ctx);
|
||||
LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
|
||||
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
|
||||
|
||||
LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model);
|
||||
|
||||
LLAMA_API int llama_n_vocab (const struct llama_model * model);
|
||||
LLAMA_API int llama_n_ctx_train(const struct llama_model * model);
|
||||
LLAMA_API int llama_n_embd (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_n_embd (const struct llama_model * model);
|
||||
|
||||
// Get the model's RoPE frequency scaling factor
|
||||
LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model);
|
||||
|
||||
// Functions to access the model's GGUF metadata scalar values
|
||||
// - The functions return the length of the string on success, or -1 on failure
|
||||
// - The output string is always null-terminated and cleared on failure
|
||||
// - GGUF array values are not supported by these functions
|
||||
|
||||
// Get metadata value as a string by key name
|
||||
LLAMA_API int32_t llama_model_meta_val_str(const struct llama_model * model, const char * key, char * buf, size_t buf_size);
|
||||
|
||||
// Get the number of metadata key/value pairs
|
||||
LLAMA_API int32_t llama_model_meta_count(const struct llama_model * model);
|
||||
|
||||
// Get metadata key name by index
|
||||
LLAMA_API int32_t llama_model_meta_key_by_index(const struct llama_model * model, int32_t i, char * buf, size_t buf_size);
|
||||
|
||||
// Get metadata value as a string by index
|
||||
LLAMA_API int32_t llama_model_meta_val_str_by_index(const struct llama_model * model, int32_t i, char * buf, size_t buf_size);
|
||||
|
||||
// Get a string describing the model type
|
||||
LLAMA_API int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size);
|
||||
LLAMA_API int32_t llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size);
|
||||
|
||||
// Returns the total size of all the tensors in the model in bytes
|
||||
LLAMA_API uint64_t llama_model_size(const struct llama_model * model);
|
||||
@ -314,7 +378,7 @@ extern "C" {
|
||||
LLAMA_API struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name);
|
||||
|
||||
// Returns 0 on success
|
||||
LLAMA_API int llama_model_quantize(
|
||||
LLAMA_API uint32_t llama_model_quantize(
|
||||
const char * fname_inp,
|
||||
const char * fname_out,
|
||||
const llama_model_quantize_params * params);
|
||||
@ -325,28 +389,79 @@ extern "C" {
|
||||
// The model needs to be reloaded before applying a new adapter, otherwise the adapter
|
||||
// will be applied on top of the previous one
|
||||
// Returns 0 on success
|
||||
LLAMA_API DEPRECATED(int llama_apply_lora_from_file(
|
||||
LLAMA_API DEPRECATED(int32_t llama_apply_lora_from_file(
|
||||
struct llama_context * ctx,
|
||||
const char * path_lora,
|
||||
float scale,
|
||||
const char * path_base_model,
|
||||
int n_threads),
|
||||
int32_t n_threads),
|
||||
"use llama_model_apply_lora_from_file instead");
|
||||
|
||||
LLAMA_API int llama_model_apply_lora_from_file(
|
||||
LLAMA_API int32_t llama_model_apply_lora_from_file(
|
||||
const struct llama_model * model,
|
||||
const char * path_lora,
|
||||
float scale,
|
||||
const char * path_base_model,
|
||||
int n_threads);
|
||||
int32_t n_threads);
|
||||
|
||||
//
|
||||
// KV cache
|
||||
//
|
||||
|
||||
// Returns the number of tokens in the KV cache
|
||||
LLAMA_API DEPRECATED(int llama_get_kv_cache_token_count(const struct llama_context * ctx),
|
||||
"avoid using this, it will be removed in the future, instead - count the tokens in user code");
|
||||
// Information associated with an individual cell in the KV cache view.
|
||||
struct llama_kv_cache_view_cell {
|
||||
// The position for this cell. Takes KV cache shifts into account.
|
||||
// May be negative if the cell is not populated.
|
||||
llama_pos pos;
|
||||
};
|
||||
|
||||
// An updateable view of the KV cache.
|
||||
struct llama_kv_cache_view {
|
||||
// Number of KV cache cells. This will be the same as the context size.
|
||||
int32_t n_cells;
|
||||
|
||||
// Maximum number of sequences that can exist in a cell. It's not an error
|
||||
// if there are more sequences in a cell than this value, however they will
|
||||
// not be visible in the view cells_sequences.
|
||||
int32_t n_max_seq;
|
||||
|
||||
// Number of tokens in the cache. For example, if there are two populated
|
||||
// cells, the first with 1 sequence id in it and the second with 2 sequence
|
||||
// ids then you'll have 3 tokens.
|
||||
int32_t token_count;
|
||||
|
||||
// Number of populated cache cells.
|
||||
int32_t used_cells;
|
||||
|
||||
// Maximum contiguous empty slots in the cache.
|
||||
int32_t max_contiguous;
|
||||
|
||||
// Index to the start of the max_contiguous slot range. Can be negative
|
||||
// when cache is full.
|
||||
int32_t max_contiguous_idx;
|
||||
|
||||
// Information for an individual cell.
|
||||
struct llama_kv_cache_view_cell * cells;
|
||||
|
||||
// The sequences for each cell. There will be n_max_seq items per cell.
|
||||
llama_seq_id * cells_sequences;
|
||||
};
|
||||
|
||||
// Create an empty KV cache view. (use only for debugging purposes)
|
||||
LLAMA_API struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_max_seq);
|
||||
|
||||
// Free a KV cache view. (use only for debugging purposes)
|
||||
LLAMA_API void llama_kv_cache_view_free(struct llama_kv_cache_view * view);
|
||||
|
||||
// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
|
||||
LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view);
|
||||
|
||||
// Returns the number of tokens in the KV cache (slow, use only for debug)
|
||||
// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
|
||||
LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx);
|
||||
|
||||
// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
|
||||
LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx);
|
||||
|
||||
// Clear the KV cache
|
||||
LLAMA_API void llama_kv_cache_clear(
|
||||
@ -389,6 +504,17 @@ extern "C" {
|
||||
llama_pos p1,
|
||||
llama_pos delta);
|
||||
|
||||
// Integer division of the positions by factor of `d > 1`
|
||||
// If the KV cache is RoPEd, the KV data is updated accordingly
|
||||
// p0 < 0 : [0, p1]
|
||||
// p1 < 0 : [p0, inf)
|
||||
LLAMA_API void llama_kv_cache_seq_div(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
int d);
|
||||
|
||||
//
|
||||
// State / sessions
|
||||
//
|
||||
@ -437,7 +563,7 @@ extern "C" {
|
||||
struct llama_context * ctx,
|
||||
llama_token * tokens,
|
||||
int32_t n_tokens,
|
||||
int n_past),
|
||||
int32_t n_past),
|
||||
"use llama_decode() instead");
|
||||
|
||||
// Same as llama_eval, but use float matrix input directly.
|
||||
@ -446,7 +572,7 @@ extern "C" {
|
||||
struct llama_context * ctx,
|
||||
float * embd,
|
||||
int32_t n_tokens,
|
||||
int n_past),
|
||||
int32_t n_past),
|
||||
"use llama_decode() instead");
|
||||
|
||||
// Return batch for single sequence of tokens starting at pos_0
|
||||
@ -478,7 +604,7 @@ extern "C" {
|
||||
// 0 - success
|
||||
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
|
||||
// < 0 - error
|
||||
LLAMA_API int llama_decode(
|
||||
LLAMA_API int32_t llama_decode(
|
||||
struct llama_context * ctx,
|
||||
struct llama_batch batch);
|
||||
|
||||
@ -517,6 +643,12 @@ extern "C" {
|
||||
LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence
|
||||
LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line
|
||||
|
||||
// Returns -1 if unknown, 1 for true or 0 for false.
|
||||
LLAMA_API int32_t llama_add_bos_token(const struct llama_model * model);
|
||||
|
||||
// Returns -1 if unknown, 1 for true or 0 for false.
|
||||
LLAMA_API int32_t llama_add_eos_token(const struct llama_model * model);
|
||||
|
||||
// codellama infill tokens
|
||||
LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix
|
||||
LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle
|
||||
@ -533,12 +665,12 @@ extern "C" {
|
||||
/// @return Returns a negative number on failure - the number of tokens that would have been returned
|
||||
/// @param special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.
|
||||
/// Does not insert a leading space.
|
||||
LLAMA_API int llama_tokenize(
|
||||
LLAMA_API int32_t llama_tokenize(
|
||||
const struct llama_model * model,
|
||||
const char * text,
|
||||
int text_len,
|
||||
int32_t text_len,
|
||||
llama_token * tokens,
|
||||
int n_max_tokens,
|
||||
int32_t n_max_tokens,
|
||||
bool add_bos,
|
||||
bool special);
|
||||
|
||||
@ -546,11 +678,11 @@ extern "C" {
|
||||
// Uses the vocabulary in the provided context.
|
||||
// Does not write null terminator to the buffer.
|
||||
// User code is responsible to remove the leading whitespace of the first non-BOS token when decoding multiple tokens.
|
||||
LLAMA_API int llama_token_to_piece(
|
||||
LLAMA_API int32_t llama_token_to_piece(
|
||||
const struct llama_model * model,
|
||||
llama_token token,
|
||||
char * buf,
|
||||
int length);
|
||||
int32_t length);
|
||||
|
||||
//
|
||||
// Grammar
|
||||
@ -584,14 +716,21 @@ extern "C" {
|
||||
float penalty_present);
|
||||
|
||||
/// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
|
||||
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted.
|
||||
/// @params guidance_ctx A separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.
|
||||
/// @params scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance.
|
||||
LLAMA_API void llama_sample_classifier_free_guidance(
|
||||
/// @param logits Logits extracted from the original generation context.
|
||||
/// @param logits_guidance Logits extracted from a separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.
|
||||
/// @param scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance.
|
||||
LLAMA_API void llama_sample_apply_guidance(
|
||||
struct llama_context * ctx,
|
||||
float * logits,
|
||||
float * logits_guidance,
|
||||
float scale);
|
||||
|
||||
LLAMA_API DEPRECATED(void llama_sample_classifier_free_guidance(
|
||||
struct llama_context * ctx,
|
||||
llama_token_data_array * candidates,
|
||||
struct llama_context * guidance_ctx,
|
||||
float scale);
|
||||
float scale),
|
||||
"use llama_sample_apply_guidance() instead");
|
||||
|
||||
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
|
||||
LLAMA_API void llama_sample_softmax(
|
||||
@ -602,7 +741,7 @@ extern "C" {
|
||||
LLAMA_API void llama_sample_top_k(
|
||||
struct llama_context * ctx,
|
||||
llama_token_data_array * candidates,
|
||||
int k,
|
||||
int32_t k,
|
||||
size_t min_keep);
|
||||
|
||||
/// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
||||
@ -633,6 +772,14 @@ extern "C" {
|
||||
float p,
|
||||
size_t min_keep);
|
||||
|
||||
/// @details Dynamic temperature implementation described in the paper https://arxiv.org/abs/2309.02772.
|
||||
LLAMA_API void llama_sample_entropy(
|
||||
struct llama_context * ctx,
|
||||
llama_token_data_array * candidates_p,
|
||||
float min_temp,
|
||||
float max_temp,
|
||||
float exponent_val);
|
||||
|
||||
LLAMA_API void llama_sample_temp(
|
||||
struct llama_context * ctx,
|
||||
llama_token_data_array * candidates,
|
||||
@ -661,7 +808,7 @@ extern "C" {
|
||||
llama_token_data_array * candidates,
|
||||
float tau,
|
||||
float eta,
|
||||
int m,
|
||||
int32_t m,
|
||||
float * mu);
|
||||
|
||||
/// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
||||
@ -734,8 +881,8 @@ extern "C" {
|
||||
llama_beam_search_callback_fn_t callback,
|
||||
void * callback_data,
|
||||
size_t n_beams,
|
||||
int n_past,
|
||||
int n_predict);
|
||||
int32_t n_past,
|
||||
int32_t n_predict);
|
||||
|
||||
// Performance information
|
||||
LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx);
|
||||
|
@ -9,6 +9,14 @@
|
||||
#
|
||||
#espeak -v en-us+m$1 -s 225 -p 50 -a 200 -g 5 -k 5 "$2"
|
||||
|
||||
# piper
|
||||
#
|
||||
# https://github.com/rhasspy/piper
|
||||
#
|
||||
# Tested with Linux:
|
||||
#
|
||||
#echo "$2" | piper --model ~/en_US-lessac-medium.onnx --output-raw | aplay -q -r 22050 -f S16_LE -t raw -
|
||||
|
||||
# for Mac
|
||||
say "$2"
|
||||
|
||||
|
@ -14,6 +14,7 @@
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <regex>
|
||||
#include <sstream>
|
||||
|
||||
std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos) {
|
||||
auto * model = llama_get_model(ctx);
|
||||
@ -53,6 +54,7 @@ struct whisper_params {
|
||||
int32_t capture_id = -1;
|
||||
int32_t max_tokens = 32;
|
||||
int32_t audio_ctx = 0;
|
||||
int32_t n_gpu_layers = 999;
|
||||
|
||||
float vad_thold = 0.6f;
|
||||
float freq_thold = 100.0f;
|
||||
@ -66,6 +68,9 @@ struct whisper_params {
|
||||
bool use_gpu = true;
|
||||
|
||||
std::string person = "Georgi";
|
||||
std::string bot_name = "LLaMA";
|
||||
std::string wake_cmd = "";
|
||||
std::string heard_ok = "";
|
||||
std::string language = "en";
|
||||
std::string model_wsp = "models/ggml-base.en.bin";
|
||||
std::string model_llama = "models/ggml-llama-7B.bin";
|
||||
@ -90,6 +95,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
else if (arg == "-c" || arg == "--capture") { params.capture_id = std::stoi(argv[++i]); }
|
||||
else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); }
|
||||
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
|
||||
else if (arg == "-ngl" || arg == "--n-gpu-layers") { params.n_gpu_layers = std::stoi(argv[++i]); }
|
||||
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
|
||||
@ -99,7 +105,10 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
else if (arg == "-vp" || arg == "--verbose-prompt") { params.verbose_prompt = true; }
|
||||
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
|
||||
else if (arg == "-p" || arg == "--person") { params.person = argv[++i]; }
|
||||
else if (arg == "--session") { params.path_session = argv[++i];}
|
||||
else if (arg == "-bn" || arg == "--bot-name") { params.bot_name = argv[++i]; }
|
||||
else if (arg == "--session") { params.path_session = argv[++i]; }
|
||||
else if (arg == "-w" || arg == "--wake-command") { params.wake_cmd = argv[++i]; }
|
||||
else if (arg == "-ho" || arg == "--heard-ok") { params.heard_ok = argv[++i]; }
|
||||
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
|
||||
else if (arg == "-mw" || arg == "--model-whisper") { params.model_wsp = argv[++i]; }
|
||||
else if (arg == "-ml" || arg == "--model-llama") { params.model_llama = argv[++i]; }
|
||||
@ -134,6 +143,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
||||
fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
|
||||
fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
|
||||
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
|
||||
fprintf(stderr, " -ngl N, --n-gpu-layers N [%-7d] number of layers to store in VRAM\n", params.n_gpu_layers);
|
||||
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
|
||||
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
|
||||
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
|
||||
@ -143,6 +153,9 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
||||
fprintf(stderr, " -vp, --verbose-prompt [%-7s] print prompt at start\n", params.verbose_prompt ? "true" : "false");
|
||||
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
|
||||
fprintf(stderr, " -p NAME, --person NAME [%-7s] person name (for prompt selection)\n", params.person.c_str());
|
||||
fprintf(stderr, " -bn NAME, --bot-name NAME [%-7s] bot name (to display)\n", params.bot_name.c_str());
|
||||
fprintf(stderr, " -w TEXT, --wake-command T [%-7s] wake-up command to listen for\n", params.wake_cmd.c_str());
|
||||
fprintf(stderr, " -ho TEXT, --heard-ok TEXT [%-7s] said by TTS before generating reply\n", params.heard_ok.c_str());
|
||||
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
|
||||
fprintf(stderr, " -mw FILE, --model-whisper [%-7s] whisper model file\n", params.model_wsp.c_str());
|
||||
fprintf(stderr, " -ml FILE, --model-llama [%-7s] llama model file\n", params.model_llama.c_str());
|
||||
@ -221,6 +234,18 @@ std::string transcribe(
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<std::string> get_words(const std::string &txt) {
|
||||
std::vector<std::string> words;
|
||||
|
||||
std::istringstream iss(txt);
|
||||
std::string word;
|
||||
while (iss >> word) {
|
||||
words.push_back(word);
|
||||
}
|
||||
|
||||
return words;
|
||||
}
|
||||
|
||||
const std::string k_prompt_whisper = R"(A conversation with a person called {1}.)";
|
||||
|
||||
const std::string k_prompt_llama = R"(Text transcript of a never ending dialog, where {0} interacts with an AI assistant named {1}.
|
||||
@ -256,7 +281,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// whisper init
|
||||
|
||||
struct whisper_context_params cparams;
|
||||
struct whisper_context_params cparams = whisper_context_default_params();
|
||||
cparams.use_gpu = params.use_gpu;
|
||||
|
||||
struct whisper_context * ctx_wsp = whisper_init_from_file_with_params(params.model_wsp.c_str(), cparams);
|
||||
@ -268,6 +293,8 @@ int main(int argc, char ** argv) {
|
||||
auto lmparams = llama_model_default_params();
|
||||
if (!params.use_gpu) {
|
||||
lmparams.n_gpu_layers = 0;
|
||||
} else {
|
||||
lmparams.n_gpu_layers = params.n_gpu_layers;
|
||||
}
|
||||
|
||||
struct llama_model * model_llama = llama_load_model_from_file(params.model_llama.c_str(), lmparams);
|
||||
@ -277,7 +304,6 @@ int main(int argc, char ** argv) {
|
||||
// tune these to your liking
|
||||
lcparams.n_ctx = 2048;
|
||||
lcparams.seed = 1;
|
||||
lcparams.f16_kv = true;
|
||||
lcparams.n_threads = params.n_threads;
|
||||
|
||||
struct llama_context * ctx_llama = llama_new_context_with_model(model_llama, lcparams);
|
||||
@ -319,12 +345,11 @@ int main(int argc, char ** argv) {
|
||||
float prob0 = 0.0f;
|
||||
|
||||
const std::string chat_symb = ":";
|
||||
const std::string bot_name = "LLaMA";
|
||||
|
||||
std::vector<float> pcmf32_cur;
|
||||
std::vector<float> pcmf32_prompt;
|
||||
|
||||
const std::string prompt_whisper = ::replace(k_prompt_whisper, "{1}", bot_name);
|
||||
const std::string prompt_whisper = ::replace(k_prompt_whisper, "{1}", params.bot_name);
|
||||
|
||||
// construct the initial prompt for LLaMA inference
|
||||
std::string prompt_llama = params.prompt.empty() ? k_prompt_llama : params.prompt;
|
||||
@ -333,7 +358,7 @@ int main(int argc, char ** argv) {
|
||||
prompt_llama.insert(0, 1, ' ');
|
||||
|
||||
prompt_llama = ::replace(prompt_llama, "{0}", params.person);
|
||||
prompt_llama = ::replace(prompt_llama, "{1}", bot_name);
|
||||
prompt_llama = ::replace(prompt_llama, "{1}", params.bot_name);
|
||||
|
||||
{
|
||||
// get time string
|
||||
@ -435,6 +460,16 @@ int main(int argc, char ** argv) {
|
||||
bool need_to_save_session = !path_session.empty() && n_matching_session_tokens < (embd_inp.size() * 3 / 4);
|
||||
|
||||
printf("%s : done! start speaking in the microphone\n", __func__);
|
||||
|
||||
// show wake command if enabled
|
||||
const std::string wake_cmd = params.wake_cmd;
|
||||
const int wake_cmd_length = get_words(wake_cmd).size();
|
||||
const bool use_wake_cmd = wake_cmd_length > 0;
|
||||
|
||||
if (use_wake_cmd) {
|
||||
printf("%s : the wake-up command is: '%s%s%s'\n", __func__, "\033[1m", wake_cmd.c_str(), "\033[0m");
|
||||
}
|
||||
|
||||
printf("\n");
|
||||
printf("%s%s", params.person.c_str(), chat_symb.c_str());
|
||||
fflush(stdout);
|
||||
@ -480,10 +515,41 @@ int main(int argc, char ** argv) {
|
||||
|
||||
audio.get(params.voice_ms, pcmf32_cur);
|
||||
|
||||
std::string text_heard;
|
||||
std::string all_heard;
|
||||
|
||||
if (!force_speak) {
|
||||
text_heard = ::trim(::transcribe(ctx_wsp, params, pcmf32_cur, prompt_whisper, prob0, t_ms));
|
||||
all_heard = ::trim(::transcribe(ctx_wsp, params, pcmf32_cur, prompt_whisper, prob0, t_ms));
|
||||
}
|
||||
|
||||
const auto words = get_words(all_heard);
|
||||
|
||||
std::string wake_cmd_heard;
|
||||
std::string text_heard;
|
||||
|
||||
for (int i = 0; i < (int) words.size(); ++i) {
|
||||
if (i < wake_cmd_length) {
|
||||
wake_cmd_heard += words[i] + " ";
|
||||
} else {
|
||||
text_heard += words[i] + " ";
|
||||
}
|
||||
}
|
||||
|
||||
// check if audio starts with the wake-up command if enabled
|
||||
if (use_wake_cmd) {
|
||||
const float sim = similarity(wake_cmd_heard, wake_cmd);
|
||||
|
||||
if ((sim < 0.7f) || (text_heard.empty())) {
|
||||
audio.clear();
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// optionally give audio feedback that the current text is being processed
|
||||
if (!params.heard_ok.empty()) {
|
||||
int ret = system((params.speak + " " + std::to_string(voice_id) + " '" + params.heard_ok + "'").c_str());
|
||||
if (ret != 0) {
|
||||
fprintf(stderr, "%s: failed to speak\n", __func__);
|
||||
}
|
||||
}
|
||||
|
||||
// remove text between brackets using regex
|
||||
@ -520,7 +586,7 @@ int main(int argc, char ** argv) {
|
||||
force_speak = false;
|
||||
|
||||
text_heard.insert(0, 1, ' ');
|
||||
text_heard += "\n" + bot_name + chat_symb;
|
||||
text_heard += "\n" + params.bot_name + chat_symb;
|
||||
fprintf(stdout, "%s%s%s", "\033[1m", text_heard.c_str(), "\033[0m");
|
||||
fflush(stdout);
|
||||
|
||||
@ -653,6 +719,7 @@ int main(int argc, char ** argv) {
|
||||
text_to_speak += llama_token_to_piece(ctx_llama, id);
|
||||
|
||||
printf("%s", llama_token_to_piece(ctx_llama, id).c_str());
|
||||
fflush(stdout);
|
||||
}
|
||||
}
|
||||
|
||||
@ -681,8 +748,8 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
}
|
||||
|
||||
text_to_speak = ::replace(text_to_speak, "\"", "");
|
||||
int ret = system((params.speak + " " + std::to_string(voice_id) + " \"" + text_to_speak + "\"").c_str());
|
||||
text_to_speak = ::replace(text_to_speak, "'", "'\"'\"'");
|
||||
int ret = system((params.speak + " " + std::to_string(voice_id) + " '" + text_to_speak + "'").c_str());
|
||||
if (ret != 0) {
|
||||
fprintf(stderr, "%s: failed to speak\n", __func__);
|
||||
}
|
||||
|
@ -2,8 +2,9 @@
|
||||
|
||||
#include <cassert>
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
static const std::vector<std::pair<uint32_t, uint32_t>> digit_ranges = {
|
||||
{0x30, 0x39}, {0xB2, 0xB3}, {0xB9, 0xB9}, {0x660, 0x669}, {0x6F0, 0x6F9}, {0x7C0, 0x7C9}, {0x966, 0x96F}, {0x9E6, 0x9EF}, {0xA66, 0xA6F}, {0xAE6, 0xAEF}, {0xB66, 0xB6F}, {0xBE6, 0xBEF}, {0xC66, 0xC6F},
|
||||
|
@ -155,33 +155,33 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
|
||||
const int n_ctx = hparams.n_ctx;
|
||||
const int n_vocab = hparams.n_vocab;
|
||||
|
||||
ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_g
|
||||
ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_b
|
||||
ctx_size += ggml_row_size(GGML_TYPE_F32, n_embd); // ln_f_g
|
||||
ctx_size += ggml_row_size(GGML_TYPE_F32, n_embd); // ln_f_b
|
||||
|
||||
ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype); // wte
|
||||
ctx_size += n_ctx*n_embd*ggml_type_sizef(GGML_TYPE_F32); // wpe
|
||||
ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype); // lm_head
|
||||
ctx_size += n_vocab*ggml_row_size(wtype, n_embd); // wte
|
||||
ctx_size += n_ctx*ggml_row_size(GGML_TYPE_F32, n_embd); // wpe
|
||||
ctx_size += n_vocab*ggml_row_size(wtype, n_embd); // lm_head
|
||||
|
||||
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_g
|
||||
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_b
|
||||
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_1_g
|
||||
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_1_b
|
||||
|
||||
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_g
|
||||
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_b
|
||||
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_2_g
|
||||
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_2_b
|
||||
|
||||
ctx_size += n_layer*(3*n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_attn_w
|
||||
ctx_size += n_layer*( 3*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_attn_b
|
||||
ctx_size += n_layer*(ggml_row_size(wtype, 3*n_embd*n_embd)); // c_attn_attn_w
|
||||
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, 3*n_embd)); // c_attn_attn_b
|
||||
|
||||
ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_proj_w
|
||||
ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_proj_b
|
||||
ctx_size += n_layer*(ggml_row_size(wtype, n_embd*n_embd)); // c_attn_proj_w
|
||||
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // c_attn_proj_b
|
||||
|
||||
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_fc_w
|
||||
ctx_size += n_layer*( 4*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_fc_b
|
||||
ctx_size += n_layer*(ggml_row_size(wtype, 4*n_embd*n_embd)); // c_mlp_fc_w
|
||||
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, 4*n_embd)); // c_mlp_fc_b
|
||||
|
||||
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_proj_w
|
||||
ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_proj_b
|
||||
ctx_size += n_layer*(ggml_row_size(wtype, 4*n_embd*n_embd)); // c_mlp_proj_w
|
||||
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // c_mlp_proj_b
|
||||
|
||||
ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_k
|
||||
ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_v
|
||||
ctx_size += n_ctx*n_layer*ggml_row_size(GGML_TYPE_F32, n_embd); // memory_k
|
||||
ctx_size += n_ctx*n_layer*ggml_row_size(GGML_TYPE_F32, n_embd); // memory_v
|
||||
|
||||
ctx_size += (6 + 12*n_layer)*256; // object overhead
|
||||
|
||||
@ -524,8 +524,7 @@ bool gpt2_eval(
|
||||
struct ggml_tensor * KQ_scaled =
|
||||
ggml_scale(ctx0,
|
||||
KQ,
|
||||
ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head))
|
||||
);
|
||||
1.0f/sqrt(float(n_embd)/n_head));
|
||||
|
||||
// KQ_masked = mask_past(KQ_scaled)
|
||||
// [n_past + N, N, 12]
|
||||
|
@ -155,33 +155,33 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
|
||||
const int n_ctx = hparams.n_ctx;
|
||||
const int n_vocab = hparams.n_vocab;
|
||||
|
||||
ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_g
|
||||
ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_b
|
||||
ctx_size += ggml_row_size(GGML_TYPE_F32, n_embd); // ln_f_g
|
||||
ctx_size += ggml_row_size(GGML_TYPE_F32, n_embd); // ln_f_b
|
||||
|
||||
ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype); // wte
|
||||
ctx_size += n_ctx*n_embd*ggml_type_sizef(GGML_TYPE_F32); // wpe
|
||||
ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype); // lm_head
|
||||
ctx_size += n_vocab*ggml_row_size(wtype, n_embd); // wte
|
||||
ctx_size += n_ctx*ggml_row_size(GGML_TYPE_F32, n_embd); // wpe
|
||||
ctx_size += n_vocab*ggml_row_size(wtype, n_embd); // lm_head
|
||||
|
||||
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_g
|
||||
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_b
|
||||
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_1_g
|
||||
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_1_b
|
||||
|
||||
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_g
|
||||
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_b
|
||||
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_2_g
|
||||
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_2_b
|
||||
|
||||
ctx_size += n_layer*(3*n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_attn_w
|
||||
ctx_size += n_layer*( 3*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_attn_b
|
||||
ctx_size += n_layer*(ggml_row_size(wtype, 3*n_embd*n_embd)); // c_attn_attn_w
|
||||
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, 3*n_embd)); // c_attn_attn_b
|
||||
|
||||
ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_proj_w
|
||||
ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_proj_b
|
||||
ctx_size += n_layer*(ggml_row_size(wtype, n_embd*n_embd)); // c_attn_proj_w
|
||||
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // c_attn_proj_b
|
||||
|
||||
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_fc_w
|
||||
ctx_size += n_layer*( 4*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_fc_b
|
||||
ctx_size += n_layer*(ggml_row_size(wtype, 4*n_embd*n_embd)); // c_mlp_fc_w
|
||||
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, 4*n_embd)); // c_mlp_fc_b
|
||||
|
||||
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_proj_w
|
||||
ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_proj_b
|
||||
ctx_size += n_layer*(ggml_row_size(wtype, 4*n_embd*n_embd)); // c_mlp_proj_w
|
||||
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // c_mlp_proj_b
|
||||
|
||||
ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_k
|
||||
ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_v
|
||||
ctx_size += n_ctx*n_layer*ggml_row_size(GGML_TYPE_F32, n_embd); // memory_k
|
||||
ctx_size += n_ctx*n_layer*ggml_row_size(GGML_TYPE_F32, n_embd); // memory_v
|
||||
|
||||
ctx_size += (6 + 12*n_layer)*256; // object overhead
|
||||
|
||||
@ -525,8 +525,7 @@ bool gpt2_eval(
|
||||
struct ggml_tensor * KQ_scaled =
|
||||
ggml_scale(ctx0,
|
||||
KQ,
|
||||
ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head))
|
||||
);
|
||||
1.0f/sqrt(float(n_embd)/n_head));
|
||||
|
||||
// KQ_masked = mask_past(KQ_scaled)
|
||||
// [n_past + N, N, 12]
|
||||
|
@ -184,7 +184,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
// whisper init
|
||||
struct whisper_context_params cparams;
|
||||
struct whisper_context_params cparams = whisper_context_default_params();
|
||||
cparams.use_gpu = params.use_gpu;
|
||||
|
||||
struct whisper_context * ctx_wsp = whisper_init_from_file_with_params(params.model_wsp.c_str(), cparams);
|
||||
|
@ -21,7 +21,7 @@ help()
|
||||
echo "Usage: ./twitch.sh -s [step] -m [model] -t [threads] [url]"
|
||||
echo "options:"
|
||||
echo "-s Step in seconds (default is $step)."
|
||||
echo "-m Choose model, options are: 'tiny.en' 'tiny' 'base.en' 'base' 'small.en' 'small' 'medium.en' 'medium' 'large-v1' 'large-v2' 'large' (default is '$model')."
|
||||
echo "-m Choose model, options are: 'tiny.en' 'tiny' 'base.en' 'base' 'small.en' 'small' 'medium.en' 'medium' 'large-v1' 'large-v2' 'large-v3' (default is '$model')."
|
||||
echo "-t Number of threads to use."
|
||||
echo "-h Print this help page."
|
||||
echo
|
||||
|
9
examples/wchess/CMakeLists.txt
Normal file
@ -0,0 +1,9 @@
|
||||
set(CMAKE_CXX_STANDARD 11)
|
||||
|
||||
add_subdirectory(libwchess)
|
||||
|
||||
if (EMSCRIPTEN)
|
||||
add_subdirectory(wchess.wasm)
|
||||
else()
|
||||
add_subdirectory(wchess.cmd)
|
||||
endif()
|
45
examples/wchess/README.md
Normal file
@ -0,0 +1,45 @@
|
||||
# wchess
|
||||
|
||||
Voice-controlled chess using Whisper
|
||||
|
||||
Online demo: https://whisper.ggerganov.com/wchess/
|
||||
|
||||
https://github.com/ggerganov/whisper.cpp/assets/1991296/c2b2f03c-9684-49f3-8106-357d2d4e67fa
|
||||
|
||||
## Command-line tool
|
||||
|
||||
```bash
|
||||
mkdir build && cd build
|
||||
cmake -DWHISPER_SDL2=1 ..
|
||||
make -j
|
||||
|
||||
./bin/wchess -m ../models/ggml-base.en.bin
|
||||
|
||||
Move: start
|
||||
|
||||
a b c d e f g h
|
||||
r n b q k b n r 8
|
||||
p p p p p p p p 7
|
||||
. * . * . * . * 6
|
||||
* . * . * . * . 5
|
||||
. * . * . * . * 4
|
||||
* . * . * . * . 3
|
||||
P P P P P P P P 2
|
||||
R N B Q K B N R 1
|
||||
|
||||
White's turn
|
||||
[(l)isten/(p)ause/(q)uit]:
|
||||
```
|
||||
|
||||
## TODO
|
||||
|
||||
- Fix bugs in the chess moves logic
|
||||
- Improve web-browser audio capture - sometimes it does not record the voice properly
|
||||
- Add support for more languages by making the generated grammar string multilingual
|
||||
- Explore ways to improve the dynamic grammar to be narrower
|
||||
|
||||
PRs welcome!
|
||||
|
||||
## Thanks
|
||||
|
||||
- [chessboardjs](https://chessboardjs.com) for the neat chessboard JS library used in this demo
|
19
examples/wchess/libwchess/CMakeLists.txt
Normal file
@ -0,0 +1,19 @@
|
||||
add_library(wchess-core STATIC
|
||||
WChess.cpp
|
||||
WChess.h
|
||||
Chessboard.cpp
|
||||
Chessboard.h
|
||||
)
|
||||
|
||||
target_link_libraries(wchess-core
|
||||
PUBLIC
|
||||
whisper
|
||||
common
|
||||
)
|
||||
|
||||
target_include_directories(wchess-core
|
||||
PUBLIC
|
||||
"$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}>"
|
||||
)
|
||||
|
||||
# add_executable(test-chessboard test-chessboard.cpp Chessboard.cpp)
|
803
examples/wchess/libwchess/Chessboard.cpp
Normal file
@ -0,0 +1,803 @@
|
||||
#include "Chessboard.h"
|
||||
|
||||
#include <array>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
#include <set>
|
||||
#include <list>
|
||||
#include <chrono>
|
||||
|
||||
namespace {
|
||||
constexpr std::array<const char*, 64> positions = {
|
||||
"a1", "b1", "c1", "d1", "e1", "f1", "g1", "h1",
|
||||
"a2", "b2", "c2", "d2", "e2", "f2", "g2", "h2",
|
||||
"a3", "b3", "c3", "d3", "e3", "f3", "g3", "h3",
|
||||
"a4", "b4", "c4", "d4", "e4", "f4", "g4", "h4",
|
||||
"a5", "b5", "c5", "d5", "e5", "f5", "g5", "h5",
|
||||
"a6", "b6", "c6", "d6", "e6", "f6", "g6", "h6",
|
||||
"a7", "b7", "c7", "d7", "e7", "f7", "g7", "h7",
|
||||
"a8", "b8", "c8", "d8", "e8", "f8", "g8", "h8",
|
||||
};
|
||||
constexpr char INVALID_POS = positions.size();
|
||||
constexpr int R = 0; // rank index
|
||||
constexpr int F = 1; // file index
|
||||
#define FILE (c[F] - '1')
|
||||
#define RANK (c[R] - 'a')
|
||||
constexpr char operator ""_P(const char * c, size_t size) {
|
||||
return size < 2 || RANK < 0 || RANK > 7 ||
|
||||
FILE < 0 || FILE > 7 ? INVALID_POS : FILE * 8 + RANK;
|
||||
}
|
||||
#undef FILE
|
||||
#undef RANK
|
||||
|
||||
struct sview {
|
||||
const char * ptr = nullptr;
|
||||
size_t size = 0;
|
||||
|
||||
sview() = default;
|
||||
sview(const char * p, size_t s) : ptr(p), size(s) {}
|
||||
sview(const std::string& s) : ptr(s.data()), size(s.size()) {}
|
||||
|
||||
size_t find(char del, size_t pos) {
|
||||
while (pos < size && ptr[pos] != del) ++pos;
|
||||
return pos < size ? pos : std::string::npos;
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<sview> split(sview str, char del) {
|
||||
std::vector<sview> res;
|
||||
size_t cur = 0;
|
||||
size_t last = 0;
|
||||
while (cur != std::string::npos) {
|
||||
if (str.ptr[last] == ' ') {
|
||||
++last;
|
||||
continue;
|
||||
}
|
||||
cur = str.find(del, last);
|
||||
size_t len = cur == std::string::npos ? str.size - last : cur - last;
|
||||
res.emplace_back(str.ptr + last, len);
|
||||
last = cur + 1;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
char strToPos(sview str) {
|
||||
return operator ""_P(str.ptr, str.size);
|
||||
}
|
||||
|
||||
constexpr std::array<const char*, 6> pieceNames = {
|
||||
"pawn", "knight", "bishop", "rook", "queen", "king",
|
||||
};
|
||||
|
||||
static constexpr std::array<char, 6> blackShort = {
|
||||
'p', 'n', 'b', 'r', 'q', 'k',
|
||||
};
|
||||
static constexpr std::array<char, 6> whiteShort = {
|
||||
'P', 'N', 'B', 'R', 'Q', 'K',
|
||||
};
|
||||
|
||||
char strToType(sview str) {
|
||||
auto it = std::find_if(pieceNames.begin(), pieceNames.end(), [str] (const char* name) { return strncmp(name, str.ptr, str.size) == 0; });
|
||||
return it != pieceNames.end() ? it - pieceNames.begin() : pieceNames.size();
|
||||
}
|
||||
|
||||
// directions
|
||||
using Direction = std::array<char, 2>;
|
||||
|
||||
constexpr Direction N = {(char) 0, (char) 1};
|
||||
constexpr Direction NNE = {(char) 1, (char) 2};
|
||||
constexpr Direction NE = {(char) 1, (char) 1};
|
||||
constexpr Direction ENE = {(char) 2, (char) 1};
|
||||
constexpr Direction E = {(char) 1, (char) 0};
|
||||
constexpr Direction ESE = {(char) 2, (char) -1};
|
||||
constexpr Direction SE = {(char) 1, (char) -1};
|
||||
constexpr Direction SSE = {(char) 1, (char) -2};
|
||||
constexpr Direction S = {(char) 0, (char) -1};
|
||||
constexpr Direction SSW = {(char) -1, (char) -2};
|
||||
constexpr Direction SW = {(char) -1, (char) -1};
|
||||
constexpr Direction WSW = {(char) -2, (char) -1};
|
||||
constexpr Direction W = {(char) -1, (char) 0};
|
||||
constexpr Direction WNW = {(char) -2, (char) 1};
|
||||
constexpr Direction NW = {(char) -1, (char) 1};
|
||||
constexpr Direction NNW = {(char) -1, (char) 2};
|
||||
|
||||
char makeStep(char pos, const Direction& d) {
|
||||
char next[2] = { char(positions[pos][R] + d[R]) , char(positions[pos][F] + d[F]) };
|
||||
return strToPos(sview{next, sizeof(next)});
|
||||
}
|
||||
|
||||
template<class Modifier>
|
||||
char traverse(char pos, const Direction& d, const Modifier& m, int count = 8) {
|
||||
while (--count >= 0) {
|
||||
pos = makeStep(pos, d);
|
||||
if (pos == INVALID_POS || m(pos)) break;
|
||||
}
|
||||
return pos;
|
||||
}
|
||||
|
||||
Direction normalize(const Direction& distance) {
|
||||
//return {char((distance[R] > 0) - (distance[R] < 0)), char((distance[F] > 0) - (distance[F] < 0))};
|
||||
const int drp = distance[R] > 0 ? 1 : 0;
|
||||
const int drn = distance[R] < 0 ? 1 : 0;
|
||||
const int dfp = distance[F] > 0 ? 1 : 0;
|
||||
const int dfn = distance[F] < 0 ? 1 : 0;
|
||||
return {char(drp - drn), char(dfp - dfn)};
|
||||
}
|
||||
|
||||
struct Pin {
|
||||
Direction d;
|
||||
Piece* pinner;
|
||||
Piece* pinned;
|
||||
};
|
||||
using Pins = std::list<Pin>;
|
||||
using Board = std::array<Piece*, 64>;
|
||||
|
||||
std::vector<Direction> filter(const Direction& pin, std::initializer_list<Direction> directions) {
|
||||
if (pin[R] == 0 && pin[F] == 0) return directions;
|
||||
std::vector<Direction> result;
|
||||
for (auto& d : directions) {
|
||||
if ((d[R] == pin[R] || d[R] == -pin[R]) && (d[F] == pin[F] || d[F] == -pin[F])) result.push_back(d);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
class Piece {
|
||||
public:
|
||||
enum Types : char {
|
||||
Pawn,
|
||||
Knight,
|
||||
Bishop,
|
||||
Rook,
|
||||
Queen,
|
||||
King,
|
||||
//
|
||||
NUM_PIECES
|
||||
};
|
||||
|
||||
enum Colors : char {
|
||||
White,
|
||||
Black,
|
||||
};
|
||||
|
||||
const char* name() const;
|
||||
char initial() const;
|
||||
Types type() const { return m_type; }
|
||||
Colors color() const { return m_color; }
|
||||
char pos() const { return m_pos; }
|
||||
void setPos(char pos) {
|
||||
m_pos = pos;
|
||||
invalidate();
|
||||
}
|
||||
const char* coord() const;
|
||||
const std::set<char>& allowed() const { return m_allowed; }
|
||||
bool canReach(char pos) const;
|
||||
virtual bool movePattern(char pos) const = 0;
|
||||
void take();
|
||||
virtual void reinit(const State& state) = 0;
|
||||
void invalidate();
|
||||
protected:
|
||||
Piece(Types type, Colors color, char pos, std::set<char> allowed)
|
||||
: m_type(type), m_color(color), m_pos(pos), m_allowed(std::move(allowed)) {}
|
||||
Piece(const Piece&) = delete;
|
||||
~Piece() = default;
|
||||
|
||||
const Types m_type;
|
||||
const Colors m_color;
|
||||
char m_pos;
|
||||
std::set<char> m_allowed;
|
||||
bool m_update = false;
|
||||
};
|
||||
|
||||
struct Pawn : public Piece {
|
||||
Pawn(Colors color, char pos, std::set<char> next) : Piece(Types::Pawn, color, pos, std::move(next)) {}
|
||||
|
||||
bool is_first_move() const {
|
||||
return m_color ? coord()[F] == '7' : coord()[F] == '2';
|
||||
}
|
||||
|
||||
virtual bool movePattern(char pos) const override {
|
||||
if (m_pos == INVALID_POS) return false;
|
||||
auto cur = coord();
|
||||
auto next = positions[pos];
|
||||
Direction distance = {char(next[R] - cur[R]), char(next[F] - cur[F])};
|
||||
char forward = m_color ? -1 : 1;
|
||||
return (forward == distance[F] && distance[R] * distance[R] <= 1)
|
||||
|| (is_first_move() && 2 * forward == distance[F] && distance[R] == 0);
|
||||
}
|
||||
|
||||
virtual void reinit(const State& state) override;
|
||||
};
|
||||
|
||||
struct Knight : public Piece {
|
||||
Knight(Colors color, char pos, std::set<char> next) : Piece(Types::Knight, color, pos, std::move(next)) {}
|
||||
|
||||
virtual bool movePattern(char pos) const override {
|
||||
if (m_pos == INVALID_POS) return false;
|
||||
auto cur = coord();
|
||||
auto next = positions[pos];
|
||||
Direction diff = {char(next[R] - cur[R]), char(next[F] - cur[F])};
|
||||
return diff[R]*diff[R] + diff[F]*diff[F] == 5;
|
||||
}
|
||||
|
||||
virtual void reinit(const State& state) override;
|
||||
};
|
||||
|
||||
struct Bishop : public Piece {
|
||||
Bishop(Colors color, char pos) : Piece(Types::Bishop, color, pos, {}) {}
|
||||
|
||||
virtual bool movePattern(char pos) const override {
|
||||
if (m_pos == INVALID_POS) return false;
|
||||
auto cur = coord();
|
||||
auto next = positions[pos];
|
||||
return cur[R] - cur[F] == next[R] - next[F] || cur[R] + cur[F] == next[R] + next[F];
|
||||
}
|
||||
|
||||
virtual void reinit(const State& state) override;
|
||||
};
|
||||
|
||||
struct Rook : public Piece {
|
||||
Rook(Colors color, char pos) : Piece(Types::Rook, color, pos, {}) {}
|
||||
|
||||
virtual bool movePattern(char pos) const override {
|
||||
if (m_pos == INVALID_POS) return false;
|
||||
auto cur = coord();
|
||||
auto next = positions[pos];
|
||||
return cur[R] == next[R] || cur[F] == next[F];
|
||||
}
|
||||
|
||||
virtual void reinit(const State& state) override;
|
||||
};
|
||||
|
||||
struct Queen : public Piece {
|
||||
Queen(Colors color, char pos) : Piece(Types::Queen, color, pos, {}) {}
|
||||
|
||||
virtual bool movePattern(char pos) const override {
|
||||
if (m_pos == INVALID_POS) return false;
|
||||
auto cur = coord();
|
||||
auto next = positions[pos];
|
||||
return cur[R] == next[R] || cur[F] == next[F] || cur[R] - cur[F] == next[R] - next[F] || cur[R] + cur[F] == next[R] + next[F];
|
||||
}
|
||||
|
||||
virtual void reinit(const State& state) override;
|
||||
};
|
||||
|
||||
struct King : public Piece {
|
||||
King(Colors color, char pos) : Piece(Types::King, color, pos, {}) {}
|
||||
|
||||
virtual bool movePattern(char pos) const override {
|
||||
if (m_pos == INVALID_POS) return false;
|
||||
auto cur = coord();
|
||||
auto next = positions[pos];
|
||||
Direction diff = {char(next[R] - cur[R]), char(next[F] - cur[F])};
|
||||
return diff[R]*diff[R] + diff[F]*diff[F] <= 2;
|
||||
}
|
||||
|
||||
virtual void reinit(const State& state) override;
|
||||
};
|
||||
|
||||
struct PieceSet {
|
||||
Piece* begin() { return &p1; }
|
||||
Piece* end() { return &r2 + 1; }
|
||||
const Piece* begin() const { return &p1; }
|
||||
const Piece* end() const { return &r2 + 1; }
|
||||
Piece& operator[](int i) { return *(begin() + i); }
|
||||
const Piece& operator[](int i) const { return *(begin() + i); }
|
||||
|
||||
Pawn p1;
|
||||
Pawn p2;
|
||||
Pawn p3;
|
||||
Pawn p4;
|
||||
Pawn p5;
|
||||
Pawn p6;
|
||||
Pawn p7;
|
||||
Pawn p8;
|
||||
Rook r1;
|
||||
Knight n1;
|
||||
Bishop b1;
|
||||
Queen q;
|
||||
King k;
|
||||
Bishop b2;
|
||||
Knight n2;
|
||||
Rook r2;
|
||||
};
|
||||
|
||||
struct State {
|
||||
State();
|
||||
PieceSet blacks;
|
||||
PieceSet whites;
|
||||
Board board;
|
||||
Pins blackPins;
|
||||
Pins whitePins;
|
||||
};
|
||||
|
||||
Direction findPin(const Piece& piece, const State& state) {
|
||||
auto& pins = piece.color() ? state.blackPins : state.whitePins;
|
||||
auto it = std::find_if(pins.begin(), pins.end(), [&] (const Pin& pin) { return pin.pinned == &piece; });
|
||||
if (it != pins.end()) return it->d;
|
||||
return {0, 0};
|
||||
}
|
||||
|
||||
struct Find {
|
||||
Find(const Board& board) : m_board(board) {}
|
||||
bool operator() (char pos) const { return m_board[pos]; }
|
||||
const Board& m_board;
|
||||
};
|
||||
|
||||
struct Add {
|
||||
Add(const Board& board, std::set<char>& moves, Piece::Colors color) : m_board(board), m_moves(moves), m_color(color) {}
|
||||
bool operator() (char pos) const {
|
||||
if (!m_board[pos] || m_board[pos]->color() != m_color) m_moves.insert(pos);
|
||||
return m_board[pos];
|
||||
}
|
||||
const Board& m_board;
|
||||
std::set<char>& m_moves;
|
||||
Piece::Colors m_color;
|
||||
};
|
||||
|
||||
void Pawn::reinit(const State& state) {
|
||||
if (m_pos == INVALID_POS) return;
|
||||
if (!m_update) return;
|
||||
m_update = false;
|
||||
m_allowed.clear();
|
||||
|
||||
auto pin = findPin(*this, state);
|
||||
|
||||
auto & left = m_color ? SW : NW;
|
||||
auto & right = m_color ? SE : NE;
|
||||
|
||||
for (auto& direction : filter(pin, { left, right })) {
|
||||
auto pos = makeStep(m_pos, direction);
|
||||
if (pos != INVALID_POS && state.board[pos] && state.board[pos]->color() != m_color) m_allowed.insert(pos);
|
||||
}
|
||||
|
||||
auto & forward = m_color ? S : N;
|
||||
if (!filter(pin, {forward}).empty()) {
|
||||
traverse(m_pos, forward, [&] (char pos) {
|
||||
if (!state.board[pos]) m_allowed.insert(pos);
|
||||
return state.board[pos] || !is_first_move();
|
||||
}, 2);
|
||||
}
|
||||
}
|
||||
|
||||
void Knight::reinit(const State& state) {
|
||||
if (m_pos == INVALID_POS) return;
|
||||
if (!m_update) return;
|
||||
m_update = false;
|
||||
m_allowed.clear();
|
||||
auto pin = findPin(*this, state);
|
||||
if (pin[R] != 0 || pin[F] != 0) return;
|
||||
for (auto& direction : { NNE, ENE, ESE, SSE, SSW, WSW, WNW, NNW }) {
|
||||
auto pos = makeStep(m_pos, direction);
|
||||
if (pos != INVALID_POS && (!state.board[pos] || state.board[pos]->color() != m_color)) m_allowed.insert(pos);
|
||||
}
|
||||
}
|
||||
|
||||
void Bishop::reinit(const State& state) {
|
||||
if (m_pos == INVALID_POS) return;
|
||||
if (!m_update) return;
|
||||
m_update = false;
|
||||
m_allowed.clear();
|
||||
auto pin = findPin(*this, state);
|
||||
for (auto& direction : filter(pin, { NE, SE, SW, NW })) {
|
||||
traverse(m_pos, direction, Add(state.board, m_allowed, m_color));
|
||||
}
|
||||
}
|
||||
|
||||
void Rook::reinit(const State& state) {
|
||||
if (m_pos == INVALID_POS) return;
|
||||
if (!m_update) return;
|
||||
m_update = false;
|
||||
m_allowed.clear();
|
||||
auto pin = findPin(*this, state);
|
||||
for (auto& direction : filter(pin, { N, E, S, W })) {
|
||||
traverse(m_pos, direction, Add(state.board, m_allowed, m_color));
|
||||
}
|
||||
}
|
||||
|
||||
void Queen::reinit(const State& state) {
|
||||
if (m_pos == INVALID_POS) return;
|
||||
if (!m_update) return;
|
||||
m_update = false;
|
||||
m_allowed.clear();
|
||||
auto pin = findPin(*this, state);
|
||||
for (auto& direction : filter(pin, { N, NE, E, SE, S, SW, W, NW })) {
|
||||
traverse(m_pos, direction, Add(state.board, m_allowed, m_color));
|
||||
}
|
||||
}
|
||||
|
||||
void King::reinit(const State& state) {
|
||||
if (m_pos == INVALID_POS) return;
|
||||
if (!m_update) return;
|
||||
m_update = false;
|
||||
m_allowed.clear();
|
||||
auto& enemyPieces = m_color ? state.whites : state.blacks;
|
||||
auto& pawnAttackLeft = m_color ? SW : NW;
|
||||
auto& pawnAttackRight = m_color ? SE : NE;
|
||||
for (auto& direction : { N, NE, E, SE, S, SW, W, NW }) {
|
||||
auto pos = makeStep(m_pos, direction);
|
||||
bool accept = pos != INVALID_POS && !(state.board[pos] && state.board[pos]->color() == m_color);
|
||||
if (accept) {
|
||||
for (auto& p : enemyPieces) {
|
||||
if (!p.movePattern(pos)) continue;
|
||||
if (p.type() == Piece::Knight || p.type() == Piece::King) {
|
||||
accept = false;
|
||||
break;
|
||||
}
|
||||
else if (p.type() == Piece::Pawn) {
|
||||
auto from = positions[pos];
|
||||
auto to = p.coord();
|
||||
Direction d {char(to[R] - from[R]), char(to[F] - from[F])};
|
||||
if (d == pawnAttackLeft || d == pawnAttackRight) {
|
||||
accept = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
else {
|
||||
auto from = positions[pos];
|
||||
auto to = p.coord();
|
||||
Direction d = normalize({char(to[R] - from[R]), char(to[F] - from[F])});
|
||||
auto reached = traverse(pos, d, Find(state.board));
|
||||
if (p.pos() == reached) {
|
||||
accept = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (accept) m_allowed.insert(pos);
|
||||
}
|
||||
}
|
||||
|
||||
const char* Piece::name() const {
|
||||
static_assert(pieceNames.size() == Piece::NUM_PIECES, "Mismatch between piece names and types");
|
||||
return pieceNames[m_type];
|
||||
}
|
||||
|
||||
char Piece::initial() const {
|
||||
static_assert(blackShort.size() == Piece::NUM_PIECES, "Mismatch between piece names and types");
|
||||
static_assert(whiteShort.size() == Piece::NUM_PIECES, "Mismatch between piece names and types");
|
||||
return m_color ? blackShort[m_type] : whiteShort[m_type];
|
||||
}
|
||||
|
||||
void Piece::invalidate() {
|
||||
m_update = true;
|
||||
}
|
||||
|
||||
|
||||
const char* Piece::coord() const {
|
||||
if (m_pos == INVALID_POS) return "";
|
||||
return positions[m_pos];
|
||||
}
|
||||
|
||||
bool Piece::canReach(char pos) const {
|
||||
return movePattern(pos) && m_allowed.count(pos);
|
||||
}
|
||||
|
||||
void Piece::take() {
|
||||
m_pos = INVALID_POS;
|
||||
m_allowed = {};
|
||||
}
|
||||
|
||||
State::State()
|
||||
: blacks {
|
||||
{Piece::Black, "a7"_P, {"a5"_P, "a6"_P} },
|
||||
{Piece::Black, "b7"_P, {"b5"_P, "b6"_P} },
|
||||
{Piece::Black, "c7"_P, {"c5"_P, "c6"_P} },
|
||||
{Piece::Black, "d7"_P, {"d5"_P, "d6"_P} },
|
||||
{Piece::Black, "e7"_P, {"e5"_P, "e6"_P} },
|
||||
{Piece::Black, "f7"_P, {"f5"_P, "f6"_P} },
|
||||
{Piece::Black, "g7"_P, {"g5"_P, "g6"_P} },
|
||||
{Piece::Black, "h7"_P, {"h5"_P, "h6"_P} },
|
||||
{Piece::Black, "a8"_P},
|
||||
{Piece::Black, "b8"_P, {"a6"_P, "c6"_P} },
|
||||
{Piece::Black, "c8"_P},
|
||||
{Piece::Black, "d8"_P},
|
||||
{Piece::Black, "e8"_P},
|
||||
{Piece::Black, "f8"_P},
|
||||
{Piece::Black, "g8"_P, {"f6"_P, "h6"_P} },
|
||||
{Piece::Black, "h8"_P},
|
||||
}
|
||||
, whites {
|
||||
{Piece::White, "a2"_P, {"a3"_P, "a4"_P} },
|
||||
{Piece::White, "b2"_P, {"b3"_P, "b4"_P} },
|
||||
{Piece::White, "c2"_P, {"c3"_P, "c4"_P} },
|
||||
{Piece::White, "d2"_P, {"d3"_P, "d4"_P} },
|
||||
{Piece::White, "e2"_P, {"e3"_P, "e4"_P} },
|
||||
{Piece::White, "f2"_P, {"f3"_P, "f4"_P} },
|
||||
{Piece::White, "g2"_P, {"g3"_P, "g4"_P} },
|
||||
{Piece::White, "h2"_P, {"h3"_P, "h4"_P} },
|
||||
{Piece::White, "a1"_P},
|
||||
{Piece::White, "b1"_P, {"a3"_P, "c3"_P} },
|
||||
{Piece::White, "c1"_P},
|
||||
{Piece::White, "d1"_P},
|
||||
{Piece::White, "e1"_P},
|
||||
{Piece::White, "f1"_P},
|
||||
{Piece::White, "g1"_P, {"f3"_P, "h3"_P} },
|
||||
{Piece::White, "h1"_P},
|
||||
}
|
||||
, board {{
|
||||
&whites[ 8], &whites[ 9], &whites[10], &whites[11], &whites[12], &whites[13], &whites[14], &whites[15],
|
||||
&whites[ 0], &whites[ 1], &whites[ 2], &whites[ 3], &whites[ 4], &whites[ 5], &whites[ 6], &whites[ 7],
|
||||
nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
|
||||
nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
|
||||
nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
|
||||
nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
|
||||
&blacks[ 0], &blacks[ 1], &blacks[ 2], &blacks[ 3], &blacks[ 4], &blacks[ 5], &blacks[ 6], &blacks[ 7],
|
||||
&blacks[ 8], &blacks[ 9], &blacks[10], &blacks[11], &blacks[12], &blacks[13], &blacks[14], &blacks[15],
|
||||
}}
|
||||
{}
|
||||
|
||||
Chessboard::Chessboard()
|
||||
: m_state(new State())
|
||||
{
|
||||
setGrammar();
|
||||
}
|
||||
|
||||
Chessboard::~Chessboard() = default;
|
||||
|
||||
void Chessboard::setPrompt(const std::string& prompt) {
|
||||
m_prompt = prompt;
|
||||
setGrammar();
|
||||
}
|
||||
|
||||
void Chessboard::setGrammar() {
|
||||
m_grammar.clear();
|
||||
|
||||
std::string result;
|
||||
if (m_prompt.empty()) {
|
||||
result += "move ::= \" \" ((piece | frompos) \" \" \"to \"?)? topos\n";
|
||||
//result += "move ::= \" \" frompos \" \" \"to \"? topos\n";
|
||||
}
|
||||
else {
|
||||
// result += "move ::= prompt \" \" ((piece | frompos) \" \" \"to \"?)? topos\n"
|
||||
result += "move ::= prompt \" \" frompos \" \" \"to \"? topos\n"
|
||||
"prompt ::= \" " + m_prompt + "\"\n";
|
||||
}
|
||||
|
||||
std::set<Piece::Types> pieceTypes;
|
||||
std::set<char> from_pos;
|
||||
std::set<char> to_pos;
|
||||
auto& pieces = m_moveCounter % 2 ? m_state->blacks : m_state->whites;
|
||||
std::set<size_t> flags;
|
||||
for (auto& p : pieces) {
|
||||
if (p.allowed().empty()) continue;
|
||||
bool addPiece = false;
|
||||
if (!m_inCheck || p.type() == Piece::King) {
|
||||
to_pos.insert(p.allowed().begin(), p.allowed().end());
|
||||
addPiece = !p.allowed().empty();
|
||||
}
|
||||
else {
|
||||
for (auto move : p.allowed()) {
|
||||
if (m_allowedInCheck.count(move)) {
|
||||
to_pos.insert(move);
|
||||
addPiece = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (addPiece) {
|
||||
pieceTypes.insert(p.type());
|
||||
from_pos.insert(p.pos());
|
||||
}
|
||||
}
|
||||
if (pieceTypes.empty()) return;
|
||||
|
||||
result += "piece ::= (";
|
||||
for (auto& p : pieceTypes) result += " \"" + std::string(pieceNames[p]) + "\" |";
|
||||
result.pop_back();
|
||||
result += ")\n\n";
|
||||
|
||||
result += "frompos ::= (";
|
||||
for (auto& p : from_pos) result += " \"" + std::string(positions[p]) + "\" |";
|
||||
result.pop_back();
|
||||
result += ")\n";
|
||||
|
||||
result += "topos ::= (";
|
||||
for (auto& p : to_pos) result += " \"" + std::string(positions[p]) + "\" |";
|
||||
result.pop_back();
|
||||
result += ")\n";
|
||||
|
||||
m_grammar = std::move(result);
|
||||
}
|
||||
|
||||
std::string Chessboard::stringifyBoard() {
|
||||
std::string result;
|
||||
result.reserve(16 + 2 * 64 + 16);
|
||||
for (char rank = 'a'; rank <= 'h'; ++rank) {
|
||||
result.push_back(rank);
|
||||
result.push_back(' ');
|
||||
}
|
||||
result.back() = '\n';
|
||||
for (int i = 7; i >= 0; --i) {
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
auto p = m_state->board[i * 8 + j];
|
||||
if (p) result.push_back(p->initial());
|
||||
else result.push_back((i + j) % 2 ? '.' : '*');
|
||||
result.push_back(' ');
|
||||
}
|
||||
result.push_back('0' + i + 1);
|
||||
result.push_back('\n');
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string Chessboard::process(const std::string& command) {
|
||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||
auto color = Piece::Colors(m_moveCounter % 2);
|
||||
Piece* piece = nullptr;
|
||||
auto pos_to = INVALID_POS;
|
||||
if (!parseCommand(command, piece, pos_to)) return "";
|
||||
|
||||
auto pos_from = piece->pos();
|
||||
|
||||
if (!move(*piece, pos_to)) return "";
|
||||
|
||||
flagUpdates(pos_from, pos_to);
|
||||
|
||||
detectChecks();
|
||||
|
||||
auto& enemyPieces = color ? m_state->whites : m_state->blacks;
|
||||
for (auto& p : enemyPieces) p.reinit(*m_state); // only enemy moves needed next
|
||||
|
||||
std::string result = {positions[pos_from][R], positions[pos_from][F], '-', positions[pos_to][R], positions[pos_to][F]};
|
||||
++m_moveCounter;
|
||||
setGrammar();
|
||||
const auto t_end = std::chrono::high_resolution_clock::now();
|
||||
auto t_ms = std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count();
|
||||
fprintf(stdout, "%s: Move '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", result.data(), "\033[0m", (int) t_ms);
|
||||
if (m_grammar.empty()) result.push_back('#');
|
||||
return result;
|
||||
}
|
||||
|
||||
bool Chessboard::parseCommand(const std::string& command, Piece*& piece, char& pos_to) {
|
||||
auto color = Piece::Colors(m_moveCounter % 2);
|
||||
fprintf(stdout, "%s: Command to %s: '%s%.*s%s'\n", __func__, (color ? "Black" : "White"), "\033[1m", int(command.size()), command.data(), "\033[0m");
|
||||
|
||||
if (command.empty()) return false;
|
||||
auto tokens = split(command, ' ');
|
||||
auto pos_from = INVALID_POS;
|
||||
auto type = Piece::Types::NUM_PIECES;
|
||||
if (tokens.size() == 1) {
|
||||
type = Piece::Types::Pawn;
|
||||
pos_to = strToPos(tokens.front());
|
||||
}
|
||||
else {
|
||||
pos_from = strToPos(tokens.front());
|
||||
if (pos_from == INVALID_POS) type = Piece::Types(strToType(tokens.front()));
|
||||
pos_to = strToPos(tokens.back());
|
||||
}
|
||||
if (pos_to == INVALID_POS) return false;
|
||||
if (pos_from == INVALID_POS) {
|
||||
if (type == Piece::Types::NUM_PIECES) return false;
|
||||
auto& pieces = color ? m_state->blacks : m_state->whites;
|
||||
for (auto& p : pieces) {
|
||||
if (p.type() == type && p.canReach(pos_to)) {
|
||||
pos_from = p.pos();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (pos_from == INVALID_POS) return false;
|
||||
if (m_state->board[pos_from] == nullptr) return false;
|
||||
piece = m_state->board[pos_from];
|
||||
if (piece->color() != color) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
void Chessboard::flagUpdates(char pos_from, char pos_to) {
|
||||
auto color = Piece::Colors(m_moveCounter % 2);
|
||||
auto& enemyPieces = color ? m_state->whites : m_state->blacks;
|
||||
auto& ownPieces = color ? m_state->blacks : m_state->whites;
|
||||
for (auto& p : enemyPieces) {
|
||||
if (p.movePattern(pos_to) || p.movePattern(pos_from)) {
|
||||
updatePins(p);
|
||||
p.invalidate();
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& p : ownPieces) {
|
||||
if (p.movePattern(pos_to) || p.movePattern(pos_from)) {
|
||||
updatePins(p);
|
||||
p.invalidate();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Chessboard::updatePins(Piece& piece) {
|
||||
if (piece.type() == Piece::Pawn || piece.type() == Piece::Knight || piece.type() == Piece::King) return;
|
||||
auto& enemyPieces = piece.color() ? m_state->whites : m_state->blacks;
|
||||
auto& enemyPins = piece.color() ? m_state->whitePins : m_state->blackPins;
|
||||
auto& king = enemyPieces.k;
|
||||
auto it = std::find_if(enemyPins.begin(), enemyPins.end(), [&] (const Pin& pin) { return pin.pinner == &piece; });
|
||||
if (it != enemyPins.end()) {
|
||||
it->pinned->invalidate();
|
||||
enemyPins.erase(it);
|
||||
}
|
||||
if (piece.movePattern(king.pos())) {
|
||||
auto to = positions[king.pos()];
|
||||
auto from = piece.coord();
|
||||
Direction d = normalize({char(to[R] - from[R]), char(to[F] - from[F])});
|
||||
|
||||
auto reached = traverse(piece.pos(), d, Find(m_state->board));
|
||||
auto foundPiece = m_state->board[reached];
|
||||
if (&king == foundPiece) {
|
||||
// check
|
||||
king.invalidate();
|
||||
}
|
||||
else if (foundPiece && foundPiece->color() != piece.color()) {
|
||||
reached = traverse(reached, d, Find(m_state->board));
|
||||
if (&king == m_state->board[reached]) {
|
||||
enemyPins.push_back({d, &piece, foundPiece});
|
||||
foundPiece->invalidate();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Chessboard::detectChecks() {
|
||||
auto color = Piece::Colors(m_moveCounter % 2);
|
||||
auto& enemyPieces = color ? m_state->whites : m_state->blacks;
|
||||
auto& ownPieces = color ? m_state->blacks : m_state->whites;
|
||||
auto& king = enemyPieces.k;
|
||||
auto& pawnAttackLeft = color ? SW : NW;
|
||||
auto& pawnAttackRight = color ? SE : NE;
|
||||
for (auto& p : ownPieces) {
|
||||
if (!p.movePattern(king.pos())) continue;
|
||||
auto to = positions[king.pos()];
|
||||
auto from = p.coord();
|
||||
|
||||
if (p.type() == Piece::Knight) {
|
||||
if (!m_inCheck) {
|
||||
m_allowedInCheck = { p.pos() };
|
||||
}
|
||||
else {
|
||||
m_allowedInCheck.clear();
|
||||
}
|
||||
m_inCheck = true;
|
||||
}
|
||||
else if (p.type() == Piece::Pawn) {
|
||||
Direction d {char(to[R] - from[R]), char(to[F] - from[F])};
|
||||
if (d == pawnAttackLeft || d == pawnAttackRight) {
|
||||
if (!m_inCheck) {
|
||||
m_allowedInCheck = { p.pos() };
|
||||
}
|
||||
else {
|
||||
m_allowedInCheck.clear();
|
||||
}
|
||||
m_inCheck = true;
|
||||
}
|
||||
}
|
||||
else {
|
||||
Direction d = normalize({char(to[R] - from[R]), char(to[F] - from[F])});
|
||||
std::set<char> tmp;
|
||||
auto pos = traverse(p.pos(), d, Add(m_state->board, tmp, king.color()));
|
||||
if (pos == king.pos()) {
|
||||
tmp.insert(p.pos());
|
||||
if (!m_inCheck) {
|
||||
m_allowedInCheck = std::move(tmp);
|
||||
}
|
||||
else {
|
||||
m_allowedInCheck.clear();
|
||||
}
|
||||
m_inCheck = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool Chessboard::move(Piece& piece, char pos_to) {
|
||||
auto& allowed = piece.allowed();
|
||||
|
||||
if (allowed.count(pos_to) == 0 || (m_inCheck && piece.type() != Piece::King && m_allowedInCheck.count(pos_to) == 0)) return false;
|
||||
if (m_state->board[pos_to] && m_state->board[pos_to]->color() == piece.color()) return false;
|
||||
if (m_state->board[pos_to]) m_state->board[pos_to]->take();
|
||||
m_state->board[piece.pos()] = nullptr;
|
||||
m_state->board[pos_to] = &piece;
|
||||
piece.setPos(pos_to);
|
||||
|
||||
m_inCheck = false;
|
||||
m_allowedInCheck.clear();
|
||||
|
||||
return true;
|
||||
}
|
33
examples/wchess/libwchess/Chessboard.h
Normal file
@ -0,0 +1,33 @@
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include <set>
|
||||
#include <memory>
|
||||
|
||||
// just basic validation
|
||||
// fixme: missing en passant, castling, promotion, etc.
|
||||
struct State;
|
||||
class Piece;
|
||||
class Chessboard {
|
||||
public:
|
||||
Chessboard();
|
||||
~Chessboard();
|
||||
std::string process(const std::string& command);
|
||||
std::string stringifyBoard();
|
||||
const std::string& grammar() { return m_grammar; }
|
||||
const std::string& prompt() { return m_prompt; }
|
||||
void setPrompt(const std::string& prompt);
|
||||
private:
|
||||
bool parseCommand(const std::string& command, Piece*& piece, char& pos_to);
|
||||
bool move(Piece& piece, char pos);
|
||||
void flagUpdates(char pos_from, char pos_to);
|
||||
void updatePins(Piece& piece);
|
||||
void detectChecks();
|
||||
void setGrammar();
|
||||
|
||||
std::unique_ptr<State> m_state;
|
||||
std::set<char> m_allowedInCheck;
|
||||
bool m_inCheck = false;
|
||||
int m_moveCounter = 0;
|
||||
std::string m_grammar;
|
||||
std::string m_prompt;
|
||||
};
|
193
examples/wchess/libwchess/WChess.cpp
Normal file
@ -0,0 +1,193 @@
|
||||
#include "WChess.h"
|
||||
#include "Chessboard.h"
|
||||
#include "grammar-parser.h"
|
||||
#include "common.h"
|
||||
#include <thread>
|
||||
|
||||
WChess::WChess(whisper_context * ctx,
|
||||
const whisper_full_params & wparams,
|
||||
callbacks cb,
|
||||
settings s)
|
||||
: m_ctx(ctx)
|
||||
, m_wparams(wparams)
|
||||
, m_cb(cb)
|
||||
, m_settings(s)
|
||||
, m_board(new Chessboard())
|
||||
{}
|
||||
|
||||
WChess::~WChess() = default;
|
||||
|
||||
void WChess::set_move(const std::string& moves, float prob) const {
|
||||
if (m_cb.set_move) (*m_cb.set_move)(moves, prob);
|
||||
}
|
||||
|
||||
void WChess::set_grammar(const std::string& grammar) const {
|
||||
if (m_cb.set_grammar) (*m_cb.set_grammar)(grammar);
|
||||
}
|
||||
|
||||
bool WChess::get_audio(std::vector<float>& pcmf32) const {
|
||||
if (m_cb.get_audio) return (*m_cb.get_audio)(pcmf32);
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string WChess::stringify_board() const {
|
||||
return m_board->stringifyBoard();
|
||||
}
|
||||
|
||||
std::string WChess::get_grammar() const {
|
||||
return m_board->grammar();
|
||||
}
|
||||
|
||||
void WChess::run() {
|
||||
bool have_prompt = true;
|
||||
bool ask_prompt = !have_prompt;
|
||||
|
||||
float logprob_min = 0.0f;
|
||||
|
||||
float logprob_sum = 0.0f;
|
||||
|
||||
int n_tokens = 0;
|
||||
|
||||
std::vector<float> pcmf32_cur;
|
||||
std::vector<float> pcmf32_prompt;
|
||||
|
||||
const std::string k_prompt = have_prompt ? "" : "rook to d4, f3";
|
||||
int64_t t_ms = 0;
|
||||
|
||||
if (ask_prompt) {
|
||||
fprintf(stdout, "\n");
|
||||
fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m");
|
||||
fprintf(stdout, "\n");
|
||||
|
||||
ask_prompt = false;
|
||||
}
|
||||
|
||||
while (get_audio(pcmf32_cur)) {
|
||||
if (!pcmf32_cur.empty()) {
|
||||
// fprintf(stdout, "%s: Processing ...\n", __func__);
|
||||
|
||||
if (!have_prompt) {
|
||||
const auto txt = ::trim(transcribe(pcmf32_cur, logprob_min, logprob_sum, n_tokens, t_ms));
|
||||
|
||||
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms);
|
||||
|
||||
const float sim = similarity(txt, k_prompt);
|
||||
|
||||
if (txt.length() < 0.8*k_prompt.length() || txt.length() > 1.2*k_prompt.length() || sim < 0.8f) {
|
||||
fprintf(stdout, "%s: WARNING: prompt not recognized, try again\n", __func__);
|
||||
ask_prompt = true;
|
||||
} else {
|
||||
fprintf(stdout, "\n");
|
||||
fprintf(stdout, "%s: The prompt has been recognized!\n", __func__);
|
||||
fprintf(stdout, "%s: Waiting for voice commands ...\n", __func__);
|
||||
fprintf(stdout, "\n");
|
||||
|
||||
// save the audio for the prompt
|
||||
pcmf32_prompt = pcmf32_cur;
|
||||
have_prompt = true;
|
||||
m_board->setPrompt(k_prompt);
|
||||
}
|
||||
} else {
|
||||
if (!pcmf32_prompt.empty()) pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
|
||||
constexpr size_t MIN_SIZE = 1.2 * WHISPER_SAMPLE_RATE;
|
||||
if (MIN_SIZE > pcmf32_cur.size()) pcmf32_cur.insert(pcmf32_cur.begin(), MIN_SIZE - pcmf32_cur.size(), 0.0f);
|
||||
|
||||
// fprintf(stdout, "%s: grammar rules:\n'%s'\n", __func__, m_board->grammar().c_str());
|
||||
|
||||
auto grammar_parsed = grammar_parser::parse(m_board->grammar().c_str());
|
||||
auto grammar_rules = grammar_parsed.c_rules();
|
||||
|
||||
m_wparams.grammar_rules = grammar_rules.data();
|
||||
m_wparams.n_grammar_rules = grammar_rules.size();
|
||||
|
||||
m_wparams.i_start_rule = grammar_parsed.symbol_ids.at("move");
|
||||
auto txt = ::trim(transcribe(pcmf32_cur, logprob_min, logprob_sum, n_tokens, t_ms));
|
||||
|
||||
const float p = 100.0f * std::exp(logprob_min);
|
||||
|
||||
fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str());
|
||||
|
||||
// find the prompt in the text
|
||||
float best_sim = 0.0f;
|
||||
size_t best_len = 0;
|
||||
for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
|
||||
const auto prompt = txt.substr(0, n);
|
||||
|
||||
const float sim = similarity(prompt, k_prompt);
|
||||
|
||||
//fprintf(stderr, "%s: prompt = '%s', sim = %f\n", __func__, prompt.c_str(), sim);
|
||||
|
||||
if (sim > best_sim) {
|
||||
best_sim = sim;
|
||||
best_len = n;
|
||||
}
|
||||
}
|
||||
|
||||
fprintf(stdout, "%s: DEBUG: txt = '%s', prob = %.2f%%\n", __func__, txt.c_str(), p);
|
||||
std::string command = ::trim(txt.substr(best_len));
|
||||
|
||||
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
|
||||
fprintf(stdout, "\n");
|
||||
|
||||
if (!command.empty()) {
|
||||
set_move(m_board->process(command), p);
|
||||
set_grammar(m_board->grammar());
|
||||
}
|
||||
if (m_board->grammar().empty()) {
|
||||
fprintf(stdout, "%s: No more moves possible\n", __func__);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (ask_prompt) {
|
||||
fprintf(stdout, "\n");
|
||||
fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m");
|
||||
fprintf(stdout, "\n");
|
||||
|
||||
ask_prompt = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::string WChess::transcribe(
|
||||
const std::vector<float> & pcmf32,
|
||||
float & logprob_min,
|
||||
float & logprob_sum,
|
||||
int & n_tokens,
|
||||
int64_t & t_ms) {
|
||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
logprob_min = 0.0f;
|
||||
logprob_sum = 0.0f;
|
||||
n_tokens = 0;
|
||||
t_ms = 0;
|
||||
|
||||
if (whisper_full(m_ctx, m_wparams, pcmf32.data(), pcmf32.size()) != 0) {
|
||||
return {};
|
||||
}
|
||||
|
||||
std::string result;
|
||||
|
||||
const int n_segments = whisper_full_n_segments(m_ctx);
|
||||
for (int i = 0; i < n_segments; ++i) {
|
||||
const char * text = whisper_full_get_segment_text(m_ctx, i);
|
||||
|
||||
result += text;
|
||||
|
||||
const int n = whisper_full_n_tokens(m_ctx, i);
|
||||
for (int j = 0; j < n; ++j) {
|
||||
const auto token = whisper_full_get_token_data(m_ctx, i, j);
|
||||
|
||||
if(token.plog > 0.0f) return {};
|
||||
logprob_min = std::min(logprob_min, token.plog);
|
||||
logprob_sum += token.plog;
|
||||
++n_tokens;
|
||||
}
|
||||
}
|
||||
|
||||
const auto t_end = std::chrono::high_resolution_clock::now();
|
||||
t_ms = std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count();
|
||||
|
||||
return result;
|
||||
}
|
63
examples/wchess/libwchess/WChess.h
Normal file
@ -0,0 +1,63 @@
|
||||
#pragma once
|
||||
#include "whisper.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
class Chessboard;
|
||||
|
||||
class WChess {
|
||||
public:
|
||||
using CheckRunningCb = bool (*)();
|
||||
using GetAudioCb = bool (*)(std::vector<float> &);
|
||||
using SetMovesCb = void (*)(const std::string &, float);
|
||||
using SetGrammarCb = void (*)(const std::string &);
|
||||
using ClearAudioCb = void (*)();
|
||||
|
||||
struct callbacks {
|
||||
GetAudioCb get_audio = nullptr;
|
||||
SetMovesCb set_move = nullptr;
|
||||
SetGrammarCb set_grammar = nullptr;
|
||||
};
|
||||
|
||||
struct settings {
|
||||
int32_t vad_ms = 2000;
|
||||
int32_t prompt_ms = 5000;
|
||||
int32_t command_ms = 4000;
|
||||
float vad_thold = 0.2f;
|
||||
float freq_thold = 100.0f;
|
||||
bool print_energy = false;
|
||||
};
|
||||
|
||||
WChess(
|
||||
whisper_context * ctx,
|
||||
const whisper_full_params & wparams,
|
||||
callbacks cb,
|
||||
settings s
|
||||
);
|
||||
~WChess();
|
||||
|
||||
void run();
|
||||
|
||||
std::string stringify_board() const;
|
||||
|
||||
std::string get_grammar() const;
|
||||
|
||||
private:
|
||||
bool get_audio(std::vector<float>& pcmf32) const;
|
||||
void set_move(const std::string& moves, float prob) const;
|
||||
void set_grammar(const std::string& grammar) const;
|
||||
|
||||
std::string transcribe(
|
||||
const std::vector<float> & pcmf32,
|
||||
float & logprob_min,
|
||||
float & logprob_sum,
|
||||
int & n_tokens,
|
||||
int64_t & t_ms);
|
||||
|
||||
whisper_context * m_ctx;
|
||||
whisper_full_params m_wparams;
|
||||
const callbacks m_cb;
|
||||
const settings m_settings;
|
||||
std::unique_ptr<Chessboard> m_board;
|
||||
};
|
117
examples/wchess/libwchess/test-chessboard.cpp
Normal file
@ -0,0 +1,117 @@
|
||||
#include "Chessboard.h"
|
||||
|
||||
#define ASSERT(x) \
|
||||
do { \
|
||||
if (!(x)) { \
|
||||
fprintf(stderr, "ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
|
||||
fflush(stderr); \
|
||||
exit(1); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
|
||||
int main() {
|
||||
{
|
||||
Chessboard chess;
|
||||
|
||||
ASSERT(chess.process("pawn to d4") == "d2-d4");
|
||||
ASSERT(chess.process("e5") == "e7-e5");
|
||||
ASSERT(chess.process("c1 h6") == "c1-h6");
|
||||
ASSERT(chess.process("queen h4") == "d8-h4");
|
||||
ASSERT(chess.process("bishop to g5") == "h6-g5");
|
||||
ASSERT(chess.process("bishop to b4") == "f8-b4");
|
||||
ASSERT(chess.process("c4") == "");
|
||||
ASSERT(chess.process("knight c3") == "b1-c3");
|
||||
ASSERT(chess.process("knight c6") == "b8-c6");
|
||||
ASSERT(chess.process("f3") == "");
|
||||
}
|
||||
|
||||
{
|
||||
Chessboard chess;
|
||||
|
||||
ASSERT(chess.process("d4") == "d2-d4");
|
||||
ASSERT(chess.process("e5") == "e7-e5");
|
||||
ASSERT(chess.process("e4") == "e2-e4");
|
||||
ASSERT(chess.process("queen h4") == "d8-h4");
|
||||
ASSERT(chess.process("queen h5") == "d1-h5");
|
||||
ASSERT(chess.process("f5") == "");
|
||||
ASSERT(chess.process("g6") == "g7-g6");
|
||||
ASSERT(chess.process("knight e2") == "g1-e2");
|
||||
ASSERT(chess.process("f5") == "f7-f5");
|
||||
ASSERT(chess.process("knight g3") == "e2-g3");
|
||||
ASSERT(chess.process("g5") == "");
|
||||
ASSERT(chess.process("king e7") == "e8-e7");
|
||||
ASSERT(chess.process("f4") == "f2-f4");
|
||||
ASSERT(chess.process("g5") == "g6-g5");
|
||||
}
|
||||
|
||||
{
|
||||
Chessboard chess;
|
||||
|
||||
ASSERT(chess.process("e4") == "e2-e4");
|
||||
ASSERT(chess.process("c5") == "c7-c5");
|
||||
ASSERT(chess.process("e5") == "e4-e5");
|
||||
ASSERT(chess.process("c4") == "c5-c4");
|
||||
ASSERT(chess.process("e6") == "e5-e6");
|
||||
ASSERT(chess.process("c3") == "c4-c3");
|
||||
ASSERT(chess.process("e7") == "");
|
||||
ASSERT(chess.process("f7") == "e6-f7");
|
||||
ASSERT(chess.process("d2") == "");
|
||||
ASSERT(chess.process("king to f7") == "e8-f7");
|
||||
ASSERT(chess.process("f4") == "f2-f4");
|
||||
ASSERT(chess.process("d2") == "c3-d2");
|
||||
ASSERT(chess.process("f5") == "");
|
||||
ASSERT(chess.process("king to e2") == "e1-e2");
|
||||
ASSERT(chess.process("king to g6") == "f7-g6");
|
||||
ASSERT(chess.process("f5") == "f4-f5");
|
||||
ASSERT(chess.process("e6") == "");
|
||||
ASSERT(chess.process("king to h5") == "g6-h5");
|
||||
ASSERT(chess.process("g4") == "g2-g4");
|
||||
ASSERT(chess.process("king to g5") == "h5-g5");
|
||||
ASSERT(chess.process("h4") == "h2-h4");
|
||||
ASSERT(chess.process("king to h5") == "");
|
||||
ASSERT(chess.process("king to g6") == "");
|
||||
ASSERT(chess.process("king to h6") == "g5-h6");
|
||||
ASSERT(chess.process("bishop to d2") == "c1-d2");
|
||||
ASSERT(chess.process("king to g5") == "");
|
||||
ASSERT(chess.process("g5") == "g7-g5");
|
||||
}
|
||||
|
||||
{
|
||||
Chessboard chess;
|
||||
ASSERT(chess.process("f4") == "f2-f4");
|
||||
ASSERT(chess.process("e5") == "e7-e5");
|
||||
ASSERT(chess.process("g4") == "g2-g4");
|
||||
ASSERT(chess.process("queen to h4") == "d8-h4#");
|
||||
ASSERT(chess.process("knight f3") == "");
|
||||
ASSERT(chess.grammar().empty());
|
||||
}
|
||||
|
||||
{
|
||||
Chessboard chess;
|
||||
ASSERT(chess.process("f4") == "f2-f4");
|
||||
ASSERT(chess.process("e5") == "e7-e5");
|
||||
ASSERT(chess.process("g4") == "g2-g4");
|
||||
ASSERT(chess.process("d5") == "d7-d5");
|
||||
ASSERT(chess.process("g1 f3") == "g1-f3");
|
||||
ASSERT(chess.process("queen to h4") == "d8-h4");
|
||||
ASSERT(!chess.grammar().empty());
|
||||
}
|
||||
|
||||
{
|
||||
Chessboard chess;
|
||||
ASSERT(chess.process("knight c3") == "b1-c3");
|
||||
ASSERT(chess.process("knight c6") == "b8-c6");
|
||||
ASSERT(chess.process("knight b5") == "c3-b5");
|
||||
ASSERT(chess.process("knight f6") == "g8-f6");
|
||||
ASSERT(chess.process("knight d6") == "b5-d6");
|
||||
ASSERT(chess.process("knight d4") == "");
|
||||
ASSERT(chess.process("d6") == "c7-d6");
|
||||
ASSERT(chess.process("e4") == "e2-e4");
|
||||
ASSERT(chess.process("knight d4") == "c6-d4");
|
||||
ASSERT(chess.process("d3") == "d2-d3");
|
||||
ASSERT(chess.process("knight e4") == "f6-e4");
|
||||
ASSERT(chess.process("king to e2") == "");
|
||||
ASSERT(chess.process("king to d2") == "");
|
||||
}
|
||||
}
|
8
examples/wchess/wchess.cmd/CMakeLists.txt
Normal file
@ -0,0 +1,8 @@
|
||||
if (WHISPER_SDL2)
|
||||
set(TARGET wchess)
|
||||
add_executable(${TARGET} wchess.cmd.cpp)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE wchess-core common-sdl ${CMAKE_THREAD_LIBS_INIT})
|
||||
endif ()
|
247
examples/wchess/wchess.cmd/wchess.cmd.cpp
Normal file
@ -0,0 +1,247 @@
|
||||
// Command line voice assisted chess
|
||||
//
|
||||
// Speak chess move commands to the microphone.
|
||||
// The moves will translated to chessboard positions.
|
||||
//
|
||||
//
|
||||
|
||||
#include "WChess.h"
|
||||
#include "common-sdl.h"
|
||||
#include <iostream>
|
||||
|
||||
#include <memory>
|
||||
#include <thread>
|
||||
|
||||
// command-line parameters
|
||||
struct whisper_params {
|
||||
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||
int32_t prompt_ms = 5000;
|
||||
int32_t command_ms = 8000;
|
||||
int32_t capture_id = -1;
|
||||
int32_t max_tokens = 32;
|
||||
int32_t audio_ctx = 0;
|
||||
|
||||
float vad_thold = 0.6f;
|
||||
float freq_thold = 100.0f;
|
||||
|
||||
float grammar_penalty = 100.0f;
|
||||
|
||||
bool speed_up = false;
|
||||
bool translate = false;
|
||||
bool print_special = false;
|
||||
bool print_energy = false;
|
||||
bool no_timestamps = true;
|
||||
bool use_gpu = true;
|
||||
|
||||
std::string language = "en";
|
||||
std::string model = "models/ggml-base.en.bin";
|
||||
std::string fname_out;
|
||||
std::string commands;
|
||||
std::string prompt;
|
||||
std::string context;
|
||||
std::string grammar;
|
||||
};
|
||||
|
||||
void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params) {
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "usage: %s [options]\n", argv[0]);
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "options:\n");
|
||||
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
|
||||
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
|
||||
fprintf(stderr, " -pms N, --prompt-ms N [%-7d] prompt duration in milliseconds\n", params.prompt_ms);
|
||||
fprintf(stderr, " -cms N, --command-ms N [%-7d] command duration in milliseconds\n", params.command_ms);
|
||||
fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
|
||||
fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
|
||||
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
|
||||
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
|
||||
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
|
||||
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
|
||||
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
|
||||
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
||||
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
|
||||
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
|
||||
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
|
||||
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
||||
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
|
||||
fprintf(stderr, " -cmd FNAME, --commands FNAME [%-7s] text file with allowed commands\n", params.commands.c_str());
|
||||
fprintf(stderr, " -p, --prompt [%-7s] the required activation prompt\n", params.prompt.c_str());
|
||||
fprintf(stderr, " -ctx, --context [%-7s] sample text to help the transcription\n", params.context.c_str());
|
||||
fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty);
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
for (int i = 1; i < argc; i++) {
|
||||
std::string arg = argv[i];
|
||||
|
||||
if (arg == "-h" || arg == "--help") {
|
||||
whisper_print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
}
|
||||
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
|
||||
else if (arg == "-pms" || arg == "--prompt-ms") { params.prompt_ms = std::stoi(argv[++i]); }
|
||||
else if (arg == "-cms" || arg == "--command-ms") { params.command_ms = std::stoi(argv[++i]); }
|
||||
else if (arg == "-c" || arg == "--capture") { params.capture_id = std::stoi(argv[++i]); }
|
||||
else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); }
|
||||
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
|
||||
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
|
||||
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
|
||||
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
|
||||
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
|
||||
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
|
||||
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
|
||||
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
|
||||
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
|
||||
else if (arg == "-cmd" || arg == "--commands") { params.commands = argv[++i]; }
|
||||
else if (arg == "-p" || arg == "--prompt") { params.prompt = argv[++i]; }
|
||||
else if (arg == "-ctx" || arg == "--context") { params.context = argv[++i]; }
|
||||
else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
|
||||
else {
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
whisper_print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::unique_ptr<WChess> g_wchess;
|
||||
int g_moveCount = 0;
|
||||
void set_move(const std::string & move, float) {
|
||||
if (!move.empty()) {
|
||||
g_moveCount++;
|
||||
fprintf(stdout, "Move: %s\n\n", move.c_str());
|
||||
}
|
||||
else fprintf(stdout, "Move rejected\n\n");
|
||||
fprintf(stdout, "%s\n", g_wchess->stringify_board().c_str());
|
||||
fprintf(stdout, "%s\n", g_moveCount ? "White's turn" : "Black's turn");
|
||||
}
|
||||
|
||||
audio_async g_audio(30*1000);
|
||||
bool g_listening = false;
|
||||
std::vector<float> g_pcmf32;
|
||||
|
||||
bool read_input() {
|
||||
std::string input;
|
||||
while (true) {
|
||||
fprintf(stdout, "[(l)isten/(p)ause/(q)uit]: ");
|
||||
std::cin >> input;
|
||||
fprintf(stdout, "\n");
|
||||
if (input[0] == 'q') {
|
||||
fprintf(stdout, "Quitting\n");
|
||||
return false;
|
||||
}
|
||||
if (input[0] == 'l') {
|
||||
if (!g_listening) {
|
||||
fprintf(stdout, "Listening\n");
|
||||
g_listening = true;
|
||||
g_pcmf32.clear();
|
||||
g_audio.resume();
|
||||
g_audio.clear();
|
||||
}
|
||||
else fprintf(stdout, "Still listening\n");
|
||||
return true;
|
||||
}
|
||||
else {
|
||||
if (g_listening) {
|
||||
g_listening = false;
|
||||
g_audio.get(0, g_pcmf32);
|
||||
g_audio.pause();
|
||||
fprintf(stdout, "Processing\n");
|
||||
}
|
||||
else fprintf(stdout, "Not listening\n");
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool get_audio(std::vector<float> & pcmf32_cur) {
|
||||
if (!read_input()) return false;
|
||||
if (!g_pcmf32.empty()) pcmf32_cur = std::move(g_pcmf32);
|
||||
else pcmf32_cur.clear();
|
||||
return true;
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
whisper_params params;
|
||||
|
||||
if (whisper_params_parse(argc, argv, params) == false) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (whisper_lang_id(params.language.c_str()) == -1) {
|
||||
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
|
||||
whisper_print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
}
|
||||
|
||||
// whisper init
|
||||
|
||||
struct whisper_context_params cparams = whisper_context_default_params();
|
||||
cparams.use_gpu = params.use_gpu;
|
||||
|
||||
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
|
||||
if (!ctx) {
|
||||
fprintf(stderr, "%s: whisper_init_from_file_with_params() failed!\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// init audio
|
||||
|
||||
if (!g_audio.init(params.capture_id, WHISPER_SAMPLE_RATE)) {
|
||||
fprintf(stderr, "%s: audio.init() failed!\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
struct whisper_full_params wparams = whisper_full_default_params(whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY);
|
||||
wparams.offset_ms = 0;
|
||||
wparams.translate = false;
|
||||
wparams.no_context = true;
|
||||
wparams.single_segment = true;
|
||||
wparams.print_realtime = false;
|
||||
wparams.print_progress = false;
|
||||
wparams.print_timestamps = true;
|
||||
wparams.print_special = false;
|
||||
wparams.no_timestamps = true;
|
||||
|
||||
wparams.max_tokens = 32;
|
||||
wparams.audio_ctx = 768; // partial encoder context for better performance
|
||||
|
||||
wparams.temperature = 0.0f;
|
||||
wparams.temperature_inc = 2.0f;
|
||||
wparams.greedy.best_of = 1;
|
||||
|
||||
wparams.beam_search.beam_size = 1;
|
||||
|
||||
wparams.language = "en";
|
||||
|
||||
wparams.grammar_penalty = 100.0;
|
||||
|
||||
wparams.initial_prompt = params.context.data();
|
||||
|
||||
WChess::callbacks cb;
|
||||
cb.get_audio = get_audio;
|
||||
cb.set_move = set_move;
|
||||
|
||||
WChess::settings s;
|
||||
s.vad_ms = 2000;
|
||||
s.prompt_ms = params.prompt_ms;
|
||||
s.command_ms = params.command_ms;
|
||||
s.vad_thold = params.vad_thold;
|
||||
s.freq_thold = params.freq_thold;
|
||||
s.print_energy = params.print_energy;
|
||||
|
||||
g_wchess.reset(new WChess(ctx, wparams, cb, s));
|
||||
set_move("start", 0);
|
||||
g_wchess->run();
|
||||
|
||||
whisper_print_timings(ctx);
|
||||
whisper_free(ctx);
|
||||
|
||||
return 0;
|
||||
}
|
51
examples/wchess/wchess.wasm/CMakeLists.txt
Normal file
@ -0,0 +1,51 @@
|
||||
set(TARGET wchess.wasm)
|
||||
|
||||
add_executable(${TARGET}
|
||||
wchess.wasm.cpp
|
||||
)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE
|
||||
common
|
||||
wchess-core
|
||||
)
|
||||
|
||||
unset(EXTRA_FLAGS)
|
||||
|
||||
if (WHISPER_WASM_SINGLE_FILE)
|
||||
set(EXTRA_FLAGS "-s SINGLE_FILE=1")
|
||||
message(STATUS "Embedding WASM inside chess.js")
|
||||
|
||||
add_custom_command(
|
||||
TARGET ${TARGET} POST_BUILD
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${CMAKE_BINARY_DIR}/bin/${TARGET}.js
|
||||
${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TARGET}/js/chess.js
|
||||
)
|
||||
endif()
|
||||
|
||||
set_target_properties(${TARGET} PROPERTIES LINK_FLAGS " \
|
||||
--bind \
|
||||
-s USE_PTHREADS=1 \
|
||||
-s PTHREAD_POOL_SIZE=8 \
|
||||
-s INITIAL_MEMORY=1024MB \
|
||||
-s TOTAL_MEMORY=1024MB \
|
||||
-s FORCE_FILESYSTEM=1 \
|
||||
-s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap']\" \
|
||||
${EXTRA_FLAGS} \
|
||||
")
|
||||
|
||||
|
||||
add_custom_command(
|
||||
TARGET ${TARGET} POST_BUILD
|
||||
COMMAND ${CMAKE_COMMAND} -E copy_directory
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/chessboardjs-1.0.0
|
||||
${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TARGET}/
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/jquery-3.7.1.min.js
|
||||
${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TARGET}/js/
|
||||
)
|
||||
|
||||
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/index-tmpl.html ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TARGET}/index.html @ONLY)
|
||||
configure_file(${CMAKE_SOURCE_DIR}/examples/helpers.js ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TARGET}/js/helpers.js @ONLY)
|
@ -0,0 +1,54 @@
|
||||
/*! chessboard.js v1.0.0 | (c) 2019 Chris Oakman | MIT License chessboardjs.com/license */
|
||||
|
||||
.clearfix-7da63 {
|
||||
clear: both;
|
||||
}
|
||||
|
||||
.board-b72b1 {
|
||||
border: 2px solid #404040;
|
||||
box-sizing: content-box;
|
||||
}
|
||||
|
||||
.square-55d63 {
|
||||
float: left;
|
||||
position: relative;
|
||||
|
||||
/* disable any native browser highlighting */
|
||||
-webkit-touch-callout: none;
|
||||
-webkit-user-select: none;
|
||||
-khtml-user-select: none;
|
||||
-moz-user-select: none;
|
||||
-ms-user-select: none;
|
||||
user-select: none;
|
||||
}
|
||||
|
||||
.white-1e1d7 {
|
||||
background-color: #f0d9b5;
|
||||
color: #b58863;
|
||||
}
|
||||
|
||||
.black-3c85d {
|
||||
background-color: #b58863;
|
||||
color: #f0d9b5;
|
||||
}
|
||||
|
||||
.highlight1-32417, .highlight2-9c5d2 {
|
||||
box-shadow: inset 0 0 3px 3px yellow;
|
||||
}
|
||||
|
||||
.notation-322f9 {
|
||||
cursor: default;
|
||||
font-family: "Helvetica Neue", Helvetica, Arial, sans-serif;
|
||||
font-size: 14px;
|
||||
position: absolute;
|
||||
}
|
||||
|
||||
.alpha-d2270 {
|
||||
bottom: 1px;
|
||||
right: 3px;
|
||||
}
|
||||
|
||||
.numeric-fc462 {
|
||||
top: 2px;
|
||||
left: 2px;
|
||||
}
|
2
examples/wchess/wchess.wasm/chessboardjs-1.0.0/css/chessboard-1.0.0.min.css
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
/*! chessboard.js v1.0.0 | (c) 2019 Chris Oakman | MIT License chessboardjs.com/license */
|
||||
.clearfix-7da63{clear:both}.board-b72b1{border:2px solid #404040;box-sizing:content-box}.square-55d63{float:left;position:relative;-webkit-touch-callout:none;-webkit-user-select:none;-khtml-user-select:none;-moz-user-select:none;-ms-user-select:none;user-select:none}.white-1e1d7{background-color:#f0d9b5;color:#b58863}.black-3c85d{background-color:#b58863;color:#f0d9b5}.highlight1-32417,.highlight2-9c5d2{box-shadow:inset 0 0 3px 3px #ff0}.notation-322f9{cursor:default;font-family:"Helvetica Neue",Helvetica,Arial,sans-serif;font-size:14px;position:absolute}.alpha-d2270{bottom:1px;right:3px}.numeric-fc462{top:2px;left:2px}
|
After Width: | Height: | Size: 1.4 KiB |
After Width: | Height: | Size: 2.9 KiB |
After Width: | Height: | Size: 1.8 KiB |
After Width: | Height: | Size: 777 B |
After Width: | Height: | Size: 2.6 KiB |
After Width: | Height: | Size: 748 B |
After Width: | Height: | Size: 2.3 KiB |
After Width: | Height: | Size: 2.8 KiB |
After Width: | Height: | Size: 2.3 KiB |
After Width: | Height: | Size: 1.5 KiB |
After Width: | Height: | Size: 3.7 KiB |
After Width: | Height: | Size: 1.1 KiB |
2
examples/wchess/wchess.wasm/chessboardjs-1.0.0/js/chessboard-1.0.0.min.js
vendored
Normal file
@ -0,0 +1,32 @@
|
||||
# chessboard.js Change Log
|
||||
|
||||
All notable changes to this project will be documented in this file.
|
||||
|
||||
## [1.0.0] - 2019-06-11
|
||||
- Orientation methods now return current orientation. [Issue #64]
|
||||
- Drop support for IE8
|
||||
- Do not check for `window.JSON` (Error #1004)
|
||||
- Rename `ChessBoard` to `Chessboard` (`ChessBoard` is still supported, however)
|
||||
- id query selectors are now supported as the first argument to `Chessboard()`
|
||||
- Remove Error #1002
|
||||
- Format code according to [StandardJS]
|
||||
- Bump minimum jQuery version to 1.8.3
|
||||
- Throttle piece drag functions
|
||||
|
||||
## [0.3.0] - 2013-08-10
|
||||
- Added `appearSpeed` animation config property
|
||||
- Added `onSnapbackEnd` event
|
||||
- Added `onMoveEnd` event
|
||||
|
||||
## [0.2.0] - 2013-08-05
|
||||
- Added `onMouseoverSquare` and `onMouseoutSquare` events
|
||||
- Added `onSnapEnd` event
|
||||
- Added square code as CSS class on the squares
|
||||
- Added [chess.js] integration examples
|
||||
|
||||
## [0.1.0] - 2013-05-21
|
||||
- Initial release
|
||||
|
||||
[chess.js]:https://github.com/jhlywa/chess.js
|
||||
[Issue #64]:https://github.com/oakmac/chessboardjs/issues/64
|
||||
[StandardJS]:https://standardjs.com/
|
@ -0,0 +1,20 @@
|
||||
Copyright 2019 Chris Oakman
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining
|
||||
a copy of this software and associated documentation files (the
|
||||
"Software"), to deal in the Software without restriction, including
|
||||
without limitation the rights to use, copy, modify, merge, publish,
|
||||
distribute, sublicense, and/or sell copies of the Software, and to
|
||||
permit persons to whom the Software is furnished to do so, subject to
|
||||
the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be
|
||||
included in all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
||||
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
||||
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
||||
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
@ -0,0 +1,82 @@
|
||||
# chessboard.js
|
||||
|
||||
chessboard.js is a JavaScript chessboard component. It depends on [jQuery].
|
||||
|
||||
Please see [chessboardjs.com] for documentation and examples.
|
||||
|
||||
## What is chessboard.js?
|
||||
|
||||
chessboard.js is a JavaScript chessboard component with a flexible "just a
|
||||
board" API that
|
||||
|
||||
chessboard.js is a standalone JavaScript Chess Board. It is designed to be "just
|
||||
a board" and expose a powerful API so that it can be used in different ways.
|
||||
Here's a non-exhaustive list of things you can do with chessboard.js:
|
||||
|
||||
- Use chessboard.js to show game positions alongside your expert commentary.
|
||||
- Use chessboard.js to have a tactics website where users have to guess the best
|
||||
move.
|
||||
- Integrate chessboard.js and [chess.js] with a PGN database and allow people to
|
||||
search and playback games (see [Example 5000])
|
||||
- Build a chess server and have users play their games out using the
|
||||
chessboard.js board.
|
||||
|
||||
chessboard.js is flexible enough to handle any of these situations with relative
|
||||
ease.
|
||||
|
||||
## What can chessboard.js **not** do?
|
||||
|
||||
The scope of chessboard.js is limited to "just a board." This is intentional and
|
||||
makes chessboard.js flexible for handling a multitude of chess-related problems.
|
||||
|
||||
This is a common source of confusion for new users. [remove?]
|
||||
|
||||
Specifically, chessboard.js does not understand anything about how the game of
|
||||
chess is played: how a knight moves, who's turn is it, is White in check?, etc.
|
||||
|
||||
Fortunately, the powerful [chess.js] library deals with exactly this sort of
|
||||
problem domain and plays nicely with chessboard.js's flexible API. Some examples
|
||||
of chessboard.js combined with chess.js: 5000, 5001, 5002
|
||||
|
||||
Please see the powerful [chess.js] library for an API to deal with these sorts
|
||||
of questions.
|
||||
|
||||
|
||||
This logic is distinct from the logic of the board. Please see the powerful
|
||||
[chess.js] library for this aspect of your application.
|
||||
|
||||
|
||||
|
||||
Here is a list of things that chessboard.js is **not**:
|
||||
|
||||
- A chess engine
|
||||
- A legal move validator
|
||||
- A PGN parser
|
||||
|
||||
chessboard.js is designed to work well with any of those things, but the idea
|
||||
behind chessboard.js is that the logic that controls the board should be
|
||||
independent of those other problems.
|
||||
|
||||
## Docs and Examples
|
||||
|
||||
- Docs - <http://chessboardjs.com/docs>
|
||||
- Examples - <http://chessboardjs.com/examples>
|
||||
|
||||
## Developer Tools
|
||||
|
||||
```sh
|
||||
# create a build in the build/ directory
|
||||
npm run build
|
||||
|
||||
# re-build the website
|
||||
npm run website
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
[MIT License](LICENSE.md)
|
||||
|
||||
[jQuery]:https://jquery.com/
|
||||
[chessboardjs.com]:http://chessboardjs.com
|
||||
[chess.js]:https://github.com/jhlywa/chess.js
|
||||
[Example 5000]:http://chessboardjs.com/examples#5000
|
@ -0,0 +1,29 @@
|
||||
{
|
||||
"author": "Chris Oakman <chris@oakmac.com> (http://chrisoakman.com/)",
|
||||
"name": "@chrisoakman/chessboardjs",
|
||||
"description": "JavaScript chessboard widget",
|
||||
"homepage": "https://chessboardjs.com",
|
||||
"license": "MIT",
|
||||
"version": "1.0.0",
|
||||
"repository": {
|
||||
"type": "git",
|
||||
"url": "git://github.com/oakmac/chessboardjs.git"
|
||||
},
|
||||
"files": ["dist/"],
|
||||
"dependencies": {
|
||||
"jquery": ">=3.4.1"
|
||||
},
|
||||
"devDependencies": {
|
||||
"csso": "3.5.1",
|
||||
"fs-plus": "3.1.1",
|
||||
"kidif": "1.1.0",
|
||||
"mustache": "2.3.0",
|
||||
"standard": "10.0.2",
|
||||
"uglify-js": "3.6.0"
|
||||
},
|
||||
"scripts": {
|
||||
"build": "standard lib/chessboard.js && node scripts/build.js",
|
||||
"standard": "standard --fix lib/*.js website/js/*.js",
|
||||
"website": "node scripts/website.js"
|
||||
}
|
||||
}
|
499
examples/wchess/wchess.wasm/index-tmpl.html
Normal file
@ -0,0 +1,499 @@
|
||||
<!doctype html>
|
||||
<html lang="en-us">
|
||||
<head>
|
||||
<title>wchess : voice-controlled chess using Whisper + WebAssembly</title>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/iframe-resizer/4.3.1/iframeResizer.contentWindow.min.js"></script>
|
||||
|
||||
<meta name="viewport" content="width=device-width, initial-scale=0.7, maximum-scale=1, minimum-scale=0.7, user-scalable=no"/>
|
||||
<meta name="apple-mobile-web-app-capable" content="yes" />
|
||||
|
||||
<style>
|
||||
#output {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
margin: 0 auto;
|
||||
margin-top: 10px;
|
||||
border-left: 0px;
|
||||
border-right: 0px;
|
||||
padding-left: 0px;
|
||||
padding-right: 0px;
|
||||
display: block;
|
||||
background-color: black;
|
||||
color: white;
|
||||
font-size: 10px;
|
||||
font-family: 'Lucida Console', Monaco, monospace;
|
||||
outline: none;
|
||||
white-space: pre;
|
||||
overflow-wrap: normal;
|
||||
overflow-x: scroll;
|
||||
}
|
||||
.button {
|
||||
background-color: #000000;
|
||||
color: #FFFFFF;
|
||||
padding: 20px;
|
||||
border-radius: 10px;
|
||||
-moz-border-radius: 10px;
|
||||
-webkit-border-radius: 10px;
|
||||
margin:10px;
|
||||
width: 100px;
|
||||
height: 50px;
|
||||
-webkit-touch-callout: none; /* Safari */
|
||||
-webkit-user-select: none; /* Chrome */
|
||||
-moz-user-select: none; /* Firefox */
|
||||
-ms-user-select: none; /* Internet Explorer/Edge */
|
||||
user-select: none;
|
||||
}
|
||||
button[disabled]{
|
||||
background-color: #cccccc;
|
||||
color: #666666;
|
||||
padding: 20px;
|
||||
border-radius: 10px;
|
||||
-moz-border-radius: 10px;
|
||||
-webkit-border-radius: 10px;
|
||||
margin:10px;
|
||||
width: 100px;
|
||||
}
|
||||
.center {
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
width: 500px;
|
||||
}
|
||||
#description {
|
||||
width: 500px;
|
||||
}
|
||||
</style>
|
||||
<link rel="stylesheet" href="css/chessboard-1.0.0.min.css" integrity="sha384-q94+BZtLrkL1/ohfjR8c6L+A6qzNH9R2hBLwyoAfu3i/WCvQjzL2RQJ3uNHDISdU" crossorigin="anonymous">
|
||||
</head>
|
||||
<body>
|
||||
<div id="main-container">
|
||||
<div id="description">
|
||||
<b>wchess : voice-controlled chess using Whisper + WebAssembly</b>
|
||||
|
||||
<br><br>
|
||||
|
||||
This is a demonstration of using Whisper to recognize voice commands in the browser.
|
||||
|
||||
<br><br>
|
||||
|
||||
Usage:<br>
|
||||
|
||||
<ul>
|
||||
<li>Select a Whisper model</li>
|
||||
<li>Accept the microphone permission request if prompted</li>
|
||||
<li>Hold the button and say a chess move (e.g. "Knight to c3")</li>
|
||||
<li>Release the button and wait for the move to be recognized</li>
|
||||
<li>Repeat</li>
|
||||
</ul>
|
||||
|
||||
Examples:<br>
|
||||
|
||||
<ul>
|
||||
<li><b>"d4"</b></li>
|
||||
<li><b>"e2 e4"</b></li>
|
||||
<li><b>"Knight f3"</b></li>
|
||||
<li><b>"Bishop to b5"</b></li>
|
||||
</ul>
|
||||
|
||||
Features:<br>
|
||||
|
||||
<ul>
|
||||
<li>Model quantization for reduced memory footprint (~42MB)</li>
|
||||
<li><a href="https://github.com/ggerganov/whisper.cpp/pull/1229">Grammar-based sampling</a> for improved recognition accuracy</li>
|
||||
</ul>
|
||||
|
||||
<b>
|
||||
Note that not all chess moves are supported. For example, castling and pawn promotion
|
||||
currently do not work, but can be easily implemented. There could also be some bugs in
|
||||
the move handling logic in general. The main reason for that is to keep the implementation
|
||||
simple. The assumption is that a real application would already have a proper move
|
||||
validation logic in place.<br><br>
|
||||
|
||||
The main purpose of this example is to demonstrate the capabilities of whisper.cpp and
|
||||
its application in the browser for voice recognition locally on your device.
|
||||
</b>
|
||||
|
||||
<br><br>
|
||||
|
||||
You can find more about this project on <a href="https://github.com/ggerganov/whisper.cpp/tree/master/examples/wchess">GitHub</a>.
|
||||
|
||||
<br><br>
|
||||
|
||||
<b>More examples:</b>
|
||||
<a href="https://whisper.ggerganov.com/">main</a> |
|
||||
<a href="https://whisper.ggerganov.com/bench">bench</a> |
|
||||
<a href="https://whisper.ggerganov.com/stream">stream</a> |
|
||||
<a href="https://whisper.ggerganov.com/command">command</a> |
|
||||
<a href="https://whisper.ggerganov.com/talk">talk</a> |
|
||||
|
||||
<br><br>
|
||||
|
||||
</div>
|
||||
|
||||
<hr>
|
||||
|
||||
<div id="model-whisper">
|
||||
Whisper model: <span id="model-whisper-status"></span>
|
||||
<button id="fetch-whisper-tiny-en" onclick="loadWhisper()">tiny.en (Q8_0, 42 MB)</button>
|
||||
<span id="fetch-whisper-progress"></span>
|
||||
<br><br>
|
||||
<button id="clear" onclick="clearCache()">Clear browser cache</button>
|
||||
<!--
|
||||
<input type="file" id="file" name="file" onchange="loadFile(event, 'whisper.bin')" />
|
||||
-->
|
||||
</div>
|
||||
|
||||
<div id="game">
|
||||
<br>
|
||||
<div id="chessboard" style="width: 500px"></div>
|
||||
<script src="js/jquery-3.7.1.min.js"></script>
|
||||
<script src="js/chessboard-1.0.0.min.js"></script>
|
||||
<script>
|
||||
var board = Chessboard('chessboard', 'start')
|
||||
var move_count = 0;
|
||||
</script>
|
||||
|
||||
<br>
|
||||
|
||||
<div id="state">
|
||||
Status: <b><span id="state-status">select model</span></b>
|
||||
|
||||
<div id="input" class="center">
|
||||
<button id="toggler" class="button" onselectstart="return false" style="display: none">Hold</button>
|
||||
</div>
|
||||
|
||||
<pre id="state-grammar">[The grammar will be displayed here]</pre>
|
||||
|
||||
<pre id="state-moves">[The moves will be displayed here]</pre>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<hr>
|
||||
|
||||
Debug output:
|
||||
<textarea id="output" rows="20"></textarea>
|
||||
|
||||
<br>
|
||||
|
||||
<b>Troubleshooting</b>
|
||||
|
||||
<br><br>
|
||||
|
||||
The page does some heavy computations, so make sure:
|
||||
|
||||
<ul>
|
||||
<li>To use a modern web browser (e.g. Chrome, Firefox)</li>
|
||||
<li>Your browser supports WASM <a href="https://webassembly.org/roadmap/">Fixed-width SIMD</a></li>
|
||||
</ul>
|
||||
|
||||
<div class="cell-version">
|
||||
<span>
|
||||
|
|
||||
Build time: <span class="nav-link">@GIT_DATE@</span> |
|
||||
Commit hash: <a class="nav-link" href="https://github.com/ggerganov/whisper.cpp/commit/@GIT_SHA1@">@GIT_SHA1@</a> |
|
||||
Commit subject: <span class="nav-link">@GIT_COMMIT_SUBJECT@</span> |
|
||||
<a class="nav-link" href="https://github.com/ggerganov/whisper.cpp/tree/master/examples/command.wasm">Source Code</a> |
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script type="text/javascript" src="js/helpers.js"></script>
|
||||
<script type='text/javascript'>
|
||||
// web audio context
|
||||
var context = null;
|
||||
|
||||
// the command instance
|
||||
var instance = null;
|
||||
|
||||
// model name
|
||||
var model_whisper = null;
|
||||
var model_file = null;
|
||||
|
||||
var module_ready = null;
|
||||
|
||||
var Module = {
|
||||
print: printTextarea,
|
||||
printErr: printTextarea,
|
||||
setStatus: function(text) {
|
||||
printTextarea('js: ' + text);
|
||||
},
|
||||
monitorRunDependencies: function(left) {
|
||||
},
|
||||
preRun: function() {
|
||||
printTextarea('js: Preparing ...');
|
||||
},
|
||||
postRun: function() {
|
||||
printTextarea('js: Module initialized successfully!');
|
||||
module_ready = true;
|
||||
initInstance();
|
||||
}
|
||||
};
|
||||
|
||||
function initInstance() {
|
||||
if (!module_ready || !model_file || instance) return
|
||||
|
||||
instance = Module.init(model_file);
|
||||
|
||||
if (instance) {
|
||||
setStatus('Ready');
|
||||
printTextarea("js: whisper initialized, instance: " + instance);
|
||||
}
|
||||
else {
|
||||
printTextarea("js: failed to initialize whisper");
|
||||
}
|
||||
}
|
||||
|
||||
function setStatus(text) {
|
||||
document.getElementById('state-status').innerHTML = text;
|
||||
}
|
||||
|
||||
//
|
||||
// fetch models
|
||||
//
|
||||
|
||||
let dbVersion = 1
|
||||
let dbName = 'whisper.ggerganov.com';
|
||||
let indexedDB = window.indexedDB || window.mozIndexedDB || window.webkitIndexedDB || window.msIndexedDB
|
||||
|
||||
function storeFS(fname, buf) {
|
||||
// write to WASM file using FS_createDataFile
|
||||
// if the file exists, delete it
|
||||
try {
|
||||
Module.FS_unlink(fname);
|
||||
} catch (e) {
|
||||
// ignore
|
||||
}
|
||||
|
||||
Module.FS_createDataFile("/", fname, buf, true, true);
|
||||
|
||||
printTextarea('storeFS: stored model: ' + fname + ' size: ' + buf.length);
|
||||
|
||||
document.getElementById('model-whisper-status').innerHTML = 'loaded "' + model_whisper + '"!';
|
||||
|
||||
model_file = fname;
|
||||
initInstance();
|
||||
}
|
||||
|
||||
function loadWhisper() {
|
||||
setStatus('Loading')
|
||||
//let url = 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en-q8_0.bin';
|
||||
let url = 'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny.en-q8_0.bin';
|
||||
let dst = 'whisper.bin';
|
||||
let size_mb = 42;
|
||||
|
||||
model_whisper = 'tiny.en-q8_0';
|
||||
|
||||
document.getElementById('model-whisper-status').innerHTML = 'loading "' + model_whisper + '" ... ';
|
||||
document.getElementById('fetch-whisper-tiny-en').style.display = 'none';
|
||||
|
||||
cbProgress = function(p) {
|
||||
let el = document.getElementById('fetch-whisper-progress');
|
||||
el.innerHTML = Math.round(100*p) + '%';
|
||||
};
|
||||
|
||||
cbCancel = function() {
|
||||
var el;
|
||||
el = document.getElementById('model-whisper-status'); if (el) el.innerHTML = '';
|
||||
};
|
||||
|
||||
loadRemote(url, dst, size_mb, cbProgress, storeFS, cbCancel, printTextarea);
|
||||
|
||||
// init audio capture so that the user receives a permission request
|
||||
{
|
||||
let context = new AudioContext({
|
||||
sampleRate: 16000,
|
||||
channelCount: 1,
|
||||
echoCancellation: false,
|
||||
autoGainControl: true,
|
||||
noiseSuppression: true,
|
||||
});
|
||||
navigator.mediaDevices.getUserMedia({audio: true, video: false})
|
||||
.then(function(s) {
|
||||
stream = s;
|
||||
stream.getTracks().forEach(function(track) {
|
||||
track.stop();
|
||||
});
|
||||
})
|
||||
.catch(function(err) {
|
||||
printTextarea('js: error getting audio stream: ' + err);
|
||||
});
|
||||
context.close();
|
||||
}
|
||||
|
||||
document.getElementById('toggler').style.display = 'block';
|
||||
}
|
||||
|
||||
//
|
||||
// microphone
|
||||
//
|
||||
|
||||
const kSampleRate = 16000;
|
||||
const kRestartRecording_s = 120;
|
||||
const kIntervalAudio_ms = 250; // pass the recorded audio to the C++ instance at this rate
|
||||
|
||||
var mediaRecorder = null;
|
||||
var doRecording = false;
|
||||
var startTime = 0;
|
||||
|
||||
window.AudioContext = window.AudioContext || window.webkitAudioContext;
|
||||
window.OfflineAudioContext = window.OfflineAudioContext || window.webkitOfflineAudioContext;
|
||||
|
||||
function stopRecording() {
|
||||
if (mediaRecorder) {
|
||||
mediaRecorder.stop();
|
||||
}
|
||||
}
|
||||
|
||||
function startRecording() {
|
||||
if (!context) {
|
||||
context = new AudioContext({
|
||||
sampleRate: kSampleRate,
|
||||
channelCount: 1,
|
||||
echoCancellation: false,
|
||||
autoGainControl: true,
|
||||
noiseSuppression: true,
|
||||
});
|
||||
}
|
||||
|
||||
startTime = Date.now();
|
||||
|
||||
var chunks = [];
|
||||
var stream = null;
|
||||
|
||||
navigator.mediaDevices.getUserMedia({audio: true, video: false})
|
||||
.then(function(s) {
|
||||
stream = s;
|
||||
mediaRecorder = new MediaRecorder(stream);
|
||||
mediaRecorder.ondataavailable = function(e) {
|
||||
chunks.push(e.data);
|
||||
|
||||
var blob = new Blob(chunks, { 'type' : 'audio/ogg; codecs=opus' });
|
||||
var reader = new FileReader();
|
||||
|
||||
reader.onload = function(event) {
|
||||
var buf = new Uint8Array(reader.result);
|
||||
context.decodeAudioData(buf.buffer, function(audioBuffer) {
|
||||
var offlineContext = new OfflineAudioContext(audioBuffer.numberOfChannels, audioBuffer.length, audioBuffer.sampleRate);
|
||||
var source = offlineContext.createBufferSource();
|
||||
source.buffer = audioBuffer;
|
||||
source.connect(offlineContext.destination);
|
||||
source.start(0);
|
||||
|
||||
offlineContext.startRendering().then(function(renderedBuffer) {
|
||||
let audio = renderedBuffer.getChannelData(0);
|
||||
printTextarea('js: number of samples: ' + audio.length);
|
||||
Module.set_audio(instance, audio);
|
||||
});
|
||||
|
||||
mediaRecorder = null;
|
||||
context = null;
|
||||
});
|
||||
}
|
||||
|
||||
reader.readAsArrayBuffer(blob);
|
||||
};
|
||||
|
||||
mediaRecorder.onstop = function(e) {
|
||||
stream.getTracks().forEach(function(track) {
|
||||
track.stop();
|
||||
});
|
||||
};
|
||||
|
||||
mediaRecorder.start();
|
||||
})
|
||||
.catch(function(err) {
|
||||
printTextarea('js: error getting audio stream: ' + err);
|
||||
});
|
||||
}
|
||||
|
||||
//
|
||||
// main
|
||||
//
|
||||
|
||||
var nLines = 0;
|
||||
var movesAll = '';
|
||||
|
||||
// document.body.addEventListener('keydown', function(event) {
|
||||
// if (event.keyCode === 32) {
|
||||
// document.getElementById('toggler').innerText = "";
|
||||
// onStart();
|
||||
// }
|
||||
// }, true);
|
||||
|
||||
// document.body.addEventListener('keyup', function(event) {
|
||||
// if (event.keyCode === 32) {
|
||||
// document.getElementById('toggler').innerText = "Hold";
|
||||
// onStop();
|
||||
// }
|
||||
// }, true);
|
||||
|
||||
document.getElementById('toggler').addEventListener("touchstart", function(event){
|
||||
this.innerText = "";
|
||||
onStart();
|
||||
}, true);
|
||||
|
||||
document.getElementById('toggler').addEventListener("touchend", function(event){
|
||||
this.innerText = "Hold";
|
||||
onStop();
|
||||
}, true)
|
||||
|
||||
document.getElementById('toggler').addEventListener('mousedown', function(event) {
|
||||
this.innerText = "";
|
||||
onStart();
|
||||
}, true);
|
||||
|
||||
document.getElementById('toggler').addEventListener('mouseup', function(event) {
|
||||
this.innerText = "Hold";
|
||||
onStop();
|
||||
}, true);
|
||||
|
||||
function onStart() {
|
||||
if (!instance) return;
|
||||
setStatus('Listening');
|
||||
|
||||
startRecording();
|
||||
}
|
||||
|
||||
function onStop() {
|
||||
setStatus('Processing');
|
||||
printTextarea('js: stopping recording ...');
|
||||
stopRecording();
|
||||
}
|
||||
|
||||
function setMove(move, prob) {
|
||||
if (move != null && move.length > 1) {
|
||||
let gameOver = move[move.length - 1] === '#';
|
||||
if (gameOver) {
|
||||
move = move.substring(0, move.length - 1);
|
||||
document.getElementById('toggler').disabled = true;
|
||||
}
|
||||
board.move(move);
|
||||
|
||||
movesAll += move + ', prob = ' + prob.toFixed(2) + '% <br>';
|
||||
nLines++;
|
||||
|
||||
// if more than 10 lines, remove the first line
|
||||
if (nLines > 10) {
|
||||
var i = movesAll.indexOf('<br>');
|
||||
if (i > 0) {
|
||||
movesAll = movesAll.substring(i + 4);
|
||||
nLines--;
|
||||
}
|
||||
}
|
||||
++move_count;
|
||||
setStatus(gameOver ? 'Done' : move_count % 2 ? 'Black\'s turn' : 'White\'s turn');
|
||||
document.getElementById('state-moves').innerHTML = movesAll;
|
||||
}
|
||||
else {
|
||||
setStatus('Failed. ' + (move_count % 2 ? 'Black\'s turn' : 'White\'s turn'));
|
||||
}
|
||||
}
|
||||
|
||||
function setGrammar(grammar) {
|
||||
document.getElementById('state-grammar').innerHTML = grammar;
|
||||
}
|
||||
|
||||
</script>
|
||||
<script type="text/javascript" src="js/chess.js"></script>
|
||||
</body>
|
||||
</html>
|
2
examples/wchess/wchess.wasm/jquery-3.7.1.min.js
vendored
Normal file
141
examples/wchess/wchess.wasm/wchess.wasm.cpp
Normal file
@ -0,0 +1,141 @@
|
||||
#include <WChess.h>
|
||||
#include <emscripten.h>
|
||||
#include <emscripten/bind.h>
|
||||
|
||||
#include <thread>
|
||||
|
||||
constexpr int N_THREAD = 8;
|
||||
|
||||
std::vector<struct whisper_context *> g_contexts(4, nullptr);
|
||||
|
||||
std::mutex g_mutex;
|
||||
std::thread g_worker;
|
||||
|
||||
std::condition_variable g_cv;
|
||||
|
||||
bool g_running(false);
|
||||
std::vector<float> g_pcmf32;
|
||||
|
||||
void set_move(const std::string & move, float prob) {
|
||||
MAIN_THREAD_EM_ASM({
|
||||
setMove(UTF8ToString($0), $1)
|
||||
}, move.c_str(), prob);
|
||||
}
|
||||
|
||||
void set_grammar(const std::string & grammar) {
|
||||
MAIN_THREAD_EM_ASM({
|
||||
setGrammar(UTF8ToString($0))
|
||||
}, grammar.c_str());
|
||||
}
|
||||
|
||||
bool get_audio(std::vector<float> & audio) {
|
||||
std::unique_lock<std::mutex> lock(g_mutex);
|
||||
g_cv.wait(lock, [] { return !g_running || !g_pcmf32.empty(); });
|
||||
if (!g_running) return false;
|
||||
audio = std::move(g_pcmf32);
|
||||
return true;
|
||||
}
|
||||
|
||||
void wchess_main(size_t i) {
|
||||
struct whisper_full_params wparams = whisper_full_default_params(whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY);
|
||||
|
||||
wparams.n_threads = std::min(N_THREAD, (int) std::thread::hardware_concurrency());
|
||||
wparams.offset_ms = 0;
|
||||
wparams.translate = false;
|
||||
wparams.no_context = true;
|
||||
wparams.single_segment = true;
|
||||
wparams.print_realtime = false;
|
||||
wparams.print_progress = false;
|
||||
wparams.print_timestamps = true;
|
||||
wparams.print_special = false;
|
||||
wparams.no_timestamps = true;
|
||||
|
||||
wparams.max_tokens = 32;
|
||||
wparams.audio_ctx = 1280; // partial encoder context for better performance
|
||||
|
||||
wparams.temperature = 0.0f;
|
||||
wparams.temperature_inc = 2.0f;
|
||||
wparams.greedy.best_of = 1;
|
||||
|
||||
wparams.beam_search.beam_size = 1;
|
||||
|
||||
wparams.language = "en";
|
||||
|
||||
wparams.grammar_penalty = 100.0;
|
||||
wparams.initial_prompt = "bishop to c3, rook to d4, knight to e5, d4 d5, knight to c3, c3, queen to d4, king b1, pawn to a1, bishop to b2, knight to c3,";
|
||||
|
||||
printf("command: using %d threads\n", wparams.n_threads);
|
||||
|
||||
WChess::callbacks cb;
|
||||
cb.get_audio = get_audio;
|
||||
cb.set_move = set_move;
|
||||
cb.set_grammar = set_grammar;
|
||||
|
||||
WChess(g_contexts[i], wparams, cb, {}).run();
|
||||
|
||||
if (i < g_contexts.size()) {
|
||||
whisper_free(g_contexts[i]);
|
||||
g_contexts[i] = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
EMSCRIPTEN_BINDINGS(command) {
|
||||
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
|
||||
for (size_t i = 0; i < g_contexts.size(); ++i) {
|
||||
if (g_contexts[i] == nullptr) {
|
||||
g_contexts[i] = whisper_init_from_file_with_params(path_model.c_str(), whisper_context_default_params());
|
||||
if (g_contexts[i] != nullptr) {
|
||||
g_running = true;
|
||||
if (g_worker.joinable()) {
|
||||
g_worker.join();
|
||||
}
|
||||
g_worker = std::thread([i]() {
|
||||
wchess_main(i);
|
||||
});
|
||||
|
||||
return i + 1;
|
||||
} else {
|
||||
return (size_t) 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return (size_t) 0;
|
||||
}));
|
||||
|
||||
emscripten::function("free", emscripten::optional_override([](size_t /* index */) {
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(g_mutex);
|
||||
g_running = false;
|
||||
}
|
||||
g_cv.notify_one();
|
||||
}));
|
||||
|
||||
emscripten::function("set_audio", emscripten::optional_override([](size_t index, const emscripten::val & audio) {
|
||||
--index;
|
||||
|
||||
if (index >= g_contexts.size()) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (g_contexts[index] == nullptr) {
|
||||
return -2;
|
||||
}
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(g_mutex);
|
||||
const int n = audio["length"].as<int>();
|
||||
|
||||
emscripten::val heap = emscripten::val::module_property("HEAPU8");
|
||||
emscripten::val memory = heap["buffer"];
|
||||
|
||||
g_pcmf32.resize(n);
|
||||
|
||||
emscripten::val memoryView = audio["constructor"].new_(memory, reinterpret_cast<uintptr_t>(g_pcmf32.data()), n);
|
||||
memoryView.call<void>("set", audio);
|
||||
}
|
||||
g_cv.notify_one();
|
||||
|
||||
return 0;
|
||||
}));
|
||||
}
|
15
examples/whisper.android.java/.gitignore
vendored
Normal file
@ -0,0 +1,15 @@
|
||||
*.iml
|
||||
.gradle
|
||||
/local.properties
|
||||
/.idea/caches
|
||||
/.idea/libraries
|
||||
/.idea/modules.xml
|
||||
/.idea/workspace.xml
|
||||
/.idea/navEditor.xml
|
||||
/.idea/assetWizardSettings.xml
|
||||
.DS_Store
|
||||
/build
|
||||
/captures
|
||||
.externalNativeBuild
|
||||
.cxx
|
||||
local.properties
|
20
examples/whisper.android.java/README.md
Normal file
@ -0,0 +1,20 @@
|
||||
A sample Android app using java code and [whisper.cpp](https://github.com/ggerganov/whisper.cpp/) to do voice-to-text transcriptions.
|
||||
|
||||
To use:
|
||||
|
||||
1. Select a model from the [whisper.cpp repository](https://github.com/ggerganov/whisper.cpp/tree/master/models).[^1]
|
||||
2. Copy the model to the "app/src/main/assets/models" folder.
|
||||
3. Select a sample audio file (for example, [jfk.wav](https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav)).
|
||||
4. Copy the sample to the "app/src/main/assets/samples" folder.
|
||||
5. Modify the modelFilePath in the WhisperService.java
|
||||
6. Modify the sampleFilePath in the WhisperService.java
|
||||
7. Select the "release" active build variant, and use Android Studio to run and deploy to your device.
|
||||
[^1]: I recommend the tiny or base models for running on an Android device.
|
||||
|
||||
PS:
|
||||
1. Do not move this android project folder individually to other folders, because this android project folder depends on the files of the whole project.
|
||||
2. The cpp code is compiled during the build process
|
||||
3. If you want to import a compiled cpp project in your Android project, please refer to the https://github.com/litongjava/whisper.cpp.android.java.demo
|
||||
|
||||

|
||||
|
BIN
examples/whisper.android.java/README_files/1.jpg
Normal file
After Width: | Height: | Size: 67 KiB |
1
examples/whisper.android.java/app/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
/build
|