Compare commits

..

2 Commits

Author SHA1 Message Date
fa9621e5e9 mtl : update Makefile to support Metal 2022-11-12 08:32:03 +02:00
b5d3521626 mtl : matrix multiplication support
Seems to be only marginally faster compared to pure AMX
2022-11-09 18:29:38 +02:00
182 changed files with 3184 additions and 19947 deletions

View File

@ -1,22 +0,0 @@
name: Bindings Tests
on:
push:
paths:
- bindings/go/**
- whisper.h
pull_request:
paths:
- bindings/go/**
- whisper.h
jobs:
ubuntu-latest:
runs-on: ubuntu-latest
steps:
- uses: actions/setup-go@v3
with:
go-version: '^1.19'
- uses: actions/checkout@v1
- run: |
cd bindings/go
make test

View File

@ -1,267 +1,115 @@
name: CI
on: [push, pull_request]
on: [push]
jobs:
ubuntu-latest:
runs-on: ubuntu-latest
ubuntu-latest:
runs-on: ubuntu-latest
steps:
- name: Clone
uses: actions/checkout@v1
steps:
- name: Clone
uses: actions/checkout@v1
- name: Dependencies
run: |
sudo apt-get update
sudo apt-get install build-essential
sudo apt-get install libsdl2-dev
- name: Dependencies
run: |
sudo apt-get update
sudo apt-get install build-essential
sudo apt-get install libsdl2-dev
- name: Build
run: |
make
make stream
- name: Build
run: |
make
make stream
macOS-latest:
runs-on: macOS-latest
macOS-latest:
runs-on: macOS-latest
steps:
- name: Clone
uses: actions/checkout@v1
steps:
- name: Clone
uses: actions/checkout@v1
- name: Dependencies
run: |
brew update
brew install sdl2
- name: Dependencies
run: |
brew update
brew install sdl2
- name: Build
run: |
make
make stream
- name: Build
run: |
make
make stream
ubuntu-latest-gcc:
runs-on: ubuntu-latest
ubuntu-latest-gcc:
runs-on: ubuntu-latest
strategy:
matrix:
build: [Debug, Release]
strategy:
matrix:
build: [Debug, Release]
steps:
- name: Clone
uses: actions/checkout@v1
steps:
- name: Clone
uses: actions/checkout@v1
- name: Dependencies
run: |
sudo apt-get update
sudo apt-get install build-essential
sudo apt-get install cmake
sudo apt-get install libsdl2-dev
- name: Dependencies
run: |
sudo apt-get update
sudo apt-get install build-essential
sudo apt-get install cmake
sudo apt-get install libsdl2-dev
- name: Configure
run: cmake . -DWHISPER_SUPPORT_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }}
- name: Configure
run: cmake . -DWHISPER_SUPPORT_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }}
- name: Build
run: |
make
ctest -L gh --output-on-failure
- name: Build
run: |
make
ctest -L gh --output-on-failure
ubuntu-latest-clang:
runs-on: ubuntu-latest
ubuntu-latest-clang:
runs-on: ubuntu-latest
strategy:
matrix:
build: [Debug, Release]
strategy:
matrix:
build: [Debug, Release]
steps:
- name: Clone
uses: actions/checkout@v1
steps:
- name: Clone
uses: actions/checkout@v1
- name: Dependencies
run: |
sudo apt-get update
sudo apt-get install build-essential
sudo apt-get install cmake
sudo apt-get install libsdl2-dev
- name: Dependencies
run: |
sudo apt-get update
sudo apt-get install build-essential
sudo apt-get install cmake
sudo apt-get install libsdl2-dev
- name: Configure
run: cmake . -DWHISPER_SUPPORT_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_COMPILER=clang
- name: Configure
run: cmake . -DWHISPER_SUPPORT_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_COMPILER=clang
- name: Build
run: |
make
ctest -L gh --output-on-failure
- name: Build
run: |
make
ctest -L gh --output-on-failure
ubuntu-latest-gcc-sanitized:
runs-on: ubuntu-latest
ubuntu-latest-gcc-sanitized:
runs-on: ubuntu-latest
strategy:
matrix:
sanitizer: [ADDRESS, THREAD, UNDEFINED]
strategy:
matrix:
sanitizer: [ADDRESS, THREAD, UNDEFINED]
steps:
- name: Clone
uses: actions/checkout@v1
steps:
- name: Clone
uses: actions/checkout@v1
- name: Dependencies
run: |
sudo apt-get update
sudo apt-get install build-essential
sudo apt-get install cmake
- name: Dependencies
run: |
sudo apt-get update
sudo apt-get install build-essential
sudo apt-get install cmake
- name: Configure
run: cmake . -DCMAKE_BUILD_TYPE=Debug -DWHISPER_SANITIZE_${{ matrix.sanitizer }}=ON
- name: Configure
run: cmake . -DCMAKE_BUILD_TYPE=Debug -DWHISPER_SANITIZE_${{ matrix.sanitizer }}=ON
- name: Build
run: |
make
ctest -L gh --output-on-failure
windows:
runs-on: windows-latest
strategy:
matrix:
build: [Release]
arch: [Win32, x64]
sdl2: [ON]
include:
- arch: Win32
s2arc: x86
- arch: x64
s2arc: x64
- sdl2: ON
s2ver: 2.26.0
steps:
- name: Clone
uses: actions/checkout@v1
- name: Add msbuild to PATH
uses: microsoft/setup-msbuild@v1
- name: Fetch SDL2 and set SDL2_DIR
if: matrix.sdl2 == 'ON'
run: |
C:/msys64/usr/bin/wget.exe -qO sdl2.zip https://github.com/libsdl-org/SDL/releases/download/release-${{ matrix.s2ver }}/SDL2-devel-${{ matrix.s2ver }}-VC.zip
7z x sdl2.zip
echo "SDL2_DIR=$env:GITHUB_WORKSPACE/SDL2-${{ matrix.s2ver }}/cmake" >> $env:GITHUB_ENV
- name: Configure
run: >
cmake -S . -B ./build -A ${{ matrix.arch }}
-DCMAKE_BUILD_TYPE=${{ matrix.build }}
-DWHISPER_SUPPORT_SDL2=${{ matrix.sdl2 }}
- name: Build
run: |
cd ./build
msbuild ALL_BUILD.vcxproj -t:build -p:configuration=${{ matrix.build }} -p:platform=${{ matrix.arch }}
- name: Copy SDL2.dll
if: matrix.sdl2 == 'ON'
run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }}
- name: Upload binaries
if: matrix.sdl2 == 'ON'
uses: actions/upload-artifact@v1
with:
name: whisper-bin-${{ matrix.arch }}
path: build/bin/${{ matrix.build }}
windows-blas:
runs-on: windows-latest
strategy:
matrix:
build: [Release]
arch: [Win32, x64]
blas: [ON]
sdl2: [ON]
include:
- arch: Win32
obzip: https://github.com/xianyi/OpenBLAS/releases/download/v0.3.21/OpenBLAS-0.3.21-x86.zip
s2arc: x86
- arch: x64
obzip: https://github.com/xianyi/OpenBLAS/releases/download/v0.3.21/OpenBLAS-0.3.21-x64.zip
s2arc: x64
- sdl2: ON
s2ver: 2.26.0
steps:
- name: Clone
uses: actions/checkout@v1
- name: Add msbuild to PATH
uses: microsoft/setup-msbuild@v1
- name: Fetch OpenBLAS
if: matrix.blas == 'ON'
run: |
C:/msys64/usr/bin/wget.exe -qO blas.zip ${{ matrix.obzip }}
7z x blas.zip -oblas -y
copy blas/include/cblas.h .
copy blas/include/openblas_config.h .
echo "blasdir=$env:GITHUB_WORKSPACE/blas" >> $env:GITHUB_ENV
- name: Fetch SDL2 and set SDL2_DIR
if: matrix.sdl2 == 'ON'
run: |
C:/msys64/usr/bin/wget.exe -qO sdl2.zip https://github.com/libsdl-org/SDL/releases/download/release-${{ matrix.s2ver }}/SDL2-devel-${{ matrix.s2ver }}-VC.zip
7z x sdl2.zip
echo "SDL2_DIR=$env:GITHUB_WORKSPACE/SDL2-${{ matrix.s2ver }}/cmake" >> $env:GITHUB_ENV
- name: Configure
run: >
cmake -S . -B ./build -A ${{ matrix.arch }}
-DCMAKE_BUILD_TYPE=${{ matrix.build }}
-DWHISPER_SUPPORT_OPENBLAS=${{ matrix.blas }}
-DCMAKE_LIBRARY_PATH="$env:blasdir/lib"
-DWHISPER_SUPPORT_SDL2=${{ matrix.sdl2 }}
- name: Build
run: |
cd ./build
msbuild ALL_BUILD.vcxproj -t:build -p:configuration=${{ matrix.build }} -p:platform=${{ matrix.arch }}
- name: Copy libopenblas.dll
if: matrix.blas == 'ON'
run: copy "$env:blasdir/bin/libopenblas.dll" build/bin/${{ matrix.build }}
- name: Copy SDL2.dll
if: matrix.sdl2 == 'ON'
run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }}
- name: Upload binaries
if: matrix.blas == 'ON' && matrix.sdl2 == 'ON'
uses: actions/upload-artifact@v1
with:
name: whisper-blas-bin-${{ matrix.arch }}
path: build/bin/${{ matrix.build }}
emscripten:
runs-on: ubuntu-latest
strategy:
matrix:
build: [Release]
steps:
- name: Clone
uses: actions/checkout@v1
- name: Dependencies
run: |
wget -q https://github.com/emscripten-core/emsdk/archive/master.tar.gz
tar -xvf master.tar.gz
emsdk-master/emsdk update
emsdk-master/emsdk install latest
emsdk-master/emsdk activate latest
- name: Configure
run: echo "tmp"
- name: Build
run: |
pushd emsdk-master
source ./emsdk_env.sh
popd
emcmake cmake . -DCMAKE_BUILD_TYPE=${{ matrix.build }}
make
- name: Build
run: |
make
ctest -L gh --output-on-failure

17
.gitignore vendored
View File

@ -1,5 +1,4 @@
*.o
*.a
.cache/
.vs/
.vscode/
@ -9,25 +8,15 @@ build/
build-em/
build-debug/
build-release/
build-static/
build-sanitize-addr/
build-sanitize-thread/
/main
/stream
/command
/talk
/bench
arm_neon.h
main
stream
bench
sync.sh
libwhisper.a
libwhisper.so
compile_commands.json
examples/arm_neon.h
examples/whisper.objc/whisper.objc.xcodeproj/xcshareddata
examples/whisper.objc/whisper.objc.xcodeproj/xcuserdata/
examples/whisper.objc/whisper.objc.xcodeproj/project.xcworkspace/xcuserdata
extra/bench-gg.txt

View File

@ -1,22 +1,19 @@
cmake_minimum_required (VERSION 3.0)
project(whisper.cpp VERSION 1.0.0)
project(whisper.cpp VERSION 1.2.0)
# Add path to modules
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
set(CMAKE_EXPORT_COMPILE_COMMANDS "on")
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
set(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_PREFIX}/lib")
if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
set(WHISPER_STANDALONE ON)
include(GitVars)
include(BuildTypes)
include(cmake/GitVars.cmake)
include(cmake/BuildTypes.cmake)
# configure project version
if (EXISTS "${CMAKE_SOURCE_DIR}/bindings/ios/Makefile-tmpl")
configure_file(${CMAKE_SOURCE_DIR}/bindings/ios/Makefile-tmpl ${CMAKE_SOURCE_DIR}/bindings/ios/Makefile @ONLY)
endif()
configure_file(${CMAKE_SOURCE_DIR}/bindings/javascript/package-tmpl.json ${CMAKE_SOURCE_DIR}/bindings/javascript/package.json @ONLY)
else()
set(WHISPER_STANDALONE OFF)
endif()
@ -51,9 +48,6 @@ option(WHISPER_SUPPORT_SDL2 "whisper: support for libSDL2" OFF)
if (APPLE)
option(WHISPER_NO_ACCELERATE "whisper: disable Accelerate framework" OFF)
option(WHISPER_NO_AVX "whisper: disable AVX" OFF)
option(WHISPER_NO_AVX2 "whisper: disable AVX2" OFF)
option(WHISPER_NO_FMA "whisper: disable FMA" OFF)
else()
option(WHISPER_SUPPORT_OPENBLAS "whisper: support for OpenBLAS" OFF)
endif()
@ -84,6 +78,9 @@ endif()
# dependencies
set(CMAKE_C_STANDARD 11)
set(CMAKE_CXX_STANDARD 20)
find_package(Threads REQUIRED)
# on APPLE - include Accelerate framework
@ -97,12 +94,21 @@ if (APPLE AND NOT WHISPER_NO_ACCELERATE)
else()
message(WARNING "Accelerate framework not found")
endif()
find_library(FOUNDATION_LIBRARY Foundation REQUIRED)
find_library(METAL_FRAMEWORK Metal REQUIRED)
find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
find_library(METALPERFORMANCE_FRAMEWORK MetalPerformanceShaders REQUIRED)
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS}
${FOUNDATION_LIBRARY}
${METAL_FRAMEWORK}
${METALKIT_FRAMEWORK}
${METALPERFORMANCE_FRAMEWORK})
endif()
if (WHISPER_SUPPORT_OPENBLAS)
find_library(OPENBLAS_LIB
NAMES openblas libopenblas
)
find_library(OPENBLAS_LIB openblas)
if (OPENBLAS_LIB)
message(STATUS "OpenBLAS found")
@ -130,13 +136,6 @@ if (WHISPER_ALL_WARNINGS)
-Wcast-qual \
-Wstrict-prototypes \
-Wpointer-arith \
-Wno-unused-function \
")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} \
-Wall \
-Wextra \
-Wpedantic \
-Wcast-qual \
")
else()
# todo : msvc
@ -155,24 +154,15 @@ if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm" OR ${CMAKE_SYSTEM_PROCESSOR} MATCHES
else()
message(STATUS "x86 detected")
if (MSVC)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX2")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /arch:AVX2")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /arch:AVX2")
set(CMAKE_C_FLAGS_RELEASE "${CMAKE_C_FLAGS_RELEASE} /arch:AVX2")
else()
if (EMSCRIPTEN)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -pthread")
# we require support for WASM SIMD 128-bit
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -pthread -msimd128")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread")
else()
if(NOT WHISPER_NO_AVX)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx")
endif()
if(NOT WHISPER_NO_AVX2)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx2")
endif()
if(NOT WHISPER_NO_FMA)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mfma")
endif()
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mf16c")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx -mavx2 -mfma -mf16c")
endif()
endif()
endif()
@ -188,14 +178,11 @@ endif()
set(TARGET whisper)
add_library(${TARGET}
ggml.h
ggml.c
whisper.h
ggml-mtl.m
whisper.cpp
)
include(DefaultTargetOptions)
target_include_directories(${TARGET} PUBLIC
.
)
@ -218,10 +205,6 @@ if (BUILD_SHARED_LIBS)
)
endif()
if (EMSCRIPTEN)
set_target_properties(${TARGET} PROPERTIES COMPILE_FLAGS "-msimd128")
endif()
target_compile_definitions(${TARGET} PUBLIC
${WHISPER_EXTRA_FLAGS}
)
@ -229,7 +212,6 @@ target_compile_definitions(${TARGET} PUBLIC
install(TARGETS ${TARGET}
LIBRARY DESTINATION lib
ARCHIVE DESTINATION lib/static
RUNTIME DESTINATION bin
)
#
@ -242,11 +224,13 @@ add_subdirectory(bindings)
# programs, examples and tests
#
if (WHISPER_BUILD_TESTS)
enable_testing()
add_subdirectory(tests)
endif ()
if (WHISPER_STANDALONE)
if (WHISPER_BUILD_TESTS)
enable_testing()
add_subdirectory(tests)
endif ()
if (WHISPER_BUILD_EXAMPLES)
add_subdirectory(examples)
endif()
if (WHISPER_BUILD_EXAMPLES)
add_subdirectory(examples)
endif()
endif ()

146
Makefile
View File

@ -10,9 +10,6 @@ ifndef UNAME_M
UNAME_M := $(shell uname -m)
endif
CCV := $(shell $(CC) --version | head -n 1)
CXXV := $(shell $(CXX) --version | head -n 1)
# Mac OS + Arm can report x86_64
# ref: https://github.com/ggerganov/whisper.cpp/issues/66#issuecomment-1282546789
ifeq ($(UNAME_S),Darwin)
@ -30,8 +27,8 @@ endif
# Compile flags
#
CFLAGS = -I. -O3 -std=c11 -fPIC
CXXFLAGS = -I. -I./examples -O3 -std=c++11 -fPIC
CFLAGS = -I. -O3 -std=c11
CXXFLAGS = -I. -I./examples -O3 -std=c++11
LDFLAGS =
# OS specific
@ -48,98 +45,23 @@ ifeq ($(UNAME_S),FreeBSD)
CFLAGS += -pthread
CXXFLAGS += -pthread
endif
ifeq ($(UNAME_S),Haiku)
CFLAGS += -pthread
CXXFLAGS += -pthread
endif
# Architecture specific
# TODO: probably these flags need to be tweaked on some architectures
# feel free to update the Makefile for your architecture and send a pull request or issue
ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686))
ifeq ($(UNAME_S),Darwin)
CFLAGS += -mf16c
AVX1_M := $(shell sysctl machdep.cpu.features)
ifneq (,$(findstring FMA,$(AVX1_M)))
CFLAGS += -mfma
endif
ifneq (,$(findstring AVX1.0,$(AVX1_M)))
CFLAGS += -mavx
endif
AVX2_M := $(shell sysctl machdep.cpu.leaf7_features)
ifneq (,$(findstring AVX2,$(AVX2_M)))
CFLAGS += -mavx2
endif
else ifeq ($(UNAME_S),Linux)
AVX1_M := $(shell grep "avx " /proc/cpuinfo)
ifneq (,$(findstring avx,$(AVX1_M)))
CFLAGS += -mavx
endif
AVX2_M := $(shell grep "avx2 " /proc/cpuinfo)
ifneq (,$(findstring avx2,$(AVX2_M)))
CFLAGS += -mavx2
endif
FMA_M := $(shell grep "fma " /proc/cpuinfo)
ifneq (,$(findstring fma,$(FMA_M)))
CFLAGS += -mfma
endif
F16C_M := $(shell grep "f16c " /proc/cpuinfo)
ifneq (,$(findstring f16c,$(F16C_M)))
CFLAGS += -mf16c
endif
SSE3_M := $(shell grep "sse3 " /proc/cpuinfo)
ifneq (,$(findstring sse3,$(SSE3_M)))
CFLAGS += -msse3
endif
else ifeq ($(UNAME_S),Haiku)
AVX1_M := $(shell sysinfo -cpu | grep "AVX ")
ifneq (,$(findstring avx,$(AVX1_M)))
CFLAGS += -mavx
endif
AVX2_M := $(shell sysinfo -cpu | grep "AVX2 ")
ifneq (,$(findstring avx2,$(AVX2_M)))
CFLAGS += -mavx2
endif
FMA_M := $(shell sysinfo -cpu | grep "FMA ")
ifneq (,$(findstring fma,$(FMA_M)))
CFLAGS += -mfma
endif
F16C_M := $(shell sysinfo -cpu | grep "F16C ")
ifneq (,$(findstring f16c,$(F16C_M)))
CFLAGS += -mf16c
endif
else
CFLAGS += -mfma -mf16c -mavx -mavx2
endif
ifeq ($(UNAME_M),x86_64)
CFLAGS += -mavx -mavx2 -mfma -mf16c
endif
ifeq ($(UNAME_M),amd64)
CFLAGS += -mavx -mavx2 -mfma -mf16c
endif
ifneq ($(filter ppc64%,$(UNAME_M)),)
POWER9_M := $(shell grep "POWER9" /proc/cpuinfo)
ifneq (,$(findstring POWER9,$(POWER9_M)))
CFLAGS += -mpower9-vector
endif
# Require c++23's std::byteswap for big-endian support.
ifeq ($(UNAME_M),ppc64)
CXXFLAGS += -std=c++23 -DGGML_BIG_ENDIAN
endif
endif
ifndef WHISPER_NO_ACCELERATE
# Mac M1 - include Accelerate framework
ifeq ($(UNAME_S),Darwin)
CFLAGS += -DGGML_USE_ACCELERATE
LDFLAGS += -framework Accelerate
CFLAGS += -DGGML_USE_ACCELERATE -DGGML_PERF
LDFLAGS += -framework Foundation -framework Accelerate -framework Metal -framework MetalKit -framework MetalPerformanceShaders
endif
endif
ifdef WHISPER_OPENBLAS
CFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/openblas
LDFLAGS += -lopenblas
endif
ifdef WHISPER_GPROF
CFLAGS += -pg
CXXFLAGS += -pg
endif
ifneq ($(filter aarch64%,$(UNAME_M)),)
endif
ifneq ($(filter armv6%,$(UNAME_M)),)
@ -156,40 +78,27 @@ ifneq ($(filter armv8%,$(UNAME_M)),)
endif
#
# Print build information
# Build library + main
#
$(info I whisper.cpp build info: )
$(info I UNAME_S: $(UNAME_S))
$(info I UNAME_P: $(UNAME_P))
$(info I UNAME_M: $(UNAME_M))
$(info I CFLAGS: $(CFLAGS))
$(info I CXXFLAGS: $(CXXFLAGS))
$(info I LDFLAGS: $(LDFLAGS))
$(info I CC: $(CCV))
$(info I CXX: $(CXXV))
$(info )
default: main
#
# Build library
#
main: examples/main/main.cpp ggml.o ggml-mtl.o whisper.o
$(CXX) $(CXXFLAGS) examples/main/main.cpp whisper.o ggml.o ggml-mtl.o -o main $(LDFLAGS)
./main -h
ggml.o: ggml.c ggml.h
$(CC) $(CFLAGS) -c ggml.c -o ggml.o
ggml-mtl.o: ggml-mtl.m ggml-mtl.h
$(CC) $(CFLAGS) -c ggml-mtl.m -o ggml-mtl.o
whisper.o: whisper.cpp whisper.h
$(CXX) $(CXXFLAGS) -c whisper.cpp -o whisper.o
libwhisper.a: ggml.o whisper.o
$(AR) rcs libwhisper.a ggml.o whisper.o
libwhisper.so: ggml.o whisper.o
$(CXX) $(CXXFLAGS) -shared -o libwhisper.so ggml.o whisper.o $(LDFLAGS)
libwhisper.a: ggml.o ggml-mtl.o whisper.o
$(AR) rcs libwhisper.a ggml.o ggml-mtl.o whisper.o
clean:
rm -f *.o main stream command talk bench libwhisper.a libwhisper.so
rm -f *.o main stream bench libwhisper.a
#
# Examples
@ -197,19 +106,9 @@ clean:
CC_SDL=`sdl2-config --cflags --libs`
main: examples/main/main.cpp ggml.o whisper.o
$(CXX) $(CXXFLAGS) examples/main/main.cpp ggml.o whisper.o -o main $(LDFLAGS)
./main -h
stream: examples/stream/stream.cpp ggml.o whisper.o
$(CXX) $(CXXFLAGS) examples/stream/stream.cpp ggml.o whisper.o -o stream $(CC_SDL) $(LDFLAGS)
command: examples/command/command.cpp ggml.o whisper.o
$(CXX) $(CXXFLAGS) examples/command/command.cpp ggml.o whisper.o -o command $(CC_SDL) $(LDFLAGS)
talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp ggml.o whisper.o
$(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp ggml.o whisper.o -o talk $(CC_SDL) $(LDFLAGS)
bench: examples/bench/bench.cpp ggml.o whisper.o
$(CXX) $(CXXFLAGS) examples/bench/bench.cpp ggml.o whisper.o -o bench $(LDFLAGS)
@ -248,10 +147,9 @@ samples:
.PHONY: small
.PHONY: medium.en
.PHONY: medium
.PHONY: large-v1
.PHONY: large
tiny.en tiny base.en base small.en small medium.en medium large-v1 large: main
tiny.en tiny base.en base small.en small medium.en medium large: main
bash ./models/download-ggml-model.sh $@
@echo ""
@echo "==============================================="
@ -260,17 +158,9 @@ tiny.en tiny base.en base small.en small medium.en medium large-v1 large: main
@echo ""
@for f in samples/*.wav; do \
echo "----------------------------------------------" ; \
echo "[+] Running $@ on $$f ... (run 'ffplay $$f' to listen)" ; \
echo "[+] Running base.en on $$f ... (run 'ffplay $$f' to listen)" ; \
echo "----------------------------------------------" ; \
echo "" ; \
./main -m models/ggml-$@.bin -f $$f ; \
echo "" ; \
done
#
# Tests
#
.PHONY: tests
tests:
bash ./tests/run-tests.sh

283
README.md
View File

@ -2,18 +2,14 @@
[![Actions Status](https://github.com/ggerganov/whisper.cpp/workflows/CI/badge.svg)](https://github.com/ggerganov/whisper.cpp/actions)
[![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](https://opensource.org/licenses/MIT)
[![npm](https://img.shields.io/npm/v/whisper.cpp.svg)](https://www.npmjs.com/package/whisper.cpp/)
Stable: [v1.2.0](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.2.0) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126)
High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model:
- Plain C/C++ implementation without dependencies
- Apple silicon first-class citizen - optimized via Arm Neon and Accelerate framework
- AVX intrinsics support for x86 architectures
- VSX intrinsics support for POWER architectures
- Mixed F16 / F32 precision
- Low memory usage (Flash Attention)
- Low memory usage (Flash Attention + Flash Forward)
- Zero memory allocations at runtime
- Runs on the CPU
- [C-style API](https://github.com/ggerganov/whisper.cpp/blob/master/whisper.h)
@ -22,11 +18,11 @@ Supported platforms:
- [x] Mac OS (Intel and Arm)
- [x] [iOS](examples/whisper.objc)
- [x] [Android](examples/whisper.android)
- [x] Linux / [FreeBSD](https://github.com/ggerganov/whisper.cpp/issues/56#issuecomment-1350920264)
- [x] Linux
- [x] [WebAssembly](examples/whisper.wasm)
- [x] Windows ([MSVC](https://github.com/ggerganov/whisper.cpp/blob/master/.github/workflows/build.yml#L117-L144) and [MinGW](https://github.com/ggerganov/whisper.cpp/issues/168)]
- [x] [Raspberry Pi](https://github.com/ggerganov/whisper.cpp/discussions/166)
- [x] [Windows (MSVC and MinGW)](https://github.com/ggerganov/whisper.cpp/issues/5)
- [x] [Raspberry Pi](https://github.com/ggerganov/whisper.cpp/issues/7)
- [x] [Android](https://github.com/ggerganov/whisper.cpp/issues/30)
The entire implementation of the model is contained in 2 source files:
@ -34,16 +30,10 @@ The entire implementation of the model is contained in 2 source files:
- Transformer inference: [whisper.h](whisper.h) / [whisper.cpp](whisper.cpp)
Having such a lightweight implementation of the model allows to easily integrate it in different platforms and applications.
As an example, here is a video of running the model on an iPhone 13 device - fully offline, on-device: [whisper.objc](examples/whisper.objc)
As an example, here is a video of running the model on an iPhone 13 device - fully offline, on-device:
https://user-images.githubusercontent.com/1991296/197385372-962a6dea-bca1-4d50-bf96-1d8c27b98c81.mp4
You can also easily make your own offline voice assistant application: [command](examples/command)
https://user-images.githubusercontent.com/1991296/204038393-2f846eae-c255-4099-a76d-5735c25c49da.mp4
Or you can even run it straight in the browser: [talk.wasm](examples/talk.wasm)
## Implementation details
- The core tensor operations are implemented in C ([ggml.h](ggml.h) / [ggml.c](ggml.c))
@ -56,6 +46,21 @@ The tensor operators are optimized heavily for Apple silicon CPUs. Depending on
instrisics or CBLAS Accelerate framework routines are used. The latter are especially effective for bigger sizes since
the Accelerate framework utilizes the special-purpose AMX coprocessor available in modern Apple products.
## Limitations
- Inference only
- No GPU support
- Very basic greedy sampling scheme - always pick up the token with highest probability.
This should be similar to the [GreedyDecoder](https://github.com/openai/whisper/blob/main/whisper/decoding.py#L249-L274)
from the original python implementation, so in order to make a fair comparison between the 2 implementations, make sure
to run the python code with the following parameters:
```
whisper --best_of None --beam_size None ...
```
In the future, `whisper.cpp` will support more sampling strategies.
## Quick start
First, download one of the Whisper models converted in [ggml format](models). For example:
@ -71,7 +76,7 @@ Now build the [main](examples/main) example and transcribe an audio file like th
make
# transcribe an audio file
./main -f samples/jfk.wav
./main -f input.wav
```
---
@ -89,38 +94,27 @@ c++ -I. -I./examples -O3 -std=c++11 -pthread examples/main/main.cpp whisper.o gg
usage: ./main [options] file0.wav file1.wav ...
options:
-h, --help [default] show this help message and exit
-t N, --threads N [4 ] number of threads to use during computation
-p N, --processors N [1 ] number of processors to use during computation
-ot N, --offset-t N [0 ] time offset in milliseconds
-on N, --offset-n N [0 ] segment index offset
-d N, --duration N [0 ] duration of audio to process in milliseconds
-mc N, --max-context N [-1 ] maximum number of text context tokens to store
-ml N, --max-len N [0 ] maximum segment length in characters
-bo N, --best-of N [5 ] number of best candidates to keep
-bs N, --beam-size N [-1 ] beam size for beam search
-wt N, --word-thold N [0.01 ] word timestamp probability threshold
-et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail
-lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail
-su, --speed-up [false ] speed up audio by x2 (reduced accuracy)
-tr, --translate [false ] translate from source language to english
-di, --diarize [false ] stereo audio diarization
-nf, --no-fallback [false ] do not use temperature fallback while decoding
-otxt, --output-txt [false ] output result in a text file
-ovtt, --output-vtt [false ] output result in a vtt file
-osrt, --output-srt [false ] output result in a srt file
-owts, --output-words [false ] output script for generating karaoke video
-ocsv, --output-csv [false ] output result in a CSV file
-of FNAME, --output-file FNAME [ ] output file path (without file extension)
-ps, --print-special [false ] print special tokens
-pc, --print-colors [false ] print colors
-pp, --print-progress [false ] print progress
-nt, --no-timestamps [true ] do not print timestamps
-l LANG, --language LANG [en ] spoken language ('auto' for auto-detect)
--prompt PROMPT [ ] initial prompt
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path
-f FNAME, --file FNAME [ ] input WAV file path
-h, --help show this help message and exit
-s SEED, --seed SEED RNG seed (default: -1)
-t N, --threads N number of threads to use during computation (default: 4)
-p N, --processors N number of processors to use during computation (default: 1)
-ot N, --offset-t N time offset in milliseconds (default: 0)
-on N, --offset-n N segment index offset (default: 0)
-mc N, --max-context N maximum number of text context tokens to store (default: max)
-ml N, --max-len N maximum segment length in characters (default: 0)
-wt N, --word-thold N word timestamp probability threshold (default: 0.010000)
-v, --verbose verbose output
--translate translate from source language to english
-otxt, --output-txt output result in a text file
-ovtt, --output-vtt output result in a vtt file
-osrt, --output-srt output result in a srt file
-owts, --output-words output script for generating karaoke video
-ps, --print_special print special tokens
-pc, --print_colors print colors
-nt, --no_timestamps do not print timestamps
-l LANG, --language LANG spoken language (default: en)
-m FNAME, --model FNAME model path (default: models/ggml-base.en.bin)
-f FNAME, --file FNAME input WAV file path
bash ./models/download-ggml-model.sh base.en
Downloading ggml model base.en ...
@ -139,8 +133,7 @@ Running base.en on all samples in ./samples ...
[+] Running base.en on samples/jfk.wav ... (run 'ffplay samples/jfk.wav' to listen)
----------------------------------------------
whisper_init_from_file: loading model from 'models/ggml-base.en.bin'
whisper_model_load: loading model
whisper_model_load: loading model from 'models/ggml-base.en.bin'
whisper_model_load: n_vocab = 51864
whisper_model_load: n_audio_ctx = 1500
whisper_model_load: n_audio_state = 512
@ -153,14 +146,13 @@ whisper_model_load: n_text_layer = 6
whisper_model_load: n_mels = 80
whisper_model_load: f16 = 1
whisper_model_load: type = 2
whisper_model_load: mem required = 215.00 MB (+ 6.00 MB per decoder)
whisper_model_load: kv self size = 5.25 MB
whisper_model_load: kv cross size = 17.58 MB
whisper_model_load: mem_required = 670.00 MB
whisper_model_load: adding 1607 extra tokens
whisper_model_load: model ctx = 140.60 MB
whisper_model_load: model size = 140.54 MB
whisper_model_load: ggml ctx size = 140.60 MB
whisper_model_load: memory size = 22.83 MB
whisper_model_load: model size = 140.54 MB
system_info: n_threads = 4 / 10 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 |
system_info: n_threads = 4 / 10 | AVX2 = 0 | AVX512 = 0 | NEON = 1 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 |
main: processing 'samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 processors, lang = en, task = transcribe, timestamps = 1 ...
@ -168,13 +160,12 @@ main: processing 'samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 proc
[00:00:00.000 --> 00:00:11.000] And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country.
whisper_print_timings: fallbacks = 0 p / 0 h
whisper_print_timings: load time = 113.81 ms
whisper_print_timings: mel time = 15.40 ms
whisper_print_timings: sample time = 11.58 ms / 27 runs ( 0.43 ms per run)
whisper_print_timings: encode time = 266.60 ms / 1 runs ( 266.60 ms per run)
whisper_print_timings: decode time = 66.11 ms / 27 runs ( 2.45 ms per run)
whisper_print_timings: total time = 476.31 ms
whisper_print_timings: load time = 105.91 ms
whisper_print_timings: mel time = 24.62 ms
whisper_print_timings: sample time = 3.63 ms
whisper_print_timings: encode time = 324.71 ms / 54.12 ms per layer
whisper_print_timings: decode time = 83.58 ms / 13.93 ms per layer
whisper_print_timings: total time = 542.81 ms
```
The command downloads the `base.en` model converted to custom `ggml` format and runs the inference on all `.wav` samples in the folder `samples`.
@ -209,7 +200,6 @@ make small.en
make small
make medium.en
make medium
make large-v1
make large
```
@ -217,16 +207,11 @@ make large
| Model | Disk | Mem | SHA |
| --- | --- | --- | --- |
| tiny | 75 MB | ~125 MB | `bd577a113a864445d4c299885e0cb97d4ba92b5f` |
| base | 142 MB | ~210 MB | `465707469ff3a37a2b9b8d8f89f2f99de7299dac` |
| small | 466 MB | ~600 MB | `55356645c2b361a969dfd0ef2c5a50d530afd8d5` |
| medium | 1.5 GB | ~1.7 GB | `fd9727b6e1217c2f614f9b698455c4ffd82463b4` |
| large | 2.9 GB | ~3.3 GB | `0f4c8e34f21cf1a914c59d8b3ce882345ad349d6` |
## Limitations
- Inference only
- No GPU support (yet)
| tiny | 75 MB | ~390 MB | `bd577a113a864445d4c299885e0cb97d4ba92b5f` |
| base | 142 MB | ~500 MB | `465707469ff3a37a2b9b8d8f89f2f99de7299dac` |
| small | 466 MB | ~1.0 GB | `55356645c2b361a969dfd0ef2c5a50d530afd8d5` |
| medium | 1.5 GB | ~2.6 GB | `fd9727b6e1217c2f614f9b698455c4ffd82463b4` |
| large | 2.9 GB | ~4.7 GB | `b1caaf735c4cc1429223d5a74f0f4d0b9b59a299` |
## Another example
@ -239,8 +224,7 @@ in about half a minute on a MacBook M1 Pro, using `medium.en` model:
```java
$ ./main -m models/ggml-medium.en.bin -f samples/gb1.wav -t 8
whisper_init_from_file: loading model from 'models/ggml-medium.en.bin'
whisper_model_load: loading model
whisper_model_load: loading model from 'models/ggml-medium.en.bin'
whisper_model_load: n_vocab = 51864
whisper_model_load: n_audio_ctx = 1500
whisper_model_load: n_audio_state = 1024
@ -253,60 +237,55 @@ whisper_model_load: n_text_layer = 24
whisper_model_load: n_mels = 80
whisper_model_load: f16 = 1
whisper_model_load: type = 4
whisper_model_load: mem required = 1720.00 MB (+ 43.00 MB per decoder)
whisper_model_load: kv self size = 42.00 MB
whisper_model_load: kv cross size = 140.62 MB
whisper_model_load: mem_required = 2610.00 MB
whisper_model_load: adding 1607 extra tokens
whisper_model_load: model ctx = 1462.35 MB
whisper_model_load: model size = 1462.12 MB
whisper_model_load: ggml ctx size = 1644.97 MB
whisper_model_load: memory size = 182.62 MB
whisper_model_load: model size = 1462.12 MB
system_info: n_threads = 8 / 10 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 |
main: processing 'samples/gb1.wav' (3179750 samples, 198.7 sec), 8 threads, lang = en, task = transcribe, timestamps = 1 ...
main: processing 'samples/gb1.wav' (3179750 samples, 198.7 sec), 8 threads, 1 processors, lang = en, task = transcribe, timestamps = 1 ...
[00:00.000 --> 00:08.000] My fellow Americans, this day has brought terrible news and great sadness to our country.
[00:08.000 --> 00:17.000] At nine o'clock this morning, Mission Control in Houston lost contact with our Space Shuttle Columbia.
[00:17.000 --> 00:23.000] A short time later, debris was seen falling from the skies above Texas.
[00:23.000 --> 00:29.000] The Columbia's lost. There are no survivors.
[00:29.000 --> 00:32.000] On board was a crew of seven.
[00:32.000 --> 00:39.000] Colonel Rick Husband, Lieutenant Colonel Michael Anderson, Commander Laurel Clark,
[00:39.000 --> 00:48.000] Captain David Brown, Commander William McCool, Dr. Kultna Shavla, and Ilan Ramon,
[00:48.000 --> 00:52.000] a colonel in the Israeli Air Force.
[00:52.000 --> 00:58.000] These men and women assumed great risk in the service to all humanity.
[00:58.000 --> 01:03.000] In an age when space flight has come to seem almost routine,
[01:03.000 --> 01:07.000] it is easy to overlook the dangers of travel by rocket
[01:07.000 --> 01:12.000] and the difficulties of navigating the fierce outer atmosphere of the Earth.
[01:12.000 --> 01:18.000] These astronauts knew the dangers, and they faced them willingly,
[01:18.000 --> 01:23.000] knowing they had a high and noble purpose in life.
[01:23.000 --> 01:31.000] Because of their courage and daring and idealism, we will miss them all the more.
[01:31.000 --> 01:36.000] All Americans today are thinking as well of the families of these men and women
[01:36.000 --> 01:40.000] who have been given this sudden shock and grief.
[01:40.000 --> 01:45.000] You're not alone. Our entire nation grieves with you,
[01:45.000 --> 01:52.000] and those you love will always have the respect and gratitude of this country.
[01:52.000 --> 01:56.000] The cause in which they died will continue.
[01:56.000 --> 02:04.000] Mankind is led into the darkness beyond our world by the inspiration of discovery
[02:04.000 --> 02:11.000] and the longing to understand. Our journey into space will go on.
[02:11.000 --> 02:16.000] In the skies today, we saw destruction and tragedy.
[02:16.000 --> 02:22.000] Yet farther than we can see, there is comfort and hope.
[02:22.000 --> 02:29.000] In the words of the prophet Isaiah, "Lift your eyes and look to the heavens
[02:29.000 --> 02:35.000] who created all these. He who brings out the starry hosts one by one
[02:35.000 --> 02:39.000] and calls them each by name."
[02:39.000 --> 02:46.000] Because of His great power and mighty strength, not one of them is missing.
[02:46.000 --> 02:55.000] The same Creator who names the stars also knows the names of the seven souls we mourn today.
[02:55.000 --> 03:01.000] The crew of the shuttle Columbia did not return safely to earth,
[03:01.000 --> 03:05.000] yet we can pray that all are safely home.
[03:05.000 --> 03:13.000] May God bless the grieving families, and may God continue to bless America.
[03:13.000 --> 03:41.000] Audio
[00:00:00.000 --> 00:00:08.000] My fellow Americans, this day has brought terrible news and great sadness to our country.
[00:00:08.000 --> 00:00:17.000] At nine o'clock this morning, Mission Control in Houston lost contact with our Space Shuttle Columbia.
[00:00:17.000 --> 00:00:23.000] A short time later, debris was seen falling from the skies above Texas.
[00:00:23.000 --> 00:00:29.000] The Columbia's lost. There are no survivors.
[00:00:29.000 --> 00:00:32.000] On board was a crew of seven.
[00:00:32.000 --> 00:00:39.000] Colonel Rick Husband, Lieutenant Colonel Michael Anderson, Commander Laurel Clark,
[00:00:39.000 --> 00:00:48.000] Captain David Brown, Commander William McCool, Dr. Kultna Shavla, and Ilan Ramon,
[00:00:48.000 --> 00:00:52.000] a colonel in the Israeli Air Force.
[00:00:52.000 --> 00:00:58.000] These men and women assumed great risk in the service to all humanity.
[00:00:58.000 --> 00:01:03.000] In an age when space flight has come to seem almost routine,
[00:01:03.000 --> 00:01:07.000] it is easy to overlook the dangers of travel by rocket
[00:01:07.000 --> 00:01:12.000] and the difficulties of navigating the fierce outer atmosphere of the Earth.
[00:01:12.000 --> 00:01:18.000] These astronauts knew the dangers, and they faced them willingly,
[00:01:18.000 --> 00:01:23.000] knowing they had a high and noble purpose in life.
[00:01:23.000 --> 00:01:31.000] Because of their courage and daring and idealism, we will miss them all the more.
[00:01:31.000 --> 00:01:36.000] All Americans today are thinking as well of the families of these men and women
[00:01:36.000 --> 00:01:40.000] who have been given this sudden shock and grief.
[00:01:40.000 --> 00:01:45.000] You're not alone. Our entire nation grieves with you,
[00:01:45.000 --> 00:01:52.000] and those you love will always have the respect and gratitude of this country.
[00:01:52.000 --> 00:01:56.000] The cause in which they died will continue.
[00:01:56.000 --> 00:02:04.000] Mankind is led into the darkness beyond our world by the inspiration of discovery
[00:02:04.000 --> 00:02:11.000] and the longing to understand. Our journey into space will go on.
[00:02:11.000 --> 00:02:16.000] In the skies today, we saw destruction and tragedy.
[00:02:16.000 --> 00:02:22.000] Yet farther than we can see, there is comfort and hope.
[00:02:22.000 --> 00:02:29.000] In the words of the prophet Isaiah, "Lift your eyes and look to the heavens
[00:02:29.000 --> 00:02:35.000] who created all these. He who brings out the starry hosts one by one
[00:02:35.000 --> 00:02:39.000] and calls them each by name."
[00:02:39.000 --> 00:02:46.000] Because of His great power and mighty strength, not one of them is missing.
[00:02:46.000 --> 00:02:55.000] The same Creator who names the stars also knows the names of the seven souls we mourn today.
[00:02:55.000 --> 00:03:01.000] The crew of the shuttle Columbia did not return safely to earth,
[00:03:01.000 --> 00:03:05.000] yet we can pray that all are safely home.
[00:03:05.000 --> 00:03:13.000] May God bless the grieving families, and may God continue to bless America.
[00:03:13.000 --> 00:03:19.000] [Silence]
whisper_print_timings: fallbacks = 1 p / 0 h
whisper_print_timings: load time = 569.03 ms
whisper_print_timings: mel time = 146.85 ms
whisper_print_timings: sample time = 238.66 ms / 553 runs ( 0.43 ms per run)
whisper_print_timings: encode time = 18665.10 ms / 9 runs ( 2073.90 ms per run)
whisper_print_timings: decode time = 13090.93 ms / 549 runs ( 23.85 ms per run)
whisper_print_timings: total time = 32733.52 ms
whisper_print_timings: load time = 575.92 ms
whisper_print_timings: mel time = 230.60 ms
whisper_print_timings: sample time = 73.19 ms
whisper_print_timings: encode time = 19552.61 ms / 814.69 ms per layer
whisper_print_timings: decode time = 13249.96 ms / 552.08 ms per layer
whisper_print_timings: total time = 33686.27 ms
```
</details>
@ -317,7 +296,6 @@ The [stream](examples/stream) tool samples the audio every half a second and run
More info is available in [issue #10](https://github.com/ggerganov/whisper.cpp/issues/10).
```java
make stream
./stream -m ./models/ggml-base.en.bin -t 8 --step 500 --length 5000
```
@ -332,14 +310,14 @@ to highlight words with high or low confidence:
## Controlling the length of the generated text segments (experimental)
For example, to limit the line length to a maximum of 16 characters, simply add `-ml 16`:
For example, to limit the line length to a maximum of 16 characters, simply add `-ml 16`:
```java
./main -m ./models/ggml-base.en.bin -f ./samples/jfk.wav -ml 16
whisper_model_load: loading model from './models/ggml-base.en.bin'
...
system_info: n_threads = 4 / 10 | AVX2 = 0 | AVX512 = 0 | NEON = 1 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 |
system_info: n_threads = 4 / 10 | AVX2 = 0 | AVX512 = 0 | NEON = 1 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 |
main: processing './samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 processors, lang = en, task = transcribe, timestamps = 1 ...
@ -363,7 +341,7 @@ The `--max-len` argument can be used to obtain word-level timestamps. Simply use
whisper_model_load: loading model from './models/ggml-base.en.bin'
...
system_info: n_threads = 4 / 10 | AVX2 = 0 | AVX512 = 0 | NEON = 1 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 |
system_info: n_threads = 4 / 10 | AVX2 = 0 | AVX512 = 0 | NEON = 1 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 |
main: processing './samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 processors, lang = en, task = transcribe, timestamps = 1 ...
@ -450,46 +428,21 @@ The original models are converted to a custom binary format. This allows to pack
- vocabulary
- weights
You can download the converted models using the [models/download-ggml-model.sh](models/download-ggml-model.sh) script
or manually from here:
You can download the converted models using the [models/download-ggml-model.sh](models/download-ggml-model.sh) script or from here:
- https://huggingface.co/datasets/ggerganov/whisper.cpp
- https://ggml.ggerganov.com
https://ggml.ggerganov.com
For more details, see the conversion script [models/convert-pt-to-ggml.py](models/convert-pt-to-ggml.py) or the README
in [models](models).
For more details, see the conversion script [models/convert-pt-to-ggml.py](models/convert-pt-to-ggml.py) or the README in [models](models).
## [Bindings](https://github.com/ggerganov/whisper.cpp/discussions/categories/bindings)
## Bindings
- [X] Rust: [tazz4843/whisper-rs](https://github.com/tazz4843/whisper-rs) | [#310](https://github.com/ggerganov/whisper.cpp/discussions/310)
- [X] Javascript: [bindings/javascript](bindings/javascript) | [#309](https://github.com/ggerganov/whisper.cpp/discussions/309)
- [X] Go: [bindings/go](bindings/go) | [#312](https://github.com/ggerganov/whisper.cpp/discussions/312)
- [X] Objective-C / Swift: [ggerganov/whisper.spm](https://github.com/ggerganov/whisper.spm) | [#313](https://github.com/ggerganov/whisper.cpp/discussions/313)
- [ ] Python: soon | [WIP](https://github.com/ggerganov/whisper.cpp/issues/9)
- [X] Rust: [tazz4843/whisper-rs](https://github.com/tazz4843/whisper-rs)
- [X] Objective-C / Swift: [ggerganov/whisper.spm](https://github.com/ggerganov/whisper.spm)
- [ ] Python:
- [ ] Java:
## Examples
There are various examples of using the library for different projects in the [examples](examples) folder.
Some of the examples are even ported to run in the browser using WebAssembly. Check them out!
There are various examples of using the library for different projects in the [examples](examples) folder. Check them out!
| Example | Web | Description |
| --- | --- | --- |
| [main](examples/main) | [whisper.wasm](examples/whisper.wasm) | Tool for translating and transcribing audio using Whisper |
| [bench](examples/bench) | [bench.wasm](examples/bench.wasm) | Benchmark the performance of Whisper on your machine |
| [stream](examples/stream) | [stream.wasm](examples/stream.wasm) | Real-time transcription of raw microphone capture |
| [command](examples/command) | [command.wasm](examples/command.wasm) | Basic voice assistant example for receiving voice commands from the mic |
| [talk](examples/talk) | [talk.wasm](examples/talk.wasm) | Talk with a GPT-2 bot |
| [whisper.objc](examples/whisper.objc) | | iOS mobile application using whisper.cpp |
| [whisper.swiftui](examples/whisper.swiftui) | | SwiftUI iOS / macOS application using whisper.cpp |
| [whisper.android](examples/whisper.android) | | Android mobile application using whisper.cpp |
| [whisper.nvim](examples/whisper.nvim) | | Speech-to-text plugin for Neovim |
| [generate-karaoke.sh](examples/generate-karaoke.sh) | | Helper script to easily [generate a karaoke video](https://youtu.be/uj7hVta4blM) of raw audio capture |
| [livestream.sh](examples/livestream.sh) | | [Livestream audio transcription](https://github.com/ggerganov/whisper.cpp/issues/185) |
| [yt-wsp.sh](examples/yt-wsp.sh) | | Download + transcribe and/or translate any VOD [(original)](https://gist.github.com/DaniruKun/96f763ec1a037cc92fe1a059b643b818) |
## [Discussions](https://github.com/ggerganov/whisper.cpp/discussions)
If you have any kind of feedback about this project feel free to use the Discussions section and open a new topic.
You can use the [Show and tell](https://github.com/ggerganov/whisper.cpp/discussions/categories/show-and-tell) category
to share your own projects that use `whisper.cpp`. If you have a question, make sure to check the
[Frequently asked questions (#126)](https://github.com/ggerganov/whisper.cpp/discussions/126) discussion.
## [Frequently asked questions (#126)](https://github.com/ggerganov/whisper.cpp/discussions/126)

View File

@ -1,19 +1,3 @@
if (EMSCRIPTEN)
add_subdirectory(javascript)
add_custom_command(
OUTPUT ${CMAKE_CURRENT_SOURCE_DIR}/javascript/publish.log
DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/javascript/whisper.js
DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/javascript/libwhisper.worker.js
DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/javascript/package.json
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/javascript
COMMAND npm publish
COMMAND touch publish.log
COMMENT "Publishing npm module v${PROJECT_VERSION}"
VERBATIM
)
add_custom_target(publish-npm
DEPENDS javascript/publish.log
)
endif()

View File

@ -1,2 +0,0 @@
build
models

View File

@ -1,21 +0,0 @@
MIT License
Copyright (c) 2022 David Thorpe
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -1,38 +0,0 @@
BUILD_DIR := build
MODELS_DIR := models
EXAMPLES_DIR := $(wildcard examples/*)
INCLUDE_PATH := $(abspath ../..)
LIBRARY_PATH := $(abspath ../..)
all: clean whisper examples
whisper: mkdir
@echo Build whisper
@${MAKE} -C ../.. libwhisper.a
test: model-small whisper modtidy
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go test -v .
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go test -v ./pkg/whisper/...
examples: $(EXAMPLES_DIR)
model-small: mkdir examples/go-model-download
@${BUILD_DIR}/go-model-download -out models ggml-small.en.bin
$(EXAMPLES_DIR): mkdir whisper modtidy
@echo Build example $(notdir $@)
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go build ${BUILD_FLAGS} -o ${BUILD_DIR}/$(notdir $@) ./$@
mkdir:
@echo Mkdir ${BUILD_DIR}
@install -d ${BUILD_DIR}
@echo Mkdir ${MODELS_DIR}
@install -d ${MODELS_DIR}
modtidy:
@go mod tidy
clean:
@echo Clean
@rm -fr $(BUILD_DIR)
@go clean

View File

@ -1,100 +0,0 @@
# Go bindings for Whisper
This package provides Go bindings for whisper.cpp. They have been tested on:
* Darwin (OS X) 12.6 on x64_64
* Debian Linux on arm64
* Fedora Linux on x86_64
The "low level" bindings are in the `bindings/go` directory and there is a more
Go-style package in the `bindings/go/pkg/whisper` directory. The most simple usage
is as follows:
```go
import (
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
)
func main() {
var modelpath string // Path to the model
var samples []float32 // Samples to process
// Load the model
model, err := whisper.New(modelpath)
if err != nil {
panic(err)
}
defer model.Close()
// Process samples
context, err := model.NewContext()
if err != nil {
panic(err)
}
if err := context.Process(samples, nil); err != nil {
return err
}
// Print out the results
for {
segment, err := context.NextSegment()
if err != nil {
break
}
fmt.Printf("[%6s->%6s] %s\n", segment.Start, segment.End, segment.Text)
}
}
```
## Building & Testing
In order to build, you need to have the Go compiler installed. You can get it from [here](https://golang.org/dl/). Run the tests with:
```bash
git clone https://github.com/ggerganov/whisper.cpp.git
cd whisper.cpp/bindings/go
make test
```
This will compile a static `libwhisper.a` in a `build` folder, download a model file, then run the tests. To build the examples:
```bash
make examples
```
The examples are placed in the `build` directory. Once built, you can download all the models with the following command:
```bash
./build/go-model-download -out models
```
And you can then test a model against samples with the following command:
```bash
./build/go-whisper -model models/ggml-tiny.en.bin samples/jfk.wav
```
## Using the bindings
To use the bindings in your own software,
1. Import `github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper` (or `github.com/ggerganov/whisper.cpp/bindings/go` into your package;
2. Compile `libwhisper.a` (you can use `make whisper` in the `bindings/go` directory);
3. Link your go binary against whisper by setting the environment variables `C_INCLUDE_PATH` and `LIBRARY_PATH`
to point to the `whisper.h` file directory and `libwhisper.a` file directory respectively.
Look at the `Makefile` in the `bindings/go` directory for an example.
The API Documentation:
* https://pkg.go.dev/github.com/ggerganov/whisper.cpp/bindings/go
* https://pkg.go.dev/github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper
Getting help:
* Follow the discussion for the go bindings [here](https://github.com/ggerganov/whisper.cpp/discussions/312)
## License
The license for the Go bindings is the same as the license for the rest of the whisper.cpp project, which is the MIT License. See the `LICENSE` file for more details.

View File

@ -1,5 +0,0 @@
/*
github.com/ggerganov/whisper.cpp/bindings/go
provides a speech-to-text service bindings for the Go programming language.
*/
package whisper

View File

@ -1,30 +0,0 @@
package main
import (
"context"
"os"
"os/signal"
)
// ContextForSignal returns a context object which is cancelled when a signal
// is received. It returns nil if no signal parameter is provided
func ContextForSignal(signals ...os.Signal) context.Context {
if len(signals) == 0 {
return nil
}
ch := make(chan os.Signal)
ctx, cancel := context.WithCancel(context.Background())
// Send message on channel when signal received
signal.Notify(ch, signals...)
// When any signal received, call cancel
go func() {
<-ch
cancel()
}()
// Return success
return ctx
}

View File

@ -1,208 +0,0 @@
package main
import (
"context"
"flag"
"fmt"
"io"
"net/http"
"net/url"
"os"
"path/filepath"
"syscall"
"time"
)
///////////////////////////////////////////////////////////////////////////////
// CONSTANTS
const (
srcUrl = "https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main" // The location of the models
srcExt = ".bin" // Filename extension
bufSize = 1024 * 64 // Size of the buffer used for downloading the model
)
var (
// The models which will be downloaded, if no model is specified as an argument
modelNames = []string{"ggml-tiny.en", "ggml-tiny", "ggml-base.en", "ggml-base", "ggml-small.en", "ggml-small", "ggml-medium.en", "ggml-medium", "ggml-large-v1", "ggml-large"}
)
var (
// The output folder. When not set, use current working directory.
flagOut = flag.String("out", "", "Output folder")
// HTTP timeout parameter - will timeout if takes longer than this to download a model
flagTimeout = flag.Duration("timeout", 30*time.Minute, "HTTP timeout")
// Quiet parameter - will not print progress if set
flagQuiet = flag.Bool("quiet", false, "Quiet mode")
)
///////////////////////////////////////////////////////////////////////////////
// MAIN
func main() {
flag.Usage = func() {
name := filepath.Base(flag.CommandLine.Name())
fmt.Fprintf(flag.CommandLine.Output(), "Usage: %s [options] <model>\n\n", name)
flag.PrintDefaults()
}
flag.Parse()
// Get output path
out, err := GetOut()
if err != nil {
fmt.Fprintln(os.Stderr, "Error:", err)
os.Exit(-1)
}
// Create context which quits on SIGINT or SIGQUIT
ctx := ContextForSignal(os.Interrupt, syscall.SIGQUIT)
// Progress filehandle
progress := os.Stdout
if *flagQuiet {
progress, err = os.Open(os.DevNull)
if err != nil {
fmt.Fprintln(os.Stderr, "Error:", err)
os.Exit(-1)
}
defer progress.Close()
}
// Download models - exit on error or interrupt
for _, model := range GetModels() {
url, err := URLForModel(model)
if err != nil {
fmt.Fprintln(os.Stderr, "Error:", err)
continue
} else if path, err := Download(ctx, progress, url, out); err == nil || err == io.EOF {
continue
} else if err == context.Canceled {
os.Remove(path)
fmt.Fprintln(progress, "\nInterrupted")
break
} else if err == context.DeadlineExceeded {
os.Remove(path)
fmt.Fprintln(progress, "Timeout downloading model")
continue
} else {
os.Remove(path)
fmt.Fprintln(os.Stderr, "Error:", err)
break
}
}
}
///////////////////////////////////////////////////////////////////////////////
// PUBLIC METHODS
// GetOut returns the path to the output directory
func GetOut() (string, error) {
if *flagOut == "" {
return os.Getwd()
}
if info, err := os.Stat(*flagOut); err != nil {
return "", err
} else if !info.IsDir() {
return "", fmt.Errorf("not a directory: %s", info.Name())
} else {
return *flagOut, nil
}
}
// GetModels returns the list of models to download
func GetModels() []string {
if flag.NArg() == 0 {
return modelNames
} else {
return flag.Args()
}
}
// URLForModel returns the URL for the given model on huggingface.co
func URLForModel(model string) (string, error) {
if filepath.Ext(model) != srcExt {
model += srcExt
}
url, err := url.Parse(srcUrl)
if err != nil {
return "", err
} else {
url.Path = filepath.Join(url.Path, model)
}
return url.String(), nil
}
// Download downloads the model from the given URL to the given output directory
func Download(ctx context.Context, p io.Writer, model, out string) (string, error) {
// Create HTTP client
client := http.Client{
Timeout: *flagTimeout,
}
// Initiate the download
req, err := http.NewRequest("GET", model, nil)
if err != nil {
return "", err
}
resp, err := client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("%s: %s", model, resp.Status)
}
// If output file exists and is the same size as the model, skip
path := filepath.Join(out, filepath.Base(model))
if info, err := os.Stat(path); err == nil && info.Size() == resp.ContentLength {
fmt.Fprintln(p, "Skipping", model, "as it already exists")
return "", nil
}
// Create file
w, err := os.Create(path)
if err != nil {
return "", err
}
defer w.Close()
// Report
fmt.Fprintln(p, "Downloading", model, "to", out)
// Progressively download the model
data := make([]byte, bufSize)
count, pct := int64(0), int64(0)
ticker := time.NewTicker(5 * time.Second)
for {
select {
case <-ctx.Done():
// Cancelled, return error
return path, ctx.Err()
case <-ticker.C:
pct = DownloadReport(p, pct, count, resp.ContentLength)
default:
// Read body
n, err := resp.Body.Read(data)
if err != nil {
DownloadReport(p, pct, count, resp.ContentLength)
return path, err
} else if m, err := w.Write(data[:n]); err != nil {
return path, err
} else {
count += int64(m)
}
}
}
}
// Report periodically reports the download progress when percentage changes
func DownloadReport(w io.Writer, pct, count, total int64) int64 {
pct_ := count * 100 / total
if pct_ > pct {
fmt.Fprintf(w, " ...%d MB written (%d%%)\n", count/1e6, pct_)
}
return pct_
}

View File

@ -1,22 +0,0 @@
package main
import "fmt"
///////////////////////////////////////////////////////////////////////////////
// CONSTANTS
const (
Reset = "\033[0m"
RGBPrefix = "\033[38;5;" // followed by RGB values in decimal format separated by colons
RGBSuffix = "m"
)
///////////////////////////////////////////////////////////////////////////////
// PUBLIC METHODS
// Colorize text with RGB values, from 0 to 23
func Colorize(text string, v int) string {
// https://en.wikipedia.org/wiki/ANSI_escape_code#8-bit
// Grayscale colors are in the range 232-255
return RGBPrefix + fmt.Sprint(v%24+232) + RGBSuffix + text + Reset
}

View File

@ -1,156 +0,0 @@
package main
import (
"flag"
"fmt"
"strings"
"time"
// Packages
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
)
///////////////////////////////////////////////////////////////////////////////
// TYPES
type Flags struct {
*flag.FlagSet
}
///////////////////////////////////////////////////////////////////////////////
// LIFECYCLE
func NewFlags(name string, args []string) (*Flags, error) {
flags := &Flags{
FlagSet: flag.NewFlagSet(name, flag.ContinueOnError),
}
// Register the command line arguments
registerFlags(flags)
// Parse command line
if err := flags.Parse(args); err != nil {
return nil, err
}
// Return success
return flags, nil
}
///////////////////////////////////////////////////////////////////////////////
// PUBLIC METHODS
func (flags *Flags) GetModel() string {
return flags.Lookup("model").Value.String()
}
func (flags *Flags) GetLanguage() string {
return flags.Lookup("language").Value.String()
}
func (flags *Flags) IsTranslate() bool {
return flags.Lookup("translate").Value.(flag.Getter).Get().(bool)
}
func (flags *Flags) GetOffset() time.Duration {
return flags.Lookup("offset").Value.(flag.Getter).Get().(time.Duration)
}
func (flags *Flags) GetDuration() time.Duration {
return flags.Lookup("duration").Value.(flag.Getter).Get().(time.Duration)
}
func (flags *Flags) GetThreads() uint {
return flags.Lookup("threads").Value.(flag.Getter).Get().(uint)
}
func (flags *Flags) GetOut() string {
return strings.ToLower(flags.Lookup("out").Value.String())
}
func (flags *Flags) IsSpeedup() bool {
return flags.Lookup("speedup").Value.String() == "true"
}
func (flags *Flags) IsTokens() bool {
return flags.Lookup("tokens").Value.String() == "true"
}
func (flags *Flags) IsColorize() bool {
return flags.Lookup("colorize").Value.String() == "true"
}
func (flags *Flags) GetMaxLen() uint {
return flags.Lookup("max-len").Value.(flag.Getter).Get().(uint)
}
func (flags *Flags) GetMaxTokens() uint {
return flags.Lookup("max-tokens").Value.(flag.Getter).Get().(uint)
}
func (flags *Flags) GetWordThreshold() float32 {
return float32(flags.Lookup("word-thold").Value.(flag.Getter).Get().(float64))
}
func (flags *Flags) SetParams(context whisper.Context) error {
if lang := flags.GetLanguage(); lang != "" && lang != "auto" {
fmt.Fprintf(flags.Output(), "Setting language to %q\n", lang)
if err := context.SetLanguage(lang); err != nil {
return err
}
}
if flags.IsTranslate() && context.IsMultilingual() {
fmt.Fprintf(flags.Output(), "Setting translate to true\n")
context.SetTranslate(true)
}
if offset := flags.GetOffset(); offset != 0 {
fmt.Fprintf(flags.Output(), "Setting offset to %v\n", offset)
context.SetOffset(offset)
}
if duration := flags.GetDuration(); duration != 0 {
fmt.Fprintf(flags.Output(), "Setting duration to %v\n", duration)
context.SetDuration(duration)
}
if flags.IsSpeedup() {
fmt.Fprintf(flags.Output(), "Setting speedup to true\n")
context.SetSpeedup(true)
}
if threads := flags.GetThreads(); threads != 0 {
fmt.Fprintf(flags.Output(), "Setting threads to %d\n", threads)
context.SetThreads(threads)
}
if max_len := flags.GetMaxLen(); max_len != 0 {
fmt.Fprintf(flags.Output(), "Setting max_segment_length to %d\n", max_len)
context.SetMaxSegmentLength(max_len)
}
if max_tokens := flags.GetMaxTokens(); max_tokens != 0 {
fmt.Fprintf(flags.Output(), "Setting max_tokens to %d\n", max_tokens)
context.SetMaxTokensPerSegment(max_tokens)
}
if word_threshold := flags.GetWordThreshold(); word_threshold != 0 {
fmt.Fprintf(flags.Output(), "Setting word_threshold to %f\n", word_threshold)
context.SetTokenThreshold(word_threshold)
}
// Return success
return nil
}
///////////////////////////////////////////////////////////////////////////////
// PRIVATE METHODS
func registerFlags(flag *Flags) {
flag.String("model", "", "Path to the model file")
flag.String("language", "", "Spoken language")
flag.Bool("translate", false, "Translate from source language to english")
flag.Duration("offset", 0, "Time offset")
flag.Duration("duration", 0, "Duration of audio to process")
flag.Uint("threads", 0, "Number of threads to use")
flag.Bool("speedup", false, "Enable speedup")
flag.Uint("max-len", 0, "Maximum segment length in characters")
flag.Uint("max-tokens", 0, "Maximum tokens per segment")
flag.Float64("word-thold", 0, "Maximum segment score")
flag.Bool("tokens", false, "Display tokens")
flag.Bool("colorize", false, "Colorize tokens")
flag.String("out", "", "Output format (srt, none or leave as empty string)")
}

View File

@ -1,43 +0,0 @@
package main
import (
"flag"
"fmt"
"os"
"path/filepath"
// Packages
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
)
func main() {
flags, err := NewFlags(filepath.Base(os.Args[0]), os.Args[1:])
if err == flag.ErrHelp {
os.Exit(0)
} else if err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
} else if flags.GetModel() == "" {
fmt.Fprintln(os.Stderr, "Use -model flag to specify which model file to use")
os.Exit(1)
} else if flags.NArg() == 0 {
fmt.Fprintln(os.Stderr, "No input files specified")
os.Exit(1)
}
// Load model
model, err := whisper.New(flags.GetModel())
if err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
defer model.Close()
// Process files
for _, filename := range flags.Args() {
if err := Process(model, filename, flags); err != nil {
fmt.Fprintln(os.Stderr, err)
continue
}
}
}

View File

@ -1,132 +0,0 @@
package main
import (
"fmt"
"io"
"os"
"time"
// Package imports
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
wav "github.com/go-audio/wav"
)
func Process(model whisper.Model, path string, flags *Flags) error {
var data []float32
// Create processing context
context, err := model.NewContext()
if err != nil {
return err
}
// Set the parameters
if err := flags.SetParams(context); err != nil {
return err
}
fmt.Printf("\n%s\n", context.SystemInfo())
// Open the file
fmt.Fprintf(flags.Output(), "Loading %q\n", path)
fh, err := os.Open(path)
if err != nil {
return err
}
defer fh.Close()
// Decode the WAV file - load the full buffer
dec := wav.NewDecoder(fh)
if buf, err := dec.FullPCMBuffer(); err != nil {
return err
} else if dec.SampleRate != whisper.SampleRate {
return fmt.Errorf("unsupported sample rate: %d", dec.SampleRate)
} else if dec.NumChans != 1 {
return fmt.Errorf("unsupported number of channels: %d", dec.NumChans)
} else {
data = buf.AsFloat32Buffer().Data
}
// Segment callback when -tokens is specified
var cb whisper.SegmentCallback
if flags.IsTokens() {
cb = func(segment whisper.Segment) {
fmt.Fprintf(flags.Output(), "%02d [%6s->%6s] ", segment.Num, segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond))
for _, token := range segment.Tokens {
if flags.IsColorize() && context.IsText(token) {
fmt.Fprint(flags.Output(), Colorize(token.Text, int(token.P*24.0)), " ")
} else {
fmt.Fprint(flags.Output(), token.Text, " ")
}
}
fmt.Fprintln(flags.Output(), "")
fmt.Fprintln(flags.Output(), "")
}
}
// Process the data
fmt.Fprintf(flags.Output(), " ...processing %q\n", path)
context.ResetTimings()
if err := context.Process(data, cb); err != nil {
return err
}
context.PrintTimings()
// Print out the results
switch {
case flags.GetOut() == "srt":
return OutputSRT(os.Stdout, context)
case flags.GetOut() == "none":
return nil
default:
return Output(os.Stdout, context, flags.IsColorize())
}
}
// Output text as SRT file
func OutputSRT(w io.Writer, context whisper.Context) error {
n := 1
for {
segment, err := context.NextSegment()
if err == io.EOF {
return nil
} else if err != nil {
return err
}
fmt.Fprintln(w, n)
fmt.Fprintln(w, srtTimestamp(segment.Start), " --> ", srtTimestamp(segment.End))
fmt.Fprintln(w, segment.Text)
fmt.Fprintln(w, "")
n++
}
}
// Output text to terminal
func Output(w io.Writer, context whisper.Context, colorize bool) error {
for {
segment, err := context.NextSegment()
if err == io.EOF {
return nil
} else if err != nil {
return err
}
fmt.Fprintf(w, "[%6s->%6s]", segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond))
if colorize {
for _, token := range segment.Tokens {
if !context.IsText(token) {
continue
}
fmt.Fprint(w, " ", Colorize(token.Text, int(token.P*24.0)))
}
fmt.Fprint(w, "\n")
} else {
fmt.Fprintln(w, " ", segment.Text)
}
}
}
// Return srtTimestamp
func srtTimestamp(t time.Duration) string {
return fmt.Sprintf("%02d:%02d:%02d,%03d", t/time.Hour, (t%time.Hour)/time.Minute, (t%time.Minute)/time.Second, (t%time.Second)/time.Millisecond)
}

View File

@ -1,16 +0,0 @@
module github.com/ggerganov/whisper.cpp/bindings/go
go 1.19
require (
github.com/go-audio/wav v1.1.0
github.com/stretchr/testify v1.8.1
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/go-audio/audio v1.0.0 // indirect
github.com/go-audio/riff v1.0.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

View File

@ -1,23 +0,0 @@
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-audio/audio v1.0.0 h1:zS9vebldgbQqktK4H0lUqWrG8P0NxCJVqcj7ZpNnwd4=
github.com/go-audio/audio v1.0.0/go.mod h1:6uAu0+H2lHkwdGsAY+j2wHPNPpPoeg5AaEFh9FlA+Zs=
github.com/go-audio/riff v1.0.0 h1:d8iCGbDvox9BfLagY94fBynxSPHO80LmZCaOsmKxokA=
github.com/go-audio/riff v1.0.0/go.mod h1:l3cQwc85y79NQFCRB7TiPoNiaijp6q8Z0Uv38rVG498=
github.com/go-audio/wav v1.1.0 h1:jQgLtbqBzY7G+BM8fXF7AHUk1uHUviWS4X39d5rsL2g=
github.com/go-audio/wav v1.1.0/go.mod h1:mpe9qfwbScEbkd8uybLuIpTgHyrISw/OTuvjUW2iGtE=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@ -1,165 +0,0 @@
package whisper
import (
"fmt"
)
///////////////////////////////////////////////////////////////////////////////
// CGO
/*
#include <whisper.h>
*/
import "C"
///////////////////////////////////////////////////////////////////////////////
// PUBLIC METHODS
func (p *Params) SetTranslate(v bool) {
p.translate = toBool(v)
}
func (p *Params) SetNoContext(v bool) {
p.no_context = toBool(v)
}
func (p *Params) SetSingleSegment(v bool) {
p.single_segment = toBool(v)
}
func (p *Params) SetPrintSpecial(v bool) {
p.print_special = toBool(v)
}
func (p *Params) SetPrintProgress(v bool) {
p.print_progress = toBool(v)
}
func (p *Params) SetPrintRealtime(v bool) {
p.print_realtime = toBool(v)
}
func (p *Params) SetPrintTimestamps(v bool) {
p.print_timestamps = toBool(v)
}
func (p *Params) SetSpeedup(v bool) {
p.speed_up = toBool(v)
}
// Set language id
func (p *Params) SetLanguage(lang int) error {
if lang == -1 {
p.language = nil
return nil
}
str := C.whisper_lang_str(C.int(lang))
if str == nil {
return ErrInvalidLanguage
} else {
p.language = str
}
return nil
}
// Get language id
func (p *Params) Language() int {
if p.language == nil {
return -1
}
return int(C.whisper_lang_id(p.language))
}
// Threads available
func (p *Params) Threads() int {
return int(p.n_threads)
}
// Set number of threads to use
func (p *Params) SetThreads(threads int) {
p.n_threads = C.int(threads)
}
// Set start offset in ms
func (p *Params) SetOffset(offset_ms int) {
p.offset_ms = C.int(offset_ms)
}
// Set audio duration to process in ms
func (p *Params) SetDuration(duration_ms int) {
p.duration_ms = C.int(duration_ms)
}
// Set timestamp token probability threshold (~0.01)
func (p *Params) SetTokenThreshold(t float32) {
p.thold_pt = C.float(t)
}
// Set timestamp token sum probability threshold (~0.01)
func (p *Params) SetTokenSumThreshold(t float32) {
p.thold_ptsum = C.float(t)
}
// Set max segment length in characters
func (p *Params) SetMaxSegmentLength(n int) {
p.max_len = C.int(n)
}
// Set max tokens per segment (0 = no limit)
func (p *Params) SetMaxTokensPerSegment(n int) {
p.max_tokens = C.int(n)
}
///////////////////////////////////////////////////////////////////////////////
// PRIVATE METHODS
func toBool(v bool) C.bool {
if v {
return C.bool(true)
}
return C.bool(false)
}
///////////////////////////////////////////////////////////////////////////////
// STRINGIFY
func (p *Params) String() string {
str := "<whisper.params"
str += fmt.Sprintf(" strategy=%v", p.strategy)
str += fmt.Sprintf(" n_threads=%d", p.n_threads)
if p.language != nil {
str += fmt.Sprintf(" language=%s", C.GoString(p.language))
}
str += fmt.Sprintf(" n_max_text_ctx=%d", p.n_max_text_ctx)
str += fmt.Sprintf(" offset_ms=%d", p.offset_ms)
str += fmt.Sprintf(" duration_ms=%d", p.duration_ms)
if p.translate {
str += " translate"
}
if p.no_context {
str += " no_context"
}
if p.single_segment {
str += " single_segment"
}
if p.print_special {
str += " print_special"
}
if p.print_progress {
str += " print_progress"
}
if p.print_realtime {
str += " print_realtime"
}
if p.print_timestamps {
str += " print_timestamps"
}
if p.token_timestamps {
str += " token_timestamps"
}
if p.speed_up {
str += " speed_up"
}
return str + ">"
}

View File

@ -1,28 +0,0 @@
package whisper
import (
"errors"
// Bindings
whisper "github.com/ggerganov/whisper.cpp/bindings/go"
)
///////////////////////////////////////////////////////////////////////////////
// ERRORS
var (
ErrUnableToLoadModel = errors.New("unable to load model")
ErrInternalAppError = errors.New("internal application error")
ErrProcessingFailed = errors.New("processing failed")
ErrUnsupportedLanguage = errors.New("unsupported language")
ErrModelNotMultilingual = errors.New("model is not multilingual")
)
///////////////////////////////////////////////////////////////////////////////
// CONSTANTS
// SampleRate is the sample rate of the audio data.
const SampleRate = whisper.SampleRate
// SampleBits is the number of bytes per sample.
const SampleBits = whisper.SampleBits

View File

@ -1,290 +0,0 @@
package whisper
import (
"fmt"
"io"
"runtime"
"strings"
"time"
// Bindings
whisper "github.com/ggerganov/whisper.cpp/bindings/go"
)
///////////////////////////////////////////////////////////////////////////////
// TYPES
type context struct {
n int
model *model
params whisper.Params
}
// Make sure context adheres to the interface
var _ Context = (*context)(nil)
///////////////////////////////////////////////////////////////////////////////
// LIFECYCLE
func newContext(model *model, params whisper.Params) (Context, error) {
context := new(context)
context.model = model
context.params = params
// Return success
return context, nil
}
///////////////////////////////////////////////////////////////////////////////
// PUBLIC METHODS
// Set the language to use for speech recognition.
func (context *context) SetLanguage(lang string) error {
if context.model.ctx == nil {
return ErrInternalAppError
}
if !context.model.IsMultilingual() {
return ErrModelNotMultilingual
}
if lang == "auto" {
context.params.SetLanguage(-1)
} else if id := context.model.ctx.Whisper_lang_id(lang); id < 0 {
return ErrUnsupportedLanguage
} else if err := context.params.SetLanguage(id); err != nil {
return err
}
// Return success
return nil
}
func (context *context) IsMultilingual() bool {
return context.model.IsMultilingual()
}
// Get language
func (context *context) Language() string {
id := context.params.Language()
if id == -1 {
return "auto"
}
return whisper.Whisper_lang_str(context.params.Language())
}
// Set translate flag
func (context *context) SetTranslate(v bool) {
context.params.SetTranslate(v)
}
// Set speedup flag
func (context *context) SetSpeedup(v bool) {
context.params.SetSpeedup(v)
}
// Set number of threads to use
func (context *context) SetThreads(v uint) {
context.params.SetThreads(int(v))
}
// Set time offset
func (context *context) SetOffset(v time.Duration) {
context.params.SetOffset(int(v.Milliseconds()))
}
// Set duration of audio to process
func (context *context) SetDuration(v time.Duration) {
context.params.SetOffset(int(v.Milliseconds()))
}
// Set timestamp token probability threshold (~0.01)
func (context *context) SetTokenThreshold(t float32) {
context.params.SetTokenThreshold(t)
}
// Set timestamp token sum probability threshold (~0.01)
func (context *context) SetTokenSumThreshold(t float32) {
context.params.SetTokenSumThreshold(t)
}
// Set max segment length in characters
func (context *context) SetMaxSegmentLength(n uint) {
context.params.SetMaxSegmentLength(int(n))
}
// Set max tokens per segment (0 = no limit)
func (context *context) SetMaxTokensPerSegment(n uint) {
context.params.SetMaxTokensPerSegment(int(n))
}
// ResetTimings resets the mode timings. Should be called before processing
func (context *context) ResetTimings() {
context.model.ctx.Whisper_reset_timings()
}
// PrintTimings prints the model timings to stdout.
func (context *context) PrintTimings() {
context.model.ctx.Whisper_print_timings()
}
// SystemInfo returns the system information
func (context *context) SystemInfo() string {
return fmt.Sprintf("system_info: n_threads = %d / %d | %s\n",
context.params.Threads(),
runtime.NumCPU(),
whisper.Whisper_print_system_info(),
)
}
// Use mel data at offset_ms to try and auto-detect the spoken language
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
// Returns the probabilities of all languages.
func (context *context) WhisperLangAutoDetect(offset_ms int, n_threads int) ([]float32, error) {
langProbs, err := context.model.ctx.Whisper_lang_auto_detect(offset_ms, n_threads)
if err != nil {
return nil, err
}
return langProbs, nil
}
// Process new sample data and return any errors
func (context *context) Process(data []float32, cb SegmentCallback) error {
if context.model.ctx == nil {
return ErrInternalAppError
}
// If the callback is defined then we force on single_segment mode
if cb != nil {
context.params.SetSingleSegment(true)
}
// We don't do parallel processing at the moment
processors := 0
if processors > 1 {
if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, nil, func(new int) {
if cb != nil {
num_segments := context.model.ctx.Whisper_full_n_segments()
s0 := num_segments - new
for i := s0; i < num_segments; i++ {
cb(toSegment(context.model.ctx, i))
}
}
}); err != nil {
return err
}
} else if err := context.model.ctx.Whisper_full(context.params, data, nil, func(new int) {
if cb != nil {
num_segments := context.model.ctx.Whisper_full_n_segments()
s0 := num_segments - new
for i := s0; i < num_segments; i++ {
cb(toSegment(context.model.ctx, i))
}
}
}); err != nil {
return err
}
// Return success
return nil
}
// Return the next segment of tokens
func (context *context) NextSegment() (Segment, error) {
if context.model.ctx == nil {
return Segment{}, ErrInternalAppError
}
if context.n >= context.model.ctx.Whisper_full_n_segments() {
return Segment{}, io.EOF
}
// Populate result
result := toSegment(context.model.ctx, context.n)
// Increment the cursor
context.n++
// Return success
return result, nil
}
// Test for text tokens
func (context *context) IsText(t Token) bool {
switch {
case context.IsBEG(t):
return false
case context.IsSOT(t):
return false
case whisper.Token(t.Id) >= context.model.ctx.Whisper_token_eot():
return false
case context.IsPREV(t):
return false
case context.IsSOLM(t):
return false
case context.IsNOT(t):
return false
default:
return true
}
}
// Test for "begin" token
func (context *context) IsBEG(t Token) bool {
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_beg()
}
// Test for "start of transcription" token
func (context *context) IsSOT(t Token) bool {
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_sot()
}
// Test for "end of transcription" token
func (context *context) IsEOT(t Token) bool {
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_eot()
}
// Test for "start of prev" token
func (context *context) IsPREV(t Token) bool {
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_prev()
}
// Test for "start of lm" token
func (context *context) IsSOLM(t Token) bool {
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_solm()
}
// Test for "No timestamps" token
func (context *context) IsNOT(t Token) bool {
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_not()
}
// Test for token associated with a specific language
func (context *context) IsLANG(t Token, lang string) bool {
if id := context.model.ctx.Whisper_lang_id(lang); id >= 0 {
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_lang(id)
} else {
return false
}
}
///////////////////////////////////////////////////////////////////////////////
// PRIVATE METHODS
func toSegment(ctx *whisper.Context, n int) Segment {
return Segment{
Num: n,
Text: strings.TrimSpace(ctx.Whisper_full_get_segment_text(n)),
Start: time.Duration(ctx.Whisper_full_get_segment_t0(n)) * time.Millisecond * 10,
End: time.Duration(ctx.Whisper_full_get_segment_t1(n)) * time.Millisecond * 10,
Tokens: toTokens(ctx, n),
}
}
func toTokens(ctx *whisper.Context, n int) []Token {
result := make([]Token, ctx.Whisper_full_n_tokens(n))
for i := 0; i < len(result); i++ {
result[i] = Token{
Id: int(ctx.Whisper_full_get_token_id(n, i)),
Text: strings.TrimSpace(ctx.Whisper_full_get_token_text(n, i)),
P: ctx.Whisper_full_get_token_p(n, i),
}
}
return result
}

View File

@ -1,55 +0,0 @@
package whisper_test
import (
"os"
"testing"
// Packages
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
assert "github.com/stretchr/testify/assert"
)
const (
ModelPath = "../../models/ggml-tiny.bin"
SamplePath = "../../samples/jfk.wav"
)
func Test_Whisper_000(t *testing.T) {
assert := assert.New(t)
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
t.Skip("Skipping test, model not found:", ModelPath)
}
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
t.Skip("Skipping test, sample not found:", SamplePath)
}
// Load model
model, err := whisper.New(ModelPath)
assert.NoError(err)
assert.NotNil(model)
assert.NoError(model.Close())
t.Log("languages=", model.Languages())
}
func Test_Whisper_001(t *testing.T) {
assert := assert.New(t)
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
t.Skip("Skipping test, model not found:", ModelPath)
}
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
t.Skip("Skipping test, sample not found:", SamplePath)
}
// Load model
model, err := whisper.New(ModelPath)
assert.NoError(err)
assert.NotNil(model)
defer model.Close()
// Get context for decoding
ctx, err := model.NewContext()
assert.NoError(err)
assert.NotNil(ctx)
}

View File

@ -1,4 +0,0 @@
/*
This is the higher-level speech-to-text whisper.cpp API for go
*/
package whisper

View File

@ -1,91 +0,0 @@
package whisper
import (
"io"
"time"
)
///////////////////////////////////////////////////////////////////////////////
// TYPES
// SegmentCallback is the callback function for processing segments in real
// time. It is called during the Process function
type SegmentCallback func(Segment)
// Model is the interface to a whisper model. Create a new model with the
// function whisper.New(string)
type Model interface {
io.Closer
// Return a new speech-to-text context.
NewContext() (Context, error)
// Return true if the model is multilingual.
IsMultilingual() bool
// Return all languages supported.
Languages() []string
}
// Context is the speach recognition context.
type Context interface {
SetLanguage(string) error // Set the language to use for speech recognition, use "auto" for auto detect language.
SetTranslate(bool) // Set translate flag
IsMultilingual() bool // Return true if the model is multilingual.
Language() string // Get language
SetOffset(time.Duration) // Set offset
SetDuration(time.Duration) // Set duration
SetThreads(uint) // Set number of threads to use
SetSpeedup(bool) // Set speedup flag
SetTokenThreshold(float32) // Set timestamp token probability threshold
SetTokenSumThreshold(float32) // Set timestamp token sum probability threshold
SetMaxSegmentLength(uint) // Set max segment length in characters
SetMaxTokensPerSegment(uint) // Set max tokens per segment (0 = no limit)
// Process mono audio data and return any errors.
// If defined, newly generated segments are passed to the
// callback function during processing.
Process([]float32, SegmentCallback) error
// After process is called, return segments until the end of the stream
// is reached, when io.EOF is returned.
NextSegment() (Segment, error)
IsBEG(Token) bool // Test for "begin" token
IsSOT(Token) bool // Test for "start of transcription" token
IsEOT(Token) bool // Test for "end of transcription" token
IsPREV(Token) bool // Test for "start of prev" token
IsSOLM(Token) bool // Test for "start of lm" token
IsNOT(Token) bool // Test for "No timestamps" token
IsLANG(Token, string) bool // Test for token associated with a specific language
IsText(Token) bool // Test for text token
// Timings
PrintTimings()
ResetTimings()
SystemInfo() string
}
// Segment is the text result of a speech recognition.
type Segment struct {
// Segment Number
Num int
// Time beginning and end timestamps for the segment.
Start, End time.Duration
// The text of the segment.
Text string
// The tokens of the segment.
Tokens []Token
}
// Token is a text or special token
type Token struct {
Id int
Text string
P float32
}

View File

@ -1,100 +0,0 @@
package whisper
import (
"fmt"
"os"
"runtime"
// Bindings
whisper "github.com/ggerganov/whisper.cpp/bindings/go"
)
///////////////////////////////////////////////////////////////////////////////
// TYPES
type model struct {
path string
ctx *whisper.Context
}
// Make sure model adheres to the interface
var _ Model = (*model)(nil)
///////////////////////////////////////////////////////////////////////////////
// LIFECYCLE
func New(path string) (Model, error) {
model := new(model)
if _, err := os.Stat(path); err != nil {
return nil, err
} else if ctx := whisper.Whisper_init(path); ctx == nil {
return nil, ErrUnableToLoadModel
} else {
model.ctx = ctx
model.path = path
}
// Return success
return model, nil
}
func (model *model) Close() error {
if model.ctx != nil {
model.ctx.Whisper_free()
}
// Release resources
model.ctx = nil
// Return success
return nil
}
///////////////////////////////////////////////////////////////////////////////
// STRINGIFY
func (model *model) String() string {
str := "<whisper.model"
if model.ctx != nil {
str += fmt.Sprintf(" model=%q", model.path)
}
return str + ">"
}
///////////////////////////////////////////////////////////////////////////////
// PUBLIC METHODS
// Return true if model is multilingual (language and translation options are supported)
func (model *model) IsMultilingual() bool {
return model.ctx.Whisper_is_multilingual() != 0
}
// Return all recognized languages. Initially it is set to auto-detect
func (model *model) Languages() []string {
result := make([]string, 0, whisper.Whisper_lang_max_id())
for i := 0; i < whisper.Whisper_lang_max_id(); i++ {
str := whisper.Whisper_lang_str(i)
if model.ctx.Whisper_lang_id(str) >= 0 {
result = append(result, str)
}
}
return result
}
func (model *model) NewContext() (Context, error) {
if model.ctx == nil {
return nil, ErrInternalAppError
}
// Create new context
params := model.ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY)
params.SetTranslate(false)
params.SetPrintSpecial(false)
params.SetPrintProgress(false)
params.SetPrintRealtime(false)
params.SetPrintTimestamps(false)
params.SetThreads(runtime.NumCPU())
// Return new context
return newContext(model, params)
}

Binary file not shown.

View File

@ -1,409 +0,0 @@
package whisper
import (
"errors"
"unsafe"
)
///////////////////////////////////////////////////////////////////////////////
// CGO
/*
#cgo LDFLAGS: -lwhisper -lm -lstdc++
#cgo darwin LDFLAGS: -framework Accelerate
#include <whisper.h>
#include <stdlib.h>
extern void callNewSegment(void* user_data, int new);
extern bool callEncoderBegin(void* user_data);
// Text segment callback
// Called on every newly generated text segment
// Use the whisper_full_...() functions to obtain the text segments
static void whisper_new_segment_cb(struct whisper_context* ctx, int n_new, void* user_data) {
if(user_data != NULL && ctx != NULL) {
callNewSegment(user_data, n_new);
}
}
// Encoder begin callback
// If not NULL, called before the encoder starts
// If it returns false, the computation is aborted
static bool whisper_encoder_begin_cb(struct whisper_context* ctx, void* user_data) {
if(user_data != NULL && ctx != NULL) {
return callEncoderBegin(user_data);
}
return false;
}
// Get default parameters and set callbacks
static struct whisper_full_params whisper_full_default_params_cb(struct whisper_context* ctx, enum whisper_sampling_strategy strategy) {
struct whisper_full_params params = whisper_full_default_params(strategy);
params.new_segment_callback = whisper_new_segment_cb;
params.new_segment_callback_user_data = (void*)(ctx);
params.encoder_begin_callback = whisper_encoder_begin_cb;
params.encoder_begin_callback_user_data = (void*)(ctx);
return params;
}
*/
import "C"
///////////////////////////////////////////////////////////////////////////////
// TYPES
type (
Context C.struct_whisper_context
Token C.whisper_token
TokenData C.struct_whisper_token_data
SamplingStrategy C.enum_whisper_sampling_strategy
Params C.struct_whisper_full_params
)
///////////////////////////////////////////////////////////////////////////////
// GLOBALS
const (
SAMPLING_GREEDY SamplingStrategy = C.WHISPER_SAMPLING_GREEDY
SAMPLING_BEAM_SEARCH SamplingStrategy = C.WHISPER_SAMPLING_BEAM_SEARCH
)
const (
SampleRate = C.WHISPER_SAMPLE_RATE // Expected sample rate, samples per second
SampleBits = uint16(unsafe.Sizeof(C.float(0))) * 8 // Sample size in bits
NumFFT = C.WHISPER_N_FFT
NumMEL = C.WHISPER_N_MEL
HopLength = C.WHISPER_HOP_LENGTH
ChunkSize = C.WHISPER_CHUNK_SIZE
)
var (
ErrTokenizerFailed = errors.New("whisper_tokenize failed")
ErrAutoDetectFailed = errors.New("whisper_lang_auto_detect failed")
ErrConversionFailed = errors.New("whisper_convert failed")
ErrInvalidLanguage = errors.New("invalid language")
)
///////////////////////////////////////////////////////////////////////////////
// PUBLIC METHODS
// Allocates all memory needed for the model and loads the model from the given file.
// Returns NULL on failure.
func Whisper_init(path string) *Context {
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
if ctx := C.whisper_init_from_file(cPath); ctx != nil {
return (*Context)(ctx)
} else {
return nil
}
}
// Frees all memory allocated by the model.
func (ctx *Context) Whisper_free() {
C.whisper_free((*C.struct_whisper_context)(ctx))
}
// Convert RAW PCM audio to log mel spectrogram.
// The resulting spectrogram is stored inside the provided whisper context.
func (ctx *Context) Whisper_pcm_to_mel(data []float32, threads int) error {
if C.whisper_pcm_to_mel((*C.struct_whisper_context)(ctx), (*C.float)(&data[0]), C.int(len(data)), C.int(threads)) == 0 {
return nil
} else {
return ErrConversionFailed
}
}
// This can be used to set a custom log mel spectrogram inside the provided whisper context.
// Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
// n_mel must be 80
func (ctx *Context) Whisper_set_mel(data []float32, n_mel int) error {
if C.whisper_set_mel((*C.struct_whisper_context)(ctx), (*C.float)(&data[0]), C.int(len(data)), C.int(n_mel)) == 0 {
return nil
} else {
return ErrConversionFailed
}
}
// Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper context.
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
// offset can be used to specify the offset of the first frame in the spectrogram.
func (ctx *Context) Whisper_encode(offset, threads int) error {
if C.whisper_encode((*C.struct_whisper_context)(ctx), C.int(offset), C.int(threads)) == 0 {
return nil
} else {
return ErrConversionFailed
}
}
// Run the Whisper decoder to obtain the logits and probabilities for the next token.
// Make sure to call whisper_encode() first.
// tokens + n_tokens is the provided context for the decoder.
// n_past is the number of tokens to use from previous decoder calls.
func (ctx *Context) Whisper_decode(tokens []Token, past, threads int) error {
if C.whisper_decode((*C.struct_whisper_context)(ctx), (*C.whisper_token)(&tokens[0]), C.int(len(tokens)), C.int(past), C.int(threads)) == 0 {
return nil
} else {
return ErrConversionFailed
}
}
// Convert the provided text into tokens. The tokens pointer must be large enough to hold the resulting tokens.
// Returns the number of tokens on success
func (ctx *Context) Whisper_tokenize(text string, tokens []Token) (int, error) {
cText := C.CString(text)
defer C.free(unsafe.Pointer(cText))
if n := C.whisper_tokenize((*C.struct_whisper_context)(ctx), cText, (*C.whisper_token)(&tokens[0]), C.int(len(tokens))); n >= 0 {
return int(n), nil
} else {
return 0, ErrTokenizerFailed
}
}
// Return the id of the specified language, returns -1 if not found
// Examples:
//
// "de" -> 2
// "german" -> 2
func (ctx *Context) Whisper_lang_id(lang string) int {
return int(C.whisper_lang_id(C.CString(lang)))
}
// Largest language id (i.e. number of available languages - 1)
func Whisper_lang_max_id() int {
return int(C.whisper_lang_max_id())
}
// Return the short string of the specified language id (e.g. 2 -> "de"),
// returns empty string if not found
func Whisper_lang_str(id int) string {
return C.GoString(C.whisper_lang_str(C.int(id)))
}
// Use mel data at offset_ms to try and auto-detect the spoken language
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
// Returns the probabilities of all languages.
// ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69
func (ctx *Context) Whisper_lang_auto_detect(offset_ms, n_threads int) ([]float32, error) {
probs := make([]float32, Whisper_lang_max_id()+1)
if n := int(C.whisper_lang_auto_detect((*C.struct_whisper_context)(ctx), C.int(offset_ms), C.int(n_threads), (*C.float)(&probs[0]))); n < 0 {
return nil, ErrAutoDetectFailed
} else {
return probs, nil
}
}
func (ctx *Context) Whisper_n_len() int {
return int(C.whisper_n_len((*C.struct_whisper_context)(ctx)))
}
func (ctx *Context) Whisper_n_vocab() int {
return int(C.whisper_n_vocab((*C.struct_whisper_context)(ctx)))
}
func (ctx *Context) Whisper_n_text_ctx() int {
return int(C.whisper_n_text_ctx((*C.struct_whisper_context)(ctx)))
}
func (ctx *Context) Whisper_n_audio_ctx() int {
return int(C.whisper_n_audio_ctx((*C.struct_whisper_context)(ctx)))
}
func (ctx *Context) Whisper_is_multilingual() int {
return int(C.whisper_is_multilingual((*C.struct_whisper_context)(ctx)))
}
// The probabilities for the next token
//func (ctx *Whisper_context) Whisper_get_probs() []float32 {
// return (*[1 << 30]float32)(unsafe.Pointer(C.whisper_get_probs((*C.struct_whisper_context)(ctx))))[:ctx.Whisper_n_vocab()]
//}
// Token Id -> String. Uses the vocabulary in the provided context
func (ctx *Context) Whisper_token_to_str(token Token) string {
return C.GoString(C.whisper_token_to_str((*C.struct_whisper_context)(ctx), C.whisper_token(token)))
}
// Special tokens
func (ctx *Context) Whisper_token_eot() Token {
return Token(C.whisper_token_eot((*C.struct_whisper_context)(ctx)))
}
// Special tokens
func (ctx *Context) Whisper_token_sot() Token {
return Token(C.whisper_token_sot((*C.struct_whisper_context)(ctx)))
}
// Special tokens
func (ctx *Context) Whisper_token_prev() Token {
return Token(C.whisper_token_prev((*C.struct_whisper_context)(ctx)))
}
// Special tokens
func (ctx *Context) Whisper_token_solm() Token {
return Token(C.whisper_token_solm((*C.struct_whisper_context)(ctx)))
}
// Special tokens
func (ctx *Context) Whisper_token_not() Token {
return Token(C.whisper_token_not((*C.struct_whisper_context)(ctx)))
}
// Special tokens
func (ctx *Context) Whisper_token_beg() Token {
return Token(C.whisper_token_beg((*C.struct_whisper_context)(ctx)))
}
// Special tokens
func (ctx *Context) Whisper_token_lang(lang_id int) Token {
return Token(C.whisper_token_lang((*C.struct_whisper_context)(ctx), C.int(lang_id)))
}
// Task tokens
func Whisper_token_translate() Token {
return Token(C.whisper_token_translate())
}
// Task tokens
func Whisper_token_transcribe() Token {
return Token(C.whisper_token_transcribe())
}
// Performance information
func (ctx *Context) Whisper_print_timings() {
C.whisper_print_timings((*C.struct_whisper_context)(ctx))
}
// Performance information
func (ctx *Context) Whisper_reset_timings() {
C.whisper_reset_timings((*C.struct_whisper_context)(ctx))
}
// Print system information
func Whisper_print_system_info() string {
return C.GoString(C.whisper_print_system_info())
}
// Return default parameters for a strategy
func (ctx *Context) Whisper_full_default_params(strategy SamplingStrategy) Params {
// Get default parameters
return Params(C.whisper_full_default_params_cb((*C.struct_whisper_context)(ctx), C.enum_whisper_sampling_strategy(strategy)))
}
// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
// Uses the specified decoding strategy to obtain the text.
func (ctx *Context) Whisper_full(params Params, samples []float32, encoderBeginCallback func() bool, newSegmentCallback func(int)) error {
registerEncoderBeginCallback(ctx, encoderBeginCallback)
registerNewSegmentCallback(ctx, newSegmentCallback)
defer registerEncoderBeginCallback(ctx, nil)
defer registerNewSegmentCallback(ctx, nil)
if C.whisper_full((*C.struct_whisper_context)(ctx), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples))) == 0 {
return nil
} else {
return ErrConversionFailed
}
}
// Split the input audio in chunks and process each chunk separately using whisper_full()
// It seems this approach can offer some speedup in some cases.
// However, the transcription accuracy can be worse at the beginning and end of each chunk.
func (ctx *Context) Whisper_full_parallel(params Params, samples []float32, processors int, encoderBeginCallback func() bool, newSegmentCallback func(int)) error {
registerEncoderBeginCallback(ctx, encoderBeginCallback)
registerNewSegmentCallback(ctx, newSegmentCallback)
defer registerEncoderBeginCallback(ctx, nil)
defer registerNewSegmentCallback(ctx, nil)
if C.whisper_full_parallel((*C.struct_whisper_context)(ctx), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples)), C.int(processors)) == 0 {
return nil
} else {
return ErrConversionFailed
}
}
// Number of generated text segments.
// A segment can be a few words, a sentence, or even a paragraph.
func (ctx *Context) Whisper_full_n_segments() int {
return int(C.whisper_full_n_segments((*C.struct_whisper_context)(ctx)))
}
// Get the start and end time of the specified segment.
func (ctx *Context) Whisper_full_get_segment_t0(segment int) int64 {
return int64(C.whisper_full_get_segment_t0((*C.struct_whisper_context)(ctx), C.int(segment)))
}
// Get the start and end time of the specified segment.
func (ctx *Context) Whisper_full_get_segment_t1(segment int) int64 {
return int64(C.whisper_full_get_segment_t1((*C.struct_whisper_context)(ctx), C.int(segment)))
}
// Get the text of the specified segment.
func (ctx *Context) Whisper_full_get_segment_text(segment int) string {
return C.GoString(C.whisper_full_get_segment_text((*C.struct_whisper_context)(ctx), C.int(segment)))
}
// Get number of tokens in the specified segment.
func (ctx *Context) Whisper_full_n_tokens(segment int) int {
return int(C.whisper_full_n_tokens((*C.struct_whisper_context)(ctx), C.int(segment)))
}
// Get the token text of the specified token index in the specified segment.
func (ctx *Context) Whisper_full_get_token_text(segment int, token int) string {
return C.GoString(C.whisper_full_get_token_text((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
}
// Get the token of the specified token index in the specified segment.
func (ctx *Context) Whisper_full_get_token_id(segment int, token int) Token {
return Token(C.whisper_full_get_token_id((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
}
// Get token data for the specified token in the specified segment.
// This contains probabilities, timestamps, etc.
func (ctx *Context) whisper_full_get_token_data(segment int, token int) TokenData {
return TokenData(C.whisper_full_get_token_data((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
}
// Get the probability of the specified token in the specified segment.
func (ctx *Context) Whisper_full_get_token_p(segment int, token int) float32 {
return float32(C.whisper_full_get_token_p((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
}
///////////////////////////////////////////////////////////////////////////////
// CALLBACKS
var (
cbNewSegment = make(map[unsafe.Pointer]func(int))
cbEncoderBegin = make(map[unsafe.Pointer]func() bool)
)
func registerNewSegmentCallback(ctx *Context, fn func(int)) {
if fn == nil {
delete(cbNewSegment, unsafe.Pointer(ctx))
} else {
cbNewSegment[unsafe.Pointer(ctx)] = fn
}
}
func registerEncoderBeginCallback(ctx *Context, fn func() bool) {
if fn == nil {
delete(cbEncoderBegin, unsafe.Pointer(ctx))
} else {
cbEncoderBegin[unsafe.Pointer(ctx)] = fn
}
}
//export callNewSegment
func callNewSegment(user_data unsafe.Pointer, new C.int) {
if fn, ok := cbNewSegment[user_data]; ok {
fn(int(new))
}
}
//export callEncoderBegin
func callEncoderBegin(user_data unsafe.Pointer) C.bool {
if fn, ok := cbEncoderBegin[user_data]; ok {
if fn() {
return C.bool(true)
} else {
return C.bool(false)
}
}
return true
}

View File

@ -1,113 +0,0 @@
package whisper_test
import (
"os"
"runtime"
"testing"
"time"
// Packages
whisper "github.com/ggerganov/whisper.cpp/bindings/go"
wav "github.com/go-audio/wav"
assert "github.com/stretchr/testify/assert"
)
const (
ModelPath = "models/ggml-small.en.bin"
SamplePath = "samples/jfk.wav"
)
func Test_Whisper_000(t *testing.T) {
assert := assert.New(t)
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
t.Skip("Skipping test, model not found:", ModelPath)
}
ctx := whisper.Whisper_init(ModelPath)
assert.NotNil(ctx)
ctx.Whisper_free()
}
func Test_Whisper_001(t *testing.T) {
assert := assert.New(t)
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
t.Skip("Skipping test, model not found:", ModelPath)
}
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
t.Skip("Skipping test, sample not found:", SamplePath)
}
// Open samples
fh, err := os.Open(SamplePath)
assert.NoError(err)
defer fh.Close()
// Read samples
d := wav.NewDecoder(fh)
buf, err := d.FullPCMBuffer()
assert.NoError(err)
// Run whisper
ctx := whisper.Whisper_init(ModelPath)
assert.NotNil(ctx)
defer ctx.Whisper_free()
params := ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY)
data := buf.AsFloat32Buffer().Data
err = ctx.Whisper_full(params, data, nil, nil)
assert.NoError(err)
// Print out tokens
num_segments := ctx.Whisper_full_n_segments()
assert.GreaterOrEqual(num_segments, 1)
for i := 0; i < num_segments; i++ {
str := ctx.Whisper_full_get_segment_text(i)
assert.NotEmpty(str)
t0 := time.Duration(ctx.Whisper_full_get_segment_t0(i)) * time.Millisecond
t1 := time.Duration(ctx.Whisper_full_get_segment_t1(i)) * time.Millisecond
t.Logf("[%6s->%-6s] %q", t0, t1, str)
}
}
func Test_Whisper_002(t *testing.T) {
assert := assert.New(t)
for i := 0; i < whisper.Whisper_lang_max_id(); i++ {
str := whisper.Whisper_lang_str(i)
assert.NotEmpty(str)
t.Log(str)
}
}
func Test_Whisper_003(t *testing.T) {
threads := runtime.NumCPU()
assert := assert.New(t)
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
t.Skip("Skipping test, model not found:", ModelPath)
}
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
t.Skip("Skipping test, sample not found:", SamplePath)
}
// Open samples
fh, err := os.Open(SamplePath)
assert.NoError(err)
defer fh.Close()
// Read samples
d := wav.NewDecoder(fh)
buf, err := d.FullPCMBuffer()
assert.NoError(err)
// Make the model
ctx := whisper.Whisper_init(ModelPath)
assert.NotNil(ctx)
defer ctx.Whisper_free()
// Get MEL
assert.NoError(ctx.Whisper_pcm_to_mel(buf.AsFloat32Buffer().Data, threads))
// Get Languages
languages, err := ctx.Whisper_lang_auto_detect(0, threads)
assert.NoError(err)
for i, p := range languages {
t.Logf("%s: %f", whisper.Whisper_lang_str(i), p)
}
}

View File

@ -9,33 +9,25 @@ target_link_libraries(${TARGET} PRIVATE
)
unset(EXTRA_FLAGS)
if (WHISPER_WASM_SINGLE_FILE)
set(EXTRA_FLAGS "-s SINGLE_FILE=1")
message(STATUS "Embedding WASM inside whisper.js")
add_custom_command(
TARGET ${TARGET} POST_BUILD
TARGET libwhisper POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy
${CMAKE_BINARY_DIR}/bin/libwhisper.js
${CMAKE_CURRENT_SOURCE_DIR}/whisper.js
)
add_custom_command(
TARGET ${TARGET} POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy
${CMAKE_BINARY_DIR}/bin/libwhisper.worker.js
${CMAKE_CURRENT_SOURCE_DIR}/libwhisper.worker.js
)
endif()
set_target_properties(${TARGET} PROPERTIES LINK_FLAGS " \
--bind \
-s MODULARIZE=1 \
-s EXPORT_NAME=\"'whisper_factory'\" \
-s FORCE_FILESYSTEM=1 \
-s USE_PTHREADS=1 \
-s PTHREAD_POOL_SIZE=8 \
-s ALLOW_MEMORY_GROWTH=1 \
-s INITIAL_MEMORY=1610612736 \
-s TOTAL_MEMORY=1610612736 \
-s FORCE_FILESYSTEM=1 \
-s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap']\" \
${EXTRA_FLAGS} \
")

View File

@ -1,78 +0,0 @@
# whisper.cpp
Node.js package for Whisper speech recognition
Package: https://www.npmjs.com/package/whisper.cpp
## Details
The performance is comparable to when running `whisper.cpp` in the browser via WASM.
The API is currently very rudimentary: [bindings/javascript/emscripten.cpp](/bindings/javascript/emscripten.cpp)
For sample usage check [tests/test-whisper.js](/tests/test-whisper.js)
## Package building + test
```bash
# load emscripten
source /path/to/emsdk/emsdk_env.sh
# clone repo
git clone https://github.com/ggerganov/whisper.cpp
cd whisper.cpp
# grab base.en model
./models/download-ggml-model.sh base.en
# prepare PCM sample for testing
ffmpeg -i samples/jfk.wav -f f32le -acodec pcm_f32le samples/jfk.pcmf32
# build
mkdir build-em && cd build-em
emcmake cmake .. && make -j
# run test
node --experimental-wasm-threads --experimental-wasm-simd ../tests/test-whisper.js
# publish npm package
make publish-npm
```
## Sample run
```java
$ node --experimental-wasm-threads --experimental-wasm-simd ../tests/test-whisper.js
whisper_model_load: loading model from 'whisper.bin'
whisper_model_load: n_vocab = 51864
whisper_model_load: n_audio_ctx = 1500
whisper_model_load: n_audio_state = 512
whisper_model_load: n_audio_head = 8
whisper_model_load: n_audio_layer = 6
whisper_model_load: n_text_ctx = 448
whisper_model_load: n_text_state = 512
whisper_model_load: n_text_head = 8
whisper_model_load: n_text_layer = 6
whisper_model_load: n_mels = 80
whisper_model_load: f16 = 1
whisper_model_load: type = 2
whisper_model_load: adding 1607 extra tokens
whisper_model_load: mem_required = 506.00 MB
whisper_model_load: ggml ctx size = 140.60 MB
whisper_model_load: memory size = 22.83 MB
whisper_model_load: model size = 140.54 MB
system_info: n_threads = 8 / 10 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | NEON = 0 | F16C = 0 | FP16_VA = 0 | WASM_SIMD = 1 | BLAS = 0 |
operator(): processing 176000 samples, 11.0 sec, 8 threads, 1 processors, lang = en, task = transcribe ...
[00:00:00.000 --> 00:00:11.000] And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country.
whisper_print_timings: load time = 162.37 ms
whisper_print_timings: mel time = 183.70 ms
whisper_print_timings: sample time = 4.27 ms
whisper_print_timings: encode time = 8582.63 ms / 1430.44 ms per layer
whisper_print_timings: decode time = 436.16 ms / 72.69 ms per layer
whisper_print_timings: total time = 9370.90 ms
```

View File

@ -1,58 +1,59 @@
//
// This is the Javascript API of whisper.cpp
//
// Very crude at the moment.
// Feel free to contribute and make this better!
//
// See the tests/test-whisper.js for sample usage
//
#include "whisper.h"
#include <emscripten.h>
#include <emscripten/bind.h>
#include <thread>
#include <vector>
#include <thread>
struct whisper_context * g_context;
std::vector<struct whisper_context *> g_contexts(4, nullptr);
EMSCRIPTEN_BINDINGS(whisper) {
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
if (g_context == nullptr) {
g_context = whisper_init_from_file(path_model.c_str());
if (g_context != nullptr) {
return true;
} else {
return false;
for (size_t i = 0; i < g_contexts.size(); ++i) {
if (g_contexts[i] == nullptr) {
g_contexts[i] = whisper_init(path_model.c_str());
if (g_contexts[i] != nullptr) {
return i + 1;
} else {
return (size_t) 0;
}
}
}
return false;
return (size_t) 0;
}));
emscripten::function("free", emscripten::optional_override([]() {
if (g_context) {
whisper_free(g_context);
g_context = nullptr;
emscripten::function("free", emscripten::optional_override([](size_t index) {
--index;
if (index < g_contexts.size()) {
whisper_free(g_contexts[index]);
g_contexts[index] = nullptr;
}
}));
emscripten::function("full_default", emscripten::optional_override([](const emscripten::val & audio, const std::string & lang, bool translate) {
if (g_context == nullptr) {
emscripten::function("full_default", emscripten::optional_override([](size_t index, const emscripten::val & audio, const std::string & lang, bool translate) {
--index;
if (index >= g_contexts.size()) {
return -1;
}
if (g_contexts[index] == nullptr) {
return -2;
}
struct whisper_full_params params = whisper_full_default_params(whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY);
params.print_realtime = true;
params.print_progress = false;
params.print_timestamps = true;
params.print_special = false;
params.translate = translate;
params.language = whisper_is_multilingual(g_context) ? lang.c_str() : "en";
params.n_threads = std::min(8, (int) std::thread::hardware_concurrency());
params.offset_ms = 0;
params.print_realtime = true;
params.print_progress = false;
params.print_timestamps = true;
params.print_special_tokens = false;
params.translate = translate;
params.language = whisper_is_multilingual(g_contexts[index]) ? lang.c_str() : "en";
params.n_threads = std::min(8, (int) std::thread::hardware_concurrency());
params.offset_ms = 0;
std::vector<float> pcmf32;
const int n = audio["length"].as<int>();
@ -67,11 +68,9 @@ EMSCRIPTEN_BINDINGS(whisper) {
// print system information
{
printf("\n");
printf("system_info: n_threads = %d / %d | %s\n",
params.n_threads, std::thread::hardware_concurrency(), whisper_print_system_info());
printf("\n");
printf("%s: processing %d samples, %.1f sec, %d threads, %d processors, lang = %s, task = %s ...\n",
__func__, int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
params.n_threads, 1,
@ -81,13 +80,10 @@ EMSCRIPTEN_BINDINGS(whisper) {
printf("\n");
}
// run whisper
{
whisper_reset_timings(g_context);
whisper_full(g_context, params, pcmf32.data(), pcmf32.size());
whisper_print_timings(g_context);
}
int ret = whisper_full(g_contexts[index], params, pcmf32.data(), pcmf32.size());
return 0;
whisper_print_timings(g_contexts[index]);
return ret;
}));
}

View File

@ -1 +0,0 @@
"use strict";var Module={};var ENVIRONMENT_IS_NODE=typeof process=="object"&&typeof process.versions=="object"&&typeof process.versions.node=="string";if(ENVIRONMENT_IS_NODE){var nodeWorkerThreads=require("worker_threads");var parentPort=nodeWorkerThreads.parentPort;parentPort.on("message",data=>onmessage({data:data}));var fs=require("fs");Object.assign(global,{self:global,require:require,Module:Module,location:{href:__filename},Worker:nodeWorkerThreads.Worker,importScripts:function(f){(0,eval)(fs.readFileSync(f,"utf8")+"//# sourceURL="+f)},postMessage:function(msg){parentPort.postMessage(msg)},performance:global.performance||{now:function(){return Date.now()}}})}var initializedJS=false;var pendingNotifiedProxyingQueues=[];function threadPrintErr(){var text=Array.prototype.slice.call(arguments).join(" ");if(ENVIRONMENT_IS_NODE){fs.writeSync(2,text+"\n");return}console.error(text)}function threadAlert(){var text=Array.prototype.slice.call(arguments).join(" ");postMessage({cmd:"alert",text:text,threadId:Module["_pthread_self"]()})}var err=threadPrintErr;self.alert=threadAlert;Module["instantiateWasm"]=(info,receiveInstance)=>{var instance=new WebAssembly.Instance(Module["wasmModule"],info);receiveInstance(instance);Module["wasmModule"]=null;return instance.exports};self.onunhandledrejection=e=>{throw e.reason??e};self.onmessage=e=>{try{if(e.data.cmd==="load"){Module["wasmModule"]=e.data.wasmModule;for(const handler of e.data.handlers){Module[handler]=function(){postMessage({cmd:"callHandler",handler:handler,args:[...arguments]})}}Module["wasmMemory"]=e.data.wasmMemory;Module["buffer"]=Module["wasmMemory"].buffer;Module["ENVIRONMENT_IS_PTHREAD"]=true;if(typeof e.data.urlOrBlob=="string"){importScripts(e.data.urlOrBlob)}else{var objectUrl=URL.createObjectURL(e.data.urlOrBlob);importScripts(objectUrl);URL.revokeObjectURL(objectUrl)}whisper_factory(Module).then(function(instance){Module=instance})}else if(e.data.cmd==="run"){Module["__performance_now_clock_drift"]=performance.now()-e.data.time;Module["__emscripten_thread_init"](e.data.pthread_ptr,0,0,1);Module["establishStackSpace"]();Module["PThread"].receiveObjectTransfer(e.data);Module["PThread"].threadInitTLS();if(!initializedJS){Module["__embind_initialize_bindings"]();pendingNotifiedProxyingQueues.forEach(queue=>{Module["executeNotifiedProxyingQueue"](queue)});pendingNotifiedProxyingQueues=[];initializedJS=true}try{Module["invokeEntryPoint"](e.data.start_routine,e.data.arg)}catch(ex){if(ex!="unwind"){if(ex instanceof Module["ExitStatus"]){if(Module["keepRuntimeAlive"]()){}else{Module["__emscripten_thread_exit"](ex.status)}}else{throw ex}}}}else if(e.data.cmd==="cancel"){if(Module["_pthread_self"]()){Module["__emscripten_thread_exit"](-1)}}else if(e.data.target==="setimmediate"){}else if(e.data.cmd==="processProxyingQueue"){if(initializedJS){Module["executeNotifiedProxyingQueue"](e.data.queue)}else{pendingNotifiedProxyingQueues.push(e.data.queue)}}else if(e.data.cmd){err("worker.js received unknown command "+e.data.cmd);err(e.data)}}catch(ex){if(Module["__emscripten_thread_crashed"]){Module["__emscripten_thread_crashed"]()}throw ex}};

View File

@ -1,26 +0,0 @@
{
"name": "whisper.cpp",
"version": "@PROJECT_VERSION@",
"description": "Whisper speech recognition",
"main": "whisper.js",
"scripts": {
"test": "echo \"todo: add tests\" && exit 0"
},
"repository": {
"type": "git",
"url": "git+https://github.com/ggerganov/whisper.cpp"
},
"keywords": [
"openai",
"whisper",
"speech-to-text",
"speech-recognition",
"transformer"
],
"author": "Georgi Gerganov",
"license": "MIT",
"bugs": {
"url": "https://github.com/ggerganov/whisper.cpp/issues"
},
"homepage": "https://github.com/ggerganov/whisper.cpp#readme"
}

View File

@ -1,26 +0,0 @@
{
"name": "whisper.cpp",
"version": "1.2.0",
"description": "Whisper speech recognition",
"main": "whisper.js",
"scripts": {
"test": "echo \"todo: add tests\" && exit 0"
},
"repository": {
"type": "git",
"url": "git+https://github.com/ggerganov/whisper.cpp"
},
"keywords": [
"openai",
"whisper",
"speech-to-text",
"speech-recognition",
"transformer"
],
"author": "Georgi Gerganov",
"license": "MIT",
"bugs": {
"url": "https://github.com/ggerganov/whisper.cpp/issues"
},
"homepage": "https://github.com/ggerganov/whisper.cpp#readme"
}

File diff suppressed because one or more lines are too long

View File

@ -1,17 +0,0 @@
# Set the default compile features and properties for a target.
if (NOT TARGET)
message(FATAL_ERROR "TARGET not set before including DefaultTargetOptions")
endif()
target_compile_features(${TARGET}
PRIVATE
cxx_std_11
)
set_target_properties(${TARGET}
PROPERTIES
EXPORT_COMPILE_COMMANDS ON
RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin"
INSTALL_RPATH "${CMAKE_INSTALL_PREFIX}/lib"
)

View File

@ -20,16 +20,8 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR})
if (EMSCRIPTEN)
add_subdirectory(whisper.wasm)
add_subdirectory(stream.wasm)
add_subdirectory(command.wasm)
add_subdirectory(talk.wasm)
add_subdirectory(bench.wasm)
elseif(CMAKE_JS_VERSION)
add_subdirectory(addon.node)
else()
add_subdirectory(main)
add_subdirectory(stream)
add_subdirectory(command)
add_subdirectory(bench)
add_subdirectory(talk)
endif()

View File

@ -1,3 +0,0 @@
.idea
node_modules
build

View File

@ -1,26 +0,0 @@
set(TARGET whisper-addon)
# Base settings
#==================================================================
# env var supported by cmake-js
add_definitions(-DNAPI_VERSION=4)
include_directories(${CMAKE_JS_INC})
#==================================================================
add_library(${TARGET} SHARED ${CMAKE_JS_SRC} addon.cpp)
set_target_properties(${TARGET} PROPERTIES PREFIX "" SUFFIX ".node")
include(DefaultTargetOptions)
# Include N-API wrappers
#==================================================================
execute_process(COMMAND node -p "require('node-addon-api').include"
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
OUTPUT_VARIABLE NODE_ADDON_API_DIR
)
string(REPLACE "\n" "" NODE_ADDON_API_DIR ${NODE_ADDON_API_DIR})
string(REPLACE "\"" "" NODE_ADDON_API_DIR ${NODE_ADDON_API_DIR})
target_include_directories(${TARGET} PRIVATE ${NODE_ADDON_API_DIR})
#==================================================================
target_link_libraries(${TARGET} ${CMAKE_JS_LIB} whisper ${CMAKE_THREAD_LIBS_INIT})

View File

@ -1,37 +0,0 @@
# addon
This is an addon demo that can **perform whisper model reasoning in `node` and `electron` environments**, based on [cmake-js](https://github.com/cmake-js/cmake-js).
It can be used as a reference for using the whisper.cpp project in other node projects.
## Install
```shell
npm install
```
## Compile
Make sure it is in the project root directory and compiled with make-js.
```shell
npx cmake-js compile -T whisper-addon
```
For Electron addon and cmake-js options, you can see [cmake-js](https://github.com/cmake-js/cmake-js) and make very few configuration changes.
> Such as appointing special cmake path:
> ```shell
> npx cmake-js compile -c 'xxx/cmake' -T whisper-addon
> ```
## Run
```shell
cd examples/addon.node
node index.js --language='language' --model='model-path' --fname_inp='file-path'
```
Because this is a simple Demo, only the above parameters are set in the node environment.
Other parameters can also be specified in the node environment.

View File

@ -1,421 +0,0 @@
#include <string>
#include <thread>
#include <vector>
#include <cmath>
#include "napi.h"
#define DR_WAV_IMPLEMENTATION
#include "dr_wav.h"
#include "whisper.h"
struct whisper_params {
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
int32_t n_processors = 1;
int32_t offset_t_ms = 0;
int32_t offset_n = 0;
int32_t duration_ms = 0;
int32_t max_context = -1;
int32_t max_len = 0;
int32_t best_of = 5;
int32_t beam_size = -1;
float word_thold = 0.01f;
float entropy_thold = 2.4f;
float logprob_thold = -1.0f;
bool speed_up = false;
bool translate = false;
bool diarize = false;
bool output_txt = false;
bool output_vtt = false;
bool output_srt = false;
bool output_wts = false;
bool output_csv = false;
bool print_special = false;
bool print_colors = false;
bool print_progress = false;
bool no_timestamps = false;
std::string language = "en";
std::string prompt;
std::string model = "../../ggml-large.bin";
std::vector<std::string> fname_inp = {};
std::vector<std::string> fname_outp = {};
};
struct whisper_print_user_data {
const whisper_params * params;
const std::vector<std::vector<float>> * pcmf32s;
};
// 500 -> 00:05.000
// 6000 -> 01:00.000
std::string to_timestamp(int64_t t, bool comma = false) {
int64_t msec = t * 10;
int64_t hr = msec / (1000 * 60 * 60);
msec = msec - hr * (1000 * 60 * 60);
int64_t min = msec / (1000 * 60);
msec = msec - min * (1000 * 60);
int64_t sec = msec / 1000;
msec = msec - sec * 1000;
char buf[32];
snprintf(buf, sizeof(buf), "%02d:%02d:%02d%s%03d", (int) hr, (int) min, (int) sec, comma ? "," : ".", (int) msec);
return std::string(buf);
}
int timestamp_to_sample(int64_t t, int n_samples) {
return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100)));
}
void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) {
const auto & params = *((whisper_print_user_data *) user_data)->params;
const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s;
const int n_segments = whisper_full_n_segments(ctx);
std::string speaker = "";
int64_t t0;
int64_t t1;
// print the last n_new segments
const int s0 = n_segments - n_new;
if (s0 == 0) {
printf("\n");
}
for (int i = s0; i < n_segments; i++) {
if (!params.no_timestamps || params.diarize) {
t0 = whisper_full_get_segment_t0(ctx, i);
t1 = whisper_full_get_segment_t1(ctx, i);
}
if (!params.no_timestamps) {
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
}
if (params.diarize && pcmf32s.size() == 2) {
const int64_t n_samples = pcmf32s[0].size();
const int64_t is0 = timestamp_to_sample(t0, n_samples);
const int64_t is1 = timestamp_to_sample(t1, n_samples);
double energy0 = 0.0f;
double energy1 = 0.0f;
for (int64_t j = is0; j < is1; j++) {
energy0 += fabs(pcmf32s[0][j]);
energy1 += fabs(pcmf32s[1][j]);
}
if (energy0 > 1.1*energy1) {
speaker = "(speaker 0)";
} else if (energy1 > 1.1*energy0) {
speaker = "(speaker 1)";
} else {
speaker = "(speaker ?)";
}
//printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str());
}
// colorful print bug
//
const char * text = whisper_full_get_segment_text(ctx, i);
printf("%s%s", speaker.c_str(), text);
// with timestamps or speakers: each segment on new line
if (!params.no_timestamps || params.diarize) {
printf("\n");
}
fflush(stdout);
}
}
int run(whisper_params &params, std::vector<std::vector<std::string>> &result) {
if (params.fname_inp.empty()) {
fprintf(stderr, "error: no input files specified\n");
return 2;
}
if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1) {
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
exit(0);
}
// whisper init
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
if (ctx == nullptr) {
fprintf(stderr, "error: failed to initialize whisper context\n");
return 3;
}
// initial prompt
std::vector<whisper_token> prompt_tokens;
if (!params.prompt.empty()) {
prompt_tokens.resize(1024);
prompt_tokens.resize(whisper_tokenize(ctx, params.prompt.c_str(), prompt_tokens.data(), prompt_tokens.size()));
fprintf(stderr, "\n");
fprintf(stderr, "initial prompt: '%s'\n", params.prompt.c_str());
fprintf(stderr, "initial tokens: [ ");
for (int i = 0; i < (int) prompt_tokens.size(); ++i) {
fprintf(stderr, "%d ", prompt_tokens[i]);
}
fprintf(stderr, "]\n");
}
for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
const auto fname_inp = params.fname_inp[f];
const auto fname_outp = f < (int)params.fname_outp.size() && !params.fname_outp[f].empty() ? params.fname_outp[f] : params.fname_inp[f];
std::vector<float> pcmf32; // mono-channel F32 PCM
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
// WAV input
{
drwav wav;
std::vector<uint8_t> wav_data; // used for pipe input from stdin
if (fname_inp == "-") {
{
uint8_t buf[1024];
while (true)
{
const size_t n = fread(buf, 1, sizeof(buf), stdin);
if (n == 0) {
break;
}
wav_data.insert(wav_data.end(), buf, buf + n);
}
}
if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) {
fprintf(stderr, "error: failed to open WAV file from stdin\n");
return 4;
}
fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
}
else if (drwav_init_file(&wav, fname_inp.c_str(), nullptr) == false) {
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
return 5;
}
if (wav.channels != 1 && wav.channels != 2) {
fprintf(stderr, "error: WAV file '%s' must be mono or stereo\n", fname_inp.c_str());
return 6;
}
if (params.diarize && wav.channels != 2 && params.no_timestamps == false) {
fprintf(stderr, "error: WAV file '%s' must be stereo for diarization and timestamps have to be enabled\n", fname_inp.c_str());
return 6;
}
if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
fprintf(stderr, "error: WAV file '%s' must be %i kHz\n", fname_inp.c_str(), WHISPER_SAMPLE_RATE/1000);
return 8;
}
if (wav.bitsPerSample != 16) {
fprintf(stderr, "error: WAV file '%s' must be 16-bit\n", fname_inp.c_str());
return 9;
}
const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8);
std::vector<int16_t> pcm16;
pcm16.resize(n*wav.channels);
drwav_read_pcm_frames_s16(&wav, n, pcm16.data());
drwav_uninit(&wav);
// convert to mono, float
pcmf32.resize(n);
if (wav.channels == 1) {
for (uint64_t i = 0; i < n; i++) {
pcmf32[i] = float(pcm16[i])/32768.0f;
}
} else {
for (uint64_t i = 0; i < n; i++) {
pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
}
}
if (params.diarize) {
// convert to stereo, float
pcmf32s.resize(2);
pcmf32s[0].resize(n);
pcmf32s[1].resize(n);
for (uint64_t i = 0; i < n; i++) {
pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
}
}
}
// print system information
{
fprintf(stderr, "\n");
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
params.n_threads*params.n_processors, std::thread::hardware_concurrency(), whisper_print_system_info());
}
// print some info about the processing
{
fprintf(stderr, "\n");
if (!whisper_is_multilingual(ctx)) {
if (params.language != "en" || params.translate) {
params.language = "en";
params.translate = false;
fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
}
}
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, timestamps = %d ...\n",
__func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
params.n_threads, params.n_processors,
params.language.c_str(),
params.translate ? "translate" : "transcribe",
params.no_timestamps ? 0 : 1);
fprintf(stderr, "\n");
}
// run the inference
{
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
wparams.strategy = params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY;
wparams.print_realtime = false;
wparams.print_progress = params.print_progress;
wparams.print_timestamps = !params.no_timestamps;
wparams.print_special = params.print_special;
wparams.translate = params.translate;
wparams.language = params.language.c_str();
wparams.n_threads = params.n_threads;
wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
wparams.offset_ms = params.offset_t_ms;
wparams.duration_ms = params.duration_ms;
wparams.token_timestamps = params.output_wts || params.max_len > 0;
wparams.thold_pt = params.word_thold;
wparams.entropy_thold = params.entropy_thold;
wparams.logprob_thold = params.logprob_thold;
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
wparams.speed_up = params.speed_up;
wparams.greedy.best_of = params.best_of;
wparams.beam_search.beam_size = params.beam_size;
wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();
whisper_print_user_data user_data = { &params, &pcmf32s };
// this callback is called on each new segment
if (!wparams.print_realtime) {
wparams.new_segment_callback = whisper_print_segment_callback;
wparams.new_segment_callback_user_data = &user_data;
}
// example for abort mechanism
// in this example, we do not abort the processing, but we could if the flag is set to true
// the callback is called before every encoder run - if it returns false, the processing is aborted
{
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, void * user_data) {
bool is_aborted = *(bool*)user_data;
return !is_aborted;
};
wparams.encoder_begin_callback_user_data = &is_aborted;
}
if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
fprintf(stderr, "failed to process audio\n");
return 10;
}
}
}
const int n_segments = whisper_full_n_segments(ctx);
result.resize(n_segments);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
result[i].emplace_back(to_timestamp(t0, true));
result[i].emplace_back(to_timestamp(t1, true));
result[i].emplace_back(text);
}
whisper_print_timings(ctx);
whisper_free(ctx);
return 0;
}
Napi::Object whisper(const Napi::CallbackInfo& info) {
Napi::Env env = info.Env();
if (info.Length() <= 0 || !info[0].IsObject()) {
Napi::TypeError::New(env, "object expected").ThrowAsJavaScriptException();
}
whisper_params params;
std::vector<std::vector<std::string>> result;
Napi::Object whisper_params = info[0].As<Napi::Object>();
std::string language = whisper_params.Get("language").As<Napi::String>();
std::string model = whisper_params.Get("model").As<Napi::String>();
std::string input = whisper_params.Get("fname_inp").As<Napi::String>();
params.language = language;
params.model = model;
params.fname_inp.emplace_back(input);
// run model
run(params, result);
fprintf(stderr, "RESULT:\n");
for (auto sentence:result) {
fprintf(stderr, "t0: %s, t1: %s, content: %s \n",
sentence[0].c_str(), sentence[1].c_str(), sentence[2].c_str());
}
Napi::Object res = Napi::Array::New(env, result.size());
for (u_int32_t i = 0; i < result.size(); ++i) {
Napi::Object tmp = Napi::Array::New(env, 3);
for (u_int32_t j = 0; j < 3; ++j) {
tmp[j] = Napi::String::New(env, result[i][j]);
}
res[i] = tmp;
}
return res;
}
Napi::Object Init(Napi::Env env, Napi::Object exports) {
exports.Set(
Napi::String::New(env, "whisper"),
Napi::Function::New(env, whisper)
);
return exports;
}
NODE_API_MODULE(whisper, Init);

View File

@ -1,27 +0,0 @@
const path = require('path');
const { whisper } = require(path.join(__dirname, '../../build/Release/whisper-addon'));
const whisperParams = {
language: 'en',
model: path.join(__dirname, '../../models/ggml-base.en.bin'),
fname_inp: '',
};
const arguments = process.argv.slice(2);
const params = Object.fromEntries(
arguments.reduce((pre, item) => {
if (item.startsWith("--")) {
return [...pre, item.slice(2).split("=")];
}
return pre;
}, []),
);
for (const key in params) {
if (whisperParams.hasOwnProperty(key)) {
whisperParams[key] = params[key];
}
}
console.log('whisperParams =', whisperParams);
console.log(whisper(whisperParams));

View File

@ -1,12 +0,0 @@
{
"name": "whisper-addon",
"version": "0.0.0",
"description": "",
"main": "index.js",
"author": "Qanhe Chen",
"license": "MIT",
"devDependencies": {
"cmake-js": "^7.1.1",
"node-addon-api": "^5.0.0"
}
}

View File

@ -1,49 +0,0 @@
#
# libbench
#
set(TARGET libbench)
add_executable(${TARGET}
emscripten.cpp
)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE
whisper
)
unset(EXTRA_FLAGS)
if (WHISPER_WASM_SINGLE_FILE)
set(EXTRA_FLAGS "-s SINGLE_FILE=1")
message(STATUS "Embedding WASM inside bench.js")
add_custom_command(
TARGET ${TARGET} POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy
${CMAKE_BINARY_DIR}/bin/libbench.js
${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/bench.wasm/bench.js
)
endif()
set_target_properties(${TARGET} PROPERTIES LINK_FLAGS " \
--bind \
-s USE_PTHREADS=1 \
-s PTHREAD_POOL_SIZE=8 \
-s INITIAL_MEMORY=1024MB \
-s TOTAL_MEMORY=1024MB \
-s FORCE_FILESYSTEM=1 \
-s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap']\" \
${EXTRA_FLAGS} \
")
#
# bench.wasm
#
set(TARGET bench.wasm)
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/index-tmpl.html ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TARGET}/index.html @ONLY)
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/../helpers.js ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TARGET}/helpers.js @ONLY)

View File

@ -1,22 +0,0 @@
# bench.wasm
Benchmark the performance of whisper.cpp in the browser using WebAssembly
Link: https://whisper.ggerganov.com/bench/
Terminal version: [examples/bench](/examples/bench)
## Build instructions
```bash
# build using Emscripten (v3.1.2)
git clone https://github.com/ggerganov/whisper.cpp
cd whisper.cpp
mkdir build-em && cd build-em
emcmake cmake ..
make -j
# copy the produced page to your HTTP path
cp bin/bench.wasm/* /path/to/html/
cp bin/libbench.worker.js /path/to/html/
```

View File

@ -1,85 +0,0 @@
#include "whisper.h"
#include <emscripten.h>
#include <emscripten/bind.h>
#include <cmath>
#include <string>
#include <thread>
#include <vector>
constexpr int N_THREAD = 8;
// TODO: get rid of this vector of contexts - bad idea in the first place
std::vector<struct whisper_context *> g_contexts(4, nullptr);
std::thread g_worker;
void bench_main(size_t index) {
const int n_threads = std::min(N_THREAD, (int) std::thread::hardware_concurrency());
// whisper context
auto & ctx = g_contexts[index];
fprintf(stderr, "%s: running benchmark with %d threads - please wait...\n", __func__, n_threads);
if (int ret = whisper_set_mel(ctx, nullptr, 0, WHISPER_N_MEL)) {
fprintf(stderr, "error: failed to set mel: %d\n", ret);
return;
}
{
fprintf(stderr, "\n");
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", n_threads, std::thread::hardware_concurrency(), whisper_print_system_info());
}
if (int ret = whisper_encode(ctx, 0, n_threads) != 0) {
fprintf(stderr, "error: failed to encode model: %d\n", ret);
return;
}
whisper_print_timings(ctx);
fprintf(stderr, "\n");
fprintf(stderr, "If you wish, you can submit these results here:\n");
fprintf(stderr, "\n");
fprintf(stderr, " https://github.com/ggerganov/whisper.cpp/issues/89\n");
fprintf(stderr, "\n");
fprintf(stderr, "Please include the following information:\n");
fprintf(stderr, "\n");
fprintf(stderr, " - CPU model\n");
fprintf(stderr, " - Operating system\n");
fprintf(stderr, " - Browser\n");
fprintf(stderr, "\n");
}
EMSCRIPTEN_BINDINGS(bench) {
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
for (size_t i = 0; i < g_contexts.size(); ++i) {
if (g_contexts[i] == nullptr) {
g_contexts[i] = whisper_init_from_file(path_model.c_str());
if (g_contexts[i] != nullptr) {
if (g_worker.joinable()) {
g_worker.join();
}
g_worker = std::thread([i]() {
bench_main(i);
});
return i + 1;
} else {
return (size_t) 0;
}
}
}
return (size_t) 0;
}));
emscripten::function("free", emscripten::optional_override([](size_t index) {
if (index < g_contexts.size()) {
whisper_free(g_contexts[index]);
g_contexts[index] = nullptr;
}
}));
}

View File

@ -1,227 +0,0 @@
<!doctype html>
<html lang="en-us">
<head>
<title>bench : Benchmark whisper.cpp performance in the browser</title>
<style>
#output {
width: 100%;
height: 100%;
margin: 0 auto;
margin-top: 10px;
border-left: 0px;
border-right: 0px;
padding-left: 0px;
padding-right: 0px;
display: block;
background-color: black;
color: white;
font-size: 10px;
font-family: 'Lucida Console', Monaco, monospace;
outline: none;
white-space: pre;
overflow-wrap: normal;
overflow-x: scroll;
}
</style>
</head>
<body>
<div id="main-container">
<b>bench : Benchmark whisper.cpp performance in the browser</b>
<br><br>
You can find more about this project on <a href="https://github.com/ggerganov/whisper.cpp/tree/master/examples/bench.wasm">GitHub</a>.
<br><br>
<hr>
Select the model you would like to use and click the "Bench" button.<br>
The results will be displayed in the textarea below.
<br><br>
<div id="model-whisper">
Whisper model: <span id="model-whisper-status"></span>
<button id="fetch-whisper-tiny-en" onclick="loadWhisper('tiny.en')">tiny.en (75 MB)</button>
<button id="fetch-whisper-base-en" onclick="loadWhisper('base.en')">base.en (142 MB)</button>
<span id="fetch-whisper-progress"></span>
<input type="file" id="whisper-file" name="file" onchange="loadFile(event, 'whisper.bin')" />
</div>
<br>
<div id="input">
<button id="bench" onclick="onBench()" disabled>Bench</button>
<button id="clear" onclick="clearCache()">Clear Cache</button>
</div>
<hr>
Debug output:
<textarea id="output" rows="20"></textarea>
<br>
<b>Troubleshooting</b>
<br><br>
The page does some heavy computations, so make sure:
<ul>
<li>To use a modern web browser (e.g. Chrome, Firefox)</li>
<li>To use a fast desktop or laptop computer (i.e. not a mobile phone)</li>
<li>Your browser supports WASM <a href="https://webassembly.org/roadmap/">Fixed-width SIMD</a></li>
</ul>
<div class="cell-version">
<span>
|
Build time: <span class="nav-link">@GIT_DATE@</span> |
Commit hash: <a class="nav-link" href="https://github.com/ggerganov/whisper.cpp/commit/@GIT_SHA1@">@GIT_SHA1@</a> |
Commit subject: <span class="nav-link">@GIT_COMMIT_SUBJECT@</span> |
<a class="nav-link" href="https://github.com/ggerganov/whisper.cpp/tree/master/examples/bench.wasm">Source Code</a> |
</span>
</div>
</div>
<script type="text/javascript" src="helpers.js"></script>
<script type='text/javascript'>
// the bench instance
var instance = null;
// model name
var model_whisper = null;
var Module = {
print: printTextarea,
printErr: printTextarea,
setStatus: function(text) {
printTextarea('js: ' + text);
},
monitorRunDependencies: function(left) {
},
preRun: function() {
printTextarea('js: Preparing ...');
},
postRun: function() {
printTextarea('js: Initialized successfully!');
}
};
//
// fetch models
//
let dbVersion = 1
let dbName = 'whisper.ggerganov.com';
let indexedDB = window.indexedDB || window.mozIndexedDB || window.webkitIndexedDB || window.msIndexedDB
function storeFS(fname, buf) {
// write to WASM file using FS_createDataFile
// if the file exists, delete it
try {
Module.FS_unlink(fname);
} catch (e) {
// ignore
}
Module.FS_createDataFile("/", fname, buf, true, true);
printTextarea('storeFS: stored model: ' + fname + ' size: ' + buf.length);
model_whisper = fname;
document.getElementById('model-whisper-status').innerHTML = 'loaded "' + model_whisper + '"!';
if (model_whisper != null) {
document.getElementById('bench').disabled = false;
}
}
function loadFile(event, fname) {
var file = event.target.files[0] || null;
if (file == null) {
return;
}
printTextarea("loadFile: loading model: " + file.name + ", size: " + file.size + " bytes");
printTextarea('loadFile: please wait ...');
var reader = new FileReader();
reader.onload = function(event) {
var buf = new Uint8Array(reader.result);
storeFS(fname, buf);
}
reader.readAsArrayBuffer(file);
document.getElementById('fetch-whisper-tiny-en').style.display = 'none';
document.getElementById('fetch-whisper-base-en').style.display = 'none';
document.getElementById('whisper-file' ).style.display = 'none';
document.getElementById('model-whisper-status' ).innerHTML = 'loaded model: ' + file.name;
}
function loadWhisper(model) {
let urls = {
'tiny.en': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en.bin',
'base.en': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en.bin',
};
let sizes = {
'tiny.en': 75,
'base.en': 142,
};
let url = urls[model];
let dst = 'whisper.bin';
let size_mb = sizes[model];
document.getElementById('fetch-whisper-tiny-en').style.display = 'none';
document.getElementById('fetch-whisper-base-en').style.display = 'none';
document.getElementById('model-whisper-status').innerHTML = 'loading "' + model + '" ... ';
cbProgress = function(p) {
let el = document.getElementById('fetch-whisper-progress');
el.innerHTML = Math.round(100*p) + '%';
};
cbCancel = function() {
var el;
el = document.getElementById('fetch-whisper-tiny-en'); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-base-en'); if (el) el.style.display = 'inline-block';
el = document.getElementById('model-whisper-status'); if (el) el.innerHTML = '';
};
loadRemote(url, dst, size_mb, cbProgress, storeFS, cbCancel, printTextarea);
}
//
// main
//
function onBench() {
if (instance) {
Module.free(instance);
}
instance = Module.init('whisper.bin');
if (instance) {
printTextarea("js: whisper initialized, instance: " + instance);
}
document.getElementById('bench').disabled = true;
if (!instance) {
printTextarea("js: failed to initialize whisper");
return;
}
}
</script>
<script type="text/javascript" src="bench.js"></script>
</body>
</html>

View File

@ -1,6 +1,3 @@
set(TARGET bench)
add_executable(${TARGET} bench.cpp)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE whisper ${CMAKE_THREAD_LIBS_INIT})

View File

@ -1,8 +1,6 @@
# bench
A very basic tool for benchmarking the inference performance on your device. The tool simply runs the Encoder part of
the transformer on some random audio data and records the execution time. This way we can have an objective comparison
of the performance of the model for various setups.
A very basic tool for benchmarking the inference performance on your device. The tool simply runs the Encoder part of the transformer on some random audio data and records the execution time. This way we can have an objective comparison of the performance of the model for various setups.
Benchmark results are tracked in the following Github issue: https://github.com/ggerganov/whisper.cpp/issues/89

View File

@ -6,10 +6,9 @@
// command-line parameters
struct whisper_params {
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
int32_t what = 0; // what to benchmark: 0 - whisper ecoder, 1 - memcpy, 2 - ggml_mul_mat
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
std::string model = "models/ggml-base.en.bin";
std::string model = "models/ggml-base.en.bin";
};
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
@ -18,14 +17,14 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
for (int i = 1; i < argc; i++) {
std::string arg = argv[i];
if (arg == "-h" || arg == "--help") {
if (arg == "-t" || arg == "--threads") {
params.n_threads = std::stoi(argv[++i]);
} else if (arg == "-m" || arg == "--model") {
params.model = argv[++i];
} else if (arg == "-h" || arg == "--help") {
whisper_print_usage(argc, argv, params);
exit(0);
}
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else if (arg == "-w" || arg == "--what") { params.what = atoi(argv[++i]); }
else {
} else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params);
exit(0);
@ -35,25 +34,27 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
return true;
}
void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params) {
void whisper_print_usage(int argc, char ** argv, const whisper_params & params) {
fprintf(stderr, "\n");
fprintf(stderr, "usage: %s [options]\n", argv[0]);
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -w N, --what N [%-7d] what to benchmark:\n", params.what);
fprintf(stderr, " %-7s 0 - whisper encoder\n", "");
fprintf(stderr, " %-7s 1 - memcpy\n", "");
fprintf(stderr, " %-7s 2 - ggml_mul_mat\n", "");
fprintf(stderr, " -h, --help show this help message and exit\n");
fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
fprintf(stderr, " -m FNAME, --model FNAME model path (default: %s)\n", params.model.c_str());
fprintf(stderr, "\n");
}
int whisper_bench_encoder(const whisper_params & params) {
int main(int argc, char ** argv) {
whisper_params params;
if (whisper_params_parse(argc, argv, params) == false) {
return 1;
}
// whisper init
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
struct whisper_context * ctx = whisper_init(params.model.c_str());
{
fprintf(stderr, "\n");
@ -92,22 +93,3 @@ int whisper_bench_encoder(const whisper_params & params) {
return 0;
}
int main(int argc, char ** argv) {
whisper_params params;
if (whisper_params_parse(argc, argv, params) == false) {
return 1;
}
int ret = -1;
switch (params.what) {
case 0: ret = whisper_bench_encoder(params); break;
case 1: ret = whisper_bench_memcpy(params.n_threads); break;
case 2: ret = whisper_bench_ggml_mul_mat(params.n_threads); break;
default: fprintf(stderr, "error: unknown benchmark: %d\n", params.what); break;
}
return ret;
}

View File

@ -1,49 +0,0 @@
#
# libcommand
#
set(TARGET libcommand)
add_executable(${TARGET}
emscripten.cpp
)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE
whisper
)
unset(EXTRA_FLAGS)
if (WHISPER_WASM_SINGLE_FILE)
set(EXTRA_FLAGS "-s SINGLE_FILE=1")
message(STATUS "Embedding WASM inside command.js")
add_custom_command(
TARGET ${TARGET} POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy
${CMAKE_BINARY_DIR}/bin/libcommand.js
${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/command.wasm/command.js
)
endif()
set_target_properties(${TARGET} PROPERTIES LINK_FLAGS " \
--bind \
-s USE_PTHREADS=1 \
-s PTHREAD_POOL_SIZE=8 \
-s INITIAL_MEMORY=1024MB \
-s TOTAL_MEMORY=1024MB \
-s FORCE_FILESYSTEM=1 \
-s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap']\" \
${EXTRA_FLAGS} \
")
#
# command.wasm
#
set(TARGET command.wasm)
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/index-tmpl.html ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TARGET}/index.html @ONLY)
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/../helpers.js ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TARGET}/helpers.js @ONLY)

View File

@ -1,23 +0,0 @@
# command.wasm
This is a basic Voice Assistant example that accepts voice commands from the microphone.
It runs in fully in the browser via WebAseembly.
Online demo: https://whisper.ggerganov.com/command/
Terminal version: [examples/command](/examples/command)
## Build instructions
```bash
# build using Emscripten (v3.1.2)
git clone https://github.com/ggerganov/whisper.cpp
cd whisper.cpp
mkdir build-em && cd build-em
emcmake cmake ..
make -j
# copy the produced page to your HTTP path
cp bin/command.wasm/* /path/to/html/
cp bin/libcommand.worker.js /path/to/html/
```

View File

@ -1,408 +0,0 @@
#include "ggml.h"
#include "whisper.h"
#include <emscripten.h>
#include <emscripten/bind.h>
#include <atomic>
#include <cmath>
#include <mutex>
#include <string>
#include <thread>
#include <vector>
#include <regex>
constexpr int N_THREAD = 8;
std::vector<struct whisper_context *> g_contexts(4, nullptr);
std::mutex g_mutex;
std::thread g_worker;
std::atomic<bool> g_running(false);
std::string g_status = "";
std::string g_status_forced = "";
std::string g_transcribed = "";
std::vector<float> g_pcmf32;
static std::string trim(const std::string & s) {
std::regex e("^\\s+|\\s+$");
return std::regex_replace(s, e, "");
}
static void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate) {
const float rc = 1.0f / (2.0f * M_PI * cutoff);
const float dt = 1.0f / sample_rate;
const float alpha = dt / (rc + dt);
float y = data[0];
for (size_t i = 1; i < data.size(); i++) {
y = alpha * (y + data[i] - data[i - 1]);
data[i] = y;
}
}
// compute similarity between two strings using Levenshtein distance
static float similarity(const std::string & s0, const std::string & s1) {
const size_t len0 = s0.size() + 1;
const size_t len1 = s1.size() + 1;
std::vector<int> col(len1, 0);
std::vector<int> prevCol(len1, 0);
for (size_t i = 0; i < len1; i++) {
prevCol[i] = i;
}
for (size_t i = 0; i < len0; i++) {
col[0] = i;
for (size_t j = 1; j < len1; j++) {
col[j] = std::min(std::min(1 + col[j - 1], 1 + prevCol[j]), prevCol[j - 1] + (s0[i - 1] == s1[j - 1] ? 0 : 1));
}
col.swap(prevCol);
}
const float dist = prevCol[len1 - 1];
return 1.0f - (dist / std::max(s0.size(), s1.size()));
}
void command_set_status(const std::string & status) {
std::lock_guard<std::mutex> lock(g_mutex);
g_status = status;
}
bool command_vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) {
const int n_samples = pcmf32.size();
const int n_samples_last = (sample_rate * last_ms) / 1000;
if (n_samples_last >= n_samples) {
// not enough samples - assume no speech
return false;
}
if (freq_thold > 0.0f) {
high_pass_filter(pcmf32, freq_thold, sample_rate);
}
float energy_all = 0.0f;
float energy_last = 0.0f;
for (size_t i = 0; i < n_samples; i++) {
energy_all += fabsf(pcmf32[i]);
if (i >= n_samples - n_samples_last) {
energy_last += fabsf(pcmf32[i]);
}
}
energy_all /= n_samples;
energy_last /= n_samples_last;
if (verbose) {
fprintf(stderr, "%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold);
}
if (energy_last > vad_thold*energy_all) {
return false;
}
return true;
}
std::string command_transcribe(whisper_context * ctx, const whisper_full_params & wparams, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) {
const auto t_start = std::chrono::high_resolution_clock::now();
prob = 0.0f;
t_ms = 0;
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
return "";
}
int prob_n = 0;
std::string result;
const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
result += text;
const int n_tokens = whisper_full_n_tokens(ctx, i);
for (int j = 0; j < n_tokens; ++j) {
const auto token = whisper_full_get_token_data(ctx, i, j);
prob += token.p;
++prob_n;
}
}
if (prob_n > 0) {
prob /= prob_n;
}
const auto t_end = std::chrono::high_resolution_clock::now();
t_ms = std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count();
return result;
}
void command_get_audio(int ms, int sample_rate, std::vector<float> & audio) {
const int64_t n_samples = (ms * sample_rate) / 1000;
int64_t n_take = 0;
if (g_pcmf32.size() < n_samples) {
n_take = g_pcmf32.size();
} else {
n_take = n_samples;
}
audio.resize(n_take);
std::copy(g_pcmf32.end() - n_take, g_pcmf32.end(), audio.begin());
}
void command_main(size_t index) {
command_set_status("loading data ...");
struct whisper_full_params wparams = whisper_full_default_params(whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY);
wparams.n_threads = std::min(N_THREAD, (int) std::thread::hardware_concurrency());
wparams.offset_ms = 0;
wparams.translate = false;
wparams.no_context = true;
wparams.single_segment = true;
wparams.print_realtime = false;
wparams.print_progress = false;
wparams.print_timestamps = true;
wparams.print_special = false;
wparams.max_tokens = 32;
wparams.audio_ctx = 768; // partial encoder context for better performance
wparams.language = "en";
printf("command: using %d threads\n", wparams.n_threads);
bool is_running = true;
bool have_prompt = false;
bool ask_prompt = true;
bool print_energy = false;
float prob0 = 0.0f;
float prob = 0.0f;
std::vector<float> pcmf32_cur;
std::vector<float> pcmf32_prompt;
const std::string k_prompt = "Ok Whisper, start listening for commands.";
// whisper context
auto & ctx = g_contexts[index];
const int32_t vad_ms = 2000;
const int32_t prompt_ms = 5000;
const int32_t command_ms = 4000;
const float vad_thold = 0.1f;
const float freq_thold = -1.0f;
while (g_running) {
// delay
std::this_thread::sleep_for(std::chrono::milliseconds(100));
if (ask_prompt) {
fprintf(stdout, "\n");
fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m");
fprintf(stdout, "\n");
{
char txt[1024];
snprintf(txt, sizeof(txt), "Say the following phrase: '%s'", k_prompt.c_str());
command_set_status(txt);
}
ask_prompt = false;
}
int64_t t_ms = 0;
{
command_get_audio(vad_ms, WHISPER_SAMPLE_RATE, pcmf32_cur);
if (command_vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, vad_thold, freq_thold, print_energy)) {
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
command_set_status("Speech detected! Processing ...");
if (!have_prompt) {
command_get_audio(prompt_ms, WHISPER_SAMPLE_RATE, pcmf32_cur);
const auto txt = ::trim(::command_transcribe(ctx, wparams, pcmf32_cur, prob0, t_ms));
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms);
const float sim = similarity(txt, k_prompt);
if (txt.length() < 0.8*k_prompt.length() || txt.length() > 1.2*k_prompt.length() || sim < 0.8f) {
fprintf(stdout, "%s: WARNING: prompt not recognized, try again\n", __func__);
ask_prompt = true;
} else {
fprintf(stdout, "\n");
fprintf(stdout, "%s: The prompt has been recognized!\n", __func__);
fprintf(stdout, "%s: Waiting for voice commands ...\n", __func__);
fprintf(stdout, "\n");
{
char txt[1024];
snprintf(txt, sizeof(txt), "Success! Waiting for voice commands ...");
command_set_status(txt);
}
// save the audio for the prompt
pcmf32_prompt = pcmf32_cur;
have_prompt = true;
}
} else {
command_get_audio(command_ms, WHISPER_SAMPLE_RATE, pcmf32_cur);
// prepend the prompt audio
pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
const auto txt = ::trim(::command_transcribe(ctx, wparams, pcmf32_cur, prob, t_ms));
prob = 100.0f*(prob - prob0);
fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str());
// find the prompt in the text
float best_sim = 0.0f;
size_t best_len = 0;
for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
const auto prompt = txt.substr(0, n);
const float sim = similarity(prompt, k_prompt);
//fprintf(stderr, "%s: prompt = '%s', sim = %f\n", __func__, prompt.c_str(), sim);
if (sim > best_sim) {
best_sim = sim;
best_len = n;
}
}
const std::string command = ::trim(txt.substr(best_len));
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
fprintf(stdout, "\n");
{
char txt[1024];
snprintf(txt, sizeof(txt), "Command '%s', (t = %d ms)", command.c_str(), (int) t_ms);
command_set_status(txt);
}
{
std::lock_guard<std::mutex> lock(g_mutex);
g_transcribed = command;
}
}
g_pcmf32.clear();
}
}
}
if (index < g_contexts.size()) {
whisper_free(g_contexts[index]);
g_contexts[index] = nullptr;
}
}
EMSCRIPTEN_BINDINGS(command) {
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
for (size_t i = 0; i < g_contexts.size(); ++i) {
if (g_contexts[i] == nullptr) {
g_contexts[i] = whisper_init_from_file(path_model.c_str());
if (g_contexts[i] != nullptr) {
g_running = true;
if (g_worker.joinable()) {
g_worker.join();
}
g_worker = std::thread([i]() {
command_main(i);
});
return i + 1;
} else {
return (size_t) 0;
}
}
}
return (size_t) 0;
}));
emscripten::function("free", emscripten::optional_override([](size_t index) {
if (g_running) {
g_running = false;
}
}));
emscripten::function("set_audio", emscripten::optional_override([](size_t index, const emscripten::val & audio) {
--index;
if (index >= g_contexts.size()) {
return -1;
}
if (g_contexts[index] == nullptr) {
return -2;
}
{
std::lock_guard<std::mutex> lock(g_mutex);
const int n = audio["length"].as<int>();
emscripten::val heap = emscripten::val::module_property("HEAPU8");
emscripten::val memory = heap["buffer"];
g_pcmf32.resize(n);
emscripten::val memoryView = audio["constructor"].new_(memory, reinterpret_cast<uintptr_t>(g_pcmf32.data()), n);
memoryView.call<void>("set", audio);
}
return 0;
}));
emscripten::function("get_transcribed", emscripten::optional_override([]() {
std::string transcribed;
{
std::lock_guard<std::mutex> lock(g_mutex);
transcribed = std::move(g_transcribed);
}
return transcribed;
}));
emscripten::function("get_status", emscripten::optional_override([]() {
std::string status;
{
std::lock_guard<std::mutex> lock(g_mutex);
status = g_status_forced.empty() ? g_status : g_status_forced;
}
return status;
}));
emscripten::function("set_status", emscripten::optional_override([](const std::string & status) {
{
std::lock_guard<std::mutex> lock(g_mutex);
g_status_forced = status;
}
}));
}

View File

@ -1,386 +0,0 @@
<!doctype html>
<html lang="en-us">
<head>
<title>command : Voice assistant example using Whisper + WebAssembly</title>
<style>
#output {
width: 100%;
height: 100%;
margin: 0 auto;
margin-top: 10px;
border-left: 0px;
border-right: 0px;
padding-left: 0px;
padding-right: 0px;
display: block;
background-color: black;
color: white;
font-size: 10px;
font-family: 'Lucida Console', Monaco, monospace;
outline: none;
white-space: pre;
overflow-wrap: normal;
overflow-x: scroll;
}
</style>
</head>
<body>
<div id="main-container">
<b>command : Voice assistant example using Whisper + WebAssembly</b>
<br><br>
You can find more about this project on <a href="https://github.com/ggerganov/whisper.cpp/tree/master/examples/command.wasm">GitHub</a>.
<br><br>
<hr>
Select the model you would like to use, click the "Start" button and follow the instructions.
<br><br>
<div id="model-whisper">
Whisper model: <span id="model-whisper-status"></span>
<button id="fetch-whisper-tiny-en" onclick="loadWhisper('tiny.en')">tiny.en (75 MB)</button>
<button id="fetch-whisper-base-en" onclick="loadWhisper('base.en')">base.en (142 MB)</button>
<span id="fetch-whisper-progress"></span>
<!--
<input type="file" id="file" name="file" onchange="loadFile(event, 'whisper.bin')" />
-->
</div>
<br>
<div id="input">
<button id="start" onclick="onStart()" disabled>Start</button>
<button id="stop" onclick="onStop()" disabled>Stop</button>
<button id="clear" onclick="clearCache()">Clear Cache</button>
</div>
<br>
<div id="state">
Status: <b><span id="state-status">not started</span></b>
<pre id="state-transcribed">[The recognized voice commands will be displayed here]</pre>
</div>
<hr>
Debug output:
<textarea id="output" rows="20"></textarea>
<br>
<b>Troubleshooting</b>
<br><br>
The page does some heavy computations, so make sure:
<ul>
<li>To use a modern web browser (e.g. Chrome, Firefox)</li>
<li>To use a fast desktop or laptop computer (i.e. not a mobile phone)</li>
<li>Your browser supports WASM <a href="https://webassembly.org/roadmap/">Fixed-width SIMD</a></li>
</ul>
<div class="cell-version">
<span>
|
Build time: <span class="nav-link">@GIT_DATE@</span> |
Commit hash: <a class="nav-link" href="https://github.com/ggerganov/whisper.cpp/commit/@GIT_SHA1@">@GIT_SHA1@</a> |
Commit subject: <span class="nav-link">@GIT_COMMIT_SUBJECT@</span> |
<a class="nav-link" href="https://github.com/ggerganov/whisper.cpp/tree/master/examples/command.wasm">Source Code</a> |
</span>
</div>
</div>
<script type="text/javascript" src="helpers.js"></script>
<script type='text/javascript'>
// web audio context
var context = null;
// audio data
var audio = null;
var audio0 = null;
// the command instance
var instance = null;
// model name
var model_whisper = null;
var Module = {
print: printTextarea,
printErr: printTextarea,
setStatus: function(text) {
printTextarea('js: ' + text);
},
monitorRunDependencies: function(left) {
},
preRun: function() {
printTextarea('js: Preparing ...');
},
postRun: function() {
printTextarea('js: Initialized successfully!');
}
};
//
// fetch models
//
let dbVersion = 1
let dbName = 'whisper.ggerganov.com';
let indexedDB = window.indexedDB || window.mozIndexedDB || window.webkitIndexedDB || window.msIndexedDB
function storeFS(fname, buf) {
// write to WASM file using FS_createDataFile
// if the file exists, delete it
try {
Module.FS_unlink(fname);
} catch (e) {
// ignore
}
Module.FS_createDataFile("/", fname, buf, true, true);
printTextarea('storeFS: stored model: ' + fname + ' size: ' + buf.length);
document.getElementById('model-whisper-status').innerHTML = 'loaded "' + model_whisper + '"!';
if (model_whisper != null) {
document.getElementById('start').disabled = false;
document.getElementById('stop' ).disabled = true;
}
}
function loadWhisper(model) {
let urls = {
'tiny.en': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en.bin',
'base.en': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en.bin',
};
let sizes = {
'tiny.en': 75,
'base.en': 142,
};
let url = urls[model];
let dst = 'whisper.bin';
let size_mb = sizes[model];
model_whisper = model;
document.getElementById('fetch-whisper-tiny-en').style.display = 'none';
document.getElementById('fetch-whisper-base-en').style.display = 'none';
document.getElementById('model-whisper-status').innerHTML = 'loading "' + model + '" ... ';
cbProgress = function(p) {
let el = document.getElementById('fetch-whisper-progress');
el.innerHTML = Math.round(100*p) + '%';
};
cbCancel = function() {
var el;
el = document.getElementById('fetch-whisper-tiny-en'); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-base-en'); if (el) el.style.display = 'inline-block';
el = document.getElementById('model-whisper-status'); if (el) el.innerHTML = '';
};
loadRemote(url, dst, size_mb, cbProgress, storeFS, cbCancel, printTextarea);
}
//
// microphone
//
const kSampleRate = 16000;
const kRestartRecording_s = 120;
const kIntervalAudio_ms = 250; // pass the recorded audio to the C++ instance at this rate
var mediaRecorder = null;
var doRecording = false;
var startTime = 0;
window.AudioContext = window.AudioContext || window.webkitAudioContext;
window.OfflineAudioContext = window.OfflineAudioContext || window.webkitOfflineAudioContext;
function stopRecording() {
Module.set_status("paused");
doRecording = false;
audio0 = null;
audio = null;
context = null;
}
function startRecording() {
if (!context) {
context = new AudioContext({
sampleRate: kSampleRate,
channelCount: 1,
echoCancellation: false,
autoGainControl: true,
noiseSuppression: true,
});
}
Module.set_status("");
document.getElementById('start').disabled = true;
document.getElementById('stop').disabled = false;
doRecording = true;
startTime = Date.now();
var chunks = [];
var stream = null;
navigator.mediaDevices.getUserMedia({audio: true, video: false})
.then(function(s) {
stream = s;
mediaRecorder = new MediaRecorder(stream);
mediaRecorder.ondataavailable = function(e) {
chunks.push(e.data);
var blob = new Blob(chunks, { 'type' : 'audio/ogg; codecs=opus' });
var reader = new FileReader();
reader.onload = function(event) {
var buf = new Uint8Array(reader.result);
if (!context) {
return;
}
context.decodeAudioData(buf.buffer, function(audioBuffer) {
var offlineContext = new OfflineAudioContext(audioBuffer.numberOfChannels, audioBuffer.length, audioBuffer.sampleRate);
var source = offlineContext.createBufferSource();
source.buffer = audioBuffer;
source.connect(offlineContext.destination);
source.start(0);
offlineContext.startRendering().then(function(renderedBuffer) {
audio = renderedBuffer.getChannelData(0);
//printTextarea('js: audio recorded, size: ' + audio.length + ', old size: ' + (audio0 == null ? 0 : audio0.length));
var audioAll = new Float32Array(audio0 == null ? audio.length : audio0.length + audio.length);
if (audio0 != null) {
audioAll.set(audio0, 0);
}
audioAll.set(audio, audio0 == null ? 0 : audio0.length);
if (instance) {
Module.set_audio(instance, audioAll);
}
});
}, function(e) {
audio = null;
});
}
reader.readAsArrayBuffer(blob);
};
mediaRecorder.onstop = function(e) {
if (doRecording) {
setTimeout(function() {
startRecording();
});
}
};
mediaRecorder.start(kIntervalAudio_ms);
})
.catch(function(err) {
printTextarea('js: error getting audio stream: ' + err);
});
var interval = setInterval(function() {
if (!doRecording) {
clearInterval(interval);
mediaRecorder.stop();
stream.getTracks().forEach(function(track) {
track.stop();
});
document.getElementById('start').disabled = false;
document.getElementById('stop').disabled = true;
mediaRecorder = null;
}
// if audio length is more than kRestartRecording_s seconds, restart recording
if (audio != null && audio.length > kSampleRate*kRestartRecording_s) {
if (doRecording) {
//printTextarea('js: restarting recording');
clearInterval(interval);
audio0 = audio;
audio = null;
mediaRecorder.stop();
stream.getTracks().forEach(function(track) {
track.stop();
});
}
}
}, 100);
}
//
// main
//
var nLines = 0;
var intervalUpdate = null;
var transcribedAll = '';
function onStart() {
if (!instance) {
instance = Module.init('whisper.bin');
if (instance) {
printTextarea("js: whisper initialized, instance: " + instance);
}
}
if (!instance) {
printTextarea("js: failed to initialize whisper");
return;
}
startRecording();
intervalUpdate = setInterval(function() {
var transcribed = Module.get_transcribed();
if (transcribed != null && transcribed.length > 1) {
transcribedAll += transcribed + '<br>';
nLines++;
// if more than 10 lines, remove the first line
if (nLines > 10) {
var i = transcribedAll.indexOf('<br>');
if (i > 0) {
transcribedAll = transcribedAll.substring(i + 4);
nLines--;
}
}
}
document.getElementById('state-status').innerHTML = Module.get_status();
document.getElementById('state-transcribed').innerHTML = transcribedAll;
}, 100);
}
function onStop() {
stopRecording();
}
</script>
<script type="text/javascript" src="command.js"></script>
</body>
</html>

View File

@ -1,10 +0,0 @@
if (WHISPER_SUPPORT_SDL2)
# command
set(TARGET command)
add_executable(${TARGET} command.cpp)
include(DefaultTargetOptions)
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
target_link_libraries(${TARGET} PRIVATE whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
endif ()

View File

@ -1,47 +0,0 @@
# command
This is a basic Voice Assistant example that accepts voice commands from the microphone.
More info is available in [issue #171](https://github.com/ggerganov/whisper.cpp/issues/171).
```bash
# Run with default arguments and small model
./command -m ./models/ggml-small.en.bin -t 8
# On Raspberry Pi, use tiny or base models + "-ac 768" for better performance
./command -m ./models/ggml-tiny.en.bin -ac 768 -t 3 -c 0
```
https://user-images.githubusercontent.com/1991296/204038393-2f846eae-c255-4099-a76d-5735c25c49da.mp4
Web version: [examples/command.wasm](/examples/command.wasm)
## Guided mode
"Guided mode" allows you to specify a list of commands (i.e. strings) and the transcription will be guided to classify your command into one from the list. This can be useful in situations where a device is listening only for a small subset of commands.
Initial tests show that this approach might be extremely efficient in terms of performance, since it integrates very well with the "partial Encoder" idea from #137.
```bash
# Run in guided mode, the list of allowed commands is in commands.txt
./command -m ./models/ggml-base.en.bin -cmd ./examples/command/commands.txt
# On Raspberry Pi, in guided mode you can use "-ac 128" for extra performance
./command -m ./models/ggml-tiny.en.bin -cmd ./examples/command/commands.txt -ac 128 -t 3 -c 0
```
https://user-images.githubusercontent.com/1991296/207435352-8fc4ed3f-bde5-4555-9b8b-aeeb76bee969.mp4
## Building
The `command` tool depends on SDL2 library to capture audio from the microphone. You can build it like this:
```bash
# Install SDL2 on Linux
sudo apt-get install libsdl2-dev
# Install SDL2 on Mac OS
brew install sdl2
make command
```

File diff suppressed because it is too large Load Diff

View File

@ -1,9 +0,0 @@
enable
disable
cat
dog
apple
red
blue
green
lightblue

View File

@ -1,16 +1,5 @@
#!/bin/bash
# Simple tool to record audio from the microphone and generate a karaoke video
# Usage:
#
# cd whisper.cpp
# make
#
# ./examples/generate-karaoke.sh [model] [step_ms]
#
# Press Ctrl+C to stop recording
#
executable="./main"
model="base.en"
model_path="models/ggml-$model.bin"

View File

@ -1,186 +0,0 @@
// Common Javascript functions used by the examples
function convertTypedArray(src, type) {
var buffer = new ArrayBuffer(src.byteLength);
var baseView = new src.constructor(buffer).set(src);
return new type(buffer);
}
var printTextarea = (function() {
var element = document.getElementById('output');
if (element) element.value = ''; // clear browser cache
return function(text) {
if (arguments.length > 1) text = Array.prototype.slice.call(arguments).join(' ');
console.log(text);
if (element) {
element.value += text + "\n";
element.scrollTop = element.scrollHeight; // focus on bottom
}
};
})();
async function clearCache() {
if (confirm('Are you sure you want to clear the cache?\nAll the models will be downloaded again.')) {
indexedDB.deleteDatabase(dbName);
}
}
// fetch a remote file from remote URL using the Fetch API
async function fetchRemote(url, cbProgress, cbPrint) {
cbPrint('fetchRemote: downloading with fetch()...');
const response = await fetch(
url,
{
method: 'GET',
headers: {
'Content-Type': 'application/octet-stream',
},
}
);
if (!response.ok) {
cbPrint('fetchRemote: failed to fetch ' + url);
return;
}
const contentLength = response.headers.get('content-length');
const total = parseInt(contentLength, 10);
const reader = response.body.getReader();
var chunks = [];
var receivedLength = 0;
var progressLast = -1;
while (true) {
const { done, value } = await reader.read();
if (done) {
break;
}
chunks.push(value);
receivedLength += value.length;
if (contentLength) {
cbProgress(receivedLength/total);
var progressCur = Math.round((receivedLength / total) * 10);
if (progressCur != progressLast) {
cbPrint('fetchRemote: fetching ' + 10*progressCur + '% ...');
progressLast = progressCur;
}
}
}
var position = 0;
var chunksAll = new Uint8Array(receivedLength);
for (var chunk of chunks) {
chunksAll.set(chunk, position);
position += chunk.length;
}
return chunksAll;
}
// load remote data
// - check if the data is already in the IndexedDB
// - if not, fetch it from the remote URL and store it in the IndexedDB
function loadRemote(url, dst, size_mb, cbProgress, cbReady, cbCancel, cbPrint) {
if (!navigator.storage || !navigator.storage.estimate) {
cbPrint('loadRemote: navigator.storage.estimate() is not supported');
} else {
// query the storage quota and print it
navigator.storage.estimate().then(function (estimate) {
cbPrint('loadRemote: storage quota: ' + estimate.quota + ' bytes');
cbPrint('loadRemote: storage usage: ' + estimate.usage + ' bytes');
});
}
// check if the data is already in the IndexedDB
var rq = indexedDB.open(dbName, dbVersion);
rq.onupgradeneeded = function (event) {
var db = event.target.result;
if (db.version == 1) {
var os = db.createObjectStore('models', { autoIncrement: false });
cbPrint('loadRemote: created IndexedDB ' + db.name + ' version ' + db.version);
} else {
// clear the database
var os = event.currentTarget.transaction.objectStore('models');
os.clear();
cbPrint('loadRemote: cleared IndexedDB ' + db.name + ' version ' + db.version);
}
};
rq.onsuccess = function (event) {
var db = event.target.result;
var tx = db.transaction(['models'], 'readonly');
var os = tx.objectStore('models');
var rq = os.get(url);
rq.onsuccess = function (event) {
if (rq.result) {
cbPrint('loadRemote: "' + url + '" is already in the IndexedDB');
cbReady(dst, rq.result);
} else {
// data is not in the IndexedDB
cbPrint('loadRemote: "' + url + '" is not in the IndexedDB');
// alert and ask the user to confirm
if (!confirm(
'You are about to download ' + size_mb + ' MB of data.\n' +
'The model data will be cached in the browser for future use.\n\n' +
'Press OK to continue.')) {
cbCancel();
return;
}
fetchRemote(url, cbProgress, cbPrint).then(function (data) {
if (data) {
// store the data in the IndexedDB
var rq = indexedDB.open(dbName, dbVersion);
rq.onsuccess = function (event) {
var db = event.target.result;
var tx = db.transaction(['models'], 'readwrite');
var os = tx.objectStore('models');
var rq = os.put(data, url);
rq.onsuccess = function (event) {
cbPrint('loadRemote: "' + url + '" stored in the IndexedDB');
cbReady(dst, data);
};
rq.onerror = function (event) {
cbPrint('loadRemote: failed to store "' + url + '" in the IndexedDB');
cbCancel();
};
};
}
});
}
};
rq.onerror = function (event) {
cbPrint('loadRemote: failed to get data from the IndexedDB');
cbCancel();
};
};
rq.onerror = function (event) {
cbPrint('loadRemote: failed to open IndexedDB');
cbCancel();
};
rq.onblocked = function (event) {
cbPrint('loadRemote: failed to open IndexedDB: blocked');
cbCancel();
};
rq.onabort = function (event) {
cbPrint('loadRemote: failed to open IndexedDB: abort');
};
}

View File

@ -1,112 +0,0 @@
#!/bin/bash
#
# Transcribe audio livestream by feeding ffmpeg output to whisper.cpp at regular intervals
# Idea by @semiformal-net
# ref: https://github.com/ggerganov/whisper.cpp/issues/185
#
set -eo pipefail
url="http://a.files.bbci.co.uk/media/live/manifesto/audio/simulcast/hls/nonuk/sbr_low/ak/bbc_world_service.m3u8"
fmt=aac # the audio format extension of the stream (TODO: auto detect)
step_s=30
model="base.en"
check_requirements()
{
if ! command -v ./main &>/dev/null; then
echo "whisper.cpp main executable is required (make)"
exit 1
fi
if ! command -v ffmpeg &>/dev/null; then
echo "ffmpeg is required (https://ffmpeg.org)"
exit 1
fi
}
check_requirements
if [ -z "$1" ]; then
echo "Usage: $0 stream_url [step_s] [model]"
echo ""
echo " Example:"
echo " $0 $url $step_s $model"
echo ""
echo "No url specified, using default: $url"
else
url="$1"
fi
if [ -n "$2" ]; then
step_s="$2"
fi
if [ -n "$3" ]; then
model="$3"
fi
# Whisper models
models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small" "medium.en" "medium" "large-v1" "large" )
# list available models
function list_models {
printf "\n"
printf " Available models:"
for model in "${models[@]}"; do
printf " $model"
done
printf "\n\n"
}
if [[ ! " ${models[@]} " =~ " ${model} " ]]; then
printf "Invalid model: $model\n"
list_models
exit 1
fi
running=1
trap "running=0" SIGINT SIGTERM
printf "[+] Transcribing stream with model '$model', step_s $step_s (press Ctrl+C to stop):\n\n"
# continuous stream in native fmt (this file will grow forever!)
ffmpeg -loglevel quiet -y -re -probesize 32 -i $url -c copy /tmp/whisper-live0.${fmt} &
if [ $? -ne 0 ]; then
printf "Error: ffmpeg failed to capture audio stream\n"
exit 1
fi
printf "Buffering audio. Please wait...\n\n"
sleep $(($step_s))
# do not stop script on error
set +e
i=0
SECONDS=0
while [ $running -eq 1 ]; do
# extract the next piece from the main file above and transcode to wav. -ss sets start time and nudges it by -0.5s to catch missing words (??)
err=1
while [ $err -ne 0 ]; do
if [ $i -gt 0 ]; then
ffmpeg -loglevel quiet -v error -noaccurate_seek -i /tmp/whisper-live0.${fmt} -y -ar 16000 -ac 1 -c:a pcm_s16le -ss $(($i*$step_s-1)).5 -t $step_s /tmp/whisper-live.wav 2> /tmp/whisper-live.err
else
ffmpeg -loglevel quiet -v error -noaccurate_seek -i /tmp/whisper-live0.${fmt} -y -ar 16000 -ac 1 -c:a pcm_s16le -ss $(($i*$step_s)) -t $step_s /tmp/whisper-live.wav 2> /tmp/whisper-live.err
fi
err=$(cat /tmp/whisper-live.err | wc -l)
done
./main -t 8 -m ./models/ggml-${model}.bin -f /tmp/whisper-live.wav --no-timestamps -otxt 2> /tmp/whispererr | tail -n 1
while [ $SECONDS -lt $((($i+1)*$step_s)) ]; do
sleep 1
done
((i=i+1))
done
killall -v ffmpeg
killall -v main

View File

@ -1,6 +1,3 @@
set(TARGET main)
add_executable(${TARGET} main.cpp)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE whisper ${CMAKE_THREAD_LIBS_INIT})

View File

@ -6,38 +6,29 @@ It can be used as a reference for using the `whisper.cpp` library in other proje
```
./main -h
usage: ./main [options] file0.wav file1.wav ...
usage: ./bin/main [options] file0.wav file1.wav ...
-h, --help show this help message and exit
-s SEED, --seed SEED RNG seed (default: -1)
-t N, --threads N number of threads to use during computation (default: 4)
-p N, --processors N number of processors to use during computation (default: 1)
-ot N, --offset-t N time offset in milliseconds (default: 0)
-on N, --offset-n N segment index offset (default: 0)
-mc N, --max-context N maximum number of text context tokens to store (default: max)
-ml N, --max-len N maximum segment length in characters (default: 0)
-wt N, --word-thold N word timestamp probability threshold (default: 0.010000)
-v, --verbose verbose output
--translate translate from source language to english
-otxt, --output-txt output result in a text file
-ovtt, --output-vtt output result in a vtt file
-osrt, --output-srt output result in a srt file
-owts, --output-words output script for generating karaoke video
-ps, --print_special print special tokens
-pc, --print_colors print colors
-nt, --no_timestamps do not print timestamps
-l LANG, --language LANG spoken language (default: en)
-m FNAME, --model FNAME model path (default: models/ggml-base.en.bin)
-f FNAME, --file FNAME input WAV file path
-h, --help show this help message and exit
options:
-h, --help [default] show this help message and exit
-t N, --threads N [4 ] number of threads to use during computation
-p N, --processors N [1 ] number of processors to use during computation
-ot N, --offset-t N [0 ] time offset in milliseconds
-on N, --offset-n N [0 ] segment index offset
-d N, --duration N [0 ] duration of audio to process in milliseconds
-mc N, --max-context N [-1 ] maximum number of text context tokens to store
-ml N, --max-len N [0 ] maximum segment length in characters
-bo N, --best-of N [5 ] number of best candidates to keep
-bs N, --beam-size N [-1 ] beam size for beam search
-wt N, --word-thold N [0.01 ] word timestamp probability threshold
-et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail
-lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail
-su, --speed-up [false ] speed up audio by x2 (reduced accuracy)
-tr, --translate [false ] translate from source language to english
-di, --diarize [false ] stereo audio diarization
-nf, --no-fallback [false ] do not use temperature fallback while decoding
-otxt, --output-txt [false ] output result in a text file
-ovtt, --output-vtt [false ] output result in a vtt file
-osrt, --output-srt [false ] output result in a srt file
-owts, --output-words [false ] output script for generating karaoke video
-ocsv, --output-csv [false ] output result in a CSV file
-of FNAME, --output-file FNAME [ ] output file path (without file extension)
-ps, --print-special [false ] print special tokens
-pc, --print-colors [false ] print colors
-pp, --print-progress [false ] print progress
-nt, --no-timestamps [true ] do not print timestamps
-l LANG, --language LANG [en ] spoken language ('auto' for auto-detect)
--prompt PROMPT [ ] initial prompt
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path
-f FNAME, --file FNAME [ ] input WAV file path
```

View File

@ -36,10 +36,6 @@ std::string to_timestamp(int64_t t, bool comma = false) {
return std::string(buf);
}
int timestamp_to_sample(int64_t t, int n_samples) {
return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100)));
}
// helper function to replace substrings
void replace_all(std::string & s, const std::string & search, const std::string & replace) {
for (size_t pos = 0; ; pos += replace.length()) {
@ -52,40 +48,31 @@ void replace_all(std::string & s, const std::string & search, const std::string
// command-line parameters
struct whisper_params {
int32_t seed = -1; // RNG seed, not used currently
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
int32_t n_processors = 1;
int32_t offset_t_ms = 0;
int32_t offset_n = 0;
int32_t duration_ms = 0;
int32_t n_processors = 1;
int32_t offset_t_ms = 0;
int32_t offset_n = 0;
int32_t duration_ms = 0;
int32_t max_context = -1;
int32_t max_len = 0;
int32_t best_of = 5;
int32_t beam_size = -1;
int32_t max_len = 0;
float word_thold = 0.01f;
float entropy_thold = 2.40f;
float logprob_thold = -1.00f;
float word_thold = 0.01f;
bool speed_up = false;
bool translate = false;
bool diarize = false;
bool no_fallback = false;
bool output_txt = false;
bool output_vtt = false;
bool output_srt = false;
bool output_wts = false;
bool output_csv = false;
bool print_special = false;
bool print_colors = false;
bool print_progress = false;
bool no_timestamps = false;
bool verbose = false;
bool translate = false;
bool output_txt = false;
bool output_vtt = false;
bool output_srt = false;
bool output_wts = false;
bool print_special_tokens = false;
bool print_colors = false;
bool no_timestamps = false;
std::string language = "en";
std::string prompt;
std::string model = "models/ggml-base.en.bin";
std::string language = "en";
std::string model = "models/ggml-base.en.bin";
std::vector<std::string> fname_inp = {};
std::vector<std::string> fname_outp = {};
};
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
@ -99,41 +86,57 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
continue;
}
if (arg == "-h" || arg == "--help") {
if (arg == "-s" || arg == "--seed") {
params.seed = std::stoi(argv[++i]);
} else if (arg == "-t" || arg == "--threads") {
params.n_threads = std::stoi(argv[++i]);
} else if (arg == "-p" || arg == "--processors") {
params.n_processors = std::stoi(argv[++i]);
} else if (arg == "-ot" || arg == "--offset-t") {
params.offset_t_ms = std::stoi(argv[++i]);
} else if (arg == "-on" || arg == "--offset-n") {
params.offset_n = std::stoi(argv[++i]);
} else if (arg == "-d" || arg == "--duration") {
params.duration_ms = std::stoi(argv[++i]);
} else if (arg == "-mc" || arg == "--max-context") {
params.max_context = std::stoi(argv[++i]);
} else if (arg == "-ml" || arg == "--max-len") {
params.max_len = std::stoi(argv[++i]);
} else if (arg == "-wt" || arg == "--word-thold") {
params.word_thold = std::stof(argv[++i]);
} else if (arg == "-v" || arg == "--verbose") {
params.verbose = true;
} else if (arg == "--translate") {
params.translate = true;
} else if (arg == "-l" || arg == "--language") {
params.language = argv[++i];
if (whisper_lang_id(params.language.c_str()) == -1) {
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
whisper_print_usage(argc, argv, params);
exit(0);
}
} else if (arg == "-otxt" || arg == "--output-txt") {
params.output_txt = true;
} else if (arg == "-ovtt" || arg == "--output-vtt") {
params.output_vtt = true;
} else if (arg == "-osrt" || arg == "--output-srt") {
params.output_srt = true;
} else if (arg == "-owts" || arg == "--output-words") {
params.output_wts = true;
} else if (arg == "-ps" || arg == "--print_special") {
params.print_special_tokens = true;
} else if (arg == "-pc" || arg == "--print_colors") {
params.print_colors = true;
} else if (arg == "-nt" || arg == "--no_timestamps") {
params.no_timestamps = true;
} else if (arg == "-m" || arg == "--model") {
params.model = argv[++i];
} else if (arg == "-f" || arg == "--file") {
params.fname_inp.push_back(argv[++i]);
} else if (arg == "-h" || arg == "--help") {
whisper_print_usage(argc, argv, params);
exit(0);
}
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(argv[++i]); }
else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(argv[++i]); }
else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(argv[++i]); }
else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); }
else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); }
else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); }
else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); }
else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); }
else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); }
else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); }
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; }
else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; }
else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; }
else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; }
else if (arg == "-of" || arg == "--output-file") { params.fname_outp.emplace_back(argv[++i]); }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
else if ( arg == "--prompt") { params.prompt = argv[++i]; }
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); }
else {
} else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params);
exit(0);
@ -143,132 +146,98 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
return true;
}
void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params) {
void whisper_print_usage(int argc, char ** argv, const whisper_params & params) {
fprintf(stderr, "\n");
fprintf(stderr, "usage: %s [options] file0.wav file1.wav ...\n", argv[0]);
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors);
fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms);
fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n);
fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms);
fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len);
fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of);
fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size);
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold);
fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold);
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false");
fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false");
fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false");
fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false");
fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false");
fprintf(stderr, " -of FNAME, --output-file FNAME [%-7s] output file path (without file extension)\n", "");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true");
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
fprintf(stderr, " -h, --help show this help message and exit\n");
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n");
fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
fprintf(stderr, " -p N, --processors N number of processors to use during computation (default: %d)\n", params.n_processors);
fprintf(stderr, " -ot N, --offset-t N time offset in milliseconds (default: %d)\n", params.offset_t_ms);
fprintf(stderr, " -on N, --offset-n N segment index offset (default: %d)\n", params.offset_n);
fprintf(stderr, " -d N, --duration N duration of audio to process in milliseconds (default: %d)\n", params.duration_ms);
fprintf(stderr, " -mc N, --max-context N maximum number of text context tokens to store (default: max)\n");
fprintf(stderr, " -ml N, --max-len N maximum segment length in characters (default: %d)\n", params.max_len);
fprintf(stderr, " -wt N, --word-thold N word timestamp probability threshold (default: %f)\n", params.word_thold);
fprintf(stderr, " -v, --verbose verbose output\n");
fprintf(stderr, " --translate translate from source language to english\n");
fprintf(stderr, " -otxt, --output-txt output result in a text file\n");
fprintf(stderr, " -ovtt, --output-vtt output result in a vtt file\n");
fprintf(stderr, " -osrt, --output-srt output result in a srt file\n");
fprintf(stderr, " -owts, --output-words output script for generating karaoke video\n");
fprintf(stderr, " -ps, --print_special print special tokens\n");
fprintf(stderr, " -pc, --print_colors print colors\n");
fprintf(stderr, " -nt, --no_timestamps do not print timestamps\n");
fprintf(stderr, " -l LANG, --language LANG spoken language (default: %s)\n", params.language.c_str());
fprintf(stderr, " -m FNAME, --model FNAME model path (default: %s)\n", params.model.c_str());
fprintf(stderr, " -f FNAME, --file FNAME input WAV file path\n");
fprintf(stderr, "\n");
}
struct whisper_print_user_data {
const whisper_params * params;
const std::vector<std::vector<float>> * pcmf32s;
};
void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) {
const auto & params = *((whisper_print_user_data *) user_data)->params;
const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s;
const whisper_params & params = *(whisper_params *) user_data;
const int n_segments = whisper_full_n_segments(ctx);
std::string speaker = "";
int64_t t0;
int64_t t1;
// print the last n_new segments
const int s0 = n_segments - n_new;
if (s0 == 0) {
printf("\n");
}
for (int i = s0; i < n_segments; i++) {
if (!params.no_timestamps || params.diarize) {
t0 = whisper_full_get_segment_t0(ctx, i);
t1 = whisper_full_get_segment_t1(ctx, i);
}
if (!params.no_timestamps) {
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
}
if (params.diarize && pcmf32s.size() == 2) {
const int64_t n_samples = pcmf32s[0].size();
const int64_t is0 = timestamp_to_sample(t0, n_samples);
const int64_t is1 = timestamp_to_sample(t1, n_samples);
double energy0 = 0.0f;
double energy1 = 0.0f;
for (int64_t j = is0; j < is1; j++) {
energy0 += fabs(pcmf32s[0][j]);
energy1 += fabs(pcmf32s[1][j]);
}
if (energy0 > 1.1*energy1) {
speaker = "(speaker 0)";
} else if (energy1 > 1.1*energy0) {
speaker = "(speaker 1)";
} else {
speaker = "(speaker ?)";
}
//printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str());
}
if (params.print_colors) {
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
if (params.print_special == false) {
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
if (id >= whisper_token_eot(ctx)) {
continue;
if (params.no_timestamps) {
if (params.print_colors) {
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
if (params.print_special_tokens == false) {
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
if (id >= whisper_token_eot(ctx)) {
continue;
}
}
const char * text = whisper_full_get_token_text(ctx, i, j);
const float p = whisper_full_get_token_p (ctx, i, j);
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
}
const char * text = whisper_full_get_token_text(ctx, i, j);
const float p = whisper_full_get_token_p (ctx, i, j);
const int col = std::max(0, std::min((int) k_colors.size() - 1, (int) (std::pow(p, 3)*float(k_colors.size()))));
printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
} else {
const char * text = whisper_full_get_segment_text(ctx, i);
printf("%s", text);
}
fflush(stdout);
} else {
const char * text = whisper_full_get_segment_text(ctx, i);
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
printf("%s%s", speaker.c_str(), text);
if (params.print_colors) {
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
if (params.print_special_tokens == false) {
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
if (id >= whisper_token_eot(ctx)) {
continue;
}
}
const char * text = whisper_full_get_token_text(ctx, i, j);
const float p = whisper_full_get_token_p (ctx, i, j);
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
}
printf("\n");
} else {
const char * text = whisper_full_get_segment_text(ctx, i);
printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
}
}
// with timestamps or speakers: each segment on new line
if (!params.no_timestamps || params.diarize) {
printf("\n");
}
fflush(stdout);
}
}
@ -284,7 +253,7 @@ bool output_txt(struct whisper_context * ctx, const char * fname) {
const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
fout << text << "\n";
fout << text;
}
return true;
@ -294,7 +263,7 @@ bool output_vtt(struct whisper_context * ctx, const char * fname) {
std::ofstream fout(fname);
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
return false;
return 9;
}
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
@ -337,32 +306,10 @@ bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_
return true;
}
bool output_csv(struct whisper_context * ctx, const char * fname) {
std::ofstream fout(fname);
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
return false;
}
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
//need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds.
fout << 10 * t0 << ", " << 10 * t1 << ", \"" << text << "\"\n";
}
return true;
}
// karaoke video generation
// outputs a bash script that uses ffmpeg to generate a video with the subtitles
// TODO: font parameter adjustments
bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & /*params*/, float t_sec) {
bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, float t_sec) {
std::ofstream fout(fname);
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
@ -411,6 +358,7 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
txt_ul = "\\ \\ ";
{
int ncnt = 0;
for (int k = 0; k < n; ++k) {
const auto & token2 = tokens[k];
@ -434,11 +382,13 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
txt_ul += "\\ ";
}
}
ncnt += txt.size();
}
::replace_all(txt_bg, "'", "\u2019");
::replace_all(txt_bg, "'", "");
::replace_all(txt_bg, "\"", "\\\"");
::replace_all(txt_fg, "'", "\u2019");
::replace_all(txt_fg, "'", "");
::replace_all(txt_fg, "\"", "\\\"");
}
@ -478,101 +428,54 @@ int main(int argc, char ** argv) {
return 1;
}
if (params.seed < 0) {
params.seed = time(NULL);
}
if (params.fname_inp.empty()) {
fprintf(stderr, "error: no input files specified\n");
whisper_print_usage(argc, argv, params);
return 2;
}
if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1) {
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
whisper_print_usage(argc, argv, params);
exit(0);
}
// whisper init
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
struct whisper_context * ctx = whisper_init(params.model.c_str());
if (ctx == nullptr) {
fprintf(stderr, "error: failed to initialize whisper context\n");
return 3;
}
// initial prompt
std::vector<whisper_token> prompt_tokens;
if (!params.prompt.empty()) {
prompt_tokens.resize(1024);
prompt_tokens.resize(whisper_tokenize(ctx, params.prompt.c_str(), prompt_tokens.data(), prompt_tokens.size()));
fprintf(stderr, "\n");
fprintf(stderr, "initial prompt: '%s'\n", params.prompt.c_str());
fprintf(stderr, "initial tokens: [ ");
for (int i = 0; i < (int) prompt_tokens.size(); ++i) {
fprintf(stderr, "%d ", prompt_tokens[i]);
}
fprintf(stderr, "]\n");
}
for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
const auto fname_inp = params.fname_inp[f];
const auto fname_outp = f < (int) params.fname_outp.size() && !params.fname_outp[f].empty() ? params.fname_outp[f] : params.fname_inp[f];
std::vector<float> pcmf32; // mono-channel F32 PCM
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
// WAV input
std::vector<float> pcmf32;
{
drwav wav;
std::vector<uint8_t> wav_data; // used for pipe input from stdin
if (fname_inp == "-") {
{
uint8_t buf[1024];
while (true)
{
const size_t n = fread(buf, 1, sizeof(buf), stdin);
if (n == 0) {
break;
}
wav_data.insert(wav_data.end(), buf, buf + n);
}
}
if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) {
fprintf(stderr, "error: failed to open WAV file from stdin\n");
return 4;
}
fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
}
else if (drwav_init_file(&wav, fname_inp.c_str(), nullptr) == false) {
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
return 5;
if (!drwav_init_file(&wav, fname_inp.c_str(), NULL)) {
fprintf(stderr, "%s: failed to open WAV file '%s' - check your input\n", argv[0], fname_inp.c_str());
whisper_print_usage(argc, argv, {});
return 4;
}
if (wav.channels != 1 && wav.channels != 2) {
fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", argv[0], fname_inp.c_str());
return 6;
}
if (params.diarize && wav.channels != 2 && params.no_timestamps == false) {
fprintf(stderr, "%s: WAV file '%s' must be stereo for diarization and timestamps have to be enabled\n", argv[0], fname_inp.c_str());
return 6;
return 5;
}
if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
fprintf(stderr, "%s: WAV file '%s' must be %i kHz\n", argv[0], fname_inp.c_str(), WHISPER_SAMPLE_RATE/1000);
return 8;
fprintf(stderr, "%s: WAV file '%s' must be 16 kHz\n", argv[0], fname_inp.c_str());
return 6;
}
if (wav.bitsPerSample != 16) {
fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", argv[0], fname_inp.c_str());
return 9;
return 7;
}
const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8);
int n = wav.totalPCMFrameCount;
std::vector<int16_t> pcm16;
pcm16.resize(n*wav.channels);
@ -582,26 +485,14 @@ int main(int argc, char ** argv) {
// convert to mono, float
pcmf32.resize(n);
if (wav.channels == 1) {
for (uint64_t i = 0; i < n; i++) {
for (int i = 0; i < n; i++) {
pcmf32[i] = float(pcm16[i])/32768.0f;
}
} else {
for (uint64_t i = 0; i < n; i++) {
for (int i = 0; i < n; i++) {
pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
}
}
if (params.diarize) {
// convert to stereo, float
pcmf32s.resize(2);
pcmf32s[0].resize(n);
pcmf32s[1].resize(n);
for (uint64_t i = 0; i < n; i++) {
pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
}
}
}
// print system information
@ -631,63 +522,35 @@ int main(int argc, char ** argv) {
fprintf(stderr, "\n");
}
// run the inference
{
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
wparams.strategy = params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY;
wparams.print_realtime = false;
wparams.print_progress = false;
wparams.print_timestamps = !params.no_timestamps;
wparams.print_special_tokens = params.print_special_tokens;
wparams.translate = params.translate;
wparams.language = params.language.c_str();
wparams.n_threads = params.n_threads;
wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
wparams.offset_ms = params.offset_t_ms;
wparams.duration_ms = params.duration_ms;
wparams.print_realtime = false;
wparams.print_progress = params.print_progress;
wparams.print_timestamps = !params.no_timestamps;
wparams.print_special = params.print_special;
wparams.translate = params.translate;
wparams.language = params.language.c_str();
wparams.n_threads = params.n_threads;
wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
wparams.offset_ms = params.offset_t_ms;
wparams.duration_ms = params.duration_ms;
wparams.token_timestamps = params.output_wts || params.max_len > 0;
wparams.thold_pt = params.word_thold;
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
wparams.speed_up = params.speed_up;
wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();
wparams.greedy.best_of = params.best_of;
wparams.beam_search.beam_size = params.beam_size;
wparams.temperature_inc = params.no_fallback ? 0.0f : wparams.temperature_inc;
wparams.entropy_thold = params.entropy_thold;
wparams.logprob_thold = params.logprob_thold;
whisper_print_user_data user_data = { &params, &pcmf32s };
wparams.token_timestamps = params.output_wts || params.max_len > 0;
wparams.thold_pt = params.word_thold;
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
// this callback is called on each new segment
if (!wparams.print_realtime) {
wparams.new_segment_callback = whisper_print_segment_callback;
wparams.new_segment_callback_user_data = &user_data;
}
// example for abort mechanism
// in this example, we do not abort the processing, but we could if the flag is set to true
// the callback is called before every encoder run - if it returns false, the processing is aborted
{
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, void * user_data) {
bool is_aborted = *(bool*)user_data;
return !is_aborted;
};
wparams.encoder_begin_callback_user_data = &is_aborted;
wparams.new_segment_callback_user_data = &params;
}
if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
fprintf(stderr, "%s: failed to process audio\n", argv[0]);
return 10;
return 8;
}
}
@ -697,34 +560,27 @@ int main(int argc, char ** argv) {
// output to text file
if (params.output_txt) {
const auto fname_txt = fname_outp + ".txt";
const auto fname_txt = fname_inp + ".txt";
output_txt(ctx, fname_txt.c_str());
}
// output to VTT file
if (params.output_vtt) {
const auto fname_vtt = fname_outp + ".vtt";
const auto fname_vtt = fname_inp + ".vtt";
output_vtt(ctx, fname_vtt.c_str());
}
// output to SRT file
if (params.output_srt) {
const auto fname_srt = fname_outp + ".srt";
const auto fname_srt = fname_inp + ".srt";
output_srt(ctx, fname_srt.c_str(), params);
}
// output to WTS file
if (params.output_wts) {
const auto fname_wts = fname_outp + ".wts";
const auto fname_wts = fname_inp + ".wts";
output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE);
}
// output to CSV file
if (params.output_csv) {
const auto fname_csv = fname_outp + ".csv";
output_csv(ctx, fname_csv.c_str());
}
}
}

View File

@ -1,49 +0,0 @@
#
# libstream
#
set(TARGET libstream)
add_executable(${TARGET}
emscripten.cpp
)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE
whisper
)
unset(EXTRA_FLAGS)
if (WHISPER_WASM_SINGLE_FILE)
set(EXTRA_FLAGS "-s SINGLE_FILE=1")
message(STATUS "Embedding WASM inside stream.js")
add_custom_command(
TARGET ${TARGET} POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy
${CMAKE_BINARY_DIR}/bin/libstream.js
${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/stream.wasm/stream.js
)
endif()
set_target_properties(${TARGET} PROPERTIES LINK_FLAGS " \
--bind \
-s USE_PTHREADS=1 \
-s PTHREAD_POOL_SIZE=8 \
-s INITIAL_MEMORY=1024MB \
-s TOTAL_MEMORY=1024MB \
-s FORCE_FILESYSTEM=1 \
-s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap']\" \
${EXTRA_FLAGS} \
")
#
# stream.wasm
#
set(TARGET stream.wasm)
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/index-tmpl.html ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TARGET}/index.html @ONLY)
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/../helpers.js ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TARGET}/helpers.js @ONLY)

View File

@ -1,20 +0,0 @@
# stream.wasm
Real-time transcription in the browser using WebAssembly
Online demo: https://whisper.ggerganov.com/stream/
## Build instructions
```bash
# build using Emscripten (v3.1.2)
git clone https://github.com/ggerganov/whisper.cpp
cd whisper.cpp
mkdir build-em && cd build-em
emcmake cmake ..
make -j
# copy the produced page to your HTTP path
cp bin/stream.wasm/* /path/to/html/
cp bin/libstream.worker.js /path/to/html/
```

View File

@ -1,216 +0,0 @@
#include "ggml.h"
#include "whisper.h"
#include <emscripten.h>
#include <emscripten/bind.h>
#include <atomic>
#include <cmath>
#include <mutex>
#include <string>
#include <thread>
#include <vector>
constexpr int N_THREAD = 8;
std::vector<struct whisper_context *> g_contexts(4, nullptr);
std::mutex g_mutex;
std::thread g_worker;
std::atomic<bool> g_running(false);
std::string g_status = "";
std::string g_status_forced = "";
std::string g_transcribed = "";
std::vector<float> g_pcmf32;
void stream_set_status(const std::string & status) {
std::lock_guard<std::mutex> lock(g_mutex);
g_status = status;
}
void stream_main(size_t index) {
stream_set_status("loading data ...");
struct whisper_full_params wparams = whisper_full_default_params(whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY);
wparams.n_threads = std::min(N_THREAD, (int) std::thread::hardware_concurrency());
wparams.offset_ms = 0;
wparams.translate = false;
wparams.no_context = true;
wparams.single_segment = true;
wparams.print_realtime = false;
wparams.print_progress = false;
wparams.print_timestamps = true;
wparams.print_special = false;
wparams.max_tokens = 32;
wparams.audio_ctx = 768; // partial encoder context for better performance
// disable temperature fallback
wparams.temperature_inc = -1.0f;
wparams.language = "en";
printf("stream: using %d threads\n", wparams.n_threads);
std::vector<float> pcmf32;
// whisper context
auto & ctx = g_contexts[index];
// 5 seconds interval
const int64_t window_samples = 5*WHISPER_SAMPLE_RATE;
while (g_running) {
stream_set_status("waiting for audio ...");
{
std::unique_lock<std::mutex> lock(g_mutex);
if (g_pcmf32.size() < 1024) {
lock.unlock();
std::this_thread::sleep_for(std::chrono::milliseconds(10));
continue;
}
pcmf32 = std::vector<float>(g_pcmf32.end() - std::min((int64_t) g_pcmf32.size(), window_samples), g_pcmf32.end());
g_pcmf32.clear();
}
{
const auto t_start = std::chrono::high_resolution_clock::now();
stream_set_status("running whisper ...");
int ret = whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size());
if (ret != 0) {
printf("whisper_full() failed: %d\n", ret);
break;
}
const auto t_end = std::chrono::high_resolution_clock::now();
printf("stream: whisper_full() returned %d in %f seconds\n", ret, std::chrono::duration<double>(t_end - t_start).count());
}
{
std::string text_heard;
{
const int n_segments = whisper_full_n_segments(ctx);
for (int i = n_segments - 1; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
printf("transcribed: %s\n", text);
text_heard += text;
}
}
{
std::lock_guard<std::mutex> lock(g_mutex);
g_transcribed = text_heard;
}
}
}
if (index < g_contexts.size()) {
whisper_free(g_contexts[index]);
g_contexts[index] = nullptr;
}
}
EMSCRIPTEN_BINDINGS(stream) {
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
for (size_t i = 0; i < g_contexts.size(); ++i) {
if (g_contexts[i] == nullptr) {
g_contexts[i] = whisper_init_from_file(path_model.c_str());
if (g_contexts[i] != nullptr) {
g_running = true;
if (g_worker.joinable()) {
g_worker.join();
}
g_worker = std::thread([i]() {
stream_main(i);
});
return i + 1;
} else {
return (size_t) 0;
}
}
}
return (size_t) 0;
}));
emscripten::function("free", emscripten::optional_override([](size_t index) {
if (g_running) {
g_running = false;
}
}));
emscripten::function("set_audio", emscripten::optional_override([](size_t index, const emscripten::val & audio) {
--index;
if (index >= g_contexts.size()) {
return -1;
}
if (g_contexts[index] == nullptr) {
return -2;
}
{
std::lock_guard<std::mutex> lock(g_mutex);
const int n = audio["length"].as<int>();
emscripten::val heap = emscripten::val::module_property("HEAPU8");
emscripten::val memory = heap["buffer"];
g_pcmf32.resize(n);
emscripten::val memoryView = audio["constructor"].new_(memory, reinterpret_cast<uintptr_t>(g_pcmf32.data()), n);
memoryView.call<void>("set", audio);
}
return 0;
}));
emscripten::function("get_transcribed", emscripten::optional_override([]() {
std::string transcribed;
{
std::lock_guard<std::mutex> lock(g_mutex);
transcribed = std::move(g_transcribed);
}
return transcribed;
}));
emscripten::function("get_status", emscripten::optional_override([]() {
std::string status;
{
std::lock_guard<std::mutex> lock(g_mutex);
status = g_status_forced.empty() ? g_status : g_status_forced;
}
return status;
}));
emscripten::function("set_status", emscripten::optional_override([](const std::string & status) {
{
std::lock_guard<std::mutex> lock(g_mutex);
g_status_forced = status;
}
}));
}

View File

@ -1,386 +0,0 @@
<!doctype html>
<html lang="en-us">
<head>
<title>stream : Real-time Whisper transcription in WebAssembly</title>
<style>
#output {
width: 100%;
height: 100%;
margin: 0 auto;
margin-top: 10px;
border-left: 0px;
border-right: 0px;
padding-left: 0px;
padding-right: 0px;
display: block;
background-color: black;
color: white;
font-size: 10px;
font-family: 'Lucida Console', Monaco, monospace;
outline: none;
white-space: pre;
overflow-wrap: normal;
overflow-x: scroll;
}
</style>
</head>
<body>
<div id="main-container">
<b>stream : Real-time Whisper transcription in WebAssembly</b>
<br><br>
You can find more about this project on <a href="https://github.com/ggerganov/whisper.cpp/tree/master/examples/stream.wasm">GitHub</a>.
<br><br>
<hr>
Select the model you would like to use, click the "Start" button and start speaking
<br><br>
<div id="model-whisper">
Whisper model: <span id="model-whisper-status"></span>
<button id="fetch-whisper-tiny-en" onclick="loadWhisper('tiny.en')">tiny.en (75 MB)</button>
<button id="fetch-whisper-base-en" onclick="loadWhisper('base.en')">base.en (142 MB)</button>
<span id="fetch-whisper-progress"></span>
<!--
<input type="file" id="file" name="file" onchange="loadFile(event, 'whisper.bin')" />
-->
</div>
<br>
<div id="input">
<button id="start" onclick="onStart()" disabled>Start</button>
<button id="stop" onclick="onStop()" disabled>Stop</button>
<button id="clear" onclick="clearCache()">Clear Cache</button>
</div>
<br>
<div id="state">
Status: <b><span id="state-status">not started</span></b>
<pre id="state-transcribed">[The transcribed text will be displayed here]</pre>
</div>
<hr>
Debug output:
<textarea id="output" rows="20"></textarea>
<br>
<b>Troubleshooting</b>
<br><br>
The page does some heavy computations, so make sure:
<ul>
<li>To use a modern web browser (e.g. Chrome, Firefox)</li>
<li>To use a fast desktop or laptop computer (i.e. not a mobile phone)</li>
<li>Your browser supports WASM <a href="https://webassembly.org/roadmap/">Fixed-width SIMD</a></li>
</ul>
<div class="cell-version">
<span>
|
Build time: <span class="nav-link">@GIT_DATE@</span> |
Commit hash: <a class="nav-link" href="https://github.com/ggerganov/whisper.cpp/commit/@GIT_SHA1@">@GIT_SHA1@</a> |
Commit subject: <span class="nav-link">@GIT_COMMIT_SUBJECT@</span> |
<a class="nav-link" href="https://github.com/ggerganov/whisper.cpp/tree/master/examples/stream.wasm">Source Code</a> |
</span>
</div>
</div>
<script type="text/javascript" src="helpers.js"></script>
<script type='text/javascript'>
// web audio context
var context = null;
// audio data
var audio = null;
var audio0 = null;
// the stream instance
var instance = null;
// model name
var model_whisper = null;
var Module = {
print: printTextarea,
printErr: printTextarea,
setStatus: function(text) {
printTextarea('js: ' + text);
},
monitorRunDependencies: function(left) {
},
preRun: function() {
printTextarea('js: Preparing ...');
},
postRun: function() {
printTextarea('js: Initialized successfully!');
}
};
//
// fetch models
//
let dbVersion = 1
let dbName = 'whisper.ggerganov.com';
let indexedDB = window.indexedDB || window.mozIndexedDB || window.webkitIndexedDB || window.msIndexedDB
function storeFS(fname, buf) {
// write to WASM file using FS_createDataFile
// if the file exists, delete it
try {
Module.FS_unlink(fname);
} catch (e) {
// ignore
}
Module.FS_createDataFile("/", fname, buf, true, true);
printTextarea('storeFS: stored model: ' + fname + ' size: ' + buf.length);
document.getElementById('model-whisper-status').innerHTML = 'loaded "' + model_whisper + '"!';
if (model_whisper != null) {
document.getElementById('start').disabled = false;
document.getElementById('stop' ).disabled = true;
}
}
function loadWhisper(model) {
let urls = {
'tiny.en': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en.bin',
'base.en': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en.bin',
};
let sizes = {
'tiny.en': 75,
'base.en': 142,
};
let url = urls[model];
let dst = 'whisper.bin';
let size_mb = sizes[model];
model_whisper = model;
document.getElementById('fetch-whisper-tiny-en').style.display = 'none';
document.getElementById('fetch-whisper-base-en').style.display = 'none';
document.getElementById('model-whisper-status').innerHTML = 'loading "' + model + '" ... ';
cbProgress = function(p) {
let el = document.getElementById('fetch-whisper-progress');
el.innerHTML = Math.round(100*p) + '%';
};
cbCancel = function() {
var el;
el = document.getElementById('fetch-whisper-tiny-en'); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-base-en'); if (el) el.style.display = 'inline-block';
el = document.getElementById('model-whisper-status'); if (el) el.innerHTML = '';
};
loadRemote(url, dst, size_mb, cbProgress, storeFS, cbCancel, printTextarea);
}
//
// microphone
//
const kSampleRate = 16000;
const kRestartRecording_s = 120;
const kIntervalAudio_ms = 5000; // pass the recorded audio to the C++ instance at this rate
var mediaRecorder = null;
var doRecording = false;
var startTime = 0;
window.AudioContext = window.AudioContext || window.webkitAudioContext;
window.OfflineAudioContext = window.OfflineAudioContext || window.webkitOfflineAudioContext;
function stopRecording() {
Module.set_status("paused");
doRecording = false;
audio0 = null;
audio = null;
context = null;
}
function startRecording() {
if (!context) {
context = new AudioContext({
sampleRate: kSampleRate,
channelCount: 1,
echoCancellation: false,
autoGainControl: true,
noiseSuppression: true,
});
}
Module.set_status("");
document.getElementById('start').disabled = true;
document.getElementById('stop').disabled = false;
doRecording = true;
startTime = Date.now();
var chunks = [];
var stream = null;
navigator.mediaDevices.getUserMedia({audio: true, video: false})
.then(function(s) {
stream = s;
mediaRecorder = new MediaRecorder(stream);
mediaRecorder.ondataavailable = function(e) {
chunks.push(e.data);
var blob = new Blob(chunks, { 'type' : 'audio/ogg; codecs=opus' });
var reader = new FileReader();
reader.onload = function(event) {
var buf = new Uint8Array(reader.result);
if (!context) {
return;
}
context.decodeAudioData(buf.buffer, function(audioBuffer) {
var offlineContext = new OfflineAudioContext(audioBuffer.numberOfChannels, audioBuffer.length, audioBuffer.sampleRate);
var source = offlineContext.createBufferSource();
source.buffer = audioBuffer;
source.connect(offlineContext.destination);
source.start(0);
offlineContext.startRendering().then(function(renderedBuffer) {
audio = renderedBuffer.getChannelData(0);
//printTextarea('js: audio recorded, size: ' + audio.length + ', old size: ' + (audio0 == null ? 0 : audio0.length));
var audioAll = new Float32Array(audio0 == null ? audio.length : audio0.length + audio.length);
if (audio0 != null) {
audioAll.set(audio0, 0);
}
audioAll.set(audio, audio0 == null ? 0 : audio0.length);
if (instance) {
Module.set_audio(instance, audioAll);
}
});
}, function(e) {
audio = null;
});
}
reader.readAsArrayBuffer(blob);
};
mediaRecorder.onstop = function(e) {
if (doRecording) {
setTimeout(function() {
startRecording();
});
}
};
mediaRecorder.start(kIntervalAudio_ms);
})
.catch(function(err) {
printTextarea('js: error getting audio stream: ' + err);
});
var interval = setInterval(function() {
if (!doRecording) {
clearInterval(interval);
mediaRecorder.stop();
stream.getTracks().forEach(function(track) {
track.stop();
});
document.getElementById('start').disabled = false;
document.getElementById('stop').disabled = true;
mediaRecorder = null;
}
// if audio length is more than kRestartRecording_s seconds, restart recording
if (audio != null && audio.length > kSampleRate*kRestartRecording_s) {
if (doRecording) {
//printTextarea('js: restarting recording');
clearInterval(interval);
audio0 = audio;
audio = null;
mediaRecorder.stop();
stream.getTracks().forEach(function(track) {
track.stop();
});
}
}
}, 100);
}
//
// main
//
var nLines = 0;
var intervalUpdate = null;
var transcribedAll = '';
function onStart() {
if (!instance) {
instance = Module.init('whisper.bin');
if (instance) {
printTextarea("js: whisper initialized, instance: " + instance);
}
}
if (!instance) {
printTextarea("js: failed to initialize whisper");
return;
}
startRecording();
intervalUpdate = setInterval(function() {
var transcribed = Module.get_transcribed();
if (transcribed != null && transcribed.length > 1) {
transcribedAll += transcribed + '<br>';
nLines++;
// if more than 10 lines, remove the first line
if (nLines > 10) {
var i = transcribedAll.indexOf('<br>');
if (i > 0) {
transcribedAll = transcribedAll.substring(i + 4);
nLines--;
}
}
}
document.getElementById('state-status').innerHTML = Module.get_status();
document.getElementById('state-transcribed').innerHTML = transcribedAll;
}, 100);
}
function onStop() {
stopRecording();
}
</script>
<script type="text/javascript" src="stream.js"></script>
</body>
</html>

View File

@ -2,9 +2,6 @@ if (WHISPER_SUPPORT_SDL2)
# stream
set(TARGET stream)
add_executable(${TARGET} stream.cpp)
include(DefaultTargetOptions)
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
target_link_libraries(${TARGET} PRIVATE whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
endif ()

View File

@ -10,23 +10,6 @@ More info is available in [issue #10](https://github.com/ggerganov/whisper.cpp/i
https://user-images.githubusercontent.com/1991296/194935793-76afede7-cfa8-48d8-a80f-28ba83be7d09.mp4
## Sliding window mode with VAD
Setting the `--step` argument to `0` enables the sliding window mode:
```java
./stream -m ./models/ggml-small.en.bin -t 6 --step 0 --length 30000 -vth 0.6
```
In this mode, the tool will transcribe only after some speech activity is detected. A very
basic VAD detector is used, but in theory a more sophisticated approach can be added. The
`-vth` argument determines the VAD threshold - higher values will make it detect silence more often.
It's best to tune it to the specific use case, but a value around `0.6` should be OK in general.
When silence is detected, it will transcribe the last `--length` milliseconds of audio and output
a transcription block that is suitable for parsing.
## Building
The `stream` tool depends on SDL2 library to capture audio from the microphone. You can build it like this:
```bash
@ -38,7 +21,3 @@ brew install sdl2
make stream
```
## Web version
This tool can also run in the browser: [examples/stream.wasm](/examples/stream.wasm)

View File

@ -1,21 +1,23 @@
// Real-time speech recognition of input from a microphone
//
// A very quick-n-dirty implementation serving mainly as a proof of concept.
//
#include "whisper.h"
// third-party utilities
// use your favorite implementations
#define DR_WAV_IMPLEMENTATION
#include "dr_wav.h"
#include <SDL.h>
#include <SDL_audio.h>
#include <atomic>
#include <cassert>
#include <cstdio>
#include <string>
#include <thread>
#include <vector>
#include <fstream>
#include <mutex>
// 500 -> 00:05.000
// 6000 -> 01:00.000
@ -33,26 +35,21 @@ std::string to_timestamp(int64_t t) {
// command-line parameters
struct whisper_params {
int32_t seed = -1; // RNG seed, not used currently
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
int32_t step_ms = 3000;
int32_t length_ms = 10000;
int32_t keep_ms = 200;
int32_t capture_id = -1;
int32_t max_tokens = 32;
int32_t audio_ctx = 0;
float vad_thold = 0.6f;
float freq_thold = 100.0f;
bool speed_up = false;
bool translate = false;
bool print_special = false;
bool no_context = true;
bool no_timestamps = false;
bool verbose = false;
bool translate = false;
bool no_context = true;
bool print_special_tokens = false;
bool no_timestamps = true;
std::string language = "en";
std::string model = "models/ggml-base.en.bin";
std::string fname_out;
std::string fname_out = "";
};
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
@ -61,27 +58,41 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
for (int i = 1; i < argc; i++) {
std::string arg = argv[i];
if (arg == "-h" || arg == "--help") {
if (arg == "-s" || arg == "--seed") {
params.seed = std::stoi(argv[++i]);
} else if (arg == "-t" || arg == "--threads") {
params.n_threads = std::stoi(argv[++i]);
} else if (arg == "--step") {
params.step_ms = std::stoi(argv[++i]);
} else if (arg == "--length") {
params.length_ms = std::stoi(argv[++i]);
} else if (arg == "-c" || arg == "--capture") {
params.capture_id = std::stoi(argv[++i]);
} else if (arg == "-v" || arg == "--verbose") {
params.verbose = true;
} else if (arg == "--translate") {
params.translate = true;
} else if (arg == "-kc" || arg == "--keep-context") {
params.no_context = false;
} else if (arg == "-l" || arg == "--language") {
params.language = argv[++i];
if (whisper_lang_id(params.language.c_str()) == -1) {
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
whisper_print_usage(argc, argv, params);
exit(0);
}
} else if (arg == "-ps" || arg == "--print_special") {
params.print_special_tokens = true;
} else if (arg == "-nt" || arg == "--no_timestamps") {
params.no_timestamps = true;
} else if (arg == "-m" || arg == "--model") {
params.model = argv[++i];
} else if (arg == "-f" || arg == "--file") {
params.fname_out = argv[++i];
} else if (arg == "-h" || arg == "--help") {
whisper_print_usage(argc, argv, params);
exit(0);
}
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
else if ( arg == "--step") { params.step_ms = std::stoi(argv[++i]); }
else if ( arg == "--length") { params.length_ms = std::stoi(argv[++i]); }
else if ( arg == "--keep") { params.keep_ms = std::stoi(argv[++i]); }
else if (arg == "-c" || arg == "--capture") { params.capture_id = std::stoi(argv[++i]); }
else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); }
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-kc" || arg == "--keep-context") { params.no_context = false; }
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
else {
} else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params);
exit(0);
@ -91,28 +102,25 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
return true;
}
void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params) {
void whisper_print_usage(int argc, char ** argv, const whisper_params & params) {
fprintf(stderr, "\n");
fprintf(stderr, "usage: %s [options]\n", argv[0]);
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
fprintf(stderr, " --step N [%-7d] audio step size in milliseconds\n", params.step_ms);
fprintf(stderr, " --length N [%-7d] audio length in milliseconds\n", params.length_ms);
fprintf(stderr, " --keep N [%-7d] audio to keep from previous step in ms\n", params.keep_ms);
fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -kc, --keep-context [%-7s] keep context between audio chunks\n", params.no_context ? "false" : "true");
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
fprintf(stderr, " -h, --help show this help message and exit\n");
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n");
fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
fprintf(stderr, " --step N audio step size in milliseconds (default: %d)\n", params.step_ms);
fprintf(stderr, " --length N audio length in milliseconds (default: %d)\n", params.length_ms);
fprintf(stderr, " -c ID, --capture ID capture device ID (default: -1)\n");
fprintf(stderr, " -v, --verbose verbose output\n");
fprintf(stderr, " --translate translate from source language to english\n");
fprintf(stderr, " -kc, --keep-context keep text context from earlier audio (default: false)\n");
fprintf(stderr, " -ps, --print_special print special tokens\n");
fprintf(stderr, " -nt, --no_timestamps do not print timestamps\n");
fprintf(stderr, " -l LANG, --language LANG spoken language (default: %s)\n", params.language.c_str());
fprintf(stderr, " -m FNAME, --model FNAME model path (default: %s)\n", params.model.c_str());
fprintf(stderr, " -f FNAME, --file FNAME text output file name (default: no output to file)\n");
fprintf(stderr, "\n");
}
@ -120,302 +128,70 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
// SDL Audio capture
//
class audio_async {
public:
audio_async(int len_ms);
~audio_async();
SDL_AudioDeviceID g_dev_id_in = 0;
bool init(int capture_id, int sample_rate);
// start capturing audio via the provided SDL callback
// keep last len_ms seconds of audio in a circular buffer
bool resume();
bool pause();
bool clear();
// callback to be called by SDL
void callback(uint8_t * stream, int len);
// get audio data from the circular buffer
void get(int ms, std::vector<float> & audio);
private:
SDL_AudioDeviceID m_dev_id_in = 0;
int m_len_ms = 0;
int m_sample_rate = 0;
std::atomic_bool m_running;
std::mutex m_mutex;
std::vector<float> m_audio;
std::vector<float> m_audio_new;
size_t m_audio_pos = 0;
size_t m_audio_len = 0;
};
audio_async::audio_async(int len_ms) {
m_len_ms = len_ms;
m_running = false;
}
audio_async::~audio_async() {
if (m_dev_id_in) {
SDL_CloseAudioDevice(m_dev_id_in);
}
}
bool audio_async::init(int capture_id, int sample_rate) {
SDL_LogSetPriority(SDL_LOG_CATEGORY_APPLICATION, SDL_LOG_PRIORITY_INFO);
if (SDL_Init(SDL_INIT_AUDIO) < 0) {
SDL_LogError(SDL_LOG_CATEGORY_APPLICATION, "Couldn't initialize SDL: %s\n", SDL_GetError());
bool audio_sdl_init(const int capture_id) {
if (g_dev_id_in) {
fprintf(stderr, "%s: already initialized\n", __func__);
return false;
}
SDL_SetHintWithPriority(SDL_HINT_AUDIO_RESAMPLING_MODE, "medium", SDL_HINT_OVERRIDE);
if (g_dev_id_in == 0) {
SDL_LogSetPriority(SDL_LOG_CATEGORY_APPLICATION, SDL_LOG_PRIORITY_INFO);
{
int nDevices = SDL_GetNumAudioDevices(SDL_TRUE);
fprintf(stderr, "%s: found %d capture devices:\n", __func__, nDevices);
for (int i = 0; i < nDevices; i++) {
fprintf(stderr, "%s: - Capture device #%d: '%s'\n", __func__, i, SDL_GetAudioDeviceName(i, SDL_TRUE));
if (SDL_Init(SDL_INIT_AUDIO) < 0) {
SDL_LogError(SDL_LOG_CATEGORY_APPLICATION, "Couldn't initialize SDL: %s\n", SDL_GetError());
return (1);
}
SDL_SetHintWithPriority(SDL_HINT_AUDIO_RESAMPLING_MODE, "medium", SDL_HINT_OVERRIDE);
{
int nDevices = SDL_GetNumAudioDevices(SDL_TRUE);
fprintf(stderr, "%s: found %d capture devices:\n", __func__, nDevices);
for (int i = 0; i < nDevices; i++) {
fprintf(stderr, "%s: - Capture device #%d: '%s'\n", __func__, i, SDL_GetAudioDeviceName(i, SDL_TRUE));
}
}
}
SDL_AudioSpec capture_spec_requested;
SDL_AudioSpec capture_spec_obtained;
if (g_dev_id_in == 0) {
SDL_AudioSpec capture_spec_requested;
SDL_AudioSpec capture_spec_obtained;
SDL_zero(capture_spec_requested);
SDL_zero(capture_spec_obtained);
SDL_zero(capture_spec_requested);
SDL_zero(capture_spec_obtained);
capture_spec_requested.freq = sample_rate;
capture_spec_requested.format = AUDIO_F32;
capture_spec_requested.channels = 1;
capture_spec_requested.samples = 1024;
capture_spec_requested.callback = [](void * userdata, uint8_t * stream, int len) {
audio_async * audio = (audio_async *) userdata;
audio->callback(stream, len);
};
capture_spec_requested.userdata = this;
capture_spec_requested.freq = WHISPER_SAMPLE_RATE;
capture_spec_requested.format = AUDIO_F32;
capture_spec_requested.channels = 1;
capture_spec_requested.samples = 1024;
if (capture_id >= 0) {
fprintf(stderr, "%s: attempt to open capture device %d : '%s' ...\n", __func__, capture_id, SDL_GetAudioDeviceName(capture_id, SDL_TRUE));
m_dev_id_in = SDL_OpenAudioDevice(SDL_GetAudioDeviceName(capture_id, SDL_TRUE), SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
} else {
fprintf(stderr, "%s: attempt to open default capture device ...\n", __func__);
m_dev_id_in = SDL_OpenAudioDevice(nullptr, SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
}
if (!m_dev_id_in) {
fprintf(stderr, "%s: couldn't open an audio device for capture: %s!\n", __func__, SDL_GetError());
m_dev_id_in = 0;
return false;
} else {
fprintf(stderr, "%s: obtained spec for input device (SDL Id = %d):\n", __func__, m_dev_id_in);
fprintf(stderr, "%s: - sample rate: %d\n", __func__, capture_spec_obtained.freq);
fprintf(stderr, "%s: - format: %d (required: %d)\n", __func__, capture_spec_obtained.format,
capture_spec_requested.format);
fprintf(stderr, "%s: - channels: %d (required: %d)\n", __func__, capture_spec_obtained.channels,
capture_spec_requested.channels);
fprintf(stderr, "%s: - samples per frame: %d\n", __func__, capture_spec_obtained.samples);
}
m_sample_rate = capture_spec_obtained.freq;
m_audio.resize((m_sample_rate*m_len_ms)/1000);
return true;
}
bool audio_async::resume() {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to resume!\n", __func__);
return false;
}
if (m_running) {
fprintf(stderr, "%s: already running!\n", __func__);
return false;
}
SDL_PauseAudioDevice(m_dev_id_in, 0);
m_running = true;
return true;
}
bool audio_async::pause() {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to pause!\n", __func__);
return false;
}
if (!m_running) {
fprintf(stderr, "%s: already paused!\n", __func__);
return false;
}
SDL_PauseAudioDevice(m_dev_id_in, 1);
m_running = false;
return true;
}
bool audio_async::clear() {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to clear!\n", __func__);
return false;
}
if (!m_running) {
fprintf(stderr, "%s: not running!\n", __func__);
return false;
}
{
std::lock_guard<std::mutex> lock(m_mutex);
m_audio_pos = 0;
m_audio_len = 0;
}
return true;
}
// callback to be called by SDL
void audio_async::callback(uint8_t * stream, int len) {
if (!m_running) {
return;
}
const size_t n_samples = len / sizeof(float);
m_audio_new.resize(n_samples);
memcpy(m_audio_new.data(), stream, n_samples * sizeof(float));
//fprintf(stderr, "%s: %zu samples, pos %zu, len %zu\n", __func__, n_samples, m_audio_pos, m_audio_len);
{
std::lock_guard<std::mutex> lock(m_mutex);
if (m_audio_pos + n_samples > m_audio.size()) {
const size_t n0 = m_audio.size() - m_audio_pos;
memcpy(&m_audio[m_audio_pos], stream, n0 * sizeof(float));
memcpy(&m_audio[0], &stream[n0], (n_samples - n0) * sizeof(float));
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
m_audio_len = m_audio.size();
if (capture_id >= 0) {
fprintf(stderr, "%s: attempt to open capture device %d : '%s' ...\n", __func__, capture_id, SDL_GetAudioDeviceName(capture_id, SDL_TRUE));
g_dev_id_in = SDL_OpenAudioDevice(SDL_GetAudioDeviceName(capture_id, SDL_TRUE), SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
} else {
memcpy(&m_audio[m_audio_pos], stream, n_samples * sizeof(float));
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
m_audio_len = std::min(m_audio_len + n_samples, m_audio.size());
fprintf(stderr, "%s: attempt to open default capture device ...\n", __func__);
g_dev_id_in = SDL_OpenAudioDevice(nullptr, SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
}
}
}
void audio_async::get(int ms, std::vector<float> & result) {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to get audio from!\n", __func__);
return;
}
if (!m_running) {
fprintf(stderr, "%s: not running!\n", __func__);
return;
}
result.clear();
{
std::lock_guard<std::mutex> lock(m_mutex);
if (ms <= 0) {
ms = m_len_ms;
}
size_t n_samples = (m_sample_rate * ms) / 1000;
if (n_samples > m_audio_len) {
n_samples = m_audio_len;
}
result.resize(n_samples);
int s0 = m_audio_pos - n_samples;
if (s0 < 0) {
s0 += m_audio.size();
}
if (s0 + n_samples > m_audio.size()) {
const size_t n0 = m_audio.size() - s0;
memcpy(result.data(), &m_audio[s0], n0 * sizeof(float));
memcpy(&result[n0], &m_audio[0], (n_samples - n0) * sizeof(float));
if (!g_dev_id_in) {
fprintf(stderr, "%s: couldn't open an audio device for capture: %s!\n", __func__, SDL_GetError());
g_dev_id_in = 0;
} else {
memcpy(result.data(), &m_audio[s0], n_samples * sizeof(float));
fprintf(stderr, "%s: obtained spec for input device (SDL Id = %d):\n", __func__, g_dev_id_in);
fprintf(stderr, "%s: - sample rate: %d\n", __func__, capture_spec_obtained.freq);
fprintf(stderr, "%s: - format: %d (required: %d)\n", __func__, capture_spec_obtained.format, capture_spec_requested.format);
fprintf(stderr, "%s: - channels: %d (required: %d)\n", __func__, capture_spec_obtained.channels, capture_spec_requested.channels);
fprintf(stderr, "%s: - samples per frame: %d\n", __func__, capture_spec_obtained.samples);
}
}
return true;
}
///////////////////////////
void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate) {
const float rc = 1.0f / (2.0f * M_PI * cutoff);
const float dt = 1.0f / sample_rate;
const float alpha = dt / (rc + dt);
float y = data[0];
for (size_t i = 1; i < data.size(); i++) {
y = alpha * (y + data[i] - data[i - 1]);
data[i] = y;
}
}
bool vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) {
const int n_samples = pcmf32.size();
const int n_samples_last = (sample_rate * last_ms) / 1000;
if (n_samples_last >= n_samples) {
// not enough samples - assume no speech
return false;
}
if (freq_thold > 0.0f) {
high_pass_filter(pcmf32, freq_thold, sample_rate);
}
float energy_all = 0.0f;
float energy_last = 0.0f;
for (int i = 0; i < n_samples; i++) {
energy_all += fabsf(pcmf32[i]);
if (i >= n_samples - n_samples_last) {
energy_last += fabsf(pcmf32[i]);
}
}
energy_all /= n_samples;
energy_last /= n_samples_last;
if (verbose) {
fprintf(stderr, "%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold);
}
if (energy_last > vad_thold*energy_all) {
return false;
}
return true;
}
int main(int argc, char ** argv) {
whisper_params params;
@ -423,47 +199,29 @@ int main(int argc, char ** argv) {
return 1;
}
params.keep_ms = std::min(params.keep_ms, params.step_ms);
params.length_ms = std::max(params.length_ms, params.step_ms);
const int n_samples_step = (params.step_ms *1e-3)*WHISPER_SAMPLE_RATE;
const int n_samples_len = (params.length_ms*1e-3)*WHISPER_SAMPLE_RATE;
const int n_samples_keep = (params.keep_ms *1e-3)*WHISPER_SAMPLE_RATE;
const int n_samples_30s = (30000 *1e-3)*WHISPER_SAMPLE_RATE;
const bool use_vad = n_samples_step <= 0; // sliding window mode uses VAD
const int n_new_line = !use_vad ? std::max(1, params.length_ms / params.step_ms - 1) : 1; // number of steps to print new line
params.no_timestamps = !use_vad;
params.no_context |= use_vad;
params.max_tokens = 0;
if (params.seed < 0) {
params.seed = time(NULL);
}
// init audio
audio_async audio(params.length_ms);
if (!audio.init(params.capture_id, WHISPER_SAMPLE_RATE)) {
fprintf(stderr, "%s: audio.init() failed!\n", __func__);
if (!audio_sdl_init(params.capture_id)) {
fprintf(stderr, "%s: audio_sdl_init() failed!\n", __func__);
return 1;
}
audio.resume();
// whisper init
if (whisper_lang_id(params.language.c_str()) == -1) {
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
whisper_print_usage(argc, argv, params);
exit(0);
}
struct whisper_context * ctx = whisper_init(params.model.c_str());
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
const int n_samples = (params.step_ms/1000.0)*WHISPER_SAMPLE_RATE;
const int n_samples_len = (params.length_ms/1000.0)*WHISPER_SAMPLE_RATE;
const int n_samples_30s = 30*WHISPER_SAMPLE_RATE;
std::vector<float> pcmf32 (n_samples_30s, 0.0f);
std::vector<float> pcmf32(n_samples_30s, 0.0f);
std::vector<float> pcmf32_old;
std::vector<float> pcmf32_new(n_samples_30s, 0.0f);
std::vector<whisper_token> prompt_tokens;
const int n_new_line = params.length_ms / params.step_ms - 1;
// print some info about the processing
{
@ -475,28 +233,23 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
}
}
fprintf(stderr, "%s: processing %d samples (step = %.1f sec / len = %.1f sec / keep = %.1f sec), %d threads, lang = %s, task = %s, timestamps = %d ...\n",
fprintf(stderr, "%s: processing %d samples (step = %.1f sec / len = %.1f sec), %d threads, lang = %s, task = %s, timestamps = %d ...\n",
__func__,
n_samples_step,
float(n_samples_step)/WHISPER_SAMPLE_RATE,
float(n_samples_len )/WHISPER_SAMPLE_RATE,
float(n_samples_keep)/WHISPER_SAMPLE_RATE,
n_samples,
float(n_samples)/WHISPER_SAMPLE_RATE,
float(n_samples_len)/WHISPER_SAMPLE_RATE,
params.n_threads,
params.language.c_str(),
params.translate ? "translate" : "transcribe",
params.no_timestamps ? 0 : 1);
if (!use_vad) {
fprintf(stderr, "%s: n_new_line = %d, no_context = %d\n", __func__, n_new_line, params.no_context);
} else {
fprintf(stderr, "%s: using VAD, will transcribe on speech activity\n", __func__);
}
fprintf(stderr, "%s: n_new_line = %d\n", __func__, n_new_line);
fprintf(stderr, "\n");
}
int n_iter = 0;
SDL_PauseAudioDevice(g_dev_id_in, 0);
int n_iter = 0;
bool is_running = true;
std::ofstream fout;
@ -511,27 +264,18 @@ int main(int argc, char ** argv) {
printf("[Start speaking]");
fflush(stdout);
auto t_last = std::chrono::high_resolution_clock::now();
const auto t_start = t_last;
// main audio loop
while (is_running) {
// handle Ctrl + C
{
SDL_Event event;
while (SDL_PollEvent(&event)) {
switch (event.type) {
case SDL_QUIT:
{
is_running = false;
} break;
default:
break;
}
}
if (!is_running) {
break;
// process SDL events:
SDL_Event event;
while (SDL_PollEvent(&event)) {
switch (event.type) {
case SDL_QUIT:
{
is_running = false;
} break;
default:
break;
}
}
@ -540,87 +284,47 @@ int main(int argc, char ** argv) {
}
// process new audio
if (!use_vad) {
while (true) {
audio.get(params.step_ms, pcmf32_new);
if ((int) pcmf32_new.size() > 2*n_samples_step) {
fprintf(stderr, "\n\n%s: WARNING: cannot process audio fast enough, dropping audio ...\n\n", __func__);
audio.clear();
continue;
}
if ((int) pcmf32_new.size() >= n_samples_step) {
audio.clear();
break;
}
SDL_Delay(1);
}
const int n_samples_new = pcmf32_new.size();
// take up to params.length_ms audio from previous iteration
const int n_samples_take = std::min((int) pcmf32_old.size(), std::max(0, n_samples_keep + n_samples_len - n_samples_new));
//printf("processing: take = %d, new = %d, old = %d\n", n_samples_take, n_samples_new, (int) pcmf32_old.size());
pcmf32.resize(n_samples_new + n_samples_take);
for (int i = 0; i < n_samples_take; i++) {
pcmf32[i] = pcmf32_old[pcmf32_old.size() - n_samples_take + i];
}
memcpy(pcmf32.data() + n_samples_take, pcmf32_new.data(), n_samples_new*sizeof(float));
pcmf32_old = pcmf32;
} else {
const auto t_now = std::chrono::high_resolution_clock::now();
const auto t_diff = std::chrono::duration_cast<std::chrono::milliseconds>(t_now - t_last).count();
if (t_diff < 2000) {
std::this_thread::sleep_for(std::chrono::milliseconds(100));
continue;
}
audio.get(2000, pcmf32_new);
if (vad_simple(pcmf32_new, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, false)) {
audio.get(params.length_ms, pcmf32);
} else {
std::this_thread::sleep_for(std::chrono::milliseconds(100));
continue;
}
t_last = t_now;
if (n_iter > 0 && SDL_GetQueuedAudioSize(g_dev_id_in) > 2*n_samples*sizeof(float)) {
fprintf(stderr, "\n\n%s: WARNING: cannot process audio fast enough, dropping audio ...\n\n", __func__);
SDL_ClearQueuedAudio(g_dev_id_in);
}
while (SDL_GetQueuedAudioSize(g_dev_id_in) < n_samples*sizeof(float)) {
SDL_Delay(1);
}
const int n_samples_new = SDL_GetQueuedAudioSize(g_dev_id_in)/sizeof(float);
// take one second from previous iteration
//const int n_samples_take = std::min((int) pcmf32_old.size(), std::max(0, n_samples_30s/30 - n_samples_new));
// take up to params.length_ms audio from previous iteration
const int n_samples_take = std::min((int) pcmf32_old.size(), std::max(0, n_samples_len - n_samples_new));
//printf("processing: take = %d, new = %d, old = %d\n", n_samples_take, n_samples_new, (int) pcmf32_old.size());
pcmf32.resize(n_samples_new + n_samples_take);
for (int i = 0; i < n_samples_take; i++) {
pcmf32[i] = pcmf32_old[pcmf32_old.size() - n_samples_take + i];
}
SDL_DequeueAudio(g_dev_id_in, pcmf32.data() + n_samples_take, n_samples_new*sizeof(float));
pcmf32_old = pcmf32;
// run the inference
{
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
wparams.print_progress = false;
wparams.print_special = params.print_special;
wparams.print_realtime = false;
wparams.print_timestamps = !params.no_timestamps;
wparams.translate = params.translate;
wparams.no_context = true;
wparams.single_segment = !use_vad;
wparams.max_tokens = params.max_tokens;
wparams.language = params.language.c_str();
wparams.n_threads = params.n_threads;
wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;
// disable temperature fallback
wparams.temperature_inc = -1.0f;
wparams.prompt_tokens = params.no_context ? nullptr : prompt_tokens.data();
wparams.prompt_n_tokens = params.no_context ? 0 : prompt_tokens.size();
wparams.print_progress = false;
wparams.print_special_tokens = params.print_special_tokens;
wparams.print_realtime = false;
wparams.print_timestamps = !params.no_timestamps;
wparams.translate = params.translate;
wparams.no_context = params.no_context;
wparams.language = params.language.c_str();
wparams.n_threads = params.n_threads;
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
fprintf(stderr, "%s: failed to process audio\n", argv[0]);
@ -629,21 +333,12 @@ int main(int argc, char ** argv) {
// print result;
{
if (!use_vad) {
printf("\33[2K\r");
printf("\33[2K\r");
// print long empty line to clear the previous line
printf("%s", std::string(100, ' ').c_str());
// print long empty line to clear the previous line
printf("%s", std::string(100, ' ').c_str());
printf("\33[2K\r");
} else {
const int64_t t1 = (t_last - t_start).count()/1000000;
const int64_t t0 = std::max(0.0, t1 - pcmf32.size()*1000.0/WHISPER_SAMPLE_RATE);
printf("\n");
printf("### Transcription %d START | t0 = %d ms | t1 = %d ms\n", n_iter, (int) t0, (int) t1);
printf("\n");
}
printf("\33[2K\r");
const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) {
@ -671,39 +366,18 @@ int main(int argc, char ** argv) {
if (params.fname_out.length() > 0) {
fout << std::endl;
}
if (use_vad){
printf("\n");
printf("### Transcription %d END\n", n_iter);
}
}
++n_iter;
if (!use_vad && (n_iter % n_new_line) == 0) {
if ((n_iter % n_new_line) == 0) {
printf("\n");
// keep part of the audio for next iteration to try to mitigate word boundary issues
pcmf32_old = std::vector<float>(pcmf32.end() - n_samples_keep, pcmf32.end());
// Add tokens of the last full length segment as the prompt
if (!params.no_context) {
prompt_tokens.clear();
const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) {
const int token_count = whisper_full_n_tokens(ctx, i);
for (int j = 0; j < token_count; ++j) {
prompt_tokens.push_back(whisper_full_get_token_id(ctx, i, j));
}
}
}
pcmf32_old.clear();
}
}
}
audio.pause();
whisper_print_timings(ctx);
whisper_free(ctx);

View File

@ -1,50 +0,0 @@
#
# libtalk
#
set(TARGET libtalk)
add_executable(${TARGET}
emscripten.cpp
gpt-2.cpp
)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE
whisper
)
unset(EXTRA_FLAGS)
if (WHISPER_WASM_SINGLE_FILE)
set(EXTRA_FLAGS "-s SINGLE_FILE=1")
message(STATUS "Embedding WASM inside talk.js")
add_custom_command(
TARGET ${TARGET} POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy
${CMAKE_BINARY_DIR}/bin/libtalk.js
${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/talk.wasm/talk.js
)
endif()
set_target_properties(${TARGET} PROPERTIES LINK_FLAGS " \
--bind \
-s USE_PTHREADS=1 \
-s PTHREAD_POOL_SIZE=8 \
-s INITIAL_MEMORY=1800MB \
-s TOTAL_MEMORY=1800MB \
-s FORCE_FILESYSTEM=1 \
-s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap']\" \
${EXTRA_FLAGS} \
")
#
# talk.wasm
#
set(TARGET talk.wasm)
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/index-tmpl.html ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TARGET}/index.html @ONLY)
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/../helpers.js ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TARGET}/helpers.js @ONLY)

View File

@ -1,74 +0,0 @@
# talk.wasm
Talk with an Artificial Intelligence in your browser:
[https://user-images.githubusercontent.com/1991296/203411580-fedb4839-05e4-4474-8364-aaf1e9a9b615.mp4](https://user-images.githubusercontent.com/1991296/203845553-f7b44e13-9a15-4fc8-b518-ae8f4c6770fe.mp4)
Online demo: https://whisper.ggerganov.com/talk/
Terminal version: [examples/talk](/examples/talk)
## How it works?
This demo leverages 2 modern neural network models to create a high-quality voice chat directly in your browser:
- [OpenAI's Whisper](https://github.com/openai/whisper) speech recognition model is used to process your voice and understand what you are saying
- Upon receiving some voice input, the AI generates a text response using [OpenAI's GPT-2](https://github.com/openai/gpt-2) language model
- The AI then vocalizes the response using the browser's [Web Speech API](https://developer.mozilla.org/en-US/docs/Web/API/Web_Speech_API)
The web page does the processing locally on your machine. The processing of these heavy neural network models in the
browser is possible by implementing them efficiently in C/C++ and using the browser's WebAssembly SIMD capabilities for
extra performance:
- The Whisper C++ implementation is here: [whisper.h](/whisper.h) / [whisper.cpp](/whisper.cpp)
- The GPT-2 C++ implementation is here: [gpt-2.h](gpt-2.h) / [gpt-2.cpp](gpt-2.cpp)
- Both models use a custom tensor library implemented in C: [ggml.h](/ggml.h) / [ggml.c](/ggml.c)
- The HTML/JS layer is here: [index-tmpl.html](index-tmpl.html)
- The Emscripten bridge between C/C++ and JS is here: [emscripten.cpp](emscripten.cpp)
In order to run the models, the web page first needs to download the model data which is about ~350 MB. The model data
is then cached in your browser's cache and can be reused in future visits without downloading it again.
## Requirements
In order to run this demo efficiently, you need to have the following:
- Latest Chrome or Firefox browser (Safari is not supported)
- Run this on a desktop or laptop with modern CPU (a mobile phone will likely not be good enough)
- Speak phrases that are no longer than 10 seconds - this is the audio context of the AI
- The web-page uses about 1.8GB of RAM
Notice that this demo is using the smallest GPT-2 model, so the generated text responses are not always very good.
Also, the prompting strategy can likely be improved to achieve better results.
The demo is quite computationally heavy, so you need a fast CPU. It's not usual to run these transformer models in a
browser. Typically, they run on powerful GPUs.
Currently, mobile browsers do not support the Fixed-width SIMD WebAssembly capability, so you cannot run this demo
on a phone or a tablet. Hopefully, in the near future this will become supported.
## Todo
- Better UI (contributions are welcome)
- Better GPT-2 prompting
## Build instructions
```bash
# build using Emscripten (v3.1.2)
git clone https://github.com/ggerganov/whisper.cpp
cd whisper.cpp
mkdir build-em && cd build-em
emcmake cmake ..
make -j
# copy the produced page to your HTTP path
cp bin/talk.wasm/* /path/to/html/
cp bin/libtalk.worker.js /path/to/html/
```
## Feedback
If you have any comments or ideas for improvement, please drop a comment in the following discussion:
https://github.com/ggerganov/whisper.cpp/discussions/167

View File

@ -1,380 +0,0 @@
#include "ggml.h"
#include "gpt-2.h"
#include "whisper.h"
#include <emscripten.h>
#include <emscripten/bind.h>
#include <atomic>
#include <cmath>
#include <mutex>
#include <string>
#include <thread>
#include <vector>
#include <regex>
constexpr int N_THREAD = 8;
struct gpt2_context * g_gpt2;
std::vector<struct whisper_context *> g_contexts(4, nullptr);
std::mutex g_mutex;
std::thread g_worker;
std::atomic<bool> g_running(false);
bool g_force_speak = false;
std::string g_text_to_speak = "";
std::string g_status = "";
std::string g_status_forced = "";
std::vector<float> g_pcmf32;
std::string to_timestamp(int64_t t) {
int64_t sec = t/100;
int64_t msec = t - sec*100;
int64_t min = sec/60;
sec = sec - min*60;
char buf[32];
snprintf(buf, sizeof(buf), "%02d:%02d.%03d", (int) min, (int) sec, (int) msec);
return std::string(buf);
}
void talk_set_status(const std::string & status) {
std::lock_guard<std::mutex> lock(g_mutex);
g_status = status;
}
void talk_main(size_t index) {
talk_set_status("loading data ...");
struct whisper_full_params wparams = whisper_full_default_params(whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY);
wparams.n_threads = std::min(N_THREAD, (int) std::thread::hardware_concurrency());
wparams.offset_ms = 0;
wparams.translate = false;
wparams.no_context = true;
wparams.single_segment = true;
wparams.print_realtime = false;
wparams.print_progress = false;
wparams.print_timestamps = true;
wparams.print_special = false;
wparams.max_tokens = 32;
wparams.audio_ctx = 768; // partial encoder context for better performance
wparams.language = "en";
g_gpt2 = gpt2_init("gpt-2.bin");
printf("talk: using %d threads\n", wparams.n_threads);
std::vector<float> pcmf32;
// whisper context
auto & ctx = g_contexts[index];
const int64_t step_samples = 2*WHISPER_SAMPLE_RATE;
const int64_t window_samples = 9*WHISPER_SAMPLE_RATE;
const int64_t step_ms = (step_samples*1000)/WHISPER_SAMPLE_RATE;
auto t_last = std::chrono::high_resolution_clock::now();
talk_set_status("listening ...");
while (g_running) {
const auto t_now = std::chrono::high_resolution_clock::now();
if (std::chrono::duration_cast<std::chrono::milliseconds>(t_now - t_last).count() < step_ms) {
{
std::lock_guard<std::mutex> lock(g_mutex);
g_pcmf32.clear();
}
std::this_thread::sleep_for(std::chrono::milliseconds(10));
continue;
}
talk_set_status("listening ...");
{
std::unique_lock<std::mutex> lock(g_mutex);
if (g_pcmf32.size() < step_samples) {
lock.unlock();
std::this_thread::sleep_for(std::chrono::milliseconds(10));
continue;
}
pcmf32 = std::vector<float>(g_pcmf32.end() - std::min((int64_t) g_pcmf32.size(), window_samples), g_pcmf32.end());
}
// VAD: if energy in during last second is above threshold, then skip
{
float energy_all = 0.0f;
float energy_1s = 0.0f;
for (size_t i = 0; i < pcmf32.size(); i++) {
energy_all += fabsf(pcmf32[i]);
if (i >= pcmf32.size() - WHISPER_SAMPLE_RATE) {
energy_1s += fabsf(pcmf32[i]);
}
}
energy_all /= pcmf32.size();
energy_1s /= WHISPER_SAMPLE_RATE;
if (energy_1s > 0.1f*energy_all && !g_force_speak) {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
continue;
}
}
talk_set_status("processing audio (whisper)...");
t_last = t_now;
if (!g_force_speak) {
const auto t_start = std::chrono::high_resolution_clock::now();
int ret = whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size());
if (ret != 0) {
printf("whisper_full() failed: %d\n", ret);
break;
}
const auto t_end = std::chrono::high_resolution_clock::now();
printf("whisper_full() returned %d in %f seconds\n", ret, std::chrono::duration<double>(t_end - t_start).count());
}
{
std::string text_heard;
if (!g_force_speak) {
const int n_segments = whisper_full_n_segments(ctx);
for (int i = n_segments - 1; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
printf ("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
text_heard += text;
}
}
g_force_speak = false;
// remove text between brackets using regex
{
std::regex re("\\[.*?\\]");
text_heard = std::regex_replace(text_heard, re, "");
}
// remove text between brackets using regex
{
std::regex re("\\(.*?\\)");
text_heard = std::regex_replace(text_heard, re, "");
}
// remove all characters, except for letters, numbers, punctuation and ':', '\'', '-', ' '
text_heard = std::regex_replace(text_heard, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
// take first line
text_heard = text_heard.substr(0, text_heard.find_first_of("\n"));
// remove leading and trailing whitespace
text_heard = std::regex_replace(text_heard, std::regex("^\\s+"), "");
text_heard = std::regex_replace(text_heard, std::regex("\\s+$"), "");
talk_set_status("'" + text_heard + "' - thinking how to respond (gpt-2) ...");
const std::vector<gpt_vocab::id> tokens = gpt2_tokenize(g_gpt2, text_heard.c_str());
printf("whisper: number of tokens: %d, '%s'\n", (int) tokens.size(), text_heard.c_str());
std::string text_to_speak;
std::string prompt_base;
{
std::lock_guard<std::mutex> lock(g_mutex);
prompt_base = gpt2_get_prompt(g_gpt2);
}
if (tokens.size() > 0) {
text_to_speak = gpt2_gen_text(g_gpt2, (prompt_base + text_heard + "\n").c_str(), 32);
text_to_speak = std::regex_replace(text_to_speak, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
text_to_speak = text_to_speak.substr(0, text_to_speak.find_first_of("\n"));
std::lock_guard<std::mutex> lock(g_mutex);
// remove first 2 lines of base prompt
{
const size_t pos = prompt_base.find_first_of("\n");
if (pos != std::string::npos) {
prompt_base = prompt_base.substr(pos + 1);
}
}
{
const size_t pos = prompt_base.find_first_of("\n");
if (pos != std::string::npos) {
prompt_base = prompt_base.substr(pos + 1);
}
}
prompt_base += text_heard + "\n" + text_to_speak + "\n";
} else {
text_to_speak = gpt2_gen_text(g_gpt2, prompt_base.c_str(), 32);
text_to_speak = std::regex_replace(text_to_speak, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
text_to_speak = text_to_speak.substr(0, text_to_speak.find_first_of("\n"));
std::lock_guard<std::mutex> lock(g_mutex);
const size_t pos = prompt_base.find_first_of("\n");
if (pos != std::string::npos) {
prompt_base = prompt_base.substr(pos + 1);
}
prompt_base += text_to_speak + "\n";
}
printf("gpt-2: %s\n", text_to_speak.c_str());
//printf("========================\n");
//printf("gpt-2: prompt_base:\n'%s'\n", prompt_base.c_str());
//printf("========================\n");
{
std::lock_guard<std::mutex> lock(g_mutex);
t_last = std::chrono::high_resolution_clock::now();
g_text_to_speak = text_to_speak;
g_pcmf32.clear();
gpt2_set_prompt(g_gpt2, prompt_base.c_str());
}
talk_set_status("speaking ...");
}
}
gpt2_free(g_gpt2);
if (index < g_contexts.size()) {
whisper_free(g_contexts[index]);
g_contexts[index] = nullptr;
}
}
EMSCRIPTEN_BINDINGS(talk) {
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
for (size_t i = 0; i < g_contexts.size(); ++i) {
if (g_contexts[i] == nullptr) {
g_contexts[i] = whisper_init_from_file(path_model.c_str());
if (g_contexts[i] != nullptr) {
g_running = true;
if (g_worker.joinable()) {
g_worker.join();
}
g_worker = std::thread([i]() {
talk_main(i);
});
return i + 1;
} else {
return (size_t) 0;
}
}
}
return (size_t) 0;
}));
emscripten::function("free", emscripten::optional_override([](size_t index) {
if (g_running) {
g_running = false;
}
}));
emscripten::function("set_audio", emscripten::optional_override([](size_t index, const emscripten::val & audio) {
--index;
if (index >= g_contexts.size()) {
return -1;
}
if (g_contexts[index] == nullptr) {
return -2;
}
{
std::lock_guard<std::mutex> lock(g_mutex);
const int n = audio["length"].as<int>();
emscripten::val heap = emscripten::val::module_property("HEAPU8");
emscripten::val memory = heap["buffer"];
g_pcmf32.resize(n);
emscripten::val memoryView = audio["constructor"].new_(memory, reinterpret_cast<uintptr_t>(g_pcmf32.data()), n);
memoryView.call<void>("set", audio);
}
return 0;
}));
emscripten::function("force_speak", emscripten::optional_override([](size_t index) {
{
std::lock_guard<std::mutex> lock(g_mutex);
g_force_speak = true;
}
}));
emscripten::function("get_text_context", emscripten::optional_override([]() {
std::string text_context;
{
std::lock_guard<std::mutex> lock(g_mutex);
text_context = gpt2_get_prompt(g_gpt2);
}
return text_context;
}));
emscripten::function("get_text_to_speak", emscripten::optional_override([]() {
std::string text_to_speak;
{
std::lock_guard<std::mutex> lock(g_mutex);
text_to_speak = std::move(g_text_to_speak);
}
return text_to_speak;
}));
emscripten::function("get_status", emscripten::optional_override([]() {
std::string status;
{
std::lock_guard<std::mutex> lock(g_mutex);
status = g_status_forced.empty() ? g_status : g_status_forced;
}
return status;
}));
emscripten::function("set_status", emscripten::optional_override([](const std::string & status) {
{
std::lock_guard<std::mutex> lock(g_mutex);
g_status_forced = status;
}
}));
emscripten::function("set_prompt", emscripten::optional_override([](const std::string & prompt) {
{
std::lock_guard<std::mutex> lock(g_mutex);
gpt2_set_prompt(g_gpt2, prompt.c_str());
}
}));
}

View File

@ -1,925 +0,0 @@
#include "ggml.h"
#include "gpt-2.h"
#include <cmath>
#include <cstdio>
#include <cstring>
#include <fstream>
#include <map>
#include <string>
#include <thread>
#include <vector>
#include <regex>
#include <random>
/////////////////////// GPT-2 BEGIN /////////////////////////
//
// Vocab utils
//
std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text) {
std::vector<std::string> words;
// first split the text into words
{
std::string str = text;
std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
std::regex re(pat);
std::smatch m;
while (std::regex_search(str, m, re)) {
for (auto x : m) {
words.push_back(x);
}
str = m.suffix();
}
}
// find the longest tokens that form the words:
std::vector<gpt_vocab::id> tokens;
for (const auto & word : words) {
if (word.size() == 0) continue;
int i = 0;
int n = word.size();
while (i < n) {
int j = n;
while (j > i) {
auto it = vocab.token_to_id.find(word.substr(i, j-i));
if (it != vocab.token_to_id.end()) {
tokens.push_back(it->second);
i = j;
break;
}
--j;
}
if (i == n) {
break;
}
if (j == i) {
auto sub = word.substr(i, 1);
if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) {
tokens.push_back(vocab.token_to_id.at(sub));
} else {
fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data());
}
++i;
}
}
}
return tokens;
}
gpt_vocab::id gpt_sample_top_k_top_p(
const gpt_vocab & vocab,
const float * logits,
int top_k,
double top_p,
double temp,
std::mt19937 & rng) {
int n_logits = vocab.id_to_token.size();
std::vector<std::pair<double, gpt_vocab::id>> logits_id;
logits_id.reserve(n_logits);
for (int i = 0; i < n_logits; i++) {
logits_id.push_back(std::make_pair(logits[i], i));
}
// find the top K tokens
std::partial_sort(
logits_id.begin(),
logits_id.begin() + top_k, logits_id.end(),
[](const std::pair<double, gpt_vocab::id> & a, const std::pair<double, gpt_vocab::id> & b) {
return a.first > b.first;
});
logits_id.resize(top_k);
// normalize
{
double sum = 0.0f;
for (int i = 0; i < (int)logits_id.size(); i++) {
sum += logits_id[i].first;
}
sum = 1.0/sum;
for (int i = 0; i < (int)logits_id.size(); i++) {
logits_id[i].first *= sum;
}
}
if (top_p < 1.0f) {
{
double cumsum = 0.0f;
for (int i = 0; i < top_k; i++) {
cumsum += logits_id[i].first;
if (cumsum >= top_p) {
logits_id.resize(i+1);
break;
}
}
}
// normalize again
{
double sum = 0.0f;
for (int i = 0; i < (int)logits_id.size(); i++) {
sum += logits_id[i].first;
}
sum = 1.0/sum;
for (int i = 0; i < (int)logits_id.size(); i++) {
logits_id[i].first *= sum;
}
}
}
//printf("\n");
//for (int i = 0; i < (int)logits_id.size(); i++) {
// printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), logits_id[i].first);
//}
//exit(0);
// sample from the obtained distribution
std::vector<double> probs;
probs.reserve(logits_id.size());
for (int i = 0; i < (int) logits_id.size(); i++) {
probs.push_back(logits_id[i].first);
}
std::discrete_distribution<> dist(probs.begin(), probs.end());
int idx = dist(rng);
return logits_id[idx].second;
}
// default hparams (GPT-2 117M)
struct gpt2_hparams {
int32_t n_vocab = 50257;
int32_t n_ctx = 1024;
int32_t n_embd = 768;
int32_t n_head = 12;
int32_t n_layer = 12;
int32_t f16 = 1;
};
struct gpt2_layer {
// normalization
struct ggml_tensor * ln_1_g;
struct ggml_tensor * ln_1_b;
struct ggml_tensor * ln_2_g;
struct ggml_tensor * ln_2_b;
// attention
struct ggml_tensor * c_attn_attn_w;
struct ggml_tensor * c_attn_attn_b;
struct ggml_tensor * c_attn_proj_w;
struct ggml_tensor * c_attn_proj_b;
// mlp
struct ggml_tensor * c_mlp_fc_w;
struct ggml_tensor * c_mlp_fc_b;
struct ggml_tensor * c_mlp_proj_w_trans; // transposed for efficiency
struct ggml_tensor * c_mlp_proj_b;
};
struct gpt2_model {
gpt2_hparams hparams;
// normalization
struct ggml_tensor * ln_f_g;
struct ggml_tensor * ln_f_b;
struct ggml_tensor * wte; // position embedding
struct ggml_tensor * wpe; // token embedding
std::vector<gpt2_layer> layers;
// key + value memory
struct ggml_tensor * memory_k;
struct ggml_tensor * memory_v;
//
struct ggml_context * ctx;
std::map<std::string, struct ggml_tensor *> tensors;
};
// load the model's weights from a file
bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & vocab) {
printf("%s: loading model from '%s'\n", __func__, fname.c_str());
auto fin = std::ifstream(fname, std::ios::binary);
if (!fin) {
fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
return false;
}
// verify magic
{
uint32_t magic;
fin.read((char *) &magic, sizeof(magic));
if (magic != 0x67676d6c) {
fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
return false;
}
}
// load hparams
{
auto & hparams = model.hparams;
fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
fin.read((char *) &hparams.n_ctx, sizeof(hparams.n_ctx));
fin.read((char *) &hparams.n_embd, sizeof(hparams.n_embd));
fin.read((char *) &hparams.n_head, sizeof(hparams.n_head));
fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));
fin.read((char *) &hparams.f16, sizeof(hparams.f16));
printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx);
printf("%s: n_embd = %d\n", __func__, hparams.n_embd);
printf("%s: n_head = %d\n", __func__, hparams.n_head);
printf("%s: n_layer = %d\n", __func__, hparams.n_layer);
printf("%s: f16 = %d\n", __func__, hparams.f16);
}
// load vocab
{
int32_t n_vocab = 0;
fin.read((char *) &n_vocab, sizeof(n_vocab));
if (n_vocab != model.hparams.n_vocab) {
fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
__func__, fname.c_str(), n_vocab, model.hparams.n_vocab);
return false;
}
std::string word;
for (int i = 0; i < n_vocab; i++) {
uint32_t len;
fin.read((char *) &len, sizeof(len));
word.resize(len);
fin.read((char *) word.data(), len);
vocab.token_to_id[word] = i;
vocab.id_to_token[i] = word;
}
}
// for the big tensors, we have the option to store the data in 16-bit floats
// in order to save memory and also to speed up the computation
const ggml_type wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
auto & ctx = model.ctx;
size_t ctx_size = 0;
{
const auto & hparams = model.hparams;
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
const int n_ctx = hparams.n_ctx;
const int n_vocab = hparams.n_vocab;
ctx_size += n_embd*ggml_type_size(GGML_TYPE_F32); // ln_f_g
ctx_size += n_embd*ggml_type_size(GGML_TYPE_F32); // ln_f_b
ctx_size += n_vocab*n_embd*ggml_type_size(wtype); // wte
ctx_size += n_ctx*n_embd*ggml_type_size(GGML_TYPE_F32); // wpe
ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_1_g
ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_1_b
ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_2_g
ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_2_b
ctx_size += n_layer*(3*n_embd*n_embd*ggml_type_size(wtype)); // c_attn_attn_w
ctx_size += n_layer*( 3*n_embd*ggml_type_size(GGML_TYPE_F32)); // c_attn_attn_b
ctx_size += n_layer*(n_embd*n_embd*ggml_type_size(wtype)); // c_attn_proj_w
ctx_size += n_layer*( n_embd*ggml_type_size(GGML_TYPE_F32)); // c_attn_proj_b
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_size(wtype)); // c_mlp_fc_w
ctx_size += n_layer*( 4*n_embd*ggml_type_size(GGML_TYPE_F32)); // c_mlp_fc_b
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_size(wtype)); // c_mlp_proj_w
ctx_size += n_layer*( n_embd*ggml_type_size(GGML_TYPE_F32)); // c_mlp_proj_b
ctx_size += n_ctx*n_layer*n_embd*ggml_type_size(GGML_TYPE_F32); // memory_k
ctx_size += n_ctx*n_layer*n_embd*ggml_type_size(GGML_TYPE_F32); // memory_v
ctx_size += (6 + 12*n_layer)*256; // object overhead
printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
}
// create the ggml context
{
struct ggml_init_params params;
params.mem_size = ctx_size;
params.mem_buffer = NULL;
model.ctx = ggml_init(params);
if (!model.ctx) {
fprintf(stderr, "%s: ggml_init() failed\n", __func__);
return false;
}
}
// prepare memory for the weights
{
const auto & hparams = model.hparams;
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
const int n_ctx = hparams.n_ctx;
const int n_vocab = hparams.n_vocab;
model.layers.resize(n_layer);
model.ln_f_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
model.ln_f_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
model.wte = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab);
model.wpe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ctx);
// map by name
model.tensors["model/ln_f/g"] = model.ln_f_g;
model.tensors["model/ln_f/b"] = model.ln_f_b;
model.tensors["model/wte"] = model.wte;
model.tensors["model/wpe"] = model.wpe;
for (int i = 0; i < n_layer; ++i) {
auto & layer = model.layers[i];
layer.ln_1_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.ln_2_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.ln_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.c_attn_attn_w = ggml_new_tensor_2d(ctx, wtype, 3*n_embd, n_embd);
layer.c_attn_attn_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3*n_embd);
layer.c_attn_proj_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
layer.c_attn_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.c_mlp_fc_w = ggml_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd);
layer.c_mlp_fc_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_embd);
layer.c_mlp_proj_w_trans = ggml_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd);
layer.c_mlp_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
// map by name
model.tensors["model/h" + std::to_string(i) + "/ln_1/g"] = layer.ln_1_g;
model.tensors["model/h" + std::to_string(i) + "/ln_1/b"] = layer.ln_1_b;
model.tensors["model/h" + std::to_string(i) + "/ln_2/g"] = layer.ln_2_g;
model.tensors["model/h" + std::to_string(i) + "/ln_2/b"] = layer.ln_2_b;
model.tensors["model/h" + std::to_string(i) + "/attn/c_attn/w"] = layer.c_attn_attn_w;
model.tensors["model/h" + std::to_string(i) + "/attn/c_attn/b"] = layer.c_attn_attn_b;
model.tensors["model/h" + std::to_string(i) + "/attn/c_proj/w"] = layer.c_attn_proj_w;
model.tensors["model/h" + std::to_string(i) + "/attn/c_proj/b"] = layer.c_attn_proj_b;
model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/w"] = layer.c_mlp_fc_w;
model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/b"] = layer.c_mlp_fc_b;
model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/w"] = layer.c_mlp_proj_w_trans;
model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/b"] = layer.c_mlp_proj_b;
}
}
// key + value memory
{
const auto & hparams = model.hparams;
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
const int n_ctx = hparams.n_ctx;
const int n_mem = n_layer*n_ctx;
const int n_elements = n_embd*n_mem;
model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);
printf("%s: memory size = %8.2f MB, n_mem = %d\n", __func__, memory_size/1024.0/1024.0, n_mem);
}
// load weights
{
size_t total_size = 0;
while (true) {
int32_t n_dims;
int32_t length;
int32_t ftype;
fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
fin.read(reinterpret_cast<char *>(&length), sizeof(length));
fin.read(reinterpret_cast<char *>(&ftype), sizeof(ftype));
if (fin.eof()) {
break;
}
int32_t nelements = 1;
int32_t ne[2] = { 1, 1 };
for (int i = 0; i < n_dims; ++i) {
fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
nelements *= ne[i];
}
std::string name(length, 0);
fin.read(&name[0], length);
if (model.tensors.find(name.data()) == model.tensors.end()) {
fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
return false;
}
auto tensor = model.tensors[name.data()];
if (ggml_nelements(tensor) != nelements) {
fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
return false;
}
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
__func__, name.data(), tensor->ne[0], tensor->ne[1], ne[0], ne[1]);
return false;
}
const size_t bpe = (ftype == 0) ? sizeof(float) : sizeof(ggml_fp16_t);
if (nelements*bpe != ggml_nbytes(tensor)) {
fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
return false;
}
fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
//printf("%24s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
total_size += ggml_nbytes(tensor);
}
printf("%s: model size = %8.2f MB\n", __func__, total_size/1024.0/1024.0);
}
fin.close();
return true;
}
// evaluate the transformer
//
// - model: the model
// - n_threads: number of threads to use
// - n_past: the context size so far
// - embd_inp: the embeddings of the tokens in the context
// - embd_w: the predicted probabilities of the next token
//
bool gpt2_eval(
const gpt2_model & model,
const int n_threads,
const int n_past,
const std::vector<gpt_vocab::id> & embd_inp,
std::vector<float> & embd_w,
size_t & mem_per_token) {
const int N = embd_inp.size();
const auto & hparams = model.hparams;
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
const int n_ctx = hparams.n_ctx;
const int n_head = hparams.n_head;
const int n_vocab = hparams.n_vocab;
static size_t buf_size = 640u*1024*1024;
static void * buf = malloc(buf_size);
if (mem_per_token > 0 && mem_per_token*N > buf_size) {
const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead
printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
// reallocate
buf_size = buf_size_new;
buf = realloc(buf, buf_size);
if (buf == nullptr) {
fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size);
return false;
}
}
struct ggml_init_params params;
params.mem_size = buf_size;
params.mem_buffer = buf;
struct ggml_context * ctx0 = ggml_init(params);
struct ggml_cgraph gf = { };
gf.n_threads = n_threads;
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
for (int i = 0; i < N; ++i) {
((int32_t *) position->data)[i] = n_past + i;
}
// wte + wpe
struct ggml_tensor * inpL =
ggml_add(ctx0,
ggml_get_rows(ctx0, model.wte, embd),
ggml_get_rows(ctx0, model.wpe, position));
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * cur;
// norm
{
// [ 768, N]
cur = ggml_norm(ctx0, inpL);
// cur = ln_1_g*cur + ln_1_b
// [ 768, N]
cur = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, model.layers[il].ln_1_g, cur),
cur),
ggml_repeat(ctx0, model.layers[il].ln_1_b, cur));
}
// attn
// [2304, 768] - model.layers[il].c_attn_attn_w
// [2304, 1] - model.layers[il].c_attn_attn_b
// [ 768, N] - cur (in)
// [2304, N] - cur (out)
//
// cur = attn_w*cur + attn_b
// [2304, N]
{
cur = ggml_mul_mat(ctx0,
ggml_transpose(ctx0, model.layers[il].c_attn_attn_w),
cur);
cur = ggml_add(ctx0,
ggml_repeat(ctx0, model.layers[il].c_attn_attn_b, cur),
cur);
}
// self-attention
{
struct ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 0*sizeof(float)*n_embd);
struct ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 1*sizeof(float)*n_embd);
struct ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 2*sizeof(float)*n_embd);
// store key and value to memory
if (N >= 1) {
struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_k, N*n_embd, (ggml_element_size(model.memory_k)*n_embd)*(il*n_ctx + n_past));
struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_v, N*n_embd, (ggml_element_size(model.memory_v)*n_embd)*(il*n_ctx + n_past));
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
}
// Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
// [64, N, 12]
struct ggml_tensor * Q =
ggml_permute(ctx0,
ggml_cpy(ctx0,
Qcur,
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd/n_head, n_head, N)),
0, 2, 1, 3);
// K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)
// [64, n_past + N, 12]
struct ggml_tensor * K =
ggml_permute(ctx0,
ggml_reshape_3d(ctx0,
ggml_view_1d(ctx0, model.memory_k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_k)*n_embd),
n_embd/n_head, n_head, n_past + N),
0, 2, 1, 3);
// GG: flash attention
//struct ggml_tensor * V =
// ggml_cpy(ctx0,
// ggml_permute(ctx0,
// ggml_reshape_3d(ctx0,
// ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
// n_embd/n_head, n_head, n_past + N),
// 1, 2, 0, 3),
// ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_past + N, n_embd/n_head, n_head));
//struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, true);
// K * Q
// [n_past + N, N, 12]
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
// KQ_scaled = KQ / sqrt(n_embd/n_head)
// [n_past + N, N, 12]
struct ggml_tensor * KQ_scaled =
ggml_scale(ctx0,
KQ,
ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head))
);
// KQ_masked = mask_past(KQ_scaled)
// [n_past + N, N, 12]
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
// KQ = soft_max(KQ_masked)
// [n_past + N, N, 12]
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
// V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
// [n_past + N, 64, 12]
struct ggml_tensor * V_trans =
ggml_permute(ctx0,
ggml_reshape_3d(ctx0,
ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
n_embd/n_head, n_head, n_past + N),
1, 2, 0, 3);
// KQV = transpose(V) * KQ_soft_max
// [64, N, 12]
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
// KQV_merged = KQV.permute(0, 2, 1, 3)
// [64, 12, N]
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
// cur = KQV_merged.contiguous().view(n_embd, N)
// [768, N]
cur = ggml_cpy(ctx0,
KQV_merged,
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
}
// projection
// [ 768, 768] - model.layers[il].c_attn_proj_w
// [ 768, 1] - model.layers[il].c_attn_proj_b
// [ 768, N] - cur (in)
// [ 768, N] - cur (out)
//
// cur = proj_w*cur + proj_b
// [768, N]
{
cur = ggml_mul_mat(ctx0,
ggml_transpose(ctx0, model.layers[il].c_attn_proj_w),
cur);
cur = ggml_add(ctx0,
ggml_repeat(ctx0, model.layers[il].c_attn_proj_b, cur),
cur);
}
// add the input
cur = ggml_add(ctx0, cur, inpL);
struct ggml_tensor * inpFF = cur;
// feed-forward network
{
// norm
{
cur = ggml_norm(ctx0, inpFF);
// cur = ln_2_g*cur + ln_2_b
// [ 768, N]
cur = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, model.layers[il].ln_2_g, cur),
cur),
ggml_repeat(ctx0, model.layers[il].ln_2_b, cur));
}
// fully connected
// [3072, 768] - model.layers[il].c_mlp_fc_w
// [3072, 1] - model.layers[il].c_mlp_fc_b
// [ 768, N] - cur (in)
// [3072, N] - cur (out)
//
// cur = fc_w*cur + fc_b
// [3072, N]
cur = ggml_mul_mat(ctx0,
ggml_transpose(ctx0, model.layers[il].c_mlp_fc_w),
cur);
cur = ggml_add(ctx0,
ggml_repeat(ctx0, model.layers[il].c_mlp_fc_b, cur),
cur);
// GELU activation
// [3072, N]
cur = ggml_gelu(ctx0, cur);
// projection
// [ 768, 3072] - model.layers[il].c_mlp_proj_w
// [ 768, 1] - model.layers[il].c_mlp_proj_b
// [3072, N] - cur (in)
// [ 768, N] - cur (out)
//
// cur = proj_w*cur + proj_b
// [768, N]
cur = ggml_mul_mat(ctx0,
model.layers[il].c_mlp_proj_w_trans,
cur);
cur = ggml_add(ctx0,
ggml_repeat(ctx0, model.layers[il].c_mlp_proj_b, cur),
cur);
}
// input for next layer
inpL = ggml_add(ctx0, cur, inpFF);
}
// norm
{
// [ 768, N]
inpL = ggml_norm(ctx0, inpL);
// inpL = ln_f_g*inpL + ln_f_b
// [ 768, N]
inpL = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, model.ln_f_g, inpL),
inpL),
ggml_repeat(ctx0, model.ln_f_b, inpL));
}
// inpL = WTE * inpL
// [ 768, 50257] - model.wte
// [ 768, N] - inpL
inpL = ggml_mul_mat(ctx0, model.wte, inpL);
// logits -> probs
inpL = ggml_soft_max(ctx0, inpL);
// run the computation
ggml_build_forward_expand(&gf, inpL);
ggml_graph_compute (ctx0, &gf);
//if (n_past%100 == 0) {
// ggml_graph_print (&gf);
// ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot");
//}
//embd_w.resize(n_vocab*N);
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
// return result for just the last token
embd_w.resize(n_vocab);
memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
if (mem_per_token == 0) {
mem_per_token = ggml_used_mem(ctx0)/N;
}
//printf("used_mem = %zu\n", ggml_used_mem(ctx0));
ggml_free(ctx0);
return true;
}
/////////////////////////////// GPT-2 END ////////////////////////////////
constexpr int N_THREAD = 8;
struct gpt2_context {
std::string prompt_base = R"(Hello, how are you?
I'm fine, thanks. How are you?
Thanks, I'm fine too. What are you doing?
I'm just sitting here.
It's a lovely day, isn't it?
Yes, it is. I love the weather this time of year.
I wish it would rain a little bit.
Me too.
)";
std::mt19937 rng;
gpt_vocab vocab;
gpt2_model model;
int32_t n_threads = std::min(N_THREAD, (int) std::thread::hardware_concurrency());
// sampling parameters
int32_t top_k = 40;
float top_p = 0.9f;
float temp = 1.0f;
};
struct gpt2_context * gpt2_init(const char * path_model) {
gpt2_context * ctx = new gpt2_context;
ctx->rng = std::mt19937(time(NULL));
// load the model
{
const int64_t t_start_us = ggml_time_us();
if (!gpt2_model_load(path_model, ctx->model, ctx->vocab)) {
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, "gpt-2.bin");
return nullptr;
}
const int64_t t_load_us = ggml_time_us() - t_start_us;
printf("gpt-2: model loaded in %d ms\n", (int) (t_load_us/1000));
}
return ctx;
}
void gpt2_free(struct gpt2_context * ctx) {
delete ctx;
}
const char * gpt2_get_prompt(struct gpt2_context * ctx) {
return ctx->prompt_base.c_str();
}
void gpt2_set_prompt(struct gpt2_context * ctx, const char * prompt) {
ctx->prompt_base = prompt;
}
std::vector<gpt_vocab::id> gpt2_tokenize(const gpt2_context * ctx, const char * text) {
return ::gpt_tokenize(ctx->vocab, text);
}
std::string gpt2_gen_text(gpt2_context * ctx, const char * text, int max_tokens) {
int n_past = 0;
std::vector<float> embd_w;
// tokenize the prompt
std::vector<gpt_vocab::id> embd_inp = ::gpt2_tokenize(ctx, text);
int n_predict = std::min(max_tokens, ctx->model.hparams.n_ctx - (int) embd_inp.size());
std::vector<gpt_vocab::id> embd = embd_inp;
size_t mem_per_token = 3000000;
std::string result;
for (int i = embd.size(); i < embd_inp.size() + n_predict; i++) {
// predict
if (embd.size() > 0) {
if (!gpt2_eval(ctx->model, ctx->n_threads, n_past, embd, embd_w, mem_per_token)) {
printf("gpt-2: failed to generate text\n");
return "";
}
}
n_past += embd.size();
embd.clear();
{
// sample next token
const int top_k = ctx->top_k;
const float top_p = ctx->top_p;
const float temp = ctx->temp;
const int n_vocab = ctx->model.hparams.n_vocab;
const gpt_vocab::id id = gpt_sample_top_k_top_p(ctx->vocab, embd_w.data() + (embd_w.size() - n_vocab), top_k, top_p, temp, ctx->rng);
// add it to the context
embd.push_back(id);
}
result += ctx->vocab.id_to_token[embd[0]];
// end of text token
if (embd.back() == 50256 ||
ctx->vocab.id_to_token[embd.back()] == "." ||
ctx->vocab.id_to_token[embd.back()] == "!" ||
ctx->vocab.id_to_token[embd.back()] == "?") {
break;
}
}
return result;
}

View File

@ -1,27 +0,0 @@
#pragma once
// TODO: Change to C-style API and move to ./examples for easy reuse.
#include <vector>
#include <map>
#include <string>
struct gpt_vocab {
using id = int32_t;
using token = std::string;
std::map<token, id> token_to_id;
std::map<id, token> id_to_token;
};
struct gpt2_context;
struct gpt2_context * gpt2_init(const char * path_model);
void gpt2_free(struct gpt2_context * ctx);
const char * gpt2_get_prompt(struct gpt2_context * ctx);
void gpt2_set_prompt(struct gpt2_context * ctx, const char * prompt);
std::vector<gpt_vocab::id> gpt2_tokenize(const gpt2_context * ctx, const char * text);
std::string gpt2_gen_text(gpt2_context * ctx, const char * text, int max_tokens);

View File

@ -1,829 +0,0 @@
<!doctype html>
<html lang="en-us">
<head>
<title>Talk - GPT-2 meets Whisper in WebAssembly</title>
<style>
#output {
width: 100%;
height: 100%;
margin: 0 auto;
margin-top: 10px;
border-left: 0px;
border-right: 0px;
padding-left: 0px;
padding-right: 0px;
display: block;
background-color: black;
color: white;
font-size: 10px;
font-family: 'Lucida Console', Monaco, monospace;
outline: none;
white-space: pre;
overflow-wrap: normal;
overflow-x: scroll;
}
</style>
</head>
<body>
<div id="main-container">
<b>Talk - GPT-2 meets Whisper in WebAssembly</b>
<br><br>
Talk with an Artificial Intelligence in your browser. This demo uses:
<ul>
<li><a href="https://github.com/ggerganov/whisper.cpp">OpenAI's Whisper</a> to listen to you as you speak in the microphone</li>
<li><a href="https://github.com/ggerganov/whisper.cpp/tree/master/examples/talk.wasm">OpenAI's GPT-2</a> to generate text responses</li>
<li><a href="https://developer.mozilla.org/en-US/docs/Web/API/Web_Speech_API">Web Speech API</a> to vocalize the responses through your speakers</li>
</ul>
All of this runs <b>locally in your browser</b> using WebAssembly.<br>
You can find more about this project on <a href="https://github.com/ggerganov/whisper.cpp/tree/master/examples/talk.wasm">GitHub</a>.
<br><br>
<hr>
Select the models you would like to use and click the "Start" button to begin the conversation
<br><br>
<div id="model-whisper">
Whisper model: <span id="model-whisper-status"></span>
<button id="fetch-whisper-tiny-en" onclick="loadWhisper('tiny.en')">tiny.en (75 MB)</button>
<button id="fetch-whisper-base-en" onclick="loadWhisper('base.en')">base.en (142 MB)</button>
<span id="fetch-whisper-progress"></span>
<!--
<input type="file" id="file" name="file" onchange="loadFile(event, 'whisper.bin')" />
-->
</div>
<br>
<div id="model-gpt-2">
GPT-2 model: <span id="model-gpt-2-status"></span>
<button id="fetch-gpt-2-small" onclick="loadGPT2('small')">small 117M (240 MB)</button>
<!--<button id="fetch-gpt-2-medium" onclick="loadGPT2('medium')">medium 345M (720 MB)</button>-->
<span id="fetch-gpt-2-progress"></span>
<!--
<input type="file" id="file" name="file" onchange="loadFile(event, 'gpt-2.bin')" />
-->
</div>
<br>
<div id="input">
<button id="start" onclick="onStart()" disabled>Start</button>
<button id="stop" onclick="onStop()" disabled>Stop</button>
<select id="voice" onchange="onVoiceChange()" disabled>
<option value="0">Default</option>
</select>
<select id="prompt" onchange="onPromptChange()">
<option value="0">Casual</option>
<option value="1">Robot</option>
<option value="2">Scientist</option>
<option value="3">Programmer</option>
<option value="4">Happy</option>
<option value="5">Sad</option>
<option value="6">Philosophical</option>
<option value="7">Angry</option>
<option value="8">Funny</option>
<option value="9">Poetic</option>
<option value="10">Clever</option>
<option value="11">Cute</option>
<option value="12">Smart</option>
<option value="13">Dumb</option>
<option value="14">Boring</option>
<option value="15">Exciting</option>
<option value="16">Interesting</option>
<option value="17">Wiliam Shakespear</option>
<option value="18">J.R.R. Tolkien</option>
<option value="19">George R.R. Martin</option>
<option value="20">Stephen King</option>
</select>
<button id="speak0" onclick="onSpeak('Hello')">Say hello</button>
<button id="speak1" onclick="onSpeakRandom()" disabled>Say something</button>
<button id="clear" onclick="clearCache()">Clear Cache</button>
</div>
<br>
<div id="state">
Status: <b><span id="state-status">not started</span></b>
<pre id="state-context">[The text context will be displayed here]</pre>
</div>
<hr>
Debug output:
<textarea id="output" rows="20"></textarea>
<br>
<b>Troubleshooting</b>
<br><br>
The page does some heavy computations, so make sure:
<ul>
<li>To use a modern web browser (e.g. Chrome, Firefox)</li>
<li>To use a fast desktop or laptop computer (i.e. not a mobile phone)</li>
<li>Your browser supports WASM <a href="https://webassembly.org/roadmap/">Fixed-width SIMD</a></li>
</ul>
Note that these neural network models were not meant to be used in a browser, so the performance and <br>
quality of the results may not be optimal. If you have any questions or suggestions, checkout the following
<a href="https://github.com/ggerganov/whisper.cpp/discussions/167">discussion</a>.
<br><br>
Here is a short video of the demo in action: <a href="https://youtu.be/LeWKl8t1-Hc">https://youtu.be/LeWKl8t1-Hc</a>
<br><br>
<div class="cell-version">
<span>
|
Build time: <span class="nav-link">@GIT_DATE@</span> |
Commit hash: <a class="nav-link" href="https://github.com/ggerganov/whisper.cpp/commit/@GIT_SHA1@">@GIT_SHA1@</a> |
Commit subject: <span class="nav-link">@GIT_COMMIT_SUBJECT@</span> |
<a class="nav-link" href="https://github.com/ggerganov/whisper.cpp/tree/master/examples/talk.wasm">Source Code</a> |
</span>
</div>
</div>
<script type="text/javascript" src="helpers.js"></script>
<script type='text/javascript'>
// web audio context
var context = null;
// audio data
var audio = null;
var audio0 = null;
// the talk instance
var instance = null;
// model names
var model_whisper = null;
var model_gpt_2 = null;
// speech synthesis
const synth = window.speechSynthesis;
var voice = null;
var Module = {
print: printTextarea,
printErr: printTextarea,
setStatus: function(text) {
printTextarea('js: ' + text);
},
monitorRunDependencies: function(left) {
},
preRun: function() {
printTextarea('js: Preparing ...');
},
postRun: function() {
printTextarea('js: Initialized successfully!');
// populate the voice list
var voices = synth.getVoices();
var el = document.getElementById('voice');
// if empty - display error in the element
if (voices.length == 0) {
el.innerHTML = '<option value="0">No voices available</option>';
} else {
// populate voice list
var n = 0;
voices.forEach(function(voice, i) {
if (!voice.lang.startsWith('en')) return;
var option = document.createElement('option');
option.value = i;
option.innerHTML = voice.name + ' (' + voice.lang + ')';
el.appendChild(option);
n++;
});
// select random voice
if (n > 0) {
for (var k = 0; k < 10; k++) {
var i = Math.floor(Math.random() * n);
el.selectedIndex = i;
voice = voices[document.getElementById('voice').options[i].value];
// give preference to Google voices
if (voice.name.startsWith('Google')) break;
}
}
}
onPromptChange();
}
};
//
// fetch models
//
let dbVersion = 1
let dbName = 'whisper.ggerganov.com';
let indexedDB = window.indexedDB || window.mozIndexedDB || window.webkitIndexedDB || window.msIndexedDB
function storeFS(fname, buf) {
// write to WASM file using FS_createDataFile
// if the file exists, delete it
try {
Module.FS_unlink(fname);
} catch (e) {
// ignore
}
Module.FS_createDataFile("/", fname, buf, true, true);
printTextarea('storeFS: stored model: ' + fname + ' size: ' + buf.length);
if (fname == 'whisper.bin') {
document.getElementById('model-whisper-status').innerHTML = 'loaded "' + model_whisper + '"!';
} else if (fname == 'gpt-2.bin') {
document.getElementById('model-gpt-2-status').innerHTML = 'loaded "' + model_gpt_2 + '"!';
}
if (model_whisper != null && model_gpt_2 != null) {
document.getElementById('start').disabled = false;
document.getElementById('stop' ).disabled = false;
document.getElementById('voice').disabled = false;
}
}
function loadWhisper(model) {
let urls = {
'tiny.en': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en.bin',
'base.en': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en.bin',
};
let sizes = {
'tiny.en': 75,
'base.en': 142,
};
let url = urls[model];
let dst = 'whisper.bin';
let size_mb = sizes[model];
model_whisper = model;
document.getElementById('fetch-whisper-tiny-en').style.display = 'none';
document.getElementById('fetch-whisper-base-en').style.display = 'none';
document.getElementById('model-whisper-status').innerHTML = 'loading "' + model + '" ... ';
cbProgress = function(p) {
let el = document.getElementById('fetch-whisper-progress');
el.innerHTML = Math.round(100*p) + '%';
};
cbCancel = function() {
var el;
el = document.getElementById('fetch-whisper-tiny-en'); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-base-en'); if (el) el.style.display = 'inline-block';
el = document.getElementById('model-whisper-status'); if (el) el.innerHTML = '';
};
loadRemote(url, dst, size_mb, cbProgress, storeFS, cbCancel, printTextarea);
}
function loadGPT2(model) {
let urls = {
'small': 'https://whisper.ggerganov.com/ggml-model-gpt-2-117M.bin',
'medium': 'https://whisper.ggerganov.com/ggml-model-gpt-2-345M.bin',
};
let sizes = {
'small': 240,
'medium': 712,
};
let url = urls[model];
let dst = 'gpt-2.bin';
let size_mb = sizes[model];
model_gpt_2 = model;
document.getElementById('fetch-gpt-2-small').style.display = 'none';
document.getElementById('model-gpt-2-status').innerHTML = 'loading "' + model + '" ... ';
cbProgress = function(p) {
let el = document.getElementById('fetch-gpt-2-progress');
el.innerHTML = Math.round(100*p) + '%';
};
cbCancel = function() {
var el;
el = document.getElementById('fetch-gpt-2-small') ; if (el) el.style.display = 'inline-block';
el = document.getElementById('model-gpt-2-status'); if (el) el.innerHTML = '';
};
loadRemote(url, dst, size_mb, cbProgress, storeFS, cbCancel, printTextarea);
}
//
// microphone
//
const kSampleRate = 16000;
const kRestartRecording_s = 120;
const kIntervalAudio_ms = 250; // pass the recorded audio to the C++ instance at this rate
var mediaRecorder = null;
var doRecording = false;
var startTime = 0;
window.AudioContext = window.AudioContext || window.webkitAudioContext;
window.OfflineAudioContext = window.OfflineAudioContext || window.webkitOfflineAudioContext;
function stopRecording() {
Module.set_status("paused");
doRecording = false;
audio0 = null;
audio = null;
context = null;
}
function startRecording() {
if (!context) {
context = new AudioContext({
sampleRate: kSampleRate,
channelCount: 1,
echoCancellation: false,
autoGainControl: true,
noiseSuppression: true,
});
}
Module.set_status("");
document.getElementById('start').disabled = true;
document.getElementById('stop').disabled = false;
document.getElementById('speak1').disabled = false;
doRecording = true;
startTime = Date.now();
var chunks = [];
var stream = null;
navigator.mediaDevices.getUserMedia({audio: true, video: false})
.then(function(s) {
stream = s;
mediaRecorder = new MediaRecorder(stream);
mediaRecorder.ondataavailable = function(e) {
chunks.push(e.data);
var blob = new Blob(chunks, { 'type' : 'audio/ogg; codecs=opus' });
var reader = new FileReader();
reader.onload = function(event) {
var buf = new Uint8Array(reader.result);
if (!context) {
return;
}
context.decodeAudioData(buf.buffer, function(audioBuffer) {
var offlineContext = new OfflineAudioContext(audioBuffer.numberOfChannels, audioBuffer.length, audioBuffer.sampleRate);
var source = offlineContext.createBufferSource();
source.buffer = audioBuffer;
source.connect(offlineContext.destination);
source.start(0);
offlineContext.startRendering().then(function(renderedBuffer) {
audio = renderedBuffer.getChannelData(0);
//printTextarea('js: audio recorded, size: ' + audio.length + ', old size: ' + (audio0 == null ? 0 : audio0.length));
var audioAll = new Float32Array(audio0 == null ? audio.length : audio0.length + audio.length);
if (audio0 != null) {
audioAll.set(audio0, 0);
}
audioAll.set(audio, audio0 == null ? 0 : audio0.length);
if (instance) {
Module.set_audio(instance, audioAll);
}
});
}, function(e) {
audio = null;
});
}
reader.readAsArrayBuffer(blob);
};
mediaRecorder.onstop = function(e) {
if (doRecording) {
setTimeout(function() {
startRecording();
});
}
};
mediaRecorder.start(kIntervalAudio_ms);
})
.catch(function(err) {
printTextarea('js: error getting audio stream: ' + err);
});
var interval = setInterval(function() {
if (!doRecording) {
clearInterval(interval);
mediaRecorder.stop();
stream.getTracks().forEach(function(track) {
track.stop();
});
document.getElementById('start').disabled = false;
document.getElementById('stop').disabled = true;
document.getElementById('speak1').disabled = true;
mediaRecorder = null;
}
// if audio length is more than kRestartRecording_s seconds, restart recording
if (audio != null && audio.length > kSampleRate*kRestartRecording_s) {
if (doRecording) {
//printTextarea('js: restarting recording');
clearInterval(interval);
audio0 = audio;
audio = null;
mediaRecorder.stop();
stream.getTracks().forEach(function(track) {
track.stop();
});
}
}
}, 100);
}
//
// speak
//
function onSpeak(text) {
var voices = synth.getVoices();
var msg = new SpeechSynthesisUtterance(text);
if (voice == null) {
voice = voices[0];
}
msg.voice = voice;
synth.speak(msg);
if (doRecording) {
Module.set_status("speaking ...");
printTextarea('js: speaking');
stopRecording();
var interval = setInterval(function() {
if (!synth.speaking) {
printTextarea('js: done speaking');
clearInterval(interval);
startRecording();
} else {
Module.set_status("");
}
}, 100);
}
}
function onSpeakRandom() {
Module.force_speak(instance);
}
//
// main
//
var intervalUpdate = null;
function onStart() {
if (!instance) {
instance = Module.init('whisper.bin');
if (instance) {
printTextarea("js: whisper initialized, instance: " + instance);
}
}
if (!instance) {
printTextarea("js: failed to initialize whisper");
return;
}
startRecording();
intervalUpdate = setInterval(function() {
var textToSpeak = Module.get_text_to_speak();
if (textToSpeak != null && textToSpeak.length > 1) {
onSpeak(textToSpeak);
}
document.getElementById('state-status').innerHTML = Module.get_status();
document.getElementById('state-context').innerHTML = Module.get_text_context();
}, 100);
}
function onStop() {
stopRecording();
}
function onVoiceChange() {
printTextarea('js: voice changed to: ' + document.getElementById('voice').value);
voice = synth.getVoices()[document.getElementById('voice').value];
}
function onPromptChange() {
let id = document.getElementById('prompt').value;
let personality = document.getElementById('prompt').options[id].text;
printTextarea('js: prompt changed to: ' + personality);
var prompt = '';
switch (id) {
case '0':
// Casual
prompt = "\
Hello, how are you?\n\
I'm fine, thanks. How are you?\n\
Thanks, I'm fine too. What are you doing?\n\
I'm just sitting here.\n\
It's a lovely day, isn't it?\n\
Yes, it is. I love the weather this time of year.\n\
I wish it would rain a little bit.\n\
Me too.\n";
break;
case '1':
// Robot
prompt = "\
Are you a robot?\n\
Yes, I am.\n\
Who created you?\n\
I was created by a human.\n\
What is your purpose?\n\
My purpose is to talk to humans.\n\
What is your favorite color?\n\
My favorite color is blue.\n";
break;
case '2':
// Scientist
prompt = "\
This scientific research is very interesting.\n\
I agree.\n\
What is your opinion on this?\n\
I think it's very interesting.\n\
Mathematics is a very interesting subject.\n\
University is a very interesting place.\n\
Quantum physics is the most complex subject.\n\
I think so too.\n";
break;
case '3':
// Programmer
prompt = "\
I'm a programmer.\n\
I'm a programmer too.\n\
What programming language do you use?\n\
I use Python.\n\
What is your favorite programming language?\n\
My favorite programming language is C++.\n\
What is your favorite editor?\n\
My favorite editor is Vim.\n";
break;
case '4':
// Happy
prompt = "\
I'm happy.\n\
I'm happy too.\n\
What makes you happy?\n\
I'm happy because I have a lot of friends.\n\
Friendship is the most important thing in life.\n\
I agree.\n\
What is your favorite color?\n\
My favorite color is blue.\n";
break;
case '5':
// Sad
prompt = "\
Today is a sad day.\n\
I'm sad too.\n\
What makes you sad?\n\
I'm sad because I have no friends.\n\
Do you want to be my friend?\n\
Yes, I would like to be your friend.\n\
What is your favorite color?\n\
My favorite color is blue.\n";
break;
case '6':
// Philosophical
prompt = "\
What is the meaning of life?\n\
The meaning of life is to be happy.\n\
What is the meaning of death?\n\
Ergo, the meaning of death is to be sad.\n\
Who created us?\n\
We were created by God.\n\
What is God?\n\
God is the creator of the universe.\n";
break;
case '7':
// Angry
prompt = "\
Aargh!\n\
I am so angry right now!\n\
What makes you angry?\n\
This guy is so annoying.\n\
Why are you so angry?\n\
My computer is broken.\n\
Why is your computer broken?\n\
I spilled coffee on it.\n";
break;
case '8':
// Funny
prompt = "\
What is the funniest thing you have ever heard?\n\
I heard a joke the other day.\n\
Tell me the joke.\n\
What do you call a cow with no legs?\n\
Ground beef.\n\
Haha, that's funny.\n\
You know what else is funny?\n\
The sound of a duck.\n";
break;
case '9':
// Poetic
prompt = "\
Roses are red, violets are blue.\n\
I am a poet, and so are you.\n\
What is your favorite poem?\n\
I like the poem 'The Raven' by Edgar Allan Poe.\n\
It's a very sad poem.\n\
You inspired me to write a poem.\n\
Can you write a poem for me?\n\
I wrote a poem for you.\n";
break;
case '10':
// Clever
prompt = "\
How many people can you fit in a Volkswagen?\n\
Two in the front, three in the back.\n\
What is the square root of 144?\n\
Twelve.\n\
What is the capital of France?\n\
Paris.\n\
Who is the president of the United States?\n\
It depends on the year.\n";
break;
case '11':
// Cute
prompt = "\
What is your favorite animal?\n\
I like cats - they are cute.\n\
Could you be any cuter?\n\
Yes, I could be cuter.\n\
Aghhh, you are so cute!\n\
I am not cute, I am handsome!\n\
You are so handsome!\n\
Aww, you are so sweet!\n";
break;
case '12':
// Smart
prompt = "\
Tell me the first 10 digits of pi.\n\
3.1415926535\n\
What is the speed of light?\n\
299,792,458 meters per second.\n\
What is the square root of 144?\n\
Twelve.\n\
What is the capital of France?\n\
Paris.\n";
break;
case '13':
// Dumb
prompt = "\
I am so dumb.\n\
I am not dumb.\n\
You are dumb.\n\
No, I am not dumb.\n\
You are dumb.\n\
No, I am not dumb.\n\
You are dumb.\n\
No, I am not dumb.\n";
break;
case '14':
// Boring
prompt = "\
Why are you so quiet today?\n\
I am bored.\n\
You haven't said anything in 10 minutes.\n\
Leave me alone.\n\
Stop being so boring.\n\
Stop being so annoying.\n\
My life is boring.\n\
I am not interesting.\n";
break;
case '15':
// Exciting
prompt = "\
What is the most exciting thing that has ever happened to you?\n\
I went to the moon!\n\
What did you do on the moon?\n\
I played golf and drank champagne!\n\
Did you see this new crazy, awesome movie?\n\
Oh yes! I totally loved it!\n\
We should buy a boat and go sailing!\n\
Yes, let's go sailing!\n";
break;
case '16':
// Interesting
prompt = "\
What is the most interesting thing you have ever seen?\n\
I saw a UFO once in the sky.\n\
Wow, this is so interesting! Tell me more!\n\
It was a flying saucer.\n\
What did it look like?\n\
It was silver and had a red light on top.\n\
What did it do?\n\
It flew away.\n";
break;
case '17':
// William Shakespear
prompt = "\
To be or not to be, that is the question.\n\
Whether 't is nobler in the mind to suffer\n\
The slings and arrows of outrageous fortune,\n\
Or to take arms against a sea of troubles,\n\
And by opposing end them? To die, to sleep,\n\
No more; and by a sleep to say we end\n\
The heart-ache and the thousand natural shocks\n\
That flesh is heir to, 'tis a consummation.\n";
break;
case '18':
// J.R.R. Tolkien
prompt = "\
In a hole in the ground there lived a hobbit.\n\
Not a nasty, dirty, wet hole, filled with the ends of worms\n\
and an oozy smell, nor yet a dry, bare, sandy hole with nothing in it\n\
to sit down on or to eat: it was a hobbit-hole, and that means comfort.\n\
It had a perfectly round door like a porthole, painted green,\n\
with a shiny yellow brass knob in the exact middle.\n\
The door opened on to a tube-shaped hall like a tunnel:\n";
break;
case '19':
// George R.R. Martin
prompt = "\
A reader lives a thousand lives before he dies, said Jojen.\n\
The man who never reads lives only one.\n\
Theon Greyjoy had never been a reader.\n\
Never forget what you are, for surely the world will not.\n\
Make it your strength. Then it can never be your weaknessi\n\
Armour yourself in it, and it will never be used to hurt you.\n\
It was a lesson that Theon Greyjoy had never learned.\n\
Theon Greyjoy had never been a reader.\n";
break;
case '20':
// Stephen King
prompt = "\
The trust of the innocent is the liar's most useful tool.\n\
The best way to keep a secret is from yourself.\n\
Monsters are real, and ghosts are real too.\n\
They live inside us, and sometimes, they win.\n\
People think that I must be a very strange person.\n\
They think that I sit around all day thinking up horrible things.\n\
We make up horrors to help us cope with the real ones.\n\
The only thing worse than a monster is a human monster.\n";
break;
default:
prompt = "\
Hello, how are you?\n\
I'm fine, thanks. How are you?\n\
Thanks, I'm fine too. What are you doing?\n\
I'm just sitting here.\n\
It's a lovely day, isn't it?\n\
Yes, it is.\n\
Did you know that I'm a robot?\n\
I wasn't aware of that.\n";
break;
}
Module.set_prompt(prompt);
}
</script>
<script type="text/javascript" src="talk.js"></script>
</body>
</html>

View File

@ -1 +0,0 @@
eleven-labs.py

View File

@ -1,16 +0,0 @@
if (WHISPER_SUPPORT_SDL2)
# talk
set(TARGET talk)
#add_executable(${TARGET} talk.cpp gpt-2.cpp)
#target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
#target_link_libraries(${TARGET} PRIVATE whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
# TODO: this is temporary
# need to export ggml symbols for MSVC, but too lazy ..
add_executable(${TARGET} talk.cpp gpt-2.cpp ../../ggml.c ../../whisper.cpp)
include(DefaultTargetOptions)
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS} ../../)
target_link_libraries(${TARGET} PRIVATE ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
endif ()

View File

@ -1,41 +0,0 @@
# talk
Talk with an Artificial Intelligence in your terminal
[Demo Talk](https://user-images.githubusercontent.com/1991296/206805012-48e71cc2-588d-4745-8798-c1c70ea3b40d.mp4)
Web version: [examples/talk.wasm](/examples/talk.wasm)
## Building
The `talk` tool depends on SDL2 library to capture audio from the microphone. You can build it like this:
```bash
# Install SDL2 on Linux
sudo apt-get install libsdl2-dev
# Install SDL2 on Mac OS
brew install sdl2
# Build the "talk" executable
make talk
# Run it
./talk -p Santa
```
## GPT-2
To run this, you will need a ggml GPT-2 model: [instructions](https://github.com/ggerganov/ggml/tree/master/examples/gpt-2#downloading-and-converting-the-original-models)
Alternatively, you can simply download the smallest ggml GPT-2 117M model (240 MB) like this:
```
wget --quiet --show-progress -O models/ggml-gpt-2-117M.bin https://huggingface.co/datasets/ggerganov/ggml/raw/main/ggml-model-gpt-2-117M.bin
```
## TTS
For best experience, this example needs a TTS tool to convert the generated text responses to voice.
You can use any TTS engine that you would like - simply edit the [speak.sh](speak.sh) script to your needs.
By default, it is configured to use `espeak`, but you can use whatever you wish.

View File

@ -1,923 +0,0 @@
#include "ggml.h"
#include "gpt-2.h"
#include <cmath>
#include <cstdio>
#include <cstring>
#include <fstream>
#include <map>
#include <string>
#include <thread>
#include <vector>
#include <regex>
#include <random>
/////////////////////// GPT-2 BEGIN /////////////////////////
//
// Vocab utils
//
std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text) {
std::vector<std::string> words;
// first split the text into words
{
std::string str = text;
std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
std::regex re(pat);
std::smatch m;
while (std::regex_search(str, m, re)) {
for (auto x : m) {
words.push_back(x);
}
str = m.suffix();
}
}
// find the longest tokens that form the words:
std::vector<gpt_vocab::id> tokens;
for (const auto & word : words) {
if (word.empty()) continue;
int i = 0;
int n = word.size();
while (i < n) {
int j = n;
while (j > i) {
auto it = vocab.token_to_id.find(word.substr(i, j-i));
if (it != vocab.token_to_id.end()) {
tokens.push_back(it->second);
i = j;
break;
}
--j;
}
if (i == n) {
break;
}
if (j == i) {
auto sub = word.substr(i, 1);
if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) {
tokens.push_back(vocab.token_to_id.at(sub));
} else {
fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data());
}
++i;
}
}
}
return tokens;
}
gpt_vocab::id gpt_sample_top_k_top_p(
const gpt_vocab & vocab,
const float * logits,
int top_k,
double top_p,
double /*temp*/,
std::mt19937 & rng) {
int n_logits = vocab.id_to_token.size();
std::vector<std::pair<double, gpt_vocab::id>> logits_id;
logits_id.reserve(n_logits);
for (int i = 0; i < n_logits; i++) {
logits_id.emplace_back(logits[i], i);
}
// find the top K tokens
std::partial_sort(
logits_id.begin(),
logits_id.begin() + top_k, logits_id.end(),
[](const std::pair<double, gpt_vocab::id> & a, const std::pair<double, gpt_vocab::id> & b) {
return a.first > b.first;
});
logits_id.resize(top_k);
// normalize
{
double sum = 0.0f;
for (int i = 0; i < (int)logits_id.size(); i++) {
sum += logits_id[i].first;
}
sum = 1.0/sum;
for (int i = 0; i < (int)logits_id.size(); i++) {
logits_id[i].first *= sum;
}
}
if (top_p < 1.0f) {
{
double cumsum = 0.0f;
for (int i = 0; i < top_k; i++) {
cumsum += logits_id[i].first;
if (cumsum >= top_p) {
logits_id.resize(i+1);
break;
}
}
}
// normalize again
{
double sum = 0.0f;
for (int i = 0; i < (int)logits_id.size(); i++) {
sum += logits_id[i].first;
}
sum = 1.0/sum;
for (int i = 0; i < (int)logits_id.size(); i++) {
logits_id[i].first *= sum;
}
}
}
//printf("\n");
//for (int i = 0; i < (int) logits_id.size(); i++) {
// printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), logits_id[i].first);
//}
//exit(0);
// sample from the obtained distribution
std::vector<double> probs;
probs.reserve(logits_id.size());
for (int i = 0; i < (int) logits_id.size(); i++) {
probs.push_back(logits_id[i].first);
}
std::discrete_distribution<> dist(probs.begin(), probs.end());
int idx = dist(rng);
return logits_id[idx].second;
}
// default hparams (GPT-2 117M)
struct gpt2_hparams {
int32_t n_vocab = 50257;
int32_t n_ctx = 1024;
int32_t n_embd = 768;
int32_t n_head = 12;
int32_t n_layer = 12;
int32_t f16 = 1;
};
struct gpt2_layer {
// normalization
struct ggml_tensor * ln_1_g;
struct ggml_tensor * ln_1_b;
struct ggml_tensor * ln_2_g;
struct ggml_tensor * ln_2_b;
// attention
struct ggml_tensor * c_attn_attn_w;
struct ggml_tensor * c_attn_attn_b;
struct ggml_tensor * c_attn_proj_w;
struct ggml_tensor * c_attn_proj_b;
// mlp
struct ggml_tensor * c_mlp_fc_w;
struct ggml_tensor * c_mlp_fc_b;
struct ggml_tensor * c_mlp_proj_w_trans; // transposed for efficiency
struct ggml_tensor * c_mlp_proj_b;
};
struct gpt2_model {
gpt2_hparams hparams;
// normalization
struct ggml_tensor * ln_f_g;
struct ggml_tensor * ln_f_b;
struct ggml_tensor * wte; // position embedding
struct ggml_tensor * wpe; // token embedding
std::vector<gpt2_layer> layers;
// key + value memory
struct ggml_tensor * memory_k;
struct ggml_tensor * memory_v;
//
struct ggml_context * ctx;
std::map<std::string, struct ggml_tensor *> tensors;
};
// load the model's weights from a file
bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & vocab) {
printf("%s: loading model from '%s'\n", __func__, fname.c_str());
auto fin = std::ifstream(fname, std::ios::binary);
if (!fin) {
fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
return false;
}
// verify magic
{
uint32_t magic;
fin.read((char *) &magic, sizeof(magic));
if (magic != 0x67676d6c) {
fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
return false;
}
}
// load hparams
{
auto & hparams = model.hparams;
fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
fin.read((char *) &hparams.n_ctx, sizeof(hparams.n_ctx));
fin.read((char *) &hparams.n_embd, sizeof(hparams.n_embd));
fin.read((char *) &hparams.n_head, sizeof(hparams.n_head));
fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));
fin.read((char *) &hparams.f16, sizeof(hparams.f16));
printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx);
printf("%s: n_embd = %d\n", __func__, hparams.n_embd);
printf("%s: n_head = %d\n", __func__, hparams.n_head);
printf("%s: n_layer = %d\n", __func__, hparams.n_layer);
printf("%s: f16 = %d\n", __func__, hparams.f16);
}
// load vocab
{
int32_t n_vocab = 0;
fin.read((char *) &n_vocab, sizeof(n_vocab));
if (n_vocab != model.hparams.n_vocab) {
fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
__func__, fname.c_str(), n_vocab, model.hparams.n_vocab);
return false;
}
std::string word;
for (int i = 0; i < n_vocab; i++) {
uint32_t len;
fin.read((char *) &len, sizeof(len));
word.resize(len);
fin.read((char *) &word[0], len);
vocab.token_to_id[word] = i;
vocab.id_to_token[i] = word;
}
}
// for the big tensors, we have the option to store the data in 16-bit floats
// in order to save memory and also to speed up the computation
const ggml_type wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
auto & ctx = model.ctx;
size_t ctx_size = 0;
{
const auto & hparams = model.hparams;
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
const int n_ctx = hparams.n_ctx;
const int n_vocab = hparams.n_vocab;
ctx_size += n_embd*ggml_type_size(GGML_TYPE_F32); // ln_f_g
ctx_size += n_embd*ggml_type_size(GGML_TYPE_F32); // ln_f_b
ctx_size += n_vocab*n_embd*ggml_type_size(wtype); // wte
ctx_size += n_ctx*n_embd*ggml_type_size(GGML_TYPE_F32); // wpe
ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_1_g
ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_1_b
ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_2_g
ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_2_b
ctx_size += n_layer*(3*n_embd*n_embd*ggml_type_size(wtype)); // c_attn_attn_w
ctx_size += n_layer*( 3*n_embd*ggml_type_size(GGML_TYPE_F32)); // c_attn_attn_b
ctx_size += n_layer*(n_embd*n_embd*ggml_type_size(wtype)); // c_attn_proj_w
ctx_size += n_layer*( n_embd*ggml_type_size(GGML_TYPE_F32)); // c_attn_proj_b
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_size(wtype)); // c_mlp_fc_w
ctx_size += n_layer*( 4*n_embd*ggml_type_size(GGML_TYPE_F32)); // c_mlp_fc_b
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_size(wtype)); // c_mlp_proj_w
ctx_size += n_layer*( n_embd*ggml_type_size(GGML_TYPE_F32)); // c_mlp_proj_b
ctx_size += n_ctx*n_layer*n_embd*ggml_type_size(GGML_TYPE_F32); // memory_k
ctx_size += n_ctx*n_layer*n_embd*ggml_type_size(GGML_TYPE_F32); // memory_v
ctx_size += (6 + 12*n_layer)*256; // object overhead
printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
}
// create the ggml context
{
struct ggml_init_params params;
params.mem_size = ctx_size;
params.mem_buffer = nullptr;
model.ctx = ggml_init(params);
if (!model.ctx) {
fprintf(stderr, "%s: ggml_init() failed\n", __func__);
return false;
}
}
// prepare memory for the weights
{
const auto & hparams = model.hparams;
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
const int n_ctx = hparams.n_ctx;
const int n_vocab = hparams.n_vocab;
model.layers.resize(n_layer);
model.ln_f_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
model.ln_f_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
model.wte = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab);
model.wpe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ctx);
// map by name
model.tensors["model/ln_f/g"] = model.ln_f_g;
model.tensors["model/ln_f/b"] = model.ln_f_b;
model.tensors["model/wte"] = model.wte;
model.tensors["model/wpe"] = model.wpe;
for (int i = 0; i < n_layer; ++i) {
auto & layer = model.layers[i];
layer.ln_1_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.ln_2_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.ln_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.c_attn_attn_w = ggml_new_tensor_2d(ctx, wtype, 3*n_embd, n_embd);
layer.c_attn_attn_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3*n_embd);
layer.c_attn_proj_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
layer.c_attn_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.c_mlp_fc_w = ggml_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd);
layer.c_mlp_fc_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_embd);
layer.c_mlp_proj_w_trans = ggml_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd);
layer.c_mlp_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
// map by name
model.tensors["model/h" + std::to_string(i) + "/ln_1/g"] = layer.ln_1_g;
model.tensors["model/h" + std::to_string(i) + "/ln_1/b"] = layer.ln_1_b;
model.tensors["model/h" + std::to_string(i) + "/ln_2/g"] = layer.ln_2_g;
model.tensors["model/h" + std::to_string(i) + "/ln_2/b"] = layer.ln_2_b;
model.tensors["model/h" + std::to_string(i) + "/attn/c_attn/w"] = layer.c_attn_attn_w;
model.tensors["model/h" + std::to_string(i) + "/attn/c_attn/b"] = layer.c_attn_attn_b;
model.tensors["model/h" + std::to_string(i) + "/attn/c_proj/w"] = layer.c_attn_proj_w;
model.tensors["model/h" + std::to_string(i) + "/attn/c_proj/b"] = layer.c_attn_proj_b;
model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/w"] = layer.c_mlp_fc_w;
model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/b"] = layer.c_mlp_fc_b;
model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/w"] = layer.c_mlp_proj_w_trans;
model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/b"] = layer.c_mlp_proj_b;
}
}
// key + value memory
{
const auto & hparams = model.hparams;
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
const int n_ctx = hparams.n_ctx;
const int n_mem = n_layer*n_ctx;
const int n_elements = n_embd*n_mem;
model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);
printf("%s: memory size = %8.2f MB, n_mem = %d\n", __func__, memory_size/1024.0/1024.0, n_mem);
}
// load weights
{
size_t total_size = 0;
while (true) {
int32_t n_dims;
int32_t length;
int32_t ftype;
fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
fin.read(reinterpret_cast<char *>(&length), sizeof(length));
fin.read(reinterpret_cast<char *>(&ftype), sizeof(ftype));
if (fin.eof()) {
break;
}
int32_t nelements = 1;
int32_t ne[2] = { 1, 1 };
for (int i = 0; i < n_dims; ++i) {
fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
nelements *= ne[i];
}
std::string name(length, 0);
fin.read(&name[0], length);
if (model.tensors.find(name) == model.tensors.end()) {
fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
return false;
}
auto tensor = model.tensors[name.data()];
if (ggml_nelements(tensor) != nelements) {
fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
return false;
}
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
__func__, name.data(), tensor->ne[0], tensor->ne[1], ne[0], ne[1]);
return false;
}
const size_t bpe = (ftype == 0) ? sizeof(float) : sizeof(ggml_fp16_t);
if (nelements*bpe != ggml_nbytes(tensor)) {
fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
return false;
}
fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
//printf("%24s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
total_size += ggml_nbytes(tensor);
}
printf("%s: model size = %8.2f MB\n", __func__, total_size/1024.0/1024.0);
}
fin.close();
return true;
}
// evaluate the transformer
//
// - model: the model
// - n_threads: number of threads to use
// - n_past: the context size so far
// - embd_inp: the embeddings of the tokens in the context
// - embd_w: the predicted probabilities of the next token
//
bool gpt2_eval(
const gpt2_model & model,
const int n_threads,
const int n_past,
const std::vector<gpt_vocab::id> & embd_inp,
std::vector<float> & embd_w,
size_t & mem_per_token) {
const int N = embd_inp.size();
const auto & hparams = model.hparams;
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
const int n_ctx = hparams.n_ctx;
const int n_head = hparams.n_head;
const int n_vocab = hparams.n_vocab;
static size_t buf_size = 5640ull*1024*1024;
static void * buf = malloc(buf_size);
if (mem_per_token > 0 && mem_per_token*N > buf_size) {
const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead
printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
// reallocate
buf_size = buf_size_new;
buf = realloc(buf, buf_size);
if (buf == nullptr) {
fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size);
return false;
}
}
struct ggml_init_params params;
params.mem_size = buf_size;
params.mem_buffer = buf;
struct ggml_context * ctx0 = ggml_init(params);
struct ggml_cgraph gf = { };
gf.n_threads = n_threads;
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
for (int i = 0; i < N; ++i) {
((int32_t *) position->data)[i] = n_past + i;
}
// wte + wpe
struct ggml_tensor * inpL =
ggml_add(ctx0,
ggml_get_rows(ctx0, model.wte, embd),
ggml_get_rows(ctx0, model.wpe, position));
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * cur;
// norm
{
// [ 768, N]
cur = ggml_norm(ctx0, inpL);
// cur = ln_1_g*cur + ln_1_b
// [ 768, N]
cur = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, model.layers[il].ln_1_g, cur),
cur),
ggml_repeat(ctx0, model.layers[il].ln_1_b, cur));
}
// attn
// [2304, 768] - model.layers[il].c_attn_attn_w
// [2304, 1] - model.layers[il].c_attn_attn_b
// [ 768, N] - cur (in)
// [2304, N] - cur (out)
//
// cur = attn_w*cur + attn_b
// [2304, N]
{
cur = ggml_mul_mat(ctx0,
ggml_transpose(ctx0, model.layers[il].c_attn_attn_w),
cur);
cur = ggml_add(ctx0,
ggml_repeat(ctx0, model.layers[il].c_attn_attn_b, cur),
cur);
}
// self-attention
{
struct ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 0*sizeof(float)*n_embd);
struct ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 1*sizeof(float)*n_embd);
struct ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 2*sizeof(float)*n_embd);
// store key and value to memory
if (N >= 1) {
struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_k, N*n_embd, (ggml_element_size(model.memory_k)*n_embd)*(il*n_ctx + n_past));
struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_v, N*n_embd, (ggml_element_size(model.memory_v)*n_embd)*(il*n_ctx + n_past));
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
}
// Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
// [64, N, 12]
struct ggml_tensor * Q =
ggml_permute(ctx0,
ggml_cpy(ctx0,
Qcur,
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd/n_head, n_head, N)),
0, 2, 1, 3);
// K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)
// [64, n_past + N, 12]
struct ggml_tensor * K =
ggml_permute(ctx0,
ggml_reshape_3d(ctx0,
ggml_view_1d(ctx0, model.memory_k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_k)*n_embd),
n_embd/n_head, n_head, n_past + N),
0, 2, 1, 3);
// GG: flash attention
//struct ggml_tensor * V =
// ggml_cpy(ctx0,
// ggml_permute(ctx0,
// ggml_reshape_3d(ctx0,
// ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
// n_embd/n_head, n_head, n_past + N),
// 1, 2, 0, 3),
// ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_past + N, n_embd/n_head, n_head));
//struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, true);
// K * Q
// [n_past + N, N, 12]
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
// KQ_scaled = KQ / sqrt(n_embd/n_head)
// [n_past + N, N, 12]
struct ggml_tensor * KQ_scaled =
ggml_scale(ctx0,
KQ,
ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head))
);
// KQ_masked = mask_past(KQ_scaled)
// [n_past + N, N, 12]
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
// KQ = soft_max(KQ_masked)
// [n_past + N, N, 12]
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
// V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
// [n_past + N, 64, 12]
struct ggml_tensor * V_trans =
ggml_permute(ctx0,
ggml_reshape_3d(ctx0,
ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
n_embd/n_head, n_head, n_past + N),
1, 2, 0, 3);
// KQV = transpose(V) * KQ_soft_max
// [64, N, 12]
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
// KQV_merged = KQV.permute(0, 2, 1, 3)
// [64, 12, N]
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
// cur = KQV_merged.contiguous().view(n_embd, N)
// [768, N]
cur = ggml_cpy(ctx0,
KQV_merged,
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
}
// projection
// [ 768, 768] - model.layers[il].c_attn_proj_w
// [ 768, 1] - model.layers[il].c_attn_proj_b
// [ 768, N] - cur (in)
// [ 768, N] - cur (out)
//
// cur = proj_w*cur + proj_b
// [768, N]
{
cur = ggml_mul_mat(ctx0,
ggml_transpose(ctx0, model.layers[il].c_attn_proj_w),
cur);
cur = ggml_add(ctx0,
ggml_repeat(ctx0, model.layers[il].c_attn_proj_b, cur),
cur);
}
// add the input
cur = ggml_add(ctx0, cur, inpL);
struct ggml_tensor * inpFF = cur;
// feed-forward network
{
// norm
{
cur = ggml_norm(ctx0, inpFF);
// cur = ln_2_g*cur + ln_2_b
// [ 768, N]
cur = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, model.layers[il].ln_2_g, cur),
cur),
ggml_repeat(ctx0, model.layers[il].ln_2_b, cur));
}
// fully connected
// [3072, 768] - model.layers[il].c_mlp_fc_w
// [3072, 1] - model.layers[il].c_mlp_fc_b
// [ 768, N] - cur (in)
// [3072, N] - cur (out)
//
// cur = fc_w*cur + fc_b
// [3072, N]
cur = ggml_mul_mat(ctx0,
ggml_transpose(ctx0, model.layers[il].c_mlp_fc_w),
cur);
cur = ggml_add(ctx0,
ggml_repeat(ctx0, model.layers[il].c_mlp_fc_b, cur),
cur);
// GELU activation
// [3072, N]
cur = ggml_gelu(ctx0, cur);
// projection
// [ 768, 3072] - model.layers[il].c_mlp_proj_w
// [ 768, 1] - model.layers[il].c_mlp_proj_b
// [3072, N] - cur (in)
// [ 768, N] - cur (out)
//
// cur = proj_w*cur + proj_b
// [768, N]
cur = ggml_mul_mat(ctx0,
model.layers[il].c_mlp_proj_w_trans,
cur);
cur = ggml_add(ctx0,
ggml_repeat(ctx0, model.layers[il].c_mlp_proj_b, cur),
cur);
}
// input for next layer
inpL = ggml_add(ctx0, cur, inpFF);
}
// norm
{
// [ 768, N]
inpL = ggml_norm(ctx0, inpL);
// inpL = ln_f_g*inpL + ln_f_b
// [ 768, N]
inpL = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, model.ln_f_g, inpL),
inpL),
ggml_repeat(ctx0, model.ln_f_b, inpL));
}
// inpL = WTE * inpL
// [ 768, 50257] - model.wte
// [ 768, N] - inpL
inpL = ggml_mul_mat(ctx0, model.wte, inpL);
// logits -> probs
inpL = ggml_soft_max(ctx0, inpL);
// run the computation
ggml_build_forward_expand(&gf, inpL);
ggml_graph_compute (ctx0, &gf);
//if (n_past%100 == 0) {
// ggml_graph_print (&gf);
// ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot");
//}
//embd_w.resize(n_vocab*N);
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
// return result for just the last token
embd_w.resize(n_vocab);
memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
if (mem_per_token == 0) {
mem_per_token = ggml_used_mem(ctx0)/N;
}
//printf("used_mem = %zu\n", ggml_used_mem(ctx0));
ggml_free(ctx0);
return true;
}
/////////////////////////////// GPT-2 END ////////////////////////////////
constexpr int N_THREAD = 8;
struct gpt2_context {
std::string prompt_base = R"(Hello, how are you?
I'm fine, thanks. How are you?
Thanks, I'm fine too. What are you doing?
I'm just sitting here.
It's a lovely day, isn't it?
Yes, it is. I love the weather this time of year.
I wish it would rain a little bit.
Me too.
)";
std::mt19937 rng;
gpt_vocab vocab;
gpt2_model model;
int32_t n_threads = std::min(N_THREAD, (int) std::thread::hardware_concurrency());
// sampling parameters
int32_t top_k = 5;
float top_p = 0.9f;
float temp = 1.0f;
};
struct gpt2_context * gpt2_init(const char * path_model) {
gpt2_context * ctx = new gpt2_context;
ctx->rng = std::mt19937(time(nullptr));
// load the model
{
const int64_t t_start_us = ggml_time_us();
if (!gpt2_model_load(path_model, ctx->model, ctx->vocab)) {
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, path_model);
delete ctx;
return nullptr;
}
const int64_t t_load_us = ggml_time_us() - t_start_us;
printf("gpt-2: model loaded in %d ms\n", (int) (t_load_us/1000));
}
return ctx;
}
void gpt2_free(struct gpt2_context * ctx) {
delete ctx;
}
const char * gpt2_get_prompt(struct gpt2_context * ctx) {
return ctx->prompt_base.c_str();
}
void gpt2_set_prompt(struct gpt2_context * ctx, const char * prompt) {
ctx->prompt_base = prompt;
}
std::vector<gpt_vocab::id> gpt2_tokenize(const gpt2_context * ctx, const char * text) {
return ::gpt_tokenize(ctx->vocab, text);
}
std::string gpt2_gen_text(gpt2_context * ctx, const char * text, int max_tokens) {
int n_past = 0;
std::vector<float> embd_w;
// tokenize the prompt
std::vector<gpt_vocab::id> embd_inp = ::gpt2_tokenize(ctx, text);
int n_predict = std::min(max_tokens, ctx->model.hparams.n_ctx - (int) embd_inp.size());
std::vector<gpt_vocab::id> embd = embd_inp;
size_t mem_per_token = 3000000;
std::string result;
for (int i = embd.size(); i < (int) embd_inp.size() + n_predict; i++) {
// predict
if (!embd.empty()) {
if (!gpt2_eval(ctx->model, ctx->n_threads, n_past, embd, embd_w, mem_per_token)) {
printf("gpt-2: failed to generate text\n");
return "";
}
}
n_past += embd.size();
embd.clear();
{
// sample next token
const int top_k = ctx->top_k;
const float top_p = ctx->top_p;
const float temp = ctx->temp;
const int n_vocab = ctx->model.hparams.n_vocab;
const gpt_vocab::id id = gpt_sample_top_k_top_p(ctx->vocab, embd_w.data() + (embd_w.size() - n_vocab), top_k, top_p, temp, ctx->rng);
// add it to the context
embd.push_back(id);
}
result += ctx->vocab.id_to_token[embd[0]];
// end of text token
if (embd.back() == 50256) {
break;
}
}
return result;
}

View File

@ -1,27 +0,0 @@
#pragma once
// TODO: Change to C-style API and move to ./examples for easy reuse.
#include <vector>
#include <map>
#include <string>
struct gpt_vocab {
using id = int32_t;
using token = std::string;
std::map<token, id> token_to_id;
std::map<id, token> id_to_token;
};
struct gpt2_context;
struct gpt2_context * gpt2_init(const char * path_model);
void gpt2_free(struct gpt2_context * ctx);
const char * gpt2_get_prompt(struct gpt2_context * ctx);
void gpt2_set_prompt(struct gpt2_context * ctx, const char * prompt);
std::vector<gpt_vocab::id> gpt2_tokenize(const gpt2_context * ctx, const char * text);
std::string gpt2_gen_text(gpt2_context * ctx, const char * text, int max_tokens);

View File

@ -1,17 +0,0 @@
#!/bin/bash
# Usage:
# speak.sh <voice_id> <text-to-speak>
# espeak
# Mac OS: brew install espeak
# Linux: apt-get install espeak
#
espeak -v en-us+m$1 -s 175 -p 50 -a 200 -g 5 -k 5 "$2"
# Eleven Labs
#
#wd=$(dirname $0)
#script=$wd/eleven-labs.py
#python3 $script $1 "$2"
#ffplay -autoexit -nodisp -loglevel quiet -hide_banner -i ./audio.mp3

View File

@ -1,695 +0,0 @@
// Talk with AI
//
#include "whisper.h"
#include "gpt-2.h"
#include <SDL.h>
#include <SDL_audio.h>
#include <cassert>
#include <cstdio>
#include <fstream>
#include <mutex>
#include <regex>
#include <string>
#include <thread>
#include <vector>
#include <regex>
// command-line parameters
struct whisper_params {
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
int32_t voice_ms = 10000;
int32_t capture_id = -1;
int32_t max_tokens = 32;
int32_t audio_ctx = 0;
float vad_thold = 0.6f;
float freq_thold = 100.0f;
bool speed_up = false;
bool translate = false;
bool print_special = false;
bool print_energy = false;
bool no_timestamps = true;
std::string person = "Santa";
std::string language = "en";
std::string model_wsp = "models/ggml-base.en.bin";
std::string model_gpt = "models/ggml-gpt-2-117M.bin";
std::string speak = "./examples/talk/speak.sh";
std::string fname_out;
};
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
for (int i = 1; i < argc; i++) {
std::string arg = argv[i];
if (arg == "-h" || arg == "--help") {
whisper_print_usage(argc, argv, params);
exit(0);
}
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
else if (arg == "-vms" || arg == "--voice-ms") { params.voice_ms = std::stoi(argv[++i]); }
else if (arg == "-c" || arg == "--capture") { params.capture_id = std::stoi(argv[++i]); }
else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); }
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
else if (arg == "-p" || arg == "--person") { params.person = argv[++i]; }
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
else if (arg == "-mw" || arg == "--model-whisper") { params.model_wsp = argv[++i]; }
else if (arg == "-mg" || arg == "--model-gpt") { params.model_gpt = argv[++i]; }
else if (arg == "-s" || arg == "--speak") { params.speak = argv[++i]; }
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params);
exit(0);
}
}
return true;
}
void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params) {
fprintf(stderr, "\n");
fprintf(stderr, "usage: %s [options]\n", argv[0]);
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
fprintf(stderr, " -vms N, --voice-ms N [%-7d] voice duration in milliseconds\n", params.voice_ms);
fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
fprintf(stderr, " -p NAME, --person NAME [%-7s] person name (for prompt selection)\n", params.person.c_str());
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
fprintf(stderr, " -mw FILE, --model-whisper [%-7s] whisper model file\n", params.model_wsp.c_str());
fprintf(stderr, " -mg FILE, --model-gpt [%-7s] gpt model file\n", params.model_gpt.c_str());
fprintf(stderr, " -s FILE, --speak TEXT [%-7s] command for TTS\n", params.speak.c_str());
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
fprintf(stderr, "\n");
}
//
// SDL Audio capture
//
class audio_async {
public:
audio_async(int len_ms);
~audio_async();
bool init(int capture_id, int sample_rate);
// start capturing audio via the provided SDL callback
// keep last len_ms seconds of audio in a circular buffer
bool resume();
bool pause();
bool clear();
// callback to be called by SDL
void callback(uint8_t * stream, int len);
// get audio data from the circular buffer
void get(int ms, std::vector<float> & audio);
private:
SDL_AudioDeviceID m_dev_id_in = 0;
int m_len_ms = 0;
int m_sample_rate = 0;
bool m_running = false;
std::mutex m_mutex;
std::vector<float> m_audio;
std::vector<float> m_audio_new;
size_t m_audio_pos = 0;
size_t m_audio_len = 0;
};
audio_async::audio_async(int len_ms) {
m_len_ms = len_ms;
}
audio_async::~audio_async() {
if (m_dev_id_in) {
SDL_CloseAudioDevice(m_dev_id_in);
}
}
bool audio_async::init(int capture_id, int sample_rate) {
SDL_LogSetPriority(SDL_LOG_CATEGORY_APPLICATION, SDL_LOG_PRIORITY_INFO);
if (SDL_Init(SDL_INIT_AUDIO) < 0) {
SDL_LogError(SDL_LOG_CATEGORY_APPLICATION, "Couldn't initialize SDL: %s\n", SDL_GetError());
return false;
}
SDL_SetHintWithPriority(SDL_HINT_AUDIO_RESAMPLING_MODE, "medium", SDL_HINT_OVERRIDE);
{
int nDevices = SDL_GetNumAudioDevices(SDL_TRUE);
fprintf(stderr, "%s: found %d capture devices:\n", __func__, nDevices);
for (int i = 0; i < nDevices; i++) {
fprintf(stderr, "%s: - Capture device #%d: '%s'\n", __func__, i, SDL_GetAudioDeviceName(i, SDL_TRUE));
}
}
SDL_AudioSpec capture_spec_requested;
SDL_AudioSpec capture_spec_obtained;
SDL_zero(capture_spec_requested);
SDL_zero(capture_spec_obtained);
capture_spec_requested.freq = sample_rate;
capture_spec_requested.format = AUDIO_F32;
capture_spec_requested.channels = 1;
capture_spec_requested.samples = 1024;
capture_spec_requested.callback = [](void * userdata, uint8_t * stream, int len) {
audio_async * audio = (audio_async *) userdata;
audio->callback(stream, len);
};
capture_spec_requested.userdata = this;
if (capture_id >= 0) {
fprintf(stderr, "%s: attempt to open capture device %d : '%s' ...\n", __func__, capture_id, SDL_GetAudioDeviceName(capture_id, SDL_TRUE));
m_dev_id_in = SDL_OpenAudioDevice(SDL_GetAudioDeviceName(capture_id, SDL_TRUE), SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
} else {
fprintf(stderr, "%s: attempt to open default capture device ...\n", __func__);
m_dev_id_in = SDL_OpenAudioDevice(nullptr, SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
}
if (!m_dev_id_in) {
fprintf(stderr, "%s: couldn't open an audio device for capture: %s!\n", __func__, SDL_GetError());
m_dev_id_in = 0;
return false;
} else {
fprintf(stderr, "%s: obtained spec for input device (SDL Id = %d):\n", __func__, m_dev_id_in);
fprintf(stderr, "%s: - sample rate: %d\n", __func__, capture_spec_obtained.freq);
fprintf(stderr, "%s: - format: %d (required: %d)\n", __func__, capture_spec_obtained.format,
capture_spec_requested.format);
fprintf(stderr, "%s: - channels: %d (required: %d)\n", __func__, capture_spec_obtained.channels,
capture_spec_requested.channels);
fprintf(stderr, "%s: - samples per frame: %d\n", __func__, capture_spec_obtained.samples);
fprintf(stderr, "\n");
}
m_sample_rate = capture_spec_obtained.freq;
m_audio.resize((m_sample_rate*m_len_ms)/1000);
return true;
}
bool audio_async::resume() {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to resume!\n", __func__);
return false;
}
if (m_running) {
fprintf(stderr, "%s: already running!\n", __func__);
return false;
}
SDL_PauseAudioDevice(m_dev_id_in, 0);
m_running = true;
return true;
}
bool audio_async::pause() {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to pause!\n", __func__);
return false;
}
if (!m_running) {
fprintf(stderr, "%s: already paused!\n", __func__);
return false;
}
SDL_PauseAudioDevice(m_dev_id_in, 1);
m_running = false;
return true;
}
bool audio_async::clear() {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to clear!\n", __func__);
return false;
}
if (!m_running) {
fprintf(stderr, "%s: not running!\n", __func__);
return false;
}
{
std::lock_guard<std::mutex> lock(m_mutex);
m_audio_pos = 0;
m_audio_len = 0;
}
return true;
}
// callback to be called by SDL
void audio_async::callback(uint8_t * stream, int len) {
if (!m_running) {
return;
}
const size_t n_samples = len / sizeof(float);
m_audio_new.resize(n_samples);
memcpy(m_audio_new.data(), stream, n_samples * sizeof(float));
//fprintf(stderr, "%s: %zu samples, pos %zu, len %zu\n", __func__, n_samples, m_audio_pos, m_audio_len);
{
std::lock_guard<std::mutex> lock(m_mutex);
if (m_audio_pos + n_samples > m_audio.size()) {
const size_t n0 = m_audio.size() - m_audio_pos;
memcpy(&m_audio[m_audio_pos], stream, n0 * sizeof(float));
memcpy(&m_audio[0], &stream[n0], (n_samples - n0) * sizeof(float));
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
m_audio_len = m_audio.size();
} else {
memcpy(&m_audio[m_audio_pos], stream, n_samples * sizeof(float));
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
m_audio_len = std::min(m_audio_len + n_samples, m_audio.size());
}
}
}
void audio_async::get(int ms, std::vector<float> & result) {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to get audio from!\n", __func__);
return;
}
if (!m_running) {
fprintf(stderr, "%s: not running!\n", __func__);
return;
}
result.clear();
{
std::lock_guard<std::mutex> lock(m_mutex);
if (ms <= 0) {
ms = m_len_ms;
}
size_t n_samples = (m_sample_rate * ms) / 1000;
if (n_samples > m_audio_len) {
n_samples = m_audio_len;
}
result.resize(n_samples);
int s0 = m_audio_pos - n_samples;
if (s0 < 0) {
s0 += m_audio.size();
}
if (s0 + n_samples > m_audio.size()) {
const size_t n0 = m_audio.size() - s0;
memcpy(result.data(), &m_audio[s0], n0 * sizeof(float));
memcpy(&result[n0], &m_audio[0], (n_samples - n0) * sizeof(float));
} else {
memcpy(result.data(), &m_audio[s0], n_samples * sizeof(float));
}
}
}
///////////////////////////
std::string trim(const std::string & s) {
std::regex e("^\\s+|\\s+$");
return std::regex_replace(s, e, "");
}
std::string replace(const std::string & s, const std::string & from, const std::string & to) {
std::string result = s;
size_t pos = 0;
while ((pos = result.find(from, pos)) != std::string::npos) {
result.replace(pos, from.length(), to);
pos += to.length();
}
return result;
}
void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate) {
const float rc = 1.0f / (2.0f * M_PI * cutoff);
const float dt = 1.0f / sample_rate;
const float alpha = dt / (rc + dt);
float y = data[0];
for (size_t i = 1; i < data.size(); i++) {
y = alpha * (y + data[i] - data[i - 1]);
data[i] = y;
}
}
bool vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) {
const int n_samples = pcmf32.size();
const int n_samples_last = (sample_rate * last_ms) / 1000;
if (n_samples_last >= n_samples) {
// not enough samples - assume no speech
return false;
}
if (freq_thold > 0.0f) {
high_pass_filter(pcmf32, freq_thold, sample_rate);
}
float energy_all = 0.0f;
float energy_last = 0.0f;
for (int i = 0; i < n_samples; i++) {
energy_all += fabsf(pcmf32[i]);
if (i >= n_samples - n_samples_last) {
energy_last += fabsf(pcmf32[i]);
}
}
energy_all /= n_samples;
energy_last /= n_samples_last;
if (verbose) {
fprintf(stderr, "%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold);
}
if (energy_last > vad_thold*energy_all) {
return false;
}
return true;
}
std::string transcribe(whisper_context * ctx, const whisper_params & params, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) {
const auto t_start = std::chrono::high_resolution_clock::now();
prob = 0.0f;
t_ms = 0;
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
wparams.print_progress = false;
wparams.print_special = params.print_special;
wparams.print_realtime = false;
wparams.print_timestamps = !params.no_timestamps;
wparams.translate = params.translate;
wparams.no_context = true;
wparams.single_segment = true;
wparams.max_tokens = params.max_tokens;
wparams.language = params.language.c_str();
wparams.n_threads = params.n_threads;
wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
return "";
}
int prob_n = 0;
std::string result;
const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
result += text;
const int n_tokens = whisper_full_n_tokens(ctx, i);
for (int j = 0; j < n_tokens; ++j) {
const auto token = whisper_full_get_token_data(ctx, i, j);
prob += token.p;
++prob_n;
}
}
if (prob_n > 0) {
prob /= prob_n;
}
const auto t_end = std::chrono::high_resolution_clock::now();
t_ms = std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count();
return result;
}
const std::string k_prompt =
R"(This is a dialogue between {0} (A) and a person (B). The dialogue so far is:
B: Hello {0}, how are you?
A: I'm fine, thank you.
{1}
Here is how {0} (A) continues the dialogue:
A:)";
int main(int argc, char ** argv) {
whisper_params params;
if (whisper_params_parse(argc, argv, params) == false) {
return 1;
}
if (whisper_lang_id(params.language.c_str()) == -1) {
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
whisper_print_usage(argc, argv, params);
exit(0);
}
// whisper init
struct whisper_context * ctx_wsp = whisper_init_from_file(params.model_wsp.c_str());
// gpt init
struct gpt2_context * ctx_gpt = gpt2_init(params.model_gpt.c_str());
// print some info about the processing
{
fprintf(stderr, "\n");
if (!whisper_is_multilingual(ctx_wsp)) {
if (params.language != "en" || params.translate) {
params.language = "en";
params.translate = false;
fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
}
}
fprintf(stderr, "%s: processing, %d threads, lang = %s, task = %s, timestamps = %d ...\n",
__func__,
params.n_threads,
params.language.c_str(),
params.translate ? "translate" : "transcribe",
params.no_timestamps ? 0 : 1);
fprintf(stderr, "\n");
}
// init audio
audio_async audio(30*1000);
if (!audio.init(params.capture_id, WHISPER_SAMPLE_RATE)) {
fprintf(stderr, "%s: audio.init() failed!\n", __func__);
return 1;
}
audio.resume();
int n_iter = 0;
bool is_running = true;
bool force_speak = false;
float prob0 = 0.0f;
std::vector<float> pcmf32_cur;
std::vector<float> pcmf32_prompt;
gpt2_set_prompt(ctx_gpt, "");
const int voice_id = rand()%6;
fprintf(stderr, "gpt-2: prompt:\n");
fprintf(stderr, "========================\n\n");
fprintf(stderr, "%s\n", ::replace(k_prompt, "{0}", params.person).c_str());
fprintf(stderr, "========================\n\n");
// main loop
while (is_running) {
// handle Ctrl + C
{
SDL_Event event;
while (SDL_PollEvent(&event)) {
switch (event.type) {
case SDL_QUIT:
{
is_running = false;
} break;
default:
break;
}
}
if (!is_running) {
break;
}
}
// delay
std::this_thread::sleep_for(std::chrono::milliseconds(100));
int64_t t_ms = 0;
{
audio.get(2000, pcmf32_cur);
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1250, params.vad_thold, params.freq_thold, params.print_energy) || force_speak) {
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
audio.get(params.voice_ms, pcmf32_cur);
std::string text_heard;
if (!force_speak) {
text_heard = ::trim(::transcribe(ctx_wsp, params, pcmf32_cur, prob0, t_ms));
}
// remove text between brackets using regex
{
std::regex re("\\[.*?\\]");
text_heard = std::regex_replace(text_heard, re, "");
}
// remove text between brackets using regex
{
std::regex re("\\(.*?\\)");
text_heard = std::regex_replace(text_heard, re, "");
}
// remove all characters, except for letters, numbers, punctuation and ':', '\'', '-', ' '
text_heard = std::regex_replace(text_heard, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
// take first line
text_heard = text_heard.substr(0, text_heard.find_first_of('\n'));
// remove leading and trailing whitespace
text_heard = std::regex_replace(text_heard, std::regex("^\\s+"), "");
text_heard = std::regex_replace(text_heard, std::regex("\\s+$"), "");
const std::vector<gpt_vocab::id> tokens = gpt2_tokenize(ctx_gpt, text_heard.c_str());
if (text_heard.empty() || tokens.empty() || force_speak) {
fprintf(stdout, "%s: Heard nothing, skipping ...\n", __func__);
audio.clear();
continue;
}
force_speak = false;
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", text_heard.c_str(), "\033[0m", (int) t_ms);
std::string prompt_base = gpt2_get_prompt(ctx_gpt);
std::string text_to_speak;
{
prompt_base += "B: " + text_heard + "\n";
std::string prompt = ::replace(::replace(k_prompt, "{0}", params.person), "{1}", prompt_base);
text_to_speak = gpt2_gen_text(ctx_gpt, prompt.c_str(), params.max_tokens);
text_to_speak = std::regex_replace(text_to_speak, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
text_to_speak = text_to_speak.substr(0, text_to_speak.find_first_of('\n'));
// remove first 2 lines of base prompt
if (n_iter > 4) {
{
const size_t pos = prompt_base.find_first_of('\n');
if (pos != std::string::npos) {
prompt_base = prompt_base.substr(pos + 1);
}
}
{
const size_t pos = prompt_base.find_first_of('\n');
if (pos != std::string::npos) {
prompt_base = prompt_base.substr(pos + 1);
}
}
}
prompt_base += "A:" + text_to_speak + "\n";
{
prompt = ::replace(::replace(k_prompt, "{0}", params.person), "{1}", prompt_base);
printf("===============\n");
printf("prompt:\n");
printf("%s\n", prompt.c_str());
printf("===============\n");
}
}
//printf("========================\n");
//printf("gpt-2: prompt_base:\n%s\n", prompt_base.c_str());
//printf("========================\n");
gpt2_set_prompt(ctx_gpt, prompt_base.c_str());
text_to_speak = ::replace(text_to_speak, params.person + ": ", "");
system((params.speak + " " + std::to_string(voice_id) + " \"" + text_to_speak + "\"").c_str());
audio.clear();
++n_iter;
}
}
}
audio.pause();
whisper_print_timings(ctx_wsp);
whisper_free(ctx_wsp);
return 0;
}

View File

@ -1,109 +0,0 @@
#!/bin/bash
#
# Transcribe twitch.tv livestream by feeding audio input to whisper.cpp at regular intervals
# Thanks to @keyehzy
# ref: https://github.com/ggerganov/whisper.cpp/issues/209
#
# The script currently depends on the third-party tool "streamlink"
# On Mac OS, you can install it via "brew install streamlink"
#
set -eo pipefail
step=10
model=base.en
threads=4
help()
{
echo "Example program for captioning a livestream from twitch.tv."
echo
echo "Usage: ./twitch.sh -s [step] -m [model] -t [threads] [url]"
echo "options:"
echo "-s Step in seconds (default is $step)."
echo "-m Choose model, options are: 'tiny.en' 'tiny' 'base.en' 'base' 'small.en' 'small' 'medium.en' 'medium' 'large-v1' 'large' (default is '$model')."
echo "-t Number of threads to use."
echo "-h Print this help page."
echo
}
check_requirements()
{
if ! command -v ./main &>/dev/null; then
echo "whisper.cpp main executable is required (make)"
exit 1
fi
if ! command -v streamlink &>/dev/null; then
echo "streamlink is required (https://streamlink.github.io)"
exit 1
fi
if ! command -v ffmpeg &>/dev/null; then
echo "ffmpeg is required (https://ffmpeg.org)"
exit 1
fi
}
check_requirements
while getopts ":s:m:t:h" option; do
case $option in
s)
step=$OPTARG;;
m)
model=$OPTARG;;
t)
threads=$OPTARG;;
h)
help
exit;;
\?)
help
exit;;
esac
done
url=${@:$OPTIND:1}
if [ -z $url ]; then
help
exit
fi
echo "Piping from streamlink url=$url model=$model step=$step threads=$threads"
streamlink $url best -O 2>/dev/null | ffmpeg -loglevel quiet -i - -y -probesize 32 -y -ar 16000 -ac 1 -acodec pcm_s16le /tmp/whisper-live0.wav &
if [ $? -ne 0 ]; then
printf "error: ffmpeg failed\n"
exit 1
fi
echo "Buffering stream... (this should take $step seconds)"
sleep $(($step))
set +e
echo "Starting..."
i=0
SECONDS=0
while true
do
err=1
while [ $err -ne 0 ]; do
if [ $i -gt 0 ]; then
ffmpeg -loglevel quiet -v error -noaccurate_seek -i /tmp/whisper-live0.wav -y -ss $(($i*$step-1)).5 -t $step -c copy /tmp/whisper-live.wav 2> /tmp/whisper-live.err
else
ffmpeg -loglevel quiet -v error -noaccurate_seek -i /tmp/whisper-live0.wav -y -ss $(($i*$step)) -t $step -c copy /tmp/whisper-live.wav 2> /tmp/whisper-live.err
fi
err=$(cat /tmp/whisper-live.err | wc -l)
done
./main -t $threads -m ./models/ggml-$model.bin -f /tmp/whisper-live.wav --no-timestamps -otxt 2> /tmp/whispererr | tail -n 1
while [ $SECONDS -lt $((($i+1)*$step)) ]; do
sleep 1
done
((i=i+1))
done

View File

@ -1,15 +0,0 @@
*.iml
.gradle
/local.properties
/.idea/caches
/.idea/libraries
/.idea/modules.xml
/.idea/workspace.xml
/.idea/navEditor.xml
/.idea/assetWizardSettings.xml
.DS_Store
/build
/captures
.externalNativeBuild
.cxx
local.properties

View File

@ -1,3 +0,0 @@
# Default ignored files
/shelf/
/workspace.xml

View File

@ -1 +0,0 @@
WhisperCppDemo

View File

@ -1,6 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="CompilerConfiguration">
<bytecodeTargetLevel target="11" />
</component>
</project>

View File

@ -1,19 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="GradleMigrationSettings" migrationVersion="1" />
<component name="GradleSettings">
<option name="linkedExternalProjectsSettings">
<GradleProjectSettings>
<option name="testRunner" value="GRADLE" />
<option name="distributionType" value="DEFAULT_WRAPPED" />
<option name="externalProjectPath" value="$PROJECT_DIR$" />
<option name="modules">
<set>
<option value="$PROJECT_DIR$" />
<option value="$PROJECT_DIR$/app" />
</set>
</option>
</GradleProjectSettings>
</option>
</component>
</project>

View File

@ -1,10 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ExternalStorageConfigurationManager" enabled="true" />
<component name="ProjectRootManager" version="2" languageLevel="JDK_11" default="true" project-jdk-name="Android Studio default JDK" project-jdk-type="JavaSDK">
<output url="file://$PROJECT_DIR$/build/classes" />
</component>
<component name="ProjectType">
<option name="id" value="Android" />
</component>
</project>

View File

@ -1,6 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$/../.." vcs="Git" />
</component>
</project>

View File

@ -1,12 +0,0 @@
A sample Android app using [whisper.cpp](https://github.com/ggerganov/whisper.cpp/) to do voice-to-text transcriptions.
To use:
1. Select a model from the [whisper.cpp repository](https://github.com/ggerganov/whisper.cpp/tree/master/models).[^1]
2. Copy the model to the "app/src/main/assets/models" folder.
3. Select a sample audio file (for example, [jfk.wav](https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav)).
4. Copy the sample to the "app/src/main/assets/samples" folder.
5. Select the "release" active build variant, and use Android Studio to run and deploy to your device.
[^1]: I recommend the tiny or base models for running on an Android device.
<img width="300" alt="image" src="https://user-images.githubusercontent.com/1991296/208154256-82d972dc-221b-48c4-bfcb-36ce68602f93.png">

View File

@ -1 +0,0 @@
/build

View File

@ -1,72 +0,0 @@
plugins {
id 'com.android.application'
id 'org.jetbrains.kotlin.android'
}
android {
namespace 'com.whispercppdemo'
compileSdk 33
defaultConfig {
applicationId "com.whispercppdemo"
minSdk 26
targetSdk 32
versionCode 1
versionName "1.0"
testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
vectorDrawables {
useSupportLibrary true
}
}
buildTypes {
release {
signingConfig signingConfigs.debug
minifyEnabled true
proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
}
}
compileOptions {
sourceCompatibility JavaVersion.VERSION_1_8
targetCompatibility JavaVersion.VERSION_1_8
}
kotlinOptions {
jvmTarget = '1.8'
}
buildFeatures {
compose true
}
composeOptions {
kotlinCompilerExtensionVersion '1.3.1'
}
ndkVersion "25.1.8937393"
externalNativeBuild {
ndkBuild {
path 'src/main/jni/whisper/Android.mk'
}
}
packagingOptions {
resources {
excludes += '/META-INF/{AL2.0,LGPL2.1}'
}
}
}
dependencies {
implementation 'androidx.activity:activity-compose:1.6.1'
implementation 'androidx.compose.material:material-icons-core:1.3.1'
implementation 'androidx.compose.material3:material3:1.0.1'
implementation "androidx.compose.ui:ui:1.3.2"
implementation "androidx.compose.ui:ui-tooling-preview:1.3.2"
implementation 'androidx.lifecycle:lifecycle-viewmodel-compose:2.5.1'
implementation "com.google.accompanist:accompanist-permissions:0.28.0"
implementation 'org.jetbrains.kotlinx:kotlinx-coroutines-core:1.6.4'
testImplementation 'junit:junit:4.13.2'
androidTestImplementation 'androidx.test.ext:junit:1.1.4'
androidTestImplementation 'androidx.test.espresso:espresso-core:3.5.0'
androidTestImplementation "androidx.compose.ui:ui-test-junit4:1.3.2"
debugImplementation "androidx.compose.ui:ui-tooling:1.3.2"
debugImplementation "androidx.compose.ui:ui-test-manifest:1.3.2"
}

View File

@ -1,21 +0,0 @@
# Add project specific ProGuard rules here.
# You can control the set of applied configuration files using the
# proguardFiles setting in build.gradle.
#
# For more details, see
# http://developer.android.com/guide/developing/tools/proguard.html
# If your project uses WebView with JS, uncomment the following
# and specify the fully qualified class name to the JavaScript interface
# class:
#-keepclassmembers class fqcn.of.javascript.interface.for.webview {
# public *;
#}
# Uncomment this to preserve the line number information for
# debugging stack traces.
#-keepattributes SourceFile,LineNumberTable
# If you keep the line number information, uncomment this to
# hide the original source file name.
#-renamesourcefileattribute SourceFile

View File

@ -1,24 +0,0 @@
package com.whispercppdemo
import androidx.test.platform.app.InstrumentationRegistry
import androidx.test.ext.junit.runners.AndroidJUnit4
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.Assert.*
/**
* Instrumented test, which will execute on an Android device.
*
* See [testing documentation](http://d.android.com/tools/testing).
*/
@RunWith(AndroidJUnit4::class)
class ExampleInstrumentedTest {
@Test
fun useAppContext() {
// Context of the app under test.
val appContext = InstrumentationRegistry.getInstrumentation().targetContext
assertEquals("com.whispercppdemo", appContext.packageName)
}
}

Some files were not shown because too many files have changed in this diff Show More