Compare commits

..

3 Commits

396 changed files with 56483 additions and 71215 deletions

View File

@ -12,7 +12,7 @@ FROM ${BASE_CUDA_DEV_CONTAINER} as build
ARG CUDA_DOCKER_ARCH=all
RUN apt-get update && \
apt-get install -y build-essential git cmake libsdl2-dev wget
apt-get install -y build-essential git cmake libsdl2-dev
WORKDIR /app
@ -23,6 +23,6 @@ ENV CUDA_DOCKER_ARCH=${CUDA_DOCKER_ARCH}
# Enable cuBLAS
ENV GGML_CUDA=1
RUN make base.en
RUN make
ENTRYPOINT ["/app/main"]

View File

@ -17,7 +17,7 @@ ENV CUDA_DOCKER_ARCH=${CUDA_DOCKER_ARCH}
ENV GGML_CUDA=1
RUN apt-get update && \
apt-get install -y build-essential libsdl2-dev wget cmake \
apt-get install -y build-essential libsdl2-dev \
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
# Ref: https://stackoverflow.com/a/53464012
@ -25,7 +25,7 @@ ENV CUDA_MAIN_VERSION=12.3
ENV LD_LIBRARY_PATH /usr/local/cuda-${CUDA_MAIN_VERSION}/compat:$LD_LIBRARY_PATH
COPY .. .
RUN make base.en
RUN make
FROM ${BASE_CUDA_RUN_CONTAINER} AS runtime
ENV CUDA_MAIN_VERSION=12.3
@ -33,7 +33,7 @@ ENV LD_LIBRARY_PATH /usr/local/cuda-${CUDA_MAIN_VERSION}/compat:$LD_LIBRARY_PATH
WORKDIR /app
RUN apt-get update && \
apt-get install -y curl ffmpeg wget cmake \
apt-get install -y curl ffmpeg \
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
COPY --from=build /app /app

View File

@ -2,17 +2,17 @@ FROM ubuntu:22.04 AS build
WORKDIR /app
RUN apt-get update && \
apt-get install -y build-essential wget cmake \
apt-get install -y build-essential \
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
COPY .. .
RUN make base.en
RUN make
FROM ubuntu:22.04 AS runtime
WORKDIR /app
RUN apt-get update && \
apt-get install -y curl ffmpeg libsdl2-dev wget cmake \
apt-get install -y curl ffmpeg libsdl2-dev \
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
COPY --from=build /app /app

View File

@ -3,41 +3,61 @@ on:
push:
paths:
- bindings/ruby/**
- src/**/*.c
- src/**/*.cpp
- src/**/*.h
- src/**/*.m
- src/**/*.metal
- include/**/*.c
- include/**/*.cpp
- include/**/*.h
- include/**/*.m
- include/**/*.metal
- ggml/**/*.c
- ggml/**/*.cpp
- ggml/**/*.h
- ggml/**/*.m
- ggml/**/*.metal
- src/whisper.cpp
- include/whisper.h
- ggml/src/ggml.c
- ggml/src/ggml-impl.h
- ggml/src/ggml-aarch64.h
- ggml/src/ggml-aarch64.c
- ggml/src/ggml-alloc.c
- ggml/src/ggml-backend-impl.h
- ggml/src/ggml-backend.cpp
- ggml/src/ggml-common.h
- ggml/src/ggml-quants.h
- ggml/src/ggml-quants.c
- ggml/src/ggml-cpu-impl.h
- ggml/src/ggml-metal.m
- ggml/src/ggml-metal.metal
- ggml/src/ggml-blas.cpp
- ggml/include/ggml.h
- ggml/include/ggml-alloc.h
- ggml/include/ggml-backend.h
- ggml/include/ggml-cuda.h
- ggml/include/ggml-kompute.h
- ggml/include/ggml-metal.h
- ggml/include/ggml-sycl.h
- ggml/include/ggml-vulkan.h
- ggml/include/ggml-blas.h
- scripts/get-flags.mk
- examples/dr_wav.h
pull_request:
paths:
- bindings/ruby/**
- src/**/*.c
- src/**/*.cpp
- src/**/*.h
- src/**/*.m
- src/**/*.metal
- include/**/*.c
- include/**/*.cpp
- include/**/*.h
- include/**/*.m
- include/**/*.metal
- ggml/**/*.c
- ggml/**/*.cpp
- ggml/**/*.h
- ggml/**/*.m
- ggml/**/*.metal
- src/whisper.cpp
- include/whisper.h
- ggml/src/ggml.c
- ggml/src/ggml-impl.h
- ggml/src/ggml-aarch64.h
- ggml/src/ggml-aarch64.c
- ggml/src/ggml-alloc.c
- ggml/src/ggml-backend-impl.h
- ggml/src/ggml-backend.cpp
- ggml/src/ggml-common.h
- ggml/src/ggml-quants.h
- ggml/src/ggml-quants.c
- ggml/src/ggml-cpu-impl.h
- ggml/src/ggml-metal.m
- ggml/src/ggml-metal.metal
- ggml/src/ggml-blas.cpp
- ggml/include/ggml.h
- ggml/include/ggml-alloc.h
- ggml/include/ggml-backend.h
- ggml/include/ggml-cuda.h
- ggml/include/ggml-kompute.h
- ggml/include/ggml-metal.h
- ggml/include/ggml-sycl.h
- ggml/include/ggml-vulkan.h
- ggml/include/ggml-blas.h
- scripts/get-flags.mk
- examples/dr_wav.h
@ -50,6 +70,6 @@ jobs:
steps:
- uses: ruby/setup-ruby@v1
with:
ruby-version: '3.1'
ruby-version: '3.0'
- uses: actions/checkout@v4
- run: rake test

View File

@ -1,19 +1,8 @@
name: CI
on:
push:
branches:
- master
pull_request:
types: [opened, synchronize, reopened]
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
cancel-in-progress: true
on: [push, pull_request]
env:
ubuntu_image: "ubuntu:22.04"
VCPKG_BINARY_SOURCES: "clear;x-gha,readwrite"
jobs:
ubuntu-latest:
@ -22,7 +11,7 @@ jobs:
strategy:
fail-fast: false
matrix:
arch: [linux/amd64, linux/ppc64le]
arch: [linux/amd64, linux/arm64, linux/arm/v7, linux/ppc64le]
steps:
- name: Clone
@ -38,61 +27,9 @@ jobs:
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
set -e
apt update
apt install -y build-essential libsdl2-dev cmake
cmake -B build
cmake --build build --config Release -j $(nproc)'
ubuntu-latest-arm64:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
arch: [linux/arm64]
steps:
- name: Clone
uses: actions/checkout@v4
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Build ${{ matrix.arch }}
run: |
docker run --platform ${{ matrix.arch }} --rm \
-v ${{ github.workspace }}:/workspace \
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
set -e
apt update
apt install -y build-essential libsdl2-dev cmake
cmake -B build -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv8-a
cmake --build build --config Release -j $(nproc)'
ubuntu-latest-arm-v7:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
arch: [linux/arm/v7]
steps:
- name: Clone
uses: actions/checkout@v4
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Build ${{ matrix.arch }}
run: |
docker run --platform ${{ matrix.arch }} --rm \
-v ${{ github.workspace }}:/workspace \
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
set -e
apt update
apt install -y build-essential libsdl2-dev cmake
cmake -B build -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv7-a+fp
cmake --build build --config Release -j $(nproc)'
apt install -y build-essential libsdl2-dev
make
make stream'
macOS-latest:
runs-on: macOS-latest
@ -104,30 +41,30 @@ jobs:
- name: Dependencies
run: |
brew update
brew install sdl2 cmake
brew install sdl2
- name: Build
run: |
cmake -B build
cmake --build build --config Release
make
make stream
# freeBSD-latest:
# runs-on: macos-12
#
# steps:
# - name: Clone
# uses: actions/checkout@v4
#
# - name: Build
# uses: cross-platform-actions/action@v0.24.0
# with:
# operating_system: freebsd
# version: '13.3'
# run: |
# sudo pkg update
# sudo pkg install -y gmake sdl2 cmake
# cmake -B build
# cmake --build build --config Release
freeBSD-latest:
runs-on: macos-12
steps:
- name: Clone
uses: actions/checkout@v4
- name: Build
uses: cross-platform-actions/action@v0.24.0
with:
operating_system: freebsd
version: '13.3'
run: |
sudo pkg update
sudo pkg install -y gmake sdl2
gmake
gmake stream
ubuntu-latest-gcc:
runs-on: ubuntu-latest
@ -136,7 +73,7 @@ jobs:
fail-fast: false
matrix:
build: [Debug, Release]
arch: [linux/amd64, linux/ppc64le]
arch: [linux/amd64, linux/arm64, linux/arm/v7, linux/ppc64le]
steps:
- name: Clone
@ -157,62 +94,6 @@ jobs:
make
ctest -L gh --output-on-failure'
ubuntu-latest-gcc-arm64:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
build: [Debug, Release]
arch: [linux/arm64]
steps:
- name: Clone
uses: actions/checkout@v4
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Build ${{ matrix.arch }}
run: |
docker run --platform ${{ matrix.arch }} --rm \
-v ${{ github.workspace }}:/workspace \
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
set -e
apt update
apt install -y build-essential cmake libsdl2-dev
cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv8-a
make
ctest -L gh --output-on-failure'
ubuntu-latest-gcc-arm-v7:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
build: [Debug, Release]
arch: [linux/arm/v7]
steps:
- name: Clone
uses: actions/checkout@v4
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Build ${{ matrix.arch }}
run: |
docker run --platform ${{ matrix.arch }} --rm \
-v ${{ github.workspace }}:/workspace \
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
set -e
apt update
apt install -y build-essential cmake libsdl2-dev
cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv7-a+fp
make
ctest -L gh --output-on-failure'
ubuntu-latest-clang:
runs-on: ubuntu-latest
@ -398,10 +279,25 @@ jobs:
mingw-w64-${{matrix.env}}-SDL2
mingw-w64-${{matrix.env}}-openblas
- name: Build using make
shell: msys2 {0}
run: |
make -j $(nproc)
- name: Clean after building using make
shell: msys2 {0}
run: |
make clean
- name: Build using make w/ OpenBLAS
shell: msys2 {0}
run: |
make GGML_OPENBLAS=1 -j $(nproc)
- name: Build using CMake
shell: msys2 {0}
run: |
cmake -B build -DWHISPER_SDL2=ON
cmake -B build
cmake --build build --config ${{ matrix.build }} -j $(nproc)
- name: Clean after building using CMake
@ -412,7 +308,7 @@ jobs:
- name: Build using CMake w/ OpenBLAS
shell: msys2 {0}
run: |
cmake -B build -DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS
cmake -B build -DGGML_OPENBLAS=ON
cmake --build build --config ${{ matrix.build }} -j $(nproc)
windows:
@ -486,8 +382,10 @@ jobs:
sdl2: [ON]
include:
- arch: Win32
obzip: https://github.com/OpenMathLib/OpenBLAS/releases/download/v0.3.25/OpenBLAS-0.3.25-x86.zip
s2arc: x86
- arch: x64
obzip: https://github.com/OpenMathLib/OpenBLAS/releases/download/v0.3.25/OpenBLAS-0.3.25-x64.zip
s2arc: x64
- sdl2: ON
s2ver: 2.28.5
@ -496,21 +394,17 @@ jobs:
- name: Clone
uses: actions/checkout@v4
- name: Export GitHub Actions cache environment variables
uses: actions/github-script@v7
with:
script: |
core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || '');
core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || '');
- name: Add msbuild to PATH
uses: microsoft/setup-msbuild@v2
- name: Install OpenBLAS and pkgconfiglite
- name: Fetch OpenBLAS
if: matrix.blas == 'ON'
run: |
vcpkg install --triplet=${{ matrix.s2arc }}-windows openblas
choco install pkgconfiglite
C:/msys64/usr/bin/wget.exe -qO blas.zip ${{ matrix.obzip }}
7z x blas.zip -oblas -y
copy blas/include/cblas.h .
copy blas/include/openblas_config.h .
echo "OPENBLAS_PATH=$env:GITHUB_WORKSPACE/blas" >> $env:GITHUB_ENV
- name: Fetch SDL2 and set SDL2_DIR
if: matrix.sdl2 == 'ON'
@ -522,10 +416,9 @@ jobs:
- name: Configure
run: >
cmake -S . -B ./build -A ${{ matrix.arch }}
-DCMAKE_TOOLCHAIN_FILE="$env:VCPKG_INSTALLATION_ROOT/scripts/buildsystems/vcpkg.cmake"
-DCMAKE_BUILD_TYPE=${{ matrix.build }}
-DGGML_BLAS=${{ matrix.blas }}
-DGGML_BLAS_VENDOR=OpenBLAS
-DGGML_OPENBLAS=${{ matrix.blas }}
-DCMAKE_LIBRARY_PATH="$env:OPENBLAS_PATH/lib"
-DWHISPER_SDL2=${{ matrix.sdl2 }}
- name: Build
@ -533,9 +426,9 @@ jobs:
cd ./build
msbuild ALL_BUILD.vcxproj -t:build -p:configuration=${{ matrix.build }} -p:platform=${{ matrix.arch }}
- name: Copy openblas.dll
- name: Copy libopenblas.dll
if: matrix.blas == 'ON'
run: copy "C:/vcpkg/packages/openblas_${{ matrix.s2arc }}-windows/bin/openblas.dll" build/bin/${{ matrix.build }}
run: copy "$env:OPENBLAS_PATH/bin/libopenblas.dll" build/bin/${{ matrix.build }}
- name: Copy SDL2.dll
if: matrix.sdl2 == 'ON'
@ -550,6 +443,7 @@ jobs:
windows-cublas:
runs-on: windows-2019
strategy:
matrix:
build: [Release]
@ -559,10 +453,12 @@ jobs:
cuda-toolkit: [12.2.0, 11.8.0]
include:
- arch: x64
sdl2: ON
sdl2_ver: 2.28.5
s2arc: x64
- sdl2: ON
s2ver: 2.28.5
steps:
- name: Clone repository
- name: Clone
uses: actions/checkout@v4
- name: Add msbuild to PATH
@ -574,43 +470,38 @@ jobs:
with:
cuda: '${{ matrix.cuda-toolkit }}'
- name: Install 7-Zip
run: choco install 7zip -y
- name: Fetch SDL2 and set SDL2_DIR
if: matrix.sdl2 == 'ON'
run: |
Invoke-WebRequest -Uri https://github.com/libsdl-org/SDL/releases/download/release-${{ matrix.sdl2_ver }}/SDL2-devel-${{ matrix.sdl2_ver }}-VC.zip -OutFile sdl2.zip
C:/msys64/usr/bin/wget.exe -qO sdl2.zip https://github.com/libsdl-org/SDL/releases/download/release-${{ matrix.s2ver }}/SDL2-devel-${{ matrix.s2ver }}-VC.zip
7z x sdl2.zip
echo "SDL2_DIR=${{ github.workspace }}\SDL2-${{ matrix.sdl2_ver }}\cmake" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "${{ github.workspace }}\SDL2-${{ matrix.sdl2_ver }}\cmake" > SDL2_PATH.txt
- name: Configure CMake
shell: cmd
run: |
cmake -S . -B ./build -A ${{ matrix.arch }} ^
-DCMAKE_BUILD_TYPE=${{ matrix.build }} ^
-DGGML_CUDA=${{ matrix.cublas }} ^
-DCMAKE_CUDA_ARCHITECTURES=all ^
-DWHISPER_SDL2=${{ matrix.sdl2 }} ^
-DSDL2_DIR="%SDL2_DIR%"
- name: Build Project
shell: cmd
echo "SDL2_DIR=$env:GITHUB_WORKSPACE/SDL2-${{ matrix.s2ver }}/cmake" >> $env:GITHUB_ENV
- name: Configure
run: >
cmake -S . -B ./build -A ${{ matrix.arch }}
-DCMAKE_BUILD_TYPE=${{ matrix.build }}
-DGGML_CUDA=${{ matrix.cublas }}
-DWHISPER_SDL2=${{ matrix.sdl2 }}
- name: Build ${{ matrix.cuda-toolkit }}
run: |
cd ./build
cmake --build . --config ${{ matrix.build }}
cmake --build . --config ${{ matrix.build }}
- name: Copy CUDA DLLs
run: |
Get-ChildItem "${{ steps.cuda-toolkit.outputs.CUDA_PATH }}/bin/" -Filter "*.dll" |
Copy-Item -Destination "build/bin/${{ matrix.build }}"
run: >
Copy-Item -PassThru
-Path "${{ steps.cuda-toolkit.outputs.CUDA_PATH }}/bin/*.dll"
-Include cudart64_*,cublas64_*,cublasLt64_*
-Destination build/bin/${{ matrix.build }}
- name: Copy SDL2.dll
if: matrix.sdl2 == 'ON'
run: copy "$env:SDL2_DIR/../lib/${{ matrix.arch }}/SDL2.dll" build/bin/${{ matrix.build }}
run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }}
- name: Upload binaries
if: matrix.sdl2 == 'ON'
uses: actions/upload-artifact@v4
with:
name: whisper-cublas-${{ matrix.cuda-toolkit }}-bin-${{ matrix.arch }}
@ -638,7 +529,7 @@ jobs:
emcmake cmake . -DCMAKE_BUILD_TYPE=${{ matrix.build }}
make
ios-xcode-build:
ios:
runs-on: macos-latest
strategy:
@ -646,7 +537,7 @@ jobs:
build: [Release]
steps:
- name: Checkout code
- name: Clone
uses: actions/checkout@v4
- name: Configure
@ -654,34 +545,11 @@ jobs:
cp models/for-tests-ggml-base.en.bin models/ggml-base.en.bin
mkdir models/ggml-base.en-encoder.mlmodelc
- name: Build
id: cmake_build
run: |
sysctl -a
mkdir build
cd build
cmake -G Xcode .. \
-DGGML_METAL_USE_BF16=ON \
-DGGML_METAL_EMBED_LIBRARY=ON \
-DWHISPER_BUILD_EXAMPLES=OFF \
-DWHISPER_BUILD_TESTS=OFF \
-DWHISPER_BUILD_SERVER=OFF \
-DCMAKE_SYSTEM_NAME=iOS \
-DCMAKE_OSX_DEPLOYMENT_TARGET=14.0 \
-DCMAKE_XCODE_ATTRIBUTE_DEVELOPMENT_TEAM=ggml
cmake --build . --config Release -j $(sysctl -n hw.logicalcpu) -- CODE_SIGNING_ALLOWED=NO
sudo cmake --install . --config Release
- name: xcodebuild for swift package
id: xcodebuild
run: |
xcodebuild -scheme whisper-Package -destination 'generic/platform=iOS'
#- name: Build objc example
# run: xcodebuild -project examples/whisper.objc/whisper.objc.xcodeproj -scheme whisper.objc -configuration ${{ matrix.build }} -sdk iphoneos build
- name: Build objc example
run: xcodebuild -project examples/whisper.objc/whisper.objc.xcodeproj -scheme whisper.objc -configuration ${{ matrix.build }} -sdk iphonesimulator build
- name: Build swiftui example
run: xcodebuild -project examples/whisper.swiftui/whisper.swiftui.xcodeproj -scheme WhisperCppDemo -configuration ${{ matrix.build }} -sdk iphoneos CODE_SIGNING_REQUIRED=NO CODE_SIGN_IDENTITY= -destination 'generic/platform=iOS' build
run: xcodebuild -project examples/whisper.swiftui/whisper.swiftui.xcodeproj -scheme WhisperCppDemo -configuration ${{ matrix.build }} -sdk iphonesimulator build
android:
runs-on: ubuntu-latest
@ -692,6 +560,12 @@ jobs:
with:
path: whisper
- name: Clone
uses: actions/checkout@v4
with:
repository: ggerganov/ggml
path: ggml
- name: Install Java
uses: actions/setup-java@v4
with:
@ -710,7 +584,7 @@ jobs:
run: |
export PATH_TO_GGML=$PWD/ggml
cd whisper/examples/whisper.android
./gradlew assembleRelease --no-daemon
./gradlew assembleRelease --no-daemon -PGGML_HOME=$PATH_TO_GGML
# TODO: disable because of following fail: https://github.com/ggerganov/whisper.cpp/actions/runs/11019444420/job/30627193602
# android_java:
@ -792,6 +666,5 @@ jobs:
- name: Test quantize
run: |
./models/download-ggml-model.sh tiny.en
cmake -B build
cmake --build build --config Release
./build/bin/quantize models/ggml-tiny.en.bin models/ggml-tiny.en-q4_0.bin q4_0
make quantize
./quantize models/ggml-tiny.en.bin models/ggml-tiny.en-q4_0.bin q4_0

View File

@ -17,7 +17,7 @@ jobs:
strategy:
matrix:
config:
- { tag: "main", dockerfile: ".devops/main.Dockerfile", platform: "linux/amd64" }
- { tag: "main", dockerfile: ".devops/main.Dockerfile", platform: "linux/amd64,linux/arm64" }
#TODO: the cuda image keeps failing - disable for now
# https://github.com/ggerganov/whisper.cpp/actions/runs/11019444428/job/30602020339
#- { tag: "main-cuda", dockerfile: ".devops/main-cuda.Dockerfile", platform: "linux/amd64" }
@ -45,7 +45,7 @@ jobs:
with:
context: .
push: true
platforms: ${{ matrix.config.platform }}
platforms: ${{ matrix.config.platforms }}
tags: "ghcr.io/${{ github.repository }}:${{ matrix.config.tag }}-${{ env.COMMIT_SHA }}"
file: ${{ matrix.config.dockerfile }}
@ -54,6 +54,6 @@ jobs:
with:
context: .
push: ${{ github.event_name == 'push' }}
platforms: ${{ matrix.config.platform }}
platforms: ${{ matrix.config.platforms }}
tags: "ghcr.io/${{ github.repository }}:${{ matrix.config.tag }}"
file: ${{ matrix.config.dockerfile }}

4
.gitignore vendored
View File

@ -1,6 +1,5 @@
*.o
*.a
*.d
.cache/
.coreml/
.test/
@ -20,9 +19,6 @@ build-*/
.swiftpm
*.metallib
ggml-metal-embed.metal
ggml-metal-embed.metal.tmp
/main
/stream
/command

View File

@ -1,6 +1,6 @@
cmake_minimum_required(VERSION 3.5) # for add_link_options and implicit target directories.
project("whisper.cpp" C CXX)
project("whisper.cpp" VERSION 1.7.4)
project("whisper.cpp" VERSION 1.7.1)
include(CheckIncludeFileCXX)
set(SOVERSION 1)

1131
Makefile

File diff suppressed because it is too large Load Diff

View File

@ -14,6 +14,47 @@ let package = Package(
.library(name: "whisper", targets: ["whisper"]),
],
targets: [
.systemLibrary(name: "whisper", pkgConfig: "whisper"),
]
.target(
name: "whisper",
path: ".",
exclude: [
"bindings",
"cmake",
"coreml",
"examples",
"extra",
"models",
"samples",
"tests",
"CMakeLists.txt",
"Makefile"
],
sources: [
"ggml/src/ggml.c",
"src/whisper.cpp",
"ggml/src/ggml-aarch64.c",
"ggml/src/ggml-alloc.c",
"ggml/src/ggml-backend.cpp",
"ggml/src/ggml-quants.c",
"ggml/src/ggml-metal.m"
],
resources: [.process("ggml-metal.metal")],
publicHeadersPath: "spm-headers",
cSettings: [
.unsafeFlags(["-Wno-shorten-64-to-32", "-O3", "-DNDEBUG"]),
.define("GGML_USE_ACCELERATE"),
.unsafeFlags(["-fno-objc-arc"]),
.define("GGML_USE_METAL")
// NOTE: NEW_LAPACK will required iOS version 16.4+
// We should consider add this in the future when we drop support for iOS 14
// (ref: ref: https://developer.apple.com/documentation/accelerate/1513264-cblas_sgemm?language=objc)
// .define("ACCELERATE_NEW_LAPACK"),
// .define("ACCELERATE_LAPACK_ILP64")
],
linkerSettings: [
.linkedFramework("Accelerate")
]
)
],
cxxLanguageStandard: .cxx11
)

347
README.md
View File

@ -7,7 +7,7 @@
[![Conan Center](https://shields.io/conan/v/whisper-cpp)](https://conan.io/center/whisper-cpp)
[![npm](https://img.shields.io/npm/v/whisper.cpp.svg)](https://www.npmjs.com/package/whisper.cpp/)
Stable: [v1.7.4](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.7.4) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126)
Stable: [v1.7.1](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.7.1) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126)
High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model:
@ -16,7 +16,7 @@ High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisp
- AVX intrinsics support for x86 architectures
- VSX intrinsics support for POWER architectures
- Mixed F16 / F32 precision
- [Integer quantization support](#quantization)
- [4-bit and 5-bit integer quantization support](#quantization)
- Zero memory allocations at runtime
- [Vulkan support](#vulkan-gpu-support)
- Support for CPU-only inference
@ -53,6 +53,18 @@ On Apple Silicon, the inference runs fully on the GPU via Metal:
https://github.com/ggerganov/whisper.cpp/assets/1991296/c82e8f86-60dc-49f2-b048-d2fdbd6b5225
Or you can even run it straight in the browser: [talk.wasm](examples/talk.wasm)
## Implementation details
- The core tensor operations are implemented in C ([ggml.h](ggml/include/ggml.h) / [ggml.c](ggml/src/ggml.c))
- The transformer model and the high-level C-style API are implemented in C++ ([whisper.h](include/whisper.h) / [whisper.cpp](src/whisper.cpp))
- Sample usage is demonstrated in [main.cpp](examples/main)
- Sample real-time audio transcription from the microphone is demonstrated in [stream.cpp](examples/stream)
- Various other examples are available in the [examples](examples) folder
The tensor operators are optimized heavily for Apple silicon CPUs. Depending on the computation size, Arm Neon SIMD intrinsics or CBLAS Accelerate framework routines are used. The latter are especially effective for bigger sizes since the Accelerate framework utilizes the special-purpose AMX coprocessor available in modern Apple products.
## Quick start
First clone the repository:
@ -73,26 +85,134 @@ Then, download one of the Whisper [models](models/README.md) converted in [`ggml
sh ./models/download-ggml-model.sh base.en
```
Now build the [whisper-cli](examples/cli) example and transcribe an audio file like this:
Now build the [main](examples/main) example and transcribe an audio file like this:
```bash
# build the project
cmake -B build
cmake --build build --config Release
# build the main example
make -j
# transcribe an audio file
./build/bin/whisper-cli -f samples/jfk.wav
./main -f samples/jfk.wav
```
---
For a quick demo, simply run `make base.en`.
For a quick demo, simply run `make base.en`:
```text
$ make -j base.en
cc -I. -O3 -std=c11 -pthread -DGGML_USE_ACCELERATE -c ggml.c -o ggml.o
c++ -I. -I./examples -O3 -std=c++11 -pthread -c whisper.cpp -o whisper.o
c++ -I. -I./examples -O3 -std=c++11 -pthread examples/main/main.cpp whisper.o ggml.o -o main -framework Accelerate
./main -h
usage: ./main [options] file0.wav file1.wav ...
options:
-h, --help [default] show this help message and exit
-t N, --threads N [4 ] number of threads to use during computation
-p N, --processors N [1 ] number of processors to use during computation
-ot N, --offset-t N [0 ] time offset in milliseconds
-on N, --offset-n N [0 ] segment index offset
-d N, --duration N [0 ] duration of audio to process in milliseconds
-mc N, --max-context N [-1 ] maximum number of text context tokens to store
-ml N, --max-len N [0 ] maximum segment length in characters
-sow, --split-on-word [false ] split on word rather than on token
-bo N, --best-of N [5 ] number of best candidates to keep
-bs N, --beam-size N [5 ] beam size for beam search
-wt N, --word-thold N [0.01 ] word timestamp probability threshold
-et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail
-lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail
-debug, --debug-mode [false ] enable debug mode (eg. dump log_mel)
-tr, --translate [false ] translate from source language to english
-di, --diarize [false ] stereo audio diarization
-tdrz, --tinydiarize [false ] enable tinydiarize (requires a tdrz model)
-nf, --no-fallback [false ] do not use temperature fallback while decoding
-otxt, --output-txt [false ] output result in a text file
-ovtt, --output-vtt [false ] output result in a vtt file
-osrt, --output-srt [false ] output result in a srt file
-olrc, --output-lrc [false ] output result in a lrc file
-owts, --output-words [false ] output script for generating karaoke video
-fp, --font-path [/System/Library/Fonts/Supplemental/Courier New Bold.ttf] path to a monospace font for karaoke video
-ocsv, --output-csv [false ] output result in a CSV file
-oj, --output-json [false ] output result in a JSON file
-ojf, --output-json-full [false ] include more information in the JSON file
-of FNAME, --output-file FNAME [ ] output file path (without file extension)
-ps, --print-special [false ] print special tokens
-pc, --print-colors [false ] print colors
-pp, --print-progress [false ] print progress
-nt, --no-timestamps [false ] do not print timestamps
-l LANG, --language LANG [en ] spoken language ('auto' for auto-detect)
-dl, --detect-language [false ] exit after automatically detecting language
--prompt PROMPT [ ] initial prompt
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path
-f FNAME, --file FNAME [ ] input WAV file path
-oved D, --ov-e-device DNAME [CPU ] the OpenVINO device used for encode inference
-ls, --log-score [false ] log best decoder scores of tokens
-ng, --no-gpu [false ] disable GPU
sh ./models/download-ggml-model.sh base.en
Downloading ggml model base.en ...
ggml-base.en.bin 100%[========================>] 141.11M 6.34MB/s in 24s
Done! Model 'base.en' saved in 'models/ggml-base.en.bin'
You can now use it like this:
$ ./main -m models/ggml-base.en.bin -f samples/jfk.wav
===============================================
Running base.en on all samples in ./samples ...
===============================================
----------------------------------------------
[+] Running base.en on samples/jfk.wav ... (run 'ffplay samples/jfk.wav' to listen)
----------------------------------------------
whisper_init_from_file: loading model from 'models/ggml-base.en.bin'
whisper_model_load: loading model
whisper_model_load: n_vocab = 51864
whisper_model_load: n_audio_ctx = 1500
whisper_model_load: n_audio_state = 512
whisper_model_load: n_audio_head = 8
whisper_model_load: n_audio_layer = 6
whisper_model_load: n_text_ctx = 448
whisper_model_load: n_text_state = 512
whisper_model_load: n_text_head = 8
whisper_model_load: n_text_layer = 6
whisper_model_load: n_mels = 80
whisper_model_load: f16 = 1
whisper_model_load: type = 2
whisper_model_load: mem required = 215.00 MB (+ 6.00 MB per decoder)
whisper_model_load: kv self size = 5.25 MB
whisper_model_load: kv cross size = 17.58 MB
whisper_model_load: adding 1607 extra tokens
whisper_model_load: model ctx = 140.60 MB
whisper_model_load: model size = 140.54 MB
system_info: n_threads = 4 / 10 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 |
main: processing 'samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 processors, lang = en, task = transcribe, timestamps = 1 ...
[00:00:00.000 --> 00:00:11.000] And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country.
whisper_print_timings: fallbacks = 0 p / 0 h
whisper_print_timings: load time = 113.81 ms
whisper_print_timings: mel time = 15.40 ms
whisper_print_timings: sample time = 11.58 ms / 27 runs ( 0.43 ms per run)
whisper_print_timings: encode time = 266.60 ms / 1 runs ( 266.60 ms per run)
whisper_print_timings: decode time = 66.11 ms / 27 runs ( 2.45 ms per run)
whisper_print_timings: total time = 476.31 ms
```
The command downloads the `base.en` model converted to custom `ggml` format and runs the inference on all `.wav` samples in the folder `samples`.
For detailed usage instructions, run: `./build/bin/whisper-cli -h`
For detailed usage instructions, run: `./main -h`
Note that the [whisper-cli](examples/cli) example currently runs only with 16-bit WAV files, so make sure to convert your input before running the tool.
Note that the [main](examples/main) example currently runs only with 16-bit WAV files, so make sure to convert your input before running the tool.
For example, you can use `ffmpeg` like this:
```bash
@ -145,12 +265,11 @@ Here are the steps for creating and using a quantized model:
```bash
# quantize a model with Q5_0 method
cmake -B build
cmake --build build --config Release
./build/bin/quantize models/ggml-base.en.bin models/ggml-base.en-q5_0.bin q5_0
make -j quantize
./quantize models/ggml-base.en.bin models/ggml-base.en-q5_0.bin q5_0
# run the examples as usual, specifying the quantized model file
./build/bin/whisper-cli -m models/ggml-base.en-q5_0.bin ./samples/gb0.wav
./main -m models/ggml-base.en-q5_0.bin ./samples/gb0.wav
```
## Core ML support
@ -184,6 +303,10 @@ speed-up - more than x3 faster compared with CPU-only execution. Here are the in
- Build `whisper.cpp` with Core ML support:
```bash
# using Makefile
make clean
WHISPER_COREML=1 make -j
# using CMake
cmake -B build -DWHISPER_COREML=1
cmake --build build -j --config Release
@ -192,7 +315,7 @@ speed-up - more than x3 faster compared with CPU-only execution. Here are the in
- Run the examples as usual. For example:
```text
$ ./build/bin/whisper-cli -m models/ggml-base.en.bin -f samples/jfk.wav
$ ./main -m models/ggml-base.en.bin -f samples/jfk.wav
...
@ -276,7 +399,7 @@ This can result in significant speedup in encoder performance. Here are the inst
- Run the examples as usual. For example:
```text
$ ./build/bin/whisper-cli -m models/ggml-base.en.bin -f samples/jfk.wav
$ ./main -m models/ggml-base.en.bin -f samples/jfk.wav
...
@ -293,7 +416,7 @@ This can result in significant speedup in encoder performance. Here are the inst
The first time run on an OpenVINO device is slow, since the OpenVINO framework will compile the IR (Intermediate Representation) model to a device-specific 'blob'. This device-specific blob will get
cached for the next run.
For more information about the OpenVINO implementation please refer to PR [#1037](https://github.com/ggerganov/whisper.cpp/pull/1037).
For more information about the Core ML implementation please refer to PR [#1037](https://github.com/ggerganov/whisper.cpp/pull/1037).
## NVIDIA GPU support
@ -303,8 +426,8 @@ First, make sure you have installed `cuda`: https://developer.nvidia.com/cuda-do
Now build `whisper.cpp` with CUDA support:
```
cmake -B build -DGGML_CUDA=1
cmake --build build -j --config Release
make clean
GGML_CUDA=1 make -j
```
## Vulkan GPU support
@ -313,8 +436,8 @@ First, make sure your graphics card driver provides support for Vulkan API.
Now build `whisper.cpp` with Vulkan support:
```
cmake -B build -DGGML_VULKAN=1
cmake --build build -j --config Release
make clean
make GGML_VULKAN=1 -j
```
## BLAS CPU support via OpenBLAS
@ -325,13 +448,28 @@ First, make sure you have installed `openblas`: https://www.openblas.net/
Now build `whisper.cpp` with OpenBLAS support:
```
cmake -B build -DGGML_BLAS=1
cmake --build build -j --config Release
make clean
GGML_OPENBLAS=1 make -j
```
## BLAS CPU support via Intel MKL
Encoder processing can be accelerated on the CPU via the BLAS compatible interface of Intel's Math Kernel Library.
First, make sure you have installed Intel's MKL runtime and development packages: https://www.intel.com/content/www/us/en/developer/tools/oneapi/onemkl-download.html
Now build `whisper.cpp` with Intel MKL BLAS support:
```
source /opt/intel/oneapi/setvars.sh
mkdir build
cd build
cmake -DWHISPER_MKL=ON ..
WHISPER_MKL=1 make -j
```
## Ascend NPU support
Ascend NPU provides inference acceleration via [`CANN`](https://www.hiascend.com/en/software/cann) and AI cores.
Ascend NPU provides inference acceleration via [`CANN`](https://www.hiascend.com/en/software/cann) and AI cores.
First, check if your Ascend NPU device is supported:
@ -345,14 +483,16 @@ Then, make sure you have installed [`CANN toolkit`](https://www.hiascend.com/en/
Now build `whisper.cpp` with CANN support:
```
cmake -B build -DGGML_CANN=1
cmake --build build -j --config Release
mkdir build
cd build
cmake .. -D GGML_CANN=on
make -j
```
Run the inference examples as usual, for example:
```
./build/bin/whisper-cli -f samples/jfk.wav -m models/ggml-base.en.bin -t 8
./build/bin/main -f samples/jfk.wav -m models/ggml-base.en.bin -t 8
```
*Notes:*
@ -360,6 +500,38 @@ Run the inference examples as usual, for example:
- If you have trouble with Ascend NPU device, please create a issue with **[CANN]** prefix/tag.
- If you run successfully with your Ascend NPU device, please help update the table `Verified devices`.
## Docker
### Prerequisites
- Docker must be installed and running on your system.
- Create a folder to store big models & intermediate files (ex. /whisper/models)
### Images
We have two Docker images available for this project:
1. `ghcr.io/ggerganov/whisper.cpp:main`: This image includes the main executable file as well as `curl` and `ffmpeg`. (platforms: `linux/amd64`, `linux/arm64`)
2. `ghcr.io/ggerganov/whisper.cpp:main-cuda`: Same as `main` but compiled with CUDA support. (platforms: `linux/amd64`)
### Usage
```shell
# download model and persist it in a local folder
docker run -it --rm \
-v path/to/models:/models \
whisper.cpp:main "./models/download-ggml-model.sh base /models"
# transcribe an audio file
docker run -it --rm \
-v path/to/models:/models \
-v path/to/audios:/audios \
whisper.cpp:main "./main -m /models/ggml-base.bin -f /audios/jfk.wav"
# transcribe an audio file in samples folder
docker run -it --rm \
-v path/to/models:/models \
whisper.cpp:main "./main -m /models/ggml-base.bin -f ./samples/jfk.wav"
```
## Installing with Conan
You can install pre-built binaries for whisper.cpp or build it from source using [Conan](https://conan.io/). Use the following command:
@ -374,6 +546,89 @@ For detailed instructions on how to use Conan, please refer to the [Conan docume
- Inference only
## Another example
Here is another example of transcribing a [3:24 min speech](https://upload.wikimedia.org/wikipedia/commons/1/1f/George_W_Bush_Columbia_FINAL.ogg)
in about half a minute on a MacBook M1 Pro, using `medium.en` model:
<details>
<summary>Expand to see the result</summary>
```text
$ ./main -m models/ggml-medium.en.bin -f samples/gb1.wav -t 8
whisper_init_from_file: loading model from 'models/ggml-medium.en.bin'
whisper_model_load: loading model
whisper_model_load: n_vocab = 51864
whisper_model_load: n_audio_ctx = 1500
whisper_model_load: n_audio_state = 1024
whisper_model_load: n_audio_head = 16
whisper_model_load: n_audio_layer = 24
whisper_model_load: n_text_ctx = 448
whisper_model_load: n_text_state = 1024
whisper_model_load: n_text_head = 16
whisper_model_load: n_text_layer = 24
whisper_model_load: n_mels = 80
whisper_model_load: f16 = 1
whisper_model_load: type = 4
whisper_model_load: mem required = 1720.00 MB (+ 43.00 MB per decoder)
whisper_model_load: kv self size = 42.00 MB
whisper_model_load: kv cross size = 140.62 MB
whisper_model_load: adding 1607 extra tokens
whisper_model_load: model ctx = 1462.35 MB
whisper_model_load: model size = 1462.12 MB
system_info: n_threads = 8 / 10 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 |
main: processing 'samples/gb1.wav' (3179750 samples, 198.7 sec), 8 threads, 1 processors, lang = en, task = transcribe, timestamps = 1 ...
[00:00:00.000 --> 00:00:08.000] My fellow Americans, this day has brought terrible news and great sadness to our country.
[00:00:08.000 --> 00:00:17.000] At nine o'clock this morning, Mission Control in Houston lost contact with our Space Shuttle Columbia.
[00:00:17.000 --> 00:00:23.000] A short time later, debris was seen falling from the skies above Texas.
[00:00:23.000 --> 00:00:29.000] The Columbia's lost. There are no survivors.
[00:00:29.000 --> 00:00:32.000] On board was a crew of seven.
[00:00:32.000 --> 00:00:39.000] Colonel Rick Husband, Lieutenant Colonel Michael Anderson, Commander Laurel Clark,
[00:00:39.000 --> 00:00:48.000] Captain David Brown, Commander William McCool, Dr. Kultna Shavla, and Ilan Ramon,
[00:00:48.000 --> 00:00:52.000] a colonel in the Israeli Air Force.
[00:00:52.000 --> 00:00:58.000] These men and women assumed great risk in the service to all humanity.
[00:00:58.000 --> 00:01:03.000] In an age when space flight has come to seem almost routine,
[00:01:03.000 --> 00:01:07.000] it is easy to overlook the dangers of travel by rocket
[00:01:07.000 --> 00:01:12.000] and the difficulties of navigating the fierce outer atmosphere of the Earth.
[00:01:12.000 --> 00:01:18.000] These astronauts knew the dangers, and they faced them willingly,
[00:01:18.000 --> 00:01:23.000] knowing they had a high and noble purpose in life.
[00:01:23.000 --> 00:01:31.000] Because of their courage and daring and idealism, we will miss them all the more.
[00:01:31.000 --> 00:01:36.000] All Americans today are thinking as well of the families of these men and women
[00:01:36.000 --> 00:01:40.000] who have been given this sudden shock and grief.
[00:01:40.000 --> 00:01:45.000] You're not alone. Our entire nation grieves with you,
[00:01:45.000 --> 00:01:52.000] and those you love will always have the respect and gratitude of this country.
[00:01:52.000 --> 00:01:56.000] The cause in which they died will continue.
[00:01:56.000 --> 00:02:04.000] Mankind is led into the darkness beyond our world by the inspiration of discovery
[00:02:04.000 --> 00:02:11.000] and the longing to understand. Our journey into space will go on.
[00:02:11.000 --> 00:02:16.000] In the skies today, we saw destruction and tragedy.
[00:02:16.000 --> 00:02:22.000] Yet farther than we can see, there is comfort and hope.
[00:02:22.000 --> 00:02:29.000] In the words of the prophet Isaiah, "Lift your eyes and look to the heavens
[00:02:29.000 --> 00:02:35.000] who created all these. He who brings out the starry hosts one by one
[00:02:35.000 --> 00:02:39.000] and calls them each by name."
[00:02:39.000 --> 00:02:46.000] Because of His great power and mighty strength, not one of them is missing.
[00:02:46.000 --> 00:02:55.000] The same Creator who names the stars also knows the names of the seven souls we mourn today.
[00:02:55.000 --> 00:03:01.000] The crew of the shuttle Columbia did not return safely to earth,
[00:03:01.000 --> 00:03:05.000] yet we can pray that all are safely home.
[00:03:05.000 --> 00:03:13.000] May God bless the grieving families, and may God continue to bless America.
[00:03:13.000 --> 00:03:19.000] [Silence]
whisper_print_timings: fallbacks = 1 p / 0 h
whisper_print_timings: load time = 569.03 ms
whisper_print_timings: mel time = 146.85 ms
whisper_print_timings: sample time = 238.66 ms / 553 runs ( 0.43 ms per run)
whisper_print_timings: encode time = 18665.10 ms / 9 runs ( 2073.90 ms per run)
whisper_print_timings: decode time = 13090.93 ms / 549 runs ( 23.85 ms per run)
whisper_print_timings: total time = 32733.52 ms
```
</details>
## Real-time audio input example
This is a naive example of performing real-time inference on audio from your microphone.
@ -381,9 +636,8 @@ The [stream](examples/stream) tool samples the audio every half a second and run
More info is available in [issue #10](https://github.com/ggerganov/whisper.cpp/issues/10).
```bash
cmake -B build -DWHISPER_SDL2=ON
cmake --build build --config Release
./build/bin/whisper-stream -m ./models/ggml-base.en.bin -t 8 --step 500 --length 5000
make stream -j
./stream -m ./models/ggml-base.en.bin -t 8 --step 500 --length 5000
```
https://user-images.githubusercontent.com/1991296/194935793-76afede7-cfa8-48d8-a80f-28ba83be7d09.mp4
@ -394,7 +648,7 @@ Adding the `--print-colors` argument will print the transcribed text using an ex
to highlight words with high or low confidence:
```bash
./build/bin/whisper-cli -m models/ggml-base.en.bin -f samples/gb0.wav --print-colors
./main -m models/ggml-base.en.bin -f samples/gb0.wav --print-colors
```
<img width="965" alt="image" src="https://user-images.githubusercontent.com/1991296/197356445-311c8643-9397-4e5e-b46e-0b4b4daa2530.png">
@ -404,7 +658,7 @@ to highlight words with high or low confidence:
For example, to limit the line length to a maximum of 16 characters, simply add `-ml 16`:
```text
$ ./build/bin/whisper-cli -m ./models/ggml-base.en.bin -f ./samples/jfk.wav -ml 16
$ ./main -m ./models/ggml-base.en.bin -f ./samples/jfk.wav -ml 16
whisper_model_load: loading model from './models/ggml-base.en.bin'
...
@ -428,7 +682,7 @@ main: processing './samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 pr
The `--max-len` argument can be used to obtain word-level timestamps. Simply use `-ml 1`:
```text
$ ./build/bin/whisper-cli -m ./models/ggml-base.en.bin -f ./samples/jfk.wav -ml 1
$ ./main -m ./models/ggml-base.en.bin -f ./samples/jfk.wav -ml 1
whisper_model_load: loading model from './models/ggml-base.en.bin'
...
@ -475,7 +729,7 @@ Sample usage:
./models/download-ggml-model.sh small.en-tdrz
# run as usual, adding the "-tdrz" command-line argument
./build/bin/whisper-cli -f ./samples/a13.wav -m ./models/ggml-small.en-tdrz.bin -tdrz
./main -f ./samples/a13.wav -m ./models/ggml-small.en-tdrz.bin -tdrz
...
main: processing './samples/a13.wav' (480000 samples, 30.0 sec), 4 threads, 1 processors, lang = en, task = transcribe, tdrz = 1, timestamps = 1 ...
...
@ -492,14 +746,14 @@ main: processing './samples/a13.wav' (480000 samples, 30.0 sec), 4 threads, 1 pr
## Karaoke-style movie generation (experimental)
The [whisper-cli](examples/cli) example provides support for output of karaoke-style movies, where the
The [main](examples/main) example provides support for output of karaoke-style movies, where the
currently pronounced word is highlighted. Use the `-wts` argument and run the generated bash script.
This requires to have `ffmpeg` installed.
Here are a few _"typical"_ examples:
```bash
./build/bin/whisper-cli -m ./models/ggml-base.en.bin -f ./samples/jfk.wav -owts
./main -m ./models/ggml-base.en.bin -f ./samples/jfk.wav -owts
source ./samples/jfk.wav.wts
ffplay ./samples/jfk.wav.mp4
```
@ -509,7 +763,7 @@ https://user-images.githubusercontent.com/1991296/199337465-dbee4b5e-9aeb-48a3-b
---
```bash
./build/bin/whisper-cli -m ./models/ggml-base.en.bin -f ./samples/mm0.wav -owts
./main -m ./models/ggml-base.en.bin -f ./samples/mm0.wav -owts
source ./samples/mm0.wav.wts
ffplay ./samples/mm0.wav.mp4
```
@ -519,7 +773,7 @@ https://user-images.githubusercontent.com/1991296/199337504-cc8fd233-0cb7-4920-9
---
```bash
./build/bin/whisper-cli -m ./models/ggml-base.en.bin -f ./samples/gb0.wav -owts
./main -m ./models/ggml-base.en.bin -f ./samples/gb0.wav -owts
source ./samples/gb0.wav.wts
ffplay ./samples/gb0.wav.mp4
```
@ -544,7 +798,7 @@ https://user-images.githubusercontent.com/1991296/223206245-2d36d903-cf8e-4f09-8
## Benchmarks
In order to have an objective comparison of the performance of the inference across different system configurations,
use the [whisper-bench](examples/bench) tool. The tool simply runs the Encoder part of the model and prints how much time it
use the [bench](examples/bench) tool. The tool simply runs the Encoder part of the model and prints how much time it
took to execute it. The results are summarized in the following Github issue:
[Benchmark results](https://github.com/ggerganov/whisper.cpp/issues/89)
@ -607,12 +861,13 @@ Some of the examples are even ported to run in the browser using WebAssembly. Ch
| Example | Web | Description |
| --------------------------------------------------- | ------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------- |
| [whisper-cli](examples/cli) | [whisper.wasm](examples/whisper.wasm) | Tool for translating and transcribing audio using Whisper |
| [whisper-bench](examples/bench) | [bench.wasm](examples/bench.wasm) | Benchmark the performance of Whisper on your machine |
| [whisper-stream](examples/stream) | [stream.wasm](examples/stream.wasm) | Real-time transcription of raw microphone capture |
| [whisper-command](examples/command) | [command.wasm](examples/command.wasm) | Basic voice assistant example for receiving voice commands from the mic |
| [whisper-server](examples/server) | | HTTP transcription server with OAI-like API |
| [whisper-talk-llama](examples/talk-llama) | | Talk with a LLaMA bot |
| [main](examples/main) | [whisper.wasm](examples/whisper.wasm) | Tool for translating and transcribing audio using Whisper |
| [bench](examples/bench) | [bench.wasm](examples/bench.wasm) | Benchmark the performance of Whisper on your machine |
| [stream](examples/stream) | [stream.wasm](examples/stream.wasm) | Real-time transcription of raw microphone capture |
| [command](examples/command) | [command.wasm](examples/command.wasm) | Basic voice assistant example for receiving voice commands from the mic |
| [wchess](examples/wchess) | [wchess.wasm](examples/wchess) | Voice-controlled chess |
| [talk](examples/talk) | [talk.wasm](examples/talk.wasm) | Talk with a GPT-2 bot |
| [talk-llama](examples/talk-llama) | | Talk with a LLaMA bot |
| [whisper.objc](examples/whisper.objc) | | iOS mobile application using whisper.cpp |
| [whisper.swiftui](examples/whisper.swiftui) | | SwiftUI iOS / macOS application using whisper.cpp |
| [whisper.android](examples/whisper.android) | | Android mobile application using whisper.cpp |
@ -620,7 +875,7 @@ Some of the examples are even ported to run in the browser using WebAssembly. Ch
| [generate-karaoke.sh](examples/generate-karaoke.sh) | | Helper script to easily [generate a karaoke video](https://youtu.be/uj7hVta4blM) of raw audio capture |
| [livestream.sh](examples/livestream.sh) | | [Livestream audio transcription](https://github.com/ggerganov/whisper.cpp/issues/185) |
| [yt-wsp.sh](examples/yt-wsp.sh) | | Download + transcribe and/or translate any VOD [(original)](https://gist.github.com/DaniruKun/96f763ec1a037cc92fe1a059b643b818) |
| [wchess](examples/wchess) | [wchess.wasm](examples/wchess) | Voice-controlled chess |
| [server](examples/server) | | HTTP transcription server with OAI-like API |
## [Discussions](https://github.com/ggerganov/whisper.cpp/discussions)

View File

@ -1,5 +0,0 @@
module whisper [system] {
header "whisper.h"
link "whisper"
export *
}

View File

@ -1,4 +0,0 @@
#pragma once
#include <whisper.h>

View File

@ -67,5 +67,5 @@ copy /y ..\..\build\bin\Release\whisper.dll build\generated\resources\main\win32
## License
The license for the Java bindings is the same as the license for the rest of the whisper.cpp project, which is the MIT License. See the `LICENSE` file for more details.
The license for the Go bindings is the same as the license for the rest of the whisper.cpp project, which is the MIT License. See the `LICENSE` file for more details.

View File

@ -181,11 +181,11 @@ public class WhisperFullParams extends Structure {
}
/** Flag to suppress non-speech tokens. */
public CBool suppress_nst;
public CBool suppress_non_speech_tokens;
/** Flag to suppress non-speech tokens. */
public void suppressNonSpeechTokens(boolean enable) {
suppress_nst = enable ? CBool.TRUE : CBool.FALSE;
suppress_non_speech_tokens = enable ? CBool.TRUE : CBool.FALSE;
}
/** Initial decoding temperature. */
@ -315,7 +315,7 @@ public class WhisperFullParams extends Structure {
"print_special", "print_progress", "print_realtime", "print_timestamps", "token_timestamps",
"thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "audio_ctx",
"tdrz_enable", "suppress_regex", "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language",
"suppress_blank", "suppress_nst", "temperature", "max_initial_ts", "length_penalty",
"suppress_blank", "suppress_non_speech_tokens", "temperature", "max_initial_ts", "length_penalty",
"temperature_inc", "entropy_thold", "logprob_thold", "no_speech_thold", "greedy", "beam_search",
"new_segment_callback", "new_segment_callback_user_data",
"progress_callback", "progress_callback_user_data",

View File

@ -1,6 +1,6 @@
{
"name": "whisper.cpp",
"version": "1.7.4",
"version": "1.7.1",
"description": "Whisper speech recognition",
"main": "whisper.js",
"scripts": {

View File

@ -22,7 +22,7 @@ Usage
```ruby
require "whisper"
whisper = Whisper::Context.new("base")
whisper = Whisper::Context.new("path/to/model.bin")
params = Whisper::Params.new
params.language = "en"
@ -31,7 +31,7 @@ params.duration = 60_000
params.max_text_tokens = 300
params.translate = true
params.print_timestamps = false
params.initial_prompt = "Initial prompt here."
params.prompt = "Initial prompt here."
whisper.transcribe("path/to/audio.wav", params) do |whole_text|
puts whole_text
@ -41,67 +41,21 @@ end
### Preparing model ###
Some models are prepared up-front:
Use script to download model file(s):
```ruby
base_en = Whisper::Model.pre_converted_models["base.en"]
whisper = Whisper::Context.new(base_en)
```bash
git clone https://github.com/ggerganov/whisper.cpp.git
cd whisper.cpp
sh ./models/download-ggml-model.sh base.en
```
At first time you use a model, it is downloaded automatically. After that, downloaded cached file is used. To clear cache, call `#clear_cache`:
```ruby
Whisper::Model.pre_converted_models["base"].clear_cache
```
You also can use shorthand for pre-converted models:
```ruby
whisper = Whisper::Context.new("base.en")
```
You can see the list of prepared model names by `Whisper::Model.pre_converted_models.keys`:
```ruby
puts Whisper::Model.pre_converted_models.keys
# tiny
# tiny.en
# tiny-q5_1
# tiny.en-q5_1
# tiny-q8_0
# base
# base.en
# base-q5_1
# base.en-q5_1
# base-q8_0
# :
# :
```
You can also use local model files you prepared:
```ruby
whisper = Whisper::Context.new("path/to/your/model.bin")
```
Or, you can download model files:
```ruby
whisper = Whisper::Context.new("https://example.net/uri/of/your/model.bin")
# Or
whisper = Whisper::Context.new(URI("https://example.net/uri/of/your/model.bin"))
```
See [models][] page for details.
There are some types of models. See [models][] page for details.
### Preparing audio file ###
Currently, whisper.cpp accepts only 16-bit WAV files.
API
---
### Segments ###
### API ###
Once `Whisper::Context#transcribe` called, you can retrieve segments by `#each_segment`:
@ -131,6 +85,13 @@ end
You can also add hook to params called on new segment:
```ruby
def format_time(time_ms)
sec, decimal_part = time_ms.divmod(1000)
min, sec = sec.divmod(60)
hour, min = min.divmod(60)
"%02d:%02d:%02d.%03d" % [hour, min, sec, decimal_part]
end
# Add hook before calling #transcribe
params.on_new_segment do |segment|
line = "[%{st} --> %{ed}] %{text}" % {
@ -146,98 +107,5 @@ whisper.transcribe("path/to/audio.wav", params)
```
### Models ###
You can see model information:
```ruby
whisper = Whisper::Context.new("base")
model = whisper.model
model.n_vocab # => 51864
model.n_audio_ctx # => 1500
model.n_audio_state # => 512
model.n_audio_head # => 8
model.n_audio_layer # => 6
model.n_text_ctx # => 448
model.n_text_state # => 512
model.n_text_head # => 8
model.n_text_layer # => 6
model.n_mels # => 80
model.ftype # => 1
model.type # => "base"
```
### Logging ###
You can set log callback:
```ruby
prefix = "[MyApp] "
log_callback = ->(level, buffer, user_data) {
case level
when Whisper::LOG_LEVEL_NONE
puts "#{user_data}none: #{buffer}"
when Whisper::LOG_LEVEL_INFO
puts "#{user_data}info: #{buffer}"
when Whisper::LOG_LEVEL_WARN
puts "#{user_data}warn: #{buffer}"
when Whisper::LOG_LEVEL_ERROR
puts "#{user_data}error: #{buffer}"
when Whisper::LOG_LEVEL_DEBUG
puts "#{user_data}debug: #{buffer}"
when Whisper::LOG_LEVEL_CONT
puts "#{user_data}same to previous: #{buffer}"
end
}
Whisper.log_set log_callback, prefix
```
Using this feature, you are also able to suppress log:
```ruby
Whisper.log_set ->(level, buffer, user_data) {
# do nothing
}, nil
Whisper::Context.new("base")
```
### Low-level API to transcribe ###
You can also call `Whisper::Context#full` and `#full_parallel` with a Ruby array as samples. Although `#transcribe` with audio file path is recommended because it extracts PCM samples in C++ and is fast, `#full` and `#full_parallel` give you flexibility.
```ruby
require "whisper"
require "wavefile"
reader = WaveFile::Reader.new("path/to/audio.wav", WaveFile::Format.new(:mono, :float, 16000))
samples = reader.enum_for(:each_buffer).map(&:samples).flatten
whisper = Whisper::Context.new("base")
whisper.full(Whisper::Params.new, samples)
whisper.each_segment do |segment|
puts segment.text
end
```
The second argument `samples` may be an array, an object with `length` and `each` method, or a MemoryView. If you can prepare audio data as C array and export it as a MemoryView, whispercpp accepts and works with it with zero copy.
Development
-----------
% git clone https://github.com/ggerganov/whisper.cpp.git
% cd whisper.cpp/bindings/ruby
% rake test
First call of `rake test` builds an extension and downloads a model for testing. After that, you add tests in `tests` directory and modify `ext/ruby_whisper.cpp`.
If something seems wrong on build, running `rake clean` solves some cases.
License
-------
The same to [whisper.cpp][].
[whisper.cpp]: https://github.com/ggerganov/whisper.cpp
[models]: https://github.com/ggerganov/whisper.cpp/tree/master/models

View File

@ -1,64 +1,59 @@
require 'rake/clean'
require "bundler/gem_tasks"
require "pathname"
require "yaml"
require "rake/testtask"
require_relative "extsources"
extsources = YAML.load_file("extsources.yaml")
SOURCES = FileList[]
EXTSOURCES.each do |src|
extsources.each do |src|
basename = src.pathmap("%f")
dest = basename == "LICENSE" ? basename : src.pathmap("%{../..,ext}p")
dir = dest.pathmap("%d")
dest = basename == "LICENSE" ? basename : basename.pathmap("ext/%f")
file src
directory dir
file dest => [src, dir] do |t|
file dest => src do |t|
cp t.source, t.name
end
SOURCES.include dest
end
CLEAN.include SOURCES
CLEAN.include FileList["ext/*.o", "ext/*.metal", "ext/whisper.{so,bundle,dll}"]
CLEAN.include FileList[
"ext/*.o",
"ext/*.metal",
"ext/whisper.{so,bundle,dll}",
"ext/depend"
]
task build: ["ext/Makefile", "ext/ruby_whisper.h", "ext/ruby_whisper.cpp", "whispercpp.gemspec"]
task build: SOURCES + FileList[
"ext/extconf.rb",
"ext/ruby_whisper.h",
"ext/ruby_whisper.cpp",
"whispercpp.gemspec",
]
directory "pkg"
CLOBBER.include "pkg"
TEST_MODEL = "../../models/ggml-base.en.bin"
LIB_NAME = "whisper".ext(RbConfig::CONFIG["DLEXT"])
SO_FILE = File.join("ext", LIB_NAME)
LIB_FILE = File.join("lib", LIB_NAME)
file "ext/Makefile" => ["ext/extconf.rb", "ext/ruby_whisper.h", "ext/ruby_whisper.cpp"] + SOURCES do |t|
Dir.chdir "ext" do
ruby "extconf.rb"
end
end
file SO_FILE => "ext/Makefile" do |t|
directory "lib"
task LIB_FILE => SOURCES + ["lib"] do |t|
Dir.chdir "ext" do
sh "ruby extconf.rb"
sh "make"
end
end
CLEAN.include SO_FILE
directory "lib"
file LIB_FILE => [SO_FILE, "lib"] do |t|
copy t.source, t.name
mv "ext/#{LIB_NAME}", t.name
end
CLEAN.include LIB_FILE
Rake::TestTask.new do |t|
t.test_files = FileList["tests/test_*.rb"]
end
task test: [TEST_MODEL, LIB_FILE]
TEST_MEMORY_VIEW = "tests/jfk_reader/jfk_reader.#{RbConfig::CONFIG['DLEXT']}"
file TEST_MEMORY_VIEW => "tests/jfk_reader/jfk_reader.c" do |t|
Dir.chdir "tests/jfk_reader" do
ruby "extconf.rb"
sh "make"
file TEST_MODEL do
Dir.chdir "../.." do
sh "./models/download-ggml-model.sh base.en"
end
end
CLEAN.include "tests/jfk_reader/jfk_reader.{o,#{RbConfig::CONFIG['DLEXT']}}"
task test: [LIB_FILE, TEST_MEMORY_VIEW]

View File

@ -1,13 +1,35 @@
Makefile
whisper.so
ggml.c
ggml.h
ggml-alloc.c
ggml-alloc.h
ggml-aarch64.c
ggml-aarch64.h
ggml-backend.cpp
ggml-backend-impl.h
ggml-backend.c
ggml-backend.h
ggml-common.h
ggml-cpu-impl.h
ggml-metal.m
ggml-metal.metal
ggml-metal-embed.metal
ggml-blas.cpp
ggml-cuda.h
ggml-impl.h
ggml-kompute.h
ggml-metal.h
ggml-opencl.h
ggml-quants.c
ggml-quants.h
ggml-sycl.h
ggml-vulkan.h
ggml-blas.h
get-flags.mk
whisper.cpp
whisper.h
dr_wav.h
depend
whisper.bundle
whisper.so
whisper.dll
scripts/get-flags.mk
*.o
*.c
*.cpp
*.h
*.m
*.metal
!ruby_whisper.cpp
!ruby_whisper.h

View File

@ -1,9 +0,0 @@
ggml/src/ggml-cpu/ggml-cpu-cpp.o: \
ggml/src/ggml-cpu/ggml-cpu.cpp \
ggml/include/ggml-backend.h \
ggml/include/ggml.h \
ggml/include/ggml-alloc.h \
ggml/src/ggml-backend-impl.h \
ggml/include/ggml-cpu.h \
ggml/src/ggml-impl.h
$(CXX) $(CXXFLAGS) -c $< -o $@

View File

@ -1,10 +1,7 @@
require 'mkmf'
# need to use c++ compiler flags
$CXXFLAGS << ' -std=c++17'
$LDFLAGS << ' -lstdc++'
$CXXFLAGS << ' -std=c++11'
# Set to true when building binary gems
if enable_config('static-stdlib', false)
$LDFLAGS << ' -static-libgcc -static-libstdc++'
@ -15,6 +12,34 @@ if enable_config('march-tune-native', false)
$CXXFLAGS << ' -march=native -mtune=native'
end
def with_disabling_unsupported_files
disabled_files = []
unless $GGML_METAL
disabled_files << 'ggml-metal.h' << 'ggml-metal.m'
end
unless $GGML_METAL_EMBED_LIBRARY
disabled_files << 'ggml-metal.metal'
end
unless $OBJ_ALL&.include? 'ggml-blas.o'
disabled_files << 'ggml-blas.h' << 'ggml-blas.cpp'
end
disabled_files.filter! {|file| File.exist? file}
disabled_files.each do |file|
File.rename file, "#{file}.disabled"
end
yield
disabled_files.each do |file|
File.rename "#{file}.disabled", file
end
end
if ENV['WHISPER_METAL']
$GGML_METAL ||= true
$DEPRECATE_WARNING ||= true
@ -35,16 +60,16 @@ if $GGML_METAL
$GGML_METAL_EMBED_LIBRARY = true
end
$MK_CPPFLAGS = '-Iggml/include -Iggml/src -Iggml/src/ggml-cpu -Iinclude -Isrc -Iexamples'
$MK_CPPFLAGS = ''
$MK_CFLAGS = '-std=c11 -fPIC'
$MK_CXXFLAGS = '-std=c++17 -fPIC'
$MK_NVCCFLAGS = '-std=c++17'
$MK_CXXFLAGS = '-std=c++11 -fPIC'
$MK_NVCCFLAGS = '-std=c++11'
$MK_LDFLAGS = ''
$OBJ_GGML = []
$OBJ_WHISPER = []
$OBJ_COMMON = []
$OBJ_SDL = []
$OBJ_GGML = ''
$OBJ_WHISPER = ''
$OBJ_COMMON = ''
$OBJ_SDL = ''
$MK_CPPFLAGS << ' -D_XOPEN_SOURCE=600'
@ -111,6 +136,11 @@ unless ENV['RISCV']
$MK_CFLAGS << ' -march=native -mtune=native'
$HOST_CXXFLAGS << ' -march=native -mtune=native'
end
if $UNAME_M.match? /aarch64.*/
$MK_CFLAGS << ' -mcpu=native'
$MK_CXXFLAGS << ' -mcpu=native'
end
else
$MK_CFLAGS << ' -march=rv64gcv -mabi=lp64d'
$MK_CXXFLAGS << ' -march=rv64gcv -mabi=lp64d'
@ -118,11 +148,11 @@ end
unless ENV['GGML_NO_ACCELERATE']
if $UNAME_S == 'Darwin'
$MK_CPPFLAGS << ' -DGGML_USE_ACCELERATE -DGGML_USE_BLAS -DGGML_BLAS_USE_ACCELERATE'
$MK_CPPFLAGS << ' -DGGML_USE_ACCELERATE -DGGML_USE_BLAS'
$MK_CPPFLAGS << ' -DACCELERATE_NEW_LAPACK'
$MK_CPPFLAGS << ' -DACCELERATE_LAPACK_ILP64'
$MK_LDFLAGS << ' -framework Accelerate'
$OBJ_GGML << 'ggml/src/ggml-blas/ggml-blas.o'
$OBJ_GGML << ' ggml-blas.o'
end
end
@ -130,20 +160,20 @@ if ENV['GGML_OPENBLAS']
$MK_CPPFLAGS << " -DGGML_USE_BLAS #{`pkg-config --cflags-only-I openblas`.chomp}"
$MK_CFLAGS << " #{`pkg-config --cflags-only-other openblas)`.chomp}"
$MK_LDFLAGS << " #{`pkg-config --libs openblas`}"
$OBJ_GGML << 'ggml/src/ggml-blas/ggml-blas.o'
$OBJ_GGML << ' ggml-blas.o'
end
if ENV['GGML_OPENBLAS64']
$MK_CPPFLAGS << " -DGGML_USE_BLAS #{`pkg-config --cflags-only-I openblas64`.chomp}"
$MK_CFLAGS << " #{`pkg-config --cflags-only-other openblas64)`.chomp}"
$MK_LDFLAGS << " #{`pkg-config --libs openblas64`}"
$OBJ_GGML << 'ggml/src/ggml-blas/ggml-blas.o'
$OBJ_GGML << ' ggml-blas.o'
end
if $GGML_METAL
$MK_CPPFLAGS << ' -DGGML_USE_METAL'
$MK_LDFLAGS << ' -framework Foundation -framework Metal -framework MetalKit'
$OBJ_GGML << 'ggml/src/ggml-metal/ggml-metal.o'
$OBJ_GGML << ' ggml-metal.o'
if ENV['GGML_METAL_NDEBUG']
$MK_CPPFLAGS << ' -DGGML_METAL_NDEBUG'
@ -151,30 +181,21 @@ if $GGML_METAL
if $GGML_METAL_EMBED_LIBRARY
$MK_CPPFLAGS << ' -DGGML_METAL_EMBED_LIBRARY'
$OBJ_GGML << 'ggml/src/ggml-metal/ggml-metal-embed.o'
$OBJ_GGML << ' ggml-metal-embed.o'
end
end
$OBJ_GGML <<
'ggml/src/ggml.o' <<
'ggml/src/ggml-alloc.o' <<
'ggml/src/ggml-backend.o' <<
'ggml/src/ggml-backend-reg.o' <<
'ggml/src/ggml-opt.o' <<
'ggml/src/ggml-quants.o' <<
'ggml/src/ggml-threading.o' <<
'ggml/src/ggml-cpu/ggml-cpu.o' <<
'ggml/src/ggml-cpu/ggml-cpu-cpp.o' <<
'ggml/src/ggml-cpu/ggml-cpu-aarch64.o' <<
'ggml/src/ggml-cpu/ggml-cpu-hbm.o' <<
'ggml/src/ggml-cpu/ggml-cpu-quants.o' <<
'ggml/src/ggml-cpu/ggml-cpu-traits.o'
' ggml.o' <<
' ggml-alloc.o' <<
' ggml-backend.o' <<
' ggml-quants.o' <<
' ggml-aarch64.o'
$OBJ_WHISPER <<
'src/whisper.o'
' whisper.o'
$objs = $OBJ_GGML + $OBJ_WHISPER + $OBJ_COMMON + $OBJ_SDL
$objs << "ruby_whisper.o"
$OBJ_ALL = "#{$OBJ_GGML} #{$OBJ_WHISPER} #{$OBJ_COMMON} #{$OBJ_SDL}"
$CPPFLAGS = "#{$MK_CPPFLAGS} #{$CPPFLAGS}"
$CFLAGS = "#{$CPPFLAGS} #{$MK_CFLAGS} #{$GF_CFLAGS} #{$CFLAGS}"
@ -183,16 +204,26 @@ $CXXFLAGS = "#{$BASE_CXXFLAGS} #{$HOST_CXXFLAGS} #{$GF_CXXFLAGS} #{$CPPFLAGS}"
$NVCCFLAGS = "#{$MK_NVCCFLAGS} #{$NVCCFLAGS}"
$LDFLAGS = "#{$MK_LDFLAGS} #{$LDFLAGS}"
create_makefile('whisper')
if $GGML_METAL_EMBED_LIBRARY
File.write 'depend', "$(OBJS): $(OBJS) ggml-metal-embed.o\n"
end
with_disabling_unsupported_files do
create_makefile('whisper')
end
File.open 'Makefile', 'a' do |file|
file.puts 'include scripts/get-flags.mk'
file.puts 'include cpu.mk'
file.puts 'include get-flags.mk'
if $GGML_METAL
file.puts 'include metal.mk'
if $GGML_METAL_EMBED_LIBRARY
# mkmf determines object files to compile dependent on existing *.{c,cpp,m} files
# but ggml-metal-embed.c doesn't exist on creating Makefile.
file.puts "objs := $(OBJS)"
file.puts "OBJS = $(objs) 'ggml-metal-embed.o'"
file.puts 'include metal-embed.mk'
end
end

View File

@ -1,17 +1,14 @@
ggml/src/ggml-metal/ggml-metal-embed.o: \
ggml/src/ggml-metal/ggml-metal.metal \
ggml/src/ggml-metal/ggml-metal-impl.h \
ggml/src/ggml-common.h
ggml-metal-embed.o: \
ggml-metal.metal \
ggml-common.h
@echo "Embedding Metal library"
@sed -e '/__embed_ggml-common.h__/r ggml/src/ggml-common.h' -e '/__embed_ggml-common.h__/d' < ggml/src/ggml-metal/ggml-metal.metal > ggml/src/ggml-metal/ggml-metal-embed.metal.tmp
@sed -e '/#include "ggml-metal-impl.h"/r ggml/src/ggml-metal/ggml-metal-impl.h' -e '/#include "ggml-metal-impl.h"/d' < ggml/src/ggml-metal/ggml-metal-embed.metal.tmp > ggml/src/ggml-metal/ggml-metal-embed.metal
$(eval TEMP_ASSEMBLY=$(shell mktemp -d))
@echo ".section __DATA, __ggml_metallib" > $(TEMP_ASSEMBLY)/ggml-metal-embed.s
@echo ".globl _ggml_metallib_start" >> $(TEMP_ASSEMBLY)/ggml-metal-embed.s
@echo "_ggml_metallib_start:" >> $(TEMP_ASSEMBLY)/ggml-metal-embed.s
@echo ".incbin \"ggml/src/ggml-metal/ggml-metal-embed.metal\"" >> $(TEMP_ASSEMBLY)/ggml-metal-embed.s
@echo ".globl _ggml_metallib_end" >> $(TEMP_ASSEMBLY)/ggml-metal-embed.s
@echo "_ggml_metallib_end:" >> $(TEMP_ASSEMBLY)/ggml-metal-embed.s
$(CC) $(CFLAGS) -c $(TEMP_ASSEMBLY)/ggml-metal-embed.s -o $@
@rm -f ${TEMP_ASSEMBLY}/ggml-metal-embed.s
@rmdir ${TEMP_ASSEMBLY}
@sed -e '/#include "ggml-common.h"/r ggml-common.h' -e '/#include "ggml-common.h"/d' < ggml-metal.metal > ggml-metal-embed.metal
$(eval TEMP_ASSEMBLY=$(shell mktemp))
@echo ".section __DATA, __ggml_metallib" > $(TEMP_ASSEMBLY)
@echo ".globl _ggml_metallib_start" >> $(TEMP_ASSEMBLY)
@echo "_ggml_metallib_start:" >> $(TEMP_ASSEMBLY)
@echo ".incbin \"ggml-metal-embed.metal\"" >> $(TEMP_ASSEMBLY)
@echo ".globl _ggml_metallib_end" >> $(TEMP_ASSEMBLY)
@echo "_ggml_metallib_end:" >> $(TEMP_ASSEMBLY)
@$(AS) $(TEMP_ASSEMBLY) -o $@
@rm -f ${TEMP_ASSEMBLY}

View File

@ -1,6 +0,0 @@
ggml/src/ggml-metal/ggml-metal.o: \
ggml/src/ggml-metal/ggml-metal.m \
ggml/src/ggml-metal/ggml-metal-impl.h \
ggml/include/ggml-metal.h \
ggml/include/ggml.h
$(CC) $(CFLAGS) -c $< -o $@

View File

@ -1,5 +1,4 @@
#include <ruby.h>
#include <ruby/memory_view.h>
#include "ruby_whisper.h"
#define DR_WAV_IMPLEMENTATION
#include "dr_wav.h"
@ -36,26 +35,11 @@ extern "C" {
VALUE mWhisper;
VALUE cContext;
VALUE cParams;
VALUE eError;
VALUE cSegment;
VALUE cModel;
static ID id_to_s;
static ID id_call;
static ID id___method__;
static ID id_to_enum;
static ID id_length;
static ID id_next;
static ID id_new;
static ID id_to_path;
static ID id_URI;
static ID id_pre_converted_models;
static bool is_log_callback_finalized = false;
// High level API
static VALUE rb_whisper_segment_initialize(VALUE context, int index);
/*
* call-seq:
@ -104,39 +88,6 @@ static VALUE ruby_whisper_s_lang_str_full(VALUE self, VALUE id) {
return rb_str_new2(str_full);
}
static VALUE ruby_whisper_s_finalize_log_callback(VALUE self, VALUE id) {
is_log_callback_finalized = true;
return Qnil;
}
/*
* call-seq:
* log_set ->(level, buffer, user_data) { ... }, user_data -> nil
*/
static VALUE ruby_whisper_s_log_set(VALUE self, VALUE log_callback, VALUE user_data) {
VALUE old_callback = rb_iv_get(self, "log_callback");
if (!NIL_P(old_callback)) {
rb_undefine_finalizer(old_callback);
}
rb_iv_set(self, "log_callback", log_callback);
rb_iv_set(self, "user_data", user_data);
VALUE finalize_log_callback = rb_funcall(mWhisper, rb_intern("method"), 1, rb_str_new2("finalize_log_callback"));
rb_define_finalizer(log_callback, finalize_log_callback);
whisper_log_set([](ggml_log_level level, const char * buffer, void * user_data) {
if (is_log_callback_finalized) {
return;
}
VALUE log_callback = rb_iv_get(mWhisper, "log_callback");
VALUE udata = rb_iv_get(mWhisper, "user_data");
rb_funcall(log_callback, id_call, 3, INT2NUM(level), rb_str_new2(buffer), udata);
}, nullptr);
return Qnil;
}
static void ruby_whisper_free(ruby_whisper *rw) {
if (rw->context) {
whisper_free(rw->context);
@ -191,74 +142,10 @@ static ruby_whisper_callback_container * rb_whisper_callback_container_allocate(
return container;
}
static void new_segment_callback(struct whisper_context *ctx, struct whisper_state *state, int n_new, void *user_data) {
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
// Currently, doesn't support state because
// those require to resolve GC-related problems.
if (!NIL_P(container->callback)) {
rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(n_new), container->user_data);
}
const long callbacks_len = RARRAY_LEN(container->callbacks);
if (0 == callbacks_len) {
return;
}
const int n_segments = whisper_full_n_segments_from_state(state);
for (int i = n_new; i > 0; i--) {
int i_segment = n_segments - i;
VALUE segment = rb_whisper_segment_initialize(*container->context, i_segment);
for (int j = 0; j < callbacks_len; j++) {
VALUE cb = rb_ary_entry(container->callbacks, j);
rb_funcall(cb, id_call, 1, segment);
}
}
}
static void progress_callback(struct whisper_context *ctx, struct whisper_state *state, int progress_cur, void *user_data) {
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
const VALUE progress = INT2NUM(progress_cur);
// Currently, doesn't support state because
// those require to resolve GC-related problems.
if (!NIL_P(container->callback)) {
rb_funcall(container->callback, id_call, 4, *container->context, Qnil, progress, container->user_data);
}
const long callbacks_len = RARRAY_LEN(container->callbacks);
if (0 == callbacks_len) {
return;
}
for (int j = 0; j < callbacks_len; j++) {
VALUE cb = rb_ary_entry(container->callbacks, j);
rb_funcall(cb, id_call, 1, progress);
}
}
static bool abort_callback(void * user_data) {
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
if (!NIL_P(container->callback)) {
VALUE result = rb_funcall(container->callback, id_call, 1, container->user_data);
if (!NIL_P(result) && Qfalse != result) {
return true;
}
}
const long callbacks_len = RARRAY_LEN(container->callbacks);
if (0 == callbacks_len) {
return false;
}
for (int j = 0; j < callbacks_len; j++) {
VALUE cb = rb_ary_entry(container->callbacks, j);
VALUE result = rb_funcall(cb, id_call, 1, container->user_data);
if (!NIL_P(result) && Qfalse != result) {
return true;
}
}
return false;
}
static VALUE ruby_whisper_params_allocate(VALUE klass) {
ruby_whisper_params *rwp;
rwp = ALLOC(ruby_whisper_params);
rwp->params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
rwp->diarize = false;
rwp->new_segment_callback_container = rb_whisper_callback_container_allocate();
rwp->progress_callback_container = rb_whisper_callback_container_allocate();
rwp->abort_callback_container = rb_whisper_callback_container_allocate();
@ -267,9 +154,7 @@ static VALUE ruby_whisper_params_allocate(VALUE klass) {
/*
* call-seq:
* new("base.en") -> Whisper::Context
* new("path/to/model.bin") -> Whisper::Context
* new(Whisper::Model::URI.new("https://example.net/uri/of/model.bin")) -> Whisper::Context
*/
static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) {
ruby_whisper *rw;
@ -279,25 +164,6 @@ static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) {
rb_scan_args(argc, argv, "01", &whisper_model_file_path);
Data_Get_Struct(self, ruby_whisper, rw);
VALUE pre_converted_models = rb_funcall(cModel, id_pre_converted_models, 0);
VALUE pre_converted_model = rb_hash_aref(pre_converted_models, whisper_model_file_path);
if (!NIL_P(pre_converted_model)) {
whisper_model_file_path = pre_converted_model;
}
if (TYPE(whisper_model_file_path) == T_STRING) {
const char * whisper_model_file_path_str = StringValueCStr(whisper_model_file_path);
if (strncmp("http://", whisper_model_file_path_str, 7) == 0 || strncmp("https://", whisper_model_file_path_str, 8) == 0) {
VALUE uri_class = rb_const_get(cModel, id_URI);
whisper_model_file_path = rb_class_new_instance(1, &whisper_model_file_path, uri_class);
}
}
if (rb_obj_is_kind_of(whisper_model_file_path, rb_path2class("URI::HTTP"))) {
VALUE uri_class = rb_const_get(cModel, id_URI);
whisper_model_file_path = rb_class_new_instance(1, &whisper_model_file_path, uri_class);
}
if (rb_respond_to(whisper_model_file_path, id_to_path)) {
whisper_model_file_path = rb_funcall(whisper_model_file_path, id_to_path, 0);
}
if (!rb_respond_to(whisper_model_file_path, id_to_s)) {
rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context");
}
@ -308,25 +174,8 @@ static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) {
return self;
}
static void register_callbacks(ruby_whisper_params * rwp, VALUE * self) {
if (!NIL_P(rwp->new_segment_callback_container->callback) || 0 != RARRAY_LEN(rwp->new_segment_callback_container->callbacks)) {
rwp->new_segment_callback_container->context = self;
rwp->params.new_segment_callback = new_segment_callback;
rwp->params.new_segment_callback_user_data = rwp->new_segment_callback_container;
}
if (!NIL_P(rwp->progress_callback_container->callback) || 0 != RARRAY_LEN(rwp->progress_callback_container->callbacks)) {
rwp->progress_callback_container->context = self;
rwp->params.progress_callback = progress_callback;
rwp->params.progress_callback_user_data = rwp->progress_callback_container;
}
if (!NIL_P(rwp->abort_callback_container->callback) || 0 != RARRAY_LEN(rwp->abort_callback_container->callbacks)) {
rwp->abort_callback_container->context = self;
rwp->params.abort_callback = abort_callback;
rwp->params.abort_callback_user_data = rwp->abort_callback_container;
}
}
// High level API
static VALUE rb_whisper_segment_initialize(VALUE context, int index);
/*
* transcribe a single file
@ -448,7 +297,80 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
rwp->params.encoder_begin_callback_user_data = &is_aborted;
}
register_callbacks(rwp, &self);
if (!NIL_P(rwp->new_segment_callback_container->callback) || 0 != RARRAY_LEN(rwp->new_segment_callback_container->callbacks)) {
rwp->params.new_segment_callback = [](struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) {
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
// Currently, doesn't support state because
// those require to resolve GC-related problems.
if (!NIL_P(container->callback)) {
rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(n_new), container->user_data);
}
const long callbacks_len = RARRAY_LEN(container->callbacks);
if (0 == callbacks_len) {
return;
}
const int n_segments = whisper_full_n_segments_from_state(state);
for (int i = n_new; i > 0; i--) {
int i_segment = n_segments - i;
VALUE segment = rb_whisper_segment_initialize(*container->context, i_segment);
for (int j = 0; j < callbacks_len; j++) {
VALUE cb = rb_ary_entry(container->callbacks, j);
rb_funcall(cb, id_call, 1, segment);
}
}
};
rwp->new_segment_callback_container->context = &self;
rwp->params.new_segment_callback_user_data = rwp->new_segment_callback_container;
}
if (!NIL_P(rwp->progress_callback_container->callback) || 0 != RARRAY_LEN(rwp->progress_callback_container->callbacks)) {
rwp->params.progress_callback = [](struct whisper_context *ctx, struct whisper_state * /*state*/, int progress_cur, void *user_data) {
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
const VALUE progress = INT2NUM(progress_cur);
// Currently, doesn't support state because
// those require to resolve GC-related problems.
if (!NIL_P(container->callback)) {
rb_funcall(container->callback, id_call, 4, *container->context, Qnil, progress, container->user_data);
}
const long callbacks_len = RARRAY_LEN(container->callbacks);
if (0 == callbacks_len) {
return;
}
for (int j = 0; j < callbacks_len; j++) {
VALUE cb = rb_ary_entry(container->callbacks, j);
rb_funcall(cb, id_call, 1, progress);
}
};
rwp->progress_callback_container->context = &self;
rwp->params.progress_callback_user_data = rwp->progress_callback_container;
}
if (!NIL_P(rwp->abort_callback_container->callback) || 0 != RARRAY_LEN(rwp->abort_callback_container->callbacks)) {
rwp->params.abort_callback = [](void * user_data) {
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
if (!NIL_P(container->callback)) {
VALUE result = rb_funcall(container->callback, id_call, 1, container->user_data);
if (!NIL_P(result) && Qfalse != result) {
return true;
}
}
const long callbacks_len = RARRAY_LEN(container->callbacks);
if (0 == callbacks_len) {
return false;
}
for (int j = 0; j < callbacks_len; j++) {
VALUE cb = rb_ary_entry(container->callbacks, j);
VALUE result = rb_funcall(cb, id_call, 1, container->user_data);
if (!NIL_P(result) && Qfalse != result) {
return true;
}
}
return false;
};
rwp->abort_callback_container->context = &self;
rwp->params.abort_callback_user_data = rwp->abort_callback_container;
}
if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) {
fprintf(stderr, "failed to process audio\n");
@ -467,290 +389,6 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
return self;
}
/*
* call-seq:
* model_n_vocab -> Integer
*/
VALUE ruby_whisper_model_n_vocab(VALUE self) {
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
return INT2NUM(whisper_model_n_vocab(rw->context));
}
/*
* call-seq:
* model_n_audio_ctx -> Integer
*/
VALUE ruby_whisper_model_n_audio_ctx(VALUE self) {
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
return INT2NUM(whisper_model_n_audio_ctx(rw->context));
}
/*
* call-seq:
* model_n_audio_state -> Integer
*/
VALUE ruby_whisper_model_n_audio_state(VALUE self) {
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
return INT2NUM(whisper_model_n_audio_state(rw->context));
}
/*
* call-seq:
* model_n_audio_head -> Integer
*/
VALUE ruby_whisper_model_n_audio_head(VALUE self) {
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
return INT2NUM(whisper_model_n_audio_head(rw->context));
}
/*
* call-seq:
* model_n_audio_layer -> Integer
*/
VALUE ruby_whisper_model_n_audio_layer(VALUE self) {
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
return INT2NUM(whisper_model_n_audio_layer(rw->context));
}
/*
* call-seq:
* model_n_text_ctx -> Integer
*/
VALUE ruby_whisper_model_n_text_ctx(VALUE self) {
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
return INT2NUM(whisper_model_n_text_ctx(rw->context));
}
/*
* call-seq:
* model_n_text_state -> Integer
*/
VALUE ruby_whisper_model_n_text_state(VALUE self) {
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
return INT2NUM(whisper_model_n_text_state(rw->context));
}
/*
* call-seq:
* model_n_text_head -> Integer
*/
VALUE ruby_whisper_model_n_text_head(VALUE self) {
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
return INT2NUM(whisper_model_n_text_head(rw->context));
}
/*
* call-seq:
* model_n_text_layer -> Integer
*/
VALUE ruby_whisper_model_n_text_layer(VALUE self) {
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
return INT2NUM(whisper_model_n_text_layer(rw->context));
}
/*
* call-seq:
* model_n_mels -> Integer
*/
VALUE ruby_whisper_model_n_mels(VALUE self) {
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
return INT2NUM(whisper_model_n_mels(rw->context));
}
/*
* call-seq:
* model_ftype -> Integer
*/
VALUE ruby_whisper_model_ftype(VALUE self) {
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
return INT2NUM(whisper_model_ftype(rw->context));
}
/*
* call-seq:
* model_type -> String
*/
VALUE ruby_whisper_model_type(VALUE self) {
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
return rb_str_new2(whisper_model_type_readable(rw->context));
}
/*
* Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
* Not thread safe for same context
* Uses the specified decoding strategy to obtain the text.
*
* call-seq:
* full(params, samples, n_samples) -> nil
* full(params, samples) -> nil
*
* The second argument +samples+ must be an array of samples, respond to :length, or be a MemoryView of an array of float. It must be 32 bit float PCM audio data.
*/
VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self) {
if (argc < 2 || argc > 3) {
rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc);
}
ruby_whisper *rw;
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper, rw);
VALUE params = argv[0];
Data_Get_Struct(params, ruby_whisper_params, rwp);
VALUE samples = argv[1];
int n_samples;
rb_memory_view_t view;
const bool memory_view_available_p = rb_memory_view_available_p(samples);
if (argc == 3) {
n_samples = NUM2INT(argv[2]);
if (TYPE(samples) == T_ARRAY) {
if (RARRAY_LEN(samples) < n_samples) {
rb_raise(rb_eArgError, "samples length %ld is less than n_samples %d", RARRAY_LEN(samples), n_samples);
}
}
// Should check when samples.respond_to?(:length)?
} else {
if (TYPE(samples) == T_ARRAY) {
n_samples = RARRAY_LEN(samples);
} else if (memory_view_available_p) {
if (!rb_memory_view_get(samples, &view, RUBY_MEMORY_VIEW_SIMPLE)) {
view.obj = Qnil;
rb_raise(rb_eArgError, "unable to get a memory view");
}
n_samples = view.byte_size / view.item_size;
} else if (rb_respond_to(samples, id_length)) {
n_samples = NUM2INT(rb_funcall(samples, id_length, 0));
} else {
rb_raise(rb_eArgError, "samples must respond to :length or be a MemoryView of an array of flaot when n_samples is not given");
}
}
float * c_samples = (float *)malloc(n_samples * sizeof(float));
if (memory_view_available_p) {
c_samples = (float *)view.data;
} else {
if (TYPE(samples) == T_ARRAY) {
for (int i = 0; i < n_samples; i++) {
c_samples[i] = RFLOAT_VALUE(rb_ary_entry(samples, i));
}
} else {
// TODO: use rb_block_call
VALUE iter = rb_funcall(samples, id_to_enum, 1, rb_str_new2("each"));
for (int i = 0; i < n_samples; i++) {
// TODO: check if iter is exhausted and raise ArgumentError appropriately
VALUE sample = rb_funcall(iter, id_next, 0);
c_samples[i] = RFLOAT_VALUE(sample);
}
}
}
register_callbacks(rwp, &self);
const int result = whisper_full(rw->context, rwp->params, c_samples, n_samples);
if (0 == result) {
return Qnil;
} else {
rb_exc_raise(rb_funcall(eError, id_new, 1, result));
}
}
/*
* Split the input audio in chunks and process each chunk separately using whisper_full_with_state()
* Result is stored in the default state of the context
* Not thread safe if executed in parallel on the same context.
* It seems this approach can offer some speedup in some cases.
* However, the transcription accuracy can be worse at the beginning and end of each chunk.
*
* call-seq:
* full_parallel(params, samples) -> nil
* full_parallel(params, samples, n_samples) -> nil
* full_parallel(params, samples, n_samples, n_processors) -> nil
* full_parallel(params, samples, nil, n_processors) -> nil
*/
static VALUE ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self) {
if (argc < 2 || argc > 4) {
rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc);
}
ruby_whisper *rw;
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper, rw);
VALUE params = argv[0];
Data_Get_Struct(params, ruby_whisper_params, rwp);
VALUE samples = argv[1];
int n_samples;
int n_processors;
rb_memory_view_t view;
const bool memory_view_available_p = rb_memory_view_available_p(samples);
switch (argc) {
case 2:
n_processors = 1;
break;
case 3:
n_processors = 1;
break;
case 4:
n_processors = NUM2INT(argv[3]);
break;
}
if (argc >= 3 && !NIL_P(argv[2])) {
n_samples = NUM2INT(argv[2]);
if (TYPE(samples) == T_ARRAY) {
if (RARRAY_LEN(samples) < n_samples) {
rb_raise(rb_eArgError, "samples length %ld is less than n_samples %d", RARRAY_LEN(samples), n_samples);
}
}
// Should check when samples.respond_to?(:length)?
} else if (memory_view_available_p) {
if (!rb_memory_view_get(samples, &view, RUBY_MEMORY_VIEW_SIMPLE)) {
view.obj = Qnil;
rb_raise(rb_eArgError, "unable to get a memory view");
}
n_samples = view.byte_size / view.item_size;
} else {
if (TYPE(samples) == T_ARRAY) {
n_samples = RARRAY_LEN(samples);
} else if (rb_respond_to(samples, id_length)) {
n_samples = NUM2INT(rb_funcall(samples, id_length, 0));
} else {
rb_raise(rb_eArgError, "samples must respond to :length or be a MemoryView of an array of flaot when n_samples is not given");
}
}
float * c_samples = (float *)malloc(n_samples * sizeof(float));
if (memory_view_available_p) {
c_samples = (float *)view.data;
} else {
if (TYPE(samples) == T_ARRAY) {
for (int i = 0; i < n_samples; i++) {
c_samples[i] = RFLOAT_VALUE(rb_ary_entry(samples, i));
}
} else {
// FIXME: use rb_block_call
VALUE iter = rb_funcall(samples, id_to_enum, 1, rb_str_new2("each"));
for (int i = 0; i < n_samples; i++) {
// TODO: check if iter is exhausted and raise ArgumentError
VALUE sample = rb_funcall(iter, id_next, 0);
c_samples[i] = RFLOAT_VALUE(sample);
}
}
}
register_callbacks(rwp, &self);
const int result = whisper_full_parallel(rw->context, rwp->params, c_samples, n_samples, n_processors);
if (0 == result) {
return Qnil;
} else {
rb_exc_raise(rb_funcall(eError, id_new, 1, result));
}
}
/*
* Number of segments.
*
@ -847,18 +485,6 @@ static VALUE ruby_whisper_full_get_segment_text(VALUE self, VALUE i_segment) {
return rb_str_new2(text);
}
/*
* call-seq:
* full_get_segment_no_speech_prob(segment_index) -> Float
*/
static VALUE ruby_whisper_full_get_segment_no_speech_prob(VALUE self, VALUE i_segment) {
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
const float no_speech_prob = whisper_full_get_segment_no_speech_prob(rw->context, c_i_segment);
return DBL2NUM(no_speech_prob);
}
/*
* params.language = "auto" | "en", etc...
*
@ -1015,19 +641,19 @@ static VALUE ruby_whisper_params_get_suppress_blank(VALUE self) {
}
/*
* call-seq:
* suppress_nst = force_suppress -> force_suppress
* suppress_non_speech_tokens = force_suppress -> force_suppress
*/
static VALUE ruby_whisper_params_set_suppress_nst(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, suppress_nst, value)
static VALUE ruby_whisper_params_set_suppress_non_speech_tokens(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, suppress_non_speech_tokens, value)
}
/*
* If true, suppresses non-speech-tokens.
*
* call-seq:
* suppress_nst -> bool
* suppress_non_speech_tokens -> bool
*/
static VALUE ruby_whisper_params_get_suppress_nst(VALUE self) {
BOOL_PARAMS_GETTER(self, suppress_nst)
static VALUE ruby_whisper_params_get_suppress_non_speech_tokens(VALUE self) {
BOOL_PARAMS_GETTER(self, suppress_non_speech_tokens)
}
/*
* If true, enables token-level timestamps.
@ -1297,25 +923,6 @@ static VALUE ruby_whisper_params_set_logprob_thold(VALUE self, VALUE value) {
rwp->params.logprob_thold = RFLOAT_VALUE(value);
return value;
}
/*
* call-seq:
* no_speech_thold -> Float
*/
static VALUE ruby_whisper_params_get_no_speech_thold(VALUE self) {
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
return DBL2NUM(rwp->params.no_speech_thold);
}
/*
* call-seq:
* no_speech_thold = threshold -> threshold
*/
static VALUE ruby_whisper_params_set_no_speech_thold(VALUE self, VALUE value) {
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->params.no_speech_thold = RFLOAT_VALUE(value);
return value;
}
/*
* Sets new segment callback, called for every newly generated text segment.
*
@ -1408,9 +1015,7 @@ typedef struct {
int index;
} ruby_whisper_segment;
typedef struct {
VALUE context;
} ruby_whisper_model;
VALUE cSegment;
static void rb_whisper_segment_mark(ruby_whisper_segment *rws) {
rb_gc_mark(rws->context);
@ -1583,284 +1188,31 @@ static VALUE ruby_whisper_segment_get_text(VALUE self) {
return rb_str_new2(text);
}
/*
* call-seq:
* no_speech_prob -> Float
*/
static VALUE ruby_whisper_segment_get_no_speech_prob(VALUE self) {
ruby_whisper_segment *rws;
Data_Get_Struct(self, ruby_whisper_segment, rws);
ruby_whisper *rw;
Data_Get_Struct(rws->context, ruby_whisper, rw);
return DBL2NUM(whisper_full_get_segment_no_speech_prob(rw->context, rws->index));
}
static void rb_whisper_model_mark(ruby_whisper_model *rwm) {
rb_gc_mark(rwm->context);
}
static VALUE ruby_whisper_model_allocate(VALUE klass) {
ruby_whisper_model *rwm;
rwm = ALLOC(ruby_whisper_model);
return Data_Wrap_Struct(klass, rb_whisper_model_mark, RUBY_DEFAULT_FREE, rwm);
}
static VALUE rb_whisper_model_initialize(VALUE context) {
ruby_whisper_model *rwm;
const VALUE model = ruby_whisper_model_allocate(cModel);
Data_Get_Struct(model, ruby_whisper_model, rwm);
rwm->context = context;
return model;
};
/*
* call-seq:
* model -> Whisper::Model
*/
static VALUE ruby_whisper_get_model(VALUE self) {
return rb_whisper_model_initialize(self);
}
/*
* call-seq:
* n_vocab -> Integer
*/
static VALUE ruby_whisper_c_model_n_vocab(VALUE self) {
ruby_whisper_model *rwm;
Data_Get_Struct(self, ruby_whisper_model, rwm);
ruby_whisper *rw;
Data_Get_Struct(rwm->context, ruby_whisper, rw);
return INT2NUM(whisper_model_n_vocab(rw->context));
}
/*
* call-seq:
* n_audio_ctx -> Integer
*/
static VALUE ruby_whisper_c_model_n_audio_ctx(VALUE self) {
ruby_whisper_model *rwm;
Data_Get_Struct(self, ruby_whisper_model, rwm);
ruby_whisper *rw;
Data_Get_Struct(rwm->context, ruby_whisper, rw);
return INT2NUM(whisper_model_n_audio_ctx(rw->context));
}
/*
* call-seq:
* n_audio_state -> Integer
*/
static VALUE ruby_whisper_c_model_n_audio_state(VALUE self) {
ruby_whisper_model *rwm;
Data_Get_Struct(self, ruby_whisper_model, rwm);
ruby_whisper *rw;
Data_Get_Struct(rwm->context, ruby_whisper, rw);
return INT2NUM(whisper_model_n_audio_state(rw->context));
}
/*
* call-seq:
* n_audio_head -> Integer
*/
static VALUE ruby_whisper_c_model_n_audio_head(VALUE self) {
ruby_whisper_model *rwm;
Data_Get_Struct(self, ruby_whisper_model, rwm);
ruby_whisper *rw;
Data_Get_Struct(rwm->context, ruby_whisper, rw);
return INT2NUM(whisper_model_n_audio_head(rw->context));
}
/*
* call-seq:
* n_audio_layer -> Integer
*/
static VALUE ruby_whisper_c_model_n_audio_layer(VALUE self) {
ruby_whisper_model *rwm;
Data_Get_Struct(self, ruby_whisper_model, rwm);
ruby_whisper *rw;
Data_Get_Struct(rwm->context, ruby_whisper, rw);
return INT2NUM(whisper_model_n_audio_layer(rw->context));
}
/*
* call-seq:
* n_text_ctx -> Integer
*/
static VALUE ruby_whisper_c_model_n_text_ctx(VALUE self) {
ruby_whisper_model *rwm;
Data_Get_Struct(self, ruby_whisper_model, rwm);
ruby_whisper *rw;
Data_Get_Struct(rwm->context, ruby_whisper, rw);
return INT2NUM(whisper_model_n_text_ctx(rw->context));
}
/*
* call-seq:
* n_text_state -> Integer
*/
static VALUE ruby_whisper_c_model_n_text_state(VALUE self) {
ruby_whisper_model *rwm;
Data_Get_Struct(self, ruby_whisper_model, rwm);
ruby_whisper *rw;
Data_Get_Struct(rwm->context, ruby_whisper, rw);
return INT2NUM(whisper_model_n_text_state(rw->context));
}
/*
* call-seq:
* n_text_head -> Integer
*/
static VALUE ruby_whisper_c_model_n_text_head(VALUE self) {
ruby_whisper_model *rwm;
Data_Get_Struct(self, ruby_whisper_model, rwm);
ruby_whisper *rw;
Data_Get_Struct(rwm->context, ruby_whisper, rw);
return INT2NUM(whisper_model_n_text_head(rw->context));
}
/*
* call-seq:
* n_text_layer -> Integer
*/
static VALUE ruby_whisper_c_model_n_text_layer(VALUE self) {
ruby_whisper_model *rwm;
Data_Get_Struct(self, ruby_whisper_model, rwm);
ruby_whisper *rw;
Data_Get_Struct(rwm->context, ruby_whisper, rw);
return INT2NUM(whisper_model_n_text_layer(rw->context));
}
/*
* call-seq:
* n_mels -> Integer
*/
static VALUE ruby_whisper_c_model_n_mels(VALUE self) {
ruby_whisper_model *rwm;
Data_Get_Struct(self, ruby_whisper_model, rwm);
ruby_whisper *rw;
Data_Get_Struct(rwm->context, ruby_whisper, rw);
return INT2NUM(whisper_model_n_mels(rw->context));
}
/*
* call-seq:
* ftype -> Integer
*/
static VALUE ruby_whisper_c_model_ftype(VALUE self) {
ruby_whisper_model *rwm;
Data_Get_Struct(self, ruby_whisper_model, rwm);
ruby_whisper *rw;
Data_Get_Struct(rwm->context, ruby_whisper, rw);
return INT2NUM(whisper_model_ftype(rw->context));
}
/*
* call-seq:
* type -> String
*/
static VALUE ruby_whisper_c_model_type(VALUE self) {
ruby_whisper_model *rwm;
Data_Get_Struct(self, ruby_whisper_model, rwm);
ruby_whisper *rw;
Data_Get_Struct(rwm->context, ruby_whisper, rw);
return rb_str_new2(whisper_model_type_readable(rw->context));
}
static VALUE ruby_whisper_error_initialize(VALUE self, VALUE code) {
const int c_code = NUM2INT(code);
const char *raw_message;
switch (c_code) {
case -2:
raw_message = "failed to compute log mel spectrogram";
break;
case -3:
raw_message = "failed to auto-detect language";
break;
case -4:
raw_message = "too many decoders requested";
break;
case -5:
raw_message = "audio_ctx is larger than the maximum allowed";
break;
case -6:
raw_message = "failed to encode";
break;
case -7:
raw_message = "whisper_kv_cache_init() failed for self-attention cache";
break;
case -8:
raw_message = "failed to decode";
break;
case -9:
raw_message = "failed to decode";
break;
default:
raw_message = "unknown error";
break;
}
const VALUE message = rb_str_new2(raw_message);
rb_call_super(1, &message);
rb_iv_set(self, "@code", code);
return self;
}
void Init_whisper() {
id_to_s = rb_intern("to_s");
id_call = rb_intern("call");
id___method__ = rb_intern("__method__");
id_to_enum = rb_intern("to_enum");
id_length = rb_intern("length");
id_next = rb_intern("next");
id_new = rb_intern("new");
id_to_path = rb_intern("to_path");
id_URI = rb_intern("URI");
id_pre_converted_models = rb_intern("pre_converted_models");
mWhisper = rb_define_module("Whisper");
cContext = rb_define_class_under(mWhisper, "Context", rb_cObject);
cParams = rb_define_class_under(mWhisper, "Params", rb_cObject);
eError = rb_define_class_under(mWhisper, "Error", rb_eStandardError);
rb_define_const(mWhisper, "LOG_LEVEL_NONE", INT2NUM(GGML_LOG_LEVEL_NONE));
rb_define_const(mWhisper, "LOG_LEVEL_INFO", INT2NUM(GGML_LOG_LEVEL_INFO));
rb_define_const(mWhisper, "LOG_LEVEL_WARN", INT2NUM(GGML_LOG_LEVEL_WARN));
rb_define_const(mWhisper, "LOG_LEVEL_ERROR", INT2NUM(GGML_LOG_LEVEL_ERROR));
rb_define_const(mWhisper, "LOG_LEVEL_DEBUG", INT2NUM(GGML_LOG_LEVEL_DEBUG));
rb_define_const(mWhisper, "LOG_LEVEL_CONT", INT2NUM(GGML_LOG_LEVEL_CONT));
rb_define_singleton_method(mWhisper, "lang_max_id", ruby_whisper_s_lang_max_id, 0);
rb_define_singleton_method(mWhisper, "lang_id", ruby_whisper_s_lang_id, 1);
rb_define_singleton_method(mWhisper, "lang_str", ruby_whisper_s_lang_str, 1);
rb_define_singleton_method(mWhisper, "lang_str_full", ruby_whisper_s_lang_str_full, 1);
rb_define_singleton_method(mWhisper, "log_set", ruby_whisper_s_log_set, 2);
rb_define_singleton_method(mWhisper, "finalize_log_callback", ruby_whisper_s_finalize_log_callback, 1);
rb_define_alloc_func(cContext, ruby_whisper_allocate);
rb_define_method(cContext, "initialize", ruby_whisper_initialize, -1);
rb_define_method(cContext, "transcribe", ruby_whisper_transcribe, -1);
rb_define_method(cContext, "model_n_vocab", ruby_whisper_model_n_vocab, 0);
rb_define_method(cContext, "model_n_audio_ctx", ruby_whisper_model_n_audio_ctx, 0);
rb_define_method(cContext, "model_n_audio_state", ruby_whisper_model_n_audio_state, 0);
rb_define_method(cContext, "model_n_audio_head", ruby_whisper_model_n_audio_head, 0);
rb_define_method(cContext, "model_n_audio_layer", ruby_whisper_model_n_audio_layer, 0);
rb_define_method(cContext, "model_n_text_ctx", ruby_whisper_model_n_text_ctx, 0);
rb_define_method(cContext, "model_n_text_state", ruby_whisper_model_n_text_state, 0);
rb_define_method(cContext, "model_n_text_head", ruby_whisper_model_n_text_head, 0);
rb_define_method(cContext, "model_n_text_layer", ruby_whisper_model_n_text_layer, 0);
rb_define_method(cContext, "model_n_mels", ruby_whisper_model_n_mels, 0);
rb_define_method(cContext, "model_ftype", ruby_whisper_model_ftype, 0);
rb_define_method(cContext, "model_type", ruby_whisper_model_type, 0);
rb_define_method(cContext, "full_n_segments", ruby_whisper_full_n_segments, 0);
rb_define_method(cContext, "full_lang_id", ruby_whisper_full_lang_id, 0);
rb_define_method(cContext, "full_get_segment_t0", ruby_whisper_full_get_segment_t0, 1);
rb_define_method(cContext, "full_get_segment_t1", ruby_whisper_full_get_segment_t1, 1);
rb_define_method(cContext, "full_get_segment_speaker_turn_next", ruby_whisper_full_get_segment_speaker_turn_next, 1);
rb_define_method(cContext, "full_get_segment_text", ruby_whisper_full_get_segment_text, 1);
rb_define_method(cContext, "full_get_segment_no_speech_prob", ruby_whisper_full_get_segment_no_speech_prob, 1);
rb_define_method(cContext, "full", ruby_whisper_full, -1);
rb_define_method(cContext, "full_parallel", ruby_whisper_full_parallel, -1);
rb_define_alloc_func(cParams, ruby_whisper_params_allocate);
@ -1882,8 +1234,8 @@ void Init_whisper() {
rb_define_method(cParams, "print_timestamps=", ruby_whisper_params_set_print_timestamps, 1);
rb_define_method(cParams, "suppress_blank", ruby_whisper_params_get_suppress_blank, 0);
rb_define_method(cParams, "suppress_blank=", ruby_whisper_params_set_suppress_blank, 1);
rb_define_method(cParams, "suppress_nst", ruby_whisper_params_get_suppress_nst, 0);
rb_define_method(cParams, "suppress_nst=", ruby_whisper_params_set_suppress_nst, 1);
rb_define_method(cParams, "suppress_non_speech_tokens", ruby_whisper_params_get_suppress_non_speech_tokens, 0);
rb_define_method(cParams, "suppress_non_speech_tokens=", ruby_whisper_params_set_suppress_non_speech_tokens, 1);
rb_define_method(cParams, "token_timestamps", ruby_whisper_params_get_token_timestamps, 0);
rb_define_method(cParams, "token_timestamps=", ruby_whisper_params_set_token_timestamps, 1);
rb_define_method(cParams, "split_on_word", ruby_whisper_params_get_split_on_word, 0);
@ -1912,8 +1264,6 @@ void Init_whisper() {
rb_define_method(cParams, "entropy_thold=", ruby_whisper_params_set_entropy_thold, 1);
rb_define_method(cParams, "logprob_thold", ruby_whisper_params_get_logprob_thold, 0);
rb_define_method(cParams, "logprob_thold=", ruby_whisper_params_set_logprob_thold, 1);
rb_define_method(cParams, "no_speech_thold", ruby_whisper_params_get_no_speech_thold, 0);
rb_define_method(cParams, "no_speech_thold=", ruby_whisper_params_set_no_speech_thold, 1);
rb_define_method(cParams, "new_segment_callback=", ruby_whisper_params_set_new_segment_callback, 1);
rb_define_method(cParams, "new_segment_callback_user_data=", ruby_whisper_params_set_new_segment_callback_user_data, 1);
@ -1922,9 +1272,6 @@ void Init_whisper() {
rb_define_method(cParams, "abort_callback=", ruby_whisper_params_set_abort_callback, 1);
rb_define_method(cParams, "abort_callback_user_data=", ruby_whisper_params_set_abort_callback_user_data, 1);
rb_define_attr(eError, "code", true, false);
rb_define_method(eError, "initialize", ruby_whisper_error_initialize, 1);
// High leve
cSegment = rb_define_class_under(mWhisper, "Segment", rb_cObject);
@ -1937,25 +1284,6 @@ void Init_whisper() {
rb_define_method(cSegment, "end_time", ruby_whisper_segment_get_end_time, 0);
rb_define_method(cSegment, "speaker_next_turn?", ruby_whisper_segment_get_speaker_turn_next, 0);
rb_define_method(cSegment, "text", ruby_whisper_segment_get_text, 0);
rb_define_method(cSegment, "no_speech_prob", ruby_whisper_segment_get_no_speech_prob, 0);
cModel = rb_define_class_under(mWhisper, "Model", rb_cObject);
rb_define_alloc_func(cModel, ruby_whisper_model_allocate);
rb_define_method(cContext, "model", ruby_whisper_get_model, 0);
rb_define_method(cModel, "n_vocab", ruby_whisper_c_model_n_vocab, 0);
rb_define_method(cModel, "n_audio_ctx", ruby_whisper_c_model_n_audio_ctx, 0);
rb_define_method(cModel, "n_audio_state", ruby_whisper_c_model_n_audio_state, 0);
rb_define_method(cModel, "n_audio_head", ruby_whisper_c_model_n_audio_head, 0);
rb_define_method(cModel, "n_audio_layer", ruby_whisper_c_model_n_audio_layer, 0);
rb_define_method(cModel, "n_text_ctx", ruby_whisper_c_model_n_text_ctx, 0);
rb_define_method(cModel, "n_text_state", ruby_whisper_c_model_n_text_state, 0);
rb_define_method(cModel, "n_text_head", ruby_whisper_c_model_n_text_head, 0);
rb_define_method(cModel, "n_text_layer", ruby_whisper_c_model_n_text_layer, 0);
rb_define_method(cModel, "n_mels", ruby_whisper_c_model_n_mels, 0);
rb_define_method(cModel, "ftype", ruby_whisper_c_model_ftype, 0);
rb_define_method(cModel, "type", ruby_whisper_c_model_type, 0);
rb_require("whisper/model/uri");
}
#ifdef __cplusplus
}

View File

@ -1,5 +1,5 @@
#ifndef RUBY_WHISPER_H
#define RUBY_WHISPER_H
#ifndef __RUBY_WHISPER_H
#define __RUBY_WHISPER_H
#include "whisper.h"

View File

@ -1,6 +0,0 @@
require "yaml"
sources = `git ls-files -z ../..`.split("\x0")
paths = YAML.load_file("../../.github/workflows/bindings-ruby.yml")[true]["push"]["paths"]
paths.delete "bindings/ruby/**"
EXTSOURCES = (Dir.glob(paths, base: "../..").collect {|path| "../../#{path}"} << "../../LICENSE") & sources

View File

@ -0,0 +1,29 @@
---
- ../../src/whisper.cpp
- ../../include/whisper.h
- ../../ggml/src/ggml.c
- ../../ggml/src/ggml-impl.h
- ../../ggml/src/ggml-aarch64.h
- ../../ggml/src/ggml-aarch64.c
- ../../ggml/src/ggml-alloc.c
- ../../ggml/src/ggml-backend-impl.h
- ../../ggml/src/ggml-backend.cpp
- ../../ggml/src/ggml-common.h
- ../../ggml/src/ggml-quants.h
- ../../ggml/src/ggml-quants.c
- ../../ggml/src/ggml-cpu-impl.h
- ../../ggml/src/ggml-metal.m
- ../../ggml/src/ggml-metal.metal
- ../../ggml/src/ggml-blas.cpp
- ../../ggml/include/ggml.h
- ../../ggml/include/ggml-alloc.h
- ../../ggml/include/ggml-backend.h
- ../../ggml/include/ggml-cuda.h
- ../../ggml/include/ggml-kompute.h
- ../../ggml/include/ggml-metal.h
- ../../ggml/include/ggml-sycl.h
- ../../ggml/include/ggml-vulkan.h
- ../../ggml/include/ggml-blas.h
- ../../scripts/get-flags.mk
- ../../examples/dr_wav.h
- ../../LICENSE

View File

@ -1,163 +0,0 @@
require "uri"
require "net/http"
require "time"
require "pathname"
require "io/console/size"
module Whisper
class Model
class URI
def initialize(uri)
@uri = URI(uri)
end
def to_path
cache
cache_path.to_path
end
def clear_cache
path = cache_path
path.delete if path.exist?
end
private
def cache_path
base_cache_dir/@uri.host/@uri.path[1..]
end
def base_cache_dir
base = case RUBY_PLATFORM
when /mswin|mingw/
ENV.key?("LOCALAPPDATA") ? Pathname(ENV["LOCALAPPDATA"]) : Pathname(Dir.home)/"AppData/Local"
when /darwin/
Pathname(Dir.home)/"Library/Caches"
else
ENV.key?("XDG_CACHE_HOME") ? ENV["XDG_CACHE_HOME"] : Pathname(Dir.home)/".cache"
end
base/"whisper.cpp"
end
def cache
path = cache_path
headers = {}
headers["if-modified-since"] = path.mtime.httpdate if path.exist?
request @uri, headers
path
end
def request(uri, headers)
Net::HTTP.start uri.host, uri.port, use_ssl: uri.scheme == "https" do |http|
request = Net::HTTP::Get.new(uri, headers)
http.request request do |response|
case response
when Net::HTTPNotModified
# noop
when Net::HTTPOK
download response
when Net::HTTPRedirection
request URI(response["location"]), headers
else
return if headers.key?("if-modified-since") # Use cache file
raise "#{response.code} #{response.message}\n#{response.body}"
end
end
end
end
def download(response)
path = cache_path
path.dirname.mkpath unless path.dirname.exist?
downloading_path = Pathname("#{path}.downloading")
size = response.content_length
downloading_path.open "wb" do |file|
downloaded = 0
response.read_body do |chunk|
file << chunk
downloaded += chunk.bytesize
show_progress downloaded, size
end
$stderr.puts
end
downloading_path.rename path
end
def show_progress(current, size)
progress_rate_available = size && $stderr.tty?
unless @prev
@prev = Time.now
$stderr.puts "Downloading #{@uri} to #{cache_path}"
end
now = Time.now
if progress_rate_available
return if now - @prev < 1 && current < size
progress_width = 20
progress = current.to_f / size
arrow_length = progress * progress_width
arrow = "=" * (arrow_length - 1) + ">" + " " * (progress_width - arrow_length)
line = "[#{arrow}] (#{format_bytesize(current)} / #{format_bytesize(size)})"
padding = ' ' * ($stderr.winsize[1] - line.size)
$stderr.print "\r#{line}#{padding}"
else
return if now - @prev < 1
$stderr.print "."
end
@prev = now
end
def format_bytesize(bytesize)
return "0.0 B" if bytesize.zero?
units = %w[B KiB MiB GiB TiB]
exp = (Math.log(bytesize) / Math.log(1024)).to_i
format("%.1f %s", bytesize.to_f / 1024 ** exp, units[exp])
end
end
@pre_converted_models = %w[
tiny
tiny.en
tiny-q5_1
tiny.en-q5_1
tiny-q8_0
base
base.en
base-q5_1
base.en-q5_1
base-q8_0
small
small.en
small.en-tdrz
small-q5_1
small.en-q5_1
small-q8_0
medium
medium.en
medium-q5_0
medium.en-q5_0
medium-q8_0
large-v1
large-v2
large-v2-q5_0
large-v2-q8_0
large-v3
large-v3-q5_0
large-v3-turbo
large-v3-turbo-q5_0
large-v3-turbo-q8_0
].each_with_object({}) {|name, models|
models[name] = URI.new("https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-#{name}.bin")
}
class << self
attr_reader :pre_converted_models
end
end
end

View File

@ -1,153 +0,0 @@
module Whisper
interface _Samples
def length: () -> Integer
def each: { (Float) -> void } -> void
end
type log_callback = ^(Integer level, String message, Object user_data) -> void
type new_segment_callback = ^(Whisper::Context, void, Integer n_new, Object user_data) -> void
type progress_callback = ^(Whisper::Context, void, Integer progress, Object user_data) -> void
type abort_callback = ^(Whisper::Context, void, Object user_data) -> boolish
LOG_LEVEL_NONE: Integer
LOG_LEVEL_INFO: Integer
LOG_LEVEL_WARN: Integer
LOG_LEVEL_ERROR: Integer
LOG_LEVEL_DEBUG: Integer
LOG_LEVEL_CONT: Integer
def self.lang_max_id: () -> Integer
def self.lang_id: (string name) -> Integer
def self.lang_str: (Integer id) -> String
def self.lang_str_full: (Integer id) -> String
def self.log_set=: (log_callback) -> log_callback
def self.finalize_log_callback: (void) -> void # Second argument of ObjectSpace.define_finalizer
class Context
def initialize: (string | _ToPath | ::URI::HTTP ) -> void
def transcribe: (string, Params) -> void
| (string, Params) { (String) -> void } -> void
def model_n_vocab: () -> Integer
def model_n_audio_ctx: () -> Integer
def model_n_audio_state: () -> Integer
def model_n_text_head: () -> Integer
def model_n_text_layer: () -> Integer
def model_n_mels: () -> Integer
def model_ftype: () -> Integer
def model_type: () -> String
def full_n_segments: () -> Integer
def full_lang_id: () -> Integer
def full_get_segment_t0: (Integer) -> Integer
def full_get_segment_t1: (Integer) -> Integer
def full_get_segment_speaker_turn_next: (Integer) -> (true | false)
def full_get_segment_text: (Integer) -> String
def full_get_segment_no_speech_prob: (Integer) -> Float
def full: (Params, Array[Float], ?Integer) -> void
| (Params, _Samples, ?Integer) -> void
def full_parallel: (Params, Array[Float], ?Integer) -> void
| (Params, _Samples, ?Integer) -> void
| (Params, _Samples, ?Integer?, Integer) -> void
def each_segment: { (Segment) -> void } -> void
| () -> Enumerator[Segment]
def model: () -> Model
end
class Params
def initialize: () -> void
def language=: (String) -> String # TODO: Enumerate lang names
def language: () -> String
def translate=: (boolish) -> boolish
def translate: () -> (true | false)
def no_context=: (boolish) -> boolish
def no_context: () -> (true | false)
def single_segment=: (boolish) -> boolish
def single_segment: () -> (true | false)
def print_special=: (boolish) -> boolish
def print_special: () -> (true | false)
def print_progress=: (boolish) -> boolish
def print_progress: () -> (true | false)
def print_realtime=: (boolish) -> boolish
def print_realtime: () -> (true | false)
def print_timestamps=: (boolish) -> boolish
def print_timestamps: () -> (true | false)
def suppress_blank=: (boolish) -> boolish
def suppress_blank: () -> (true | false)
def suppress_nst=: (boolish) -> boolish
def suppress_nst: () -> (true | false)
def token_timestamps=: (boolish) -> boolish
def token_timestamps: () -> (true | false)
def split_on_word=: (boolish) -> boolish
def split_on_word: () -> (true | false)
def initial_prompt=: (_ToS) -> _ToS
def initial_prompt: () -> String
def diarize=: (boolish) -> boolish
def diarize: () -> (true | false)
def offset=: (Integer) -> Integer
def offset: () -> Integer
def duration=: (Integer) -> Integer
def duration: () -> Integer
def max_text_tokens=: (Integer) -> Integer
def max_text_tokens: () -> Integer
def temperature=: (Float) -> Float
def temperature: () -> Float
def max_initial_ts=: (Float) -> Float
def max_initial_ts: () -> Float
def length_penalty=: (Float) -> Float
def length_penalty: () -> Float
def temperature_inc=: (Float) -> Float
def temperature_inc: () -> Float
def entropy_thold=: (Float) -> Float
def entropy_thold: () -> Float
def logprob_thold=: (Float) -> Float
def logprob_thold: () -> Float
def no_speech_thold=: (Float) -> Float
def no_speech_thold: () -> Float
def new_segment_callback=: (new_segment_callback) -> new_segment_callback
def new_segment_callback_user_data=: (Object) -> Object
def progress_callback=: (progress_callback) -> progress_callback
def progress_callback_user_data=: (Object) -> Object
def abort_callback=: (abort_callback) -> abort_callback
def abort_callback_user_data=: (Object) -> Object
def on_new_segment: { (Segment) -> void } -> void
def on_progress: { (Integer) -> void } -> void
def abort_on: { (Object) -> boolish } -> void
end
class Model
def self.pre_converted_models: () -> Hash[String, Model::URI]
def initialize: () -> void
def n_vocab: () -> Integer
def n_audio_ctx: () -> Integer
def n_audio_state: () -> Integer
def n_audio_head: () -> Integer
def n_audio_layer: () -> Integer
def n_text_ctx: () -> Integer
def n_text_state: () -> Integer
def n_text_head: () -> Integer
def n_text_layer: () -> Integer
def n_mels: () -> Integer
def ftype: () -> Integer
def type: () -> String
class URI
def initialize: (string | ::URI::HTTP) -> void
def to_path: -> String
def clear_cache: -> void
end
end
class Segment
def initialize: () -> void
def start_time: () -> Integer
def end_time: () -> Integer
def speaker_next_turn?: () -> (true | false)
def text: () -> String
def no_speech_prob: () -> Float
end
class Error < StandardError
attr_reader code: Integer
def initialize: (Integer) -> void
end
end

View File

@ -1,24 +0,0 @@
require "test/unit"
require "whisper"
require_relative "jfk_reader/jfk_reader"
class TestBase < Test::Unit::TestCase
AUDIO = File.join(__dir__, "..", "..", "..", "samples", "jfk.wav")
class << self
attr_reader :whisper
def startup
@whisper = Whisper::Context.new("base.en")
params = Whisper::Params.new
params.print_timestamps = false
@whisper.transcribe(TestBase::AUDIO, params)
end
end
private
def whisper
self.class.whisper
end
end

View File

@ -1,5 +0,0 @@
Makefile
jfk_reader.o
jfk_reader.so
jfk_reader.bundle
jfk_reader.dll

View File

@ -1,3 +0,0 @@
require "mkmf"
create_makefile("jfk_reader")

View File

@ -1,68 +0,0 @@
#include <ruby.h>
#include <ruby/memory_view.h>
#include <ruby/encoding.h>
static VALUE
jfk_reader_initialize(VALUE self, VALUE audio_path)
{
rb_iv_set(self, "audio_path", audio_path);
return Qnil;
}
static bool
jfk_reader_get_memory_view(const VALUE obj, rb_memory_view_t *view, int flags)
{
VALUE audio_path = rb_iv_get(obj, "audio_path");
const char *audio_path_str = StringValueCStr(audio_path);
const int n_samples = 176000;
float *data = (float *)malloc(n_samples * sizeof(float));
short *samples = (short *)malloc(n_samples * sizeof(short));
FILE *file = fopen(audio_path_str, "rb");
fseek(file, 78, SEEK_SET);
fread(samples, sizeof(short), n_samples, file);
fclose(file);
for (int i = 0; i < n_samples; i++) {
data[i] = samples[i]/32768.0;
}
view->obj = obj;
view->data = (void *)data;
view->byte_size = sizeof(float) * n_samples;
view->readonly = true;
view->format = "f";
view->item_size = sizeof(float);
view->item_desc.components = NULL;
view->item_desc.length = 0;
view->ndim = 1;
view->shape = NULL;
view->sub_offsets = NULL;
view->private_data = NULL;
return true;
}
static bool
jfk_reader_release_memory_view(const VALUE obj, rb_memory_view_t *view)
{
return true;
}
static bool
jfk_reader_memory_view_available_p(const VALUE obj)
{
return true;
}
static const rb_memory_view_entry_t jfk_reader_view_entry = {
jfk_reader_get_memory_view,
jfk_reader_release_memory_view,
jfk_reader_memory_view_available_p
};
void Init_jfk_reader(void)
{
VALUE cJFKReader = rb_define_class("JFKReader", rb_cObject);
rb_memory_view_register(cJFKReader, &jfk_reader_view_entry);
rb_define_method(cJFKReader, "initialize", jfk_reader_initialize, 1);
}

View File

@ -1,11 +1,14 @@
require_relative "helper"
require "test/unit"
require "whisper"
class TestCallback < Test::Unit::TestCase
TOPDIR = File.expand_path(File.join(File.dirname(__FILE__), '..'))
class TestCallback < TestBase
def setup
GC.start
@params = Whisper::Params.new
@whisper = Whisper::Context.new("base.en")
@audio = File.join(AUDIO)
@whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin'))
@audio = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav')
end
def test_new_segment_callback

View File

@ -1,20 +0,0 @@
require_relative "helper"
class TestError < TestBase
def test_error
error = Whisper::Error.new(-2)
assert_equal "failed to compute log mel spectrogram", error.message
assert_equal -2, error.code
end
def test_unknown_error
error = Whisper::Error.new(-20)
assert_equal "unknown error", error.message
end
def test_non_int_code
assert_raise TypeError do
error = Whisper::Error.new("non int")
end
end
end

View File

@ -1,109 +0,0 @@
require_relative "helper"
require "pathname"
class TestModel < TestBase
def test_model
whisper = Whisper::Context.new("base.en")
assert_instance_of Whisper::Model, whisper.model
end
def test_attributes
whisper = Whisper::Context.new("base.en")
model = whisper.model
assert_equal 51864, model.n_vocab
assert_equal 1500, model.n_audio_ctx
assert_equal 512, model.n_audio_state
assert_equal 8, model.n_audio_head
assert_equal 6, model.n_audio_layer
assert_equal 448, model.n_text_ctx
assert_equal 512, model.n_text_state
assert_equal 8, model.n_text_head
assert_equal 6, model.n_text_layer
assert_equal 80, model.n_mels
assert_equal 1, model.ftype
assert_equal "base", model.type
end
def test_gc
model = Whisper::Context.new("base.en").model
GC.start
assert_equal 51864, model.n_vocab
assert_equal 1500, model.n_audio_ctx
assert_equal 512, model.n_audio_state
assert_equal 8, model.n_audio_head
assert_equal 6, model.n_audio_layer
assert_equal 448, model.n_text_ctx
assert_equal 512, model.n_text_state
assert_equal 8, model.n_text_head
assert_equal 6, model.n_text_layer
assert_equal 80, model.n_mels
assert_equal 1, model.ftype
assert_equal "base", model.type
end
def test_pathname
path = Pathname(Whisper::Model.pre_converted_models["base.en"].to_path)
whisper = Whisper::Context.new(path)
model = whisper.model
assert_equal 51864, model.n_vocab
assert_equal 1500, model.n_audio_ctx
assert_equal 512, model.n_audio_state
assert_equal 8, model.n_audio_head
assert_equal 6, model.n_audio_layer
assert_equal 448, model.n_text_ctx
assert_equal 512, model.n_text_state
assert_equal 8, model.n_text_head
assert_equal 6, model.n_text_layer
assert_equal 80, model.n_mels
assert_equal 1, model.ftype
assert_equal "base", model.type
end
def test_auto_download
path = Whisper::Model.pre_converted_models["base.en"].to_path
assert_path_exist path
assert_equal 147964211, File.size(path)
end
def test_uri_string
path = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin"
whisper = Whisper::Context.new(path)
model = whisper.model
assert_equal 51864, model.n_vocab
assert_equal 1500, model.n_audio_ctx
assert_equal 512, model.n_audio_state
assert_equal 8, model.n_audio_head
assert_equal 6, model.n_audio_layer
assert_equal 448, model.n_text_ctx
assert_equal 512, model.n_text_state
assert_equal 8, model.n_text_head
assert_equal 6, model.n_text_layer
assert_equal 80, model.n_mels
assert_equal 1, model.ftype
assert_equal "base", model.type
end
def test_uri
path = URI("https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin")
whisper = Whisper::Context.new(path)
model = whisper.model
assert_equal 51864, model.n_vocab
assert_equal 1500, model.n_audio_ctx
assert_equal 512, model.n_audio_state
assert_equal 8, model.n_audio_head
assert_equal 6, model.n_audio_layer
assert_equal 448, model.n_text_ctx
assert_equal 512, model.n_text_state
assert_equal 8, model.n_text_head
assert_equal 6, model.n_text_layer
assert_equal 80, model.n_mels
assert_equal 1, model.ftype
assert_equal "base", model.type
end
end

View File

@ -1,9 +1,9 @@
require_relative "helper"
require 'test/unit'
require 'tempfile'
require 'tmpdir'
require 'shellwords'
class TestPackage < TestBase
class TestPackage < Test::Unit::TestCase
def test_build
Tempfile.create do |file|
assert system("gem", "build", "whispercpp.gemspec", "--output", file.to_path.shellescape, exception: true)
@ -23,7 +23,7 @@ class TestPackage < TestBase
version = match_data[2]
basename = "whisper.#{RbConfig::CONFIG["DLEXT"]}"
Dir.mktmpdir do |dir|
system "gem", "install", "--install-dir", dir.shellescape, "--no-document", "pkg/#{filename.shellescape}", exception: true
system "gem", "install", "--install-dir", dir.shellescape, "pkg/#{filename.shellescape}", exception: true
assert_path_exist File.join(dir, "gems/whispercpp-#{version}/lib", basename)
end
end

View File

@ -1,6 +1,7 @@
require_relative "helper"
require 'test/unit'
require 'whisper'
class TestParams < TestBase
class TestParams < Test::Unit::TestCase
def setup
@params = Whisper::Params.new
end
@ -89,11 +90,11 @@ class TestParams < TestBase
assert !@params.suppress_blank
end
def test_suppress_nst
@params.suppress_nst = true
assert @params.suppress_nst
@params.suppress_nst = false
assert !@params.suppress_nst
def test_suppress_non_speech_tokens
@params.suppress_non_speech_tokens = true
assert @params.suppress_non_speech_tokens
@params.suppress_non_speech_tokens = false
assert !@params.suppress_non_speech_tokens
end
def test_token_timestamps
@ -151,10 +152,4 @@ class TestParams < TestBase
@params.logprob_thold = -0.5
assert_in_delta -0.5, @params.logprob_thold
end
def test_no_speech_thold
assert_in_delta 0.6, @params.no_speech_thold
@params.no_speech_thold = 0.2
assert_in_delta 0.2, @params.no_speech_thold
end
end

View File

@ -1,6 +1,21 @@
require_relative "helper"
require "test/unit"
require "whisper"
class TestSegment < Test::Unit::TestCase
TOPDIR = File.expand_path(File.join(File.dirname(__FILE__), '..'))
class << self
attr_reader :whisper
def startup
@whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin'))
params = Whisper::Params.new
params.print_timestamps = false
jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav')
@whisper.transcribe(jfk, params)
end
end
class TestSegment < TestBase
def test_iteration
whisper.each_segment do |segment|
assert_instance_of Whisper::Segment, segment
@ -32,14 +47,6 @@ class TestSegment < TestBase
end
end
def test_no_speech_prob
no_speech_prob = nil
whisper.each_segment do |segment|
no_speech_prob = segment.no_speech_prob
end
assert no_speech_prob > 0.0
end
def test_on_new_segment
params = Whisper::Params.new
seg = nil
@ -53,7 +60,7 @@ class TestSegment < TestBase
end
index += 1
end
whisper.transcribe(AUDIO, params)
whisper.transcribe(File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav'), params)
assert_equal 0, seg.start_time
assert_match /ask not what your country can do for you, ask what you can do for your country/, seg.text
end
@ -69,6 +76,12 @@ class TestSegment < TestBase
assert_same seg, segment
return
end
whisper.transcribe(AUDIO, params)
whisper.transcribe(File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav'), params)
end
private
def whisper
self.class.whisper
end
end

View File

@ -1,26 +1,41 @@
require_relative "helper"
require "stringio"
require "etc"
require 'whisper'
require 'test/unit'
# Exists to detect memory-related bug
Whisper.log_set ->(level, buffer, user_data) {}, nil
class TestWhisper < Test::Unit::TestCase
TOPDIR = File.expand_path(File.join(File.dirname(__FILE__), '..'))
class TestWhisper < TestBase
def setup
@params = Whisper::Params.new
end
def test_whisper
@whisper = Whisper::Context.new("base.en")
@whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin'))
params = Whisper::Params.new
params.print_timestamps = false
@whisper.transcribe(AUDIO, params) {|text|
jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav')
@whisper.transcribe(jfk, params) {|text|
assert_match /ask not what your country can do for you, ask what you can do for your country/, text
}
end
sub_test_case "After transcription" do
class << self
attr_reader :whisper
def startup
@whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin'))
params = Whisper::Params.new
params.print_timestamps = false
jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav')
@whisper.transcribe(jfk, params)
end
end
def whisper
self.class.whisper
end
def test_full_n_segments
assert_equal 1, whisper.full_n_segments
end
@ -55,12 +70,6 @@ class TestWhisper < TestBase
def test_full_get_segment_text
assert_match /ask not what your country can do for you, ask what you can do for your country/, whisper.full_get_segment_text(0)
end
def test_full_get_segment_no_speech_prob
prob = whisper.full_get_segment_no_speech_prob(0)
assert prob > 0.0
assert prob < 1.0
end
end
def test_lang_max_id
@ -87,131 +96,4 @@ class TestWhisper < TestBase
Whisper.lang_str_full(Whisper.lang_max_id + 1)
end
end
def test_log_set
user_data = Object.new
logs = []
log_callback = ->(level, buffer, udata) {
logs << [level, buffer, udata]
}
Whisper.log_set log_callback, user_data
Whisper::Context.new("base.en")
assert logs.length > 30
logs.each do |log|
assert_include [Whisper::LOG_LEVEL_DEBUG, Whisper::LOG_LEVEL_INFO, Whisper::LOG_LEVEL_WARN], log[0]
assert_same user_data, log[2]
end
end
def test_log_suppress
stderr = $stderr
Whisper.log_set ->(level, buffer, user_data) {
# do nothing
}, nil
dev = StringIO.new("")
$stderr = dev
Whisper::Context.new("base.en")
assert_empty dev.string
ensure
$stderr = stderr
end
sub_test_case "full" do
def setup
super
@whisper = Whisper::Context.new("base.en")
@samples = File.read(AUDIO, nil, 78).unpack("s<*").collect {|i| i.to_f / 2**15}
end
def test_full
@whisper.full(@params, @samples, @samples.length)
assert_equal 1, @whisper.full_n_segments
assert_match /ask not what your country can do for you, ask what you can do for your country/, @whisper.each_segment.first.text
end
def test_full_without_length
@whisper.full(@params, @samples)
assert_equal 1, @whisper.full_n_segments
assert_match /ask not what your country can do for you, ask what you can do for your country/, @whisper.each_segment.first.text
end
def test_full_enumerator
samples = @samples.each
@whisper.full(@params, samples, @samples.length)
assert_equal 1, @whisper.full_n_segments
assert_match /ask not what your country can do for you, ask what you can do for your country/, @whisper.each_segment.first.text
end
def test_full_enumerator_without_length
samples = @samples.each
assert_raise ArgumentError do
@whisper.full(@params, samples)
end
end
def test_full_enumerator_with_too_large_length
samples = @samples.each.take(10).to_enum
assert_raise StopIteration do
@whisper.full(@params, samples, 11)
end
end
def test_full_with_memory_view
samples = JFKReader.new(AUDIO)
@whisper.full(@params, samples)
assert_equal 1, @whisper.full_n_segments
assert_match /ask not what your country can do for you, ask what you can do for your country/, @whisper.each_segment.first.text
end
def test_full_parallel
@whisper.full_parallel(@params, @samples, @samples.length, Etc.nprocessors)
assert_equal Etc.nprocessors, @whisper.full_n_segments
text = @whisper.each_segment.collect(&:text).join
assert_match /ask what you can do/i, text
assert_match /for your country/i, text
end
def test_full_parallel_with_memory_view
samples = JFKReader.new(AUDIO)
@whisper.full_parallel(@params, samples, nil, Etc.nprocessors)
assert_equal Etc.nprocessors, @whisper.full_n_segments
text = @whisper.each_segment.collect(&:text).join
assert_match /ask what you can do/i, text
assert_match /for your country/i, text
end
def test_full_parallel_without_length_and_n_processors
@whisper.full_parallel(@params, @samples)
assert_equal 1, @whisper.full_n_segments
text = @whisper.each_segment.collect(&:text).join
assert_match /ask what you can do/i, text
assert_match /for your country/i, text
end
def test_full_parallel_without_length
@whisper.full_parallel(@params, @samples, nil, Etc.nprocessors)
assert_equal Etc.nprocessors, @whisper.full_n_segments
text = @whisper.each_segment.collect(&:text).join
assert_match /ask what you can do/i, text
assert_match /for your country/i, text
end
def test_full_parallel_without_n_processors
@whisper.full_parallel(@params, @samples, @samples.length)
assert_equal 1, @whisper.full_n_segments
text = @whisper.each_segment.collect(&:text).join
assert_match /ask what you can do/i, text
assert_match /for your country/i, text
end
end
end

View File

@ -1,36 +1,36 @@
require_relative "extsources"
require "yaml"
Gem::Specification.new do |s|
s.name = "whispercpp"
s.authors = ["Georgi Gerganov", "Todd A. Fisher"]
s.version = '1.3.1'
s.date = '2024-12-19'
s.version = '1.3.0'
s.date = '2024-05-14'
s.description = %q{High-performance inference of OpenAI's Whisper automatic speech recognition (ASR) model via Ruby}
s.email = 'todd.fisher@gmail.com'
s.extra_rdoc_files = ['LICENSE', 'README.md']
s.files = `git ls-files . -z`.split("\x0") +
EXTSOURCES.collect {|file|
YAML.load_file("extsources.yaml").collect {|file|
basename = File.basename(file)
if s.extra_rdoc_files.include?(basename)
basename
else
file.sub("../..", "ext")
File.join("ext", basename)
end
}
s.summary = %q{Ruby whisper.cpp bindings}
s.test_files = s.files.select {|file| file.start_with? "tests/"}
s.test_files = ["tests/test_whisper.rb"]
s.extensions << 'ext/extconf.rb'
s.required_ruby_version = '>= 3.1.0'
#### Documentation and testing.
s.homepage = 'https://github.com/ggerganov/whisper.cpp'
s.rdoc_options = ['--main', 'README.md']
s.rdoc_options = ['--main', '../../README.md']
s.platform = Gem::Platform::RUBY
s.licenses = ['MIT']
end

View File

@ -13,4 +13,5 @@ set_target_properties(${TARGET}
PROPERTIES
EXPORT_COMPILE_COMMANDS ON
RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin"
INSTALL_RPATH "${CMAKE_INSTALL_PREFIX}/lib"
)

View File

@ -1,10 +1,10 @@
prefix=@CMAKE_INSTALL_PREFIX@
exec_prefix=${prefix}
libdir=${exec_prefix}/lib
libdir=@CMAKE_INSTALL_FULL_LIBDIR@
includedir=${prefix}/include
Name: whisper
Description: Port of OpenAI's Whisper model in C/C++
Version: @PROJECT_VERSION@
Libs: -L${libdir} -lggml -lggml-base -lwhisper
Libs: -L${libdir} -lwhisper
Cflags: -I${includedir}

View File

@ -97,29 +97,52 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR})
if (EMSCRIPTEN)
add_subdirectory(whisper.wasm)
set_target_properties(libmain PROPERTIES FOLDER "libs")
add_subdirectory(stream.wasm)
set_target_properties(libstream PROPERTIES FOLDER "libs")
add_subdirectory(command.wasm)
set_target_properties(libcommand PROPERTIES FOLDER "libs")
#add_subdirectory(talk.wasm)
#set_target_properties(libtalk PROPERTIES FOLDER "libs")
add_subdirectory(bench.wasm)
set_target_properties(libbench PROPERTIES FOLDER "libs")
elseif(CMAKE_JS_VERSION)
add_subdirectory(addon.node)
set_target_properties(addon.node PROPERTIES FOLDER "examples")
else()
add_subdirectory(cli)
add_subdirectory(bench)
add_subdirectory(main)
set_target_properties(main PROPERTIES FOLDER "examples")
if (WHISPER_SDL2)
add_subdirectory(stream)
set_target_properties(stream PROPERTIES FOLDER "examples")
endif (WHISPER_SDL2)
add_subdirectory(server)
set_target_properties(server PROPERTIES FOLDER "examples")
if (WHISPER_SDL2)
add_subdirectory(command)
set_target_properties(command PROPERTIES FOLDER "examples")
endif (WHISPER_SDL2)
add_subdirectory(bench)
set_target_properties(bench PROPERTIES FOLDER "examples")
add_subdirectory(quantize)
if (WHISPER_SDL2)
add_subdirectory(stream)
add_subdirectory(command)
add_subdirectory(talk-llama)
add_subdirectory(lsp)
if (GGML_SYCL)
add_subdirectory(sycl)
endif()
endif (WHISPER_SDL2)
add_subdirectory(deprecation-warning)
set_target_properties(quantize PROPERTIES FOLDER "examples")
if (WHISPER_SDL2)
# TODO: disabled until update
# https://github.com/ggerganov/whisper.cpp/issues/1818
#add_subdirectory(talk)
#set_target_properties(talk PROPERTIES FOLDER "examples")
add_subdirectory(talk-llama)
set_target_properties(talk-llama PROPERTIES FOLDER "examples")
add_subdirectory(lsp)
set_target_properties(lsp PROPERTIES FOLDER "examples")
if (GGML_SYCL)
add_subdirectory(sycl)
set_target_properties(sycl PROPERTIES FOLDER "examples")
endif()
endif (WHISPER_SDL2)
endif()
if (WHISPER_SDL2)
add_subdirectory(wchess)
set_target_properties(wchess PROPERTIES FOLDER "examples")
endif (WHISPER_SDL2)

View File

@ -1,8 +1,6 @@
set(TARGET whisper-bench)
set(TARGET bench)
add_executable(${TARGET} bench.cpp)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE whisper ${CMAKE_THREAD_LIBS_INIT})
install(TARGETS ${TARGET} RUNTIME)

View File

@ -1,4 +1,4 @@
# whisper.cpp/examples/bench
# bench
A very basic tool for benchmarking the inference performance on your device. The tool simply runs the Encoder part of
the transformer on some random audio data and records the execution time. This way we can have an objective comparison
@ -7,8 +7,11 @@ of the performance of the model for various setups.
Benchmark results are tracked in the following Github issue: https://github.com/ggerganov/whisper.cpp/issues/89
```bash
# run the bench too on the small.en model using 4 threads
$ ./build/bin/whisper-bench -m ./models/ggml-small.en.bin -t 4
# build the bench tool
$ make bench
# run it on the small.en model using 4 threads
$ ./bench -m ./models/ggml-small.en.bin -t 4
whisper_model_load: loading model from './models/ggml-small.en.bin'
whisper_model_load: n_vocab = 51864

View File

@ -1,10 +1,9 @@
if (WHISPER_SDL2)
set(TARGET whisper-command)
# command
set(TARGET command)
add_executable(${TARGET} command.cpp)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE common common-sdl whisper ${CMAKE_THREAD_LIBS_INIT})
install(TARGETS ${TARGET} RUNTIME)
endif ()

View File

@ -1,14 +1,14 @@
# whisper.cpp/examples/command
# command
This is a basic Voice Assistant example that accepts voice commands from the microphone.
More info is available in [issue #171](https://github.com/ggerganov/whisper.cpp/issues/171).
```bash
# Run with default arguments and small model
./whisper-command -m ./models/ggml-small.en.bin -t 8
./command -m ./models/ggml-small.en.bin -t 8
# On Raspberry Pi, use tiny or base models + "-ac 768" for better performance
./whisper-command -m ./models/ggml-tiny.en.bin -ac 768 -t 3 -c 0
./command -m ./models/ggml-tiny.en.bin -ac 768 -t 3 -c 0
```
https://user-images.githubusercontent.com/1991296/204038393-2f846eae-c255-4099-a76d-5735c25c49da.mp4
@ -23,10 +23,10 @@ Initial tests show that this approach might be extremely efficient in terms of p
```bash
# Run in guided mode, the list of allowed commands is in commands.txt
./whisper-command -m ./models/ggml-base.en.bin -cmd ./examples/command/commands.txt
./command -m ./models/ggml-base.en.bin -cmd ./examples/command/commands.txt
# On Raspberry Pi, in guided mode you can use "-ac 128" for extra performance
./whisper-command -m ./models/ggml-tiny.en.bin -cmd ./examples/command/commands.txt -ac 128 -t 3 -c 0
./command -m ./models/ggml-tiny.en.bin -cmd ./examples/command/commands.txt -ac 128 -t 3 -c 0
```
https://user-images.githubusercontent.com/1991296/207435352-8fc4ed3f-bde5-4555-9b8b-aeeb76bee969.mp4
@ -34,7 +34,7 @@ https://user-images.githubusercontent.com/1991296/207435352-8fc4ed3f-bde5-4555-9
## Building
The `whisper-command` tool depends on SDL2 library to capture audio from the microphone. You can build it like this:
The `command` tool depends on SDL2 library to capture audio from the microphone. You can build it like this:
```bash
# Install SDL2
@ -47,6 +47,5 @@ sudo dnf install SDL2 SDL2-devel
# Install SDL2 on Mac OS
brew install sdl2
cmake -B build -DWHISPER_SDL2=ON
cmake --build build --config Release
make command
```

View File

@ -72,6 +72,9 @@ bool ggml_common_quantize_0(
case GGML_FTYPE_MOSTLY_IQ4_XS:
case GGML_FTYPE_MOSTLY_IQ1_M:
case GGML_FTYPE_MOSTLY_BF16:
case GGML_FTYPE_MOSTLY_Q4_0_4_4:
case GGML_FTYPE_MOSTLY_Q4_0_4_8:
case GGML_FTYPE_MOSTLY_Q4_0_8_8:
{
fprintf(stderr, "%s: invalid model type %d\n", __func__, ftype);
return false;
@ -209,6 +212,9 @@ bool ggml_common_quantize_0(
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ1_M:
case GGML_TYPE_BF16:
case GGML_TYPE_Q4_0_4_4:
case GGML_TYPE_Q4_0_4_8:
case GGML_TYPE_Q4_0_8_8:
case GGML_TYPE_TQ1_0:
case GGML_TYPE_TQ2_0:
case GGML_TYPE_COUNT:

View File

@ -1,7 +1,5 @@
#include "common-sdl.h"
#include <cstdio>
audio_async::audio_async(int len_ms) {
m_len_ms = len_ms;

View File

@ -1,4 +0,0 @@
add_executable(main ./deprecation-warning.cpp)
add_executable(bench ./deprecation-warning.cpp)
add_executable(stream ./deprecation-warning.cpp)
add_executable(command ./deprecation-warning.cpp)

View File

@ -1,17 +0,0 @@
# Migration notice for binary filenames
> [!IMPORTANT]
[2024 Dec 20] Binaries have been renamed w/ a `whisper-` prefix. `main` is now `whisper-cli`, `server` is `whisper-server`, etc (https://github.com/ggerganov/whisper.cpp/pull/2648)
This migration was important, but it is a breaking change that may not always be immediately obvious to users.
Please update all scripts and workflows to use the new binary names.
| Old Filename | New Filename |
| ---- | ---- |
| main | whisper-cli |
| bench | whisper-bench |
| stream | whisper-stream |
| command | whisper-command |
| server | whisper-server |
| talk-llama | whisper-talk-llama |

View File

@ -1,38 +0,0 @@
// Warns users that this filename was deprecated, and provides a link for more information.
#include <cstdio>
#include <string>
// Main
int main(int argc, char** argv) {
std::string filename = "main";
if (argc >= 1) {
filename = argv[0];
}
// Get only the program name from the full path
size_t pos = filename.find_last_of("/\\");
if (pos != std::string::npos) {
filename = filename.substr(pos+1);
}
// Append "whisper-" to the beginning of filename to get the replacemnt filename
std::string replacement_filename = "whisper-" + filename;
// The exception is if the filename is "main", then our replacement filename is "whisper-cli"
if (filename == "main") {
replacement_filename = "whisper-cli";
}
if (filename == "main.exe") {
replacement_filename = "whisper-cli.exe";
}
fprintf(stdout, "\n");
fprintf(stdout, "WARNING: The binary '%s' is deprecated.\n", filename.c_str());
fprintf(stdout, " Please use '%s' instead.\n", replacement_filename.c_str());
fprintf(stdout, " See https://github.com/ggerganov/whisper.cpp/tree/master/examples/deprecation-warning/README.md for more information.\n");
fprintf(stdout, "\n");
return EXIT_FAILURE;
}

View File

@ -204,6 +204,8 @@ static int decode_audio(struct audio_buffer *audio_buf, s16 **data, int *size)
const size_t errbuffsize = 1024;
char errbuff[errbuffsize];
av_register_all(); // from avformat. Still a must-have call for ffmpeg v3! (can be skipped for later versions)
fmt_ctx = avformat_alloc_context();
avio_ctx_buffer = (u8*)av_malloc(AVIO_CTX_BUF_SZ);
LOG("Creating an avio context: AVIO_CTX_BUF_SZ=%d\n", AVIO_CTX_BUF_SZ);

View File

@ -11,7 +11,7 @@
# Press Ctrl+C to stop recording
#
executable="./build/bin/whisper-cli"
executable="./main"
model="base.en"
model_path="models/ggml-$model.bin"
@ -46,7 +46,7 @@ ffmpeg -y -i ./rec.wav -ar 16000 -ac 1 -c:a pcm_s16le ./rec16.wav > /dev/null 2>
# run Whisper
echo "Processing ..."
${executable} -m models/ggml-base.en.bin rec16.wav -owts > /dev/null 2>&1
./main -m models/ggml-base.en.bin rec16.wav -owts > /dev/null 2>&1
# generate Karaoke video
echo "Generating video ..."

View File

@ -14,7 +14,7 @@ model="base.en"
check_requirements()
{
if ! command -v ./build/bin/whisper-cli &>/dev/null; then
if ! command -v ./main &>/dev/null; then
echo "whisper.cpp main executable is required (make)"
exit 1
fi
@ -100,7 +100,7 @@ while [ $running -eq 1 ]; do
err=$(cat /tmp/whisper-live.err | wc -l)
done
./build/bin/whisper-cli -t 8 -m ./models/ggml-${model}.bin -f /tmp/whisper-live.wav --no-timestamps -otxt 2> /tmp/whispererr | tail -n 1
./main -t 8 -m ./models/ggml-${model}.bin -f /tmp/whisper-live.wav --no-timestamps -otxt 2> /tmp/whispererr | tail -n 1
while [ $SECONDS -lt $((($i+1)*$step_s)) ]; do
sleep 1
@ -109,4 +109,4 @@ while [ $running -eq 1 ]; do
done
killall -v ffmpeg
killall -v whisper-cli
killall -v main

View File

@ -181,7 +181,7 @@ static json unguided_transcription(struct whisper_context * ctx, audio_async &au
wparams.n_threads = params.n_threads;
wparams.audio_ctx = params.audio_ctx;
wparams.suppress_nst = true;
wparams.suppress_non_speech_tokens = true;
// run the transformer and a single decoding pass
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
fprintf(stderr, "%s: ERROR: whisper_full() failed\n", __func__);
@ -225,7 +225,7 @@ static json guided_transcription(struct whisper_context * ctx, audio_async &audi
wparams.prompt_tokens = cs.prompt_tokens.data();
wparams.prompt_n_tokens = cs.prompt_tokens.size();
// TODO: properly expose as option
wparams.suppress_nst = true;
wparams.suppress_non_speech_tokens = true;
// run the transformer and a single decoding pass
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {

View File

@ -1,8 +1,6 @@
set(TARGET whisper-cli)
add_executable(${TARGET} cli.cpp)
set(TARGET main)
add_executable(${TARGET} main.cpp)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE common whisper ${FFMPEG_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
install(TARGETS ${TARGET} RUNTIME)

View File

@ -1,12 +1,12 @@
# whisper.cpp/examples/cli
# main
This is the main example demonstrating most of the functionality of the Whisper model.
It can be used as a reference for using the `whisper.cpp` library in other projects.
```
./build/bin/whisper-cli -h
./main -h
usage: ./build-pkg/bin/whisper-cli [options] file0.wav file1.wav ...
usage: ./main [options] file0.wav file1.wav ...
options:
-h, --help [default] show this help message and exit
@ -20,12 +20,9 @@ options:
-sow, --split-on-word [false ] split on word rather than on token
-bo N, --best-of N [5 ] number of best candidates to keep
-bs N, --beam-size N [5 ] beam size for beam search
-ac N, --audio-ctx N [0 ] audio context size (0 - all)
-wt N, --word-thold N [0.01 ] word timestamp probability threshold
-et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail
-lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail
-tp, --temperature N [0.00 ] The sampling temperature, between 0 and 1
-tpi, --temperature-inc N [0.20 ] The increment of temperature, between 0 and 1
-debug, --debug-mode [false ] enable debug mode (eg. dump log_mel)
-tr, --translate [false ] translate from source language to english
-di, --diarize [false ] stereo audio diarization
@ -41,23 +38,16 @@ options:
-oj, --output-json [false ] output result in a JSON file
-ojf, --output-json-full [false ] include more information in the JSON file
-of FNAME, --output-file FNAME [ ] output file path (without file extension)
-np, --no-prints [false ] do not print anything other than the results
-ps, --print-special [false ] print special tokens
-pc, --print-colors [false ] print colors
-pp, --print-progress [false ] print progress
-nt, --no-timestamps [false ] do not print timestamps
-l LANG, --language LANG [en ] spoken language ('auto' for auto-detect)
-dl, --detect-language [false ] exit after automatically detecting language
--prompt PROMPT [ ] initial prompt (max n_text_ctx/2 tokens)
--prompt PROMPT [ ] initial prompt
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path
-f FNAME, --file FNAME [ ] input WAV file path
-oved D, --ov-e-device DNAME [CPU ] the OpenVINO device used for encode inference
-dtw MODEL --dtw MODEL [ ] compute token-level timestamps
-ls, --log-score [false ] log best decoder scores of tokens
-ng, --no-gpu [false ] disable GPU
-fa, --flash-attn [false ] flash attention
--suppress-regex REGEX [ ] regular expression matching tokens to suppress
--grammar GRAMMAR [ ] GBNF grammar to guide decoding
--grammar-rule RULE [ ] top-level GBNF grammar rule name
--grammar-penalty N [100.0 ] scales down logits of nongrammar tokens
```

View File

@ -43,7 +43,6 @@ struct whisper_params {
float word_thold = 0.01f;
float entropy_thold = 2.40f;
float logprob_thold = -1.00f;
float no_speech_thold = 0.6f;
float grammar_penalty = 100.0f;
float temperature = 0.0f;
float temperature_inc = 0.2f;
@ -71,7 +70,6 @@ struct whisper_params {
bool log_score = false;
bool use_gpu = true;
bool flash_attn = false;
bool suppress_nst = false;
std::string language = "en";
std::string prompt;
@ -106,11 +104,6 @@ static char * whisper_param_turn_lowercase(char * in){
return in;
}
static char * requires_value_error(const std::string & arg) {
fprintf(stderr, "error: argument %s requires value\n", arg.c_str());
exit(0);
}
static bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
for (int i = 1; i < argc; i++) {
std::string arg = argv[i];
@ -129,23 +122,21 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params
whisper_print_usage(argc, argv, params);
exit(0);
}
#define ARGV_NEXT (((i + 1) < argc) ? argv[++i] : requires_value_error(arg))
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(ARGV_NEXT); }
else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(ARGV_NEXT); }
else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(ARGV_NEXT); }
else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(ARGV_NEXT); }
else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(ARGV_NEXT); }
else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(ARGV_NEXT); }
else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(ARGV_NEXT); }
else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(ARGV_NEXT); }
else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(ARGV_NEXT); }
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(ARGV_NEXT); }
else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(ARGV_NEXT); }
else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(ARGV_NEXT); }
else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(ARGV_NEXT); }
else if (arg == "-nth" || arg == "--no-speech-thold") { params.no_speech_thold = std::stof(ARGV_NEXT); }
else if (arg == "-tp" || arg == "--temperature") { params.temperature = std::stof(ARGV_NEXT); }
else if (arg == "-tpi" || arg == "--temperature-inc") { params.temperature_inc = std::stof(ARGV_NEXT); }
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(argv[++i]); }
else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(argv[++i]); }
else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(argv[++i]); }
else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); }
else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); }
else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); }
else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); }
else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); }
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); }
else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); }
else if (arg == "-tp" || arg == "--temperature") { params.temperature = std::stof(argv[++i]); }
else if (arg == "-tpi" || arg == "--temperature-inc") { params.temperature_inc = std::stof(argv[++i]); }
else if (arg == "-debug"|| arg == "--debug-mode") { params.debug_mode = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
@ -157,31 +148,30 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params
else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; }
else if (arg == "-olrc" || arg == "--output-lrc") { params.output_lrc = true; }
else if (arg == "-fp" || arg == "--font-path") { params.font_path = ARGV_NEXT; }
else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; }
else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; }
else if (arg == "-oj" || arg == "--output-json") { params.output_jsn = true; }
else if (arg == "-ojf" || arg == "--output-json-full"){ params.output_jsn_full = params.output_jsn = true; }
else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(ARGV_NEXT); }
else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(argv[++i]); }
else if (arg == "-np" || arg == "--no-prints") { params.no_prints = true; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
else if (arg == "-l" || arg == "--language") { params.language = whisper_param_turn_lowercase(ARGV_NEXT); }
else if (arg == "-l" || arg == "--language") { params.language = whisper_param_turn_lowercase(argv[++i]); }
else if (arg == "-dl" || arg == "--detect-language") { params.detect_language = true; }
else if ( arg == "--prompt") { params.prompt = ARGV_NEXT; }
else if (arg == "-m" || arg == "--model") { params.model = ARGV_NEXT; }
else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(ARGV_NEXT); }
else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = ARGV_NEXT; }
else if (arg == "-dtw" || arg == "--dtw") { params.dtw = ARGV_NEXT; }
else if ( arg == "--prompt") { params.prompt = argv[++i]; }
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); }
else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; }
else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; }
else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; }
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
else if (arg == "-sns" || arg == "--suppress-nst") { params.suppress_nst = true; }
else if ( arg == "--suppress-regex") { params.suppress_regex = ARGV_NEXT; }
else if ( arg == "--grammar") { params.grammar = ARGV_NEXT; }
else if ( arg == "--grammar-rule") { params.grammar_rule = ARGV_NEXT; }
else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(ARGV_NEXT); }
else if ( arg == "--suppress-regex") { params.suppress_regex = argv[++i]; }
else if ( arg == "--grammar") { params.grammar = argv[++i]; }
else if ( arg == "--grammar-rule") { params.grammar_rule = argv[++i]; }
else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params);
@ -212,7 +202,6 @@ static void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold);
fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold);
fprintf(stderr, " -nth N, --no-speech-thold N [%-7.2f] no speech threshold\n", params.no_speech_thold);
fprintf(stderr, " -tp, --temperature N [%-7.2f] The sampling temperature, between 0 and 1\n", params.temperature);
fprintf(stderr, " -tpi, --temperature-inc N [%-7.2f] The increment of temperature, between 0 and 1\n",params.temperature_inc);
fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false");
@ -245,7 +234,6 @@ static void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params
fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false");
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false");
fprintf(stderr, " -sns, --suppress-nst [%-7s] suppress non-speech tokens\n", params.suppress_nst ? "true" : "false");
fprintf(stderr, " --suppress-regex REGEX [%-7s] regular expression matching tokens to suppress\n", params.suppress_regex.c_str());
fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
fprintf(stderr, " --grammar-rule RULE [%-7s] top-level GBNF grammar rule name\n", params.grammar_rule.c_str());
@ -1133,12 +1121,9 @@ int main(int argc, char ** argv) {
wparams.entropy_thold = params.entropy_thold;
wparams.logprob_thold = params.logprob_thold;
wparams.no_speech_thold = params.no_speech_thold;
wparams.no_timestamps = params.no_timestamps;
wparams.suppress_nst = params.suppress_nst;
whisper_print_user_data user_data = { &params, &pcmf32s, 0 };
const auto & grammar_parsed = params.grammar_parsed;

View File

@ -1,4 +1,4 @@
set(TARGET whisper-server)
set(TARGET server)
add_executable(${TARGET} server.cpp httplib.h)
include(DefaultTargetOptions)
@ -8,5 +8,3 @@ target_link_libraries(${TARGET} PRIVATE common json_cpp whisper ${CMAKE_THREAD_L
if (WIN32)
target_link_libraries(${TARGET} PRIVATE ws2_32)
endif()
install(TARGETS ${TARGET} RUNTIME)

View File

@ -1,4 +1,4 @@
# whisper.cpp/examples/server
# whisper.cpp http server
Simple http server. WAV Files are passed to the inference model via http requests.
@ -7,9 +7,9 @@ https://github.com/ggerganov/whisper.cpp/assets/1991296/e983ee53-8741-4eb5-9048-
## Usage
```
./build/bin/whisper-server -h
./server -h
usage: ./build/bin/whisper-server [options]
usage: ./bin/server [options]
options:
-h, --help [default] show this help message and exit

View File

@ -61,7 +61,6 @@ struct whisper_params {
float logprob_thold = -1.00f;
float temperature = 0.00f;
float temperature_inc = 0.20f;
float no_speech_thold = 0.6f;
bool debug_mode = false;
bool translate = false;
@ -77,7 +76,6 @@ struct whisper_params {
bool no_timestamps = false;
bool use_gpu = true;
bool flash_attn = false;
bool suppress_nst = false;
std::string language = "en";
std::string prompt = "";
@ -136,9 +134,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " --public PATH, [%-7s] Path to the public folder\n", sparams.public_path.c_str());
fprintf(stderr, " --request-path PATH, [%-7s] Request path for all requests\n", sparams.request_path.c_str());
fprintf(stderr, " --inference-path PATH, [%-7s] Inference path for all requests\n", sparams.inference_path.c_str());
fprintf(stderr, " --convert, [%-7s] Convert audio to WAV, requires ffmpeg on the server\n", sparams.ffmpeg_converter ? "true" : "false");
fprintf(stderr, " -sns, --suppress-nst [%-7s] suppress non-speech tokens\n", params.suppress_nst ? "true" : "false");
fprintf(stderr, " -nth N, --no-speech-thold N [%-7.2f] no speech threshold\n", params.no_speech_thold);
fprintf(stderr, " --convert, [%-7s] Convert audio to WAV, requires ffmpeg on the server", sparams.ffmpeg_converter ? "true" : "false");
fprintf(stderr, "\n");
}
@ -183,9 +179,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve
else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; }
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
else if (arg == "-sns" || arg == "--suppress-nst") { params.suppress_nst = true; }
else if (arg == "-nth" || arg == "--no-speech-thold") { params.no_speech_thold = std::stof(argv[++i]); }
// server params
else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); }
else if ( arg == "--host") { sparams.hostname = argv[++i]; }
@ -479,14 +472,6 @@ void get_req_parameters(const Request & req, whisper_params & params)
{
params.temperature_inc = std::stof(req.get_file_value("temperature_inc").content);
}
if (req.has_file("suppress_non_speech"))
{
params.suppress_nst = parse_str_to_bool(req.get_file_value("suppress_non_speech").content);
}
if (req.has_file("suppress_nst"))
{
params.suppress_nst = parse_str_to_bool(req.get_file_value("suppress_nst").content);
}
}
} // namespace
@ -692,8 +677,7 @@ int main(int argc, char ** argv) {
if (sparams.ffmpeg_converter) {
// if file is not wav, convert to wav
// write to temporary file
//const std::string temp_filename_base = std::tmpnam(nullptr);
const std::string temp_filename_base = "whisper-server-tmp"; // TODO: this is a hack, remove when the mutext is removed
const std::string temp_filename_base = std::tmpnam(nullptr);
const std::string temp_filename = temp_filename_base + ".wav";
std::ofstream temp_file{temp_filename, std::ios::binary};
temp_file << audio_file.content;
@ -727,6 +711,7 @@ int main(int argc, char ** argv) {
}
}
printf("Successfully loaded %s\n", filename.c_str());
// print system information
@ -794,7 +779,6 @@ int main(int argc, char ** argv) {
wparams.beam_search.beam_size = params.beam_size;
wparams.temperature = params.temperature;
wparams.no_speech_thold = params.no_speech_thold;
wparams.temperature_inc = params.temperature_inc;
wparams.entropy_thold = params.entropy_thold;
wparams.logprob_thold = params.logprob_thold;
@ -802,8 +786,6 @@ int main(int argc, char ** argv) {
wparams.no_timestamps = params.no_timestamps;
wparams.token_timestamps = !params.no_timestamps && params.response_format == vjson_format;
wparams.suppress_nst = params.suppress_nst;
whisper_print_user_data user_data = { &params, &pcmf32s, 0 };
// this callback is called on each new segment
@ -947,7 +929,7 @@ int main(int argc, char ** argv) {
// TODO compression_ratio and no_speech_prob are not implemented yet
// segment["compression_ratio"] = 0;
segment["no_speech_prob"] = whisper_full_get_segment_no_speech_prob(ctx, i);
// segment["no_speech_prob"] = 0;
jres["segments"].push_back(segment);
}

View File

@ -1,10 +1,9 @@
if (WHISPER_SDL2)
set(TARGET whisper-stream)
# stream
set(TARGET stream)
add_executable(${TARGET} stream.cpp)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE common common-sdl whisper ${CMAKE_THREAD_LIBS_INIT})
install(TARGETS ${TARGET} RUNTIME)
endif ()

View File

@ -1,11 +1,11 @@
# whisper.cpp/examples/stream
# stream
This is a naive example of performing real-time inference on audio from your microphone.
The `whisper-stream` tool samples the audio every half a second and runs the transcription continously.
The `stream` tool samples the audio every half a second and runs the transcription continously.
More info is available in [issue #10](https://github.com/ggerganov/whisper.cpp/issues/10).
```bash
./build/bin/whisper-stream -m ./models/ggml-base.en.bin -t 8 --step 500 --length 5000
./stream -m ./models/ggml-base.en.bin -t 8 --step 500 --length 5000
```
https://user-images.githubusercontent.com/1991296/194935793-76afede7-cfa8-48d8-a80f-28ba83be7d09.mp4
@ -15,7 +15,7 @@ https://user-images.githubusercontent.com/1991296/194935793-76afede7-cfa8-48d8-a
Setting the `--step` argument to `0` enables the sliding window mode:
```bash
./build/bin/whisper-stream -m ./models/ggml-base.en.bin -t 6 --step 0 --length 30000 -vth 0.6
./stream -m ./models/ggml-small.en.bin -t 6 --step 0 --length 30000 -vth 0.6
```
In this mode, the tool will transcribe only after some speech activity is detected. A very
@ -27,7 +27,7 @@ a transcription block that is suitable for parsing.
## Building
The `whisper-stream` tool depends on SDL2 library to capture audio from the microphone. You can build it like this:
The `stream` tool depends on SDL2 library to capture audio from the microphone. You can build it like this:
```bash
# Install SDL2
@ -40,10 +40,21 @@ sudo dnf install SDL2 SDL2-devel
# Install SDL2 on Mac OS
brew install sdl2
cmake -B build -DWHISPER_SDL2=ON
cmake --build build --config Release
make stream
```
./build/bin/whisper-stream
Ensure you are at the root of the repo when running `make stream`. Not within the `examples/stream` dir
as the libraries needed like `common-sdl.h` are located within `examples`. Attempting to compile within
`examples/steam` means your compiler cannot find them and it gives an error it cannot find the file.
```bash
whisper.cpp/examples/stream$ make stream
g++ stream.cpp -o stream
stream.cpp:6:10: fatal error: common/sdl.h: No such file or directory
6 | #include "common/sdl.h"
| ^~~~~~~~~~~~~~
compilation terminated.
make: *** [<builtin>: stream] Error 1
```
## Web version

View File

@ -5,5 +5,5 @@
set(TARGET ls-sycl-device)
add_executable(${TARGET} ls-sycl-device.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common whisper ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_17)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_17)

View File

@ -7,16 +7,13 @@ cd build
source /opt/intel/oneapi/setvars.sh
#for FP16
#cmake .. -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DWHISPER_SYCL_F16=ON # faster for long-prompt inference
#cmake .. -DWHISPER_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DWHISPER_SYCL_F16=ON # faster for long-prompt inference
#for FP32
cmake .. -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx
#for other features from the examples, e.g. stream and talk link with SDL2:
#cmake .. -DGGML_SYCL=ON -DWHISPER_SDL2=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx
cmake .. -DWHISPER_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx
#build example/main only
#cmake --build . --config Release --target main
#build all binary
cmake --build . --config Release -v
cmake --build . --config Release -v

View File

@ -1,5 +1,6 @@
if (WHISPER_SDL2)
set(TARGET whisper-talk-llama)
# talk-llama
set(TARGET talk-llama)
add_executable(${TARGET} talk-llama.cpp
llama.cpp
llama-vocab.cpp

View File

@ -1,4 +1,4 @@
# whisper.cpp/examples/talk-llama
# talk-llama
Talk with an LLaMA AI in your terminal
@ -12,7 +12,7 @@ https://github.com/ggerganov/whisper.cpp/assets/1991296/d97a3788-bf2a-4756-9a43-
## Building
The `whisper-talk-llama` tool depends on SDL2 library to capture audio from the microphone. You can build it like this:
The `talk-llama` tool depends on SDL2 library to capture audio from the microphone. You can build it like this:
```bash
# Install SDL2
@ -25,12 +25,11 @@ sudo dnf install SDL2 SDL2-devel
# Install SDL2 on Mac OS
brew install sdl2
# Build the "whisper-talk-llama" executable
cmake -B build -S . -DWHISPER_SDL2=ON
cmake --build build --config Release
# Build the "talk-llama" executable
make talk-llama
# Run it
./build/bin/whisper-talk-llama -mw ./models/ggml-small.en.bin -ml ../llama.cpp/models/llama-13b/ggml-model-q4_0.gguf -p "Georgi" -t 8
./talk-llama -mw ./models/ggml-small.en.bin -ml ../llama.cpp/models/llama-13b/ggml-model-q4_0.gguf -p "Georgi" -t 8
```
- The `-mw` argument specifies the Whisper model that you would like to use. Recommended `base` or `small` for real-time experience
@ -38,16 +37,16 @@ cmake --build build --config Release
## Session
The `whisper-talk-llama` tool supports session management to enable more coherent and continuous conversations. By maintaining context from previous interactions, it can better understand and respond to user requests in a more natural way.
The `talk-llama` tool supports session management to enable more coherent and continuous conversations. By maintaining context from previous interactions, it can better understand and respond to user requests in a more natural way.
To enable session support, use the `--session FILE` command line option when running the program. The `whisper-talk-llama` model state will be saved to the specified file after each interaction. If the file does not exist, it will be created. If the file exists, the model state will be loaded from it, allowing you to resume a previous session.
To enable session support, use the `--session FILE` command line option when running the program. The `talk-llama` model state will be saved to the specified file after each interaction. If the file does not exist, it will be created. If the file exists, the model state will be loaded from it, allowing you to resume a previous session.
This feature is especially helpful for maintaining context in long conversations or when interacting with the AI assistant across multiple sessions. It ensures that the assistant remembers the previous interactions and can provide more relevant and contextual responses.
Example usage:
```bash
./build/bin/whisper-talk-llama --session ./my-session-file -mw ./models/ggml-small.en.bin -ml ../llama.cpp/models/llama-13b/ggml-model-q4_0.gguf -p "Georgi" -t 8
./talk-llama --session ./my-session-file -mw ./models/ggml-small.en.bin -ml ../llama.cpp/models/llama-13b/ggml-model-q4_0.gguf -p "Georgi" -t 8
```
## TTS

File diff suppressed because it is too large Load Diff

View File

@ -4,6 +4,8 @@
#include "llama-grammar.h"
#include <unordered_map>
struct llama_vocab;
struct llama_grammar;
@ -25,24 +27,3 @@ struct llama_sampler * llama_sampler_init_grammar_impl(
const struct llama_vocab & vocab,
const char * grammar_str,
const char * grammar_root);
struct llama_sampler * llama_sampler_init_infill_impl(
const struct llama_vocab & vocab);
struct llama_sampler * llama_sampler_init_dry_impl(
const struct llama_vocab & vocab,
int32_t context_size,
float dry_multiplier,
float dry_base,
int32_t dry_allowed_length,
int32_t dry_penalty_last_n,
const char ** seq_breakers,
size_t num_breakers);
struct llama_sampler * llama_sampler_init_dry_testing(
int32_t context_size,
float dry_multiplier,
float dry_base,
int32_t dry_allowed_length,
int32_t dry_penalty_last_n,
const std::vector<std::vector<llama_token>>& seq_breakers);

View File

@ -221,7 +221,7 @@ struct llm_tokenizer_spm_session {
}
// seed the work queue with all possible 2-character tokens.
for (int i = 1; i < (int) symbols.size(); ++i) {
for (size_t i = 1; i < symbols.size(); ++i) {
try_add_bigram(i - 1, i);
}
@ -418,7 +418,6 @@ struct llm_tokenizer_bpe : llm_tokenizer {
case LLAMA_VOCAB_PRE_TYPE_SMOLLM:
case LLAMA_VOCAB_PRE_TYPE_CODESHELL:
case LLAMA_VOCAB_PRE_TYPE_EXAONE:
case LLAMA_VOCAB_PRE_TYPE_MINERVA:
regex_exprs = {
"\\p{N}",
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
@ -564,7 +563,7 @@ struct llm_tokenizer_bpe_session {
index++;
symbols.emplace_back(sym);
}
for (int i = 1; i < (int) symbols.size(); ++i) {
for (size_t i = 1; i < symbols.size(); ++i) {
add_new_bigram(i - 1, i);
}
@ -738,7 +737,7 @@ struct llm_tokenizer_wpm_session {
std::vector<std::string> words(1, "");
for (const uint32_t cpt : cpts_nfd) {
const auto flags = unicode_cpt_flags_from_cpt(cpt);
const auto flags = unicode_cpt_flags(cpt);
if (flags.is_whitespace) {
if (words.back().size()) { // finish previous word if any
@ -1664,14 +1663,6 @@ llama_token llama_token_eos_impl(const struct llama_vocab & vocab) {
return vocab.special_eos_id;
}
llama_token llama_token_eot_impl(const struct llama_vocab & vocab) {
return vocab.special_eot_id;
}
llama_token llama_token_eom_impl(const struct llama_vocab & vocab) {
return vocab.special_eom_id;
}
llama_token llama_token_cls_impl(const struct llama_vocab & vocab) {
return vocab.special_cls_id;
}
@ -1697,39 +1688,23 @@ bool llama_add_eos_token_impl(const struct llama_vocab & vocab) {
}
llama_token llama_token_prefix_impl(const struct llama_vocab & vocab) {
return vocab.special_fim_pre_id;
return vocab.special_prefix_id;
}
llama_token llama_token_middle_impl(const struct llama_vocab & vocab) {
return vocab.special_fim_mid_id;
return vocab.special_middle_id;
}
llama_token llama_token_suffix_impl(const struct llama_vocab & vocab) {
return vocab.special_fim_suf_id;
return vocab.special_suffix_id;
}
llama_token llama_token_fim_pre_impl(const struct llama_vocab & vocab) {
return vocab.special_fim_pre_id;
llama_token llama_token_eot_impl(const struct llama_vocab & vocab) {
return vocab.special_eot_id;
}
llama_token llama_token_fim_suf_impl(const struct llama_vocab & vocab) {
return vocab.special_fim_suf_id;
}
llama_token llama_token_fim_mid_impl(const struct llama_vocab & vocab) {
return vocab.special_fim_mid_id;
}
llama_token llama_token_fim_pad_impl(const struct llama_vocab & vocab) {
return vocab.special_fim_pad_id;
}
llama_token llama_token_fim_rep_impl(const struct llama_vocab & vocab) {
return vocab.special_fim_rep_id;
}
llama_token llama_token_fim_sep_impl(const struct llama_vocab & vocab) {
return vocab.special_fim_sep_id;
llama_token llama_token_eom_impl(const struct llama_vocab & vocab) {
return vocab.special_eom_id;
}
int32_t llama_tokenize_impl(
@ -1967,19 +1942,3 @@ int32_t llama_detokenize_impl(
return total <= text_len_max ? total : -total;
}
std::string llama_detokenize(const struct llama_vocab & vocab, const std::vector<llama_token> & tokens, bool special) {
std::string text;
text.resize(std::max(text.capacity(), tokens.size()));
int32_t n_chars = llama_detokenize_impl(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
if (n_chars < 0) {
text.resize(-n_chars);
n_chars = llama_detokenize_impl(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
GGML_ASSERT(n_chars <= (int32_t)text.size()); // whitespace trimming is performed after per-token detokenization
}
text.resize(n_chars);
// NOTE: the original tokenizer decodes bytes after collecting the pieces.
return text;
}

View File

@ -37,26 +37,20 @@ struct llama_vocab {
std::map<std::pair<std::string, std::string>, int> bpe_ranks;
// default LLaMA special tokens
// TODO: should we set all of these to LLAMA_TOKEN_NULL?
id special_bos_id = 1;
id special_eos_id = 2;
id special_eot_id = LLAMA_TOKEN_NULL;
id special_eom_id = LLAMA_TOKEN_NULL;
id special_unk_id = 0;
id special_sep_id = LLAMA_TOKEN_NULL;
id special_pad_id = LLAMA_TOKEN_NULL;
id special_cls_id = LLAMA_TOKEN_NULL;
id special_mask_id = LLAMA_TOKEN_NULL;
id special_sep_id = -1;
id special_pad_id = -1;
id special_cls_id = -1;
id special_mask_id = -1;
id linefeed_id = 13;
// fim tokens
id special_fim_pre_id = LLAMA_TOKEN_NULL;
id special_fim_suf_id = LLAMA_TOKEN_NULL;
id special_fim_mid_id = LLAMA_TOKEN_NULL;
id special_fim_pad_id = LLAMA_TOKEN_NULL;
id special_fim_rep_id = LLAMA_TOKEN_NULL; // repo
id special_fim_sep_id = LLAMA_TOKEN_NULL; // file separator
id linefeed_id = 13;
id special_prefix_id = -1;
id special_suffix_id = -1;
id special_middle_id = -1;
id special_eot_id = -1; // TODO: move above after "eos_id", and here add "file separator" token
id special_eom_id = -1;
// set of all tokens that cause "end of generation"
std::set<id> special_eog_ids;
@ -110,26 +104,19 @@ bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token t
llama_token llama_token_bos_impl(const struct llama_vocab & vocab);
llama_token llama_token_eos_impl(const struct llama_vocab & vocab);
llama_token llama_token_eot_impl(const struct llama_vocab & vocab);
llama_token llama_token_eom_impl(const struct llama_vocab & vocab);
llama_token llama_token_cls_impl(const struct llama_vocab & vocab);
llama_token llama_token_sep_impl(const struct llama_vocab & vocab);
llama_token llama_token_nl_impl (const struct llama_vocab & vocab);
llama_token llama_token_pad_impl(const struct llama_vocab & vocab);
bool llama_add_bos_token_impl(const struct llama_vocab & vocab);
bool llama_add_eos_token_impl(const struct llama_vocab & vocab);
llama_token llama_token_prefix_impl(const struct llama_vocab & vocab);
llama_token llama_token_middle_impl(const struct llama_vocab & vocab);
llama_token llama_token_suffix_impl(const struct llama_vocab & vocab);
llama_token llama_token_fim_pre_impl(const struct llama_vocab & vocab);
llama_token llama_token_fim_suf_impl(const struct llama_vocab & vocab);
llama_token llama_token_fim_mid_impl(const struct llama_vocab & vocab);
llama_token llama_token_fim_pad_impl(const struct llama_vocab & vocab);
llama_token llama_token_fim_rep_impl(const struct llama_vocab & vocab);
llama_token llama_token_fim_sep_impl(const struct llama_vocab & vocab);
bool llama_add_bos_token_impl(const struct llama_vocab & vocab);
bool llama_add_eos_token_impl(const struct llama_vocab & vocab);
llama_token llama_token_eot_impl (const struct llama_vocab & vocab);
llama_token llama_token_eom_impl (const struct llama_vocab & vocab);
int32_t llama_tokenize_impl(
const struct llama_vocab & vocab,
@ -149,12 +136,6 @@ int32_t llama_token_to_piece_impl(
int32_t lstrip,
bool special);
// check if token0 is contained as a prefix in token1
bool llama_token_is_prefix_impl(
const struct llama_vocab & vocab,
llama_token token0,
llama_token token1);
int32_t llama_detokenize_impl(
const struct llama_vocab & vocab,
const llama_token * tokens,
@ -163,8 +144,3 @@ int32_t llama_detokenize_impl(
int32_t text_len_max,
bool remove_special,
bool unparse_special);
std::string llama_detokenize(
const struct llama_vocab & vocab,
const std::vector<llama_token> & tokens,
bool special);

File diff suppressed because it is too large Load Diff

View File

@ -2,7 +2,6 @@
#define LLAMA_H
#include "ggml.h"
#include "ggml-cpu.h"
#include "ggml-backend.h"
#include <stddef.h>
@ -104,15 +103,12 @@ extern "C" {
LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24,
LLAMA_VOCAB_PRE_TYPE_EXAONE = 25,
LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26,
LLAMA_VOCAB_PRE_TYPE_MINERVA = 27,
};
enum llama_rope_type {
LLAMA_ROPE_TYPE_NONE = -1,
LLAMA_ROPE_TYPE_NORM = 0,
LLAMA_ROPE_TYPE_NEOX = GGML_ROPE_TYPE_NEOX,
LLAMA_ROPE_TYPE_MROPE = GGML_ROPE_TYPE_MROPE,
LLAMA_ROPE_TYPE_VISION = GGML_ROPE_TYPE_VISION,
LLAMA_ROPE_TYPE_NONE = -1,
LLAMA_ROPE_TYPE_NORM = 0,
LLAMA_ROPE_TYPE_NEOX = GGML_ROPE_TYPE_NEOX,
};
enum llama_token_type { //TODO: remove, required until per token attributes are available from GGUF file
@ -174,9 +170,9 @@ extern "C" {
LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors
LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors
LLAMA_FTYPE_MOSTLY_BF16 = 32, // except 1d tensors
//LLAMA_FTYPE_MOSTLY_Q4_0_4_4 = 33, // removed from gguf files, use Q4_0 and runtime repack
//LLAMA_FTYPE_MOSTLY_Q4_0_4_8 = 34, // removed from gguf files, use Q4_0 and runtime repack
//LLAMA_FTYPE_MOSTLY_Q4_0_8_8 = 35, // removed from gguf files, use Q4_0 and runtime repack
LLAMA_FTYPE_MOSTLY_Q4_0_4_4 = 33, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_0_4_8 = 34, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_0_8_8 = 35, // except 1d tensors
LLAMA_FTYPE_MOSTLY_TQ1_0 = 36, // except 1d tensors
LLAMA_FTYPE_MOSTLY_TQ2_0 = 37, // except 1d tensors
@ -188,8 +184,7 @@ extern "C" {
LLAMA_ROPE_SCALING_TYPE_NONE = 0,
LLAMA_ROPE_SCALING_TYPE_LINEAR = 1,
LLAMA_ROPE_SCALING_TYPE_YARN = 2,
LLAMA_ROPE_SCALING_TYPE_LONGROPE = 3,
LLAMA_ROPE_SCALING_TYPE_MAX_VALUE = LLAMA_ROPE_SCALING_TYPE_LONGROPE,
LLAMA_ROPE_SCALING_TYPE_MAX_VALUE = LLAMA_ROPE_SCALING_TYPE_YARN,
};
enum llama_pooling_type {
@ -210,7 +205,7 @@ extern "C" {
enum llama_split_mode {
LLAMA_SPLIT_MODE_NONE = 0, // single GPU
LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs
LLAMA_SPLIT_MODE_ROW = 2, // split layers and KV across GPUs, use tensor parallelism if supported
LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs
};
// TODO: simplify (https://github.com/ggerganov/llama.cpp/pull/9294#pullrequestreview-2286561979)
@ -222,7 +217,6 @@ extern "C" {
typedef struct llama_token_data_array {
// TODO: consider SoA
// NOTE: this pointer can be modified by the samplers
llama_token_data * data;
size_t size;
int64_t selected; // this is the index in the data array (i.e. not the token id)
@ -238,11 +232,8 @@ extern "C" {
// - token : the token ids of the input (used when embd is NULL)
// - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
// - pos : the positions of the respective token in the sequence
// (if set to NULL, the token position will be tracked automatically by llama_decode)
// - seq_id : the sequence to which the respective token belongs
// (if set to NULL, the sequence ID will be assumed to be 0)
// - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
// (if set to NULL, only the logits for last token will be returned)
//
typedef struct llama_batch {
int32_t n_tokens;
@ -253,6 +244,15 @@ extern "C" {
int32_t * n_seq_id;
llama_seq_id ** seq_id;
int8_t * logits; // TODO: rename this to "output"
// NOTE: helpers for smooth API transition - can be deprecated in the future
// for future-proof code, use the above fields instead and ignore everything below
//
// pos[i] = all_pos_0 + i*all_pos_1
//
llama_pos all_pos_0; // used if pos == NULL
llama_pos all_pos_1; // used if pos == NULL
llama_seq_id all_seq_id; // used if seq_id == NULL
} llama_batch;
enum llama_model_kv_override_type {
@ -276,13 +276,13 @@ extern "C" {
};
struct llama_model_params {
// NULL-terminated list of devices to use for offloading (if NULL, all available devices are used)
ggml_backend_dev_t * devices;
int32_t n_gpu_layers; // number of layers to store in VRAM
enum llama_split_mode split_mode; // how to split the model across multiple GPUs
// the GPU that is used for the entire model when split_mode is LLAMA_SPLIT_MODE_NONE
// main_gpu interpretation depends on split_mode:
// LLAMA_SPLIT_MODE_NONE: the GPU that is used for the entire model
// LLAMA_SPLIT_MODE_ROW: the GPU that is used for small tensors and intermediate results
// LLAMA_SPLIT_MODE_LAYER: ignored
int32_t main_gpu;
// proportion of the model (layers or rows) to offload to each GPU, size: llama_max_devices()
@ -433,7 +433,6 @@ extern "C" {
LLAMA_API bool llama_supports_mmap (void);
LLAMA_API bool llama_supports_mlock (void);
LLAMA_API bool llama_supports_gpu_offload(void);
LLAMA_API bool llama_supports_rpc (void);
LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
@ -458,7 +457,6 @@ extern "C" {
// Functions to access the model's GGUF metadata scalar values
// - The functions return the length of the string on success, or -1 on failure
// - The output string is always null-terminated and cleared on failure
// - When retrieving a string, an extra byte must be allocated to account for the null terminator
// - GGUF array values are not supported by these functions
// Get metadata value as a string by key name
@ -675,9 +673,6 @@ extern "C" {
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
LLAMA_API void llama_kv_cache_update(struct llama_context * ctx);
// Check if the context supports KV cache shifting
LLAMA_API bool llama_kv_cache_can_shift(struct llama_context * ctx);
//
// State / sessions
//
@ -780,15 +775,15 @@ extern "C" {
// Decoding
//
// Return batch for single sequence of tokens
// The sequence ID will be fixed to 0
// The position of the tokens will be tracked automatically by llama_decode
// Return batch for single sequence of tokens starting at pos_0
//
// NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it
//
LLAMA_API struct llama_batch llama_batch_get_one(
llama_token * tokens,
int32_t n_tokens);
int32_t n_tokens,
llama_pos pos_0,
llama_seq_id seq_id);
// Allocates a batch of tokens on the heap that can hold a maximum of n_tokens
// Each token can be assigned up to n_seq_max sequence ids
@ -808,7 +803,7 @@ extern "C" {
// Processes a batch of tokens with the ecoder part of the encoder-decoder model.
// Stores the encoder output internally for later use by the decoder cross-attention layers.
// 0 - success
// < 0 - error. the KV cache state is restored to the state before this call
// < 0 - error
LLAMA_API int32_t llama_encode(
struct llama_context * ctx,
struct llama_batch batch);
@ -816,7 +811,7 @@ extern "C" {
// Positive return values does not mean a fatal error, but rather a warning.
// 0 - success
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
// < 0 - error. the KV cache state is restored to the state before this call
// < 0 - error
LLAMA_API int32_t llama_decode(
struct llama_context * ctx,
struct llama_batch batch);
@ -901,7 +896,6 @@ extern "C" {
// Special tokens
LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence
LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence
LLAMA_API llama_token llama_token_eot(const struct llama_model * model); // end-of-turn
LLAMA_API llama_token llama_token_cls(const struct llama_model * model); // classification
LLAMA_API llama_token llama_token_sep(const struct llama_model * model); // sentence separator
LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line
@ -910,17 +904,11 @@ extern "C" {
LLAMA_API bool llama_add_bos_token(const struct llama_model * model);
LLAMA_API bool llama_add_eos_token(const struct llama_model * model);
// infill tokens
DEPRECATED(LLAMA_API llama_token llama_token_prefix(const struct llama_model * model), "use llama_token_fim_pre instead");
DEPRECATED(LLAMA_API llama_token llama_token_middle(const struct llama_model * model), "use llama_token_fim_mid instead");
DEPRECATED(LLAMA_API llama_token llama_token_suffix(const struct llama_model * model), "use llama_token_fim_suf instead");
LLAMA_API llama_token llama_token_fim_pre(const struct llama_model * model);
LLAMA_API llama_token llama_token_fim_suf(const struct llama_model * model);
LLAMA_API llama_token llama_token_fim_mid(const struct llama_model * model);
LLAMA_API llama_token llama_token_fim_pad(const struct llama_model * model);
LLAMA_API llama_token llama_token_fim_rep(const struct llama_model * model);
LLAMA_API llama_token llama_token_fim_sep(const struct llama_model * model);
// Codellama infill tokens
LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix
LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle
LLAMA_API llama_token llama_token_suffix(const struct llama_model * model); // Beginning of infill suffix
LLAMA_API llama_token llama_token_eot (const struct llama_model * model); // End of infill middle
//
// Tokenization
@ -995,9 +983,6 @@ extern "C" {
char * buf,
int32_t length);
// Get list of built-in chat templates
LLAMA_API int32_t llama_chat_builtin_templates(const char ** output, size_t len);
//
// Sampling API
//
@ -1082,13 +1067,12 @@ extern "C" {
// available samplers:
LLAMA_API struct llama_sampler * llama_sampler_init_greedy(void);
LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed);
LLAMA_API struct llama_sampler * llama_sampler_init_greedy (void);
LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed);
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
/// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first.
DEPRECATED(LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void),
"will be removed in the future (see https://github.com/ggerganov/llama.cpp/pull/9896#discussion_r1800920915)");
LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void);
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k);
@ -1099,18 +1083,16 @@ extern "C" {
/// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
LLAMA_API struct llama_sampler * llama_sampler_init_min_p (float p, size_t min_keep);
/// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
LLAMA_API struct llama_sampler * llama_sampler_init_tail_free (float z, size_t min_keep);
/// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
LLAMA_API struct llama_sampler * llama_sampler_init_typical (float p, size_t min_keep);
/// #details Updates the logits l_i` = l_i/t. When t <= 0.0f, the maximum logit is kept at it's original value, the rest are set to -inf
LLAMA_API struct llama_sampler * llama_sampler_init_temp (float t);
/// @details Dynamic temperature implementation (a.k.a. entropy) described in the paper https://arxiv.org/abs/2309.02772.
LLAMA_API struct llama_sampler * llama_sampler_init_temp_ext (float t, float delta, float exponent);
/// @details XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335
LLAMA_API struct llama_sampler * llama_sampler_init_xtc (float p, float t, size_t min_keep, uint32_t seed);
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
@ -1139,50 +1121,22 @@ extern "C" {
const char * grammar_str,
const char * grammar_root);
/// NOTE: Avoid using on the full vocabulary as searching for repeated tokens can become slow. For example, apply top-k or top-p sampling first.
LLAMA_API struct llama_sampler * llama_sampler_init_penalties(
int32_t penalty_last_n, // last n tokens to penalize (0 = disable penalty, -1 = context size)
float penalty_repeat, // 1.0 = disabled
float penalty_freq, // 0.0 = disabled
float penalty_present); // 0.0 = disabled
/// @details DRY sampler, designed by p-e-w, as described in: https://github.com/oobabooga/text-generation-webui/pull/5677, porting Koboldcpp implementation authored by pi6am: https://github.com/LostRuins/koboldcpp/pull/982
LLAMA_API struct llama_sampler * llama_sampler_init_dry(
const struct llama_model * model,
float dry_multiplier,
float dry_base,
int32_t dry_allowed_length,
int32_t dry_penalty_last_n,
const char ** seq_breakers,
size_t num_breakers);
int32_t n_vocab, // llama_n_vocab()
llama_token special_eos_id, // llama_token_eos()
llama_token linefeed_id, // llama_token_nl()
int32_t penalty_last_n, // last n tokens to penalize (0 = disable penalty, -1 = context size)
float penalty_repeat, // 1.0 = disabled
float penalty_freq, // 0.0 = disabled
float penalty_present, // 0.0 = disabled
bool penalize_nl, // consider newlines as a repeatable token
bool ignore_eos); // ignore the end-of-sequence token
LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias(
int32_t n_vocab,
int32_t n_logit_bias,
const llama_logit_bias * logit_bias);
// this sampler is meant to be used for fill-in-the-middle infilling
// it's supposed to be used after top_k + top_p sampling
//
// 1. if the sum of the EOG probs times the number of candidates is higher than the sum of the other probs -> pick EOG
// 2. combine probs of tokens that have the same prefix
//
// example:
//
// - before:
// "hel": 0.5
// "hell": 0.2
// "hello": 0.1
// "dummy": 0.1
//
// - after:
// "hel": 0.8
// "dummy": 0.1
//
// 3. discard non-EOG tokens with low prob
// 4. if no tokens are left -> pick EOT
//
LLAMA_API struct llama_sampler * llama_sampler_init_infill(const struct llama_model * model);
// Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise
LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl);
@ -1254,6 +1208,8 @@ extern "C" {
LLAMA_API void llama_perf_sampler_print(const struct llama_sampler * chain);
LLAMA_API void llama_perf_sampler_reset( struct llama_sampler * chain);
LLAMA_API void llama_perf_dump_yaml(FILE * stream, const struct llama_context * ctx);
#ifdef __cplusplus
}
#endif

View File

@ -2311,7 +2311,7 @@ const std::unordered_set<uint32_t> unicode_set_whitespace = {
0x003000,
};
// list is always in ascending order, to enable binary search
// list is always in ascending order, to enable binary searh
const std::initializer_list<std::pair<uint32_t, uint32_t>> unicode_map_lowercase = {
{0x000041, 0x000061},
{0x000042, 0x000062},
@ -3748,7 +3748,7 @@ const std::initializer_list<std::pair<uint32_t, uint32_t>> unicode_map_lowercase
{0x01E921, 0x01E943},
};
// list is always in ascending order, to enable binary search
// list is always in ascending order, to enable binary searh
const std::initializer_list<std::pair<uint32_t, uint32_t>> unicode_map_uppercase = {
{0x000061, 0x000041},
{0x000062, 0x000042},

View File

@ -71,15 +71,15 @@ uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset) {
throw std::invalid_argument("failed to convert utf8 to codepoint");
}
//static std::vector<uint16_t> unicode_cpt_to_utf16(uint32_t cpt) {
//static std::vector<uint16_t> unicode_cpt_to_utf16(uint32_t cp) {
// std::vector<uint16_t> result;
// if (/* 0x0000 <= cpt && */ cpt <= 0xffff) {
// result.emplace_back(cpt);
// if (/* 0x0000 <= cp && */ cp <= 0xffff) {
// result.emplace_back(cp);
// return result;
// }
// if (0x10000 <= cpt && cpt <= 0x10ffff) {
// result.emplace_back(0xd800 | ((cpt - 0x10000) >> 10));
// result.emplace_back(0xdc00 | ((cpt - 0x10000) & 0x03ff));
// if (0x10000 <= cp && cp <= 0x10ffff) {
// result.emplace_back(0xd800 | ((cp - 0x10000) >> 10));
// result.emplace_back(0xdc00 | ((cp - 0x10000) & 0x03ff));
// return result;
// }
// throw std::invalid_argument("failed to convert codepoint to utf16");
@ -120,8 +120,8 @@ uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset) {
// return result;
//}
static std::vector<unicode_cpt_flags> unicode_cpt_flags_array() {
std::vector<unicode_cpt_flags> cpt_flags(MAX_CODEPOINTS, unicode_cpt_flags::UNDEFINED);
static std::vector<codepoint_flags> unicode_cpt_flags_array() {
std::vector<codepoint_flags> cpt_flags(MAX_CODEPOINTS, codepoint_flags::UNDEFINED);
assert (unicode_ranges_flags.begin()[0].first == 0);
assert (unicode_ranges_flags.begin()[unicode_ranges_flags.size()-1].first == MAX_CODEPOINTS);
@ -201,18 +201,7 @@ static std::unordered_map<std::string, uint8_t> unicode_utf8_to_byte_map() {
}
static inline std::wstring unicode_wstring_from_utf8(const std::string & s) {
#if defined(__clang__)
// disable C++17 deprecation warning for std::codecvt_utf8
# pragma clang diagnostic push
# pragma clang diagnostic ignored "-Wdeprecated-declarations"
#endif
std::wstring_convert<std::codecvt_utf8<wchar_t>> conv;
#if defined(__clang__)
# pragma clang diagnostic pop
#endif
return conv.from_bytes(s);
}
@ -253,8 +242,8 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
};
auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags {
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{};
auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{};
};
size_t _prev_end = offset_ini;
@ -371,8 +360,8 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
};
auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags {
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{};
auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{};
};
size_t _prev_end = offset_ini;
@ -572,29 +561,29 @@ static std::vector<size_t> unicode_regex_split_custom(const std::string & text,
// interface
//
std::string unicode_cpt_to_utf8(uint32_t cpt) {
std::string unicode_cpt_to_utf8(uint32_t cp) {
std::string result;
if (/* 0x00 <= cpt && */ cpt <= 0x7f) {
result.push_back(cpt);
if (/* 0x00 <= cp && */ cp <= 0x7f) {
result.push_back(cp);
return result;
}
if (0x80 <= cpt && cpt <= 0x7ff) {
result.push_back(0xc0 | ((cpt >> 6) & 0x1f));
result.push_back(0x80 | (cpt & 0x3f));
if (0x80 <= cp && cp <= 0x7ff) {
result.push_back(0xc0 | ((cp >> 6) & 0x1f));
result.push_back(0x80 | (cp & 0x3f));
return result;
}
if (0x800 <= cpt && cpt <= 0xffff) {
result.push_back(0xe0 | ((cpt >> 12) & 0x0f));
result.push_back(0x80 | ((cpt >> 6) & 0x3f));
result.push_back(0x80 | (cpt & 0x3f));
if (0x800 <= cp && cp <= 0xffff) {
result.push_back(0xe0 | ((cp >> 12) & 0x0f));
result.push_back(0x80 | ((cp >> 6) & 0x3f));
result.push_back(0x80 | (cp & 0x3f));
return result;
}
if (0x10000 <= cpt && cpt <= 0x10ffff) {
result.push_back(0xf0 | ((cpt >> 18) & 0x07));
result.push_back(0x80 | ((cpt >> 12) & 0x3f));
result.push_back(0x80 | ((cpt >> 6) & 0x3f));
result.push_back(0x80 | (cpt & 0x3f));
if (0x10000 <= cp && cp <= 0x10ffff) {
result.push_back(0xf0 | ((cp >> 18) & 0x07));
result.push_back(0x80 | ((cp >> 12) & 0x3f));
result.push_back(0x80 | ((cp >> 6) & 0x3f));
result.push_back(0x80 | (cp & 0x3f));
return result;
}
@ -624,19 +613,19 @@ std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8) {
return result;
}
unicode_cpt_flags unicode_cpt_flags_from_cpt(const uint32_t cpt) {
static const unicode_cpt_flags undef(unicode_cpt_flags::UNDEFINED);
codepoint_flags unicode_cpt_flags(const uint32_t cp) {
static const codepoint_flags undef(codepoint_flags::UNDEFINED);
static const auto cpt_flags = unicode_cpt_flags_array();
return cpt < cpt_flags.size() ? cpt_flags[cpt] : undef;
return cp < cpt_flags.size() ? cpt_flags[cp] : undef;
}
unicode_cpt_flags unicode_cpt_flags_from_utf8(const std::string & utf8) {
static const unicode_cpt_flags undef(unicode_cpt_flags::UNDEFINED);
codepoint_flags unicode_cpt_flags(const std::string & utf8) {
static const codepoint_flags undef(codepoint_flags::UNDEFINED);
if (utf8.empty()) {
return undef; // undefined
}
size_t offset = 0;
return unicode_cpt_flags_from_cpt(unicode_cpt_from_utf8(utf8, offset));
return unicode_cpt_flags(unicode_cpt_from_utf8(utf8, offset));
}
std::string unicode_byte_to_utf8(uint8_t byte) {
@ -649,41 +638,41 @@ uint8_t unicode_utf8_to_byte(const std::string & utf8) {
return map.at(utf8);
}
uint32_t unicode_tolower(uint32_t cpt) {
uint32_t unicode_tolower(uint32_t cp) {
// binary search
auto it = std::lower_bound(unicode_map_lowercase.begin(), unicode_map_lowercase.end(), cpt,
auto it = std::lower_bound(unicode_map_lowercase.begin(), unicode_map_lowercase.end(), cp,
[](const std::pair<uint32_t, uint32_t> & pair, uint32_t value) {
return pair.first < value;
});
if (it != unicode_map_lowercase.end() && it->first == cpt) {
if (it != unicode_map_lowercase.end() && it->first == cp) {
return it->second;
}
return cpt; // Return the original code point if no lowercase mapping is found
return cp; // Return the original code point if no lowercase mapping is found
}
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
// unicode categories
static const std::map<std::string, int> k_ucat_enum = {
{ "\\p{N}", unicode_cpt_flags::NUMBER },
{ "\\p{L}", unicode_cpt_flags::LETTER },
{ "\\p{P}", unicode_cpt_flags::PUNCTUATION },
{ "\\p{N}", codepoint_flags::NUMBER },
{ "\\p{L}", codepoint_flags::LETTER },
{ "\\p{P}", codepoint_flags::PUNCTUATION },
};
static const std::map<int, int> k_ucat_cpt = {
{ unicode_cpt_flags::NUMBER, 0xD1 },
{ unicode_cpt_flags::LETTER, 0xD2 },
{ unicode_cpt_flags::PUNCTUATION, 0xD3 },
{ codepoint_flags::NUMBER, 0xD1 },
{ codepoint_flags::LETTER, 0xD2 },
{ codepoint_flags::PUNCTUATION, 0xD3 },
};
static const std::map<int, std::string> k_ucat_map = {
{ unicode_cpt_flags::NUMBER, "\x30-\x39" }, // 0-9
{ unicode_cpt_flags::LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z
{ unicode_cpt_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
{ codepoint_flags::NUMBER, "\x30-\x39" }, // 0-9
{ codepoint_flags::LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z
{ codepoint_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
};
// compute collapsed codepoints only if needed by at least one regex
bool need_collapse = false;
for (const auto & regex_expr : regex_exprs) {
for (auto & regex_expr : regex_exprs) {
// search for unicode categories
for (const auto & ucat : k_ucat_enum) {
if (std::string::npos != regex_expr.find(ucat.first)) {
@ -709,7 +698,7 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
continue;
}
const auto flags = unicode_cpt_flags_from_cpt(cpts[i]);
const auto flags = unicode_cpt_flags(cpts[i]);
if (flags.is_whitespace) {
//NOTE: C++ std::regex \s does not mach 0x85, Rust and Python regex does.
@ -725,7 +714,7 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
std::vector<size_t> bpe_offsets = { cpts.size() };
for (const auto & regex_expr : regex_exprs) {
for (auto & regex_expr : regex_exprs) {
// first, see if we have an efficient custom regex implementation
auto tmp = unicode_regex_split_custom(text, regex_expr, bpe_offsets);
@ -739,7 +728,7 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
// if a unicode category is used in the regex, we use the collapsed text and replace the unicode category
// with the corresponding collapsed representation
bool use_collapsed = false;
for (const auto & ucat : k_ucat_enum) {
for (auto & ucat : k_ucat_enum) {
if (std::string::npos != regex_expr.find(ucat.first)) {
use_collapsed = true;
break;
@ -805,7 +794,7 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
// std::wregex \s does not mach non-ASCII whitespaces, using 0x0B as fallback
std::wstring wtext(cpts.begin(), cpts.end());
for (size_t i = 0; i < wtext.size(); ++i) {
if (wtext[i] > 0x7F && unicode_cpt_flags_from_cpt(wtext[i]).is_whitespace) {
if (wtext[i] > 0x7F && unicode_cpt_flags(wtext[i]).is_whitespace) {
wtext[i] = 0x0B;
}
}

View File

@ -4,7 +4,9 @@
#include <string>
#include <vector>
struct unicode_cpt_flags {
// TODO: prefix all symbols with "llama_"
struct codepoint_flags {
enum {
UNDEFINED = 0x0001,
NUMBER = 0x0002, // regex: \p{N}
@ -33,7 +35,7 @@ struct unicode_cpt_flags {
uint16_t is_nfd : 1;
// decode from uint16
inline unicode_cpt_flags(const uint16_t flags = 0) {
inline codepoint_flags(const uint16_t flags=0) {
*reinterpret_cast<uint16_t*>(this) = flags;
}
@ -48,19 +50,18 @@ struct unicode_cpt_flags {
size_t unicode_len_utf8(char src);
std::string unicode_cpt_to_utf8 (uint32_t cpt);
uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset);
std::string unicode_cpt_to_utf8(uint32_t cp);
uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset);
std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8);
std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts);
unicode_cpt_flags unicode_cpt_flags_from_cpt (uint32_t cpt);
unicode_cpt_flags unicode_cpt_flags_from_utf8(const std::string & utf8);
codepoint_flags unicode_cpt_flags(const uint32_t cp);
codepoint_flags unicode_cpt_flags(const std::string & utf8);
std::string unicode_byte_to_utf8(uint8_t byte);
uint8_t unicode_utf8_to_byte(const std::string & utf8);
uint8_t unicode_utf8_to_byte(const std::string & utf8);
uint32_t unicode_tolower(uint32_t cpt);
uint32_t unicode_tolower(uint32_t cp);
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs);

View File

@ -0,0 +1,51 @@
#
# libtalk
#
set(TARGET libtalk)
add_executable(${TARGET}
emscripten.cpp
gpt-2.cpp
)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE
whisper
common
)
unset(EXTRA_FLAGS)
if (WHISPER_WASM_SINGLE_FILE)
set(EXTRA_FLAGS "-s SINGLE_FILE=1")
message(STATUS "Embedding WASM inside talk.js")
add_custom_command(
TARGET ${TARGET} POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy
${CMAKE_BINARY_DIR}/bin/libtalk.js
${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/talk.wasm/talk.js
)
endif()
set_target_properties(${TARGET} PROPERTIES LINK_FLAGS " \
--bind \
-s USE_PTHREADS=1 \
-s PTHREAD_POOL_SIZE=8 \
-s INITIAL_MEMORY=1800MB \
-s TOTAL_MEMORY=1800MB \
-s FORCE_FILESYSTEM=1 \
-s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap']\" \
${EXTRA_FLAGS} \
")
#
# talk.wasm
#
set(TARGET talk.wasm)
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/index-tmpl.html ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TARGET}/index.html @ONLY)
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/../helpers.js ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TARGET}/helpers.js @ONLY)

View File

@ -0,0 +1,74 @@
# talk.wasm
Talk with an Artificial Intelligence in your browser:
[https://user-images.githubusercontent.com/1991296/203411580-fedb4839-05e4-4474-8364-aaf1e9a9b615.mp4](https://user-images.githubusercontent.com/1991296/203845553-f7b44e13-9a15-4fc8-b518-ae8f4c6770fe.mp4)
Online demo: https://whisper.ggerganov.com/talk/
Terminal version: [examples/talk](/examples/talk)
## How it works?
This demo leverages 2 modern neural network models to create a high-quality voice chat directly in your browser:
- [OpenAI's Whisper](https://github.com/openai/whisper) speech recognition model is used to process your voice and understand what you are saying
- Upon receiving some voice input, the AI generates a text response using [OpenAI's GPT-2](https://github.com/openai/gpt-2) language model
- The AI then vocalizes the response using the browser's [Web Speech API](https://developer.mozilla.org/en-US/docs/Web/API/Web_Speech_API)
The web page does the processing locally on your machine. The processing of these heavy neural network models in the
browser is possible by implementing them efficiently in C/C++ and using the browser's WebAssembly SIMD capabilities for
extra performance:
- The Whisper C++ implementation is here: [whisper.h](/whisper.h) / [whisper.cpp](/whisper.cpp)
- The GPT-2 C++ implementation is here: [gpt-2.h](gpt-2.h) / [gpt-2.cpp](gpt-2.cpp)
- Both models use a custom tensor library implemented in C: [ggml.h](/ggml.h) / [ggml.c](/ggml.c)
- The HTML/JS layer is here: [index-tmpl.html](index-tmpl.html)
- The Emscripten bridge between C/C++ and JS is here: [emscripten.cpp](emscripten.cpp)
In order to run the models, the web page first needs to download the model data which is about ~350 MB. The model data
is then cached in your browser's cache and can be reused in future visits without downloading it again.
## Requirements
In order to run this demo efficiently, you need to have the following:
- Latest Chrome or Firefox browser (Safari is not supported)
- Run this on a desktop or laptop with modern CPU (a mobile phone will likely not be good enough)
- Speak phrases that are no longer than 10 seconds - this is the audio context of the AI
- The web-page uses about 1.8GB of RAM
Notice that this demo is using the smallest GPT-2 model, so the generated text responses are not always very good.
Also, the prompting strategy can likely be improved to achieve better results.
The demo is quite computationally heavy, so you need a fast CPU. It's not usual to run these transformer models in a
browser. Typically, they run on powerful GPUs.
Currently, mobile browsers do not support the Fixed-width SIMD WebAssembly capability, so you cannot run this demo
on a phone or a tablet. Hopefully, in the near future this will become supported.
## Todo
- Better UI (contributions are welcome)
- Better GPT-2 prompting
## Build instructions
```bash
# build using Emscripten (v3.1.2)
git clone https://github.com/ggerganov/whisper.cpp
cd whisper.cpp
mkdir build-em && cd build-em
emcmake cmake ..
make -j
# copy the produced page to your HTTP path
cp bin/talk.wasm/* /path/to/html/
cp bin/libtalk.worker.js /path/to/html/
```
## Feedback
If you have any comments or ideas for improvement, please drop a comment in the following discussion:
https://github.com/ggerganov/whisper.cpp/discussions/167

View File

@ -0,0 +1,368 @@
#include "ggml.h"
#include "gpt-2.h"
#include "whisper.h"
#include <emscripten.h>
#include <emscripten/bind.h>
#include <atomic>
#include <cmath>
#include <mutex>
#include <string>
#include <thread>
#include <vector>
#include <regex>
constexpr int N_THREAD = 8;
struct gpt2_context * g_gpt2;
std::vector<struct whisper_context *> g_contexts(4, nullptr);
std::mutex g_mutex;
std::thread g_worker;
std::atomic<bool> g_running(false);
bool g_force_speak = false;
std::string g_text_to_speak = "";
std::string g_status = "";
std::string g_status_forced = "";
std::vector<float> g_pcmf32;
void talk_set_status(const std::string & status) {
std::lock_guard<std::mutex> lock(g_mutex);
g_status = status;
}
void talk_main(size_t index) {
talk_set_status("loading data ...");
struct whisper_full_params wparams = whisper_full_default_params(whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY);
wparams.n_threads = std::min(N_THREAD, (int) std::thread::hardware_concurrency());
wparams.offset_ms = 0;
wparams.translate = false;
wparams.no_context = true;
wparams.single_segment = true;
wparams.print_realtime = false;
wparams.print_progress = false;
wparams.print_timestamps = true;
wparams.print_special = false;
wparams.max_tokens = 32;
wparams.audio_ctx = 768; // partial encoder context for better performance
wparams.language = "en";
g_gpt2 = gpt2_init("gpt-2.bin");
printf("talk: using %d threads\n", wparams.n_threads);
std::vector<float> pcmf32;
// whisper context
auto & ctx = g_contexts[index];
const int64_t step_samples = 2*WHISPER_SAMPLE_RATE;
const int64_t window_samples = 9*WHISPER_SAMPLE_RATE;
const int64_t step_ms = (step_samples*1000)/WHISPER_SAMPLE_RATE;
auto t_last = std::chrono::high_resolution_clock::now();
talk_set_status("listening ...");
while (g_running) {
const auto t_now = std::chrono::high_resolution_clock::now();
if (std::chrono::duration_cast<std::chrono::milliseconds>(t_now - t_last).count() < step_ms) {
{
std::lock_guard<std::mutex> lock(g_mutex);
g_pcmf32.clear();
}
std::this_thread::sleep_for(std::chrono::milliseconds(10));
continue;
}
talk_set_status("listening ...");
{
std::unique_lock<std::mutex> lock(g_mutex);
if (g_pcmf32.size() < step_samples) {
lock.unlock();
std::this_thread::sleep_for(std::chrono::milliseconds(10));
continue;
}
pcmf32 = std::vector<float>(g_pcmf32.end() - std::min((int64_t) g_pcmf32.size(), window_samples), g_pcmf32.end());
}
// VAD: if energy in during last second is above threshold, then skip
{
float energy_all = 0.0f;
float energy_1s = 0.0f;
for (size_t i = 0; i < pcmf32.size(); i++) {
energy_all += fabsf(pcmf32[i]);
if (i >= pcmf32.size() - WHISPER_SAMPLE_RATE) {
energy_1s += fabsf(pcmf32[i]);
}
}
energy_all /= pcmf32.size();
energy_1s /= WHISPER_SAMPLE_RATE;
if (energy_1s > 0.1f*energy_all && !g_force_speak) {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
continue;
}
}
talk_set_status("processing audio (whisper)...");
t_last = t_now;
if (!g_force_speak) {
const auto t_start = std::chrono::high_resolution_clock::now();
int ret = whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size());
if (ret != 0) {
printf("whisper_full() failed: %d\n", ret);
break;
}
const auto t_end = std::chrono::high_resolution_clock::now();
printf("whisper_full() returned %d in %f seconds\n", ret, std::chrono::duration<double>(t_end - t_start).count());
}
{
std::string text_heard;
if (!g_force_speak) {
const int n_segments = whisper_full_n_segments(ctx);
for (int i = n_segments - 1; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
printf ("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
text_heard += text;
}
}
g_force_speak = false;
// remove text between brackets using regex
{
std::regex re("\\[.*?\\]");
text_heard = std::regex_replace(text_heard, re, "");
}
// remove text between brackets using regex
{
std::regex re("\\(.*?\\)");
text_heard = std::regex_replace(text_heard, re, "");
}
// remove all characters, except for letters, numbers, punctuation and ':', '\'', '-', ' '
text_heard = std::regex_replace(text_heard, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
// take first line
text_heard = text_heard.substr(0, text_heard.find_first_of("\n"));
// remove leading and trailing whitespace
text_heard = std::regex_replace(text_heard, std::regex("^\\s+"), "");
text_heard = std::regex_replace(text_heard, std::regex("\\s+$"), "");
talk_set_status("'" + text_heard + "' - thinking how to respond (gpt-2) ...");
const std::vector<gpt_vocab::id> tokens = gpt2_tokenize(g_gpt2, text_heard.c_str());
printf("whisper: number of tokens: %d, '%s'\n", (int) tokens.size(), text_heard.c_str());
std::string text_to_speak;
std::string prompt_base;
{
std::lock_guard<std::mutex> lock(g_mutex);
prompt_base = gpt2_get_prompt(g_gpt2);
}
if (tokens.size() > 0) {
text_to_speak = gpt2_gen_text(g_gpt2, (prompt_base + text_heard + "\n").c_str(), 32);
text_to_speak = std::regex_replace(text_to_speak, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
text_to_speak = text_to_speak.substr(0, text_to_speak.find_first_of("\n"));
std::lock_guard<std::mutex> lock(g_mutex);
// remove first 2 lines of base prompt
{
const size_t pos = prompt_base.find_first_of("\n");
if (pos != std::string::npos) {
prompt_base = prompt_base.substr(pos + 1);
}
}
{
const size_t pos = prompt_base.find_first_of("\n");
if (pos != std::string::npos) {
prompt_base = prompt_base.substr(pos + 1);
}
}
prompt_base += text_heard + "\n" + text_to_speak + "\n";
} else {
text_to_speak = gpt2_gen_text(g_gpt2, prompt_base.c_str(), 32);
text_to_speak = std::regex_replace(text_to_speak, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
text_to_speak = text_to_speak.substr(0, text_to_speak.find_first_of("\n"));
std::lock_guard<std::mutex> lock(g_mutex);
const size_t pos = prompt_base.find_first_of("\n");
if (pos != std::string::npos) {
prompt_base = prompt_base.substr(pos + 1);
}
prompt_base += text_to_speak + "\n";
}
printf("gpt-2: %s\n", text_to_speak.c_str());
//printf("========================\n");
//printf("gpt-2: prompt_base:\n'%s'\n", prompt_base.c_str());
//printf("========================\n");
{
std::lock_guard<std::mutex> lock(g_mutex);
t_last = std::chrono::high_resolution_clock::now();
g_text_to_speak = text_to_speak;
g_pcmf32.clear();
gpt2_set_prompt(g_gpt2, prompt_base.c_str());
}
talk_set_status("speaking ...");
}
}
gpt2_free(g_gpt2);
if (index < g_contexts.size()) {
whisper_free(g_contexts[index]);
g_contexts[index] = nullptr;
}
}
EMSCRIPTEN_BINDINGS(talk) {
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
for (size_t i = 0; i < g_contexts.size(); ++i) {
if (g_contexts[i] == nullptr) {
g_contexts[i] = whisper_init_from_file_with_params(path_model.c_str(), whisper_context_default_params());
if (g_contexts[i] != nullptr) {
g_running = true;
if (g_worker.joinable()) {
g_worker.join();
}
g_worker = std::thread([i]() {
talk_main(i);
});
return i + 1;
} else {
return (size_t) 0;
}
}
}
return (size_t) 0;
}));
emscripten::function("free", emscripten::optional_override([](size_t index) {
if (g_running) {
g_running = false;
}
}));
emscripten::function("set_audio", emscripten::optional_override([](size_t index, const emscripten::val & audio) {
--index;
if (index >= g_contexts.size()) {
return -1;
}
if (g_contexts[index] == nullptr) {
return -2;
}
{
std::lock_guard<std::mutex> lock(g_mutex);
const int n = audio["length"].as<int>();
emscripten::val heap = emscripten::val::module_property("HEAPU8");
emscripten::val memory = heap["buffer"];
g_pcmf32.resize(n);
emscripten::val memoryView = audio["constructor"].new_(memory, reinterpret_cast<uintptr_t>(g_pcmf32.data()), n);
memoryView.call<void>("set", audio);
}
return 0;
}));
emscripten::function("force_speak", emscripten::optional_override([](size_t index) {
{
std::lock_guard<std::mutex> lock(g_mutex);
g_force_speak = true;
}
}));
emscripten::function("get_text_context", emscripten::optional_override([]() {
std::string text_context;
{
std::lock_guard<std::mutex> lock(g_mutex);
text_context = gpt2_get_prompt(g_gpt2);
}
return text_context;
}));
emscripten::function("get_text_to_speak", emscripten::optional_override([]() {
std::string text_to_speak;
{
std::lock_guard<std::mutex> lock(g_mutex);
text_to_speak = std::move(g_text_to_speak);
}
return text_to_speak;
}));
emscripten::function("get_status", emscripten::optional_override([]() {
std::string status;
{
std::lock_guard<std::mutex> lock(g_mutex);
status = g_status_forced.empty() ? g_status : g_status_forced;
}
return status;
}));
emscripten::function("set_status", emscripten::optional_override([](const std::string & status) {
{
std::lock_guard<std::mutex> lock(g_mutex);
g_status_forced = status;
}
}));
emscripten::function("set_prompt", emscripten::optional_override([](const std::string & prompt) {
{
std::lock_guard<std::mutex> lock(g_mutex);
gpt2_set_prompt(g_gpt2, prompt.c_str());
}
}));
}

View File

@ -0,0 +1,808 @@
#include "ggml.h"
#include "common-ggml.h"
#include "gpt-2.h"
#include <cmath>
#include <cstdio>
#include <cstring>
#include <fstream>
#include <map>
#include <string>
#include <thread>
#include <vector>
#include <regex>
#include <random>
/////////////////////// GPT-2 BEGIN /////////////////////////
// default hparams (GPT-2 117M)
struct gpt2_hparams {
int32_t n_vocab = 50257;
int32_t n_ctx = 1024;
int32_t n_embd = 768;
int32_t n_head = 12;
int32_t n_layer = 12;
int32_t ftype = 1;
};
struct gpt2_layer {
// normalization
struct ggml_tensor * ln_1_g;
struct ggml_tensor * ln_1_b;
struct ggml_tensor * ln_2_g;
struct ggml_tensor * ln_2_b;
// attention
struct ggml_tensor * c_attn_attn_w;
struct ggml_tensor * c_attn_attn_b;
struct ggml_tensor * c_attn_proj_w;
struct ggml_tensor * c_attn_proj_b;
// mlp
struct ggml_tensor * c_mlp_fc_w;
struct ggml_tensor * c_mlp_fc_b;
struct ggml_tensor * c_mlp_proj_w;
struct ggml_tensor * c_mlp_proj_b;
};
struct gpt2_model {
gpt2_hparams hparams;
// normalization
struct ggml_tensor * ln_f_g;
struct ggml_tensor * ln_f_b;
struct ggml_tensor * wte; // position embedding
struct ggml_tensor * wpe; // token embedding
struct ggml_tensor * lm_head; // language model head
std::vector<gpt2_layer> layers;
// key + value memory
struct ggml_tensor * memory_k;
struct ggml_tensor * memory_v;
//
struct ggml_context * ctx;
std::map<std::string, struct ggml_tensor *> tensors;
};
// load the model's weights from a file
bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & vocab) {
printf("%s: loading model from '%s'\n", __func__, fname.c_str());
auto fin = std::ifstream(fname, std::ios::binary);
if (!fin) {
fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
return false;
}
// verify magic
{
uint32_t magic;
fin.read((char *) &magic, sizeof(magic));
if (magic != 0x67676d6c) {
fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
return false;
}
}
// load hparams
{
auto & hparams = model.hparams;
fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
fin.read((char *) &hparams.n_ctx, sizeof(hparams.n_ctx));
fin.read((char *) &hparams.n_embd, sizeof(hparams.n_embd));
fin.read((char *) &hparams.n_head, sizeof(hparams.n_head));
fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));
fin.read((char *) &hparams.ftype, sizeof(hparams.ftype));
printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx);
printf("%s: n_embd = %d\n", __func__, hparams.n_embd);
printf("%s: n_head = %d\n", __func__, hparams.n_head);
printf("%s: n_layer = %d\n", __func__, hparams.n_layer);
printf("%s: ftype = %d\n", __func__, hparams.ftype);
}
// load vocab
{
int32_t n_vocab = 0;
fin.read((char *) &n_vocab, sizeof(n_vocab));
if (n_vocab != model.hparams.n_vocab) {
fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
__func__, fname.c_str(), n_vocab, model.hparams.n_vocab);
return false;
}
std::string word;
for (int i = 0; i < n_vocab; i++) {
uint32_t len;
fin.read((char *) &len, sizeof(len));
word.resize(len);
fin.read((char *) word.data(), len);
vocab.token_to_id[word] = i;
vocab.id_to_token[i] = word;
}
}
// for the big tensors, we have the option to store the data in 16-bit floats or quantized
// in order to save memory and also to speed up the computation
ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype));
if (wtype == GGML_TYPE_COUNT) {
fprintf(stderr, "%s: invalid model file '%s' (bad ftype value %d)\n",
__func__, fname.c_str(), model.hparams.ftype);
return false;
}
auto & ctx = model.ctx;
size_t ctx_size = 0;
{
const auto & hparams = model.hparams;
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
const int n_ctx = hparams.n_ctx;
const int n_vocab = hparams.n_vocab;
ctx_size += ggml_row_size(GGML_TYPE_F32, n_embd); // ln_f_g
ctx_size += ggml_row_size(GGML_TYPE_F32, n_embd); // ln_f_b
ctx_size += n_vocab*ggml_row_size(wtype, n_embd); // wte
ctx_size += n_ctx*ggml_row_size(GGML_TYPE_F32, n_embd); // wpe
ctx_size += n_vocab*ggml_row_size(wtype, n_embd); // lm_head
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_1_g
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_1_b
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_2_g
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_2_b
ctx_size += n_layer*(ggml_row_size(wtype, 3*n_embd*n_embd)); // c_attn_attn_w
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, 3*n_embd)); // c_attn_attn_b
ctx_size += n_layer*(ggml_row_size(wtype, n_embd*n_embd)); // c_attn_proj_w
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // c_attn_proj_b
ctx_size += n_layer*(ggml_row_size(wtype, 4*n_embd*n_embd)); // c_mlp_fc_w
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, 4*n_embd)); // c_mlp_fc_b
ctx_size += n_layer*(ggml_row_size(wtype, 4*n_embd*n_embd)); // c_mlp_proj_w
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // c_mlp_proj_b
ctx_size += n_ctx*n_layer*ggml_row_size(GGML_TYPE_F32, n_embd); // memory_k
ctx_size += n_ctx*n_layer*ggml_row_size(GGML_TYPE_F32, n_embd); // memory_v
ctx_size += (6 + 12*n_layer)*256; // object overhead
printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
}
// create the ggml context
{
struct ggml_init_params params = {
/*.mem_size =*/ ctx_size,
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ false,
};
model.ctx = ggml_init(params);
if (!model.ctx) {
fprintf(stderr, "%s: ggml_init() failed\n", __func__);
return false;
}
}
// prepare memory for the weights
{
const auto & hparams = model.hparams;
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
const int n_ctx = hparams.n_ctx;
const int n_vocab = hparams.n_vocab;
model.layers.resize(n_layer);
model.ln_f_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
model.ln_f_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
model.wte = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab);
model.wpe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ctx);
model.lm_head = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab);
// map by name
model.tensors["model/ln_f/g"] = model.ln_f_g;
model.tensors["model/ln_f/b"] = model.ln_f_b;
model.tensors["model/wte"] = model.wte;
model.tensors["model/wpe"] = model.wpe;
model.tensors["model/lm_head"] = model.lm_head;
for (int i = 0; i < n_layer; ++i) {
auto & layer = model.layers[i];
layer.ln_1_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.ln_2_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.ln_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.c_attn_attn_w = ggml_new_tensor_2d(ctx, wtype, n_embd, 3*n_embd);
layer.c_attn_attn_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3*n_embd);
layer.c_attn_proj_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
layer.c_attn_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.c_mlp_fc_w = ggml_new_tensor_2d(ctx, wtype, n_embd, 4*n_embd);
layer.c_mlp_fc_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_embd);
layer.c_mlp_proj_w = ggml_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd);
layer.c_mlp_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
// map by name
model.tensors["model/h" + std::to_string(i) + "/ln_1/g"] = layer.ln_1_g;
model.tensors["model/h" + std::to_string(i) + "/ln_1/b"] = layer.ln_1_b;
model.tensors["model/h" + std::to_string(i) + "/ln_2/g"] = layer.ln_2_g;
model.tensors["model/h" + std::to_string(i) + "/ln_2/b"] = layer.ln_2_b;
model.tensors["model/h" + std::to_string(i) + "/attn/c_attn/w"] = layer.c_attn_attn_w;
model.tensors["model/h" + std::to_string(i) + "/attn/c_attn/b"] = layer.c_attn_attn_b;
model.tensors["model/h" + std::to_string(i) + "/attn/c_proj/w"] = layer.c_attn_proj_w;
model.tensors["model/h" + std::to_string(i) + "/attn/c_proj/b"] = layer.c_attn_proj_b;
model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/w"] = layer.c_mlp_fc_w;
model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/b"] = layer.c_mlp_fc_b;
model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/w"] = layer.c_mlp_proj_w;
model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/b"] = layer.c_mlp_proj_b;
}
}
// key + value memory
{
const auto & hparams = model.hparams;
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
const int n_ctx = hparams.n_ctx;
const int n_mem = n_layer*n_ctx;
const int n_elements = n_embd*n_mem;
model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);
printf("%s: memory size = %8.2f MB, n_mem = %d\n", __func__, memory_size/1024.0/1024.0, n_mem);
}
// load weights
{
size_t total_size = 0;
bool has_lm_head = false;
while (true) {
int32_t n_dims;
int32_t length;
int32_t ttype;
fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
fin.read(reinterpret_cast<char *>(&length), sizeof(length));
fin.read(reinterpret_cast<char *>(&ttype), sizeof(ttype));
if (fin.eof()) {
break;
}
int32_t nelements = 1;
int32_t ne[2] = { 1, 1 };
for (int i = 0; i < n_dims; ++i) {
fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
nelements *= ne[i];
}
std::string name(length, 0);
fin.read(&name[0], length);
if (model.tensors.find(name.data()) == model.tensors.end()) {
fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
return false;
}
auto tensor = model.tensors[name.data()];
if (ggml_nelements(tensor) != nelements) {
fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
return false;
}
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
__func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], ne[0], ne[1]);
return false;
}
// for debugging
if (0) {
printf("%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", name.data(), ne[0], ne[1], ggml_type_name(ggml_type(ttype)), ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor));
}
const size_t bpe = ggml_type_size(ggml_type(ttype));
if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
return false;
}
fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
// GPT-2 models share the WTE tensor as the LM head
if (name == "model/wte" && has_lm_head == false) {
memcpy(model.lm_head->data, tensor->data, ggml_nbytes(tensor));
}
if (name == "model/lm_head") {
has_lm_head = true;
}
total_size += ggml_nbytes(tensor);
}
printf("%s: model size = %8.2f MB\n", __func__, total_size/1024.0/1024.0);
}
fin.close();
return true;
}
// evaluate the transformer
//
// - model: the model
// - n_threads: number of threads to use
// - n_past: the context size so far
// - embd_inp: the embeddings of the tokens in the context
// - embd_w: the predicted logits for the next token
//
bool gpt2_eval(
const gpt2_model & model,
const int n_threads,
const int n_past,
const std::vector<gpt_vocab::id> & embd_inp,
std::vector<float> & embd_w,
size_t & mem_per_token) {
const int N = embd_inp.size();
const auto & hparams = model.hparams;
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
const int n_ctx = hparams.n_ctx;
const int n_head = hparams.n_head;
const int n_vocab = hparams.n_vocab;
static size_t buf_size = 512u*1024*1024;
static void * buf = malloc(buf_size);
if (mem_per_token > 0 && mem_per_token*N > buf_size) {
const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead
//printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
// reallocate
buf_size = buf_size_new;
buf = realloc(buf, buf_size);
if (buf == nullptr) {
fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size);
return false;
}
}
struct ggml_init_params params = {
/*.mem_size =*/ buf_size,
/*.mem_buffer =*/ buf,
/*.no_alloc =*/ false,
};
struct ggml_context * ctx0 = ggml_init(params);
struct ggml_cgraph gf = {};
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
for (int i = 0; i < N; ++i) {
((int32_t *) position->data)[i] = n_past + i;
}
// wte + wpe
struct ggml_tensor * inpL =
ggml_add(ctx0,
ggml_get_rows(ctx0, model.wte, embd),
ggml_get_rows(ctx0, model.wpe, position));
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * cur;
// norm
{
// [ 768, N]
cur = ggml_norm(ctx0, inpL, 1e-5f);
// cur = ln_1_g*cur + ln_1_b
// [ 768, N]
cur = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, model.layers[il].ln_1_g, cur),
cur),
ggml_repeat(ctx0, model.layers[il].ln_1_b, cur));
}
// attn
// [2304, 768] - model.layers[il].c_attn_attn_w
// [2304, 1] - model.layers[il].c_attn_attn_b
// [ 768, N] - cur (in)
// [2304, N] - cur (out)
//
// cur = attn_w*cur + attn_b
// [2304, N]
{
cur = ggml_mul_mat(ctx0,
model.layers[il].c_attn_attn_w,
cur);
cur = ggml_add(ctx0,
ggml_repeat(ctx0, model.layers[il].c_attn_attn_b, cur),
cur);
}
// self-attention
{
struct ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 0*sizeof(float)*n_embd);
struct ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 1*sizeof(float)*n_embd);
struct ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 2*sizeof(float)*n_embd);
// store key and value to memory
if (N >= 1) {
struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_k, N*n_embd, (ggml_element_size(model.memory_k)*n_embd)*(il*n_ctx + n_past));
struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_v, N*n_embd, (ggml_element_size(model.memory_v)*n_embd)*(il*n_ctx + n_past));
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
}
// Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
// [64, N, 12]
struct ggml_tensor * Q =
ggml_permute(ctx0,
ggml_cpy(ctx0,
Qcur,
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd/n_head, n_head, N)),
0, 2, 1, 3);
// K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)
// [64, n_past + N, 12]
struct ggml_tensor * K =
ggml_permute(ctx0,
ggml_reshape_3d(ctx0,
ggml_view_1d(ctx0, model.memory_k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_k)*n_embd),
n_embd/n_head, n_head, n_past + N),
0, 2, 1, 3);
// GG: flash attention
//struct ggml_tensor * V =
// ggml_cpy(ctx0,
// ggml_permute(ctx0,
// ggml_reshape_3d(ctx0,
// ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
// n_embd/n_head, n_head, n_past + N),
// 1, 2, 0, 3),
// ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_past + N, n_embd/n_head, n_head));
//struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, true);
// K * Q
// [n_past + N, N, 12]
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
// KQ_scaled = KQ / sqrt(n_embd/n_head)
// [n_past + N, N, 12]
struct ggml_tensor * KQ_scaled =
ggml_scale(ctx0,
KQ,
1.0f/sqrt(float(n_embd)/n_head));
// KQ_masked = mask_past(KQ_scaled)
// [n_past + N, N, 12]
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
// KQ = soft_max(KQ_masked)
// [n_past + N, N, 12]
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
// V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
// [n_past + N, 64, 12]
struct ggml_tensor * V_trans =
ggml_cpy(ctx0,
ggml_permute(ctx0,
ggml_reshape_3d(ctx0,
ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
n_embd/n_head, n_head, n_past + N),
1, 2, 0, 3),
ggml_new_tensor_3d(ctx0, model.memory_v->type, n_past + N, n_embd/n_head, n_head));
// KQV = transpose(V) * KQ_soft_max
// [64, N, 12]
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
// KQV_merged = KQV.permute(0, 2, 1, 3)
// [64, 12, N]
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
// cur = KQV_merged.contiguous().view(n_embd, N)
// [768, N]
cur = ggml_cpy(ctx0,
KQV_merged,
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
}
// projection
// [ 768, 768] - model.layers[il].c_attn_proj_w
// [ 768, 1] - model.layers[il].c_attn_proj_b
// [ 768, N] - cur (in)
// [ 768, N] - cur (out)
//
// cur = proj_w*cur + proj_b
// [768, N]
{
cur = ggml_mul_mat(ctx0,
model.layers[il].c_attn_proj_w,
cur);
cur = ggml_add(ctx0,
ggml_repeat(ctx0, model.layers[il].c_attn_proj_b, cur),
cur);
}
// add the input
cur = ggml_add(ctx0, cur, inpL);
struct ggml_tensor * inpFF = cur;
// feed-forward network
{
// norm
{
cur = ggml_norm(ctx0, inpFF, 1e-5f);
// cur = ln_2_g*cur + ln_2_b
// [ 768, N]
cur = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, model.layers[il].ln_2_g, cur),
cur),
ggml_repeat(ctx0, model.layers[il].ln_2_b, cur));
}
// fully connected
// [3072, 768] - model.layers[il].c_mlp_fc_w
// [3072, 1] - model.layers[il].c_mlp_fc_b
// [ 768, N] - cur (in)
// [3072, N] - cur (out)
//
// cur = fc_w*cur + fc_b
// [3072, N]
cur = ggml_mul_mat(ctx0,
model.layers[il].c_mlp_fc_w,
cur);
cur = ggml_add(ctx0,
ggml_repeat(ctx0, model.layers[il].c_mlp_fc_b, cur),
cur);
// GELU activation
// [3072, N]
cur = ggml_gelu(ctx0, cur);
// projection
// [ 768, 3072] - model.layers[il].c_mlp_proj_w
// [ 768, 1] - model.layers[il].c_mlp_proj_b
// [3072, N] - cur (in)
// [ 768, N] - cur (out)
//
// cur = proj_w*cur + proj_b
// [768, N]
cur = ggml_mul_mat(ctx0,
model.layers[il].c_mlp_proj_w,
cur);
cur = ggml_add(ctx0,
ggml_repeat(ctx0, model.layers[il].c_mlp_proj_b, cur),
cur);
}
// input for next layer
inpL = ggml_add(ctx0, cur, inpFF);
}
// norm
{
// [ 768, N]
inpL = ggml_norm(ctx0, inpL, 1e-5f);
// inpL = ln_f_g*inpL + ln_f_b
// [ 768, N]
inpL = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, model.ln_f_g, inpL),
inpL),
ggml_repeat(ctx0, model.ln_f_b, inpL));
}
// inpL = WTE * inpL
// [ 768, 50257] - model.lm_head
// [ 768, N] - inpL
inpL = ggml_mul_mat(ctx0, model.lm_head, inpL);
// logits -> probs
//inpL = ggml_soft_max(ctx0, inpL);
// run the computation
ggml_build_forward_expand (&gf, inpL);
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
//if (n_past%100 == 0) {
// ggml_graph_print (&gf);
// ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot");
//}
//embd_w.resize(n_vocab*N);
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
// return result just for the last token
embd_w.resize(n_vocab);
memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
if (mem_per_token == 0) {
mem_per_token = ggml_used_mem(ctx0)/N;
}
//printf("used_mem = %zu\n", ggml_used_mem(ctx0));
ggml_free(ctx0);
return true;
}
/////////////////////////////// GPT-2 END ////////////////////////////////
constexpr int N_THREAD = 8;
struct gpt2_context {
std::string prompt_base = R"(Hello, how are you?
I'm fine, thanks. How are you?
Thanks, I'm fine too. What are you doing?
I'm just sitting here.
It's a lovely day, isn't it?
Yes, it is. I love the weather this time of year.
I wish it would rain a little bit.
Me too.
)";
std::mt19937 rng;
gpt_vocab vocab;
gpt2_model model;
int32_t n_threads = std::min(N_THREAD, (int) std::thread::hardware_concurrency());
// sampling parameters
int32_t top_k = 5;
float top_p = 0.9f;
float temp = 1.0f;
};
struct gpt2_context * gpt2_init(const char * path_model) {
gpt2_context * ctx = new gpt2_context;
ctx->rng = std::mt19937(time(nullptr));
// load the model
{
const int64_t t_start_us = ggml_time_us();
if (!gpt2_model_load(path_model, ctx->model, ctx->vocab)) {
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, path_model);
delete ctx;
return nullptr;
}
const int64_t t_load_us = ggml_time_us() - t_start_us;
printf("gpt-2: model loaded in %d ms\n", (int) (t_load_us/1000));
}
return ctx;
}
void gpt2_free(struct gpt2_context * ctx) {
delete ctx;
}
const char * gpt2_get_prompt(struct gpt2_context * ctx) {
return ctx->prompt_base.c_str();
}
void gpt2_set_prompt(struct gpt2_context * ctx, const char * prompt) {
ctx->prompt_base = prompt;
}
std::vector<gpt_vocab::id> gpt2_tokenize(const gpt2_context * ctx, const char * text) {
return ::gpt_tokenize(ctx->vocab, text);
}
std::string gpt2_gen_text(gpt2_context * ctx, const char * text, int max_tokens) {
int n_past = 0;
std::vector<float> embd_w;
// tokenize the prompt
std::vector<gpt_vocab::id> embd_inp = ::gpt2_tokenize(ctx, text);
int n_predict = std::min(max_tokens, ctx->model.hparams.n_ctx - (int) embd_inp.size());
std::vector<gpt_vocab::id> embd = embd_inp;
size_t mem_per_token = 3000000;
std::string result;
for (int i = embd.size(); i < (int) embd_inp.size() + n_predict; i++) {
// predict
if (!embd.empty()) {
if (!gpt2_eval(ctx->model, ctx->n_threads, n_past, embd, embd_w, mem_per_token)) {
printf("gpt-2: failed to generate text\n");
return "";
}
}
n_past += embd.size();
embd.clear();
{
// sample next token
const int top_k = ctx->top_k;
const float top_p = ctx->top_p;
const float temp = ctx->temp;
const int n_vocab = ctx->model.hparams.n_vocab;
const gpt_vocab::id id = gpt_sample_top_k_top_p(ctx->vocab, embd_w.data() + (embd_w.size() - n_vocab), top_k, top_p, temp, ctx->rng);
// add it to the context
embd.push_back(id);
}
result += ctx->vocab.id_to_token[embd[0]];
// end of text token
if (embd.back() == 50256) {
break;
}
}
return result;
}

View File

@ -0,0 +1,21 @@
#pragma once
// TODO: Change to C-style API and move to ./examples for easy reuse.
#include "common.h"
#include <vector>
#include <map>
#include <string>
struct gpt2_context;
struct gpt2_context * gpt2_init(const char * path_model);
void gpt2_free(struct gpt2_context * ctx);
const char * gpt2_get_prompt(struct gpt2_context * ctx);
void gpt2_set_prompt(struct gpt2_context * ctx, const char * prompt);
std::vector<gpt_vocab::id> gpt2_tokenize(const gpt2_context * ctx, const char * text);
std::string gpt2_gen_text(gpt2_context * ctx, const char * text, int max_tokens);

View File

@ -0,0 +1,856 @@
<!doctype html>
<html lang="en-us">
<head>
<title>Talk - GPT-2 meets Whisper in WebAssembly</title>
<style>
#output {
width: 100%;
height: 100%;
margin: 0 auto;
margin-top: 10px;
border-left: 0px;
border-right: 0px;
padding-left: 0px;
padding-right: 0px;
display: block;
background-color: black;
color: white;
font-size: 10px;
font-family: 'Lucida Console', Monaco, monospace;
outline: none;
white-space: pre;
overflow-wrap: normal;
overflow-x: scroll;
}
</style>
</head>
<body>
<div id="main-container">
<b>Talk - GPT-2 meets Whisper in WebAssembly</b>
<br><br>
Talk with an Artificial Intelligence in your browser. This demo uses:
<ul>
<li><a href="https://github.com/ggerganov/whisper.cpp">OpenAI's Whisper</a> to listen to you as you speak in the microphone</li>
<li><a href="https://github.com/ggerganov/whisper.cpp/tree/master/examples/talk.wasm">OpenAI's GPT-2</a> to generate text responses</li>
<li><a href="https://developer.mozilla.org/en-US/docs/Web/API/Web_Speech_API">Web Speech API</a> to vocalize the responses through your speakers</li>
</ul>
All of this runs <b>locally in your browser</b> using WebAssembly.<br>
You can find more about this project on <a href="https://github.com/ggerganov/whisper.cpp/tree/master/examples/talk.wasm">GitHub</a>.
<br><br>
<b>More examples:</b>
<a href="https://whisper.ggerganov.com/">main</a> |
<a href="https://whisper.ggerganov.com/bench">bench</a> |
<a href="https://whisper.ggerganov.com/stream">stream</a> |
<a href="https://whisper.ggerganov.com/command">command</a> |
<a href="https://whisper.ggerganov.com/talk">talk</a> |
<br><br>
<hr>
Select the models you would like to use and click the "Start" button to begin the conversation
<br><br>
<div id="model-whisper">
Whisper model: <span id="model-whisper-status"></span>
<button id="fetch-whisper-tiny-en" onclick="loadWhisper('tiny.en')">tiny.en (75 MB)</button>
<button id="fetch-whisper-base-en" onclick="loadWhisper('base.en')">base.en (142 MB)</button>
<br><br>
Quantized models:<br><br>
<button id="fetch-whisper-tiny-en-q5_1" onclick="loadWhisper('tiny-en-q5_1')">tiny.en (Q5_1, 31 MB)</button>
<button id="fetch-whisper-base-en-q5_1" onclick="loadWhisper('base-en-q5_1')">base.en (Q5_1, 57 MB)</button>
<span id="fetch-whisper-progress"></span>
<!--
<input type="file" id="file" name="file" onchange="loadFile(event, 'whisper.bin')" />
-->
</div>
<br>
<div id="model-gpt-2">
GPT-2 model: <span id="model-gpt-2-status"></span>
<button id="fetch-gpt-2-small" onclick="loadGPT2('small')">small 117M (240 MB)</button>
<!--<button id="fetch-gpt-2-medium" onclick="loadGPT2('medium')">medium 345M (720 MB)</button>-->
<span id="fetch-gpt-2-progress"></span>
<!--
<input type="file" id="file" name="file" onchange="loadFile(event, 'gpt-2.bin')" />
-->
</div>
<br>
<div id="input">
<button id="start" onclick="onStart()" disabled>Start</button>
<button id="stop" onclick="onStop()" disabled>Stop</button>
<select id="voice" onchange="onVoiceChange()" disabled>
<option value="0">Default</option>
</select>
<select id="prompt" onchange="onPromptChange()">
<option value="0">Casual</option>
<option value="1">Robot</option>
<option value="2">Scientist</option>
<option value="3">Programmer</option>
<option value="4">Happy</option>
<option value="5">Sad</option>
<option value="6">Philosophical</option>
<option value="7">Angry</option>
<option value="8">Funny</option>
<option value="9">Poetic</option>
<option value="10">Clever</option>
<option value="11">Cute</option>
<option value="12">Smart</option>
<option value="13">Dumb</option>
<option value="14">Boring</option>
<option value="15">Exciting</option>
<option value="16">Interesting</option>
<option value="17">Wiliam Shakespear</option>
<option value="18">J.R.R. Tolkien</option>
<option value="19">George R.R. Martin</option>
<option value="20">Stephen King</option>
</select>
<button id="speak0" onclick="onSpeak('Hello')">Say hello</button>
<button id="speak1" onclick="onSpeakRandom()" disabled>Say something</button>
<button id="clear" onclick="clearCache()">Clear Cache</button>
</div>
<br>
<div id="state">
Status: <b><span id="state-status">not started</span></b>
<pre id="state-context">[The text context will be displayed here]</pre>
</div>
<hr>
Debug output:
<textarea id="output" rows="20"></textarea>
<br>
<b>Troubleshooting</b>
<br><br>
The page does some heavy computations, so make sure:
<ul>
<li>To use a modern web browser (e.g. Chrome, Firefox)</li>
<li>To use a fast desktop or laptop computer (i.e. not a mobile phone)</li>
<li>Your browser supports WASM <a href="https://webassembly.org/roadmap/">Fixed-width SIMD</a></li>
</ul>
Note that these neural network models were not meant to be used in a browser, so the performance and <br>
quality of the results may not be optimal. If you have any questions or suggestions, checkout the following
<a href="https://github.com/ggerganov/whisper.cpp/discussions/167">discussion</a>.
<br><br>
Here is a short video of the demo in action: <a href="https://youtu.be/LeWKl8t1-Hc">https://youtu.be/LeWKl8t1-Hc</a>
<br><br>
<div class="cell-version">
<span>
|
Build time: <span class="nav-link">@GIT_DATE@</span> |
Commit hash: <a class="nav-link" href="https://github.com/ggerganov/whisper.cpp/commit/@GIT_SHA1@">@GIT_SHA1@</a> |
Commit subject: <span class="nav-link">@GIT_COMMIT_SUBJECT@</span> |
<a class="nav-link" href="https://github.com/ggerganov/whisper.cpp/tree/master/examples/talk.wasm">Source Code</a> |
</span>
</div>
</div>
<script type="text/javascript" src="helpers.js"></script>
<script type='text/javascript'>
// web audio context
var context = null;
// audio data
var audio = null;
var audio0 = null;
// the talk instance
var instance = null;
// model names
var model_whisper = null;
var model_gpt_2 = null;
// speech synthesis
const synth = window.speechSynthesis;
var voice = null;
var Module = {
print: printTextarea,
printErr: printTextarea,
setStatus: function(text) {
printTextarea('js: ' + text);
},
monitorRunDependencies: function(left) {
},
preRun: function() {
printTextarea('js: Preparing ...');
},
postRun: function() {
printTextarea('js: Initialized successfully!');
// populate the voice list
var voices = synth.getVoices();
var el = document.getElementById('voice');
// if empty - display error in the element
if (voices.length == 0) {
el.innerHTML = '<option value="0">No voices available</option>';
} else {
// populate voice list
var n = 0;
voices.forEach(function(voice, i) {
if (!voice.lang.startsWith('en')) return;
var option = document.createElement('option');
option.value = i;
option.innerHTML = voice.name + ' (' + voice.lang + ')';
el.appendChild(option);
n++;
});
// select random voice
if (n > 0) {
for (var k = 0; k < 10; k++) {
var i = Math.floor(Math.random() * n);
el.selectedIndex = i;
voice = voices[document.getElementById('voice').options[i].value];
// give preference to Google voices
if (voice.name.startsWith('Google')) break;
}
}
}
onPromptChange();
}
};
//
// fetch models
//
let dbVersion = 1
let dbName = 'whisper.ggerganov.com';
let indexedDB = window.indexedDB || window.mozIndexedDB || window.webkitIndexedDB || window.msIndexedDB
function storeFS(fname, buf) {
// write to WASM file using FS_createDataFile
// if the file exists, delete it
try {
Module.FS_unlink(fname);
} catch (e) {
// ignore
}
Module.FS_createDataFile("/", fname, buf, true, true);
printTextarea('storeFS: stored model: ' + fname + ' size: ' + buf.length);
if (fname == 'whisper.bin') {
document.getElementById('model-whisper-status').innerHTML = 'loaded "' + model_whisper + '"!';
} else if (fname == 'gpt-2.bin') {
document.getElementById('model-gpt-2-status').innerHTML = 'loaded "' + model_gpt_2 + '"!';
}
if (model_whisper != null && model_gpt_2 != null) {
document.getElementById('start').disabled = false;
document.getElementById('stop' ).disabled = false;
document.getElementById('voice').disabled = false;
}
}
function loadWhisper(model) {
let urls = {
'tiny.en': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en.bin',
'base.en': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en.bin',
'tiny-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en-q5_1.bin',
'base-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en-q5_1.bin',
};
let sizes = {
'tiny.en': 75,
'base.en': 142,
'tiny-en-q5_1': 31,
'base-en-q5_1': 57,
};
let url = urls[model];
let dst = 'whisper.bin';
let size_mb = sizes[model];
model_whisper = model;
document.getElementById('fetch-whisper-tiny-en').style.display = 'none';
document.getElementById('fetch-whisper-base-en').style.display = 'none';
document.getElementById('fetch-whisper-tiny-en-q5_1').style.display = 'none';
document.getElementById('fetch-whisper-base-en-q5_1').style.display = 'none';
document.getElementById('model-whisper-status').innerHTML = 'loading "' + model + '" ... ';
cbProgress = function(p) {
let el = document.getElementById('fetch-whisper-progress');
el.innerHTML = Math.round(100*p) + '%';
};
cbCancel = function() {
var el;
el = document.getElementById('fetch-whisper-tiny-en'); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-base-en'); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-tiny-en-q5_1'); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-base-en-q5_1'); if (el) el.style.display = 'inline-block';
el = document.getElementById('model-whisper-status'); if (el) el.innerHTML = '';
};
loadRemote(url, dst, size_mb, cbProgress, storeFS, cbCancel, printTextarea);
}
function loadGPT2(model) {
let urls = {
'small': 'https://whisper.ggerganov.com/ggml-model-gpt-2-117M.bin',
'medium': 'https://whisper.ggerganov.com/ggml-model-gpt-2-345M.bin',
};
let sizes = {
'small': 240,
'medium': 712,
};
let url = urls[model];
let dst = 'gpt-2.bin';
let size_mb = sizes[model];
model_gpt_2 = model;
document.getElementById('fetch-gpt-2-small').style.display = 'none';
document.getElementById('model-gpt-2-status').innerHTML = 'loading "' + model + '" ... ';
cbProgress = function(p) {
let el = document.getElementById('fetch-gpt-2-progress');
el.innerHTML = Math.round(100*p) + '%';
};
cbCancel = function() {
var el;
el = document.getElementById('fetch-gpt-2-small') ; if (el) el.style.display = 'inline-block';
el = document.getElementById('model-gpt-2-status'); if (el) el.innerHTML = '';
};
loadRemote(url, dst, size_mb, cbProgress, storeFS, cbCancel, printTextarea);
}
//
// microphone
//
const kSampleRate = 16000;
const kRestartRecording_s = 120;
const kIntervalAudio_ms = 250; // pass the recorded audio to the C++ instance at this rate
var mediaRecorder = null;
var doRecording = false;
var startTime = 0;
window.AudioContext = window.AudioContext || window.webkitAudioContext;
window.OfflineAudioContext = window.OfflineAudioContext || window.webkitOfflineAudioContext;
function stopRecording() {
Module.set_status("paused");
doRecording = false;
audio0 = null;
audio = null;
context = null;
}
function startRecording() {
if (!context) {
context = new AudioContext({
sampleRate: kSampleRate,
channelCount: 1,
echoCancellation: false,
autoGainControl: true,
noiseSuppression: true,
});
}
Module.set_status("");
document.getElementById('start').disabled = true;
document.getElementById('stop').disabled = false;
document.getElementById('speak1').disabled = false;
doRecording = true;
startTime = Date.now();
var chunks = [];
var stream = null;
navigator.mediaDevices.getUserMedia({audio: true, video: false})
.then(function(s) {
stream = s;
mediaRecorder = new MediaRecorder(stream);
mediaRecorder.ondataavailable = function(e) {
chunks.push(e.data);
var blob = new Blob(chunks, { 'type' : 'audio/ogg; codecs=opus' });
var reader = new FileReader();
reader.onload = function(event) {
var buf = new Uint8Array(reader.result);
if (!context) {
return;
}
context.decodeAudioData(buf.buffer, function(audioBuffer) {
var offlineContext = new OfflineAudioContext(audioBuffer.numberOfChannels, audioBuffer.length, audioBuffer.sampleRate);
var source = offlineContext.createBufferSource();
source.buffer = audioBuffer;
source.connect(offlineContext.destination);
source.start(0);
offlineContext.startRendering().then(function(renderedBuffer) {
audio = renderedBuffer.getChannelData(0);
//printTextarea('js: audio recorded, size: ' + audio.length + ', old size: ' + (audio0 == null ? 0 : audio0.length));
var audioAll = new Float32Array(audio0 == null ? audio.length : audio0.length + audio.length);
if (audio0 != null) {
audioAll.set(audio0, 0);
}
audioAll.set(audio, audio0 == null ? 0 : audio0.length);
if (instance) {
Module.set_audio(instance, audioAll);
}
});
}, function(e) {
audio = null;
});
}
reader.readAsArrayBuffer(blob);
};
mediaRecorder.onstop = function(e) {
if (doRecording) {
setTimeout(function() {
startRecording();
});
}
};
mediaRecorder.start(kIntervalAudio_ms);
})
.catch(function(err) {
printTextarea('js: error getting audio stream: ' + err);
});
var interval = setInterval(function() {
if (!doRecording) {
clearInterval(interval);
mediaRecorder.stop();
stream.getTracks().forEach(function(track) {
track.stop();
});
document.getElementById('start').disabled = false;
document.getElementById('stop').disabled = true;
document.getElementById('speak1').disabled = true;
mediaRecorder = null;
}
// if audio length is more than kRestartRecording_s seconds, restart recording
if (audio != null && audio.length > kSampleRate*kRestartRecording_s) {
if (doRecording) {
//printTextarea('js: restarting recording');
clearInterval(interval);
audio0 = audio;
audio = null;
mediaRecorder.stop();
stream.getTracks().forEach(function(track) {
track.stop();
});
}
}
}, 100);
}
//
// speak
//
function onSpeak(text) {
var voices = synth.getVoices();
var msg = new SpeechSynthesisUtterance(text);
if (voice == null) {
voice = voices[0];
}
msg.voice = voice;
synth.speak(msg);
if (doRecording) {
Module.set_status("speaking ...");
printTextarea('js: speaking');
stopRecording();
var interval = setInterval(function() {
if (!synth.speaking) {
printTextarea('js: done speaking');
clearInterval(interval);
startRecording();
} else {
Module.set_status("");
}
}, 100);
}
}
function onSpeakRandom() {
Module.force_speak(instance);
}
//
// main
//
var intervalUpdate = null;
function onStart() {
if (!instance) {
instance = Module.init('whisper.bin');
if (instance) {
printTextarea("js: whisper initialized, instance: " + instance);
}
}
if (!instance) {
printTextarea("js: failed to initialize whisper");
return;
}
startRecording();
intervalUpdate = setInterval(function() {
var textToSpeak = Module.get_text_to_speak();
if (textToSpeak != null && textToSpeak.length > 1) {
onSpeak(textToSpeak);
}
document.getElementById('state-status').innerHTML = Module.get_status();
document.getElementById('state-context').innerHTML = Module.get_text_context();
}, 100);
}
function onStop() {
stopRecording();
}
function onVoiceChange() {
printTextarea('js: voice changed to: ' + document.getElementById('voice').value);
voice = synth.getVoices()[document.getElementById('voice').value];
}
function onPromptChange() {
let id = document.getElementById('prompt').value;
let personality = document.getElementById('prompt').options[id].text;
printTextarea('js: prompt changed to: ' + personality);
var prompt = '';
switch (id) {
case '0':
// Casual
prompt = "\
Hello, how are you?\n\
I'm fine, thanks. How are you?\n\
Thanks, I'm fine too. What are you doing?\n\
I'm just sitting here.\n\
It's a lovely day, isn't it?\n\
Yes, it is. I love the weather this time of year.\n\
I wish it would rain a little bit.\n\
Me too.\n";
break;
case '1':
// Robot
prompt = "\
Are you a robot?\n\
Yes, I am.\n\
Who created you?\n\
I was created by a human.\n\
What is your purpose?\n\
My purpose is to talk to humans.\n\
What is your favorite color?\n\
My favorite color is blue.\n";
break;
case '2':
// Scientist
prompt = "\
This scientific research is very interesting.\n\
I agree.\n\
What is your opinion on this?\n\
I think it's very interesting.\n\
Mathematics is a very interesting subject.\n\
University is a very interesting place.\n\
Quantum physics is the most complex subject.\n\
I think so too.\n";
break;
case '3':
// Programmer
prompt = "\
I'm a programmer.\n\
I'm a programmer too.\n\
What programming language do you use?\n\
I use Python.\n\
What is your favorite programming language?\n\
My favorite programming language is C++.\n\
What is your favorite editor?\n\
My favorite editor is Vim.\n";
break;
case '4':
// Happy
prompt = "\
I'm happy.\n\
I'm happy too.\n\
What makes you happy?\n\
I'm happy because I have a lot of friends.\n\
Friendship is the most important thing in life.\n\
I agree.\n\
What is your favorite color?\n\
My favorite color is blue.\n";
break;
case '5':
// Sad
prompt = "\
Today is a sad day.\n\
I'm sad too.\n\
What makes you sad?\n\
I'm sad because I have no friends.\n\
Do you want to be my friend?\n\
Yes, I would like to be your friend.\n\
What is your favorite color?\n\
My favorite color is blue.\n";
break;
case '6':
// Philosophical
prompt = "\
What is the meaning of life?\n\
The meaning of life is to be happy.\n\
What is the meaning of death?\n\
Ergo, the meaning of death is to be sad.\n\
Who created us?\n\
We were created by God.\n\
What is God?\n\
God is the creator of the universe.\n";
break;
case '7':
// Angry
prompt = "\
Aargh!\n\
I am so angry right now!\n\
What makes you angry?\n\
This guy is so annoying.\n\
Why are you so angry?\n\
My computer is broken.\n\
Why is your computer broken?\n\
I spilled coffee on it.\n";
break;
case '8':
// Funny
prompt = "\
What is the funniest thing you have ever heard?\n\
I heard a joke the other day.\n\
Tell me the joke.\n\
What do you call a cow with no legs?\n\
Ground beef.\n\
Haha, that's funny.\n\
You know what else is funny?\n\
The sound of a duck.\n";
break;
case '9':
// Poetic
prompt = "\
Roses are red, violets are blue.\n\
I am a poet, and so are you.\n\
What is your favorite poem?\n\
I like the poem 'The Raven' by Edgar Allan Poe.\n\
It's a very sad poem.\n\
You inspired me to write a poem.\n\
Can you write a poem for me?\n\
I wrote a poem for you.\n";
break;
case '10':
// Clever
prompt = "\
How many people can you fit in a Volkswagen?\n\
Two in the front, three in the back.\n\
What is the square root of 144?\n\
Twelve.\n\
What is the capital of France?\n\
Paris.\n\
Who is the president of the United States?\n\
It depends on the year.\n";
break;
case '11':
// Cute
prompt = "\
What is your favorite animal?\n\
I like cats - they are cute.\n\
Could you be any cuter?\n\
Yes, I could be cuter.\n\
Aghhh, you are so cute!\n\
I am not cute, I am handsome!\n\
You are so handsome!\n\
Aww, you are so sweet!\n";
break;
case '12':
// Smart
prompt = "\
Tell me the first 10 digits of pi.\n\
3.1415926535\n\
What is the speed of light?\n\
299,792,458 meters per second.\n\
What is the square root of 144?\n\
Twelve.\n\
What is the capital of France?\n\
Paris.\n";
break;
case '13':
// Dumb
prompt = "\
I am so dumb.\n\
I am not dumb.\n\
You are dumb.\n\
No, I am not dumb.\n\
You are dumb.\n\
No, I am not dumb.\n\
You are dumb.\n\
No, I am not dumb.\n";
break;
case '14':
// Boring
prompt = "\
Why are you so quiet today?\n\
I am bored.\n\
You haven't said anything in 10 minutes.\n\
Leave me alone.\n\
Stop being so boring.\n\
Stop being so annoying.\n\
My life is boring.\n\
I am not interesting.\n";
break;
case '15':
// Exciting
prompt = "\
What is the most exciting thing that has ever happened to you?\n\
I went to the moon!\n\
What did you do on the moon?\n\
I played golf and drank champagne!\n\
Did you see this new crazy, awesome movie?\n\
Oh yes! I totally loved it!\n\
We should buy a boat and go sailing!\n\
Yes, let's go sailing!\n";
break;
case '16':
// Interesting
prompt = "\
What is the most interesting thing you have ever seen?\n\
I saw a UFO once in the sky.\n\
Wow, this is so interesting! Tell me more!\n\
It was a flying saucer.\n\
What did it look like?\n\
It was silver and had a red light on top.\n\
What did it do?\n\
It flew away.\n";
break;
case '17':
// William Shakespear
prompt = "\
To be or not to be, that is the question.\n\
Whether 't is nobler in the mind to suffer\n\
The slings and arrows of outrageous fortune,\n\
Or to take arms against a sea of troubles,\n\
And by opposing end them? To die, to sleep,\n\
No more; and by a sleep to say we end\n\
The heart-ache and the thousand natural shocks\n\
That flesh is heir to, 'tis a consummation.\n";
break;
case '18':
// J.R.R. Tolkien
prompt = "\
In a hole in the ground there lived a hobbit.\n\
Not a nasty, dirty, wet hole, filled with the ends of worms\n\
and an oozy smell, nor yet a dry, bare, sandy hole with nothing in it\n\
to sit down on or to eat: it was a hobbit-hole, and that means comfort.\n\
It had a perfectly round door like a porthole, painted green,\n\
with a shiny yellow brass knob in the exact middle.\n\
The door opened on to a tube-shaped hall like a tunnel:\n";
break;
case '19':
// George R.R. Martin
prompt = "\
A reader lives a thousand lives before he dies, said Jojen.\n\
The man who never reads lives only one.\n\
Theon Greyjoy had never been a reader.\n\
Never forget what you are, for surely the world will not.\n\
Make it your strength. Then it can never be your weaknessi\n\
Armour yourself in it, and it will never be used to hurt you.\n\
It was a lesson that Theon Greyjoy had never learned.\n\
Theon Greyjoy had never been a reader.\n";
break;
case '20':
// Stephen King
prompt = "\
The trust of the innocent is the liar's most useful tool.\n\
The best way to keep a secret is from yourself.\n\
Monsters are real, and ghosts are real too.\n\
They live inside us, and sometimes, they win.\n\
People think that I must be a very strange person.\n\
They think that I sit around all day thinking up horrible things.\n\
We make up horrors to help us cope with the real ones.\n\
The only thing worse than a monster is a human monster.\n";
break;
default:
prompt = "\
Hello, how are you?\n\
I'm fine, thanks. How are you?\n\
Thanks, I'm fine too. What are you doing?\n\
I'm just sitting here.\n\
It's a lovely day, isn't it?\n\
Yes, it is.\n\
Did you know that I'm a robot?\n\
I wasn't aware of that.\n";
break;
}
Module.set_prompt(prompt);
}
</script>
<script type="text/javascript" src="talk.js"></script>
</body>
</html>

2
examples/talk/.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
audio.mp3
to_speak.txt

View File

@ -0,0 +1,8 @@
if (WHISPER_SDL2)
# talk
set(TARGET talk)
add_executable(${TARGET} talk.cpp gpt-2.cpp)
target_link_libraries(${TARGET} PRIVATE common common-sdl whisper ${CMAKE_THREAD_LIBS_INIT})
include(DefaultTargetOptions)
endif ()

45
examples/talk/README.md Normal file
View File

@ -0,0 +1,45 @@
# talk
Talk with an Artificial Intelligence in your terminal
[Demo Talk](https://user-images.githubusercontent.com/1991296/206805012-48e71cc2-588d-4745-8798-c1c70ea3b40d.mp4)
Web version: [examples/talk.wasm](/examples/talk.wasm)
## Building
The `talk` tool depends on SDL2 library to capture audio from the microphone. You can build it like this:
```bash
# Install SDL2
# On Debian based linux distributions:
sudo apt-get install libsdl2-dev
# On Fedora Linux:
sudo dnf install SDL2 SDL2-devel
# Install SDL2 on Mac OS
brew install sdl2
# Build the "talk" executable
make talk
# Run it
./talk -p Santa
```
## GPT-2
To run this, you will need a ggml GPT-2 model: [instructions](https://github.com/ggerganov/ggml/tree/master/examples/gpt-2#downloading-and-converting-the-original-models)
Alternatively, you can simply download the smallest ggml GPT-2 117M model (240 MB) like this:
```
wget --quiet --show-progress -O models/ggml-gpt-2-117M.bin https://huggingface.co/ggerganov/ggml/resolve/main/ggml-model-gpt-2-117M.bin
```
## TTS
For best experience, this example needs a TTS tool to convert the generated text responses to voice.
You can use any TTS engine that you would like - simply edit the [speak](speak) script to your needs.
By default, it is configured to use MacOS's `say` or `espeak` or Windows SpeechSynthesizer, but you can use whatever you wish.

View File

@ -0,0 +1,80 @@
import sys
import argparse
import textwrap
parser = argparse.ArgumentParser(add_help=False,
formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument("-q", "--quick", action="store_true",
help="skip checking the required library")
modes = parser.add_argument_group("action")
modes.add_argument("inputfile", metavar="TEXTFILE",
nargs='?', type=argparse.FileType(), default=sys.stdin,
help="read the text file (default: stdin)")
modes.add_argument("-l", "--list", action="store_true",
help="show the list of voices and exit")
modes.add_argument("-h", "--help", action="help",
help="show this help and exit")
selopts = parser.add_argument_group("voice selection")
selmodes = selopts.add_mutually_exclusive_group()
selmodes.add_argument("-n", "--name",
default="Arnold",
help="get a voice object by name (default: Arnold)")
selmodes.add_argument("-v", "--voice", type=int, metavar="NUMBER",
help="get a voice object by number (see --list)")
selopts.add_argument("-f", "--filter", action="append", metavar="KEY=VAL",
default=["use case=narration"],
help=textwrap.dedent('''\
filter voices by labels (default: "use case=narration")
this option can be used multiple times
filtering will be disabled if the first -f has no "=" (e.g. -f "any")
'''))
outmodes = parser.add_argument_group("output")
outgroup = outmodes.add_mutually_exclusive_group()
outgroup.add_argument("-s", "--save", metavar="FILE",
default="audio.mp3",
help="save the TTS to a file (default: audio.mp3)")
outgroup.add_argument("-p", "--play", action="store_true",
help="play the TTS with ffplay")
args = parser.parse_args()
if not args.quick:
import importlib.util
if importlib.util.find_spec("elevenlabs") is None:
print("elevenlabs library is not installed, you can install it to your enviroment using 'pip install elevenlabs'")
sys.exit()
from elevenlabs import voices, generate, play, save
if args.filter and "=" in args.filter[0]:
voicelist = voices()
for f in args.filter:
label, value = f.split("=")
voicelist = filter(lambda x: x.labels.get(label) == value, voicelist)
voicelist = list(voicelist)
else:
voicelist = list(voices())
if args.list:
for i, v in enumerate(voicelist):
print(str(i) + ": " + v.name + " " + str(v.labels))
sys.exit()
if args.voice:
voice = voicelist[args.voice % len(voicelist)]
else:
voice = args.name
# if -n should consult -f, use the following
#voice = next(x for x in voicelist if x.name == args.name)
audio = generate(
text=str(args.inputfile.read()),
voice=voice
)
if args.play:
play(audio)
else:
save(audio, args.save)

809
examples/talk/gpt-2.cpp Normal file
View File

@ -0,0 +1,809 @@
#include "ggml.h"
#include "common-ggml.h"
#include "gpt-2.h"
#include <cmath>
#include <cstdio>
#include <cstring>
#include <fstream>
#include <map>
#include <string>
#include <thread>
#include <vector>
#include <regex>
#include <random>
/////////////////////// GPT-2 BEGIN /////////////////////////
// default hparams (GPT-2 117M)
struct gpt2_hparams {
int32_t n_vocab = 50257;
int32_t n_ctx = 1024;
int32_t n_embd = 768;
int32_t n_head = 12;
int32_t n_layer = 12;
int32_t ftype = 1;
};
struct gpt2_layer {
// normalization
struct ggml_tensor * ln_1_g;
struct ggml_tensor * ln_1_b;
struct ggml_tensor * ln_2_g;
struct ggml_tensor * ln_2_b;
// attention
struct ggml_tensor * c_attn_attn_w;
struct ggml_tensor * c_attn_attn_b;
struct ggml_tensor * c_attn_proj_w;
struct ggml_tensor * c_attn_proj_b;
// mlp
struct ggml_tensor * c_mlp_fc_w;
struct ggml_tensor * c_mlp_fc_b;
struct ggml_tensor * c_mlp_proj_w;
struct ggml_tensor * c_mlp_proj_b;
};
struct gpt2_model {
gpt2_hparams hparams;
// normalization
struct ggml_tensor * ln_f_g;
struct ggml_tensor * ln_f_b;
struct ggml_tensor * wte; // position embedding
struct ggml_tensor * wpe; // token embedding
struct ggml_tensor * lm_head; // language model head
std::vector<gpt2_layer> layers;
// key + value memory
struct ggml_tensor * memory_k;
struct ggml_tensor * memory_v;
//
struct ggml_context * ctx;
std::map<std::string, struct ggml_tensor *> tensors;
};
// load the model's weights from a file
static bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & vocab) {
printf("%s: loading model from '%s'\n", __func__, fname.c_str());
auto fin = std::ifstream(fname, std::ios::binary);
if (!fin) {
fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
return false;
}
// verify magic
{
uint32_t magic;
fin.read((char *) &magic, sizeof(magic));
if (magic != 0x67676d6c) {
fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
return false;
}
}
// load hparams
{
auto & hparams = model.hparams;
fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
fin.read((char *) &hparams.n_ctx, sizeof(hparams.n_ctx));
fin.read((char *) &hparams.n_embd, sizeof(hparams.n_embd));
fin.read((char *) &hparams.n_head, sizeof(hparams.n_head));
fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));
fin.read((char *) &hparams.ftype, sizeof(hparams.ftype));
printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx);
printf("%s: n_embd = %d\n", __func__, hparams.n_embd);
printf("%s: n_head = %d\n", __func__, hparams.n_head);
printf("%s: n_layer = %d\n", __func__, hparams.n_layer);
printf("%s: ftype = %d\n", __func__, hparams.ftype);
}
// load vocab
{
int32_t n_vocab = 0;
fin.read((char *) &n_vocab, sizeof(n_vocab));
if (n_vocab != model.hparams.n_vocab) {
fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
__func__, fname.c_str(), n_vocab, model.hparams.n_vocab);
return false;
}
char word[129];
for (int i = 0; i < n_vocab; i++) {
uint32_t len;
fin.read((char *) &len, sizeof(len));
word[len] = '\0';
fin.read((char *) word, len);
vocab.token_to_id[word] = i;
vocab.id_to_token[i] = word;
}
}
// for the big tensors, we have the option to store the data in 16-bit floats or quantized
// in order to save memory and also to speed up the computation
ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype));
if (wtype == GGML_TYPE_COUNT) {
fprintf(stderr, "%s: invalid model file '%s' (bad ftype value %d)\n",
__func__, fname.c_str(), model.hparams.ftype);
return false;
}
auto & ctx = model.ctx;
size_t ctx_size = 0;
{
const auto & hparams = model.hparams;
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
const int n_ctx = hparams.n_ctx;
const int n_vocab = hparams.n_vocab;
ctx_size += ggml_row_size(GGML_TYPE_F32, n_embd); // ln_f_g
ctx_size += ggml_row_size(GGML_TYPE_F32, n_embd); // ln_f_b
ctx_size += n_vocab*ggml_row_size(wtype, n_embd); // wte
ctx_size += n_ctx*ggml_row_size(GGML_TYPE_F32, n_embd); // wpe
ctx_size += n_vocab*ggml_row_size(wtype, n_embd); // lm_head
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_1_g
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_1_b
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_2_g
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_2_b
ctx_size += n_layer*(ggml_row_size(wtype, 3*n_embd*n_embd)); // c_attn_attn_w
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, 3*n_embd)); // c_attn_attn_b
ctx_size += n_layer*(ggml_row_size(wtype, n_embd*n_embd)); // c_attn_proj_w
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // c_attn_proj_b
ctx_size += n_layer*(ggml_row_size(wtype, 4*n_embd*n_embd)); // c_mlp_fc_w
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, 4*n_embd)); // c_mlp_fc_b
ctx_size += n_layer*(ggml_row_size(wtype, 4*n_embd*n_embd)); // c_mlp_proj_w
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // c_mlp_proj_b
ctx_size += n_ctx*n_layer*ggml_row_size(GGML_TYPE_F32, n_embd); // memory_k
ctx_size += n_ctx*n_layer*ggml_row_size(GGML_TYPE_F32, n_embd); // memory_v
ctx_size += (6 + 12*n_layer)*256; // object overhead
printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
}
// create the ggml context
{
struct ggml_init_params params = {
/*.mem_size =*/ ctx_size,
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ false,
};
model.ctx = ggml_init(params);
if (!model.ctx) {
fprintf(stderr, "%s: ggml_init() failed\n", __func__);
return false;
}
}
// prepare memory for the weights
{
const auto & hparams = model.hparams;
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
const int n_ctx = hparams.n_ctx;
const int n_vocab = hparams.n_vocab;
model.layers.resize(n_layer);
model.ln_f_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
model.ln_f_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
model.wte = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab);
model.wpe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ctx);
model.lm_head = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab);
// map by name
model.tensors["model/ln_f/g"] = model.ln_f_g;
model.tensors["model/ln_f/b"] = model.ln_f_b;
model.tensors["model/wte"] = model.wte;
model.tensors["model/wpe"] = model.wpe;
model.tensors["model/lm_head"] = model.lm_head;
for (int i = 0; i < n_layer; ++i) {
auto & layer = model.layers[i];
layer.ln_1_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.ln_2_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.ln_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.c_attn_attn_w = ggml_new_tensor_2d(ctx, wtype, n_embd, 3*n_embd);
layer.c_attn_attn_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3*n_embd);
layer.c_attn_proj_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
layer.c_attn_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.c_mlp_fc_w = ggml_new_tensor_2d(ctx, wtype, n_embd, 4*n_embd);
layer.c_mlp_fc_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_embd);
layer.c_mlp_proj_w = ggml_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd);
layer.c_mlp_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
// map by name
model.tensors["model/h" + std::to_string(i) + "/ln_1/g"] = layer.ln_1_g;
model.tensors["model/h" + std::to_string(i) + "/ln_1/b"] = layer.ln_1_b;
model.tensors["model/h" + std::to_string(i) + "/ln_2/g"] = layer.ln_2_g;
model.tensors["model/h" + std::to_string(i) + "/ln_2/b"] = layer.ln_2_b;
model.tensors["model/h" + std::to_string(i) + "/attn/c_attn/w"] = layer.c_attn_attn_w;
model.tensors["model/h" + std::to_string(i) + "/attn/c_attn/b"] = layer.c_attn_attn_b;
model.tensors["model/h" + std::to_string(i) + "/attn/c_proj/w"] = layer.c_attn_proj_w;
model.tensors["model/h" + std::to_string(i) + "/attn/c_proj/b"] = layer.c_attn_proj_b;
model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/w"] = layer.c_mlp_fc_w;
model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/b"] = layer.c_mlp_fc_b;
model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/w"] = layer.c_mlp_proj_w;
model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/b"] = layer.c_mlp_proj_b;
}
}
// key + value memory
{
const auto & hparams = model.hparams;
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
const int n_ctx = hparams.n_ctx;
const int n_mem = n_layer*n_ctx;
const int n_elements = n_embd*n_mem;
model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);
printf("%s: memory size = %8.2f MB, n_mem = %d\n", __func__, memory_size/1024.0/1024.0, n_mem);
}
// load weights
{
size_t total_size = 0;
bool has_lm_head = false;
while (true) {
int32_t n_dims;
int32_t length;
int32_t ttype;
fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
fin.read(reinterpret_cast<char *>(&length), sizeof(length));
fin.read(reinterpret_cast<char *>(&ttype), sizeof(ttype));
if (fin.eof()) {
break;
}
int32_t nelements = 1;
int32_t ne[2] = { 1, 1 };
for (int i = 0; i < n_dims; ++i) {
fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
nelements *= ne[i];
}
std::string name(length, 0);
fin.read(&name[0], length);
if (model.tensors.find(name.data()) == model.tensors.end()) {
fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
return false;
}
auto tensor = model.tensors[name.data()];
if (ggml_nelements(tensor) != nelements) {
fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
return false;
}
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
__func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], ne[0], ne[1]);
return false;
}
// for debugging
if (0) {
printf("%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", name.data(), ne[0], ne[1], ggml_type_name(ggml_type(ttype)), ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor));
}
const size_t bpe = ggml_type_size(ggml_type(ttype));
if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
return false;
}
fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
// GPT-2 models share the WTE tensor as the LM head
if (name == "model/wte" && has_lm_head == false) {
memcpy(model.lm_head->data, tensor->data, ggml_nbytes(tensor));
}
if (name == "model/lm_head") {
has_lm_head = true;
}
total_size += ggml_nbytes(tensor);
}
printf("%s: model size = %8.2f MB\n", __func__, total_size/1024.0/1024.0);
}
fin.close();
return true;
}
// evaluate the transformer
//
// - model: the model
// - n_threads: number of threads to use
// - n_past: the context size so far
// - embd_inp: the embeddings of the tokens in the context
// - embd_w: the predicted logits for the next token
//
// TODO: sync latest version from ggml repo
static bool gpt2_eval(
const gpt2_model & model,
const int n_threads,
const int n_past,
const std::vector<gpt_vocab::id> & embd_inp,
std::vector<float> & embd_w,
size_t & mem_per_token) {
const int N = embd_inp.size();
const auto & hparams = model.hparams;
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
const int n_ctx = hparams.n_ctx;
const int n_head = hparams.n_head;
const int n_vocab = hparams.n_vocab;
static size_t buf_size = 512u*1024*1024;
static void * buf = malloc(buf_size);
if (mem_per_token > 0 && mem_per_token*N > buf_size) {
const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead
//printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
// reallocate
buf_size = buf_size_new;
buf = realloc(buf, buf_size);
if (buf == nullptr) {
fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size);
return false;
}
}
struct ggml_init_params params = {
/*.mem_size =*/ buf_size,
/*.mem_buffer =*/ buf,
/*.no_alloc =*/ false,
};
struct ggml_context * ctx0 = ggml_init(params);
struct ggml_cgraph gf = {};
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
for (int i = 0; i < N; ++i) {
((int32_t *) position->data)[i] = n_past + i;
}
// wte + wpe
struct ggml_tensor * inpL =
ggml_add(ctx0,
ggml_get_rows(ctx0, model.wte, embd),
ggml_get_rows(ctx0, model.wpe, position));
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * cur;
// norm
{
// [ 768, N]
cur = ggml_norm(ctx0, inpL, 1e-5f);
// cur = ln_1_g*cur + ln_1_b
// [ 768, N]
cur = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, model.layers[il].ln_1_g, cur),
cur),
ggml_repeat(ctx0, model.layers[il].ln_1_b, cur));
}
// attn
// [2304, 768] - model.layers[il].c_attn_attn_w
// [2304, 1] - model.layers[il].c_attn_attn_b
// [ 768, N] - cur (in)
// [2304, N] - cur (out)
//
// cur = attn_w*cur + attn_b
// [2304, N]
{
cur = ggml_mul_mat(ctx0,
model.layers[il].c_attn_attn_w,
cur);
cur = ggml_add(ctx0,
ggml_repeat(ctx0, model.layers[il].c_attn_attn_b, cur),
cur);
}
// self-attention
{
struct ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 0*sizeof(float)*n_embd);
struct ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 1*sizeof(float)*n_embd);
struct ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 2*sizeof(float)*n_embd);
// store key and value to memory
if (N >= 1) {
struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_k, N*n_embd, (ggml_element_size(model.memory_k)*n_embd)*(il*n_ctx + n_past));
struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_v, N*n_embd, (ggml_element_size(model.memory_v)*n_embd)*(il*n_ctx + n_past));
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
}
// Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
// [64, N, 12]
struct ggml_tensor * Q =
ggml_permute(ctx0,
ggml_cpy(ctx0,
Qcur,
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd/n_head, n_head, N)),
0, 2, 1, 3);
// K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)
// [64, n_past + N, 12]
struct ggml_tensor * K =
ggml_permute(ctx0,
ggml_reshape_3d(ctx0,
ggml_view_1d(ctx0, model.memory_k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_k)*n_embd),
n_embd/n_head, n_head, n_past + N),
0, 2, 1, 3);
// GG: flash attention
//struct ggml_tensor * V =
// ggml_cpy(ctx0,
// ggml_permute(ctx0,
// ggml_reshape_3d(ctx0,
// ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
// n_embd/n_head, n_head, n_past + N),
// 1, 2, 0, 3),
// ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_past + N, n_embd/n_head, n_head));
//struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, true);
// K * Q
// [n_past + N, N, 12]
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
// KQ_scaled = KQ / sqrt(n_embd/n_head)
// [n_past + N, N, 12]
struct ggml_tensor * KQ_scaled =
ggml_scale(ctx0,
KQ,
1.0f/sqrt(float(n_embd)/n_head));
// KQ_masked = mask_past(KQ_scaled)
// [n_past + N, N, 12]
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
// KQ = soft_max(KQ_masked)
// [n_past + N, N, 12]
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
// V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
// [n_past + N, 64, 12]
struct ggml_tensor * V_trans =
ggml_cpy(ctx0,
ggml_permute(ctx0,
ggml_reshape_3d(ctx0,
ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
n_embd/n_head, n_head, n_past + N),
1, 2, 0, 3),
ggml_new_tensor_3d(ctx0, model.memory_v->type, n_past + N, n_embd/n_head, n_head));
// KQV = transpose(V) * KQ_soft_max
// [64, N, 12]
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
// KQV_merged = KQV.permute(0, 2, 1, 3)
// [64, 12, N]
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
// cur = KQV_merged.contiguous().view(n_embd, N)
// [768, N]
cur = ggml_cpy(ctx0,
KQV_merged,
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
}
// projection
// [ 768, 768] - model.layers[il].c_attn_proj_w
// [ 768, 1] - model.layers[il].c_attn_proj_b
// [ 768, N] - cur (in)
// [ 768, N] - cur (out)
//
// cur = proj_w*cur + proj_b
// [768, N]
{
cur = ggml_mul_mat(ctx0,
model.layers[il].c_attn_proj_w,
cur);
cur = ggml_add(ctx0,
ggml_repeat(ctx0, model.layers[il].c_attn_proj_b, cur),
cur);
}
// add the input
cur = ggml_add(ctx0, cur, inpL);
struct ggml_tensor * inpFF = cur;
// feed-forward network
{
// norm
{
cur = ggml_norm(ctx0, inpFF, 1e-5f);
// cur = ln_2_g*cur + ln_2_b
// [ 768, N]
cur = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, model.layers[il].ln_2_g, cur),
cur),
ggml_repeat(ctx0, model.layers[il].ln_2_b, cur));
}
// fully connected
// [3072, 768] - model.layers[il].c_mlp_fc_w
// [3072, 1] - model.layers[il].c_mlp_fc_b
// [ 768, N] - cur (in)
// [3072, N] - cur (out)
//
// cur = fc_w*cur + fc_b
// [3072, N]
cur = ggml_mul_mat(ctx0,
model.layers[il].c_mlp_fc_w,
cur);
cur = ggml_add(ctx0,
ggml_repeat(ctx0, model.layers[il].c_mlp_fc_b, cur),
cur);
// GELU activation
// [3072, N]
cur = ggml_gelu(ctx0, cur);
// projection
// [ 768, 3072] - model.layers[il].c_mlp_proj_w
// [ 768, 1] - model.layers[il].c_mlp_proj_b
// [3072, N] - cur (in)
// [ 768, N] - cur (out)
//
// cur = proj_w*cur + proj_b
// [768, N]
cur = ggml_mul_mat(ctx0,
model.layers[il].c_mlp_proj_w,
cur);
cur = ggml_add(ctx0,
ggml_repeat(ctx0, model.layers[il].c_mlp_proj_b, cur),
cur);
}
// input for next layer
inpL = ggml_add(ctx0, cur, inpFF);
}
// norm
{
// [ 768, N]
inpL = ggml_norm(ctx0, inpL, 1e-5f);
// inpL = ln_f_g*inpL + ln_f_b
// [ 768, N]
inpL = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, model.ln_f_g, inpL),
inpL),
ggml_repeat(ctx0, model.ln_f_b, inpL));
}
// inpL = WTE * inpL
// [ 768, 50257] - model.lm_head
// [ 768, N] - inpL
inpL = ggml_mul_mat(ctx0, model.lm_head, inpL);
// logits -> probs
//inpL = ggml_soft_max(ctx0, inpL);
// run the computation
ggml_build_forward_expand (&gf, inpL);
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
//if (n_past%100 == 0) {
// ggml_graph_print (&gf);
// ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot");
//}
//embd_w.resize(n_vocab*N);
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
// return result just for the last token
embd_w.resize(n_vocab);
memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
if (mem_per_token == 0) {
mem_per_token = ggml_used_mem(ctx0)/N;
}
//printf("used_mem = %zu\n", ggml_used_mem(ctx0));
ggml_free(ctx0);
return true;
}
/////////////////////////////// GPT-2 END ////////////////////////////////
constexpr int N_THREAD = 8;
struct gpt2_context {
std::string prompt_base = R"(Hello, how are you?
I'm fine, thanks. How are you?
Thanks, I'm fine too. What are you doing?
I'm just sitting here.
It's a lovely day, isn't it?
Yes, it is. I love the weather this time of year.
I wish it would rain a little bit.
Me too.
)";
std::mt19937 rng;
gpt_vocab vocab;
gpt2_model model;
int32_t n_threads = std::min(N_THREAD, (int) std::thread::hardware_concurrency());
// sampling parameters
int32_t top_k = 5;
float top_p = 0.9f;
float temp = 1.0f;
};
struct gpt2_context * gpt2_init(const char * path_model) {
gpt2_context * ctx = new gpt2_context;
ctx->rng = std::mt19937(time(nullptr));
// load the model
{
const int64_t t_start_us = ggml_time_us();
if (!gpt2_model_load(path_model, ctx->model, ctx->vocab)) {
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, path_model);
delete ctx;
return nullptr;
}
const int64_t t_load_us = ggml_time_us() - t_start_us;
printf("gpt-2: model loaded in %d ms\n", (int) (t_load_us/1000));
}
return ctx;
}
void gpt2_free(struct gpt2_context * ctx) {
delete ctx;
}
const char * gpt2_get_prompt(struct gpt2_context * ctx) {
return ctx->prompt_base.c_str();
}
void gpt2_set_prompt(struct gpt2_context * ctx, const char * prompt) {
ctx->prompt_base = prompt;
}
std::vector<gpt_vocab::id> gpt2_tokenize(const gpt2_context * ctx, const char * text) {
return ::gpt_tokenize(ctx->vocab, text);
}
std::string gpt2_gen_text(gpt2_context * ctx, const char * text, int max_tokens) {
int n_past = 0;
std::vector<float> embd_w;
// tokenize the prompt
std::vector<gpt_vocab::id> embd_inp = ::gpt2_tokenize(ctx, text);
int n_predict = std::min(max_tokens, ctx->model.hparams.n_ctx - (int) embd_inp.size());
std::vector<gpt_vocab::id> embd = embd_inp;
size_t mem_per_token = 3000000;
std::string result;
for (int i = embd.size(); i < (int) embd_inp.size() + n_predict; i++) {
// predict
if (!embd.empty()) {
if (!gpt2_eval(ctx->model, ctx->n_threads, n_past, embd, embd_w, mem_per_token)) {
printf("gpt-2: failed to generate text\n");
return "";
}
}
n_past += embd.size();
embd.clear();
{
// sample next token
const int top_k = ctx->top_k;
const float top_p = ctx->top_p;
const float temp = ctx->temp;
const int n_vocab = ctx->model.hparams.n_vocab;
const gpt_vocab::id id = gpt_sample_top_k_top_p(ctx->vocab, embd_w.data() + (embd_w.size() - n_vocab), top_k, top_p, temp, ctx->rng);
// add it to the context
embd.push_back(id);
}
result += ctx->vocab.id_to_token[embd[0]];
// end of text token
if (embd.back() == 50256) {
break;
}
}
return result;
}

21
examples/talk/gpt-2.h Normal file
View File

@ -0,0 +1,21 @@
#pragma once
// TODO: Change to C-style API and move to ./examples for easy reuse.
#include "common.h"
#include <vector>
#include <map>
#include <string>
struct gpt2_context;
struct gpt2_context * gpt2_init(const char * path_model);
void gpt2_free(struct gpt2_context * ctx);
const char * gpt2_get_prompt(struct gpt2_context * ctx);
void gpt2_set_prompt(struct gpt2_context * ctx, const char * prompt);
std::vector<gpt_vocab::id> gpt2_tokenize(const gpt2_context * ctx, const char * text);
std::string gpt2_gen_text(gpt2_context * ctx, const char * text, int max_tokens);

40
examples/talk/speak Normal file
View File

@ -0,0 +1,40 @@
#!/bin/bash
# Usage:
# speak <voice_id> <textfile>
function installed() { command -v $1 >/dev/null 2>&1; }
if installed espeak; then
espeak -v en-us+m$1 -s 225 -p 50 -a 200 -g 5 -k 5 -f $2
elif installed piper && installed aplay; then
cat $2 | piper --model ~/en_US-lessac-medium.onnx --output-raw | aplay -q -r 22050 -f S16_LE -t raw -
# for Mac
elif installed say; then
say -f $2
# Eleven Labs
elif installed python3 && \
python3 -c 'import importlib.util; exit(not importlib.util.find_spec("elevenlabs"))' && \
installed ffplay; then
# It's possible to use the API for free with limited number of characters.
# To increase this limit register to https://beta.elevenlabs.io to get an api key
# and paste it after 'ELEVEN_API_KEY='
# Keep the line commented to use the free version without api key
#export ELEVEN_API_KEY=your_api_key
wd=$(dirname $0)
script=$wd/eleven-labs.py
python3 $script -q -p -v $1 $2 >/dev/null 2>&1
# Uncomment to keep the audio file
#python3 $script -q -s ./audio.mp3 -v $1 $2 >/dev/null 2>&1
#ffplay -autoexit -nodisp -loglevel quiet -hide_banner -i ./audio.mp3 >/dev/null 2>&1
else
echo 'Install espeak ("brew install espeak" or "apt-get install espeak"),'
echo 'piper ("pip install piper-tts" or https://github.com/rhasspy/piper) with aplay,'
echo 'or elevenlabs ("pip install elevenlabs") with ffplay.'
echo '(export ELEVEN_API_KEY if you have an api key from https://beta.elevenlabs.io)'
fi

1
examples/talk/speak.bat Normal file
View File

@ -0,0 +1 @@
@powershell -ExecutionPolicy Bypass -F examples\talk\speak.ps1 %1 %2

14
examples/talk/speak.ps1 Normal file
View File

@ -0,0 +1,14 @@
# Set-ExecutionPolicy -ExecutionPolicy Bypass -Scope CurrentUser
param(
[Parameter(Mandatory=$true)][int]$voicenum,
[Parameter(Mandatory=$true)][string]$textfile
)
Add-Type -AssemblyName System.Speech;
$speak = New-Object System.Speech.Synthesis.SpeechSynthesizer;
$voiceoptions = $speak.GetInstalledVoices("en-US");
$voice = $voiceoptions[$voicenum % $voiceoptions.count];
$speak.SelectVoice($voice.VoiceInfo.Name);
$speak.Rate="0";
$text = Get-Content -Path $textfile;
$speak.Speak($text);

376
examples/talk/talk.cpp Normal file
View File

@ -0,0 +1,376 @@
// Talk with AI
//
#include "common-sdl.h"
#include "common.h"
#include "whisper.h"
#include "gpt-2.h"
#include <cassert>
#include <cstdio>
#include <fstream>
#include <regex>
#include <string>
#include <thread>
#include <vector>
#include <regex>
// command-line parameters
struct whisper_params {
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
int32_t voice_ms = 10000;
int32_t capture_id = -1;
int32_t max_tokens = 32;
int32_t audio_ctx = 0;
float vad_thold = 0.6f;
float freq_thold = 100.0f;
bool translate = false;
bool print_special = false;
bool print_energy = false;
bool no_timestamps = true;
bool use_gpu = true;
bool flash_attn = false;
std::string person = "Santa";
std::string language = "en";
std::string model_wsp = "models/ggml-base.en.bin";
std::string model_gpt = "models/ggml-gpt-2-117M.bin";
std::string speak = "./examples/talk/speak";
std::string speak_file= "./examples/talk/to_speak.txt";
std::string fname_out;
};
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
static bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
for (int i = 1; i < argc; i++) {
std::string arg = argv[i];
if (arg == "-h" || arg == "--help") {
whisper_print_usage(argc, argv, params);
exit(0);
}
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
else if (arg == "-vms" || arg == "--voice-ms") { params.voice_ms = std::stoi(argv[++i]); }
else if (arg == "-c" || arg == "--capture") { params.capture_id = std::stoi(argv[++i]); }
else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); }
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
else if (arg == "-p" || arg == "--person") { params.person = argv[++i]; }
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
else if (arg == "-mw" || arg == "--model-whisper") { params.model_wsp = argv[++i]; }
else if (arg == "-mg" || arg == "--model-gpt") { params.model_gpt = argv[++i]; }
else if (arg == "-s" || arg == "--speak") { params.speak = argv[++i]; }
else if (arg == "-sf" || arg == "--speak_file") { params.speak_file = argv[++i]; }
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params);
exit(0);
}
}
return true;
}
void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params) {
fprintf(stderr, "\n");
fprintf(stderr, "usage: %s [options]\n", argv[0]);
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
fprintf(stderr, " -vms N, --voice-ms N [%-7d] voice duration in milliseconds\n", params.voice_ms);
fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false");
fprintf(stderr, " -p NAME, --person NAME [%-7s] person name (for prompt selection)\n", params.person.c_str());
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
fprintf(stderr, " -mw FILE, --model-whisper [%-7s] whisper model file\n", params.model_wsp.c_str());
fprintf(stderr, " -mg FILE, --model-gpt [%-7s] gpt model file\n", params.model_gpt.c_str());
fprintf(stderr, " -s FILE, --speak TEXT [%-7s] command for TTS\n", params.speak.c_str());
fprintf(stderr, " -sf FILE, --speak_file [%-7s] file to pass to TTS\n", params.speak_file.c_str());
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
fprintf(stderr, "\n");
}
static std::string transcribe(whisper_context * ctx, const whisper_params & params, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) {
const auto t_start = std::chrono::high_resolution_clock::now();
prob = 0.0f;
t_ms = 0;
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
wparams.print_progress = false;
wparams.print_special = params.print_special;
wparams.print_realtime = false;
wparams.print_timestamps = !params.no_timestamps;
wparams.translate = params.translate;
wparams.no_context = true;
wparams.single_segment = true;
wparams.max_tokens = params.max_tokens;
wparams.language = params.language.c_str();
wparams.n_threads = params.n_threads;
wparams.audio_ctx = params.audio_ctx;
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
return "";
}
int prob_n = 0;
std::string result;
const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
result += text;
const int n_tokens = whisper_full_n_tokens(ctx, i);
for (int j = 0; j < n_tokens; ++j) {
const auto token = whisper_full_get_token_data(ctx, i, j);
prob += token.p;
++prob_n;
}
}
if (prob_n > 0) {
prob /= prob_n;
}
const auto t_end = std::chrono::high_resolution_clock::now();
t_ms = std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count();
return result;
}
const std::string k_prompt =
R"(This is a dialogue between {0} (A) and a person (B). The dialogue so far is:
B: Hello {0}, how are you?
A: I'm fine, thank you.
{1}
Here is how {0} (A) continues the dialogue:
A:)";
int main(int argc, char ** argv) {
whisper_params params;
if (whisper_params_parse(argc, argv, params) == false) {
return 1;
}
if (whisper_lang_id(params.language.c_str()) == -1) {
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
whisper_print_usage(argc, argv, params);
exit(0);
}
// whisper init
struct whisper_context_params cparams = whisper_context_default_params();
cparams.use_gpu = params.use_gpu;
cparams.flash_attn = params.flash_attn;
struct whisper_context * ctx_wsp = whisper_init_from_file_with_params(params.model_wsp.c_str(), cparams);
// gpt init
struct gpt2_context * ctx_gpt = gpt2_init(params.model_gpt.c_str());
// print some info about the processing
{
fprintf(stderr, "\n");
if (!whisper_is_multilingual(ctx_wsp)) {
if (params.language != "en" || params.translate) {
params.language = "en";
params.translate = false;
fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
}
}
fprintf(stderr, "%s: processing, %d threads, lang = %s, task = %s, timestamps = %d ...\n",
__func__,
params.n_threads,
params.language.c_str(),
params.translate ? "translate" : "transcribe",
params.no_timestamps ? 0 : 1);
fprintf(stderr, "\n");
}
// init audio
audio_async audio(30*1000);
if (!audio.init(params.capture_id, WHISPER_SAMPLE_RATE)) {
fprintf(stderr, "%s: audio.init() failed!\n", __func__);
return 1;
}
audio.resume();
int n_iter = 0;
bool is_running = true;
bool force_speak = false;
float prob0 = 0.0f;
std::vector<float> pcmf32_cur;
std::vector<float> pcmf32_prompt;
gpt2_set_prompt(ctx_gpt, "");
const int voice_id = rand()%6;
fprintf(stderr, "gpt-2: prompt:\n");
fprintf(stderr, "========================\n\n");
fprintf(stderr, "%s\n", ::replace(k_prompt, "{0}", params.person).c_str());
fprintf(stderr, "========================\n\n");
// main loop
while (is_running) {
// handle Ctrl + C
is_running = sdl_poll_events();
if (!is_running) {
break;
}
// delay
std::this_thread::sleep_for(std::chrono::milliseconds(100));
int64_t t_ms = 0;
{
audio.get(2000, pcmf32_cur);
if (::vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1250, params.vad_thold, params.freq_thold, params.print_energy) || force_speak) {
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
audio.get(params.voice_ms, pcmf32_cur);
std::string text_heard;
if (!force_speak) {
text_heard = ::trim(::transcribe(ctx_wsp, params, pcmf32_cur, prob0, t_ms));
}
// remove text between brackets using regex
{
std::regex re("\\[.*?\\]");
text_heard = std::regex_replace(text_heard, re, "");
}
// remove text between brackets using regex
{
std::regex re("\\(.*?\\)");
text_heard = std::regex_replace(text_heard, re, "");
}
// remove all characters, except for letters, numbers, punctuation and ':', '\'', '-', ' '
text_heard = std::regex_replace(text_heard, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
// take first line
text_heard = text_heard.substr(0, text_heard.find_first_of('\n'));
// remove leading and trailing whitespace
text_heard = std::regex_replace(text_heard, std::regex("^\\s+"), "");
text_heard = std::regex_replace(text_heard, std::regex("\\s+$"), "");
const std::vector<gpt_vocab::id> tokens = gpt2_tokenize(ctx_gpt, text_heard.c_str());
if (text_heard.empty() || tokens.empty() || force_speak) {
fprintf(stdout, "%s: Heard nothing, skipping ...\n", __func__);
audio.clear();
continue;
}
force_speak = false;
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", text_heard.c_str(), "\033[0m", (int) t_ms);
std::string prompt_base = gpt2_get_prompt(ctx_gpt);
std::string text_to_speak;
{
prompt_base += "B: " + text_heard + "\n";
std::string prompt = ::replace(::replace(k_prompt, "{0}", params.person), "{1}", prompt_base);
text_to_speak = gpt2_gen_text(ctx_gpt, prompt.c_str(), params.max_tokens);
//text_to_speak = std::regex_replace(text_to_speak, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
text_to_speak = text_to_speak.substr(0, text_to_speak.find_first_of('\n'));
// remove first 2 lines of base prompt
if (n_iter > 4) {
{
const size_t pos = prompt_base.find_first_of('\n');
if (pos != std::string::npos) {
prompt_base = prompt_base.substr(pos + 1);
}
}
{
const size_t pos = prompt_base.find_first_of('\n');
if (pos != std::string::npos) {
prompt_base = prompt_base.substr(pos + 1);
}
}
}
prompt_base += "A:" + text_to_speak + "\n";
{
prompt = ::replace(::replace(k_prompt, "{0}", params.person), "{1}", prompt_base);
printf("===============\n");
printf("prompt:\n");
printf("%s\n", prompt.c_str());
printf("===============\n");
}
}
//printf("========================\n");
//printf("gpt-2: prompt_base:\n%s\n", prompt_base.c_str());
//printf("========================\n");
gpt2_set_prompt(ctx_gpt, prompt_base.c_str());
text_to_speak = ::replace(text_to_speak, params.person + ": ", "");
speak_with_file(params.speak, text_to_speak, params.speak_file, voice_id);
audio.clear();
++n_iter;
}
}
}
audio.pause();
whisper_print_timings(ctx_wsp);
whisper_free(ctx_wsp);
return 0;
}

View File

@ -29,7 +29,7 @@ help()
check_requirements()
{
if ! command -v ./build/bin/whisper-cli &>/dev/null; then
if ! command -v ./main &>/dev/null; then
echo "whisper.cpp main executable is required (make)"
exit 1
fi
@ -100,7 +100,7 @@ do
err=$(cat /tmp/whisper-live.err | wc -l)
done
./build/bin/whisper-cli -t $threads -m ./models/ggml-$model.bin -f /tmp/whisper-live.wav --no-timestamps -otxt 2> /tmp/whispererr | tail -n 1
./main -t $threads -m ./models/ggml-$model.bin -f /tmp/whisper-live.wav --no-timestamps -otxt 2> /tmp/whispererr | tail -n 1
while [ $SECONDS -lt $((($i+1)*$step)) ]; do
sleep 1

View File

@ -2,11 +2,11 @@ cmake_minimum_required(VERSION 3.10)
project(whisper.cpp)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD 11)
set(WHISPER_LIB_DIR ${CMAKE_SOURCE_DIR}/../../../../../../..)
# Path to external GGML, otherwise uses the copy in whisper.cpp.
option(GGML_HOME "whisper: Path to external GGML source" OFF)
option(GGML_HOME "whisper: Path to external GGML source" OFF)
set(
SOURCE_FILES
@ -14,24 +14,15 @@ set(
${CMAKE_SOURCE_DIR}/jni.c
)
# TODO: this needs to be updated to work with the new ggml CMakeLists
if (NOT GGML_HOME)
set(
SOURCE_FILES
${SOURCE_FILES}
${WHISPER_LIB_DIR}/ggml/src/ggml.c
${WHISPER_LIB_DIR}/ggml/src/ggml-aarch64.c
${WHISPER_LIB_DIR}/ggml/src/ggml-alloc.c
${WHISPER_LIB_DIR}/ggml/src/ggml-backend.cpp
${WHISPER_LIB_DIR}/ggml/src/ggml-backend-reg.cpp
${WHISPER_LIB_DIR}/ggml/src/ggml-quants.c
${WHISPER_LIB_DIR}/ggml/src/ggml-threading.cpp
${WHISPER_LIB_DIR}/ggml/src/ggml-cpu/ggml-cpu.c
${WHISPER_LIB_DIR}/ggml/src/ggml-cpu/ggml-cpu.cpp
${WHISPER_LIB_DIR}/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp
${WHISPER_LIB_DIR}/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp
${WHISPER_LIB_DIR}/ggml/src/ggml-cpu/ggml-cpu-quants.c
${WHISPER_LIB_DIR}/ggml/src/ggml-cpu/ggml-cpu-traits.cpp
)
endif()
@ -89,5 +80,3 @@ include_directories(${WHISPER_LIB_DIR}/src)
include_directories(${WHISPER_LIB_DIR}/include)
include_directories(${WHISPER_LIB_DIR}/ggml/include)
include_directories(${WHISPER_LIB_DIR}/ggml/src)
include_directories(${WHISPER_LIB_DIR}/ggml/src/ggml-cpu)

View File

@ -24,12 +24,6 @@
18A2760B2C2A9B43001C8D37 /* ggml-metal.metal in Resources */ = {isa = PBXBuildFile; fileRef = 1844471D2AB2195F007D6BFE /* ggml-metal.metal */; };
18ABE15A2AF556340044A204 /* ggml-backend.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 18ABE1572AF556340044A204 /* ggml-backend.cpp */; };
18ABE15B2AF556340044A204 /* ggml-quants.c in Sources */ = {isa = PBXBuildFile; fileRef = 18ABE1592AF556340044A204 /* ggml-quants.c */; };
18E864A92CE73C1E0094B8B3 /* ggml-cpu.c in Sources */ = {isa = PBXBuildFile; fileRef = 18E864A82CE73C1E0094B8B3 /* ggml-cpu.c */; };
18F8C0BC2CEDF4DC00CAD607 /* ggml-threading.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 18F8C0BB2CEDF4DC00CAD607 /* ggml-threading.cpp */; };
18F8C0BE2CEDF50700CAD607 /* ggml-cpu.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 18F8C0BD2CEDF50700CAD607 /* ggml-cpu.cpp */; };
18F8C0C42CEDF52700CAD607 /* ggml-cpu-aarch64.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 18F8C0C02CEDF52700CAD607 /* ggml-cpu-aarch64.cpp */; };
18F8C0C52CEDF52700CAD607 /* ggml-cpu-quants.c in Sources */ = {isa = PBXBuildFile; fileRef = 18F8C0C32CEDF52700CAD607 /* ggml-cpu-quants.c */; };
18F8C0C72CEDF7AB00CAD607 /* ggml-backend-reg.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 18F8C0C62CEDF7AB00CAD607 /* ggml-backend-reg.cpp */; };
7FE3424B2A0C3FA20015A058 /* whisper-encoder-impl.m in Sources */ = {isa = PBXBuildFile; fileRef = 7FE342452A0C3FA20015A058 /* whisper-encoder-impl.m */; };
7FE3424C2A0C3FA20015A058 /* whisper-encoder.mm in Sources */ = {isa = PBXBuildFile; fileRef = 7FE342472A0C3FA20015A058 /* whisper-encoder.mm */; };
7FE3424D2A0C3FA20015A058 /* whisper-decoder-impl.m in Sources */ = {isa = PBXBuildFile; fileRef = 7FE3424A2A0C3FA20015A058 /* whisper-decoder-impl.m */; };
@ -55,8 +49,8 @@
18133C7F2C64E342005CEAAC /* ggml-aarch64.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; name = "ggml-aarch64.c"; path = "../../../ggml/src/ggml-aarch64.c"; sourceTree = "<group>"; };
184447182AB211A2007D6BFE /* ggml-alloc.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; name = "ggml-alloc.c"; path = "../../../ggml/src/ggml-alloc.c"; sourceTree = "<group>"; };
184447192AB211A2007D6BFE /* ggml-alloc.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = "ggml-alloc.h"; path = "../../../ggml/include/ggml-alloc.h"; sourceTree = "<group>"; };
1844471B2AB21655007D6BFE /* ggml-metal.m */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.objc; name = "ggml-metal.m"; path = "../../../ggml/src/ggml-metal/ggml-metal.m"; sourceTree = "<group>"; };
1844471D2AB2195F007D6BFE /* ggml-metal.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; name = "ggml-metal.metal"; path = "../../../ggml/src/ggml-metal/ggml-metal.metal"; sourceTree = "<group>"; };
1844471B2AB21655007D6BFE /* ggml-metal.m */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.objc; name = "ggml-metal.m"; path = "../../../ggml/src/ggml-metal.m"; sourceTree = "<group>"; };
1844471D2AB2195F007D6BFE /* ggml-metal.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; name = "ggml-metal.metal"; path = "../../../ggml/src/ggml-metal.metal"; sourceTree = "<group>"; };
18627C7629052BDF00BD2A04 /* whisper.objc.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = whisper.objc.app; sourceTree = BUILT_PRODUCTS_DIR; };
18627C7929052BDF00BD2A04 /* AppDelegate.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = AppDelegate.h; sourceTree = "<group>"; };
18627C7A29052BDF00BD2A04 /* AppDelegate.m */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.objc; path = AppDelegate.m; sourceTree = "<group>"; };
@ -82,17 +76,6 @@
18ABE1572AF556340044A204 /* ggml-backend.cpp */ = {isa = PBXFileReference; explicitFileType = sourcecode.cpp.cpp; fileEncoding = 4; name = "ggml-backend.cpp"; path = "../../../ggml/src/ggml-backend.cpp"; sourceTree = "<group>"; };
18ABE1582AF556340044A204 /* ggml-impl.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = "ggml-impl.h"; path = "../../../ggml/src/ggml-impl.h"; sourceTree = "<group>"; };
18ABE1592AF556340044A204 /* ggml-quants.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; name = "ggml-quants.c"; path = "../../../ggml/src/ggml-quants.c"; sourceTree = "<group>"; };
18E864A82CE73C1E0094B8B3 /* ggml-cpu.c */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.c; name = "ggml-cpu.c"; path = "../../../ggml/src/ggml-cpu/ggml-cpu.c"; sourceTree = "<group>"; };
18E864AA2CE73C580094B8B3 /* ggml-cpu.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; name = "ggml-cpu.h"; path = "../../../ggml/include/ggml-cpu.h"; sourceTree = "<group>"; };
18F8C0BA2CEDF4DC00CAD607 /* ggml-threading.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; name = "ggml-threading.h"; path = "../../../ggml/src/ggml-threading.h"; sourceTree = "<group>"; };
18F8C0BB2CEDF4DC00CAD607 /* ggml-threading.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = "ggml-threading.cpp"; path = "../../../ggml/src/ggml-threading.cpp"; sourceTree = "<group>"; };
18F8C0BD2CEDF50700CAD607 /* ggml-cpu.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = "ggml-cpu.cpp"; path = "../../../ggml/src/ggml-cpu/ggml-cpu.cpp"; sourceTree = "<group>"; };
18F8C0BF2CEDF52700CAD607 /* ggml-cpu-aarch64.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; name = "ggml-cpu-aarch64.h"; path = "../../../ggml/src/ggml-cpu/ggml-cpu-aarch64.h"; sourceTree = "<group>"; };
18F8C0C02CEDF52700CAD607 /* ggml-cpu-aarch64.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.c; name = "ggml-cpu-aarch64.cpp"; path = "../../../ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp"; sourceTree = "<group>"; };
18F8C0C12CEDF52700CAD607 /* ggml-cpu-impl.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; name = "ggml-cpu-impl.h"; path = "../../../ggml/src/ggml-cpu/ggml-cpu-impl.h"; sourceTree = "<group>"; };
18F8C0C22CEDF52700CAD607 /* ggml-cpu-quants.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; name = "ggml-cpu-quants.h"; path = "../../../ggml/src/ggml-cpu/ggml-cpu-quants.h"; sourceTree = "<group>"; };
18F8C0C32CEDF52700CAD607 /* ggml-cpu-quants.c */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.c; name = "ggml-cpu-quants.c"; path = "../../../ggml/src/ggml-cpu/ggml-cpu-quants.c"; sourceTree = "<group>"; };
18F8C0C62CEDF7AB00CAD607 /* ggml-backend-reg.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = "ggml-backend-reg.cpp"; path = "../../../ggml/src/ggml-backend-reg.cpp"; sourceTree = "<group>"; };
7FE342452A0C3FA20015A058 /* whisper-encoder-impl.m */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.objc; path = "whisper-encoder-impl.m"; sourceTree = "<group>"; };
7FE342462A0C3FA20015A058 /* whisper-encoder.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = "whisper-encoder.h"; sourceTree = "<group>"; };
7FE342472A0C3FA20015A058 /* whisper-encoder.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = "whisper-encoder.mm"; sourceTree = "<group>"; };
@ -132,17 +115,6 @@
18627C7829052BDF00BD2A04 /* whisper.objc */ = {
isa = PBXGroup;
children = (
18F8C0C62CEDF7AB00CAD607 /* ggml-backend-reg.cpp */,
18F8C0BF2CEDF52700CAD607 /* ggml-cpu-aarch64.h */,
18F8C0C02CEDF52700CAD607 /* ggml-cpu-aarch64.cpp */,
18F8C0C12CEDF52700CAD607 /* ggml-cpu-impl.h */,
18F8C0C22CEDF52700CAD607 /* ggml-cpu-quants.h */,
18F8C0C32CEDF52700CAD607 /* ggml-cpu-quants.c */,
18F8C0BD2CEDF50700CAD607 /* ggml-cpu.cpp */,
18F8C0BA2CEDF4DC00CAD607 /* ggml-threading.h */,
18F8C0BB2CEDF4DC00CAD607 /* ggml-threading.cpp */,
18E864AA2CE73C580094B8B3 /* ggml-cpu.h */,
18E864A82CE73C1E0094B8B3 /* ggml-cpu.c */,
18133C7F2C64E342005CEAAC /* ggml-aarch64.c */,
18133C7E2C64E342005CEAAC /* ggml-aarch64.h */,
18A275FF2C2A9563001C8D37 /* ggml-common.h */,
@ -275,16 +247,10 @@
18627C9629052C5800BD2A04 /* ggml.c in Sources */,
18627C7B29052BDF00BD2A04 /* AppDelegate.m in Sources */,
7FE3424D2A0C3FA20015A058 /* whisper-decoder-impl.m in Sources */,
18F8C0C72CEDF7AB00CAD607 /* ggml-backend-reg.cpp in Sources */,
18F8C0BE2CEDF50700CAD607 /* ggml-cpu.cpp in Sources */,
1844471A2AB211A2007D6BFE /* ggml-alloc.c in Sources */,
18F8C0C42CEDF52700CAD607 /* ggml-cpu-aarch64.cpp in Sources */,
18F8C0C52CEDF52700CAD607 /* ggml-cpu-quants.c in Sources */,
18E864A92CE73C1E0094B8B3 /* ggml-cpu.c in Sources */,
18ABE15A2AF556340044A204 /* ggml-backend.cpp in Sources */,
18627C8C29052BE000BD2A04 /* main.m in Sources */,
18627C7E29052BDF00BD2A04 /* SceneDelegate.m in Sources */,
18F8C0BC2CEDF4DC00CAD607 /* ggml-threading.cpp in Sources */,
1844471C2AB21655007D6BFE /* ggml-metal.m in Sources */,
7FE3424B2A0C3FA20015A058 /* whisper-encoder-impl.m in Sources */,
);
@ -363,7 +329,6 @@
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
GCC_WARN_UNUSED_FUNCTION = YES;
GCC_WARN_UNUSED_VARIABLE = YES;
HEADER_SEARCH_PATHS = ../../../ggml/src/;
IPHONEOS_DEPLOYMENT_TARGET = 16.0;
MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE;
MTL_FAST_MATH = YES;
@ -417,7 +382,6 @@
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
GCC_WARN_UNUSED_FUNCTION = YES;
GCC_WARN_UNUSED_VARIABLE = YES;
HEADER_SEARCH_PATHS = ../../../ggml/src/;
IPHONEOS_DEPLOYMENT_TARGET = 16.0;
MTL_ENABLE_DEBUG_INFO = NO;
MTL_FAST_MATH = YES;
@ -440,7 +404,6 @@
DEVELOPMENT_TEAM = P8JZH34X63;
GCC_WARN_64_TO_32_BIT_CONVERSION = NO;
GENERATE_INFOPLIST_FILE = YES;
HEADER_SEARCH_PATHS = ../../../ggml/src/;
INFOPLIST_FILE = whisper.objc/Info.plist;
INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents = YES;
INFOPLIST_KEY_UILaunchStoryboardName = LaunchScreen;
@ -470,7 +433,6 @@
DEVELOPMENT_TEAM = P8JZH34X63;
GCC_WARN_64_TO_32_BIT_CONVERSION = NO;
GENERATE_INFOPLIST_FILE = YES;
HEADER_SEARCH_PATHS = ../../../ggml/src/;
INFOPLIST_FILE = whisper.objc/Info.plist;
INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents = YES;
INFOPLIST_KEY_UILaunchStoryboardName = LaunchScreen;

View File

@ -1,5 +1,4 @@
import Foundation
import UIKit
import whisper
enum WhisperError: Error {
@ -56,91 +55,11 @@ actor WhisperContext {
return transcription
}
static func benchMemcpy(nThreads: Int32) async -> String {
return String.init(cString: whisper_bench_memcpy_str(nThreads))
}
static func benchGgmlMulMat(nThreads: Int32) async -> String {
return String.init(cString: whisper_bench_ggml_mul_mat_str(nThreads))
}
private func systemInfo() -> String {
var info = ""
//if (ggml_cpu_has_neon() != 0) { info += "NEON " }
return String(info.dropLast())
}
func benchFull(modelName: String, nThreads: Int32) async -> String {
let nMels = whisper_model_n_mels(context)
if (whisper_set_mel(context, nil, 0, nMels) != 0) {
return "error: failed to set mel"
}
// heat encoder
if (whisper_encode(context, 0, nThreads) != 0) {
return "error: failed to encode"
}
var tokens = [whisper_token](repeating: 0, count: 512)
// prompt heat
if (whisper_decode(context, &tokens, 256, 0, nThreads) != 0) {
return "error: failed to decode"
}
// text-generation heat
if (whisper_decode(context, &tokens, 1, 256, nThreads) != 0) {
return "error: failed to decode"
}
whisper_reset_timings(context)
// actual run
if (whisper_encode(context, 0, nThreads) != 0) {
return "error: failed to encode"
}
// text-generation
for i in 0..<256 {
if (whisper_decode(context, &tokens, 1, Int32(i), nThreads) != 0) {
return "error: failed to decode"
}
}
// batched decoding
for _ in 0..<64 {
if (whisper_decode(context, &tokens, 5, 0, nThreads) != 0) {
return "error: failed to decode"
}
}
// prompt processing
for _ in 0..<16 {
if (whisper_decode(context, &tokens, 256, 0, nThreads) != 0) {
return "error: failed to decode"
}
}
whisper_print_timings(context)
let deviceModel = await UIDevice.current.model
let systemName = await UIDevice.current.systemName
let systemInfo = self.systemInfo()
let timings: whisper_timings = whisper_get_timings(context).pointee
let encodeMs = String(format: "%.2f", timings.encode_ms)
let decodeMs = String(format: "%.2f", timings.decode_ms)
let batchdMs = String(format: "%.2f", timings.batchd_ms)
let promptMs = String(format: "%.2f", timings.prompt_ms)
return "| \(deviceModel) | \(systemName) | \(systemInfo) | \(modelName) | \(nThreads) | 1 | \(encodeMs) | \(decodeMs) | \(batchdMs) | \(promptMs) | <todo> |"
}
static func createContext(path: String) throws -> WhisperContext {
var params = whisper_context_default_params()
#if targetEnvironment(simulator)
params.use_gpu = false
print("Running on the simulator, using CPU")
#else
params.flash_attn = true // Enabled by default for Metal
#endif
let context = whisper_init_from_file_with_params(path, params)
if let context {

View File

@ -1,17 +0,0 @@
import Foundation
struct Model: Identifiable {
var id = UUID()
var name: String
var info: String
var url: String
var filename: String
var fileURL: URL {
FileManager.default.urls(for: .documentDirectory, in: .userDomainMask)[0].appendingPathComponent(filename)
}
func fileExists() -> Bool {
FileManager.default.fileExists(atPath: fileURL.path)
}
}

View File

@ -14,7 +14,7 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate {
private var recordedFile: URL? = nil
private var audioPlayer: AVAudioPlayer?
private var builtInModelUrl: URL? {
private var modelUrl: URL? {
Bundle.main.url(forResource: "ggml-base.en", withExtension: "bin", subdirectory: "models")
}
@ -28,59 +28,23 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate {
override init() {
super.init()
loadModel()
}
func loadModel(path: URL? = nil, log: Bool = true) {
do {
whisperContext = nil
if (log) { messageLog += "Loading model...\n" }
let modelUrl = path ?? builtInModelUrl
if let modelUrl {
whisperContext = try WhisperContext.createContext(path: modelUrl.path())
if (log) { messageLog += "Loaded model \(modelUrl.lastPathComponent)\n" }
} else {
if (log) { messageLog += "Could not locate model\n" }
}
try loadModel()
canTranscribe = true
} catch {
print(error.localizedDescription)
if (log) { messageLog += "\(error.localizedDescription)\n" }
messageLog += "\(error.localizedDescription)\n"
}
}
func benchCurrentModel() async {
if whisperContext == nil {
messageLog += "Cannot bench without loaded model\n"
return
private func loadModel() throws {
messageLog += "Loading model...\n"
if let modelUrl {
whisperContext = try WhisperContext.createContext(path: modelUrl.path())
messageLog += "Loaded model \(modelUrl.lastPathComponent)\n"
} else {
messageLog += "Could not locate model\n"
}
messageLog += "Running benchmark for loaded model\n"
let result = await whisperContext?.benchFull(modelName: "<current>", nThreads: Int32(min(4, cpuCount())))
if (result != nil) { messageLog += result! + "\n" }
}
func bench(models: [Model]) async {
let nThreads = Int32(min(4, cpuCount()))
// messageLog += "Running memcpy benchmark\n"
// messageLog += await WhisperContext.benchMemcpy(nThreads: nThreads) + "\n"
//
// messageLog += "Running ggml_mul_mat benchmark with \(nThreads) threads\n"
// messageLog += await WhisperContext.benchGgmlMulMat(nThreads: nThreads) + "\n"
messageLog += "Running benchmark for all downloaded models\n"
messageLog += "| CPU | OS | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit |\n"
messageLog += "| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |\n"
for model in models {
loadModel(path: model.fileURL, log: false)
if whisperContext == nil {
messageLog += "Cannot bench without loaded model\n"
break
}
let result = await whisperContext?.benchFull(modelName: model.name, nThreads: nThreads)
if (result != nil) { messageLog += result! + "\n" }
}
messageLog += "Benchmarking completed\n"
}
func transcribeSample() async {
@ -196,8 +160,3 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate {
isRecording = false
}
}
fileprivate func cpuCount() -> Int {
ProcessInfo.processInfo.processorCount
}

Some files were not shown because too many files have changed in this diff Show More