mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-06-24 17:15:19 +00:00
Compare commits
1 Commits
arghh
...
macros-cvt
Author | SHA1 | Date | |
---|---|---|---|
e0bd97f41f |
17
.github/workflows/bindings.yml
vendored
17
.github/workflows/bindings.yml
vendored
@ -1,17 +0,0 @@
|
||||
name: Bindings Tests
|
||||
on:
|
||||
push:
|
||||
paths:
|
||||
- bindings/go/**
|
||||
|
||||
jobs:
|
||||
ubuntu-latest:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: '^1.19'
|
||||
- uses: actions/checkout@v1
|
||||
- run: |
|
||||
cd bindings/go
|
||||
make test
|
86
.github/workflows/build.yml
vendored
86
.github/workflows/build.yml
vendored
@ -119,59 +119,7 @@ jobs:
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
build: [Release]
|
||||
arch: [Win32, x64]
|
||||
sdl2: [ON]
|
||||
include:
|
||||
- arch: Win32
|
||||
s2arc: x86
|
||||
- arch: x64
|
||||
s2arc: x64
|
||||
- sdl2: ON
|
||||
s2ver: 2.26.0
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v1
|
||||
|
||||
- name: Add msbuild to PATH
|
||||
uses: microsoft/setup-msbuild@v1
|
||||
|
||||
- name: Fetch SDL2 and set SDL2_DIR
|
||||
if: matrix.sdl2 == 'ON'
|
||||
run: |
|
||||
C:/msys64/usr/bin/wget.exe -qO sdl2.zip https://github.com/libsdl-org/SDL/releases/download/release-${{ matrix.s2ver }}/SDL2-devel-${{ matrix.s2ver }}-VC.zip
|
||||
7z x sdl2.zip
|
||||
echo "SDL2_DIR=$env:GITHUB_WORKSPACE/SDL2-${{ matrix.s2ver }}/cmake" >> $env:GITHUB_ENV
|
||||
|
||||
- name: Configure
|
||||
run: >
|
||||
cmake -S . -B ./build -A ${{ matrix.arch }}
|
||||
-DCMAKE_BUILD_TYPE=${{ matrix.build }}
|
||||
-DWHISPER_SUPPORT_SDL2=${{ matrix.sdl2 }}
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
cd ./build
|
||||
msbuild ALL_BUILD.vcxproj -t:build -p:configuration=${{ matrix.build }} -p:platform=${{ matrix.arch }}
|
||||
|
||||
- name: Copy SDL2.dll
|
||||
if: matrix.sdl2 == 'ON'
|
||||
run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }}
|
||||
|
||||
- name: Upload binaries
|
||||
if: matrix.sdl2 == 'ON'
|
||||
uses: actions/upload-artifact@v1
|
||||
with:
|
||||
name: whisper-bin-${{ matrix.arch }}
|
||||
path: build/bin/${{ matrix.build }}
|
||||
|
||||
windows-blas:
|
||||
runs-on: windows-latest
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
build: [Release]
|
||||
build: [RelWithDebInfo]
|
||||
arch: [Win32, x64]
|
||||
blas: [ON]
|
||||
sdl2: [ON]
|
||||
@ -233,35 +181,5 @@ jobs:
|
||||
if: matrix.blas == 'ON' && matrix.sdl2 == 'ON'
|
||||
uses: actions/upload-artifact@v1
|
||||
with:
|
||||
name: whisper-blas-bin-${{ matrix.arch }}
|
||||
name: whisper-bin-${{ matrix.arch }}
|
||||
path: build/bin/${{ matrix.build }}
|
||||
|
||||
emscripten:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
build: [Release]
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v1
|
||||
|
||||
- name: Dependencies
|
||||
run: |
|
||||
wget -q https://github.com/emscripten-core/emsdk/archive/master.tar.gz
|
||||
tar -xvf master.tar.gz
|
||||
emsdk-master/emsdk update
|
||||
emsdk-master/emsdk install latest
|
||||
emsdk-master/emsdk activate latest
|
||||
|
||||
- name: Configure
|
||||
run: echo "tmp"
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
pushd emsdk-master
|
||||
source ./emsdk_env.sh
|
||||
popd
|
||||
emcmake cmake . -DCMAKE_BUILD_TYPE=${{ matrix.build }}
|
||||
make
|
||||
|
15
.gitignore
vendored
15
.gitignore
vendored
@ -8,24 +8,17 @@ build/
|
||||
build-em/
|
||||
build-debug/
|
||||
build-release/
|
||||
build-static/
|
||||
build-sanitize-addr/
|
||||
build-sanitize-thread/
|
||||
|
||||
/main
|
||||
/stream
|
||||
/command
|
||||
/talk
|
||||
/bench
|
||||
|
||||
main
|
||||
stream
|
||||
command
|
||||
bench
|
||||
sync.sh
|
||||
libwhisper.a
|
||||
libwhisper.so
|
||||
compile_commands.json
|
||||
|
||||
examples/arm_neon.h
|
||||
examples/whisper.objc/whisper.objc.xcodeproj/xcshareddata
|
||||
examples/whisper.objc/whisper.objc.xcodeproj/xcuserdata/
|
||||
examples/whisper.objc/whisper.objc.xcodeproj/project.xcworkspace/xcuserdata
|
||||
|
||||
extra/bench-gg.txt
|
||||
|
@ -1,22 +1,19 @@
|
||||
cmake_minimum_required (VERSION 3.0)
|
||||
project(whisper.cpp VERSION 1.0.0)
|
||||
|
||||
project(whisper.cpp VERSION 1.1.0)
|
||||
|
||||
# Add path to modules
|
||||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
|
||||
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS "on")
|
||||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
|
||||
set(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_PREFIX}/lib")
|
||||
|
||||
if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
|
||||
set(WHISPER_STANDALONE ON)
|
||||
include(GitVars)
|
||||
include(BuildTypes)
|
||||
include(cmake/GitVars.cmake)
|
||||
include(cmake/BuildTypes.cmake)
|
||||
|
||||
# configure project version
|
||||
if (EXISTS "${CMAKE_SOURCE_DIR}/bindings/ios/Makefile-tmpl")
|
||||
configure_file(${CMAKE_SOURCE_DIR}/bindings/ios/Makefile-tmpl ${CMAKE_SOURCE_DIR}/bindings/ios/Makefile @ONLY)
|
||||
endif()
|
||||
configure_file(${CMAKE_SOURCE_DIR}/bindings/javascript/package-tmpl.json ${CMAKE_SOURCE_DIR}/bindings/javascript/package.json @ONLY)
|
||||
else()
|
||||
set(WHISPER_STANDALONE OFF)
|
||||
endif()
|
||||
@ -53,7 +50,6 @@ if (APPLE)
|
||||
option(WHISPER_NO_ACCELERATE "whisper: disable Accelerate framework" OFF)
|
||||
option(WHISPER_NO_AVX "whisper: disable AVX" OFF)
|
||||
option(WHISPER_NO_AVX2 "whisper: disable AVX2" OFF)
|
||||
option(WHISPER_NO_FMA "whisper: disable FMA" OFF)
|
||||
else()
|
||||
option(WHISPER_SUPPORT_OPENBLAS "whisper: support for OpenBLAS" OFF)
|
||||
endif()
|
||||
@ -84,6 +80,9 @@ endif()
|
||||
|
||||
# dependencies
|
||||
|
||||
set(CMAKE_C_STANDARD 11)
|
||||
set(CMAKE_CXX_STANDARD 20)
|
||||
|
||||
find_package(Threads REQUIRED)
|
||||
|
||||
# on APPLE - include Accelerate framework
|
||||
@ -130,13 +129,6 @@ if (WHISPER_ALL_WARNINGS)
|
||||
-Wcast-qual \
|
||||
-Wstrict-prototypes \
|
||||
-Wpointer-arith \
|
||||
-Wno-unused-function \
|
||||
")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} \
|
||||
-Wall \
|
||||
-Wextra \
|
||||
-Wpedantic \
|
||||
-Wcast-qual \
|
||||
")
|
||||
else()
|
||||
# todo : msvc
|
||||
@ -157,10 +149,10 @@ else()
|
||||
if (MSVC)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX2")
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /arch:AVX2")
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /arch:AVX2")
|
||||
else()
|
||||
if (EMSCRIPTEN)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -pthread")
|
||||
# we require support for WASM SIMD 128-bit
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -pthread -msimd128")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread")
|
||||
else()
|
||||
if(NOT WHISPER_NO_AVX)
|
||||
@ -169,10 +161,7 @@ else()
|
||||
if(NOT WHISPER_NO_AVX2)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx2")
|
||||
endif()
|
||||
if(NOT WHISPER_NO_FMA)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mfma")
|
||||
endif()
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mf16c")
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mfma -mf16c")
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
@ -188,14 +177,10 @@ endif()
|
||||
set(TARGET whisper)
|
||||
|
||||
add_library(${TARGET}
|
||||
ggml.h
|
||||
ggml.c
|
||||
whisper.h
|
||||
whisper.cpp
|
||||
)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_include_directories(${TARGET} PUBLIC
|
||||
.
|
||||
)
|
||||
@ -218,10 +203,6 @@ if (BUILD_SHARED_LIBS)
|
||||
)
|
||||
endif()
|
||||
|
||||
if (EMSCRIPTEN)
|
||||
set_target_properties(${TARGET} PROPERTIES COMPILE_FLAGS "-msimd128")
|
||||
endif()
|
||||
|
||||
target_compile_definitions(${TARGET} PUBLIC
|
||||
${WHISPER_EXTRA_FLAGS}
|
||||
)
|
||||
@ -229,7 +210,6 @@ target_compile_definitions(${TARGET} PUBLIC
|
||||
install(TARGETS ${TARGET}
|
||||
LIBRARY DESTINATION lib
|
||||
ARCHIVE DESTINATION lib/static
|
||||
RUNTIME DESTINATION bin
|
||||
)
|
||||
|
||||
#
|
||||
@ -242,11 +222,13 @@ add_subdirectory(bindings)
|
||||
# programs, examples and tests
|
||||
#
|
||||
|
||||
if (WHISPER_BUILD_TESTS)
|
||||
enable_testing()
|
||||
add_subdirectory(tests)
|
||||
endif ()
|
||||
if (WHISPER_STANDALONE)
|
||||
if (WHISPER_BUILD_TESTS)
|
||||
enable_testing()
|
||||
add_subdirectory(tests)
|
||||
endif ()
|
||||
|
||||
if (WHISPER_BUILD_EXAMPLES)
|
||||
add_subdirectory(examples)
|
||||
endif()
|
||||
if (WHISPER_BUILD_EXAMPLES)
|
||||
add_subdirectory(examples)
|
||||
endif()
|
||||
endif ()
|
||||
|
69
Makefile
69
Makefile
@ -10,9 +10,6 @@ ifndef UNAME_M
|
||||
UNAME_M := $(shell uname -m)
|
||||
endif
|
||||
|
||||
CCV := $(shell $(CC) --version | head -n 1)
|
||||
CXXV := $(shell $(CXX) --version | head -n 1)
|
||||
|
||||
# Mac OS + Arm can report x86_64
|
||||
# ref: https://github.com/ggerganov/whisper.cpp/issues/66#issuecomment-1282546789
|
||||
ifeq ($(UNAME_S),Darwin)
|
||||
@ -30,8 +27,8 @@ endif
|
||||
# Compile flags
|
||||
#
|
||||
|
||||
CFLAGS = -I. -O3 -std=c11 -fPIC
|
||||
CXXFLAGS = -I. -I./examples -O3 -std=c++11 -fPIC
|
||||
CFLAGS = -I. -O3 -std=c11
|
||||
CXXFLAGS = -I. -I./examples -O3 -std=c++11
|
||||
LDFLAGS =
|
||||
|
||||
# OS specific
|
||||
@ -48,21 +45,14 @@ ifeq ($(UNAME_S),FreeBSD)
|
||||
CFLAGS += -pthread
|
||||
CXXFLAGS += -pthread
|
||||
endif
|
||||
ifeq ($(UNAME_S),Haiku)
|
||||
CFLAGS += -pthread
|
||||
CXXFLAGS += -pthread
|
||||
endif
|
||||
|
||||
# Architecture specific
|
||||
# TODO: probably these flags need to be tweaked on some architectures
|
||||
# feel free to update the Makefile for your architecture and send a pull request or issue
|
||||
ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686))
|
||||
ifeq ($(UNAME_M),x86_64)
|
||||
ifeq ($(UNAME_S),Darwin)
|
||||
CFLAGS += -mf16c
|
||||
CFLAGS += -mfma -mf16c
|
||||
AVX1_M := $(shell sysctl machdep.cpu.features)
|
||||
ifneq (,$(findstring FMA,$(AVX1_M)))
|
||||
CFLAGS += -mfma
|
||||
endif
|
||||
ifneq (,$(findstring AVX1.0,$(AVX1_M)))
|
||||
CFLAGS += -mavx
|
||||
endif
|
||||
@ -87,27 +77,6 @@ ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686))
|
||||
ifneq (,$(findstring f16c,$(F16C_M)))
|
||||
CFLAGS += -mf16c
|
||||
endif
|
||||
SSE3_M := $(shell grep "sse3 " /proc/cpuinfo)
|
||||
ifneq (,$(findstring sse3,$(SSE3_M)))
|
||||
CFLAGS += -msse3
|
||||
endif
|
||||
else ifeq ($(UNAME_S),Haiku)
|
||||
AVX1_M := $(shell sysinfo -cpu | grep "AVX ")
|
||||
ifneq (,$(findstring avx,$(AVX1_M)))
|
||||
CFLAGS += -mavx
|
||||
endif
|
||||
AVX2_M := $(shell sysinfo -cpu | grep "AVX2 ")
|
||||
ifneq (,$(findstring avx2,$(AVX2_M)))
|
||||
CFLAGS += -mavx2
|
||||
endif
|
||||
FMA_M := $(shell sysinfo -cpu | grep "FMA ")
|
||||
ifneq (,$(findstring fma,$(FMA_M)))
|
||||
CFLAGS += -mfma
|
||||
endif
|
||||
F16C_M := $(shell sysinfo -cpu | grep "F16C ")
|
||||
ifneq (,$(findstring f16c,$(F16C_M)))
|
||||
CFLAGS += -mf16c
|
||||
endif
|
||||
else
|
||||
CFLAGS += -mfma -mf16c -mavx -mavx2
|
||||
endif
|
||||
@ -115,12 +84,6 @@ endif
|
||||
ifeq ($(UNAME_M),amd64)
|
||||
CFLAGS += -mavx -mavx2 -mfma -mf16c
|
||||
endif
|
||||
ifeq ($(UNAME_M),ppc64le)
|
||||
POWER9_M := $(shell grep "POWER9" /proc/cpuinfo)
|
||||
ifneq (,$(findstring POWER9,$(POWER9_M)))
|
||||
CFLAGS += -mpower9-vector
|
||||
endif
|
||||
endif
|
||||
ifndef WHISPER_NO_ACCELERATE
|
||||
# Mac M1 - include Accelerate framework
|
||||
ifeq ($(UNAME_S),Darwin)
|
||||
@ -133,8 +96,8 @@ ifdef WHISPER_OPENBLAS
|
||||
LDFLAGS += -lopenblas
|
||||
endif
|
||||
ifdef WHISPER_GPROF
|
||||
CFLAGS += -pg
|
||||
CXXFLAGS += -pg
|
||||
CFLAGS += -pg
|
||||
CXXFLAGS += -pg
|
||||
endif
|
||||
ifneq ($(filter aarch64%,$(UNAME_M)),)
|
||||
endif
|
||||
@ -151,21 +114,6 @@ ifneq ($(filter armv8%,$(UNAME_M)),)
|
||||
CFLAGS += -mfp16-format=ieee -mno-unaligned-access
|
||||
endif
|
||||
|
||||
#
|
||||
# Print build information
|
||||
#
|
||||
|
||||
$(info I whisper.cpp build info: )
|
||||
$(info I UNAME_S: $(UNAME_S))
|
||||
$(info I UNAME_P: $(UNAME_P))
|
||||
$(info I UNAME_M: $(UNAME_M))
|
||||
$(info I CFLAGS: $(CFLAGS))
|
||||
$(info I CXXFLAGS: $(CXXFLAGS))
|
||||
$(info I LDFLAGS: $(LDFLAGS))
|
||||
$(info I CC: $(CCV))
|
||||
$(info I CXX: $(CXXV))
|
||||
$(info )
|
||||
|
||||
default: main
|
||||
|
||||
#
|
||||
@ -185,7 +133,7 @@ libwhisper.so: ggml.o whisper.o
|
||||
$(CXX) $(CXXFLAGS) -shared -o libwhisper.so ggml.o whisper.o $(LDFLAGS)
|
||||
|
||||
clean:
|
||||
rm -f *.o main stream command talk bench libwhisper.a libwhisper.so
|
||||
rm -f *.o main stream command bench libwhisper.a libwhisper.so
|
||||
|
||||
#
|
||||
# Examples
|
||||
@ -203,9 +151,6 @@ stream: examples/stream/stream.cpp ggml.o whisper.o
|
||||
command: examples/command/command.cpp ggml.o whisper.o
|
||||
$(CXX) $(CXXFLAGS) examples/command/command.cpp ggml.o whisper.o -o command $(CC_SDL) $(LDFLAGS)
|
||||
|
||||
talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp ggml.o whisper.o
|
||||
$(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp ggml.o whisper.o -o talk $(CC_SDL) $(LDFLAGS)
|
||||
|
||||
bench: examples/bench/bench.cpp ggml.o whisper.o
|
||||
$(CXX) $(CXXFLAGS) examples/bench/bench.cpp ggml.o whisper.o -o bench $(LDFLAGS)
|
||||
|
||||
|
99
README.md
99
README.md
@ -2,16 +2,12 @@
|
||||
|
||||
[](https://github.com/ggerganov/whisper.cpp/actions)
|
||||
[](https://opensource.org/licenses/MIT)
|
||||
[](https://www.npmjs.com/package/whisper.cpp/)
|
||||
|
||||
Stable: [v1.0.4](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.0.4) / Beta: [v1.1.0](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.1.0) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126)
|
||||
|
||||
High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model:
|
||||
|
||||
- Plain C/C++ implementation without dependencies
|
||||
- Apple silicon first-class citizen - optimized via Arm Neon and Accelerate framework
|
||||
- AVX intrinsics support for x86 architectures
|
||||
- VSX intrinsics support for POWER architectures
|
||||
- Mixed F16 / F32 precision
|
||||
- Low memory usage (Flash Attention + Flash Forward)
|
||||
- Zero memory allocations at runtime
|
||||
@ -22,11 +18,11 @@ Supported platforms:
|
||||
|
||||
- [x] Mac OS (Intel and Arm)
|
||||
- [x] [iOS](examples/whisper.objc)
|
||||
- [x] [Android](examples/whisper.android)
|
||||
- [x] Linux / [FreeBSD](https://github.com/ggerganov/whisper.cpp/issues/56#issuecomment-1350920264)
|
||||
- [x] Linux
|
||||
- [x] [WebAssembly](examples/whisper.wasm)
|
||||
- [x] Windows ([MSVC](https://github.com/ggerganov/whisper.cpp/blob/master/.github/workflows/build.yml#L117-L144) and [MinGW](https://github.com/ggerganov/whisper.cpp/issues/168)]
|
||||
- [x] [Raspberry Pi](https://github.com/ggerganov/whisper.cpp/discussions/166)
|
||||
- [x] [Android](https://github.com/ggerganov/whisper.cpp/issues/30)
|
||||
|
||||
The entire implementation of the model is contained in 2 source files:
|
||||
|
||||
@ -56,6 +52,21 @@ The tensor operators are optimized heavily for Apple silicon CPUs. Depending on
|
||||
instrisics or CBLAS Accelerate framework routines are used. The latter are especially effective for bigger sizes since
|
||||
the Accelerate framework utilizes the special-purpose AMX coprocessor available in modern Apple products.
|
||||
|
||||
## Limitations
|
||||
|
||||
- Inference only
|
||||
- No GPU support
|
||||
- Very basic greedy sampling scheme - always pick up the token with highest probability.
|
||||
This should be similar to the [GreedyDecoder](https://github.com/openai/whisper/blob/main/whisper/decoding.py#L249-L274)
|
||||
from the original python implementation, so in order to make a fair comparison between the 2 implementations, make sure
|
||||
to run the python code with the following parameters:
|
||||
|
||||
```
|
||||
whisper --best_of None --beam_size None ...
|
||||
```
|
||||
|
||||
In the future, `whisper.cpp` will support more sampling strategies.
|
||||
|
||||
## Quick start
|
||||
|
||||
First, download one of the Whisper models converted in [ggml format](models). For example:
|
||||
@ -71,7 +82,7 @@ Now build the [main](examples/main) example and transcribe an audio file like th
|
||||
make
|
||||
|
||||
# transcribe an audio file
|
||||
./main -f samples/jfk.wav
|
||||
./main -f input.wav
|
||||
```
|
||||
|
||||
---
|
||||
@ -89,36 +100,27 @@ c++ -I. -I./examples -O3 -std=c++11 -pthread examples/main/main.cpp whisper.o gg
|
||||
usage: ./main [options] file0.wav file1.wav ...
|
||||
|
||||
options:
|
||||
-h, --help [default] show this help message and exit
|
||||
-t N, --threads N [4 ] number of threads to use during computation
|
||||
-p N, --processors N [1 ] number of processors to use during computation
|
||||
-ot N, --offset-t N [0 ] time offset in milliseconds
|
||||
-on N, --offset-n N [0 ] segment index offset
|
||||
-d N, --duration N [0 ] duration of audio to process in milliseconds
|
||||
-mc N, --max-context N [-1 ] maximum number of text context tokens to store
|
||||
-ml N, --max-len N [0 ] maximum segment length in characters
|
||||
-bo N, --best-of N [5 ] number of best candidates to keep
|
||||
-bs N, --beam-size N [-1 ] beam size for beam search
|
||||
-wt N, --word-thold N [0.01 ] word timestamp probability threshold
|
||||
-et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail
|
||||
-lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail
|
||||
-su, --speed-up [false ] speed up audio by x2 (reduced accuracy)
|
||||
-tr, --translate [false ] translate from source language to english
|
||||
-di, --diarize [false ] stereo audio diarization
|
||||
-otxt, --output-txt [false ] output result in a text file
|
||||
-ovtt, --output-vtt [false ] output result in a vtt file
|
||||
-osrt, --output-srt [false ] output result in a srt file
|
||||
-owts, --output-words [false ] output script for generating karaoke video
|
||||
-ocsv, --output-csv [false ] output result in a CSV file
|
||||
-ps, --print-special [false ] print special tokens
|
||||
-pc, --print-colors [false ] print colors
|
||||
-pp, --print-progress [false ] print progress
|
||||
-nt, --no-timestamps [true ] do not print timestamps
|
||||
-l LANG, --language LANG [en ] spoken language ('auto' for auto-detect)
|
||||
--prompt PROMPT [ ] initial prompt
|
||||
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path
|
||||
-f FNAME, --file FNAME [ ] input WAV file path
|
||||
|
||||
-h, --help [default] show this help message and exit
|
||||
-t N, --threads N [4 ] number of threads to use during computation
|
||||
-p N, --processors N [1 ] number of processors to use during computation
|
||||
-ot N, --offset-t N [0 ] time offset in milliseconds
|
||||
-on N, --offset-n N [0 ] segment index offset
|
||||
-d N, --duration N [0 ] duration of audio to process in milliseconds
|
||||
-mc N, --max-context N [-1 ] maximum number of text context tokens to store
|
||||
-ml N, --max-len N [0 ] maximum segment length in characters
|
||||
-wt N, --word-thold N [0.01 ] word timestamp probability threshold
|
||||
-su, --speed-up [false ] speed up audio by x2 (reduced accuracy)
|
||||
-tr, --translate [false ] translate from source language to english
|
||||
-otxt, --output-txt [false ] output result in a text file
|
||||
-ovtt, --output-vtt [false ] output result in a vtt file
|
||||
-osrt, --output-srt [false ] output result in a srt file
|
||||
-owts, --output-words [false ] output script for generating karaoke video
|
||||
-ps, --print-special [false ] print special tokens
|
||||
-pc, --print-colors [false ] print colors
|
||||
-nt, --no-timestamps [true ] do not print timestamps
|
||||
-l LANG, --language LANG [en ] spoken language
|
||||
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path
|
||||
-f FNAME, --file FNAME [ ] input WAV file path
|
||||
|
||||
bash ./models/download-ggml-model.sh base.en
|
||||
Downloading ggml model base.en ...
|
||||
@ -218,11 +220,6 @@ make large
|
||||
| medium | 1.5 GB | ~2.6 GB | `fd9727b6e1217c2f614f9b698455c4ffd82463b4` |
|
||||
| large | 2.9 GB | ~4.7 GB | `0f4c8e34f21cf1a914c59d8b3ce882345ad349d6` |
|
||||
|
||||
## Limitations
|
||||
|
||||
- Inference only
|
||||
- No GPU support (yet)
|
||||
|
||||
## Another example
|
||||
|
||||
Here is another example of transcribing a [3:24 min speech](https://upload.wikimedia.org/wikipedia/commons/1/1f/George_W_Bush_Columbia_FINAL.ogg)
|
||||
@ -306,7 +303,6 @@ The [stream](examples/stream) tool samples the audio every half a second and run
|
||||
More info is available in [issue #10](https://github.com/ggerganov/whisper.cpp/issues/10).
|
||||
|
||||
```java
|
||||
make stream
|
||||
./stream -m ./models/ggml-base.en.bin -t 8 --step 500 --length 5000
|
||||
```
|
||||
|
||||
@ -448,13 +444,12 @@ or manually from here:
|
||||
For more details, see the conversion script [models/convert-pt-to-ggml.py](models/convert-pt-to-ggml.py) or the README
|
||||
in [models](models).
|
||||
|
||||
## [Bindings](https://github.com/ggerganov/whisper.cpp/discussions/categories/bindings)
|
||||
## Bindings
|
||||
|
||||
- [X] Rust: [tazz4843/whisper-rs](https://github.com/tazz4843/whisper-rs) | [#310](https://github.com/ggerganov/whisper.cpp/discussions/310)
|
||||
- [X] Javascript: [bindings/javascript](bindings/javascript) | [#309](https://github.com/ggerganov/whisper.cpp/discussions/309)
|
||||
- [X] Go: [bindings/go](bindings/go) | [#312](https://github.com/ggerganov/whisper.cpp/discussions/312)
|
||||
- [X] Objective-C / Swift: [ggerganov/whisper.spm](https://github.com/ggerganov/whisper.spm) | [#313](https://github.com/ggerganov/whisper.cpp/discussions/313)
|
||||
- [ ] Python: soon | [WIP](https://github.com/ggerganov/whisper.cpp/issues/9)
|
||||
- [X] Rust: [tazz4843/whisper-rs](https://github.com/tazz4843/whisper-rs)
|
||||
- [X] Objective-C / Swift: [ggerganov/whisper.spm](https://github.com/ggerganov/whisper.spm)
|
||||
- [ ] Python:
|
||||
- [ ] Java:
|
||||
|
||||
## Examples
|
||||
|
||||
@ -464,13 +459,11 @@ Some of the examples are even ported to run in the browser using WebAssembly. Ch
|
||||
| Example | Web | Description |
|
||||
| --- | --- | --- |
|
||||
| [main](examples/main) | [whisper.wasm](examples/whisper.wasm) | Tool for translating and transcribing audio using Whisper |
|
||||
| [bench](examples/bench) | [bench.wasm](examples/bench.wasm) | Benchmark the performance of Whisper on your machine |
|
||||
| [bench](examples/bench) | | Benchmark the performance of Whisper on your machine |
|
||||
| [stream](examples/stream) | [stream.wasm](examples/stream.wasm) | Real-time transcription of raw microphone capture |
|
||||
| [command](examples/command) | [command.wasm](examples/command.wasm) | Basic voice assistant example for receiving voice commands from the mic |
|
||||
| [talk](examples/talk) | [talk.wasm](examples/talk.wasm) | Talk with a GPT-2 bot |
|
||||
| | [talk.wasm](examples/talk.wasm) | Talk with a GPT-2 bot in your browser |
|
||||
| [whisper.objc](examples/whisper.objc) | | iOS mobile application using whisper.cpp |
|
||||
| [whisper.swiftui](examples/whisper.swiftui) | | SwiftUI iOS / macOS application using whisper.cpp |
|
||||
| [whisper.android](examples/whisper.android) | | Android mobile application using whisper.cpp |
|
||||
| [whisper.nvim](examples/whisper.nvim) | | Speech-to-text plugin for Neovim |
|
||||
| [generate-karaoke.sh](examples/generate-karaoke.sh) | | Helper script to easily [generate a karaoke video](https://youtu.be/uj7hVta4blM) of raw audio capture |
|
||||
| [livestream.sh](examples/livestream.sh) | | [Livestream audio transcription](https://github.com/ggerganov/whisper.cpp/issues/185) |
|
||||
|
@ -1,19 +1,3 @@
|
||||
if (EMSCRIPTEN)
|
||||
add_subdirectory(javascript)
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${CMAKE_CURRENT_SOURCE_DIR}/javascript/publish.log
|
||||
DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/javascript/whisper.js
|
||||
DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/javascript/libwhisper.worker.js
|
||||
DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/javascript/package.json
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/javascript
|
||||
COMMAND npm publish
|
||||
COMMAND touch publish.log
|
||||
COMMENT "Publishing npm module v${PROJECT_VERSION}"
|
||||
VERBATIM
|
||||
)
|
||||
|
||||
add_custom_target(publish-npm
|
||||
DEPENDS javascript/publish.log
|
||||
)
|
||||
endif()
|
||||
|
2
bindings/go/.gitignore
vendored
2
bindings/go/.gitignore
vendored
@ -1,2 +0,0 @@
|
||||
build
|
||||
models
|
@ -1,21 +0,0 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2022 David Thorpe
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
@ -1,38 +0,0 @@
|
||||
BUILD_DIR := build
|
||||
MODELS_DIR := models
|
||||
EXAMPLES_DIR := $(wildcard examples/*)
|
||||
INCLUDE_PATH := $(abspath ../..)
|
||||
LIBRARY_PATH := $(abspath ../..)
|
||||
|
||||
all: clean whisper examples
|
||||
|
||||
whisper: mkdir
|
||||
@echo Build whisper
|
||||
@${MAKE} -C ../.. libwhisper.a
|
||||
|
||||
test: model-small whisper modtidy
|
||||
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go test -v .
|
||||
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go test -v ./pkg/whisper/...
|
||||
|
||||
examples: $(EXAMPLES_DIR)
|
||||
|
||||
model-small: mkdir examples/go-model-download
|
||||
@${BUILD_DIR}/go-model-download -out models ggml-small.en.bin
|
||||
|
||||
$(EXAMPLES_DIR): mkdir whisper modtidy
|
||||
@echo Build example $(notdir $@)
|
||||
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go build ${BUILD_FLAGS} -o ${BUILD_DIR}/$(notdir $@) ./$@
|
||||
|
||||
mkdir:
|
||||
@echo Mkdir ${BUILD_DIR}
|
||||
@install -d ${BUILD_DIR}
|
||||
@echo Mkdir ${MODELS_DIR}
|
||||
@install -d ${MODELS_DIR}
|
||||
|
||||
modtidy:
|
||||
@go mod tidy
|
||||
|
||||
clean:
|
||||
@echo Clean
|
||||
@rm -fr $(BUILD_DIR)
|
||||
@go clean
|
@ -1,100 +0,0 @@
|
||||
# Go bindings for Whisper
|
||||
|
||||
This package provides Go bindings for whisper.cpp. They have been tested on:
|
||||
|
||||
* Darwin (OS X) 12.6 on x64_64
|
||||
* Debian Linux on arm64
|
||||
* Fedora Linux on x86_64
|
||||
|
||||
The "low level" bindings are in the `bindings/go` directory and there is a more
|
||||
Go-style package in the `bindings/go/pkg/whisper` directory. The most simple usage
|
||||
is as follows:
|
||||
|
||||
```go
|
||||
import (
|
||||
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
||||
)
|
||||
|
||||
func main() {
|
||||
var modelpath string // Path to the model
|
||||
var samples []float32 // Samples to process
|
||||
|
||||
// Load the model
|
||||
model, err := whisper.New(modelpath)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer model.Close()
|
||||
|
||||
// Process samples
|
||||
context, err := model.NewContext()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := context.Process(samples, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Print out the results
|
||||
for {
|
||||
segment, err := context.NextSegment()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
fmt.Printf("[%6s->%6s] %s\n", segment.Start, segment.End, segment.Text)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Building & Testing
|
||||
|
||||
In order to build, you need to have the Go compiler installed. You can get it from [here](https://golang.org/dl/). Run the tests with:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/ggerganov/whisper.cpp.git
|
||||
cd whisper.cpp/bindings/go
|
||||
make test
|
||||
```
|
||||
|
||||
This will compile a static `libwhisper.a` in a `build` folder, download a model file, then run the tests. To build the examples:
|
||||
|
||||
```bash
|
||||
make examples
|
||||
```
|
||||
|
||||
The examples are placed in the `build` directory. Once built, you can download all the models with the following command:
|
||||
|
||||
```bash
|
||||
./build/go-model-download -out models
|
||||
```
|
||||
|
||||
And you can then test a model against samples with the following command:
|
||||
|
||||
```bash
|
||||
./build/go-whisper -model models/ggml-tiny.en.bin samples/jfk.wav
|
||||
```
|
||||
|
||||
## Using the bindings
|
||||
|
||||
To use the bindings in your own software,
|
||||
|
||||
1. Import `github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper` (or `github.com/ggerganov/whisper.cpp/bindings/go` into your package;
|
||||
2. Compile `libwhisper.a` (you can use `make whisper` in the `bindings/go` directory);
|
||||
3. Link your go binary against whisper by setting the environment variables `C_INCLUDE_PATH` and `LIBRARY_PATH`
|
||||
to point to the `whisper.h` file directory and `libwhisper.a` file directory respectively.
|
||||
|
||||
Look at the `Makefile` in the `bindings/go` directory for an example.
|
||||
|
||||
The API Documentation:
|
||||
|
||||
* https://pkg.go.dev/github.com/ggerganov/whisper.cpp/bindings/go
|
||||
* https://pkg.go.dev/github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper
|
||||
|
||||
Getting help:
|
||||
|
||||
* Follow the discussion for the go bindings [here](https://github.com/ggerganov/whisper.cpp/discussions/312)
|
||||
|
||||
## License
|
||||
|
||||
The license for the Go bindings is the same as the license for the rest of the whisper.cpp project, which is the MIT License. See the `LICENSE` file for more details.
|
||||
|
@ -1,5 +0,0 @@
|
||||
/*
|
||||
github.com/ggerganov/whisper.cpp/bindings/go
|
||||
provides a speech-to-text service bindings for the Go programming language.
|
||||
*/
|
||||
package whisper
|
@ -1,30 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"os/signal"
|
||||
)
|
||||
|
||||
// ContextForSignal returns a context object which is cancelled when a signal
|
||||
// is received. It returns nil if no signal parameter is provided
|
||||
func ContextForSignal(signals ...os.Signal) context.Context {
|
||||
if len(signals) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
ch := make(chan os.Signal)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// Send message on channel when signal received
|
||||
signal.Notify(ch, signals...)
|
||||
|
||||
// When any signal received, call cancel
|
||||
go func() {
|
||||
<-ch
|
||||
cancel()
|
||||
}()
|
||||
|
||||
// Return success
|
||||
return ctx
|
||||
}
|
@ -1,208 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// CONSTANTS
|
||||
|
||||
const (
|
||||
srcUrl = "https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main" // The location of the models
|
||||
srcExt = ".bin" // Filename extension
|
||||
bufSize = 1024 * 64 // Size of the buffer used for downloading the model
|
||||
)
|
||||
|
||||
var (
|
||||
// The models which will be downloaded, if no model is specified as an argument
|
||||
modelNames = []string{"ggml-tiny.en", "ggml-tiny", "ggml-base.en", "ggml-base", "ggml-small.en", "ggml-small", "ggml-medium.en", "ggml-medium", "ggml-large-v1", "ggml-large"}
|
||||
)
|
||||
|
||||
var (
|
||||
// The output folder. When not set, use current working directory.
|
||||
flagOut = flag.String("out", "", "Output folder")
|
||||
|
||||
// HTTP timeout parameter - will timeout if takes longer than this to download a model
|
||||
flagTimeout = flag.Duration("timeout", 30*time.Minute, "HTTP timeout")
|
||||
|
||||
// Quiet parameter - will not print progress if set
|
||||
flagQuiet = flag.Bool("quiet", false, "Quiet mode")
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MAIN
|
||||
|
||||
func main() {
|
||||
flag.Usage = func() {
|
||||
name := filepath.Base(flag.CommandLine.Name())
|
||||
fmt.Fprintf(flag.CommandLine.Output(), "Usage: %s [options] <model>\n\n", name)
|
||||
flag.PrintDefaults()
|
||||
}
|
||||
flag.Parse()
|
||||
|
||||
// Get output path
|
||||
out, err := GetOut()
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "Error:", err)
|
||||
os.Exit(-1)
|
||||
}
|
||||
|
||||
// Create context which quits on SIGINT or SIGQUIT
|
||||
ctx := ContextForSignal(os.Interrupt, syscall.SIGQUIT)
|
||||
|
||||
// Progress filehandle
|
||||
progress := os.Stdout
|
||||
if *flagQuiet {
|
||||
progress, err = os.Open(os.DevNull)
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "Error:", err)
|
||||
os.Exit(-1)
|
||||
}
|
||||
defer progress.Close()
|
||||
}
|
||||
|
||||
// Download models - exit on error or interrupt
|
||||
for _, model := range GetModels() {
|
||||
url, err := URLForModel(model)
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "Error:", err)
|
||||
continue
|
||||
} else if path, err := Download(ctx, progress, url, out); err == nil || err == io.EOF {
|
||||
continue
|
||||
} else if err == context.Canceled {
|
||||
os.Remove(path)
|
||||
fmt.Fprintln(progress, "\nInterrupted")
|
||||
break
|
||||
} else if err == context.DeadlineExceeded {
|
||||
os.Remove(path)
|
||||
fmt.Fprintln(progress, "Timeout downloading model")
|
||||
continue
|
||||
} else {
|
||||
os.Remove(path)
|
||||
fmt.Fprintln(os.Stderr, "Error:", err)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// PUBLIC METHODS
|
||||
|
||||
// GetOut returns the path to the output directory
|
||||
func GetOut() (string, error) {
|
||||
if *flagOut == "" {
|
||||
return os.Getwd()
|
||||
}
|
||||
if info, err := os.Stat(*flagOut); err != nil {
|
||||
return "", err
|
||||
} else if !info.IsDir() {
|
||||
return "", fmt.Errorf("not a directory: %s", info.Name())
|
||||
} else {
|
||||
return *flagOut, nil
|
||||
}
|
||||
}
|
||||
|
||||
// GetModels returns the list of models to download
|
||||
func GetModels() []string {
|
||||
if flag.NArg() == 0 {
|
||||
return modelNames
|
||||
} else {
|
||||
return flag.Args()
|
||||
}
|
||||
}
|
||||
|
||||
// URLForModel returns the URL for the given model on huggingface.co
|
||||
func URLForModel(model string) (string, error) {
|
||||
if filepath.Ext(model) != srcExt {
|
||||
model += srcExt
|
||||
}
|
||||
url, err := url.Parse(srcUrl)
|
||||
if err != nil {
|
||||
return "", err
|
||||
} else {
|
||||
url.Path = filepath.Join(url.Path, model)
|
||||
}
|
||||
return url.String(), nil
|
||||
}
|
||||
|
||||
// Download downloads the model from the given URL to the given output directory
|
||||
func Download(ctx context.Context, p io.Writer, model, out string) (string, error) {
|
||||
// Create HTTP client
|
||||
client := http.Client{
|
||||
Timeout: *flagTimeout,
|
||||
}
|
||||
|
||||
// Initiate the download
|
||||
req, err := http.NewRequest("GET", model, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("%s: %s", model, resp.Status)
|
||||
}
|
||||
|
||||
// If output file exists and is the same size as the model, skip
|
||||
path := filepath.Join(out, filepath.Base(model))
|
||||
if info, err := os.Stat(path); err == nil && info.Size() == resp.ContentLength {
|
||||
fmt.Fprintln(p, "Skipping", model, "as it already exists")
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Create file
|
||||
w, err := os.Create(path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer w.Close()
|
||||
|
||||
// Report
|
||||
fmt.Fprintln(p, "Downloading", model, "to", out)
|
||||
|
||||
// Progressively download the model
|
||||
data := make([]byte, bufSize)
|
||||
count, pct := int64(0), int64(0)
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Cancelled, return error
|
||||
return path, ctx.Err()
|
||||
case <-ticker.C:
|
||||
pct = DownloadReport(p, pct, count, resp.ContentLength)
|
||||
default:
|
||||
// Read body
|
||||
n, err := resp.Body.Read(data)
|
||||
if err != nil {
|
||||
DownloadReport(p, pct, count, resp.ContentLength)
|
||||
return path, err
|
||||
} else if m, err := w.Write(data[:n]); err != nil {
|
||||
return path, err
|
||||
} else {
|
||||
count += int64(m)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Report periodically reports the download progress when percentage changes
|
||||
func DownloadReport(w io.Writer, pct, count, total int64) int64 {
|
||||
pct_ := count * 100 / total
|
||||
if pct_ > pct {
|
||||
fmt.Fprintf(w, " ...%d MB written (%d%%)\n", count/1e6, pct_)
|
||||
}
|
||||
return pct_
|
||||
}
|
@ -1,22 +0,0 @@
|
||||
package main
|
||||
|
||||
import "fmt"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// CONSTANTS
|
||||
|
||||
const (
|
||||
Reset = "\033[0m"
|
||||
RGBPrefix = "\033[38;5;" // followed by RGB values in decimal format separated by colons
|
||||
RGBSuffix = "m"
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// PUBLIC METHODS
|
||||
|
||||
// Colorize text with RGB values, from 0 to 23
|
||||
func Colorize(text string, v int) string {
|
||||
// https://en.wikipedia.org/wiki/ANSI_escape_code#8-bit
|
||||
// Grayscale colors are in the range 232-255
|
||||
return RGBPrefix + fmt.Sprint(v%24+232) + RGBSuffix + text + Reset
|
||||
}
|
@ -1,156 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
// Packages
|
||||
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// TYPES
|
||||
|
||||
type Flags struct {
|
||||
*flag.FlagSet
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// LIFECYCLE
|
||||
|
||||
func NewFlags(name string, args []string) (*Flags, error) {
|
||||
flags := &Flags{
|
||||
FlagSet: flag.NewFlagSet(name, flag.ContinueOnError),
|
||||
}
|
||||
|
||||
// Register the command line arguments
|
||||
registerFlags(flags)
|
||||
|
||||
// Parse command line
|
||||
if err := flags.Parse(args); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Return success
|
||||
return flags, nil
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// PUBLIC METHODS
|
||||
|
||||
func (flags *Flags) GetModel() string {
|
||||
return flags.Lookup("model").Value.String()
|
||||
}
|
||||
|
||||
func (flags *Flags) GetLanguage() string {
|
||||
return flags.Lookup("language").Value.String()
|
||||
}
|
||||
|
||||
func (flags *Flags) IsTranslate() bool {
|
||||
return flags.Lookup("translate").Value.(flag.Getter).Get().(bool)
|
||||
}
|
||||
|
||||
func (flags *Flags) GetOffset() time.Duration {
|
||||
return flags.Lookup("offset").Value.(flag.Getter).Get().(time.Duration)
|
||||
}
|
||||
|
||||
func (flags *Flags) GetDuration() time.Duration {
|
||||
return flags.Lookup("duration").Value.(flag.Getter).Get().(time.Duration)
|
||||
}
|
||||
|
||||
func (flags *Flags) GetThreads() uint {
|
||||
return flags.Lookup("threads").Value.(flag.Getter).Get().(uint)
|
||||
}
|
||||
|
||||
func (flags *Flags) GetOut() string {
|
||||
return strings.ToLower(flags.Lookup("out").Value.String())
|
||||
}
|
||||
|
||||
func (flags *Flags) IsSpeedup() bool {
|
||||
return flags.Lookup("speedup").Value.String() == "true"
|
||||
}
|
||||
|
||||
func (flags *Flags) IsTokens() bool {
|
||||
return flags.Lookup("tokens").Value.String() == "true"
|
||||
}
|
||||
|
||||
func (flags *Flags) IsColorize() bool {
|
||||
return flags.Lookup("colorize").Value.String() == "true"
|
||||
}
|
||||
|
||||
func (flags *Flags) GetMaxLen() uint {
|
||||
return flags.Lookup("max-len").Value.(flag.Getter).Get().(uint)
|
||||
}
|
||||
|
||||
func (flags *Flags) GetMaxTokens() uint {
|
||||
return flags.Lookup("max-tokens").Value.(flag.Getter).Get().(uint)
|
||||
}
|
||||
|
||||
func (flags *Flags) GetWordThreshold() float32 {
|
||||
return float32(flags.Lookup("word-thold").Value.(flag.Getter).Get().(float64))
|
||||
}
|
||||
|
||||
func (flags *Flags) SetParams(context whisper.Context) error {
|
||||
if lang := flags.GetLanguage(); lang != "" && lang != "auto" {
|
||||
fmt.Fprintf(flags.Output(), "Setting language to %q\n", lang)
|
||||
if err := context.SetLanguage(lang); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if flags.IsTranslate() && context.IsMultilingual() {
|
||||
fmt.Fprintf(flags.Output(), "Setting translate to true\n")
|
||||
context.SetTranslate(true)
|
||||
}
|
||||
if offset := flags.GetOffset(); offset != 0 {
|
||||
fmt.Fprintf(flags.Output(), "Setting offset to %v\n", offset)
|
||||
context.SetOffset(offset)
|
||||
}
|
||||
if duration := flags.GetDuration(); duration != 0 {
|
||||
fmt.Fprintf(flags.Output(), "Setting duration to %v\n", duration)
|
||||
context.SetDuration(duration)
|
||||
}
|
||||
if flags.IsSpeedup() {
|
||||
fmt.Fprintf(flags.Output(), "Setting speedup to true\n")
|
||||
context.SetSpeedup(true)
|
||||
}
|
||||
if threads := flags.GetThreads(); threads != 0 {
|
||||
fmt.Fprintf(flags.Output(), "Setting threads to %d\n", threads)
|
||||
context.SetThreads(threads)
|
||||
}
|
||||
if max_len := flags.GetMaxLen(); max_len != 0 {
|
||||
fmt.Fprintf(flags.Output(), "Setting max_segment_length to %d\n", max_len)
|
||||
context.SetMaxSegmentLength(max_len)
|
||||
}
|
||||
if max_tokens := flags.GetMaxTokens(); max_tokens != 0 {
|
||||
fmt.Fprintf(flags.Output(), "Setting max_tokens to %d\n", max_tokens)
|
||||
context.SetMaxTokensPerSegment(max_tokens)
|
||||
}
|
||||
if word_threshold := flags.GetWordThreshold(); word_threshold != 0 {
|
||||
fmt.Fprintf(flags.Output(), "Setting word_threshold to %f\n", word_threshold)
|
||||
context.SetTokenThreshold(word_threshold)
|
||||
}
|
||||
|
||||
// Return success
|
||||
return nil
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// PRIVATE METHODS
|
||||
|
||||
func registerFlags(flag *Flags) {
|
||||
flag.String("model", "", "Path to the model file")
|
||||
flag.String("language", "", "Spoken language")
|
||||
flag.Bool("translate", false, "Translate from source language to english")
|
||||
flag.Duration("offset", 0, "Time offset")
|
||||
flag.Duration("duration", 0, "Duration of audio to process")
|
||||
flag.Uint("threads", 0, "Number of threads to use")
|
||||
flag.Bool("speedup", false, "Enable speedup")
|
||||
flag.Uint("max-len", 0, "Maximum segment length in characters")
|
||||
flag.Uint("max-tokens", 0, "Maximum tokens per segment")
|
||||
flag.Float64("word-thold", 0, "Maximum segment score")
|
||||
flag.Bool("tokens", false, "Display tokens")
|
||||
flag.Bool("colorize", false, "Colorize tokens")
|
||||
flag.String("out", "", "Output format (srt, none or leave as empty string)")
|
||||
}
|
@ -1,43 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
// Packages
|
||||
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
||||
)
|
||||
|
||||
func main() {
|
||||
flags, err := NewFlags(filepath.Base(os.Args[0]), os.Args[1:])
|
||||
if err == flag.ErrHelp {
|
||||
os.Exit(0)
|
||||
} else if err != nil {
|
||||
fmt.Fprintln(os.Stderr, err)
|
||||
os.Exit(1)
|
||||
} else if flags.GetModel() == "" {
|
||||
fmt.Fprintln(os.Stderr, "Use -model flag to specify which model file to use")
|
||||
os.Exit(1)
|
||||
} else if flags.NArg() == 0 {
|
||||
fmt.Fprintln(os.Stderr, "No input files specified")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Load model
|
||||
model, err := whisper.New(flags.GetModel())
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer model.Close()
|
||||
|
||||
// Process files
|
||||
for _, filename := range flags.Args() {
|
||||
if err := Process(model, filename, flags); err != nil {
|
||||
fmt.Fprintln(os.Stderr, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
@ -1,127 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
// Package imports
|
||||
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
||||
wav "github.com/go-audio/wav"
|
||||
)
|
||||
|
||||
func Process(model whisper.Model, path string, flags *Flags) error {
|
||||
var data []float32
|
||||
|
||||
// Create processing context
|
||||
context, err := model.NewContext()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Set the parameters
|
||||
if err := flags.SetParams(context); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Open the file
|
||||
fmt.Fprintf(flags.Output(), "Loading %q\n", path)
|
||||
fh, err := os.Open(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer fh.Close()
|
||||
|
||||
// Decode the WAV file - load the full buffer
|
||||
dec := wav.NewDecoder(fh)
|
||||
if buf, err := dec.FullPCMBuffer(); err != nil {
|
||||
return err
|
||||
} else if dec.SampleRate != whisper.SampleRate {
|
||||
return fmt.Errorf("unsupported sample rate: %d", dec.SampleRate)
|
||||
} else if dec.NumChans != 1 {
|
||||
return fmt.Errorf("unsupported number of channels: %d", dec.NumChans)
|
||||
} else {
|
||||
data = buf.AsFloat32Buffer().Data
|
||||
}
|
||||
|
||||
// Segment callback when -tokens is specified
|
||||
var cb whisper.SegmentCallback
|
||||
if flags.IsTokens() {
|
||||
cb = func(segment whisper.Segment) {
|
||||
fmt.Fprintf(flags.Output(), "%02d [%6s->%6s] ", segment.Num, segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond))
|
||||
for _, token := range segment.Tokens {
|
||||
if flags.IsColorize() && context.IsText(token) {
|
||||
fmt.Fprint(flags.Output(), Colorize(token.Text, int(token.P*24.0)), " ")
|
||||
} else {
|
||||
fmt.Fprint(flags.Output(), token.Text, " ")
|
||||
}
|
||||
}
|
||||
fmt.Fprintln(flags.Output(), "")
|
||||
fmt.Fprintln(flags.Output(), "")
|
||||
}
|
||||
}
|
||||
|
||||
// Process the data
|
||||
fmt.Fprintf(flags.Output(), " ...processing %q\n", path)
|
||||
if err := context.Process(data, cb); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Print out the results
|
||||
switch {
|
||||
case flags.GetOut() == "srt":
|
||||
return OutputSRT(os.Stdout, context)
|
||||
case flags.GetOut() == "none":
|
||||
return nil
|
||||
default:
|
||||
return Output(os.Stdout, context, flags.IsColorize())
|
||||
}
|
||||
}
|
||||
|
||||
// Output text as SRT file
|
||||
func OutputSRT(w io.Writer, context whisper.Context) error {
|
||||
n := 1
|
||||
for {
|
||||
segment, err := context.NextSegment()
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Fprintln(w, n)
|
||||
fmt.Fprintln(w, srtTimestamp(segment.Start), " --> ", srtTimestamp(segment.End))
|
||||
fmt.Fprintln(w, segment.Text)
|
||||
fmt.Fprintln(w, "")
|
||||
n++
|
||||
}
|
||||
}
|
||||
|
||||
// Output text to terminal
|
||||
func Output(w io.Writer, context whisper.Context, colorize bool) error {
|
||||
for {
|
||||
segment, err := context.NextSegment()
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Fprintf(w, "[%6s->%6s]", segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond))
|
||||
if colorize {
|
||||
for _, token := range segment.Tokens {
|
||||
if !context.IsText(token) {
|
||||
continue
|
||||
}
|
||||
fmt.Fprint(w, " ", Colorize(token.Text, int(token.P*24.0)))
|
||||
}
|
||||
fmt.Fprint(w, "\n")
|
||||
} else {
|
||||
fmt.Fprintln(w, " ", segment.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Return srtTimestamp
|
||||
func srtTimestamp(t time.Duration) string {
|
||||
return fmt.Sprintf("%02d:%02d:%02d,%03d", t/time.Hour, (t%time.Hour)/time.Minute, (t%time.Minute)/time.Second, (t%time.Second)/time.Millisecond)
|
||||
}
|
@ -1,16 +0,0 @@
|
||||
module github.com/ggerganov/whisper.cpp/bindings/go
|
||||
|
||||
go 1.19
|
||||
|
||||
require (
|
||||
github.com/go-audio/wav v1.1.0
|
||||
github.com/stretchr/testify v1.8.1
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/go-audio/audio v1.0.0 // indirect
|
||||
github.com/go-audio/riff v1.0.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
@ -1,23 +0,0 @@
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/go-audio/audio v1.0.0 h1:zS9vebldgbQqktK4H0lUqWrG8P0NxCJVqcj7ZpNnwd4=
|
||||
github.com/go-audio/audio v1.0.0/go.mod h1:6uAu0+H2lHkwdGsAY+j2wHPNPpPoeg5AaEFh9FlA+Zs=
|
||||
github.com/go-audio/riff v1.0.0 h1:d8iCGbDvox9BfLagY94fBynxSPHO80LmZCaOsmKxokA=
|
||||
github.com/go-audio/riff v1.0.0/go.mod h1:l3cQwc85y79NQFCRB7TiPoNiaijp6q8Z0Uv38rVG498=
|
||||
github.com/go-audio/wav v1.1.0 h1:jQgLtbqBzY7G+BM8fXF7AHUk1uHUviWS4X39d5rsL2g=
|
||||
github.com/go-audio/wav v1.1.0/go.mod h1:mpe9qfwbScEbkd8uybLuIpTgHyrISw/OTuvjUW2iGtE=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
@ -1,156 +0,0 @@
|
||||
package whisper
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// CGO
|
||||
|
||||
/*
|
||||
#include <whisper.h>
|
||||
*/
|
||||
import "C"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// PUBLIC METHODS
|
||||
|
||||
func (p *Params) SetTranslate(v bool) {
|
||||
p.translate = toBool(v)
|
||||
}
|
||||
|
||||
func (p *Params) SetNoContext(v bool) {
|
||||
p.no_context = toBool(v)
|
||||
}
|
||||
|
||||
func (p *Params) SetSingleSegment(v bool) {
|
||||
p.single_segment = toBool(v)
|
||||
}
|
||||
|
||||
func (p *Params) SetPrintSpecial(v bool) {
|
||||
p.print_special = toBool(v)
|
||||
}
|
||||
|
||||
func (p *Params) SetPrintProgress(v bool) {
|
||||
p.print_progress = toBool(v)
|
||||
}
|
||||
|
||||
func (p *Params) SetPrintRealtime(v bool) {
|
||||
p.print_realtime = toBool(v)
|
||||
}
|
||||
|
||||
func (p *Params) SetPrintTimestamps(v bool) {
|
||||
p.print_timestamps = toBool(v)
|
||||
}
|
||||
|
||||
func (p *Params) SetSpeedup(v bool) {
|
||||
p.speed_up = toBool(v)
|
||||
}
|
||||
|
||||
// Set language id
|
||||
func (p *Params) SetLanguage(lang int) error {
|
||||
str := C.whisper_lang_str(C.int(lang))
|
||||
if str == nil {
|
||||
return ErrInvalidLanguage
|
||||
} else {
|
||||
p.language = str
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get language id
|
||||
func (p *Params) Language() int {
|
||||
if p.language == nil {
|
||||
return -1
|
||||
}
|
||||
return int(C.whisper_lang_id(p.language))
|
||||
}
|
||||
|
||||
// Set number of threads to use
|
||||
func (p *Params) SetThreads(threads int) {
|
||||
p.n_threads = C.int(threads)
|
||||
}
|
||||
|
||||
// Set start offset in ms
|
||||
func (p *Params) SetOffset(offset_ms int) {
|
||||
p.offset_ms = C.int(offset_ms)
|
||||
}
|
||||
|
||||
// Set audio duration to process in ms
|
||||
func (p *Params) SetDuration(duration_ms int) {
|
||||
p.duration_ms = C.int(duration_ms)
|
||||
}
|
||||
|
||||
// Set timestamp token probability threshold (~0.01)
|
||||
func (p *Params) SetTokenThreshold(t float32) {
|
||||
p.thold_pt = C.float(t)
|
||||
}
|
||||
|
||||
// Set timestamp token sum probability threshold (~0.01)
|
||||
func (p *Params) SetTokenSumThreshold(t float32) {
|
||||
p.thold_ptsum = C.float(t)
|
||||
}
|
||||
|
||||
// Set max segment length in characters
|
||||
func (p *Params) SetMaxSegmentLength(n int) {
|
||||
p.max_len = C.int(n)
|
||||
}
|
||||
|
||||
// Set max tokens per segment (0 = no limit)
|
||||
func (p *Params) SetMaxTokensPerSegment(n int) {
|
||||
p.max_tokens = C.int(n)
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// PRIVATE METHODS
|
||||
|
||||
func toBool(v bool) C.bool {
|
||||
if v {
|
||||
return C.bool(true)
|
||||
}
|
||||
return C.bool(false)
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// STRINGIFY
|
||||
|
||||
func (p *Params) String() string {
|
||||
str := "<whisper.params"
|
||||
str += fmt.Sprintf(" strategy=%v", p.strategy)
|
||||
str += fmt.Sprintf(" n_threads=%d", p.n_threads)
|
||||
if p.language != nil {
|
||||
str += fmt.Sprintf(" language=%s", C.GoString(p.language))
|
||||
}
|
||||
str += fmt.Sprintf(" n_max_text_ctx=%d", p.n_max_text_ctx)
|
||||
str += fmt.Sprintf(" offset_ms=%d", p.offset_ms)
|
||||
str += fmt.Sprintf(" duration_ms=%d", p.duration_ms)
|
||||
if p.translate {
|
||||
str += " translate"
|
||||
}
|
||||
if p.no_context {
|
||||
str += " no_context"
|
||||
}
|
||||
if p.single_segment {
|
||||
str += " single_segment"
|
||||
}
|
||||
if p.print_special {
|
||||
str += " print_special"
|
||||
}
|
||||
if p.print_progress {
|
||||
str += " print_progress"
|
||||
}
|
||||
if p.print_realtime {
|
||||
str += " print_realtime"
|
||||
}
|
||||
if p.print_timestamps {
|
||||
str += " print_timestamps"
|
||||
}
|
||||
if p.token_timestamps {
|
||||
str += " token_timestamps"
|
||||
}
|
||||
if p.speed_up {
|
||||
str += " speed_up"
|
||||
}
|
||||
|
||||
return str + ">"
|
||||
}
|
@ -1,28 +0,0 @@
|
||||
package whisper
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
// Bindings
|
||||
whisper "github.com/ggerganov/whisper.cpp/bindings/go"
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// ERRORS
|
||||
|
||||
var (
|
||||
ErrUnableToLoadModel = errors.New("unable to load model")
|
||||
ErrInternalAppError = errors.New("internal application error")
|
||||
ErrProcessingFailed = errors.New("processing failed")
|
||||
ErrUnsupportedLanguage = errors.New("unsupported language")
|
||||
ErrModelNotMultilingual = errors.New("model is not multilingual")
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// CONSTANTS
|
||||
|
||||
// SampleRate is the sample rate of the audio data.
|
||||
const SampleRate = whisper.SampleRate
|
||||
|
||||
// SampleBits is the number of bytes per sample.
|
||||
const SampleBits = whisper.SampleBits
|
@ -1,251 +0,0 @@
|
||||
package whisper
|
||||
|
||||
import (
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
// Bindings
|
||||
whisper "github.com/ggerganov/whisper.cpp/bindings/go"
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// TYPES
|
||||
|
||||
type context struct {
|
||||
n int
|
||||
model *model
|
||||
params whisper.Params
|
||||
}
|
||||
|
||||
// Make sure context adheres to the interface
|
||||
var _ Context = (*context)(nil)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// LIFECYCLE
|
||||
|
||||
func newContext(model *model, params whisper.Params) (Context, error) {
|
||||
context := new(context)
|
||||
context.model = model
|
||||
context.params = params
|
||||
|
||||
// Return success
|
||||
return context, nil
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// PUBLIC METHODS
|
||||
|
||||
// Set the language to use for speech recognition.
|
||||
func (context *context) SetLanguage(lang string) error {
|
||||
if context.model.ctx == nil {
|
||||
return ErrInternalAppError
|
||||
}
|
||||
if !context.model.IsMultilingual() {
|
||||
return ErrModelNotMultilingual
|
||||
}
|
||||
if id := context.model.ctx.Whisper_lang_id(lang); id < 0 {
|
||||
return ErrUnsupportedLanguage
|
||||
} else if err := context.params.SetLanguage(id); err != nil {
|
||||
return err
|
||||
}
|
||||
// Return success
|
||||
return nil
|
||||
}
|
||||
|
||||
func (context *context) IsMultilingual() bool {
|
||||
return context.model.IsMultilingual()
|
||||
}
|
||||
|
||||
// Get language
|
||||
func (context *context) Language() string {
|
||||
return whisper.Whisper_lang_str(context.params.Language())
|
||||
}
|
||||
|
||||
// Set translate flag
|
||||
func (context *context) SetTranslate(v bool) {
|
||||
context.params.SetTranslate(v)
|
||||
}
|
||||
|
||||
// Set speedup flag
|
||||
func (context *context) SetSpeedup(v bool) {
|
||||
context.params.SetSpeedup(v)
|
||||
}
|
||||
|
||||
// Set number of threads to use
|
||||
func (context *context) SetThreads(v uint) {
|
||||
context.params.SetThreads(int(v))
|
||||
}
|
||||
|
||||
// Set time offset
|
||||
func (context *context) SetOffset(v time.Duration) {
|
||||
context.params.SetOffset(int(v.Milliseconds()))
|
||||
}
|
||||
|
||||
// Set duration of audio to process
|
||||
func (context *context) SetDuration(v time.Duration) {
|
||||
context.params.SetOffset(int(v.Milliseconds()))
|
||||
}
|
||||
|
||||
// Set timestamp token probability threshold (~0.01)
|
||||
func (context *context) SetTokenThreshold(t float32) {
|
||||
context.params.SetTokenThreshold(t)
|
||||
}
|
||||
|
||||
// Set timestamp token sum probability threshold (~0.01)
|
||||
func (context *context) SetTokenSumThreshold(t float32) {
|
||||
context.params.SetTokenSumThreshold(t)
|
||||
}
|
||||
|
||||
// Set max segment length in characters
|
||||
func (context *context) SetMaxSegmentLength(n uint) {
|
||||
context.params.SetMaxSegmentLength(int(n))
|
||||
}
|
||||
|
||||
// Set max tokens per segment (0 = no limit)
|
||||
func (context *context) SetMaxTokensPerSegment(n uint) {
|
||||
context.params.SetMaxTokensPerSegment(int(n))
|
||||
}
|
||||
|
||||
// Process new sample data and return any errors
|
||||
func (context *context) Process(data []float32, cb SegmentCallback) error {
|
||||
if context.model.ctx == nil {
|
||||
return ErrInternalAppError
|
||||
}
|
||||
// If the callback is defined then we force on single_segment mode
|
||||
if cb != nil {
|
||||
context.params.SetSingleSegment(true)
|
||||
}
|
||||
|
||||
// We don't do parallel processing at the moment
|
||||
processors := 0
|
||||
if processors > 1 {
|
||||
if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, nil, func(new int) {
|
||||
if cb != nil {
|
||||
num_segments := context.model.ctx.Whisper_full_n_segments()
|
||||
s0 := num_segments - new
|
||||
for i := s0; i < num_segments; i++ {
|
||||
cb(toSegment(context.model.ctx, i))
|
||||
}
|
||||
}
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if err := context.model.ctx.Whisper_full(context.params, data, nil, func(new int) {
|
||||
if cb != nil {
|
||||
num_segments := context.model.ctx.Whisper_full_n_segments()
|
||||
s0 := num_segments - new
|
||||
for i := s0; i < num_segments; i++ {
|
||||
cb(toSegment(context.model.ctx, i))
|
||||
}
|
||||
}
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Return success
|
||||
return nil
|
||||
}
|
||||
|
||||
// Return the next segment of tokens
|
||||
func (context *context) NextSegment() (Segment, error) {
|
||||
if context.model.ctx == nil {
|
||||
return Segment{}, ErrInternalAppError
|
||||
}
|
||||
if context.n >= context.model.ctx.Whisper_full_n_segments() {
|
||||
return Segment{}, io.EOF
|
||||
}
|
||||
|
||||
// Populate result
|
||||
result := toSegment(context.model.ctx, context.n)
|
||||
|
||||
// Increment the cursor
|
||||
context.n++
|
||||
|
||||
// Return success
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Test for text tokens
|
||||
func (context *context) IsText(t Token) bool {
|
||||
switch {
|
||||
case context.IsBEG(t):
|
||||
return false
|
||||
case context.IsSOT(t):
|
||||
return false
|
||||
case whisper.Token(t.Id) >= context.model.ctx.Whisper_token_eot():
|
||||
return false
|
||||
case context.IsPREV(t):
|
||||
return false
|
||||
case context.IsSOLM(t):
|
||||
return false
|
||||
case context.IsNOT(t):
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Test for "begin" token
|
||||
func (context *context) IsBEG(t Token) bool {
|
||||
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_beg()
|
||||
}
|
||||
|
||||
// Test for "start of transcription" token
|
||||
func (context *context) IsSOT(t Token) bool {
|
||||
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_sot()
|
||||
}
|
||||
|
||||
// Test for "end of transcription" token
|
||||
func (context *context) IsEOT(t Token) bool {
|
||||
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_eot()
|
||||
}
|
||||
|
||||
// Test for "start of prev" token
|
||||
func (context *context) IsPREV(t Token) bool {
|
||||
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_prev()
|
||||
}
|
||||
|
||||
// Test for "start of lm" token
|
||||
func (context *context) IsSOLM(t Token) bool {
|
||||
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_solm()
|
||||
}
|
||||
|
||||
// Test for "No timestamps" token
|
||||
func (context *context) IsNOT(t Token) bool {
|
||||
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_not()
|
||||
}
|
||||
|
||||
// Test for token associated with a specific language
|
||||
func (context *context) IsLANG(t Token, lang string) bool {
|
||||
if id := context.model.ctx.Whisper_lang_id(lang); id >= 0 {
|
||||
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_lang(id)
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// PRIVATE METHODS
|
||||
|
||||
func toSegment(ctx *whisper.Context, n int) Segment {
|
||||
return Segment{
|
||||
Num: n,
|
||||
Text: strings.TrimSpace(ctx.Whisper_full_get_segment_text(n)),
|
||||
Start: time.Duration(ctx.Whisper_full_get_segment_t0(n)) * time.Millisecond * 10,
|
||||
End: time.Duration(ctx.Whisper_full_get_segment_t1(n)) * time.Millisecond * 10,
|
||||
Tokens: toTokens(ctx, n),
|
||||
}
|
||||
}
|
||||
|
||||
func toTokens(ctx *whisper.Context, n int) []Token {
|
||||
result := make([]Token, ctx.Whisper_full_n_tokens(n))
|
||||
for i := 0; i < len(result); i++ {
|
||||
result[i] = Token{
|
||||
Id: int(ctx.Whisper_full_get_token_id(n, i)),
|
||||
Text: strings.TrimSpace(ctx.Whisper_full_get_token_text(n, i)),
|
||||
P: ctx.Whisper_full_get_token_p(n, i),
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
@ -1,55 +0,0 @@
|
||||
package whisper_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
// Packages
|
||||
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
||||
assert "github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const (
|
||||
ModelPath = "../../models/ggml-tiny.bin"
|
||||
SamplePath = "../../samples/jfk.wav"
|
||||
)
|
||||
|
||||
func Test_Whisper_000(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, model not found:", ModelPath)
|
||||
}
|
||||
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, sample not found:", SamplePath)
|
||||
}
|
||||
|
||||
// Load model
|
||||
model, err := whisper.New(ModelPath)
|
||||
assert.NoError(err)
|
||||
assert.NotNil(model)
|
||||
assert.NoError(model.Close())
|
||||
|
||||
t.Log("languages=", model.Languages())
|
||||
}
|
||||
|
||||
func Test_Whisper_001(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, model not found:", ModelPath)
|
||||
}
|
||||
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, sample not found:", SamplePath)
|
||||
}
|
||||
|
||||
// Load model
|
||||
model, err := whisper.New(ModelPath)
|
||||
assert.NoError(err)
|
||||
assert.NotNil(model)
|
||||
defer model.Close()
|
||||
|
||||
// Get context for decoding
|
||||
ctx, err := model.NewContext()
|
||||
assert.NoError(err)
|
||||
assert.NotNil(ctx)
|
||||
|
||||
}
|
@ -1,4 +0,0 @@
|
||||
/*
|
||||
This is the higher-level speech-to-text whisper.cpp API for go
|
||||
*/
|
||||
package whisper
|
@ -1,85 +0,0 @@
|
||||
package whisper
|
||||
|
||||
import (
|
||||
"io"
|
||||
"time"
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// TYPES
|
||||
|
||||
// SegmentCallback is the callback function for processing segments in real
|
||||
// time. It is called during the Process function
|
||||
type SegmentCallback func(Segment)
|
||||
|
||||
// Model is the interface to a whisper model. Create a new model with the
|
||||
// function whisper.New(string)
|
||||
type Model interface {
|
||||
io.Closer
|
||||
|
||||
// Return a new speech-to-text context.
|
||||
NewContext() (Context, error)
|
||||
|
||||
// Return true if the model is multilingual.
|
||||
IsMultilingual() bool
|
||||
|
||||
// Return all languages supported.
|
||||
Languages() []string
|
||||
}
|
||||
|
||||
// Context is the speach recognition context.
|
||||
type Context interface {
|
||||
SetLanguage(string) error // Set the language to use for speech recognition.
|
||||
SetTranslate(bool) // Set translate flag
|
||||
IsMultilingual() bool // Return true if the model is multilingual.
|
||||
Language() string // Get language
|
||||
|
||||
SetOffset(time.Duration) // Set offset
|
||||
SetDuration(time.Duration) // Set duration
|
||||
SetThreads(uint) // Set number of threads to use
|
||||
SetSpeedup(bool) // Set speedup flag
|
||||
SetTokenThreshold(float32) // Set timestamp token probability threshold
|
||||
SetTokenSumThreshold(float32) // Set timestamp token sum probability threshold
|
||||
SetMaxSegmentLength(uint) // Set max segment length in characters
|
||||
SetMaxTokensPerSegment(uint) // Set max tokens per segment (0 = no limit)
|
||||
|
||||
// Process mono audio data and return any errors.
|
||||
// If defined, newly generated segments are passed to the
|
||||
// callback function during processing.
|
||||
Process([]float32, SegmentCallback) error
|
||||
|
||||
// After process is called, return segments until the end of the stream
|
||||
// is reached, when io.EOF is returned.
|
||||
NextSegment() (Segment, error)
|
||||
|
||||
IsBEG(Token) bool // Test for "begin" token
|
||||
IsSOT(Token) bool // Test for "start of transcription" token
|
||||
IsEOT(Token) bool // Test for "end of transcription" token
|
||||
IsPREV(Token) bool // Test for "start of prev" token
|
||||
IsSOLM(Token) bool // Test for "start of lm" token
|
||||
IsNOT(Token) bool // Test for "No timestamps" token
|
||||
IsLANG(Token, string) bool // Test for token associated with a specific language
|
||||
IsText(Token) bool // Test for text token
|
||||
}
|
||||
|
||||
// Segment is the text result of a speech recognition.
|
||||
type Segment struct {
|
||||
// Segment Number
|
||||
Num int
|
||||
|
||||
// Time beginning and end timestamps for the segment.
|
||||
Start, End time.Duration
|
||||
|
||||
// The text of the segment.
|
||||
Text string
|
||||
|
||||
// The tokens of the segment.
|
||||
Tokens []Token
|
||||
}
|
||||
|
||||
// Token is a text or special token
|
||||
type Token struct {
|
||||
Id int
|
||||
Text string
|
||||
P float32
|
||||
}
|
@ -1,100 +0,0 @@
|
||||
package whisper
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
|
||||
// Bindings
|
||||
whisper "github.com/ggerganov/whisper.cpp/bindings/go"
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// TYPES
|
||||
|
||||
type model struct {
|
||||
path string
|
||||
ctx *whisper.Context
|
||||
}
|
||||
|
||||
// Make sure model adheres to the interface
|
||||
var _ Model = (*model)(nil)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// LIFECYCLE
|
||||
|
||||
func New(path string) (Model, error) {
|
||||
model := new(model)
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
return nil, err
|
||||
} else if ctx := whisper.Whisper_init(path); ctx == nil {
|
||||
return nil, ErrUnableToLoadModel
|
||||
} else {
|
||||
model.ctx = ctx
|
||||
model.path = path
|
||||
}
|
||||
|
||||
// Return success
|
||||
return model, nil
|
||||
}
|
||||
|
||||
func (model *model) Close() error {
|
||||
if model.ctx != nil {
|
||||
model.ctx.Whisper_free()
|
||||
}
|
||||
|
||||
// Release resources
|
||||
model.ctx = nil
|
||||
|
||||
// Return success
|
||||
return nil
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// STRINGIFY
|
||||
|
||||
func (model *model) String() string {
|
||||
str := "<whisper.model"
|
||||
if model.ctx != nil {
|
||||
str += fmt.Sprintf(" model=%q", model.path)
|
||||
}
|
||||
return str + ">"
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// PUBLIC METHODS
|
||||
|
||||
// Return true if model is multilingual (language and translation options are supported)
|
||||
func (model *model) IsMultilingual() bool {
|
||||
return model.ctx.Whisper_is_multilingual() != 0
|
||||
}
|
||||
|
||||
// Return all recognized languages. Initially it is set to auto-detect
|
||||
func (model *model) Languages() []string {
|
||||
result := make([]string, 0, whisper.Whisper_lang_max_id())
|
||||
for i := 0; i < whisper.Whisper_lang_max_id(); i++ {
|
||||
str := whisper.Whisper_lang_str(i)
|
||||
if model.ctx.Whisper_lang_id(str) >= 0 {
|
||||
result = append(result, str)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (model *model) NewContext() (Context, error) {
|
||||
if model.ctx == nil {
|
||||
return nil, ErrInternalAppError
|
||||
}
|
||||
|
||||
// Create new context
|
||||
params := model.ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY)
|
||||
params.SetTranslate(false)
|
||||
params.SetPrintSpecial(false)
|
||||
params.SetPrintProgress(false)
|
||||
params.SetPrintRealtime(false)
|
||||
params.SetPrintTimestamps(false)
|
||||
params.SetThreads(runtime.NumCPU())
|
||||
|
||||
// Return new context
|
||||
return newContext(model, params)
|
||||
}
|
Binary file not shown.
@ -1,409 +0,0 @@
|
||||
package whisper
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// CGO
|
||||
|
||||
/*
|
||||
#cgo LDFLAGS: -lwhisper -lm -lstdc++
|
||||
#cgo darwin LDFLAGS: -framework Accelerate
|
||||
#include <whisper.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
extern void callNewSegment(void* user_data, int new);
|
||||
extern bool callEncoderBegin(void* user_data);
|
||||
|
||||
// Text segment callback
|
||||
// Called on every newly generated text segment
|
||||
// Use the whisper_full_...() functions to obtain the text segments
|
||||
static void whisper_new_segment_cb(struct whisper_context* ctx, int n_new, void* user_data) {
|
||||
if(user_data != NULL && ctx != NULL) {
|
||||
callNewSegment(user_data, n_new);
|
||||
}
|
||||
}
|
||||
|
||||
// Encoder begin callback
|
||||
// If not NULL, called before the encoder starts
|
||||
// If it returns false, the computation is aborted
|
||||
static bool whisper_encoder_begin_cb(struct whisper_context* ctx, void* user_data) {
|
||||
if(user_data != NULL && ctx != NULL) {
|
||||
return callEncoderBegin(user_data);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Get default parameters and set callbacks
|
||||
static struct whisper_full_params whisper_full_default_params_cb(struct whisper_context* ctx, enum whisper_sampling_strategy strategy) {
|
||||
struct whisper_full_params params = whisper_full_default_params(strategy);
|
||||
params.new_segment_callback = whisper_new_segment_cb;
|
||||
params.new_segment_callback_user_data = (void*)(ctx);
|
||||
params.encoder_begin_callback = whisper_encoder_begin_cb;
|
||||
params.encoder_begin_callback_user_data = (void*)(ctx);
|
||||
return params;
|
||||
}
|
||||
*/
|
||||
import "C"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// TYPES
|
||||
|
||||
type (
|
||||
Context C.struct_whisper_context
|
||||
Token C.whisper_token
|
||||
TokenData C.struct_whisper_token_data
|
||||
SamplingStrategy C.enum_whisper_sampling_strategy
|
||||
Params C.struct_whisper_full_params
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GLOBALS
|
||||
|
||||
const (
|
||||
SAMPLING_GREEDY SamplingStrategy = C.WHISPER_SAMPLING_GREEDY
|
||||
SAMPLING_BEAM_SEARCH SamplingStrategy = C.WHISPER_SAMPLING_BEAM_SEARCH
|
||||
)
|
||||
|
||||
const (
|
||||
SampleRate = C.WHISPER_SAMPLE_RATE // Expected sample rate, samples per second
|
||||
SampleBits = uint16(unsafe.Sizeof(C.float(0))) * 8 // Sample size in bits
|
||||
NumFFT = C.WHISPER_N_FFT
|
||||
NumMEL = C.WHISPER_N_MEL
|
||||
HopLength = C.WHISPER_HOP_LENGTH
|
||||
ChunkSize = C.WHISPER_CHUNK_SIZE
|
||||
)
|
||||
|
||||
var (
|
||||
ErrTokenizerFailed = errors.New("whisper_tokenize failed")
|
||||
ErrAutoDetectFailed = errors.New("whisper_lang_auto_detect failed")
|
||||
ErrConversionFailed = errors.New("whisper_convert failed")
|
||||
ErrInvalidLanguage = errors.New("invalid language")
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// PUBLIC METHODS
|
||||
|
||||
// Allocates all memory needed for the model and loads the model from the given file.
|
||||
// Returns NULL on failure.
|
||||
func Whisper_init(path string) *Context {
|
||||
cPath := C.CString(path)
|
||||
defer C.free(unsafe.Pointer(cPath))
|
||||
if ctx := C.whisper_init_from_file(cPath); ctx != nil {
|
||||
return (*Context)(ctx)
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Frees all memory allocated by the model.
|
||||
func (ctx *Context) Whisper_free() {
|
||||
C.whisper_free((*C.struct_whisper_context)(ctx))
|
||||
}
|
||||
|
||||
// Convert RAW PCM audio to log mel spectrogram.
|
||||
// The resulting spectrogram is stored inside the provided whisper context.
|
||||
func (ctx *Context) Whisper_pcm_to_mel(data []float32, threads int) error {
|
||||
if C.whisper_pcm_to_mel((*C.struct_whisper_context)(ctx), (*C.float)(&data[0]), C.int(len(data)), C.int(threads)) == 0 {
|
||||
return nil
|
||||
} else {
|
||||
return ErrConversionFailed
|
||||
}
|
||||
}
|
||||
|
||||
// This can be used to set a custom log mel spectrogram inside the provided whisper context.
|
||||
// Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
|
||||
// n_mel must be 80
|
||||
func (ctx *Context) Whisper_set_mel(data []float32, n_mel int) error {
|
||||
if C.whisper_set_mel((*C.struct_whisper_context)(ctx), (*C.float)(&data[0]), C.int(len(data)), C.int(n_mel)) == 0 {
|
||||
return nil
|
||||
} else {
|
||||
return ErrConversionFailed
|
||||
}
|
||||
}
|
||||
|
||||
// Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper context.
|
||||
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
|
||||
// offset can be used to specify the offset of the first frame in the spectrogram.
|
||||
func (ctx *Context) Whisper_encode(offset, threads int) error {
|
||||
if C.whisper_encode((*C.struct_whisper_context)(ctx), C.int(offset), C.int(threads)) == 0 {
|
||||
return nil
|
||||
} else {
|
||||
return ErrConversionFailed
|
||||
}
|
||||
}
|
||||
|
||||
// Run the Whisper decoder to obtain the logits and probabilities for the next token.
|
||||
// Make sure to call whisper_encode() first.
|
||||
// tokens + n_tokens is the provided context for the decoder.
|
||||
// n_past is the number of tokens to use from previous decoder calls.
|
||||
func (ctx *Context) Whisper_decode(tokens []Token, past, threads int) error {
|
||||
if C.whisper_decode((*C.struct_whisper_context)(ctx), (*C.whisper_token)(&tokens[0]), C.int(len(tokens)), C.int(past), C.int(threads)) == 0 {
|
||||
return nil
|
||||
} else {
|
||||
return ErrConversionFailed
|
||||
}
|
||||
}
|
||||
|
||||
// Convert the provided text into tokens. The tokens pointer must be large enough to hold the resulting tokens.
|
||||
// Returns the number of tokens on success
|
||||
func (ctx *Context) Whisper_tokenize(text string, tokens []Token) (int, error) {
|
||||
cText := C.CString(text)
|
||||
defer C.free(unsafe.Pointer(cText))
|
||||
if n := C.whisper_tokenize((*C.struct_whisper_context)(ctx), cText, (*C.whisper_token)(&tokens[0]), C.int(len(tokens))); n >= 0 {
|
||||
return int(n), nil
|
||||
} else {
|
||||
return 0, ErrTokenizerFailed
|
||||
}
|
||||
}
|
||||
|
||||
// Return the id of the specified language, returns -1 if not found
|
||||
// Examples:
|
||||
//
|
||||
// "de" -> 2
|
||||
// "german" -> 2
|
||||
func (ctx *Context) Whisper_lang_id(lang string) int {
|
||||
return int(C.whisper_lang_id(C.CString(lang)))
|
||||
}
|
||||
|
||||
// Largest language id (i.e. number of available languages - 1)
|
||||
func Whisper_lang_max_id() int {
|
||||
return int(C.whisper_lang_max_id())
|
||||
}
|
||||
|
||||
// Return the short string of the specified language id (e.g. 2 -> "de"),
|
||||
// returns empty string if not found
|
||||
func Whisper_lang_str(id int) string {
|
||||
return C.GoString(C.whisper_lang_str(C.int(id)))
|
||||
}
|
||||
|
||||
// Use mel data at offset_ms to try and auto-detect the spoken language
|
||||
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
|
||||
// Returns the probabilities of all languages.
|
||||
// ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69
|
||||
func (ctx *Context) Whisper_lang_auto_detect(offset_ms, n_threads int) ([]float32, error) {
|
||||
probs := make([]float32, Whisper_lang_max_id()+1)
|
||||
if n := int(C.whisper_lang_auto_detect((*C.struct_whisper_context)(ctx), C.int(offset_ms), C.int(n_threads), (*C.float)(&probs[0]))); n < 0 {
|
||||
return nil, ErrAutoDetectFailed
|
||||
} else {
|
||||
return probs, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (ctx *Context) Whisper_n_len() int {
|
||||
return int(C.whisper_n_len((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
||||
func (ctx *Context) Whisper_n_vocab() int {
|
||||
return int(C.whisper_n_vocab((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
||||
func (ctx *Context) Whisper_n_text_ctx() int {
|
||||
return int(C.whisper_n_text_ctx((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
||||
func (ctx *Context) Whisper_n_audio_ctx() int {
|
||||
return int(C.whisper_n_audio_ctx((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
||||
func (ctx *Context) Whisper_is_multilingual() int {
|
||||
return int(C.whisper_is_multilingual((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
||||
// The probabilities for the next token
|
||||
//func (ctx *Whisper_context) Whisper_get_probs() []float32 {
|
||||
// return (*[1 << 30]float32)(unsafe.Pointer(C.whisper_get_probs((*C.struct_whisper_context)(ctx))))[:ctx.Whisper_n_vocab()]
|
||||
//}
|
||||
|
||||
// Token Id -> String. Uses the vocabulary in the provided context
|
||||
func (ctx *Context) Whisper_token_to_str(token Token) string {
|
||||
return C.GoString(C.whisper_token_to_str((*C.struct_whisper_context)(ctx), C.whisper_token(token)))
|
||||
}
|
||||
|
||||
// Special tokens
|
||||
func (ctx *Context) Whisper_token_eot() Token {
|
||||
return Token(C.whisper_token_eot((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
||||
// Special tokens
|
||||
func (ctx *Context) Whisper_token_sot() Token {
|
||||
return Token(C.whisper_token_sot((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
||||
// Special tokens
|
||||
func (ctx *Context) Whisper_token_prev() Token {
|
||||
return Token(C.whisper_token_prev((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
||||
// Special tokens
|
||||
func (ctx *Context) Whisper_token_solm() Token {
|
||||
return Token(C.whisper_token_solm((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
||||
// Special tokens
|
||||
func (ctx *Context) Whisper_token_not() Token {
|
||||
return Token(C.whisper_token_not((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
||||
// Special tokens
|
||||
func (ctx *Context) Whisper_token_beg() Token {
|
||||
return Token(C.whisper_token_beg((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
||||
// Special tokens
|
||||
func (ctx *Context) Whisper_token_lang(lang_id int) Token {
|
||||
return Token(C.whisper_token_lang((*C.struct_whisper_context)(ctx), C.int(lang_id)))
|
||||
}
|
||||
|
||||
// Task tokens
|
||||
func Whisper_token_translate() Token {
|
||||
return Token(C.whisper_token_translate())
|
||||
}
|
||||
|
||||
// Task tokens
|
||||
func Whisper_token_transcribe() Token {
|
||||
return Token(C.whisper_token_transcribe())
|
||||
}
|
||||
|
||||
// Performance information
|
||||
func (ctx *Context) Whisper_print_timings() {
|
||||
C.whisper_print_timings((*C.struct_whisper_context)(ctx))
|
||||
}
|
||||
|
||||
// Performance information
|
||||
func (ctx *Context) Whisper_reset_timings() {
|
||||
C.whisper_reset_timings((*C.struct_whisper_context)(ctx))
|
||||
}
|
||||
|
||||
// Print system information
|
||||
func Whisper_print_system_info() string {
|
||||
return C.GoString(C.whisper_print_system_info())
|
||||
}
|
||||
|
||||
// Return default parameters for a strategy
|
||||
func (ctx *Context) Whisper_full_default_params(strategy SamplingStrategy) Params {
|
||||
// Get default parameters
|
||||
return Params(C.whisper_full_default_params_cb((*C.struct_whisper_context)(ctx), C.enum_whisper_sampling_strategy(strategy)))
|
||||
}
|
||||
|
||||
// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
|
||||
// Uses the specified decoding strategy to obtain the text.
|
||||
func (ctx *Context) Whisper_full(params Params, samples []float32, encoderBeginCallback func() bool, newSegmentCallback func(int)) error {
|
||||
registerEncoderBeginCallback(ctx, encoderBeginCallback)
|
||||
registerNewSegmentCallback(ctx, newSegmentCallback)
|
||||
defer registerEncoderBeginCallback(ctx, nil)
|
||||
defer registerNewSegmentCallback(ctx, nil)
|
||||
if C.whisper_full((*C.struct_whisper_context)(ctx), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples))) == 0 {
|
||||
return nil
|
||||
} else {
|
||||
return ErrConversionFailed
|
||||
}
|
||||
}
|
||||
|
||||
// Split the input audio in chunks and process each chunk separately using whisper_full()
|
||||
// It seems this approach can offer some speedup in some cases.
|
||||
// However, the transcription accuracy can be worse at the beginning and end of each chunk.
|
||||
func (ctx *Context) Whisper_full_parallel(params Params, samples []float32, processors int, encoderBeginCallback func() bool, newSegmentCallback func(int)) error {
|
||||
registerEncoderBeginCallback(ctx, encoderBeginCallback)
|
||||
registerNewSegmentCallback(ctx, newSegmentCallback)
|
||||
defer registerEncoderBeginCallback(ctx, nil)
|
||||
defer registerNewSegmentCallback(ctx, nil)
|
||||
|
||||
if C.whisper_full_parallel((*C.struct_whisper_context)(ctx), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples)), C.int(processors)) == 0 {
|
||||
return nil
|
||||
} else {
|
||||
return ErrConversionFailed
|
||||
}
|
||||
}
|
||||
|
||||
// Number of generated text segments.
|
||||
// A segment can be a few words, a sentence, or even a paragraph.
|
||||
func (ctx *Context) Whisper_full_n_segments() int {
|
||||
return int(C.whisper_full_n_segments((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
||||
// Get the start and end time of the specified segment.
|
||||
func (ctx *Context) Whisper_full_get_segment_t0(segment int) int64 {
|
||||
return int64(C.whisper_full_get_segment_t0((*C.struct_whisper_context)(ctx), C.int(segment)))
|
||||
}
|
||||
|
||||
// Get the start and end time of the specified segment.
|
||||
func (ctx *Context) Whisper_full_get_segment_t1(segment int) int64 {
|
||||
return int64(C.whisper_full_get_segment_t1((*C.struct_whisper_context)(ctx), C.int(segment)))
|
||||
}
|
||||
|
||||
// Get the text of the specified segment.
|
||||
func (ctx *Context) Whisper_full_get_segment_text(segment int) string {
|
||||
return C.GoString(C.whisper_full_get_segment_text((*C.struct_whisper_context)(ctx), C.int(segment)))
|
||||
}
|
||||
|
||||
// Get number of tokens in the specified segment.
|
||||
func (ctx *Context) Whisper_full_n_tokens(segment int) int {
|
||||
return int(C.whisper_full_n_tokens((*C.struct_whisper_context)(ctx), C.int(segment)))
|
||||
}
|
||||
|
||||
// Get the token text of the specified token index in the specified segment.
|
||||
func (ctx *Context) Whisper_full_get_token_text(segment int, token int) string {
|
||||
return C.GoString(C.whisper_full_get_token_text((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
|
||||
}
|
||||
|
||||
// Get the token of the specified token index in the specified segment.
|
||||
func (ctx *Context) Whisper_full_get_token_id(segment int, token int) Token {
|
||||
return Token(C.whisper_full_get_token_id((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
|
||||
}
|
||||
|
||||
// Get token data for the specified token in the specified segment.
|
||||
// This contains probabilities, timestamps, etc.
|
||||
func (ctx *Context) whisper_full_get_token_data(segment int, token int) TokenData {
|
||||
return TokenData(C.whisper_full_get_token_data((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
|
||||
}
|
||||
|
||||
// Get the probability of the specified token in the specified segment.
|
||||
func (ctx *Context) Whisper_full_get_token_p(segment int, token int) float32 {
|
||||
return float32(C.whisper_full_get_token_p((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// CALLBACKS
|
||||
|
||||
var (
|
||||
cbNewSegment = make(map[unsafe.Pointer]func(int))
|
||||
cbEncoderBegin = make(map[unsafe.Pointer]func() bool)
|
||||
)
|
||||
|
||||
func registerNewSegmentCallback(ctx *Context, fn func(int)) {
|
||||
if fn == nil {
|
||||
delete(cbNewSegment, unsafe.Pointer(ctx))
|
||||
} else {
|
||||
cbNewSegment[unsafe.Pointer(ctx)] = fn
|
||||
}
|
||||
}
|
||||
|
||||
func registerEncoderBeginCallback(ctx *Context, fn func() bool) {
|
||||
if fn == nil {
|
||||
delete(cbEncoderBegin, unsafe.Pointer(ctx))
|
||||
} else {
|
||||
cbEncoderBegin[unsafe.Pointer(ctx)] = fn
|
||||
}
|
||||
}
|
||||
|
||||
//export callNewSegment
|
||||
func callNewSegment(user_data unsafe.Pointer, new C.int) {
|
||||
if fn, ok := cbNewSegment[user_data]; ok {
|
||||
fn(int(new))
|
||||
}
|
||||
}
|
||||
|
||||
//export callEncoderBegin
|
||||
func callEncoderBegin(user_data unsafe.Pointer) C.bool {
|
||||
if fn, ok := cbEncoderBegin[user_data]; ok {
|
||||
if fn() {
|
||||
return C.bool(true)
|
||||
} else {
|
||||
return C.bool(false)
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
@ -1,113 +0,0 @@
|
||||
package whisper_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
// Packages
|
||||
whisper "github.com/ggerganov/whisper.cpp/bindings/go"
|
||||
wav "github.com/go-audio/wav"
|
||||
assert "github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const (
|
||||
ModelPath = "models/ggml-small.en.bin"
|
||||
SamplePath = "samples/jfk.wav"
|
||||
)
|
||||
|
||||
func Test_Whisper_000(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, model not found:", ModelPath)
|
||||
}
|
||||
ctx := whisper.Whisper_init(ModelPath)
|
||||
assert.NotNil(ctx)
|
||||
ctx.Whisper_free()
|
||||
}
|
||||
|
||||
func Test_Whisper_001(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, model not found:", ModelPath)
|
||||
}
|
||||
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, sample not found:", SamplePath)
|
||||
}
|
||||
|
||||
// Open samples
|
||||
fh, err := os.Open(SamplePath)
|
||||
assert.NoError(err)
|
||||
defer fh.Close()
|
||||
|
||||
// Read samples
|
||||
d := wav.NewDecoder(fh)
|
||||
buf, err := d.FullPCMBuffer()
|
||||
assert.NoError(err)
|
||||
|
||||
// Run whisper
|
||||
ctx := whisper.Whisper_init(ModelPath)
|
||||
assert.NotNil(ctx)
|
||||
defer ctx.Whisper_free()
|
||||
params := ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY)
|
||||
data := buf.AsFloat32Buffer().Data
|
||||
err = ctx.Whisper_full(params, data, nil, nil)
|
||||
assert.NoError(err)
|
||||
|
||||
// Print out tokens
|
||||
num_segments := ctx.Whisper_full_n_segments()
|
||||
assert.GreaterOrEqual(num_segments, 1)
|
||||
for i := 0; i < num_segments; i++ {
|
||||
str := ctx.Whisper_full_get_segment_text(i)
|
||||
assert.NotEmpty(str)
|
||||
t0 := time.Duration(ctx.Whisper_full_get_segment_t0(i)) * time.Millisecond
|
||||
t1 := time.Duration(ctx.Whisper_full_get_segment_t1(i)) * time.Millisecond
|
||||
t.Logf("[%6s->%-6s] %q", t0, t1, str)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Whisper_002(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
for i := 0; i < whisper.Whisper_lang_max_id(); i++ {
|
||||
str := whisper.Whisper_lang_str(i)
|
||||
assert.NotEmpty(str)
|
||||
t.Log(str)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Whisper_003(t *testing.T) {
|
||||
threads := runtime.NumCPU()
|
||||
assert := assert.New(t)
|
||||
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, model not found:", ModelPath)
|
||||
}
|
||||
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, sample not found:", SamplePath)
|
||||
}
|
||||
|
||||
// Open samples
|
||||
fh, err := os.Open(SamplePath)
|
||||
assert.NoError(err)
|
||||
defer fh.Close()
|
||||
|
||||
// Read samples
|
||||
d := wav.NewDecoder(fh)
|
||||
buf, err := d.FullPCMBuffer()
|
||||
assert.NoError(err)
|
||||
|
||||
// Make the model
|
||||
ctx := whisper.Whisper_init(ModelPath)
|
||||
assert.NotNil(ctx)
|
||||
defer ctx.Whisper_free()
|
||||
|
||||
// Get MEL
|
||||
assert.NoError(ctx.Whisper_pcm_to_mel(buf.AsFloat32Buffer().Data, threads))
|
||||
|
||||
// Get Languages
|
||||
languages, err := ctx.Whisper_lang_auto_detect(0, threads)
|
||||
assert.NoError(err)
|
||||
for i, p := range languages {
|
||||
t.Logf("%s: %f", whisper.Whisper_lang_str(i), p)
|
||||
}
|
||||
}
|
Submodule bindings/ios updated: f6334b026f...4bda8e9d80
@ -20,22 +20,15 @@ if (WHISPER_WASM_SINGLE_FILE)
|
||||
${CMAKE_BINARY_DIR}/bin/libwhisper.js
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/whisper.js
|
||||
)
|
||||
|
||||
add_custom_command(
|
||||
TARGET ${TARGET} POST_BUILD
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${CMAKE_BINARY_DIR}/bin/libwhisper.worker.js
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/libwhisper.worker.js
|
||||
)
|
||||
endif()
|
||||
|
||||
set_target_properties(${TARGET} PROPERTIES LINK_FLAGS " \
|
||||
--bind \
|
||||
-s MODULARIZE=1 \
|
||||
-s EXPORT_NAME=\"'whisper_factory'\" \
|
||||
-s FORCE_FILESYSTEM=1 \
|
||||
-s USE_PTHREADS=1 \
|
||||
-s PTHREAD_POOL_SIZE=8 \
|
||||
-s ALLOW_MEMORY_GROWTH=1 \
|
||||
-s INITIAL_MEMORY=1610612736 \
|
||||
-s TOTAL_MEMORY=1610612736 \
|
||||
-s FORCE_FILESYSTEM=1 \
|
||||
-s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap']\" \
|
||||
${EXTRA_FLAGS} \
|
||||
")
|
||||
|
@ -1,78 +0,0 @@
|
||||
# whisper.cpp
|
||||
|
||||
Node.js package for Whisper speech recognition
|
||||
|
||||
Package: https://www.npmjs.com/package/whisper.cpp
|
||||
|
||||
## Details
|
||||
|
||||
The performance is comparable to when running `whisper.cpp` in the browser via WASM.
|
||||
|
||||
The API is currently very rudimentary: [bindings/javascript/emscripten.cpp](/bindings/javascript/emscripten.cpp)
|
||||
|
||||
For sample usage check [tests/test-whisper.js](/tests/test-whisper.js)
|
||||
|
||||
## Package building + test
|
||||
|
||||
```bash
|
||||
# load emscripten
|
||||
source /path/to/emsdk/emsdk_env.sh
|
||||
|
||||
# clone repo
|
||||
git clone https://github.com/ggerganov/whisper.cpp
|
||||
cd whisper.cpp
|
||||
|
||||
# grab base.en model
|
||||
./models/download-ggml-model.sh base.en
|
||||
|
||||
# prepare PCM sample for testing
|
||||
ffmpeg -i samples/jfk.wav -f f32le -acodec pcm_f32le samples/jfk.pcmf32
|
||||
|
||||
# build
|
||||
mkdir build-em && cd build-em
|
||||
emcmake cmake .. && make -j
|
||||
|
||||
# run test
|
||||
node --experimental-wasm-threads --experimental-wasm-simd ../tests/test-whisper.js
|
||||
|
||||
# publish npm package
|
||||
make publish-npm
|
||||
```
|
||||
|
||||
## Sample run
|
||||
|
||||
```java
|
||||
$ node --experimental-wasm-threads --experimental-wasm-simd ../tests/test-whisper.js
|
||||
|
||||
whisper_model_load: loading model from 'whisper.bin'
|
||||
whisper_model_load: n_vocab = 51864
|
||||
whisper_model_load: n_audio_ctx = 1500
|
||||
whisper_model_load: n_audio_state = 512
|
||||
whisper_model_load: n_audio_head = 8
|
||||
whisper_model_load: n_audio_layer = 6
|
||||
whisper_model_load: n_text_ctx = 448
|
||||
whisper_model_load: n_text_state = 512
|
||||
whisper_model_load: n_text_head = 8
|
||||
whisper_model_load: n_text_layer = 6
|
||||
whisper_model_load: n_mels = 80
|
||||
whisper_model_load: f16 = 1
|
||||
whisper_model_load: type = 2
|
||||
whisper_model_load: adding 1607 extra tokens
|
||||
whisper_model_load: mem_required = 506.00 MB
|
||||
whisper_model_load: ggml ctx size = 140.60 MB
|
||||
whisper_model_load: memory size = 22.83 MB
|
||||
whisper_model_load: model size = 140.54 MB
|
||||
|
||||
system_info: n_threads = 8 / 10 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | NEON = 0 | F16C = 0 | FP16_VA = 0 | WASM_SIMD = 1 | BLAS = 0 |
|
||||
|
||||
operator(): processing 176000 samples, 11.0 sec, 8 threads, 1 processors, lang = en, task = transcribe ...
|
||||
|
||||
[00:00:00.000 --> 00:00:11.000] And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country.
|
||||
|
||||
whisper_print_timings: load time = 162.37 ms
|
||||
whisper_print_timings: mel time = 183.70 ms
|
||||
whisper_print_timings: sample time = 4.27 ms
|
||||
whisper_print_timings: encode time = 8582.63 ms / 1430.44 ms per layer
|
||||
whisper_print_timings: decode time = 436.16 ms / 72.69 ms per layer
|
||||
whisper_print_timings: total time = 9370.90 ms
|
||||
```
|
@ -1,48 +1,63 @@
|
||||
//
|
||||
// This is the Javascript API of whisper.cpp
|
||||
//
|
||||
// Very crude at the moment.
|
||||
// Feel free to contribute and make this better!
|
||||
//
|
||||
// See the tests/test-whisper.js for sample usage
|
||||
//
|
||||
|
||||
#include "whisper.h"
|
||||
|
||||
#include <emscripten.h>
|
||||
#include <emscripten/bind.h>
|
||||
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <thread>
|
||||
|
||||
struct whisper_context * g_context;
|
||||
std::thread g_worker;
|
||||
|
||||
std::vector<struct whisper_context *> g_contexts(4, nullptr);
|
||||
|
||||
EMSCRIPTEN_BINDINGS(whisper) {
|
||||
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
|
||||
if (g_context == nullptr) {
|
||||
g_context = whisper_init_from_file(path_model.c_str());
|
||||
if (g_context != nullptr) {
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
if (g_worker.joinable()) {
|
||||
g_worker.join();
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < g_contexts.size(); ++i) {
|
||||
if (g_contexts[i] == nullptr) {
|
||||
g_contexts[i] = whisper_init(path_model.c_str());
|
||||
if (g_contexts[i] != nullptr) {
|
||||
return i + 1;
|
||||
} else {
|
||||
return (size_t) 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
return (size_t) 0;
|
||||
}));
|
||||
|
||||
emscripten::function("free", emscripten::optional_override([]() {
|
||||
if (g_context) {
|
||||
whisper_free(g_context);
|
||||
g_context = nullptr;
|
||||
emscripten::function("free", emscripten::optional_override([](size_t index) {
|
||||
if (g_worker.joinable()) {
|
||||
g_worker.join();
|
||||
}
|
||||
|
||||
--index;
|
||||
|
||||
if (index < g_contexts.size()) {
|
||||
whisper_free(g_contexts[index]);
|
||||
g_contexts[index] = nullptr;
|
||||
}
|
||||
}));
|
||||
|
||||
emscripten::function("full_default", emscripten::optional_override([](const emscripten::val & audio, const std::string & lang, bool translate) {
|
||||
if (g_context == nullptr) {
|
||||
emscripten::function("full_default", emscripten::optional_override([](size_t index, const emscripten::val & audio, const std::string & lang, bool translate) {
|
||||
if (g_worker.joinable()) {
|
||||
g_worker.join();
|
||||
}
|
||||
|
||||
--index;
|
||||
|
||||
if (index >= g_contexts.size()) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (g_contexts[index] == nullptr) {
|
||||
return -2;
|
||||
}
|
||||
|
||||
struct whisper_full_params params = whisper_full_default_params(whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY);
|
||||
|
||||
params.print_realtime = true;
|
||||
@ -50,7 +65,7 @@ EMSCRIPTEN_BINDINGS(whisper) {
|
||||
params.print_timestamps = true;
|
||||
params.print_special = false;
|
||||
params.translate = translate;
|
||||
params.language = whisper_is_multilingual(g_context) ? lang.c_str() : "en";
|
||||
params.language = whisper_is_multilingual(g_contexts[index]) ? lang.c_str() : "en";
|
||||
params.n_threads = std::min(8, (int) std::thread::hardware_concurrency());
|
||||
params.offset_ms = 0;
|
||||
|
||||
@ -67,11 +82,9 @@ EMSCRIPTEN_BINDINGS(whisper) {
|
||||
|
||||
// print system information
|
||||
{
|
||||
printf("\n");
|
||||
printf("system_info: n_threads = %d / %d | %s\n",
|
||||
params.n_threads, std::thread::hardware_concurrency(), whisper_print_system_info());
|
||||
|
||||
printf("\n");
|
||||
printf("%s: processing %d samples, %.1f sec, %d threads, %d processors, lang = %s, task = %s ...\n",
|
||||
__func__, int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
|
||||
params.n_threads, 1,
|
||||
@ -81,11 +94,13 @@ EMSCRIPTEN_BINDINGS(whisper) {
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
// run whisper
|
||||
// run the worker
|
||||
{
|
||||
whisper_reset_timings(g_context);
|
||||
whisper_full(g_context, params, pcmf32.data(), pcmf32.size());
|
||||
whisper_print_timings(g_context);
|
||||
g_worker = std::thread([index, params, pcmf32 = std::move(pcmf32)]() {
|
||||
whisper_reset_timings(g_contexts[index]);
|
||||
whisper_full(g_contexts[index], params, pcmf32.data(), pcmf32.size());
|
||||
whisper_print_timings(g_contexts[index]);
|
||||
});
|
||||
}
|
||||
|
||||
return 0;
|
||||
|
@ -1 +0,0 @@
|
||||
"use strict";var Module={};var ENVIRONMENT_IS_NODE=typeof process=="object"&&typeof process.versions=="object"&&typeof process.versions.node=="string";if(ENVIRONMENT_IS_NODE){var nodeWorkerThreads=require("worker_threads");var parentPort=nodeWorkerThreads.parentPort;parentPort.on("message",data=>onmessage({data:data}));var fs=require("fs");Object.assign(global,{self:global,require:require,Module:Module,location:{href:__filename},Worker:nodeWorkerThreads.Worker,importScripts:function(f){(0,eval)(fs.readFileSync(f,"utf8")+"//# sourceURL="+f)},postMessage:function(msg){parentPort.postMessage(msg)},performance:global.performance||{now:function(){return Date.now()}}})}var initializedJS=false;var pendingNotifiedProxyingQueues=[];function threadPrintErr(){var text=Array.prototype.slice.call(arguments).join(" ");if(ENVIRONMENT_IS_NODE){fs.writeSync(2,text+"\n");return}console.error(text)}function threadAlert(){var text=Array.prototype.slice.call(arguments).join(" ");postMessage({cmd:"alert",text:text,threadId:Module["_pthread_self"]()})}var err=threadPrintErr;self.alert=threadAlert;Module["instantiateWasm"]=(info,receiveInstance)=>{var instance=new WebAssembly.Instance(Module["wasmModule"],info);receiveInstance(instance);Module["wasmModule"]=null;return instance.exports};self.onunhandledrejection=e=>{throw e.reason??e};self.onmessage=e=>{try{if(e.data.cmd==="load"){Module["wasmModule"]=e.data.wasmModule;for(const handler of e.data.handlers){Module[handler]=function(){postMessage({cmd:"callHandler",handler:handler,args:[...arguments]})}}Module["wasmMemory"]=e.data.wasmMemory;Module["buffer"]=Module["wasmMemory"].buffer;Module["ENVIRONMENT_IS_PTHREAD"]=true;if(typeof e.data.urlOrBlob=="string"){importScripts(e.data.urlOrBlob)}else{var objectUrl=URL.createObjectURL(e.data.urlOrBlob);importScripts(objectUrl);URL.revokeObjectURL(objectUrl)}whisper_factory(Module).then(function(instance){Module=instance})}else if(e.data.cmd==="run"){Module["__performance_now_clock_drift"]=performance.now()-e.data.time;Module["__emscripten_thread_init"](e.data.pthread_ptr,0,0,1);Module["establishStackSpace"]();Module["PThread"].receiveObjectTransfer(e.data);Module["PThread"].threadInitTLS();if(!initializedJS){Module["__embind_initialize_bindings"]();pendingNotifiedProxyingQueues.forEach(queue=>{Module["executeNotifiedProxyingQueue"](queue)});pendingNotifiedProxyingQueues=[];initializedJS=true}try{Module["invokeEntryPoint"](e.data.start_routine,e.data.arg)}catch(ex){if(ex!="unwind"){if(ex instanceof Module["ExitStatus"]){if(Module["keepRuntimeAlive"]()){}else{Module["__emscripten_thread_exit"](ex.status)}}else{throw ex}}}}else if(e.data.cmd==="cancel"){if(Module["_pthread_self"]()){Module["__emscripten_thread_exit"](-1)}}else if(e.data.target==="setimmediate"){}else if(e.data.cmd==="processProxyingQueue"){if(initializedJS){Module["executeNotifiedProxyingQueue"](e.data.queue)}else{pendingNotifiedProxyingQueues.push(e.data.queue)}}else if(e.data.cmd){err("worker.js received unknown command "+e.data.cmd);err(e.data)}}catch(ex){if(Module["__emscripten_thread_crashed"]){Module["__emscripten_thread_crashed"]()}throw ex}};
|
@ -1,26 +0,0 @@
|
||||
{
|
||||
"name": "whisper.cpp",
|
||||
"version": "@PROJECT_VERSION@",
|
||||
"description": "Whisper speech recognition",
|
||||
"main": "whisper.js",
|
||||
"scripts": {
|
||||
"test": "echo \"todo: add tests\" && exit 0"
|
||||
},
|
||||
"repository": {
|
||||
"type": "git",
|
||||
"url": "git+https://github.com/ggerganov/whisper.cpp"
|
||||
},
|
||||
"keywords": [
|
||||
"openai",
|
||||
"whisper",
|
||||
"speech-to-text",
|
||||
"speech-recognition",
|
||||
"transformer"
|
||||
],
|
||||
"author": "Georgi Gerganov",
|
||||
"license": "MIT",
|
||||
"bugs": {
|
||||
"url": "https://github.com/ggerganov/whisper.cpp/issues"
|
||||
},
|
||||
"homepage": "https://github.com/ggerganov/whisper.cpp#readme"
|
||||
}
|
@ -1,26 +0,0 @@
|
||||
{
|
||||
"name": "whisper.cpp",
|
||||
"version": "1.1.0",
|
||||
"description": "Whisper speech recognition",
|
||||
"main": "whisper.js",
|
||||
"scripts": {
|
||||
"test": "echo \"todo: add tests\" && exit 0"
|
||||
},
|
||||
"repository": {
|
||||
"type": "git",
|
||||
"url": "git+https://github.com/ggerganov/whisper.cpp"
|
||||
},
|
||||
"keywords": [
|
||||
"openai",
|
||||
"whisper",
|
||||
"speech-to-text",
|
||||
"speech-recognition",
|
||||
"transformer"
|
||||
],
|
||||
"author": "Georgi Gerganov",
|
||||
"license": "MIT",
|
||||
"bugs": {
|
||||
"url": "https://github.com/ggerganov/whisper.cpp/issues"
|
||||
},
|
||||
"homepage": "https://github.com/ggerganov/whisper.cpp#readme"
|
||||
}
|
File diff suppressed because one or more lines are too long
@ -1,17 +0,0 @@
|
||||
# Set the default compile features and properties for a target.
|
||||
|
||||
if (NOT TARGET)
|
||||
message(FATAL_ERROR "TARGET not set before including DefaultTargetOptions")
|
||||
endif()
|
||||
|
||||
target_compile_features(${TARGET}
|
||||
PRIVATE
|
||||
cxx_std_11
|
||||
)
|
||||
|
||||
set_target_properties(${TARGET}
|
||||
PROPERTIES
|
||||
EXPORT_COMPILE_COMMANDS ON
|
||||
RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin"
|
||||
INSTALL_RPATH "${CMAKE_INSTALL_PREFIX}/lib"
|
||||
)
|
@ -23,11 +23,9 @@ if (EMSCRIPTEN)
|
||||
add_subdirectory(stream.wasm)
|
||||
add_subdirectory(command.wasm)
|
||||
add_subdirectory(talk.wasm)
|
||||
add_subdirectory(bench.wasm)
|
||||
else()
|
||||
add_subdirectory(main)
|
||||
add_subdirectory(stream)
|
||||
add_subdirectory(command)
|
||||
add_subdirectory(bench)
|
||||
add_subdirectory(talk)
|
||||
endif()
|
||||
|
@ -1,49 +0,0 @@
|
||||
#
|
||||
# libbench
|
||||
#
|
||||
|
||||
set(TARGET libbench)
|
||||
|
||||
add_executable(${TARGET}
|
||||
emscripten.cpp
|
||||
)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE
|
||||
whisper
|
||||
)
|
||||
|
||||
unset(EXTRA_FLAGS)
|
||||
|
||||
if (WHISPER_WASM_SINGLE_FILE)
|
||||
set(EXTRA_FLAGS "-s SINGLE_FILE=1")
|
||||
message(STATUS "Embedding WASM inside bench.js")
|
||||
|
||||
add_custom_command(
|
||||
TARGET ${TARGET} POST_BUILD
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${CMAKE_BINARY_DIR}/bin/libbench.js
|
||||
${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/bench.wasm/bench.js
|
||||
)
|
||||
endif()
|
||||
|
||||
set_target_properties(${TARGET} PROPERTIES LINK_FLAGS " \
|
||||
--bind \
|
||||
-s USE_PTHREADS=1 \
|
||||
-s PTHREAD_POOL_SIZE=8 \
|
||||
-s INITIAL_MEMORY=1024MB \
|
||||
-s TOTAL_MEMORY=1024MB \
|
||||
-s FORCE_FILESYSTEM=1 \
|
||||
-s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap']\" \
|
||||
${EXTRA_FLAGS} \
|
||||
")
|
||||
|
||||
#
|
||||
# bench.wasm
|
||||
#
|
||||
|
||||
set(TARGET bench.wasm)
|
||||
|
||||
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/index-tmpl.html ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TARGET}/index.html @ONLY)
|
||||
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/../helpers.js ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TARGET}/helpers.js @ONLY)
|
@ -1,22 +0,0 @@
|
||||
# bench.wasm
|
||||
|
||||
Benchmark the performance of whisper.cpp in the browser using WebAssembly
|
||||
|
||||
Link: https://whisper.ggerganov.com/bench/
|
||||
|
||||
Terminal version: [examples/bench](/examples/bench)
|
||||
|
||||
## Build instructions
|
||||
|
||||
```bash
|
||||
# build using Emscripten (v3.1.2)
|
||||
git clone https://github.com/ggerganov/whisper.cpp
|
||||
cd whisper.cpp
|
||||
mkdir build-em && cd build-em
|
||||
emcmake cmake ..
|
||||
make -j
|
||||
|
||||
# copy the produced page to your HTTP path
|
||||
cp bin/bench.wasm/* /path/to/html/
|
||||
cp bin/libbench.worker.js /path/to/html/
|
||||
```
|
@ -1,85 +0,0 @@
|
||||
#include "whisper.h"
|
||||
|
||||
#include <emscripten.h>
|
||||
#include <emscripten/bind.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
constexpr int N_THREAD = 8;
|
||||
|
||||
// TODO: get rid of this vector of contexts - bad idea in the first place
|
||||
std::vector<struct whisper_context *> g_contexts(4, nullptr);
|
||||
|
||||
std::thread g_worker;
|
||||
|
||||
void bench_main(size_t index) {
|
||||
const int n_threads = std::min(N_THREAD, (int) std::thread::hardware_concurrency());
|
||||
|
||||
// whisper context
|
||||
auto & ctx = g_contexts[index];
|
||||
|
||||
fprintf(stderr, "%s: running benchmark with %d threads - please wait...\n", __func__, n_threads);
|
||||
|
||||
if (int ret = whisper_set_mel(ctx, nullptr, 0, WHISPER_N_MEL)) {
|
||||
fprintf(stderr, "error: failed to set mel: %d\n", ret);
|
||||
return;
|
||||
}
|
||||
|
||||
{
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", n_threads, std::thread::hardware_concurrency(), whisper_print_system_info());
|
||||
}
|
||||
|
||||
if (int ret = whisper_encode(ctx, 0, n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
||||
return;
|
||||
}
|
||||
|
||||
whisper_print_timings(ctx);
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "If you wish, you can submit these results here:\n");
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, " https://github.com/ggerganov/whisper.cpp/issues/89\n");
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "Please include the following information:\n");
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, " - CPU model\n");
|
||||
fprintf(stderr, " - Operating system\n");
|
||||
fprintf(stderr, " - Browser\n");
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
EMSCRIPTEN_BINDINGS(bench) {
|
||||
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
|
||||
for (size_t i = 0; i < g_contexts.size(); ++i) {
|
||||
if (g_contexts[i] == nullptr) {
|
||||
g_contexts[i] = whisper_init_from_file(path_model.c_str());
|
||||
if (g_contexts[i] != nullptr) {
|
||||
if (g_worker.joinable()) {
|
||||
g_worker.join();
|
||||
}
|
||||
g_worker = std::thread([i]() {
|
||||
bench_main(i);
|
||||
});
|
||||
|
||||
return i + 1;
|
||||
} else {
|
||||
return (size_t) 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return (size_t) 0;
|
||||
}));
|
||||
|
||||
emscripten::function("free", emscripten::optional_override([](size_t index) {
|
||||
if (index < g_contexts.size()) {
|
||||
whisper_free(g_contexts[index]);
|
||||
g_contexts[index] = nullptr;
|
||||
}
|
||||
}));
|
||||
}
|
@ -1,227 +0,0 @@
|
||||
<!doctype html>
|
||||
<html lang="en-us">
|
||||
<head>
|
||||
<title>bench : Benchmark whisper.cpp performance in the browser</title>
|
||||
|
||||
<style>
|
||||
#output {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
margin: 0 auto;
|
||||
margin-top: 10px;
|
||||
border-left: 0px;
|
||||
border-right: 0px;
|
||||
padding-left: 0px;
|
||||
padding-right: 0px;
|
||||
display: block;
|
||||
background-color: black;
|
||||
color: white;
|
||||
font-size: 10px;
|
||||
font-family: 'Lucida Console', Monaco, monospace;
|
||||
outline: none;
|
||||
white-space: pre;
|
||||
overflow-wrap: normal;
|
||||
overflow-x: scroll;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div id="main-container">
|
||||
<b>bench : Benchmark whisper.cpp performance in the browser</b>
|
||||
|
||||
<br><br>
|
||||
|
||||
You can find more about this project on <a href="https://github.com/ggerganov/whisper.cpp/tree/master/examples/bench.wasm">GitHub</a>.
|
||||
|
||||
<br><br>
|
||||
|
||||
<hr>
|
||||
|
||||
Select the model you would like to use and click the "Bench" button.<br>
|
||||
The results will be displayed in the textarea below.
|
||||
|
||||
<br><br>
|
||||
|
||||
<div id="model-whisper">
|
||||
Whisper model: <span id="model-whisper-status"></span>
|
||||
<button id="fetch-whisper-tiny-en" onclick="loadWhisper('tiny.en')">tiny.en (75 MB)</button>
|
||||
<button id="fetch-whisper-base-en" onclick="loadWhisper('base.en')">base.en (142 MB)</button>
|
||||
<span id="fetch-whisper-progress"></span>
|
||||
|
||||
<input type="file" id="whisper-file" name="file" onchange="loadFile(event, 'whisper.bin')" />
|
||||
</div>
|
||||
|
||||
<br>
|
||||
|
||||
<div id="input">
|
||||
<button id="bench" onclick="onBench()" disabled>Bench</button>
|
||||
<button id="clear" onclick="clearCache()">Clear Cache</button>
|
||||
</div>
|
||||
|
||||
<hr>
|
||||
|
||||
Debug output:
|
||||
<textarea id="output" rows="20"></textarea>
|
||||
|
||||
<br>
|
||||
|
||||
<b>Troubleshooting</b>
|
||||
|
||||
<br><br>
|
||||
|
||||
The page does some heavy computations, so make sure:
|
||||
|
||||
<ul>
|
||||
<li>To use a modern web browser (e.g. Chrome, Firefox)</li>
|
||||
<li>To use a fast desktop or laptop computer (i.e. not a mobile phone)</li>
|
||||
<li>Your browser supports WASM <a href="https://webassembly.org/roadmap/">Fixed-width SIMD</a></li>
|
||||
</ul>
|
||||
|
||||
<div class="cell-version">
|
||||
<span>
|
||||
|
|
||||
Build time: <span class="nav-link">@GIT_DATE@</span> |
|
||||
Commit hash: <a class="nav-link" href="https://github.com/ggerganov/whisper.cpp/commit/@GIT_SHA1@">@GIT_SHA1@</a> |
|
||||
Commit subject: <span class="nav-link">@GIT_COMMIT_SUBJECT@</span> |
|
||||
<a class="nav-link" href="https://github.com/ggerganov/whisper.cpp/tree/master/examples/bench.wasm">Source Code</a> |
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script type="text/javascript" src="helpers.js"></script>
|
||||
<script type='text/javascript'>
|
||||
// the bench instance
|
||||
var instance = null;
|
||||
|
||||
// model name
|
||||
var model_whisper = null;
|
||||
|
||||
var Module = {
|
||||
print: printTextarea,
|
||||
printErr: printTextarea,
|
||||
setStatus: function(text) {
|
||||
printTextarea('js: ' + text);
|
||||
},
|
||||
monitorRunDependencies: function(left) {
|
||||
},
|
||||
preRun: function() {
|
||||
printTextarea('js: Preparing ...');
|
||||
},
|
||||
postRun: function() {
|
||||
printTextarea('js: Initialized successfully!');
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// fetch models
|
||||
//
|
||||
|
||||
let dbVersion = 1
|
||||
let dbName = 'whisper.ggerganov.com';
|
||||
let indexedDB = window.indexedDB || window.mozIndexedDB || window.webkitIndexedDB || window.msIndexedDB
|
||||
|
||||
function storeFS(fname, buf) {
|
||||
// write to WASM file using FS_createDataFile
|
||||
// if the file exists, delete it
|
||||
try {
|
||||
Module.FS_unlink(fname);
|
||||
} catch (e) {
|
||||
// ignore
|
||||
}
|
||||
|
||||
Module.FS_createDataFile("/", fname, buf, true, true);
|
||||
|
||||
printTextarea('storeFS: stored model: ' + fname + ' size: ' + buf.length);
|
||||
|
||||
model_whisper = fname;
|
||||
|
||||
document.getElementById('model-whisper-status').innerHTML = 'loaded "' + model_whisper + '"!';
|
||||
|
||||
if (model_whisper != null) {
|
||||
document.getElementById('bench').disabled = false;
|
||||
}
|
||||
}
|
||||
|
||||
function loadFile(event, fname) {
|
||||
var file = event.target.files[0] || null;
|
||||
if (file == null) {
|
||||
return;
|
||||
}
|
||||
|
||||
printTextarea("loadFile: loading model: " + file.name + ", size: " + file.size + " bytes");
|
||||
printTextarea('loadFile: please wait ...');
|
||||
|
||||
var reader = new FileReader();
|
||||
reader.onload = function(event) {
|
||||
var buf = new Uint8Array(reader.result);
|
||||
storeFS(fname, buf);
|
||||
}
|
||||
reader.readAsArrayBuffer(file);
|
||||
|
||||
document.getElementById('fetch-whisper-tiny-en').style.display = 'none';
|
||||
document.getElementById('fetch-whisper-base-en').style.display = 'none';
|
||||
document.getElementById('whisper-file' ).style.display = 'none';
|
||||
document.getElementById('model-whisper-status' ).innerHTML = 'loaded model: ' + file.name;
|
||||
}
|
||||
|
||||
function loadWhisper(model) {
|
||||
let urls = {
|
||||
'tiny.en': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en.bin',
|
||||
'base.en': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en.bin',
|
||||
};
|
||||
|
||||
let sizes = {
|
||||
'tiny.en': 75,
|
||||
'base.en': 142,
|
||||
};
|
||||
|
||||
let url = urls[model];
|
||||
let dst = 'whisper.bin';
|
||||
let size_mb = sizes[model];
|
||||
|
||||
document.getElementById('fetch-whisper-tiny-en').style.display = 'none';
|
||||
document.getElementById('fetch-whisper-base-en').style.display = 'none';
|
||||
document.getElementById('model-whisper-status').innerHTML = 'loading "' + model + '" ... ';
|
||||
|
||||
cbProgress = function(p) {
|
||||
let el = document.getElementById('fetch-whisper-progress');
|
||||
el.innerHTML = Math.round(100*p) + '%';
|
||||
};
|
||||
|
||||
cbCancel = function() {
|
||||
var el;
|
||||
el = document.getElementById('fetch-whisper-tiny-en'); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('fetch-whisper-base-en'); if (el) el.style.display = 'inline-block';
|
||||
el = document.getElementById('model-whisper-status'); if (el) el.innerHTML = '';
|
||||
};
|
||||
|
||||
loadRemote(url, dst, size_mb, cbProgress, storeFS, cbCancel, printTextarea);
|
||||
}
|
||||
|
||||
//
|
||||
// main
|
||||
//
|
||||
|
||||
function onBench() {
|
||||
if (instance) {
|
||||
Module.free(instance);
|
||||
}
|
||||
|
||||
instance = Module.init('whisper.bin');
|
||||
|
||||
if (instance) {
|
||||
printTextarea("js: whisper initialized, instance: " + instance);
|
||||
}
|
||||
|
||||
document.getElementById('bench').disabled = true;
|
||||
|
||||
if (!instance) {
|
||||
printTextarea("js: failed to initialize whisper");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
</script>
|
||||
<script type="text/javascript" src="bench.js"></script>
|
||||
</body>
|
||||
</html>
|
@ -1,6 +1,3 @@
|
||||
set(TARGET bench)
|
||||
add_executable(${TARGET} bench.cpp)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE whisper ${CMAKE_THREAD_LIBS_INIT})
|
||||
|
@ -1,8 +1,6 @@
|
||||
# bench
|
||||
|
||||
A very basic tool for benchmarking the inference performance on your device. The tool simply runs the Encoder part of
|
||||
the transformer on some random audio data and records the execution time. This way we can have an objective comparison
|
||||
of the performance of the model for various setups.
|
||||
A very basic tool for benchmarking the inference performance on your device. The tool simply runs the Encoder part of the transformer on some random audio data and records the execution time. This way we can have an objective comparison of the performance of the model for various setups.
|
||||
|
||||
Benchmark results are tracked in the following Github issue: https://github.com/ggerganov/whisper.cpp/issues/89
|
||||
|
||||
|
@ -7,7 +7,6 @@
|
||||
// command-line parameters
|
||||
struct whisper_params {
|
||||
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||
int32_t what = 0; // what to benchmark: 0 - whisper ecoder, 1 - memcpy, 2 - ggml_mul_mat
|
||||
|
||||
std::string model = "models/ggml-base.en.bin";
|
||||
};
|
||||
@ -24,7 +23,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
}
|
||||
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
|
||||
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
|
||||
else if (arg == "-w" || arg == "--what") { params.what = atoi(argv[++i]); }
|
||||
else {
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
whisper_print_usage(argc, argv, params);
|
||||
@ -35,7 +33,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
return true;
|
||||
}
|
||||
|
||||
void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params) {
|
||||
void whisper_print_usage(int argc, char ** argv, const whisper_params & params) {
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "usage: %s [options]\n", argv[0]);
|
||||
fprintf(stderr, "\n");
|
||||
@ -43,17 +41,19 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
||||
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
|
||||
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
|
||||
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
||||
fprintf(stderr, " -w N, --what N [%-7d] what to benchmark:\n", params.what);
|
||||
fprintf(stderr, " %-7s 0 - whisper encoder\n", "");
|
||||
fprintf(stderr, " %-7s 1 - memcpy\n", "");
|
||||
fprintf(stderr, " %-7s 2 - ggml_mul_mat\n", "");
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
int whisper_bench_encoder(const whisper_params & params) {
|
||||
int main(int argc, char ** argv) {
|
||||
whisper_params params;
|
||||
|
||||
if (whisper_params_parse(argc, argv, params) == false) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
// whisper init
|
||||
|
||||
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
|
||||
struct whisper_context * ctx = whisper_init(params.model.c_str());
|
||||
|
||||
{
|
||||
fprintf(stderr, "\n");
|
||||
@ -92,22 +92,3 @@ int whisper_bench_encoder(const whisper_params & params) {
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
whisper_params params;
|
||||
|
||||
if (whisper_params_parse(argc, argv, params) == false) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ret = -1;
|
||||
|
||||
switch (params.what) {
|
||||
case 0: ret = whisper_bench_encoder(params); break;
|
||||
case 1: ret = whisper_bench_memcpy(params.n_threads); break;
|
||||
case 2: ret = whisper_bench_ggml_mul_mat(params.n_threads); break;
|
||||
default: fprintf(stderr, "error: unknown benchmark: %d\n", params.what); break;
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
@ -8,8 +8,6 @@ add_executable(${TARGET}
|
||||
emscripten.cpp
|
||||
)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE
|
||||
whisper
|
||||
)
|
||||
|
@ -324,7 +324,7 @@ EMSCRIPTEN_BINDINGS(command) {
|
||||
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
|
||||
for (size_t i = 0; i < g_contexts.size(); ++i) {
|
||||
if (g_contexts[i] == nullptr) {
|
||||
g_contexts[i] = whisper_init_from_file(path_model.c_str());
|
||||
g_contexts[i] = whisper_init(path_model.c_str());
|
||||
if (g_contexts[i] != nullptr) {
|
||||
g_running = true;
|
||||
if (g_worker.joinable()) {
|
||||
|
@ -2,9 +2,6 @@ if (WHISPER_SUPPORT_SDL2)
|
||||
# command
|
||||
set(TARGET command)
|
||||
add_executable(${TARGET} command.cpp)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
|
||||
target_link_libraries(${TARGET} PRIVATE whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
|
||||
endif ()
|
||||
|
@ -8,30 +8,13 @@ More info is available in [issue #171](https://github.com/ggerganov/whisper.cpp/
|
||||
./command -m ./models/ggml-small.en.bin -t 8
|
||||
|
||||
# On Raspberry Pi, use tiny or base models + "-ac 768" for better performance
|
||||
./command -m ./models/ggml-tiny.en.bin -ac 768 -t 3 -c 0
|
||||
./command -m ./models/ggml-tiny.en.bin -ac 768 -t 4 -c 0
|
||||
```
|
||||
|
||||
https://user-images.githubusercontent.com/1991296/204038393-2f846eae-c255-4099-a76d-5735c25c49da.mp4
|
||||
|
||||
Web version: [examples/command.wasm](/examples/command.wasm)
|
||||
|
||||
## Guided mode
|
||||
|
||||
"Guided mode" allows you to specify a list of commands (i.e. strings) and the transcription will be guided to classify your command into one from the list. This can be useful in situations where a device is listening only for a small subset of commands.
|
||||
|
||||
Initial tests show that this approach might be extremely efficient in terms of performance, since it integrates very well with the "partial Encoder" idea from #137.
|
||||
|
||||
```bash
|
||||
# Run in guided mode, the list of allowed commands is in commands.txt
|
||||
./command -m ./models/ggml-base.en.bin -cmd ./examples/command/commands.txt
|
||||
|
||||
# On Raspberry Pi, in guided mode you can use "-ac 128" for extra performance
|
||||
./command -m ./models/ggml-tiny.en.bin -cmd ./examples/command/commands.txt -ac 128 -t 3 -c 0
|
||||
```
|
||||
|
||||
https://user-images.githubusercontent.com/1991296/207435352-8fc4ed3f-bde5-4555-9b8b-aeeb76bee969.mp4
|
||||
|
||||
|
||||
## Building
|
||||
|
||||
The `command` tool depends on SDL2 library to capture audio from the microphone. You can build it like this:
|
||||
|
@ -11,7 +11,6 @@
|
||||
#include <SDL.h>
|
||||
#include <SDL_audio.h>
|
||||
|
||||
#include <sstream>
|
||||
#include <cassert>
|
||||
#include <cstdio>
|
||||
#include <fstream>
|
||||
@ -20,13 +19,12 @@
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
|
||||
// command-line parameters
|
||||
struct whisper_params {
|
||||
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||
int32_t prompt_ms = 5000;
|
||||
int32_t command_ms = 8000;
|
||||
int32_t command_ms = 4000;
|
||||
int32_t capture_id = -1;
|
||||
int32_t max_tokens = 32;
|
||||
int32_t audio_ctx = 0;
|
||||
@ -36,15 +34,14 @@ struct whisper_params {
|
||||
|
||||
bool speed_up = false;
|
||||
bool translate = false;
|
||||
bool no_context = true;
|
||||
bool print_special = false;
|
||||
bool print_energy = false;
|
||||
bool no_timestamps = true;
|
||||
|
||||
std::string language = "en";
|
||||
std::string model = "models/ggml-base.en.bin";
|
||||
std::string fname_out;
|
||||
std::string commands;
|
||||
std::string prompt;
|
||||
std::string fname_out = "";
|
||||
};
|
||||
|
||||
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
|
||||
@ -72,8 +69,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
|
||||
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
|
||||
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
|
||||
else if (arg == "-cmd" || arg == "--commands") { params.commands = argv[++i]; }
|
||||
else if (arg == "-p" || arg == "--prompt") { params.prompt = argv[++i]; }
|
||||
else {
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
whisper_print_usage(argc, argv, params);
|
||||
@ -84,29 +79,27 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
return true;
|
||||
}
|
||||
|
||||
void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params) {
|
||||
void whisper_print_usage(int argc, char ** argv, const whisper_params & params) {
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "usage: %s [options]\n", argv[0]);
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "options:\n");
|
||||
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
|
||||
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
|
||||
fprintf(stderr, " -pms N, --prompt-ms N [%-7d] prompt duration in milliseconds\n", params.prompt_ms);
|
||||
fprintf(stderr, " -cms N, --command-ms N [%-7d] command duration in milliseconds\n", params.command_ms);
|
||||
fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
|
||||
fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
|
||||
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
|
||||
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
|
||||
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
|
||||
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
|
||||
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
|
||||
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
||||
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
|
||||
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
|
||||
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
||||
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
|
||||
fprintf(stderr, " -cmd FNAME, --commands FNAME [%-7s] text file with allowed commands\n", params.commands.c_str());
|
||||
fprintf(stderr, " -p, --prompt [%-7s] the required activation prompt\n", params.prompt.c_str());
|
||||
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
|
||||
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
|
||||
fprintf(stderr, " -pms N, --prompt-ms N [%-7d] prompt duration in milliseconds\n", params.prompt_ms);
|
||||
fprintf(stderr, " -cms N, --command-ms N [%-7d] command duration in milliseconds\n", params.command_ms);
|
||||
fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
|
||||
fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
|
||||
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
|
||||
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
|
||||
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
|
||||
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
|
||||
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
|
||||
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
||||
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
|
||||
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
|
||||
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
||||
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
@ -391,7 +384,7 @@ bool vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float
|
||||
float energy_all = 0.0f;
|
||||
float energy_last = 0.0f;
|
||||
|
||||
for (int i = 0; i < n_samples; i++) {
|
||||
for (size_t i = 0; i < n_samples; i++) {
|
||||
energy_all += fabsf(pcmf32[i]);
|
||||
|
||||
if (i >= n_samples - n_samples_last) {
|
||||
@ -492,350 +485,54 @@ float similarity(const std::string & s0, const std::string & s1) {
|
||||
return 1.0f - (dist / std::max(s0.size(), s1.size()));
|
||||
}
|
||||
|
||||
std::vector<std::string> read_allowed_commands(const std::string & fname) {
|
||||
std::vector<std::string> allowed_commands;
|
||||
int main(int argc, char ** argv) {
|
||||
whisper_params params;
|
||||
|
||||
std::ifstream ifs(fname);
|
||||
if (!ifs.is_open()) {
|
||||
return allowed_commands;
|
||||
if (whisper_params_parse(argc, argv, params) == false) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::string line;
|
||||
while (std::getline(ifs, line)) {
|
||||
line = trim(line);
|
||||
if (line.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::transform(line.begin(), line.end(),line.begin(), ::tolower);
|
||||
allowed_commands.push_back(std::move(line));
|
||||
if (whisper_lang_id(params.language.c_str()) == -1) {
|
||||
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
|
||||
whisper_print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
}
|
||||
|
||||
return allowed_commands;
|
||||
}
|
||||
// whisper init
|
||||
|
||||
std::vector<std::string> get_words(const std::string &txt) {
|
||||
std::vector<std::string> words;
|
||||
struct whisper_context * ctx = whisper_init(params.model.c_str());
|
||||
|
||||
std::istringstream iss(txt);
|
||||
std::string word;
|
||||
while (iss >> word) {
|
||||
words.push_back(word);
|
||||
}
|
||||
|
||||
return words;
|
||||
}
|
||||
|
||||
// returns true if no exit event was received
|
||||
bool process_sdl_events() {
|
||||
SDL_Event event;
|
||||
while (SDL_PollEvent(&event)) {
|
||||
switch (event.type) {
|
||||
case SDL_QUIT:
|
||||
{
|
||||
return false;
|
||||
} break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// command-list mode
|
||||
// guide the transcription to match the most likely command from a provided list
|
||||
int process_command_list(struct whisper_context * ctx, audio_async &audio, const whisper_params ¶ms) {
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "%s: guided mode\n", __func__);
|
||||
|
||||
std::vector<std::string> allowed_commands = read_allowed_commands(params.commands);
|
||||
|
||||
if (allowed_commands.empty()) {
|
||||
fprintf(stderr, "%s: error: failed to read allowed commands from '%s'\n", __func__, params.commands.c_str());
|
||||
return 2;
|
||||
}
|
||||
|
||||
int max_len = 0;
|
||||
|
||||
std::vector<std::vector<whisper_token>> allowed_tokens;
|
||||
|
||||
for (const auto & cmd : allowed_commands) {
|
||||
whisper_token tokens[1024];
|
||||
allowed_tokens.emplace_back();
|
||||
|
||||
for (int l = 0; l < (int) cmd.size(); ++l) {
|
||||
// NOTE: very important to add the whitespace !
|
||||
// the reason is that the first decoded token starts with a whitespace too!
|
||||
std::string ss = std::string(" ") + cmd.substr(0, l + 1);
|
||||
|
||||
const int n = whisper_tokenize(ctx, ss.c_str(), tokens, 1024);
|
||||
if (n < 0) {
|
||||
fprintf(stderr, "%s: error: failed to tokenize command '%s'\n", __func__, cmd.c_str());
|
||||
return 3;
|
||||
}
|
||||
|
||||
if (n == 1) {
|
||||
allowed_tokens.back().push_back(tokens[0]);
|
||||
}
|
||||
}
|
||||
|
||||
max_len = std::max(max_len, (int) cmd.size());
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s: allowed commands [ tokens ]:\n", __func__);
|
||||
fprintf(stderr, "\n");
|
||||
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
|
||||
fprintf(stderr, " - \033[1m%-*s\033[0m = [", max_len, allowed_commands[i].c_str());
|
||||
for (const auto & token : allowed_tokens[i]) {
|
||||
fprintf(stderr, " %5d", token);
|
||||
}
|
||||
fprintf(stderr, " ]\n");
|
||||
}
|
||||
|
||||
std::string k_prompt = "select one from the available words: ";
|
||||
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
|
||||
if (i > 0) {
|
||||
k_prompt += ", ";
|
||||
}
|
||||
k_prompt += allowed_commands[i];
|
||||
}
|
||||
k_prompt += ". selected word: ";
|
||||
|
||||
// tokenize prompt
|
||||
std::vector<whisper_token> k_tokens;
|
||||
// print some info about the processing
|
||||
{
|
||||
k_tokens.resize(1024);
|
||||
const int n = whisper_tokenize(ctx, k_prompt.c_str(), k_tokens.data(), 1024);
|
||||
if (n < 0) {
|
||||
fprintf(stderr, "%s: error: failed to tokenize prompt '%s'\n", __func__, k_prompt.c_str());
|
||||
return 4;
|
||||
}
|
||||
k_tokens.resize(n);
|
||||
}
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "%s: prompt: '%s'\n", __func__, k_prompt.c_str());
|
||||
fprintf(stderr, "%s: tokens: [", __func__);
|
||||
for (const auto & token : k_tokens) {
|
||||
fprintf(stderr, " %d", token);
|
||||
}
|
||||
fprintf(stderr, " ]\n");
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "%s: listening for a command ...\n", __func__);
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
bool is_running = true;
|
||||
|
||||
std::vector<float> pcmf32_cur;
|
||||
std::vector<float> pcmf32_prompt;
|
||||
|
||||
// main loop
|
||||
while (is_running) {
|
||||
// handle Ctrl + C
|
||||
is_running = process_sdl_events();
|
||||
|
||||
// delay
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
|
||||
audio.get(2000, pcmf32_cur);
|
||||
|
||||
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
|
||||
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
|
||||
|
||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||
|
||||
wparams.print_progress = false;
|
||||
wparams.print_special = params.print_special;
|
||||
wparams.print_realtime = false;
|
||||
wparams.print_timestamps = !params.no_timestamps;
|
||||
wparams.translate = params.translate;
|
||||
wparams.no_context = true;
|
||||
wparams.single_segment = true;
|
||||
wparams.max_tokens = 1;
|
||||
wparams.language = params.language.c_str();
|
||||
wparams.n_threads = params.n_threads;
|
||||
|
||||
wparams.audio_ctx = params.audio_ctx;
|
||||
wparams.speed_up = params.speed_up;
|
||||
|
||||
wparams.prompt_tokens = k_tokens.data();
|
||||
wparams.prompt_n_tokens = k_tokens.size();
|
||||
|
||||
// run the transformer and a single decoding pass
|
||||
if (whisper_full(ctx, wparams, pcmf32_cur.data(), pcmf32_cur.size()) != 0) {
|
||||
fprintf(stderr, "%s: ERROR: whisper_full() failed\n", __func__);
|
||||
break;
|
||||
}
|
||||
|
||||
// estimate command probability
|
||||
// NOTE: not optimal
|
||||
{
|
||||
const auto * logits = whisper_get_logits(ctx);
|
||||
|
||||
std::vector<float> probs(whisper_n_vocab(ctx), 0.0f);
|
||||
|
||||
// compute probs from logits via softmax
|
||||
{
|
||||
float max = -1e9;
|
||||
for (int i = 0; i < (int) probs.size(); ++i) {
|
||||
max = std::max(max, logits[i]);
|
||||
}
|
||||
|
||||
float sum = 0.0f;
|
||||
for (int i = 0; i < (int) probs.size(); ++i) {
|
||||
probs[i] = expf(logits[i] - max);
|
||||
sum += probs[i];
|
||||
}
|
||||
|
||||
for (int i = 0; i < (int) probs.size(); ++i) {
|
||||
probs[i] /= sum;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::pair<float, int>> probs_id;
|
||||
|
||||
double psum = 0.0;
|
||||
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
|
||||
probs_id.emplace_back(probs[allowed_tokens[i][0]], i);
|
||||
for (int j = 1; j < (int) allowed_tokens[i].size(); ++j) {
|
||||
probs_id.back().first += probs[allowed_tokens[i][j]];
|
||||
}
|
||||
probs_id.back().first /= allowed_tokens[i].size();
|
||||
psum += probs_id.back().first;
|
||||
}
|
||||
|
||||
// normalize
|
||||
for (auto & p : probs_id) {
|
||||
p.first /= psum;
|
||||
}
|
||||
|
||||
// sort descending
|
||||
{
|
||||
using pair_type = decltype(probs_id)::value_type;
|
||||
std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) {
|
||||
return a.first > b.first;
|
||||
});
|
||||
}
|
||||
|
||||
// print the commands and the respective probabilities
|
||||
{
|
||||
fprintf(stdout, "\n");
|
||||
for (const auto & cmd : probs_id) {
|
||||
fprintf(stdout, "%s: %s%-*s%s = %f | ", __func__, "\033[1m", max_len, allowed_commands[cmd.second].c_str(), "\033[0m", cmd.first);
|
||||
for (int token : allowed_tokens[cmd.second]) {
|
||||
fprintf(stdout, "'%4s' %f ", whisper_token_to_str(ctx, token), probs[token]);
|
||||
}
|
||||
fprintf(stdout, "\n");
|
||||
}
|
||||
}
|
||||
|
||||
// best command
|
||||
{
|
||||
const auto t_end = std::chrono::high_resolution_clock::now();
|
||||
|
||||
const float prob = probs_id[0].first;
|
||||
const int index = probs_id[0].second;
|
||||
|
||||
fprintf(stdout, "\n");
|
||||
fprintf(stdout, "%s: detected command: %s%s%s | p = %f | t = %d ms\n", __func__,
|
||||
"\033[1m", allowed_commands[index].c_str(), "\033[0m", prob,
|
||||
(int) std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count());
|
||||
fprintf(stdout, "\n");
|
||||
}
|
||||
}
|
||||
|
||||
audio.clear();
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
// always-prompt mode
|
||||
// transcribe the voice into text after valid prompt
|
||||
int always_prompt_transcription(struct whisper_context * ctx, audio_async & audio, const whisper_params & params) {
|
||||
bool is_running = true;
|
||||
bool ask_prompt = true;
|
||||
|
||||
float prob = 0.0f;
|
||||
|
||||
std::vector<float> pcmf32_cur;
|
||||
|
||||
const std::string k_prompt = params.prompt;
|
||||
|
||||
const int k_prompt_length = get_words(k_prompt).size();
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "%s: always-prompt mode\n", __func__);
|
||||
|
||||
// main loop
|
||||
while (is_running) {
|
||||
// handle Ctrl + C
|
||||
is_running = process_sdl_events();
|
||||
|
||||
// delay
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
|
||||
if (ask_prompt) {
|
||||
fprintf(stdout, "\n");
|
||||
fprintf(stdout, "%s: The prompt is: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m");
|
||||
fprintf(stdout, "\n");
|
||||
|
||||
ask_prompt = false;
|
||||
}
|
||||
|
||||
{
|
||||
audio.get(2000, pcmf32_cur);
|
||||
|
||||
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
|
||||
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
|
||||
|
||||
int64_t t_ms = 0;
|
||||
|
||||
// detect the commands
|
||||
audio.get(params.command_ms, pcmf32_cur);
|
||||
|
||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
|
||||
|
||||
const auto words = get_words(txt);
|
||||
|
||||
std::string prompt;
|
||||
std::string command;
|
||||
|
||||
for (int i = 0; i < (int) words.size(); ++i) {
|
||||
if (i < k_prompt_length) {
|
||||
prompt += words[i] + " ";
|
||||
} else {
|
||||
command += words[i] + " ";
|
||||
}
|
||||
}
|
||||
|
||||
const float sim = similarity(prompt, k_prompt);
|
||||
|
||||
//debug
|
||||
//fprintf(stdout, "command size: %i\n", command_length);
|
||||
|
||||
if ((sim > 0.7f) && (command.size() > 0)) {
|
||||
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
|
||||
}
|
||||
|
||||
fprintf(stdout, "\n");
|
||||
|
||||
audio.clear();
|
||||
fprintf(stderr, "\n");
|
||||
if (!whisper_is_multilingual(ctx)) {
|
||||
if (params.language != "en" || params.translate) {
|
||||
params.language = "en";
|
||||
params.translate = false;
|
||||
fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
|
||||
}
|
||||
}
|
||||
fprintf(stderr, "%s: processing, %d threads, lang = %s, task = %s, timestamps = %d ...\n",
|
||||
__func__,
|
||||
params.n_threads,
|
||||
params.language.c_str(),
|
||||
params.translate ? "translate" : "transcribe",
|
||||
params.no_timestamps ? 0 : 1);
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
// general-purpose mode
|
||||
// freely transcribe the voice into text
|
||||
int process_general_transcription(struct whisper_context * ctx, audio_async &audio, const whisper_params ¶ms) {
|
||||
// init audio
|
||||
|
||||
audio_async audio(30*1000);
|
||||
if (!audio.init(params.capture_id, WHISPER_SAMPLE_RATE)) {
|
||||
fprintf(stderr, "%s: audio.init() failed!\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
audio.resume();
|
||||
|
||||
bool is_running = true;
|
||||
bool have_prompt = false;
|
||||
bool ask_prompt = true;
|
||||
@ -848,13 +545,26 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud
|
||||
|
||||
const std::string k_prompt = "Ok Whisper, start listening for commands.";
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "%s: general-purpose mode\n", __func__);
|
||||
|
||||
// main loop
|
||||
while (is_running) {
|
||||
// handle Ctrl + C
|
||||
is_running = process_sdl_events();
|
||||
{
|
||||
SDL_Event event;
|
||||
while (SDL_PollEvent(&event)) {
|
||||
switch (event.type) {
|
||||
case SDL_QUIT:
|
||||
{
|
||||
is_running = false;
|
||||
} break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!is_running) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// delay
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
@ -867,16 +577,15 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud
|
||||
ask_prompt = false;
|
||||
}
|
||||
|
||||
int64_t t_ms = 0;
|
||||
|
||||
{
|
||||
audio.get(2000, pcmf32_cur);
|
||||
|
||||
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
|
||||
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
|
||||
|
||||
int64_t t_ms = 0;
|
||||
|
||||
if (!have_prompt) {
|
||||
// wait for activation phrase
|
||||
audio.get(params.prompt_ms, pcmf32_cur);
|
||||
|
||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob0, t_ms));
|
||||
@ -899,7 +608,6 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud
|
||||
have_prompt = true;
|
||||
}
|
||||
} else {
|
||||
// we have heard the activation phrase, now detect the commands
|
||||
audio.get(params.command_ms, pcmf32_cur);
|
||||
|
||||
// prepend the prompt audio
|
||||
@ -938,74 +646,10 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
whisper_params params;
|
||||
|
||||
if (whisper_params_parse(argc, argv, params) == false) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (whisper_lang_id(params.language.c_str()) == -1) {
|
||||
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
|
||||
whisper_print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
}
|
||||
|
||||
// whisper init
|
||||
|
||||
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
|
||||
|
||||
// print some info about the processing
|
||||
{
|
||||
fprintf(stderr, "\n");
|
||||
if (!whisper_is_multilingual(ctx)) {
|
||||
if (params.language != "en" || params.translate) {
|
||||
params.language = "en";
|
||||
params.translate = false;
|
||||
fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
|
||||
}
|
||||
}
|
||||
fprintf(stderr, "%s: processing, %d threads, lang = %s, task = %s, timestamps = %d ...\n",
|
||||
__func__,
|
||||
params.n_threads,
|
||||
params.language.c_str(),
|
||||
params.translate ? "translate" : "transcribe",
|
||||
params.no_timestamps ? 0 : 1);
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
// init audio
|
||||
|
||||
audio_async audio(30*1000);
|
||||
if (!audio.init(params.capture_id, WHISPER_SAMPLE_RATE)) {
|
||||
fprintf(stderr, "%s: audio.init() failed!\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
audio.resume();
|
||||
|
||||
// wait for 1 second to avoid any buffered noise
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
|
||||
audio.clear();
|
||||
|
||||
int ret_val = 0;
|
||||
|
||||
if (!params.commands.empty()) {
|
||||
ret_val = process_command_list(ctx, audio, params);
|
||||
} else if (!params.prompt.empty()) {
|
||||
ret_val = always_prompt_transcription(ctx, audio, params);
|
||||
} else {
|
||||
ret_val = process_general_transcription(ctx, audio, params);
|
||||
}
|
||||
|
||||
audio.pause();
|
||||
|
||||
whisper_print_timings(ctx);
|
||||
whisper_free(ctx);
|
||||
|
||||
return ret_val;
|
||||
return 0;
|
||||
}
|
||||
|
@ -1,9 +0,0 @@
|
||||
enable
|
||||
disable
|
||||
cat
|
||||
dog
|
||||
apple
|
||||
red
|
||||
blue
|
||||
green
|
||||
lightblue
|
@ -1,33 +1,19 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
set -eo pipefail
|
||||
# Transcribe audio livestream by feeding ffmpeg output to whisper.cpp at regular intervals
|
||||
# Idea by @semiformal-net
|
||||
# ref: https://github.com/ggerganov/whisper.cpp/issues/185
|
||||
#
|
||||
|
||||
set -eo pipefail
|
||||
# TODO:
|
||||
# - Currently, there is a gap between sequential chunks, so some of the words are dropped. Need to figure out a
|
||||
# way to produce a continuous stream of audio chunks.
|
||||
#
|
||||
|
||||
url="http://a.files.bbci.co.uk/media/live/manifesto/audio/simulcast/hls/nonuk/sbr_low/ak/bbc_world_service.m3u8"
|
||||
fmt=aac # the audio format extension of the stream (TODO: auto detect)
|
||||
step_s=30
|
||||
model="base.en"
|
||||
|
||||
check_requirements()
|
||||
{
|
||||
if ! command -v ./main &>/dev/null; then
|
||||
echo "whisper.cpp main executable is required (make)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! command -v ffmpeg &>/dev/null; then
|
||||
echo "ffmpeg is required (https://ffmpeg.org)"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
check_requirements
|
||||
|
||||
|
||||
if [ -z "$1" ]; then
|
||||
echo "Usage: $0 stream_url [step_s] [model]"
|
||||
echo ""
|
||||
|
@ -1,6 +1,3 @@
|
||||
set(TARGET main)
|
||||
add_executable(${TARGET} main.cpp)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE whisper ${CMAKE_THREAD_LIBS_INIT})
|
||||
|
@ -59,29 +59,22 @@ struct whisper_params {
|
||||
int32_t duration_ms = 0;
|
||||
int32_t max_context = -1;
|
||||
int32_t max_len = 0;
|
||||
int32_t best_of = 5;
|
||||
int32_t beam_size = -1;
|
||||
|
||||
float word_thold = 0.01f;
|
||||
float entropy_thold = 2.4f;
|
||||
float logprob_thold = -1.0f;
|
||||
float word_thold = 0.01f;
|
||||
|
||||
bool speed_up = false;
|
||||
bool translate = false;
|
||||
bool diarize = false;
|
||||
bool output_txt = false;
|
||||
bool output_vtt = false;
|
||||
bool output_srt = false;
|
||||
bool output_wts = false;
|
||||
bool output_csv = false;
|
||||
bool print_special = false;
|
||||
bool print_colors = false;
|
||||
bool print_progress = false;
|
||||
bool no_timestamps = false;
|
||||
bool speed_up = false;
|
||||
bool translate = false;
|
||||
bool diarize = false;
|
||||
bool output_txt = false;
|
||||
bool output_vtt = false;
|
||||
bool output_srt = false;
|
||||
bool output_wts = false;
|
||||
bool print_special = false;
|
||||
bool print_colors = false;
|
||||
bool no_timestamps = false;
|
||||
|
||||
std::string language = "en";
|
||||
std::string prompt;
|
||||
std::string model = "models/ggml-base.en.bin";
|
||||
std::string language = "en";
|
||||
std::string model = "models/ggml-base.en.bin";
|
||||
|
||||
std::vector<std::string> fname_inp = {};
|
||||
};
|
||||
@ -101,34 +94,27 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
whisper_print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
}
|
||||
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
|
||||
else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(argv[++i]); }
|
||||
else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(argv[++i]); }
|
||||
else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(argv[++i]); }
|
||||
else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); }
|
||||
else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); }
|
||||
else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); }
|
||||
else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); }
|
||||
else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); }
|
||||
else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
|
||||
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
|
||||
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
|
||||
else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; }
|
||||
else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
|
||||
else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
|
||||
else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; }
|
||||
else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; }
|
||||
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
|
||||
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
|
||||
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
|
||||
else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
|
||||
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
|
||||
else if ( arg == "--prompt") { params.prompt = argv[++i]; }
|
||||
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
|
||||
else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); }
|
||||
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
|
||||
else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(argv[++i]); }
|
||||
else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(argv[++i]); }
|
||||
else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(argv[++i]); }
|
||||
else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); }
|
||||
else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); }
|
||||
else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); }
|
||||
else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
|
||||
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
|
||||
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
|
||||
else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; }
|
||||
else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
|
||||
else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
|
||||
else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; }
|
||||
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
|
||||
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
|
||||
else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
|
||||
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
|
||||
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
|
||||
else if (arg == "-f" || arg == "--file") { params.fname_inp.push_back(argv[++i]); }
|
||||
else {
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
whisper_print_usage(argc, argv, params);
|
||||
@ -139,40 +125,33 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
return true;
|
||||
}
|
||||
|
||||
void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params) {
|
||||
void whisper_print_usage(int argc, char ** argv, const whisper_params & params) {
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "usage: %s [options] file0.wav file1.wav ...\n", argv[0]);
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "options:\n");
|
||||
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
|
||||
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
|
||||
fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors);
|
||||
fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms);
|
||||
fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n);
|
||||
fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms);
|
||||
fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
|
||||
fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len);
|
||||
fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of);
|
||||
fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size);
|
||||
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
|
||||
fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold);
|
||||
fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold);
|
||||
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
|
||||
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
|
||||
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
|
||||
fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false");
|
||||
fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
|
||||
fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false");
|
||||
fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false");
|
||||
fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false");
|
||||
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
||||
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
|
||||
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
|
||||
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true");
|
||||
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
|
||||
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
|
||||
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
||||
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
|
||||
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
|
||||
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
|
||||
fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors);
|
||||
fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms);
|
||||
fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n);
|
||||
fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms);
|
||||
fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
|
||||
fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len);
|
||||
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
|
||||
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
|
||||
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
|
||||
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
|
||||
fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false");
|
||||
fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
|
||||
fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false");
|
||||
fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false");
|
||||
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
||||
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
|
||||
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true");
|
||||
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
|
||||
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
||||
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
@ -188,81 +167,90 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi
|
||||
|
||||
const int n_segments = whisper_full_n_segments(ctx);
|
||||
|
||||
std::string speaker = "";
|
||||
|
||||
int64_t t0;
|
||||
int64_t t1;
|
||||
|
||||
// print the last n_new segments
|
||||
const int s0 = n_segments - n_new;
|
||||
|
||||
if (s0 == 0) {
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
for (int i = s0; i < n_segments; i++) {
|
||||
if (!params.no_timestamps || params.diarize) {
|
||||
t0 = whisper_full_get_segment_t0(ctx, i);
|
||||
t1 = whisper_full_get_segment_t1(ctx, i);
|
||||
}
|
||||
|
||||
if (!params.no_timestamps) {
|
||||
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
|
||||
}
|
||||
|
||||
if (params.diarize && pcmf32s.size() == 2) {
|
||||
const int64_t n_samples = pcmf32s[0].size();
|
||||
|
||||
const int64_t is0 = timestamp_to_sample(t0, n_samples);
|
||||
const int64_t is1 = timestamp_to_sample(t1, n_samples);
|
||||
|
||||
double energy0 = 0.0f;
|
||||
double energy1 = 0.0f;
|
||||
|
||||
for (int64_t j = is0; j < is1; j++) {
|
||||
energy0 += fabs(pcmf32s[0][j]);
|
||||
energy1 += fabs(pcmf32s[1][j]);
|
||||
}
|
||||
|
||||
if (energy0 > 1.1*energy1) {
|
||||
speaker = "(speaker 0)";
|
||||
} else if (energy1 > 1.1*energy0) {
|
||||
speaker = "(speaker 1)";
|
||||
} else {
|
||||
speaker = "(speaker ?)";
|
||||
}
|
||||
|
||||
//printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str());
|
||||
}
|
||||
|
||||
if (params.print_colors) {
|
||||
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
|
||||
if (params.print_special == false) {
|
||||
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
|
||||
if (id >= whisper_token_eot(ctx)) {
|
||||
continue;
|
||||
if (params.no_timestamps) {
|
||||
if (params.print_colors) {
|
||||
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
|
||||
if (params.print_special == false) {
|
||||
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
|
||||
if (id >= whisper_token_eot(ctx)) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
const char * text = whisper_full_get_token_text(ctx, i, j);
|
||||
const float p = whisper_full_get_token_p (ctx, i, j);
|
||||
|
||||
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
|
||||
|
||||
printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
|
||||
}
|
||||
} else {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
printf("%s", text);
|
||||
}
|
||||
fflush(stdout);
|
||||
} else {
|
||||
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
|
||||
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
||||
|
||||
std::string speaker = "";
|
||||
|
||||
if (params.diarize && pcmf32s.size() == 2) {
|
||||
const int64_t n_samples = pcmf32s[0].size();
|
||||
|
||||
const int64_t is0 = timestamp_to_sample(t0, n_samples);
|
||||
const int64_t is1 = timestamp_to_sample(t1, n_samples);
|
||||
|
||||
double energy0 = 0.0f;
|
||||
double energy1 = 0.0f;
|
||||
|
||||
for (int64_t j = is0; j < is1; j++) {
|
||||
energy0 += fabs(pcmf32s[0][j]);
|
||||
energy1 += fabs(pcmf32s[1][j]);
|
||||
}
|
||||
|
||||
const char * text = whisper_full_get_token_text(ctx, i, j);
|
||||
const float p = whisper_full_get_token_p (ctx, i, j);
|
||||
if (energy0 > 1.1*energy1) {
|
||||
speaker = "(speaker 0)";
|
||||
} else if (energy1 > 1.1*energy0) {
|
||||
speaker = "(speaker 1)";
|
||||
} else {
|
||||
speaker = "(speaker ?)";
|
||||
}
|
||||
|
||||
const int col = std::max(0, std::min((int) k_colors.size() - 1, (int) (std::pow(p, 3)*float(k_colors.size()))));
|
||||
|
||||
printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
|
||||
//printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str());
|
||||
}
|
||||
} else {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
|
||||
printf("%s%s", speaker.c_str(), text);
|
||||
if (params.print_colors) {
|
||||
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
|
||||
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
|
||||
if (params.print_special == false) {
|
||||
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
|
||||
if (id >= whisper_token_eot(ctx)) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
const char * text = whisper_full_get_token_text(ctx, i, j);
|
||||
const float p = whisper_full_get_token_p (ctx, i, j);
|
||||
|
||||
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
|
||||
|
||||
printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
|
||||
}
|
||||
printf("\n");
|
||||
} else {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
|
||||
printf("[%s --> %s] %s%s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), speaker.c_str(), text);
|
||||
}
|
||||
}
|
||||
|
||||
// with timestamps or speakers: each segment on new line
|
||||
if (!params.no_timestamps || params.diarize) {
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
fflush(stdout);
|
||||
}
|
||||
}
|
||||
|
||||
@ -278,7 +266,7 @@ bool output_txt(struct whisper_context * ctx, const char * fname) {
|
||||
const int n_segments = whisper_full_n_segments(ctx);
|
||||
for (int i = 0; i < n_segments; ++i) {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
fout << text << "\n";
|
||||
fout << text;
|
||||
}
|
||||
|
||||
return true;
|
||||
@ -331,35 +319,10 @@ bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_
|
||||
return true;
|
||||
}
|
||||
|
||||
bool output_csv(struct whisper_context * ctx, const char * fname) {
|
||||
std::ofstream fout(fname);
|
||||
if (!fout.is_open()) {
|
||||
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
|
||||
return false;
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
|
||||
|
||||
const int n_segments = whisper_full_n_segments(ctx);
|
||||
for (int i = 0; i < n_segments; ++i) {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
if (text[0] == ' ') {
|
||||
text = text + sizeof(char); //whisper_full_get_segment_text() returns a string with leading space, point to the next character.
|
||||
}
|
||||
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
|
||||
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
||||
|
||||
//need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds.
|
||||
fout << 10 * t0 << ", " << 10 * t1 << ", \"" << text << "\"\n";
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// karaoke video generation
|
||||
// outputs a bash script that uses ffmpeg to generate a video with the subtitles
|
||||
// TODO: font parameter adjustments
|
||||
bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & /*params*/, float t_sec) {
|
||||
bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, float t_sec) {
|
||||
std::ofstream fout(fname);
|
||||
|
||||
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
|
||||
@ -408,6 +371,7 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
|
||||
txt_ul = "\\ \\ ";
|
||||
|
||||
{
|
||||
int ncnt = 0;
|
||||
for (int k = 0; k < n; ++k) {
|
||||
const auto & token2 = tokens[k];
|
||||
|
||||
@ -431,6 +395,8 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
|
||||
txt_ul += "\\ ";
|
||||
}
|
||||
}
|
||||
|
||||
ncnt += txt.size();
|
||||
}
|
||||
|
||||
::replace_all(txt_bg, "'", "\u2019");
|
||||
@ -481,7 +447,7 @@ int main(int argc, char ** argv) {
|
||||
return 2;
|
||||
}
|
||||
|
||||
if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1) {
|
||||
if (whisper_lang_id(params.language.c_str()) == -1) {
|
||||
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
|
||||
whisper_print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
@ -489,29 +455,13 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// whisper init
|
||||
|
||||
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
|
||||
struct whisper_context * ctx = whisper_init(params.model.c_str());
|
||||
|
||||
if (ctx == nullptr) {
|
||||
fprintf(stderr, "error: failed to initialize whisper context\n");
|
||||
return 3;
|
||||
}
|
||||
|
||||
// initial prompt
|
||||
std::vector<whisper_token> prompt_tokens;
|
||||
|
||||
if (!params.prompt.empty()) {
|
||||
prompt_tokens.resize(1024);
|
||||
prompt_tokens.resize(whisper_tokenize(ctx, params.prompt.c_str(), prompt_tokens.data(), prompt_tokens.size()));
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "initial prompt: '%s'\n", params.prompt.c_str());
|
||||
fprintf(stderr, "initial tokens: [ ");
|
||||
for (int i = 0; i < (int) prompt_tokens.size(); ++i) {
|
||||
fprintf(stderr, "%d ", prompt_tokens[i]);
|
||||
}
|
||||
fprintf(stderr, "]\n");
|
||||
}
|
||||
|
||||
for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
|
||||
const auto fname_inp = params.fname_inp[f];
|
||||
|
||||
@ -536,14 +486,14 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
}
|
||||
|
||||
if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) {
|
||||
if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), NULL) == false) {
|
||||
fprintf(stderr, "error: failed to open WAV file from stdin\n");
|
||||
return 4;
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
|
||||
}
|
||||
else if (drwav_init_file(&wav, fname_inp.c_str(), nullptr) == false) {
|
||||
else if (drwav_init_file(&wav, fname_inp.c_str(), NULL) == false) {
|
||||
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
|
||||
return 5;
|
||||
}
|
||||
@ -559,7 +509,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
|
||||
fprintf(stderr, "%s: WAV file '%s' must be %i kHz\n", argv[0], fname_inp.c_str(), WHISPER_SAMPLE_RATE/1000);
|
||||
fprintf(stderr, "%s: WAV file '%s' must be 16 kHz\n", argv[0], fname_inp.c_str());
|
||||
return 8;
|
||||
}
|
||||
|
||||
@ -578,11 +528,11 @@ int main(int argc, char ** argv) {
|
||||
// convert to mono, float
|
||||
pcmf32.resize(n);
|
||||
if (wav.channels == 1) {
|
||||
for (uint64_t i = 0; i < n; i++) {
|
||||
for (int i = 0; i < n; i++) {
|
||||
pcmf32[i] = float(pcm16[i])/32768.0f;
|
||||
}
|
||||
} else {
|
||||
for (uint64_t i = 0; i < n; i++) {
|
||||
for (int i = 0; i < n; i++) {
|
||||
pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
|
||||
}
|
||||
}
|
||||
@ -593,7 +543,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
pcmf32s[0].resize(n);
|
||||
pcmf32s[1].resize(n);
|
||||
for (uint64_t i = 0; i < n; i++) {
|
||||
for (int i = 0; i < n; i++) {
|
||||
pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
|
||||
pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
|
||||
}
|
||||
@ -627,14 +577,13 @@ int main(int argc, char ** argv) {
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
|
||||
// run the inference
|
||||
{
|
||||
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||
|
||||
wparams.strategy = params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY;
|
||||
|
||||
wparams.print_realtime = false;
|
||||
wparams.print_progress = params.print_progress;
|
||||
wparams.print_progress = false;
|
||||
wparams.print_timestamps = !params.no_timestamps;
|
||||
wparams.print_special = params.print_special;
|
||||
wparams.translate = params.translate;
|
||||
@ -646,19 +595,10 @@ int main(int argc, char ** argv) {
|
||||
|
||||
wparams.token_timestamps = params.output_wts || params.max_len > 0;
|
||||
wparams.thold_pt = params.word_thold;
|
||||
wparams.entropy_thold = params.entropy_thold;
|
||||
wparams.logprob_thold = params.logprob_thold;
|
||||
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
|
||||
|
||||
wparams.speed_up = params.speed_up;
|
||||
|
||||
wparams.greedy.best_of = params.best_of;
|
||||
wparams.beam_search.beam_size = params.beam_size;
|
||||
wparams.temperature_inc = -1;
|
||||
|
||||
wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
|
||||
wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();
|
||||
|
||||
whisper_print_user_data user_data = { ¶ms, &pcmf32s };
|
||||
|
||||
// this callback is called on each new segment
|
||||
@ -673,7 +613,7 @@ int main(int argc, char ** argv) {
|
||||
{
|
||||
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
|
||||
|
||||
wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, void * user_data) {
|
||||
wparams.encoder_begin_callback = [](struct whisper_context * ctx, void * user_data) {
|
||||
bool is_aborted = *(bool*)user_data;
|
||||
return !is_aborted;
|
||||
};
|
||||
@ -713,13 +653,6 @@ int main(int argc, char ** argv) {
|
||||
const auto fname_wts = fname_inp + ".wts";
|
||||
output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE);
|
||||
}
|
||||
|
||||
// output to CSV file
|
||||
if (params.output_csv) {
|
||||
const auto fname_csv = fname_inp + ".csv";
|
||||
output_csv(ctx, fname_csv.c_str());
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -8,8 +8,6 @@ add_executable(${TARGET}
|
||||
emscripten.cpp
|
||||
)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE
|
||||
whisper
|
||||
)
|
||||
|
@ -49,9 +49,6 @@ void stream_main(size_t index) {
|
||||
wparams.max_tokens = 32;
|
||||
wparams.audio_ctx = 768; // partial encoder context for better performance
|
||||
|
||||
// disable temperature fallback
|
||||
wparams.temperature_inc = -1.0f;
|
||||
|
||||
wparams.language = "en";
|
||||
|
||||
printf("stream: using %d threads\n", wparams.n_threads);
|
||||
@ -132,7 +129,7 @@ EMSCRIPTEN_BINDINGS(stream) {
|
||||
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
|
||||
for (size_t i = 0; i < g_contexts.size(); ++i) {
|
||||
if (g_contexts[i] == nullptr) {
|
||||
g_contexts[i] = whisper_init_from_file(path_model.c_str());
|
||||
g_contexts[i] = whisper_init(path_model.c_str());
|
||||
if (g_contexts[i] != nullptr) {
|
||||
g_running = true;
|
||||
if (g_worker.joinable()) {
|
||||
|
@ -2,9 +2,6 @@ if (WHISPER_SUPPORT_SDL2)
|
||||
# stream
|
||||
set(TARGET stream)
|
||||
add_executable(${TARGET} stream.cpp)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
|
||||
target_link_libraries(${TARGET} PRIVATE whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
|
||||
endif ()
|
||||
|
@ -10,23 +10,6 @@ More info is available in [issue #10](https://github.com/ggerganov/whisper.cpp/i
|
||||
|
||||
https://user-images.githubusercontent.com/1991296/194935793-76afede7-cfa8-48d8-a80f-28ba83be7d09.mp4
|
||||
|
||||
## Sliding window mode with VAD
|
||||
|
||||
Setting the `--step` argument to `0` enables the sliding window mode:
|
||||
|
||||
```java
|
||||
./stream -m ./models/ggml-small.en.bin -t 6 --step 0 --length 30000 -vth 0.6
|
||||
```
|
||||
|
||||
In this mode, the tool will transcribe only after some speech activity is detected. A very
|
||||
basic VAD detector is used, but in theory a more sophisticated approach can be added. The
|
||||
`-vth` argument determines the VAD threshold - higher values will make it detect silence more often.
|
||||
It's best to tune it to the specific use case, but a value around `0.6` should be OK in general.
|
||||
When silence is detected, it will transcribe the last `--length` milliseconds of audio and output
|
||||
a transcription block that is suitable for parsing.
|
||||
|
||||
## Building
|
||||
|
||||
The `stream` tool depends on SDL2 library to capture audio from the microphone. You can build it like this:
|
||||
|
||||
```bash
|
||||
|
@ -1,21 +1,18 @@
|
||||
// Real-time speech recognition of input from a microphone
|
||||
//
|
||||
// A very quick-n-dirty implementation serving mainly as a proof of concept.
|
||||
//
|
||||
|
||||
#include "whisper.h"
|
||||
|
||||
#include <SDL.h>
|
||||
#include <SDL_audio.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <cassert>
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <fstream>
|
||||
#include <mutex>
|
||||
|
||||
// 500 -> 00:05.000
|
||||
// 6000 -> 01:00.000
|
||||
@ -36,23 +33,19 @@ struct whisper_params {
|
||||
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||
int32_t step_ms = 3000;
|
||||
int32_t length_ms = 10000;
|
||||
int32_t keep_ms = 200;
|
||||
int32_t capture_id = -1;
|
||||
int32_t max_tokens = 32;
|
||||
int32_t audio_ctx = 0;
|
||||
|
||||
float vad_thold = 0.6f;
|
||||
float freq_thold = 100.0f;
|
||||
|
||||
bool speed_up = false;
|
||||
bool translate = false;
|
||||
bool print_special = false;
|
||||
bool no_context = true;
|
||||
bool no_timestamps = false;
|
||||
bool print_special = false;
|
||||
bool no_timestamps = true;
|
||||
|
||||
std::string language = "en";
|
||||
std::string model = "models/ggml-base.en.bin";
|
||||
std::string fname_out;
|
||||
std::string fname_out = "";
|
||||
};
|
||||
|
||||
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
|
||||
@ -68,16 +61,13 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
|
||||
else if ( arg == "--step") { params.step_ms = std::stoi(argv[++i]); }
|
||||
else if ( arg == "--length") { params.length_ms = std::stoi(argv[++i]); }
|
||||
else if ( arg == "--keep") { params.keep_ms = std::stoi(argv[++i]); }
|
||||
else if (arg == "-c" || arg == "--capture") { params.capture_id = std::stoi(argv[++i]); }
|
||||
else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); }
|
||||
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
|
||||
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
|
||||
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
|
||||
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
|
||||
else if (arg == "-kc" || arg == "--keep-context") { params.no_context = false; }
|
||||
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
|
||||
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
|
||||
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
|
||||
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
|
||||
@ -91,7 +81,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
return true;
|
||||
}
|
||||
|
||||
void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params) {
|
||||
void whisper_print_usage(int argc, char ** argv, const whisper_params & params) {
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "usage: %s [options]\n", argv[0]);
|
||||
fprintf(stderr, "\n");
|
||||
@ -100,16 +90,13 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
||||
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
|
||||
fprintf(stderr, " --step N [%-7d] audio step size in milliseconds\n", params.step_ms);
|
||||
fprintf(stderr, " --length N [%-7d] audio length in milliseconds\n", params.length_ms);
|
||||
fprintf(stderr, " --keep N [%-7d] audio to keep from previous step in ms\n", params.keep_ms);
|
||||
fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
|
||||
fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
|
||||
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
|
||||
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
|
||||
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
|
||||
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
|
||||
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
|
||||
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
||||
fprintf(stderr, " -kc, --keep-context [%-7s] keep context between audio chunks\n", params.no_context ? "false" : "true");
|
||||
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
||||
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
|
||||
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
||||
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
|
||||
@ -120,58 +107,19 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
||||
// SDL Audio capture
|
||||
//
|
||||
|
||||
class audio_async {
|
||||
public:
|
||||
audio_async(int len_ms);
|
||||
~audio_async();
|
||||
SDL_AudioDeviceID g_dev_id_in = 0;
|
||||
|
||||
bool init(int capture_id, int sample_rate);
|
||||
|
||||
// start capturing audio via the provided SDL callback
|
||||
// keep last len_ms seconds of audio in a circular buffer
|
||||
bool resume();
|
||||
bool pause();
|
||||
bool clear();
|
||||
|
||||
// callback to be called by SDL
|
||||
void callback(uint8_t * stream, int len);
|
||||
|
||||
// get audio data from the circular buffer
|
||||
void get(int ms, std::vector<float> & audio);
|
||||
|
||||
private:
|
||||
SDL_AudioDeviceID m_dev_id_in = 0;
|
||||
|
||||
int m_len_ms = 0;
|
||||
int m_sample_rate = 0;
|
||||
|
||||
std::atomic_bool m_running;
|
||||
std::mutex m_mutex;
|
||||
|
||||
std::vector<float> m_audio;
|
||||
std::vector<float> m_audio_new;
|
||||
size_t m_audio_pos = 0;
|
||||
size_t m_audio_len = 0;
|
||||
};
|
||||
|
||||
audio_async::audio_async(int len_ms) {
|
||||
m_len_ms = len_ms;
|
||||
|
||||
m_running = false;
|
||||
}
|
||||
|
||||
audio_async::~audio_async() {
|
||||
if (m_dev_id_in) {
|
||||
SDL_CloseAudioDevice(m_dev_id_in);
|
||||
bool audio_sdl_init(const int capture_id) {
|
||||
if (g_dev_id_in) {
|
||||
fprintf(stderr, "%s: already initialized\n", __func__);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool audio_async::init(int capture_id, int sample_rate) {
|
||||
SDL_LogSetPriority(SDL_LOG_CATEGORY_APPLICATION, SDL_LOG_PRIORITY_INFO);
|
||||
|
||||
if (SDL_Init(SDL_INIT_AUDIO) < 0) {
|
||||
SDL_LogError(SDL_LOG_CATEGORY_APPLICATION, "Couldn't initialize SDL: %s\n", SDL_GetError());
|
||||
return false;
|
||||
return (1);
|
||||
}
|
||||
|
||||
SDL_SetHintWithPriority(SDL_HINT_AUDIO_RESAMPLING_MODE, "medium", SDL_HINT_OVERRIDE);
|
||||
@ -190,232 +138,34 @@ bool audio_async::init(int capture_id, int sample_rate) {
|
||||
SDL_zero(capture_spec_requested);
|
||||
SDL_zero(capture_spec_obtained);
|
||||
|
||||
capture_spec_requested.freq = sample_rate;
|
||||
capture_spec_requested.freq = WHISPER_SAMPLE_RATE;
|
||||
capture_spec_requested.format = AUDIO_F32;
|
||||
capture_spec_requested.channels = 1;
|
||||
capture_spec_requested.samples = 1024;
|
||||
capture_spec_requested.callback = [](void * userdata, uint8_t * stream, int len) {
|
||||
audio_async * audio = (audio_async *) userdata;
|
||||
audio->callback(stream, len);
|
||||
};
|
||||
capture_spec_requested.userdata = this;
|
||||
|
||||
if (capture_id >= 0) {
|
||||
fprintf(stderr, "%s: attempt to open capture device %d : '%s' ...\n", __func__, capture_id, SDL_GetAudioDeviceName(capture_id, SDL_TRUE));
|
||||
m_dev_id_in = SDL_OpenAudioDevice(SDL_GetAudioDeviceName(capture_id, SDL_TRUE), SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
|
||||
g_dev_id_in = SDL_OpenAudioDevice(SDL_GetAudioDeviceName(capture_id, SDL_TRUE), SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
|
||||
} else {
|
||||
fprintf(stderr, "%s: attempt to open default capture device ...\n", __func__);
|
||||
m_dev_id_in = SDL_OpenAudioDevice(nullptr, SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
|
||||
g_dev_id_in = SDL_OpenAudioDevice(nullptr, SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
|
||||
}
|
||||
|
||||
if (!m_dev_id_in) {
|
||||
if (!g_dev_id_in) {
|
||||
fprintf(stderr, "%s: couldn't open an audio device for capture: %s!\n", __func__, SDL_GetError());
|
||||
m_dev_id_in = 0;
|
||||
|
||||
return false;
|
||||
g_dev_id_in = 0;
|
||||
} else {
|
||||
fprintf(stderr, "%s: obtained spec for input device (SDL Id = %d):\n", __func__, m_dev_id_in);
|
||||
fprintf(stderr, "%s: - sample rate: %d\n", __func__, capture_spec_obtained.freq);
|
||||
fprintf(stderr, "%s: - format: %d (required: %d)\n", __func__, capture_spec_obtained.format,
|
||||
capture_spec_requested.format);
|
||||
fprintf(stderr, "%s: - channels: %d (required: %d)\n", __func__, capture_spec_obtained.channels,
|
||||
capture_spec_requested.channels);
|
||||
fprintf(stderr, "%s: - samples per frame: %d\n", __func__, capture_spec_obtained.samples);
|
||||
}
|
||||
|
||||
m_sample_rate = capture_spec_obtained.freq;
|
||||
|
||||
m_audio.resize((m_sample_rate*m_len_ms)/1000);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool audio_async::resume() {
|
||||
if (!m_dev_id_in) {
|
||||
fprintf(stderr, "%s: no audio device to resume!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (m_running) {
|
||||
fprintf(stderr, "%s: already running!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
SDL_PauseAudioDevice(m_dev_id_in, 0);
|
||||
|
||||
m_running = true;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool audio_async::pause() {
|
||||
if (!m_dev_id_in) {
|
||||
fprintf(stderr, "%s: no audio device to pause!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!m_running) {
|
||||
fprintf(stderr, "%s: already paused!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
SDL_PauseAudioDevice(m_dev_id_in, 1);
|
||||
|
||||
m_running = false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool audio_async::clear() {
|
||||
if (!m_dev_id_in) {
|
||||
fprintf(stderr, "%s: no audio device to clear!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!m_running) {
|
||||
fprintf(stderr, "%s: not running!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(m_mutex);
|
||||
|
||||
m_audio_pos = 0;
|
||||
m_audio_len = 0;
|
||||
fprintf(stderr, "%s: obtained spec for input device (SDL Id = %d):\n", __func__, g_dev_id_in);
|
||||
fprintf(stderr, "%s: - sample rate: %d\n", __func__, capture_spec_obtained.freq);
|
||||
fprintf(stderr, "%s: - format: %d (required: %d)\n", __func__, capture_spec_obtained.format, capture_spec_requested.format);
|
||||
fprintf(stderr, "%s: - channels: %d (required: %d)\n", __func__, capture_spec_obtained.channels, capture_spec_requested.channels);
|
||||
fprintf(stderr, "%s: - samples per frame: %d\n", __func__, capture_spec_obtained.samples);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// callback to be called by SDL
|
||||
void audio_async::callback(uint8_t * stream, int len) {
|
||||
if (!m_running) {
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t n_samples = len / sizeof(float);
|
||||
|
||||
m_audio_new.resize(n_samples);
|
||||
memcpy(m_audio_new.data(), stream, n_samples * sizeof(float));
|
||||
|
||||
//fprintf(stderr, "%s: %zu samples, pos %zu, len %zu\n", __func__, n_samples, m_audio_pos, m_audio_len);
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(m_mutex);
|
||||
|
||||
if (m_audio_pos + n_samples > m_audio.size()) {
|
||||
const size_t n0 = m_audio.size() - m_audio_pos;
|
||||
|
||||
memcpy(&m_audio[m_audio_pos], stream, n0 * sizeof(float));
|
||||
memcpy(&m_audio[0], &stream[n0], (n_samples - n0) * sizeof(float));
|
||||
|
||||
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
|
||||
m_audio_len = m_audio.size();
|
||||
} else {
|
||||
memcpy(&m_audio[m_audio_pos], stream, n_samples * sizeof(float));
|
||||
|
||||
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
|
||||
m_audio_len = std::min(m_audio_len + n_samples, m_audio.size());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void audio_async::get(int ms, std::vector<float> & result) {
|
||||
if (!m_dev_id_in) {
|
||||
fprintf(stderr, "%s: no audio device to get audio from!\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!m_running) {
|
||||
fprintf(stderr, "%s: not running!\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
result.clear();
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(m_mutex);
|
||||
|
||||
if (ms <= 0) {
|
||||
ms = m_len_ms;
|
||||
}
|
||||
|
||||
size_t n_samples = (m_sample_rate * ms) / 1000;
|
||||
if (n_samples > m_audio_len) {
|
||||
n_samples = m_audio_len;
|
||||
}
|
||||
|
||||
result.resize(n_samples);
|
||||
|
||||
int s0 = m_audio_pos - n_samples;
|
||||
if (s0 < 0) {
|
||||
s0 += m_audio.size();
|
||||
}
|
||||
|
||||
if (s0 + n_samples > m_audio.size()) {
|
||||
const size_t n0 = m_audio.size() - s0;
|
||||
|
||||
memcpy(result.data(), &m_audio[s0], n0 * sizeof(float));
|
||||
memcpy(&result[n0], &m_audio[0], (n_samples - n0) * sizeof(float));
|
||||
} else {
|
||||
memcpy(result.data(), &m_audio[s0], n_samples * sizeof(float));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////
|
||||
|
||||
void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate) {
|
||||
const float rc = 1.0f / (2.0f * M_PI * cutoff);
|
||||
const float dt = 1.0f / sample_rate;
|
||||
const float alpha = dt / (rc + dt);
|
||||
|
||||
float y = data[0];
|
||||
|
||||
for (size_t i = 1; i < data.size(); i++) {
|
||||
y = alpha * (y + data[i] - data[i - 1]);
|
||||
data[i] = y;
|
||||
}
|
||||
}
|
||||
|
||||
bool vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) {
|
||||
const int n_samples = pcmf32.size();
|
||||
const int n_samples_last = (sample_rate * last_ms) / 1000;
|
||||
|
||||
if (n_samples_last >= n_samples) {
|
||||
// not enough samples - assume no speech
|
||||
return false;
|
||||
}
|
||||
|
||||
if (freq_thold > 0.0f) {
|
||||
high_pass_filter(pcmf32, freq_thold, sample_rate);
|
||||
}
|
||||
|
||||
float energy_all = 0.0f;
|
||||
float energy_last = 0.0f;
|
||||
|
||||
for (int i = 0; i < n_samples; i++) {
|
||||
energy_all += fabsf(pcmf32[i]);
|
||||
|
||||
if (i >= n_samples - n_samples_last) {
|
||||
energy_last += fabsf(pcmf32[i]);
|
||||
}
|
||||
}
|
||||
|
||||
energy_all /= n_samples;
|
||||
energy_last /= n_samples_last;
|
||||
|
||||
if (verbose) {
|
||||
fprintf(stderr, "%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold);
|
||||
}
|
||||
|
||||
if (energy_last > vad_thold*energy_all) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
whisper_params params;
|
||||
|
||||
@ -423,46 +173,33 @@ int main(int argc, char ** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
params.keep_ms = std::min(params.keep_ms, params.step_ms); // cannot be more than step_ms
|
||||
|
||||
const int n_samples_step = (params.step_ms *1e-3)*WHISPER_SAMPLE_RATE;
|
||||
const int n_samples_len = (params.length_ms*1e-3)*WHISPER_SAMPLE_RATE;
|
||||
const int n_samples_keep = (params.keep_ms *1e-3)*WHISPER_SAMPLE_RATE;
|
||||
const int n_samples_30s = (30000 *1e-3)*WHISPER_SAMPLE_RATE;
|
||||
|
||||
const bool use_vad = n_samples_step <= 0; // sliding window mode uses VAD
|
||||
|
||||
const int n_new_line = !use_vad ? params.length_ms / params.step_ms - 1 : 1; // number of steps to print new line
|
||||
|
||||
params.no_timestamps = !use_vad;
|
||||
params.no_context |= use_vad;
|
||||
params.max_tokens = 0;
|
||||
|
||||
// init audio
|
||||
|
||||
audio_async audio(params.length_ms);
|
||||
if (!audio.init(params.capture_id, WHISPER_SAMPLE_RATE)) {
|
||||
fprintf(stderr, "%s: audio.init() failed!\n", __func__);
|
||||
if (!audio_sdl_init(params.capture_id)) {
|
||||
fprintf(stderr, "%s: audio_sdl_init() failed!\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
audio.resume();
|
||||
|
||||
// whisper init
|
||||
|
||||
if (whisper_lang_id(params.language.c_str()) == -1) {
|
||||
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
|
||||
whisper_print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
}
|
||||
|
||||
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
|
||||
// whisper init
|
||||
|
||||
std::vector<float> pcmf32 (n_samples_30s, 0.0f);
|
||||
struct whisper_context * ctx = whisper_init(params.model.c_str());
|
||||
|
||||
const int n_samples = (params.step_ms/1000.0)*WHISPER_SAMPLE_RATE;
|
||||
const int n_samples_len = (params.length_ms/1000.0)*WHISPER_SAMPLE_RATE;
|
||||
const int n_samples_30s = 30*WHISPER_SAMPLE_RATE;
|
||||
const int n_samples_keep = 0.2*WHISPER_SAMPLE_RATE;
|
||||
|
||||
std::vector<float> pcmf32(n_samples_30s, 0.0f);
|
||||
std::vector<float> pcmf32_old;
|
||||
std::vector<float> pcmf32_new(n_samples_30s, 0.0f);
|
||||
|
||||
std::vector<whisper_token> prompt_tokens;
|
||||
const int n_new_line = params.length_ms / params.step_ms - 1;
|
||||
|
||||
// print some info about the processing
|
||||
{
|
||||
@ -474,28 +211,23 @@ int main(int argc, char ** argv) {
|
||||
fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
|
||||
}
|
||||
}
|
||||
fprintf(stderr, "%s: processing %d samples (step = %.1f sec / len = %.1f sec / keep = %.1f sec), %d threads, lang = %s, task = %s, timestamps = %d ...\n",
|
||||
fprintf(stderr, "%s: processing %d samples (step = %.1f sec / len = %.1f sec), %d threads, lang = %s, task = %s, timestamps = %d ...\n",
|
||||
__func__,
|
||||
n_samples_step,
|
||||
float(n_samples_step)/WHISPER_SAMPLE_RATE,
|
||||
float(n_samples_len )/WHISPER_SAMPLE_RATE,
|
||||
float(n_samples_keep)/WHISPER_SAMPLE_RATE,
|
||||
n_samples,
|
||||
float(n_samples)/WHISPER_SAMPLE_RATE,
|
||||
float(n_samples_len)/WHISPER_SAMPLE_RATE,
|
||||
params.n_threads,
|
||||
params.language.c_str(),
|
||||
params.translate ? "translate" : "transcribe",
|
||||
params.no_timestamps ? 0 : 1);
|
||||
|
||||
if (!use_vad) {
|
||||
fprintf(stderr, "%s: n_new_line = %d, no_context = %d\n", __func__, n_new_line, params.no_context);
|
||||
} else {
|
||||
fprintf(stderr, "%s: using VAD, will transcribe on speech activity\n", __func__);
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s: n_new_line = %d\n", __func__, n_new_line);
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
int n_iter = 0;
|
||||
SDL_PauseAudioDevice(g_dev_id_in, 0);
|
||||
|
||||
int n_iter = 0;
|
||||
bool is_running = true;
|
||||
|
||||
std::ofstream fout;
|
||||
@ -510,9 +242,6 @@ int main(int argc, char ** argv) {
|
||||
printf("[Start speaking]");
|
||||
fflush(stdout);
|
||||
|
||||
auto t_last = std::chrono::high_resolution_clock::now();
|
||||
const auto t_start = t_last;
|
||||
|
||||
// main audio loop
|
||||
while (is_running) {
|
||||
// handle Ctrl + C
|
||||
@ -539,64 +268,35 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
// process new audio
|
||||
|
||||
if (!use_vad) {
|
||||
while (true) {
|
||||
audio.get(params.step_ms, pcmf32_new);
|
||||
|
||||
if ((int) pcmf32_new.size() > 2*n_samples_step) {
|
||||
fprintf(stderr, "\n\n%s: WARNING: cannot process audio fast enough, dropping audio ...\n\n", __func__);
|
||||
audio.clear();
|
||||
continue;
|
||||
}
|
||||
|
||||
if ((int) pcmf32_new.size() >= n_samples_step) {
|
||||
audio.clear();
|
||||
break;
|
||||
}
|
||||
|
||||
SDL_Delay(1);
|
||||
}
|
||||
|
||||
const int n_samples_new = pcmf32_new.size();
|
||||
|
||||
// take up to params.length_ms audio from previous iteration
|
||||
const int n_samples_take = std::min((int) pcmf32_old.size(), std::max(0, n_samples_keep + n_samples_len - n_samples_new));
|
||||
|
||||
//printf("processing: take = %d, new = %d, old = %d\n", n_samples_take, n_samples_new, (int) pcmf32_old.size());
|
||||
|
||||
pcmf32.resize(n_samples_new + n_samples_take);
|
||||
|
||||
for (int i = 0; i < n_samples_take; i++) {
|
||||
pcmf32[i] = pcmf32_old[pcmf32_old.size() - n_samples_take + i];
|
||||
}
|
||||
|
||||
memcpy(pcmf32.data() + n_samples_take, pcmf32_new.data(), n_samples_new*sizeof(float));
|
||||
|
||||
pcmf32_old = pcmf32;
|
||||
} else {
|
||||
const auto t_now = std::chrono::high_resolution_clock::now();
|
||||
const auto t_diff = std::chrono::duration_cast<std::chrono::milliseconds>(t_now - t_last).count();
|
||||
|
||||
if (t_diff < 2000) {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
audio.get(2000, pcmf32_new);
|
||||
|
||||
if (vad_simple(pcmf32_new, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, false)) {
|
||||
audio.get(params.length_ms, pcmf32);
|
||||
} else {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
t_last = t_now;
|
||||
if (n_iter > 0 && SDL_GetQueuedAudioSize(g_dev_id_in) > 2*n_samples*sizeof(float)) {
|
||||
fprintf(stderr, "\n\n%s: WARNING: cannot process audio fast enough, dropping audio ...\n\n", __func__);
|
||||
SDL_ClearQueuedAudio(g_dev_id_in);
|
||||
}
|
||||
|
||||
while (SDL_GetQueuedAudioSize(g_dev_id_in) < n_samples*sizeof(float)) {
|
||||
SDL_Delay(1);
|
||||
}
|
||||
|
||||
const int n_samples_new = SDL_GetQueuedAudioSize(g_dev_id_in)/sizeof(float);
|
||||
|
||||
// take one second from previous iteration
|
||||
//const int n_samples_take = std::min((int) pcmf32_old.size(), std::max(0, n_samples_30s/30 - n_samples_new));
|
||||
|
||||
// take up to params.length_ms audio from previous iteration
|
||||
const int n_samples_take = std::min((int) pcmf32_old.size(), std::max(0, n_samples_keep + n_samples_len - n_samples_new));
|
||||
|
||||
//printf("processing: take = %d, new = %d, old = %d\n", n_samples_take, n_samples_new, (int) pcmf32_old.size());
|
||||
|
||||
pcmf32.resize(n_samples_new + n_samples_take);
|
||||
|
||||
for (int i = 0; i < n_samples_take; i++) {
|
||||
pcmf32[i] = pcmf32_old[pcmf32_old.size() - n_samples_take + i];
|
||||
}
|
||||
|
||||
SDL_DequeueAudio(g_dev_id_in, pcmf32.data() + n_samples_take, n_samples_new*sizeof(float));
|
||||
|
||||
pcmf32_old = pcmf32;
|
||||
|
||||
// run the inference
|
||||
{
|
||||
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||
@ -607,7 +307,7 @@ int main(int argc, char ** argv) {
|
||||
wparams.print_timestamps = !params.no_timestamps;
|
||||
wparams.translate = params.translate;
|
||||
wparams.no_context = true;
|
||||
wparams.single_segment = !use_vad;
|
||||
wparams.single_segment = true;
|
||||
wparams.max_tokens = params.max_tokens;
|
||||
wparams.language = params.language.c_str();
|
||||
wparams.n_threads = params.n_threads;
|
||||
@ -615,9 +315,6 @@ int main(int argc, char ** argv) {
|
||||
wparams.audio_ctx = params.audio_ctx;
|
||||
wparams.speed_up = params.speed_up;
|
||||
|
||||
// disable temperature fallback
|
||||
wparams.temperature_inc = -1.0f;
|
||||
|
||||
wparams.prompt_tokens = params.no_context ? nullptr : prompt_tokens.data();
|
||||
wparams.prompt_n_tokens = params.no_context ? 0 : prompt_tokens.size();
|
||||
|
||||
@ -628,21 +325,12 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// print result;
|
||||
{
|
||||
if (!use_vad) {
|
||||
printf("\33[2K\r");
|
||||
printf("\33[2K\r");
|
||||
|
||||
// print long empty line to clear the previous line
|
||||
printf("%s", std::string(100, ' ').c_str());
|
||||
// print long empty line to clear the previous line
|
||||
printf("%s", std::string(100, ' ').c_str());
|
||||
|
||||
printf("\33[2K\r");
|
||||
} else {
|
||||
const int64_t t1 = (t_last - t_start).count()/1000000;
|
||||
const int64_t t0 = std::max(0.0, t1 - pcmf32.size()*1000.0/WHISPER_SAMPLE_RATE);
|
||||
|
||||
printf("\n");
|
||||
printf("### Transcription %d START | t0 = %d ms | t1 = %d ms\n", n_iter, (int) t0, (int) t1);
|
||||
printf("\n");
|
||||
}
|
||||
printf("\33[2K\r");
|
||||
|
||||
const int n_segments = whisper_full_n_segments(ctx);
|
||||
for (int i = 0; i < n_segments; ++i) {
|
||||
@ -670,16 +358,11 @@ int main(int argc, char ** argv) {
|
||||
if (params.fname_out.length() > 0) {
|
||||
fout << std::endl;
|
||||
}
|
||||
|
||||
if (use_vad){
|
||||
printf("\n");
|
||||
printf("### Transcription %d END\n", n_iter);
|
||||
}
|
||||
}
|
||||
|
||||
++n_iter;
|
||||
|
||||
if (!use_vad && (n_iter % n_new_line) == 0) {
|
||||
if ((n_iter % n_new_line) == 0) {
|
||||
printf("\n");
|
||||
|
||||
// keep part of the audio for next iteration to try to mitigate word boundary issues
|
||||
@ -701,7 +384,9 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
}
|
||||
|
||||
audio.pause();
|
||||
if (g_dev_id_in >= 0) {
|
||||
SDL_CloseAudioDevice(g_dev_id_in);
|
||||
}
|
||||
|
||||
whisper_print_timings(ctx);
|
||||
whisper_free(ctx);
|
||||
|
@ -9,8 +9,6 @@ add_executable(${TARGET}
|
||||
gpt-2.cpp
|
||||
)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE
|
||||
whisper
|
||||
)
|
||||
@ -33,8 +31,8 @@ set_target_properties(${TARGET} PROPERTIES LINK_FLAGS " \
|
||||
--bind \
|
||||
-s USE_PTHREADS=1 \
|
||||
-s PTHREAD_POOL_SIZE=8 \
|
||||
-s INITIAL_MEMORY=1800MB \
|
||||
-s TOTAL_MEMORY=1800MB \
|
||||
-s INITIAL_MEMORY=1600MB \
|
||||
-s TOTAL_MEMORY=1600MB \
|
||||
-s FORCE_FILESYSTEM=1 \
|
||||
-s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap']\" \
|
||||
${EXTRA_FLAGS} \
|
||||
|
@ -6,8 +6,6 @@ Talk with an Artificial Intelligence in your browser:
|
||||
|
||||
Online demo: https://whisper.ggerganov.com/talk/
|
||||
|
||||
Terminal version: [examples/talk](/examples/talk)
|
||||
|
||||
## How it works?
|
||||
|
||||
This demo leverages 2 modern neural network models to create a high-quality voice chat directly in your browser:
|
||||
@ -36,7 +34,7 @@ In order to run this demo efficiently, you need to have the following:
|
||||
- Latest Chrome or Firefox browser (Safari is not supported)
|
||||
- Run this on a desktop or laptop with modern CPU (a mobile phone will likely not be good enough)
|
||||
- Speak phrases that are no longer than 10 seconds - this is the audio context of the AI
|
||||
- The web-page uses about 1.8GB of RAM
|
||||
- The web-page uses about 1.6GB of RAM
|
||||
|
||||
Notice that this demo is using the smallest GPT-2 model, so the generated text responses are not always very good.
|
||||
Also, the prompting strategy can likely be improved to achieve better results.
|
||||
|
@ -271,7 +271,7 @@ EMSCRIPTEN_BINDINGS(talk) {
|
||||
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
|
||||
for (size_t i = 0; i < g_contexts.size(); ++i) {
|
||||
if (g_contexts[i] == nullptr) {
|
||||
g_contexts[i] = whisper_init_from_file(path_model.c_str());
|
||||
g_contexts[i] = whisper_init(path_model.c_str());
|
||||
if (g_contexts[i] != nullptr) {
|
||||
g_running = true;
|
||||
if (g_worker.joinable()) {
|
||||
|
@ -325,9 +325,10 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
|
||||
|
||||
// create the ggml context
|
||||
{
|
||||
struct ggml_init_params params;
|
||||
params.mem_size = ctx_size;
|
||||
params.mem_buffer = NULL;
|
||||
struct ggml_init_params params = {
|
||||
.mem_size = ctx_size,
|
||||
.mem_buffer = NULL,
|
||||
};
|
||||
|
||||
model.ctx = ggml_init(params);
|
||||
if (!model.ctx) {
|
||||
@ -528,14 +529,13 @@ bool gpt2_eval(
|
||||
}
|
||||
}
|
||||
|
||||
struct ggml_init_params params;
|
||||
params.mem_size = buf_size;
|
||||
params.mem_buffer = buf;
|
||||
struct ggml_init_params params = {
|
||||
.mem_size = buf_size,
|
||||
.mem_buffer = buf,
|
||||
};
|
||||
|
||||
struct ggml_context * ctx0 = ggml_init(params);
|
||||
|
||||
struct ggml_cgraph gf = { };
|
||||
gf.n_threads = n_threads;
|
||||
struct ggml_cgraph gf = { .n_threads = n_threads };
|
||||
|
||||
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
||||
memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
|
||||
|
1
examples/talk/.gitignore
vendored
1
examples/talk/.gitignore
vendored
@ -1 +0,0 @@
|
||||
eleven-labs.py
|
@ -1,16 +0,0 @@
|
||||
if (WHISPER_SUPPORT_SDL2)
|
||||
# talk
|
||||
set(TARGET talk)
|
||||
#add_executable(${TARGET} talk.cpp gpt-2.cpp)
|
||||
#target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
|
||||
#target_link_libraries(${TARGET} PRIVATE whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
|
||||
|
||||
# TODO: this is temporary
|
||||
# need to export ggml symbols for MSVC, but too lazy ..
|
||||
add_executable(${TARGET} talk.cpp gpt-2.cpp ../../ggml.c ../../whisper.cpp)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS} ../../)
|
||||
target_link_libraries(${TARGET} PRIVATE ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
|
||||
endif ()
|
@ -1,41 +0,0 @@
|
||||
# talk
|
||||
|
||||
Talk with an Artificial Intelligence in your terminal
|
||||
|
||||
[Demo Talk](https://user-images.githubusercontent.com/1991296/206805012-48e71cc2-588d-4745-8798-c1c70ea3b40d.mp4)
|
||||
|
||||
Web version: [examples/talk.wasm](/examples/talk.wasm)
|
||||
|
||||
## Building
|
||||
|
||||
The `talk` tool depends on SDL2 library to capture audio from the microphone. You can build it like this:
|
||||
|
||||
```bash
|
||||
# Install SDL2 on Linux
|
||||
sudo apt-get install libsdl2-dev
|
||||
|
||||
# Install SDL2 on Mac OS
|
||||
brew install sdl2
|
||||
|
||||
# Build the "talk" executable
|
||||
make talk
|
||||
|
||||
# Run it
|
||||
./talk -p Santa
|
||||
```
|
||||
|
||||
## GPT-2
|
||||
|
||||
To run this, you will need a ggml GPT-2 model: [instructions](https://github.com/ggerganov/ggml/tree/master/examples/gpt-2#downloading-and-converting-the-original-models)
|
||||
|
||||
Alternatively, you can simply download the smallest ggml GPT-2 117M model (240 MB) like this:
|
||||
|
||||
```
|
||||
wget --quiet --show-progress -O models/ggml-gpt-2-117M.bin https://huggingface.co/datasets/ggerganov/ggml/raw/main/ggml-model-gpt-2-117M.bin
|
||||
```
|
||||
|
||||
## TTS
|
||||
|
||||
For best experience, this example needs a TTS tool to convert the generated text responses to voice.
|
||||
You can use any TTS engine that you would like - simply edit the [speak.sh](speak.sh) script to your needs.
|
||||
By default, it is configured to use `espeak`, but you can use whatever you wish.
|
@ -1,923 +0,0 @@
|
||||
#include "ggml.h"
|
||||
#include "gpt-2.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <regex>
|
||||
#include <random>
|
||||
|
||||
/////////////////////// GPT-2 BEGIN /////////////////////////
|
||||
|
||||
//
|
||||
// Vocab utils
|
||||
//
|
||||
|
||||
std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text) {
|
||||
std::vector<std::string> words;
|
||||
|
||||
// first split the text into words
|
||||
{
|
||||
std::string str = text;
|
||||
std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
|
||||
|
||||
std::regex re(pat);
|
||||
std::smatch m;
|
||||
|
||||
while (std::regex_search(str, m, re)) {
|
||||
for (auto x : m) {
|
||||
words.push_back(x);
|
||||
}
|
||||
str = m.suffix();
|
||||
}
|
||||
}
|
||||
|
||||
// find the longest tokens that form the words:
|
||||
std::vector<gpt_vocab::id> tokens;
|
||||
for (const auto & word : words) {
|
||||
if (word.empty()) continue;
|
||||
|
||||
int i = 0;
|
||||
int n = word.size();
|
||||
while (i < n) {
|
||||
int j = n;
|
||||
while (j > i) {
|
||||
auto it = vocab.token_to_id.find(word.substr(i, j-i));
|
||||
if (it != vocab.token_to_id.end()) {
|
||||
tokens.push_back(it->second);
|
||||
i = j;
|
||||
break;
|
||||
}
|
||||
--j;
|
||||
}
|
||||
if (i == n) {
|
||||
break;
|
||||
}
|
||||
if (j == i) {
|
||||
auto sub = word.substr(i, 1);
|
||||
if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) {
|
||||
tokens.push_back(vocab.token_to_id.at(sub));
|
||||
} else {
|
||||
fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data());
|
||||
}
|
||||
++i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tokens;
|
||||
}
|
||||
|
||||
gpt_vocab::id gpt_sample_top_k_top_p(
|
||||
const gpt_vocab & vocab,
|
||||
const float * logits,
|
||||
int top_k,
|
||||
double top_p,
|
||||
double /*temp*/,
|
||||
std::mt19937 & rng) {
|
||||
int n_logits = vocab.id_to_token.size();
|
||||
|
||||
std::vector<std::pair<double, gpt_vocab::id>> logits_id;
|
||||
logits_id.reserve(n_logits);
|
||||
|
||||
for (int i = 0; i < n_logits; i++) {
|
||||
logits_id.emplace_back(logits[i], i);
|
||||
}
|
||||
|
||||
// find the top K tokens
|
||||
std::partial_sort(
|
||||
logits_id.begin(),
|
||||
logits_id.begin() + top_k, logits_id.end(),
|
||||
[](const std::pair<double, gpt_vocab::id> & a, const std::pair<double, gpt_vocab::id> & b) {
|
||||
return a.first > b.first;
|
||||
});
|
||||
|
||||
logits_id.resize(top_k);
|
||||
|
||||
// normalize
|
||||
{
|
||||
double sum = 0.0f;
|
||||
for (int i = 0; i < (int)logits_id.size(); i++) {
|
||||
sum += logits_id[i].first;
|
||||
}
|
||||
|
||||
sum = 1.0/sum;
|
||||
for (int i = 0; i < (int)logits_id.size(); i++) {
|
||||
logits_id[i].first *= sum;
|
||||
}
|
||||
}
|
||||
|
||||
if (top_p < 1.0f) {
|
||||
{
|
||||
double cumsum = 0.0f;
|
||||
for (int i = 0; i < top_k; i++) {
|
||||
cumsum += logits_id[i].first;
|
||||
if (cumsum >= top_p) {
|
||||
logits_id.resize(i+1);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// normalize again
|
||||
{
|
||||
double sum = 0.0f;
|
||||
for (int i = 0; i < (int)logits_id.size(); i++) {
|
||||
sum += logits_id[i].first;
|
||||
}
|
||||
|
||||
sum = 1.0/sum;
|
||||
for (int i = 0; i < (int)logits_id.size(); i++) {
|
||||
logits_id[i].first *= sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//printf("\n");
|
||||
//for (int i = 0; i < (int) logits_id.size(); i++) {
|
||||
// printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), logits_id[i].first);
|
||||
//}
|
||||
//exit(0);
|
||||
|
||||
// sample from the obtained distribution
|
||||
std::vector<double> probs;
|
||||
probs.reserve(logits_id.size());
|
||||
|
||||
for (int i = 0; i < (int) logits_id.size(); i++) {
|
||||
probs.push_back(logits_id[i].first);
|
||||
}
|
||||
|
||||
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
||||
int idx = dist(rng);
|
||||
|
||||
return logits_id[idx].second;
|
||||
}
|
||||
|
||||
// default hparams (GPT-2 117M)
|
||||
struct gpt2_hparams {
|
||||
int32_t n_vocab = 50257;
|
||||
int32_t n_ctx = 1024;
|
||||
int32_t n_embd = 768;
|
||||
int32_t n_head = 12;
|
||||
int32_t n_layer = 12;
|
||||
int32_t f16 = 1;
|
||||
};
|
||||
|
||||
struct gpt2_layer {
|
||||
// normalization
|
||||
struct ggml_tensor * ln_1_g;
|
||||
struct ggml_tensor * ln_1_b;
|
||||
|
||||
struct ggml_tensor * ln_2_g;
|
||||
struct ggml_tensor * ln_2_b;
|
||||
|
||||
// attention
|
||||
struct ggml_tensor * c_attn_attn_w;
|
||||
struct ggml_tensor * c_attn_attn_b;
|
||||
|
||||
struct ggml_tensor * c_attn_proj_w;
|
||||
struct ggml_tensor * c_attn_proj_b;
|
||||
|
||||
// mlp
|
||||
struct ggml_tensor * c_mlp_fc_w;
|
||||
struct ggml_tensor * c_mlp_fc_b;
|
||||
|
||||
struct ggml_tensor * c_mlp_proj_w_trans; // transposed for efficiency
|
||||
struct ggml_tensor * c_mlp_proj_b;
|
||||
};
|
||||
|
||||
struct gpt2_model {
|
||||
gpt2_hparams hparams;
|
||||
|
||||
// normalization
|
||||
struct ggml_tensor * ln_f_g;
|
||||
struct ggml_tensor * ln_f_b;
|
||||
|
||||
struct ggml_tensor * wte; // position embedding
|
||||
struct ggml_tensor * wpe; // token embedding
|
||||
|
||||
std::vector<gpt2_layer> layers;
|
||||
|
||||
// key + value memory
|
||||
struct ggml_tensor * memory_k;
|
||||
struct ggml_tensor * memory_v;
|
||||
|
||||
//
|
||||
struct ggml_context * ctx;
|
||||
std::map<std::string, struct ggml_tensor *> tensors;
|
||||
};
|
||||
|
||||
// load the model's weights from a file
|
||||
bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & vocab) {
|
||||
printf("%s: loading model from '%s'\n", __func__, fname.c_str());
|
||||
|
||||
auto fin = std::ifstream(fname, std::ios::binary);
|
||||
if (!fin) {
|
||||
fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
// verify magic
|
||||
{
|
||||
uint32_t magic;
|
||||
fin.read((char *) &magic, sizeof(magic));
|
||||
if (magic != 0x67676d6c) {
|
||||
fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// load hparams
|
||||
{
|
||||
auto & hparams = model.hparams;
|
||||
|
||||
fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
|
||||
fin.read((char *) &hparams.n_ctx, sizeof(hparams.n_ctx));
|
||||
fin.read((char *) &hparams.n_embd, sizeof(hparams.n_embd));
|
||||
fin.read((char *) &hparams.n_head, sizeof(hparams.n_head));
|
||||
fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));
|
||||
fin.read((char *) &hparams.f16, sizeof(hparams.f16));
|
||||
|
||||
printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
|
||||
printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx);
|
||||
printf("%s: n_embd = %d\n", __func__, hparams.n_embd);
|
||||
printf("%s: n_head = %d\n", __func__, hparams.n_head);
|
||||
printf("%s: n_layer = %d\n", __func__, hparams.n_layer);
|
||||
printf("%s: f16 = %d\n", __func__, hparams.f16);
|
||||
}
|
||||
|
||||
// load vocab
|
||||
{
|
||||
int32_t n_vocab = 0;
|
||||
fin.read((char *) &n_vocab, sizeof(n_vocab));
|
||||
|
||||
if (n_vocab != model.hparams.n_vocab) {
|
||||
fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
|
||||
__func__, fname.c_str(), n_vocab, model.hparams.n_vocab);
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string word;
|
||||
for (int i = 0; i < n_vocab; i++) {
|
||||
uint32_t len;
|
||||
fin.read((char *) &len, sizeof(len));
|
||||
|
||||
word.resize(len);
|
||||
fin.read((char *) &word[0], len);
|
||||
|
||||
vocab.token_to_id[word] = i;
|
||||
vocab.id_to_token[i] = word;
|
||||
}
|
||||
}
|
||||
|
||||
// for the big tensors, we have the option to store the data in 16-bit floats
|
||||
// in order to save memory and also to speed up the computation
|
||||
const ggml_type wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
||||
|
||||
auto & ctx = model.ctx;
|
||||
|
||||
size_t ctx_size = 0;
|
||||
|
||||
{
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
const int n_embd = hparams.n_embd;
|
||||
const int n_layer = hparams.n_layer;
|
||||
const int n_ctx = hparams.n_ctx;
|
||||
const int n_vocab = hparams.n_vocab;
|
||||
|
||||
ctx_size += n_embd*ggml_type_size(GGML_TYPE_F32); // ln_f_g
|
||||
ctx_size += n_embd*ggml_type_size(GGML_TYPE_F32); // ln_f_b
|
||||
|
||||
ctx_size += n_vocab*n_embd*ggml_type_size(wtype); // wte
|
||||
ctx_size += n_ctx*n_embd*ggml_type_size(GGML_TYPE_F32); // wpe
|
||||
|
||||
ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_1_g
|
||||
ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_1_b
|
||||
|
||||
ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_2_g
|
||||
ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_2_b
|
||||
|
||||
ctx_size += n_layer*(3*n_embd*n_embd*ggml_type_size(wtype)); // c_attn_attn_w
|
||||
ctx_size += n_layer*( 3*n_embd*ggml_type_size(GGML_TYPE_F32)); // c_attn_attn_b
|
||||
|
||||
ctx_size += n_layer*(n_embd*n_embd*ggml_type_size(wtype)); // c_attn_proj_w
|
||||
ctx_size += n_layer*( n_embd*ggml_type_size(GGML_TYPE_F32)); // c_attn_proj_b
|
||||
|
||||
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_size(wtype)); // c_mlp_fc_w
|
||||
ctx_size += n_layer*( 4*n_embd*ggml_type_size(GGML_TYPE_F32)); // c_mlp_fc_b
|
||||
|
||||
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_size(wtype)); // c_mlp_proj_w
|
||||
ctx_size += n_layer*( n_embd*ggml_type_size(GGML_TYPE_F32)); // c_mlp_proj_b
|
||||
|
||||
ctx_size += n_ctx*n_layer*n_embd*ggml_type_size(GGML_TYPE_F32); // memory_k
|
||||
ctx_size += n_ctx*n_layer*n_embd*ggml_type_size(GGML_TYPE_F32); // memory_v
|
||||
|
||||
ctx_size += (6 + 12*n_layer)*256; // object overhead
|
||||
|
||||
printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
|
||||
}
|
||||
|
||||
// create the ggml context
|
||||
{
|
||||
struct ggml_init_params params;
|
||||
params.mem_size = ctx_size;
|
||||
params.mem_buffer = nullptr;
|
||||
|
||||
model.ctx = ggml_init(params);
|
||||
if (!model.ctx) {
|
||||
fprintf(stderr, "%s: ggml_init() failed\n", __func__);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// prepare memory for the weights
|
||||
{
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
const int n_embd = hparams.n_embd;
|
||||
const int n_layer = hparams.n_layer;
|
||||
const int n_ctx = hparams.n_ctx;
|
||||
const int n_vocab = hparams.n_vocab;
|
||||
|
||||
model.layers.resize(n_layer);
|
||||
|
||||
model.ln_f_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
model.ln_f_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
|
||||
model.wte = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab);
|
||||
model.wpe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ctx);
|
||||
|
||||
// map by name
|
||||
model.tensors["model/ln_f/g"] = model.ln_f_g;
|
||||
model.tensors["model/ln_f/b"] = model.ln_f_b;
|
||||
|
||||
model.tensors["model/wte"] = model.wte;
|
||||
model.tensors["model/wpe"] = model.wpe;
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = model.layers[i];
|
||||
|
||||
layer.ln_1_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
layer.ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
|
||||
layer.ln_2_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
layer.ln_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
|
||||
layer.c_attn_attn_w = ggml_new_tensor_2d(ctx, wtype, 3*n_embd, n_embd);
|
||||
layer.c_attn_attn_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3*n_embd);
|
||||
|
||||
layer.c_attn_proj_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
|
||||
layer.c_attn_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
|
||||
layer.c_mlp_fc_w = ggml_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd);
|
||||
layer.c_mlp_fc_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_embd);
|
||||
|
||||
layer.c_mlp_proj_w_trans = ggml_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd);
|
||||
layer.c_mlp_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
|
||||
// map by name
|
||||
model.tensors["model/h" + std::to_string(i) + "/ln_1/g"] = layer.ln_1_g;
|
||||
model.tensors["model/h" + std::to_string(i) + "/ln_1/b"] = layer.ln_1_b;
|
||||
|
||||
model.tensors["model/h" + std::to_string(i) + "/ln_2/g"] = layer.ln_2_g;
|
||||
model.tensors["model/h" + std::to_string(i) + "/ln_2/b"] = layer.ln_2_b;
|
||||
|
||||
model.tensors["model/h" + std::to_string(i) + "/attn/c_attn/w"] = layer.c_attn_attn_w;
|
||||
model.tensors["model/h" + std::to_string(i) + "/attn/c_attn/b"] = layer.c_attn_attn_b;
|
||||
|
||||
model.tensors["model/h" + std::to_string(i) + "/attn/c_proj/w"] = layer.c_attn_proj_w;
|
||||
model.tensors["model/h" + std::to_string(i) + "/attn/c_proj/b"] = layer.c_attn_proj_b;
|
||||
|
||||
model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/w"] = layer.c_mlp_fc_w;
|
||||
model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/b"] = layer.c_mlp_fc_b;
|
||||
|
||||
model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/w"] = layer.c_mlp_proj_w_trans;
|
||||
model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/b"] = layer.c_mlp_proj_b;
|
||||
}
|
||||
}
|
||||
|
||||
// key + value memory
|
||||
{
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
const int n_embd = hparams.n_embd;
|
||||
const int n_layer = hparams.n_layer;
|
||||
const int n_ctx = hparams.n_ctx;
|
||||
|
||||
const int n_mem = n_layer*n_ctx;
|
||||
const int n_elements = n_embd*n_mem;
|
||||
|
||||
model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
|
||||
model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
|
||||
|
||||
const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);
|
||||
|
||||
printf("%s: memory size = %8.2f MB, n_mem = %d\n", __func__, memory_size/1024.0/1024.0, n_mem);
|
||||
}
|
||||
|
||||
// load weights
|
||||
{
|
||||
size_t total_size = 0;
|
||||
|
||||
while (true) {
|
||||
int32_t n_dims;
|
||||
int32_t length;
|
||||
int32_t ftype;
|
||||
|
||||
fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
|
||||
fin.read(reinterpret_cast<char *>(&length), sizeof(length));
|
||||
fin.read(reinterpret_cast<char *>(&ftype), sizeof(ftype));
|
||||
|
||||
if (fin.eof()) {
|
||||
break;
|
||||
}
|
||||
|
||||
int32_t nelements = 1;
|
||||
int32_t ne[2] = { 1, 1 };
|
||||
for (int i = 0; i < n_dims; ++i) {
|
||||
fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
|
||||
nelements *= ne[i];
|
||||
}
|
||||
|
||||
std::string name(length, 0);
|
||||
fin.read(&name[0], length);
|
||||
|
||||
if (model.tensors.find(name) == model.tensors.end()) {
|
||||
fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
|
||||
return false;
|
||||
}
|
||||
|
||||
auto tensor = model.tensors[name.data()];
|
||||
if (ggml_nelements(tensor) != nelements) {
|
||||
fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
|
||||
fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
|
||||
__func__, name.data(), tensor->ne[0], tensor->ne[1], ne[0], ne[1]);
|
||||
return false;
|
||||
}
|
||||
|
||||
const size_t bpe = (ftype == 0) ? sizeof(float) : sizeof(ggml_fp16_t);
|
||||
|
||||
if (nelements*bpe != ggml_nbytes(tensor)) {
|
||||
fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
|
||||
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
|
||||
return false;
|
||||
}
|
||||
|
||||
fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
|
||||
|
||||
//printf("%24s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
|
||||
total_size += ggml_nbytes(tensor);
|
||||
}
|
||||
|
||||
printf("%s: model size = %8.2f MB\n", __func__, total_size/1024.0/1024.0);
|
||||
}
|
||||
|
||||
fin.close();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// evaluate the transformer
|
||||
//
|
||||
// - model: the model
|
||||
// - n_threads: number of threads to use
|
||||
// - n_past: the context size so far
|
||||
// - embd_inp: the embeddings of the tokens in the context
|
||||
// - embd_w: the predicted probabilities of the next token
|
||||
//
|
||||
bool gpt2_eval(
|
||||
const gpt2_model & model,
|
||||
const int n_threads,
|
||||
const int n_past,
|
||||
const std::vector<gpt_vocab::id> & embd_inp,
|
||||
std::vector<float> & embd_w,
|
||||
size_t & mem_per_token) {
|
||||
const int N = embd_inp.size();
|
||||
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
const int n_embd = hparams.n_embd;
|
||||
const int n_layer = hparams.n_layer;
|
||||
const int n_ctx = hparams.n_ctx;
|
||||
const int n_head = hparams.n_head;
|
||||
const int n_vocab = hparams.n_vocab;
|
||||
|
||||
static size_t buf_size = 5640ull*1024*1024;
|
||||
static void * buf = malloc(buf_size);
|
||||
|
||||
if (mem_per_token > 0 && mem_per_token*N > buf_size) {
|
||||
const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead
|
||||
printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
|
||||
|
||||
// reallocate
|
||||
buf_size = buf_size_new;
|
||||
buf = realloc(buf, buf_size);
|
||||
if (buf == nullptr) {
|
||||
fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
struct ggml_init_params params;
|
||||
params.mem_size = buf_size;
|
||||
params.mem_buffer = buf;
|
||||
|
||||
struct ggml_context * ctx0 = ggml_init(params);
|
||||
|
||||
struct ggml_cgraph gf = { };
|
||||
gf.n_threads = n_threads;
|
||||
|
||||
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
||||
memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
|
||||
|
||||
struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
||||
for (int i = 0; i < N; ++i) {
|
||||
((int32_t *) position->data)[i] = n_past + i;
|
||||
}
|
||||
|
||||
// wte + wpe
|
||||
struct ggml_tensor * inpL =
|
||||
ggml_add(ctx0,
|
||||
ggml_get_rows(ctx0, model.wte, embd),
|
||||
ggml_get_rows(ctx0, model.wpe, position));
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * cur;
|
||||
|
||||
// norm
|
||||
{
|
||||
// [ 768, N]
|
||||
cur = ggml_norm(ctx0, inpL);
|
||||
|
||||
// cur = ln_1_g*cur + ln_1_b
|
||||
// [ 768, N]
|
||||
cur = ggml_add(ctx0,
|
||||
ggml_mul(ctx0,
|
||||
ggml_repeat(ctx0, model.layers[il].ln_1_g, cur),
|
||||
cur),
|
||||
ggml_repeat(ctx0, model.layers[il].ln_1_b, cur));
|
||||
}
|
||||
|
||||
// attn
|
||||
// [2304, 768] - model.layers[il].c_attn_attn_w
|
||||
// [2304, 1] - model.layers[il].c_attn_attn_b
|
||||
// [ 768, N] - cur (in)
|
||||
// [2304, N] - cur (out)
|
||||
//
|
||||
// cur = attn_w*cur + attn_b
|
||||
// [2304, N]
|
||||
{
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
ggml_transpose(ctx0, model.layers[il].c_attn_attn_w),
|
||||
cur);
|
||||
|
||||
cur = ggml_add(ctx0,
|
||||
ggml_repeat(ctx0, model.layers[il].c_attn_attn_b, cur),
|
||||
cur);
|
||||
}
|
||||
|
||||
// self-attention
|
||||
{
|
||||
struct ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 0*sizeof(float)*n_embd);
|
||||
struct ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 1*sizeof(float)*n_embd);
|
||||
struct ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 2*sizeof(float)*n_embd);
|
||||
|
||||
// store key and value to memory
|
||||
if (N >= 1) {
|
||||
struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_k, N*n_embd, (ggml_element_size(model.memory_k)*n_embd)*(il*n_ctx + n_past));
|
||||
struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_v, N*n_embd, (ggml_element_size(model.memory_v)*n_embd)*(il*n_ctx + n_past));
|
||||
|
||||
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
|
||||
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
|
||||
}
|
||||
|
||||
// Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
|
||||
// [64, N, 12]
|
||||
struct ggml_tensor * Q =
|
||||
ggml_permute(ctx0,
|
||||
ggml_cpy(ctx0,
|
||||
Qcur,
|
||||
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd/n_head, n_head, N)),
|
||||
0, 2, 1, 3);
|
||||
|
||||
// K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)
|
||||
// [64, n_past + N, 12]
|
||||
struct ggml_tensor * K =
|
||||
ggml_permute(ctx0,
|
||||
ggml_reshape_3d(ctx0,
|
||||
ggml_view_1d(ctx0, model.memory_k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_k)*n_embd),
|
||||
n_embd/n_head, n_head, n_past + N),
|
||||
0, 2, 1, 3);
|
||||
|
||||
// GG: flash attention
|
||||
//struct ggml_tensor * V =
|
||||
// ggml_cpy(ctx0,
|
||||
// ggml_permute(ctx0,
|
||||
// ggml_reshape_3d(ctx0,
|
||||
// ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
|
||||
// n_embd/n_head, n_head, n_past + N),
|
||||
// 1, 2, 0, 3),
|
||||
// ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_past + N, n_embd/n_head, n_head));
|
||||
|
||||
//struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, true);
|
||||
|
||||
// K * Q
|
||||
// [n_past + N, N, 12]
|
||||
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
||||
|
||||
// KQ_scaled = KQ / sqrt(n_embd/n_head)
|
||||
// [n_past + N, N, 12]
|
||||
struct ggml_tensor * KQ_scaled =
|
||||
ggml_scale(ctx0,
|
||||
KQ,
|
||||
ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head))
|
||||
);
|
||||
|
||||
// KQ_masked = mask_past(KQ_scaled)
|
||||
// [n_past + N, N, 12]
|
||||
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
|
||||
|
||||
// KQ = soft_max(KQ_masked)
|
||||
// [n_past + N, N, 12]
|
||||
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
|
||||
|
||||
// V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
|
||||
// [n_past + N, 64, 12]
|
||||
struct ggml_tensor * V_trans =
|
||||
ggml_permute(ctx0,
|
||||
ggml_reshape_3d(ctx0,
|
||||
ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
|
||||
n_embd/n_head, n_head, n_past + N),
|
||||
1, 2, 0, 3);
|
||||
|
||||
// KQV = transpose(V) * KQ_soft_max
|
||||
// [64, N, 12]
|
||||
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
|
||||
|
||||
// KQV_merged = KQV.permute(0, 2, 1, 3)
|
||||
// [64, 12, N]
|
||||
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
||||
|
||||
// cur = KQV_merged.contiguous().view(n_embd, N)
|
||||
// [768, N]
|
||||
cur = ggml_cpy(ctx0,
|
||||
KQV_merged,
|
||||
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
|
||||
}
|
||||
|
||||
// projection
|
||||
// [ 768, 768] - model.layers[il].c_attn_proj_w
|
||||
// [ 768, 1] - model.layers[il].c_attn_proj_b
|
||||
// [ 768, N] - cur (in)
|
||||
// [ 768, N] - cur (out)
|
||||
//
|
||||
// cur = proj_w*cur + proj_b
|
||||
// [768, N]
|
||||
{
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
ggml_transpose(ctx0, model.layers[il].c_attn_proj_w),
|
||||
cur);
|
||||
|
||||
cur = ggml_add(ctx0,
|
||||
ggml_repeat(ctx0, model.layers[il].c_attn_proj_b, cur),
|
||||
cur);
|
||||
}
|
||||
|
||||
// add the input
|
||||
cur = ggml_add(ctx0, cur, inpL);
|
||||
|
||||
struct ggml_tensor * inpFF = cur;
|
||||
|
||||
// feed-forward network
|
||||
{
|
||||
// norm
|
||||
{
|
||||
cur = ggml_norm(ctx0, inpFF);
|
||||
|
||||
// cur = ln_2_g*cur + ln_2_b
|
||||
// [ 768, N]
|
||||
cur = ggml_add(ctx0,
|
||||
ggml_mul(ctx0,
|
||||
ggml_repeat(ctx0, model.layers[il].ln_2_g, cur),
|
||||
cur),
|
||||
ggml_repeat(ctx0, model.layers[il].ln_2_b, cur));
|
||||
}
|
||||
|
||||
// fully connected
|
||||
// [3072, 768] - model.layers[il].c_mlp_fc_w
|
||||
// [3072, 1] - model.layers[il].c_mlp_fc_b
|
||||
// [ 768, N] - cur (in)
|
||||
// [3072, N] - cur (out)
|
||||
//
|
||||
// cur = fc_w*cur + fc_b
|
||||
// [3072, N]
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
ggml_transpose(ctx0, model.layers[il].c_mlp_fc_w),
|
||||
cur);
|
||||
|
||||
cur = ggml_add(ctx0,
|
||||
ggml_repeat(ctx0, model.layers[il].c_mlp_fc_b, cur),
|
||||
cur);
|
||||
|
||||
// GELU activation
|
||||
// [3072, N]
|
||||
cur = ggml_gelu(ctx0, cur);
|
||||
|
||||
// projection
|
||||
// [ 768, 3072] - model.layers[il].c_mlp_proj_w
|
||||
// [ 768, 1] - model.layers[il].c_mlp_proj_b
|
||||
// [3072, N] - cur (in)
|
||||
// [ 768, N] - cur (out)
|
||||
//
|
||||
// cur = proj_w*cur + proj_b
|
||||
// [768, N]
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
model.layers[il].c_mlp_proj_w_trans,
|
||||
cur);
|
||||
|
||||
cur = ggml_add(ctx0,
|
||||
ggml_repeat(ctx0, model.layers[il].c_mlp_proj_b, cur),
|
||||
cur);
|
||||
}
|
||||
|
||||
// input for next layer
|
||||
inpL = ggml_add(ctx0, cur, inpFF);
|
||||
}
|
||||
|
||||
// norm
|
||||
{
|
||||
// [ 768, N]
|
||||
inpL = ggml_norm(ctx0, inpL);
|
||||
|
||||
// inpL = ln_f_g*inpL + ln_f_b
|
||||
// [ 768, N]
|
||||
inpL = ggml_add(ctx0,
|
||||
ggml_mul(ctx0,
|
||||
ggml_repeat(ctx0, model.ln_f_g, inpL),
|
||||
inpL),
|
||||
ggml_repeat(ctx0, model.ln_f_b, inpL));
|
||||
}
|
||||
|
||||
// inpL = WTE * inpL
|
||||
// [ 768, 50257] - model.wte
|
||||
// [ 768, N] - inpL
|
||||
inpL = ggml_mul_mat(ctx0, model.wte, inpL);
|
||||
|
||||
// logits -> probs
|
||||
inpL = ggml_soft_max(ctx0, inpL);
|
||||
|
||||
// run the computation
|
||||
ggml_build_forward_expand(&gf, inpL);
|
||||
ggml_graph_compute (ctx0, &gf);
|
||||
|
||||
//if (n_past%100 == 0) {
|
||||
// ggml_graph_print (&gf);
|
||||
// ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot");
|
||||
//}
|
||||
|
||||
//embd_w.resize(n_vocab*N);
|
||||
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
|
||||
|
||||
// return result for just the last token
|
||||
embd_w.resize(n_vocab);
|
||||
memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
|
||||
|
||||
if (mem_per_token == 0) {
|
||||
mem_per_token = ggml_used_mem(ctx0)/N;
|
||||
}
|
||||
//printf("used_mem = %zu\n", ggml_used_mem(ctx0));
|
||||
|
||||
ggml_free(ctx0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/////////////////////////////// GPT-2 END ////////////////////////////////
|
||||
|
||||
constexpr int N_THREAD = 8;
|
||||
|
||||
struct gpt2_context {
|
||||
std::string prompt_base = R"(Hello, how are you?
|
||||
I'm fine, thanks. How are you?
|
||||
Thanks, I'm fine too. What are you doing?
|
||||
I'm just sitting here.
|
||||
It's a lovely day, isn't it?
|
||||
Yes, it is. I love the weather this time of year.
|
||||
I wish it would rain a little bit.
|
||||
Me too.
|
||||
)";
|
||||
|
||||
std::mt19937 rng;
|
||||
|
||||
gpt_vocab vocab;
|
||||
gpt2_model model;
|
||||
|
||||
int32_t n_threads = std::min(N_THREAD, (int) std::thread::hardware_concurrency());
|
||||
|
||||
// sampling parameters
|
||||
int32_t top_k = 5;
|
||||
float top_p = 0.9f;
|
||||
float temp = 1.0f;
|
||||
};
|
||||
|
||||
struct gpt2_context * gpt2_init(const char * path_model) {
|
||||
gpt2_context * ctx = new gpt2_context;
|
||||
|
||||
ctx->rng = std::mt19937(time(nullptr));
|
||||
|
||||
// load the model
|
||||
{
|
||||
const int64_t t_start_us = ggml_time_us();
|
||||
|
||||
if (!gpt2_model_load(path_model, ctx->model, ctx->vocab)) {
|
||||
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, path_model);
|
||||
delete ctx;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const int64_t t_load_us = ggml_time_us() - t_start_us;
|
||||
|
||||
printf("gpt-2: model loaded in %d ms\n", (int) (t_load_us/1000));
|
||||
}
|
||||
|
||||
return ctx;
|
||||
}
|
||||
|
||||
void gpt2_free(struct gpt2_context * ctx) {
|
||||
delete ctx;
|
||||
}
|
||||
|
||||
const char * gpt2_get_prompt(struct gpt2_context * ctx) {
|
||||
return ctx->prompt_base.c_str();
|
||||
}
|
||||
|
||||
void gpt2_set_prompt(struct gpt2_context * ctx, const char * prompt) {
|
||||
ctx->prompt_base = prompt;
|
||||
}
|
||||
|
||||
std::vector<gpt_vocab::id> gpt2_tokenize(const gpt2_context * ctx, const char * text) {
|
||||
return ::gpt_tokenize(ctx->vocab, text);
|
||||
}
|
||||
|
||||
std::string gpt2_gen_text(gpt2_context * ctx, const char * text, int max_tokens) {
|
||||
int n_past = 0;
|
||||
|
||||
std::vector<float> embd_w;
|
||||
|
||||
// tokenize the prompt
|
||||
std::vector<gpt_vocab::id> embd_inp = ::gpt2_tokenize(ctx, text);
|
||||
|
||||
int n_predict = std::min(max_tokens, ctx->model.hparams.n_ctx - (int) embd_inp.size());
|
||||
|
||||
std::vector<gpt_vocab::id> embd = embd_inp;
|
||||
|
||||
size_t mem_per_token = 3000000;
|
||||
|
||||
std::string result;
|
||||
|
||||
for (int i = embd.size(); i < (int) embd_inp.size() + n_predict; i++) {
|
||||
// predict
|
||||
if (!embd.empty()) {
|
||||
if (!gpt2_eval(ctx->model, ctx->n_threads, n_past, embd, embd_w, mem_per_token)) {
|
||||
printf("gpt-2: failed to generate text\n");
|
||||
return "";
|
||||
}
|
||||
}
|
||||
|
||||
n_past += embd.size();
|
||||
embd.clear();
|
||||
|
||||
{
|
||||
// sample next token
|
||||
const int top_k = ctx->top_k;
|
||||
const float top_p = ctx->top_p;
|
||||
const float temp = ctx->temp;
|
||||
|
||||
const int n_vocab = ctx->model.hparams.n_vocab;
|
||||
|
||||
const gpt_vocab::id id = gpt_sample_top_k_top_p(ctx->vocab, embd_w.data() + (embd_w.size() - n_vocab), top_k, top_p, temp, ctx->rng);
|
||||
|
||||
// add it to the context
|
||||
embd.push_back(id);
|
||||
}
|
||||
|
||||
result += ctx->vocab.id_to_token[embd[0]];
|
||||
|
||||
// end of text token
|
||||
if (embd.back() == 50256) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
@ -1,27 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
// TODO: Change to C-style API and move to ./examples for easy reuse.
|
||||
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <string>
|
||||
|
||||
struct gpt_vocab {
|
||||
using id = int32_t;
|
||||
using token = std::string;
|
||||
|
||||
std::map<token, id> token_to_id;
|
||||
std::map<id, token> id_to_token;
|
||||
};
|
||||
|
||||
struct gpt2_context;
|
||||
|
||||
struct gpt2_context * gpt2_init(const char * path_model);
|
||||
void gpt2_free(struct gpt2_context * ctx);
|
||||
|
||||
const char * gpt2_get_prompt(struct gpt2_context * ctx);
|
||||
void gpt2_set_prompt(struct gpt2_context * ctx, const char * prompt);
|
||||
|
||||
std::vector<gpt_vocab::id> gpt2_tokenize(const gpt2_context * ctx, const char * text);
|
||||
|
||||
std::string gpt2_gen_text(gpt2_context * ctx, const char * text, int max_tokens);
|
@ -1,17 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Usage:
|
||||
# speak.sh <voice_id> <text-to-speak>
|
||||
|
||||
# espeak
|
||||
# Mac OS: brew install espeak
|
||||
# Linux: apt-get install espeak
|
||||
#
|
||||
espeak -v en-us+m$1 -s 175 -p 50 -a 200 -g 5 -k 5 "$2"
|
||||
|
||||
# Eleven Labs
|
||||
#
|
||||
#wd=$(dirname $0)
|
||||
#script=$wd/eleven-labs.py
|
||||
#python3 $script $1 "$2"
|
||||
#ffplay -autoexit -nodisp -loglevel quiet -hide_banner -i ./audio.mp3
|
@ -1,695 +0,0 @@
|
||||
// Talk with AI
|
||||
//
|
||||
|
||||
#include "whisper.h"
|
||||
#include "gpt-2.h"
|
||||
|
||||
#include <SDL.h>
|
||||
#include <SDL_audio.h>
|
||||
|
||||
#include <cassert>
|
||||
#include <cstdio>
|
||||
#include <fstream>
|
||||
#include <mutex>
|
||||
#include <regex>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <regex>
|
||||
|
||||
// command-line parameters
|
||||
struct whisper_params {
|
||||
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||
int32_t voice_ms = 10000;
|
||||
int32_t capture_id = -1;
|
||||
int32_t max_tokens = 32;
|
||||
int32_t audio_ctx = 0;
|
||||
|
||||
float vad_thold = 0.6f;
|
||||
float freq_thold = 100.0f;
|
||||
|
||||
bool speed_up = false;
|
||||
bool translate = false;
|
||||
bool print_special = false;
|
||||
bool print_energy = false;
|
||||
bool no_timestamps = true;
|
||||
|
||||
std::string person = "Santa";
|
||||
std::string language = "en";
|
||||
std::string model_wsp = "models/ggml-base.en.bin";
|
||||
std::string model_gpt = "models/ggml-gpt-2-117M.bin";
|
||||
std::string speak = "./examples/talk/speak.sh";
|
||||
std::string fname_out;
|
||||
};
|
||||
|
||||
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
|
||||
|
||||
bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
for (int i = 1; i < argc; i++) {
|
||||
std::string arg = argv[i];
|
||||
|
||||
if (arg == "-h" || arg == "--help") {
|
||||
whisper_print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
}
|
||||
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
|
||||
else if (arg == "-vms" || arg == "--voice-ms") { params.voice_ms = std::stoi(argv[++i]); }
|
||||
else if (arg == "-c" || arg == "--capture") { params.capture_id = std::stoi(argv[++i]); }
|
||||
else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); }
|
||||
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
|
||||
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
|
||||
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
|
||||
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
|
||||
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
|
||||
else if (arg == "-p" || arg == "--person") { params.person = argv[++i]; }
|
||||
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
|
||||
else if (arg == "-mw" || arg == "--model-whisper") { params.model_wsp = argv[++i]; }
|
||||
else if (arg == "-mg" || arg == "--model-gpt") { params.model_gpt = argv[++i]; }
|
||||
else if (arg == "-s" || arg == "--speak") { params.speak = argv[++i]; }
|
||||
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
|
||||
else {
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
whisper_print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params) {
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "usage: %s [options]\n", argv[0]);
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "options:\n");
|
||||
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
|
||||
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
|
||||
fprintf(stderr, " -vms N, --voice-ms N [%-7d] voice duration in milliseconds\n", params.voice_ms);
|
||||
fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
|
||||
fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
|
||||
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
|
||||
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
|
||||
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
|
||||
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
|
||||
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
|
||||
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
||||
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
|
||||
fprintf(stderr, " -p NAME, --person NAME [%-7s] person name (for prompt selection)\n", params.person.c_str());
|
||||
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
|
||||
fprintf(stderr, " -mw FILE, --model-whisper [%-7s] whisper model file\n", params.model_wsp.c_str());
|
||||
fprintf(stderr, " -mg FILE, --model-gpt [%-7s] gpt model file\n", params.model_gpt.c_str());
|
||||
fprintf(stderr, " -s FILE, --speak TEXT [%-7s] command for TTS\n", params.speak.c_str());
|
||||
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
//
|
||||
// SDL Audio capture
|
||||
//
|
||||
|
||||
class audio_async {
|
||||
public:
|
||||
audio_async(int len_ms);
|
||||
~audio_async();
|
||||
|
||||
bool init(int capture_id, int sample_rate);
|
||||
|
||||
// start capturing audio via the provided SDL callback
|
||||
// keep last len_ms seconds of audio in a circular buffer
|
||||
bool resume();
|
||||
bool pause();
|
||||
bool clear();
|
||||
|
||||
// callback to be called by SDL
|
||||
void callback(uint8_t * stream, int len);
|
||||
|
||||
// get audio data from the circular buffer
|
||||
void get(int ms, std::vector<float> & audio);
|
||||
|
||||
private:
|
||||
SDL_AudioDeviceID m_dev_id_in = 0;
|
||||
|
||||
int m_len_ms = 0;
|
||||
int m_sample_rate = 0;
|
||||
|
||||
bool m_running = false;
|
||||
std::mutex m_mutex;
|
||||
|
||||
std::vector<float> m_audio;
|
||||
std::vector<float> m_audio_new;
|
||||
size_t m_audio_pos = 0;
|
||||
size_t m_audio_len = 0;
|
||||
};
|
||||
|
||||
audio_async::audio_async(int len_ms) {
|
||||
m_len_ms = len_ms;
|
||||
}
|
||||
|
||||
audio_async::~audio_async() {
|
||||
if (m_dev_id_in) {
|
||||
SDL_CloseAudioDevice(m_dev_id_in);
|
||||
}
|
||||
}
|
||||
|
||||
bool audio_async::init(int capture_id, int sample_rate) {
|
||||
SDL_LogSetPriority(SDL_LOG_CATEGORY_APPLICATION, SDL_LOG_PRIORITY_INFO);
|
||||
|
||||
if (SDL_Init(SDL_INIT_AUDIO) < 0) {
|
||||
SDL_LogError(SDL_LOG_CATEGORY_APPLICATION, "Couldn't initialize SDL: %s\n", SDL_GetError());
|
||||
return false;
|
||||
}
|
||||
|
||||
SDL_SetHintWithPriority(SDL_HINT_AUDIO_RESAMPLING_MODE, "medium", SDL_HINT_OVERRIDE);
|
||||
|
||||
{
|
||||
int nDevices = SDL_GetNumAudioDevices(SDL_TRUE);
|
||||
fprintf(stderr, "%s: found %d capture devices:\n", __func__, nDevices);
|
||||
for (int i = 0; i < nDevices; i++) {
|
||||
fprintf(stderr, "%s: - Capture device #%d: '%s'\n", __func__, i, SDL_GetAudioDeviceName(i, SDL_TRUE));
|
||||
}
|
||||
}
|
||||
|
||||
SDL_AudioSpec capture_spec_requested;
|
||||
SDL_AudioSpec capture_spec_obtained;
|
||||
|
||||
SDL_zero(capture_spec_requested);
|
||||
SDL_zero(capture_spec_obtained);
|
||||
|
||||
capture_spec_requested.freq = sample_rate;
|
||||
capture_spec_requested.format = AUDIO_F32;
|
||||
capture_spec_requested.channels = 1;
|
||||
capture_spec_requested.samples = 1024;
|
||||
capture_spec_requested.callback = [](void * userdata, uint8_t * stream, int len) {
|
||||
audio_async * audio = (audio_async *) userdata;
|
||||
audio->callback(stream, len);
|
||||
};
|
||||
capture_spec_requested.userdata = this;
|
||||
|
||||
if (capture_id >= 0) {
|
||||
fprintf(stderr, "%s: attempt to open capture device %d : '%s' ...\n", __func__, capture_id, SDL_GetAudioDeviceName(capture_id, SDL_TRUE));
|
||||
m_dev_id_in = SDL_OpenAudioDevice(SDL_GetAudioDeviceName(capture_id, SDL_TRUE), SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
|
||||
} else {
|
||||
fprintf(stderr, "%s: attempt to open default capture device ...\n", __func__);
|
||||
m_dev_id_in = SDL_OpenAudioDevice(nullptr, SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
|
||||
}
|
||||
|
||||
if (!m_dev_id_in) {
|
||||
fprintf(stderr, "%s: couldn't open an audio device for capture: %s!\n", __func__, SDL_GetError());
|
||||
m_dev_id_in = 0;
|
||||
|
||||
return false;
|
||||
} else {
|
||||
fprintf(stderr, "%s: obtained spec for input device (SDL Id = %d):\n", __func__, m_dev_id_in);
|
||||
fprintf(stderr, "%s: - sample rate: %d\n", __func__, capture_spec_obtained.freq);
|
||||
fprintf(stderr, "%s: - format: %d (required: %d)\n", __func__, capture_spec_obtained.format,
|
||||
capture_spec_requested.format);
|
||||
fprintf(stderr, "%s: - channels: %d (required: %d)\n", __func__, capture_spec_obtained.channels,
|
||||
capture_spec_requested.channels);
|
||||
fprintf(stderr, "%s: - samples per frame: %d\n", __func__, capture_spec_obtained.samples);
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
m_sample_rate = capture_spec_obtained.freq;
|
||||
|
||||
m_audio.resize((m_sample_rate*m_len_ms)/1000);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool audio_async::resume() {
|
||||
if (!m_dev_id_in) {
|
||||
fprintf(stderr, "%s: no audio device to resume!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (m_running) {
|
||||
fprintf(stderr, "%s: already running!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
SDL_PauseAudioDevice(m_dev_id_in, 0);
|
||||
|
||||
m_running = true;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool audio_async::pause() {
|
||||
if (!m_dev_id_in) {
|
||||
fprintf(stderr, "%s: no audio device to pause!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!m_running) {
|
||||
fprintf(stderr, "%s: already paused!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
SDL_PauseAudioDevice(m_dev_id_in, 1);
|
||||
|
||||
m_running = false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool audio_async::clear() {
|
||||
if (!m_dev_id_in) {
|
||||
fprintf(stderr, "%s: no audio device to clear!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!m_running) {
|
||||
fprintf(stderr, "%s: not running!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(m_mutex);
|
||||
|
||||
m_audio_pos = 0;
|
||||
m_audio_len = 0;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// callback to be called by SDL
|
||||
void audio_async::callback(uint8_t * stream, int len) {
|
||||
if (!m_running) {
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t n_samples = len / sizeof(float);
|
||||
|
||||
m_audio_new.resize(n_samples);
|
||||
memcpy(m_audio_new.data(), stream, n_samples * sizeof(float));
|
||||
|
||||
//fprintf(stderr, "%s: %zu samples, pos %zu, len %zu\n", __func__, n_samples, m_audio_pos, m_audio_len);
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(m_mutex);
|
||||
|
||||
if (m_audio_pos + n_samples > m_audio.size()) {
|
||||
const size_t n0 = m_audio.size() - m_audio_pos;
|
||||
|
||||
memcpy(&m_audio[m_audio_pos], stream, n0 * sizeof(float));
|
||||
memcpy(&m_audio[0], &stream[n0], (n_samples - n0) * sizeof(float));
|
||||
|
||||
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
|
||||
m_audio_len = m_audio.size();
|
||||
} else {
|
||||
memcpy(&m_audio[m_audio_pos], stream, n_samples * sizeof(float));
|
||||
|
||||
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
|
||||
m_audio_len = std::min(m_audio_len + n_samples, m_audio.size());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void audio_async::get(int ms, std::vector<float> & result) {
|
||||
if (!m_dev_id_in) {
|
||||
fprintf(stderr, "%s: no audio device to get audio from!\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!m_running) {
|
||||
fprintf(stderr, "%s: not running!\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
result.clear();
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(m_mutex);
|
||||
|
||||
if (ms <= 0) {
|
||||
ms = m_len_ms;
|
||||
}
|
||||
|
||||
size_t n_samples = (m_sample_rate * ms) / 1000;
|
||||
if (n_samples > m_audio_len) {
|
||||
n_samples = m_audio_len;
|
||||
}
|
||||
|
||||
result.resize(n_samples);
|
||||
|
||||
int s0 = m_audio_pos - n_samples;
|
||||
if (s0 < 0) {
|
||||
s0 += m_audio.size();
|
||||
}
|
||||
|
||||
if (s0 + n_samples > m_audio.size()) {
|
||||
const size_t n0 = m_audio.size() - s0;
|
||||
|
||||
memcpy(result.data(), &m_audio[s0], n0 * sizeof(float));
|
||||
memcpy(&result[n0], &m_audio[0], (n_samples - n0) * sizeof(float));
|
||||
} else {
|
||||
memcpy(result.data(), &m_audio[s0], n_samples * sizeof(float));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////
|
||||
|
||||
std::string trim(const std::string & s) {
|
||||
std::regex e("^\\s+|\\s+$");
|
||||
return std::regex_replace(s, e, "");
|
||||
}
|
||||
|
||||
std::string replace(const std::string & s, const std::string & from, const std::string & to) {
|
||||
std::string result = s;
|
||||
size_t pos = 0;
|
||||
while ((pos = result.find(from, pos)) != std::string::npos) {
|
||||
result.replace(pos, from.length(), to);
|
||||
pos += to.length();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate) {
|
||||
const float rc = 1.0f / (2.0f * M_PI * cutoff);
|
||||
const float dt = 1.0f / sample_rate;
|
||||
const float alpha = dt / (rc + dt);
|
||||
|
||||
float y = data[0];
|
||||
|
||||
for (size_t i = 1; i < data.size(); i++) {
|
||||
y = alpha * (y + data[i] - data[i - 1]);
|
||||
data[i] = y;
|
||||
}
|
||||
}
|
||||
|
||||
bool vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) {
|
||||
const int n_samples = pcmf32.size();
|
||||
const int n_samples_last = (sample_rate * last_ms) / 1000;
|
||||
|
||||
if (n_samples_last >= n_samples) {
|
||||
// not enough samples - assume no speech
|
||||
return false;
|
||||
}
|
||||
|
||||
if (freq_thold > 0.0f) {
|
||||
high_pass_filter(pcmf32, freq_thold, sample_rate);
|
||||
}
|
||||
|
||||
float energy_all = 0.0f;
|
||||
float energy_last = 0.0f;
|
||||
|
||||
for (int i = 0; i < n_samples; i++) {
|
||||
energy_all += fabsf(pcmf32[i]);
|
||||
|
||||
if (i >= n_samples - n_samples_last) {
|
||||
energy_last += fabsf(pcmf32[i]);
|
||||
}
|
||||
}
|
||||
|
||||
energy_all /= n_samples;
|
||||
energy_last /= n_samples_last;
|
||||
|
||||
if (verbose) {
|
||||
fprintf(stderr, "%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold);
|
||||
}
|
||||
|
||||
if (energy_last > vad_thold*energy_all) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string transcribe(whisper_context * ctx, const whisper_params & params, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) {
|
||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
prob = 0.0f;
|
||||
t_ms = 0;
|
||||
|
||||
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||
|
||||
wparams.print_progress = false;
|
||||
wparams.print_special = params.print_special;
|
||||
wparams.print_realtime = false;
|
||||
wparams.print_timestamps = !params.no_timestamps;
|
||||
wparams.translate = params.translate;
|
||||
wparams.no_context = true;
|
||||
wparams.single_segment = true;
|
||||
wparams.max_tokens = params.max_tokens;
|
||||
wparams.language = params.language.c_str();
|
||||
wparams.n_threads = params.n_threads;
|
||||
|
||||
wparams.audio_ctx = params.audio_ctx;
|
||||
wparams.speed_up = params.speed_up;
|
||||
|
||||
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
|
||||
return "";
|
||||
}
|
||||
|
||||
int prob_n = 0;
|
||||
std::string result;
|
||||
|
||||
const int n_segments = whisper_full_n_segments(ctx);
|
||||
for (int i = 0; i < n_segments; ++i) {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
|
||||
result += text;
|
||||
|
||||
const int n_tokens = whisper_full_n_tokens(ctx, i);
|
||||
for (int j = 0; j < n_tokens; ++j) {
|
||||
const auto token = whisper_full_get_token_data(ctx, i, j);
|
||||
|
||||
prob += token.p;
|
||||
++prob_n;
|
||||
}
|
||||
}
|
||||
|
||||
if (prob_n > 0) {
|
||||
prob /= prob_n;
|
||||
}
|
||||
|
||||
const auto t_end = std::chrono::high_resolution_clock::now();
|
||||
t_ms = std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count();
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
const std::string k_prompt =
|
||||
R"(This is a dialogue between {0} (A) and a person (B). The dialogue so far is:
|
||||
|
||||
B: Hello {0}, how are you?
|
||||
A: I'm fine, thank you.
|
||||
{1}
|
||||
Here is how {0} (A) continues the dialogue:
|
||||
|
||||
A:)";
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
whisper_params params;
|
||||
|
||||
if (whisper_params_parse(argc, argv, params) == false) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (whisper_lang_id(params.language.c_str()) == -1) {
|
||||
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
|
||||
whisper_print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
}
|
||||
|
||||
// whisper init
|
||||
|
||||
struct whisper_context * ctx_wsp = whisper_init_from_file(params.model_wsp.c_str());
|
||||
|
||||
// gpt init
|
||||
|
||||
struct gpt2_context * ctx_gpt = gpt2_init(params.model_gpt.c_str());
|
||||
|
||||
// print some info about the processing
|
||||
{
|
||||
fprintf(stderr, "\n");
|
||||
if (!whisper_is_multilingual(ctx_wsp)) {
|
||||
if (params.language != "en" || params.translate) {
|
||||
params.language = "en";
|
||||
params.translate = false;
|
||||
fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
|
||||
}
|
||||
}
|
||||
fprintf(stderr, "%s: processing, %d threads, lang = %s, task = %s, timestamps = %d ...\n",
|
||||
__func__,
|
||||
params.n_threads,
|
||||
params.language.c_str(),
|
||||
params.translate ? "translate" : "transcribe",
|
||||
params.no_timestamps ? 0 : 1);
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
|
||||
// init audio
|
||||
|
||||
audio_async audio(30*1000);
|
||||
if (!audio.init(params.capture_id, WHISPER_SAMPLE_RATE)) {
|
||||
fprintf(stderr, "%s: audio.init() failed!\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
audio.resume();
|
||||
|
||||
int n_iter = 0;
|
||||
|
||||
bool is_running = true;
|
||||
bool force_speak = false;
|
||||
|
||||
float prob0 = 0.0f;
|
||||
|
||||
std::vector<float> pcmf32_cur;
|
||||
std::vector<float> pcmf32_prompt;
|
||||
|
||||
gpt2_set_prompt(ctx_gpt, "");
|
||||
|
||||
const int voice_id = rand()%6;
|
||||
|
||||
fprintf(stderr, "gpt-2: prompt:\n");
|
||||
fprintf(stderr, "========================\n\n");
|
||||
fprintf(stderr, "%s\n", ::replace(k_prompt, "{0}", params.person).c_str());
|
||||
fprintf(stderr, "========================\n\n");
|
||||
|
||||
// main loop
|
||||
while (is_running) {
|
||||
// handle Ctrl + C
|
||||
{
|
||||
SDL_Event event;
|
||||
while (SDL_PollEvent(&event)) {
|
||||
switch (event.type) {
|
||||
case SDL_QUIT:
|
||||
{
|
||||
is_running = false;
|
||||
} break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!is_running) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// delay
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
|
||||
int64_t t_ms = 0;
|
||||
|
||||
{
|
||||
audio.get(2000, pcmf32_cur);
|
||||
|
||||
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1250, params.vad_thold, params.freq_thold, params.print_energy) || force_speak) {
|
||||
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
|
||||
|
||||
audio.get(params.voice_ms, pcmf32_cur);
|
||||
|
||||
std::string text_heard;
|
||||
|
||||
if (!force_speak) {
|
||||
text_heard = ::trim(::transcribe(ctx_wsp, params, pcmf32_cur, prob0, t_ms));
|
||||
}
|
||||
|
||||
// remove text between brackets using regex
|
||||
{
|
||||
std::regex re("\\[.*?\\]");
|
||||
text_heard = std::regex_replace(text_heard, re, "");
|
||||
}
|
||||
|
||||
// remove text between brackets using regex
|
||||
{
|
||||
std::regex re("\\(.*?\\)");
|
||||
text_heard = std::regex_replace(text_heard, re, "");
|
||||
}
|
||||
|
||||
// remove all characters, except for letters, numbers, punctuation and ':', '\'', '-', ' '
|
||||
text_heard = std::regex_replace(text_heard, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
|
||||
|
||||
// take first line
|
||||
text_heard = text_heard.substr(0, text_heard.find_first_of('\n'));
|
||||
|
||||
// remove leading and trailing whitespace
|
||||
text_heard = std::regex_replace(text_heard, std::regex("^\\s+"), "");
|
||||
text_heard = std::regex_replace(text_heard, std::regex("\\s+$"), "");
|
||||
|
||||
const std::vector<gpt_vocab::id> tokens = gpt2_tokenize(ctx_gpt, text_heard.c_str());
|
||||
|
||||
if (text_heard.empty() || tokens.empty() || force_speak) {
|
||||
fprintf(stdout, "%s: Heard nothing, skipping ...\n", __func__);
|
||||
audio.clear();
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
force_speak = false;
|
||||
|
||||
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", text_heard.c_str(), "\033[0m", (int) t_ms);
|
||||
|
||||
std::string prompt_base = gpt2_get_prompt(ctx_gpt);
|
||||
|
||||
std::string text_to_speak;
|
||||
|
||||
{
|
||||
prompt_base += "B: " + text_heard + "\n";
|
||||
|
||||
std::string prompt = ::replace(::replace(k_prompt, "{0}", params.person), "{1}", prompt_base);
|
||||
|
||||
text_to_speak = gpt2_gen_text(ctx_gpt, prompt.c_str(), params.max_tokens);
|
||||
text_to_speak = std::regex_replace(text_to_speak, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
|
||||
text_to_speak = text_to_speak.substr(0, text_to_speak.find_first_of('\n'));
|
||||
|
||||
// remove first 2 lines of base prompt
|
||||
if (n_iter > 4) {
|
||||
{
|
||||
const size_t pos = prompt_base.find_first_of('\n');
|
||||
if (pos != std::string::npos) {
|
||||
prompt_base = prompt_base.substr(pos + 1);
|
||||
}
|
||||
}
|
||||
{
|
||||
const size_t pos = prompt_base.find_first_of('\n');
|
||||
if (pos != std::string::npos) {
|
||||
prompt_base = prompt_base.substr(pos + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
prompt_base += "A:" + text_to_speak + "\n";
|
||||
|
||||
{
|
||||
prompt = ::replace(::replace(k_prompt, "{0}", params.person), "{1}", prompt_base);
|
||||
|
||||
printf("===============\n");
|
||||
printf("prompt:\n");
|
||||
printf("%s\n", prompt.c_str());
|
||||
printf("===============\n");
|
||||
}
|
||||
}
|
||||
|
||||
//printf("========================\n");
|
||||
//printf("gpt-2: prompt_base:\n%s\n", prompt_base.c_str());
|
||||
//printf("========================\n");
|
||||
|
||||
gpt2_set_prompt(ctx_gpt, prompt_base.c_str());
|
||||
|
||||
text_to_speak = ::replace(text_to_speak, params.person + ": ", "");
|
||||
system((params.speak + " " + std::to_string(voice_id) + " \"" + text_to_speak + "\"").c_str());
|
||||
|
||||
audio.clear();
|
||||
|
||||
++n_iter;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
audio.pause();
|
||||
|
||||
whisper_print_timings(ctx_wsp);
|
||||
whisper_free(ctx_wsp);
|
||||
|
||||
return 0;
|
||||
}
|
@ -1,109 +0,0 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# Transcribe twitch.tv livestream by feeding audio input to whisper.cpp at regular intervals
|
||||
# Thanks to @keyehzy
|
||||
# ref: https://github.com/ggerganov/whisper.cpp/issues/209
|
||||
#
|
||||
# The script currently depends on the third-party tool "streamlink"
|
||||
# On Mac OS, you can install it via "brew install streamlink"
|
||||
#
|
||||
|
||||
set -eo pipefail
|
||||
|
||||
step=10
|
||||
model=base.en
|
||||
threads=4
|
||||
|
||||
help()
|
||||
{
|
||||
echo "Example program for captioning a livestream from twitch.tv."
|
||||
echo
|
||||
echo "Usage: ./twitch.sh -s [step] -m [model] -t [threads] [url]"
|
||||
echo "options:"
|
||||
echo "-s Step in seconds (default is $step)."
|
||||
echo "-m Choose model, options are: 'tiny.en' 'tiny' 'base.en' 'base' 'small.en' 'small' 'medium.en' 'medium' 'large-v1' 'large' (default is '$model')."
|
||||
echo "-t Number of threads to use."
|
||||
echo "-h Print this help page."
|
||||
echo
|
||||
}
|
||||
|
||||
check_requirements()
|
||||
{
|
||||
if ! command -v ./main &>/dev/null; then
|
||||
echo "whisper.cpp main executable is required (make)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! command -v streamlink &>/dev/null; then
|
||||
echo "streamlink is required (https://streamlink.github.io)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! command -v ffmpeg &>/dev/null; then
|
||||
echo "ffmpeg is required (https://ffmpeg.org)"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
check_requirements
|
||||
|
||||
while getopts ":s:m:t:h" option; do
|
||||
case $option in
|
||||
s)
|
||||
step=$OPTARG;;
|
||||
m)
|
||||
model=$OPTARG;;
|
||||
t)
|
||||
threads=$OPTARG;;
|
||||
h)
|
||||
help
|
||||
exit;;
|
||||
\?)
|
||||
help
|
||||
exit;;
|
||||
esac
|
||||
done
|
||||
|
||||
url=${@:$OPTIND:1}
|
||||
|
||||
if [ -z $url ]; then
|
||||
help
|
||||
exit
|
||||
fi
|
||||
|
||||
echo "Piping from streamlink url=$url model=$model step=$step threads=$threads"
|
||||
streamlink $url best -O 2>/dev/null | ffmpeg -loglevel quiet -i - -y -probesize 32 -y -ar 16000 -ac 1 -acodec pcm_s16le /tmp/whisper-live0.wav &
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
printf "error: ffmpeg failed\n"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Buffering stream... (this should take $step seconds)"
|
||||
sleep $(($step))
|
||||
|
||||
set +e
|
||||
|
||||
echo "Starting..."
|
||||
|
||||
i=0
|
||||
SECONDS=0
|
||||
while true
|
||||
do
|
||||
err=1
|
||||
while [ $err -ne 0 ]; do
|
||||
if [ $i -gt 0 ]; then
|
||||
ffmpeg -loglevel quiet -v error -noaccurate_seek -i /tmp/whisper-live0.wav -y -ss $(($i*$step-1)).5 -t $step -c copy /tmp/whisper-live.wav 2> /tmp/whisper-live.err
|
||||
else
|
||||
ffmpeg -loglevel quiet -v error -noaccurate_seek -i /tmp/whisper-live0.wav -y -ss $(($i*$step)) -t $step -c copy /tmp/whisper-live.wav 2> /tmp/whisper-live.err
|
||||
fi
|
||||
err=$(cat /tmp/whisper-live.err | wc -l)
|
||||
done
|
||||
|
||||
./main -t $threads -m ./models/ggml-$model.bin -f /tmp/whisper-live.wav --no-timestamps -otxt 2> /tmp/whispererr | tail -n 1
|
||||
|
||||
while [ $SECONDS -lt $((($i+1)*$step)) ]; do
|
||||
sleep 1
|
||||
done
|
||||
((i=i+1))
|
||||
done
|
15
examples/whisper.android/.gitignore
vendored
15
examples/whisper.android/.gitignore
vendored
@ -1,15 +0,0 @@
|
||||
*.iml
|
||||
.gradle
|
||||
/local.properties
|
||||
/.idea/caches
|
||||
/.idea/libraries
|
||||
/.idea/modules.xml
|
||||
/.idea/workspace.xml
|
||||
/.idea/navEditor.xml
|
||||
/.idea/assetWizardSettings.xml
|
||||
.DS_Store
|
||||
/build
|
||||
/captures
|
||||
.externalNativeBuild
|
||||
.cxx
|
||||
local.properties
|
3
examples/whisper.android/.idea/.gitignore
generated
vendored
3
examples/whisper.android/.idea/.gitignore
generated
vendored
@ -1,3 +0,0 @@
|
||||
# Default ignored files
|
||||
/shelf/
|
||||
/workspace.xml
|
1
examples/whisper.android/.idea/.name
generated
1
examples/whisper.android/.idea/.name
generated
@ -1 +0,0 @@
|
||||
WhisperCppDemo
|
6
examples/whisper.android/.idea/compiler.xml
generated
6
examples/whisper.android/.idea/compiler.xml
generated
@ -1,6 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="CompilerConfiguration">
|
||||
<bytecodeTargetLevel target="11" />
|
||||
</component>
|
||||
</project>
|
19
examples/whisper.android/.idea/gradle.xml
generated
19
examples/whisper.android/.idea/gradle.xml
generated
@ -1,19 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="GradleMigrationSettings" migrationVersion="1" />
|
||||
<component name="GradleSettings">
|
||||
<option name="linkedExternalProjectsSettings">
|
||||
<GradleProjectSettings>
|
||||
<option name="testRunner" value="GRADLE" />
|
||||
<option name="distributionType" value="DEFAULT_WRAPPED" />
|
||||
<option name="externalProjectPath" value="$PROJECT_DIR$" />
|
||||
<option name="modules">
|
||||
<set>
|
||||
<option value="$PROJECT_DIR$" />
|
||||
<option value="$PROJECT_DIR$/app" />
|
||||
</set>
|
||||
</option>
|
||||
</GradleProjectSettings>
|
||||
</option>
|
||||
</component>
|
||||
</project>
|
10
examples/whisper.android/.idea/misc.xml
generated
10
examples/whisper.android/.idea/misc.xml
generated
@ -1,10 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ExternalStorageConfigurationManager" enabled="true" />
|
||||
<component name="ProjectRootManager" version="2" languageLevel="JDK_11" default="true" project-jdk-name="Android Studio default JDK" project-jdk-type="JavaSDK">
|
||||
<output url="file://$PROJECT_DIR$/build/classes" />
|
||||
</component>
|
||||
<component name="ProjectType">
|
||||
<option name="id" value="Android" />
|
||||
</component>
|
||||
</project>
|
6
examples/whisper.android/.idea/vcs.xml
generated
6
examples/whisper.android/.idea/vcs.xml
generated
@ -1,6 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="$PROJECT_DIR$/../.." vcs="Git" />
|
||||
</component>
|
||||
</project>
|
@ -1,12 +0,0 @@
|
||||
A sample Android app using [whisper.cpp](https://github.com/ggerganov/whisper.cpp/) to do voice-to-text transcriptions.
|
||||
|
||||
To use:
|
||||
|
||||
1. Select a model from the [whisper.cpp repository](https://github.com/ggerganov/whisper.cpp/tree/master/models).[^1]
|
||||
2. Copy the model to the "app/src/main/assets/models" folder.
|
||||
3. Select a sample audio file (for example, [jfk.wav](https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav)).
|
||||
4. Copy the sample to the "app/src/main/assets/samples" folder.
|
||||
5. Select the "release" active build variant, and use Android Studio to run and deploy to your device.
|
||||
[^1]: I recommend the tiny or base models for running on an Android device.
|
||||
|
||||
<img width="300" alt="image" src="https://user-images.githubusercontent.com/1991296/208154256-82d972dc-221b-48c4-bfcb-36ce68602f93.png">
|
1
examples/whisper.android/app/.gitignore
vendored
1
examples/whisper.android/app/.gitignore
vendored
@ -1 +0,0 @@
|
||||
/build
|
@ -1,72 +0,0 @@
|
||||
plugins {
|
||||
id 'com.android.application'
|
||||
id 'org.jetbrains.kotlin.android'
|
||||
}
|
||||
|
||||
android {
|
||||
namespace 'com.whispercppdemo'
|
||||
compileSdk 33
|
||||
|
||||
defaultConfig {
|
||||
applicationId "com.whispercppdemo"
|
||||
minSdk 26
|
||||
targetSdk 32
|
||||
versionCode 1
|
||||
versionName "1.0"
|
||||
|
||||
testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
|
||||
vectorDrawables {
|
||||
useSupportLibrary true
|
||||
}
|
||||
}
|
||||
|
||||
buildTypes {
|
||||
release {
|
||||
signingConfig signingConfigs.debug
|
||||
minifyEnabled true
|
||||
proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
|
||||
}
|
||||
}
|
||||
compileOptions {
|
||||
sourceCompatibility JavaVersion.VERSION_1_8
|
||||
targetCompatibility JavaVersion.VERSION_1_8
|
||||
}
|
||||
kotlinOptions {
|
||||
jvmTarget = '1.8'
|
||||
}
|
||||
buildFeatures {
|
||||
compose true
|
||||
}
|
||||
composeOptions {
|
||||
kotlinCompilerExtensionVersion '1.3.1'
|
||||
}
|
||||
ndkVersion "25.1.8937393"
|
||||
externalNativeBuild {
|
||||
ndkBuild {
|
||||
path 'src/main/jni/whisper/Android.mk'
|
||||
}
|
||||
}
|
||||
packagingOptions {
|
||||
resources {
|
||||
excludes += '/META-INF/{AL2.0,LGPL2.1}'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
dependencies {
|
||||
implementation 'androidx.activity:activity-compose:1.6.1'
|
||||
implementation 'androidx.compose.material:material-icons-core:1.3.1'
|
||||
implementation 'androidx.compose.material3:material3:1.0.1'
|
||||
implementation "androidx.compose.ui:ui:1.3.2"
|
||||
implementation "androidx.compose.ui:ui-tooling-preview:1.3.2"
|
||||
implementation 'androidx.lifecycle:lifecycle-viewmodel-compose:2.5.1'
|
||||
implementation "com.google.accompanist:accompanist-permissions:0.28.0"
|
||||
implementation 'org.jetbrains.kotlinx:kotlinx-coroutines-core:1.6.4'
|
||||
|
||||
testImplementation 'junit:junit:4.13.2'
|
||||
androidTestImplementation 'androidx.test.ext:junit:1.1.4'
|
||||
androidTestImplementation 'androidx.test.espresso:espresso-core:3.5.0'
|
||||
androidTestImplementation "androidx.compose.ui:ui-test-junit4:1.3.2"
|
||||
debugImplementation "androidx.compose.ui:ui-tooling:1.3.2"
|
||||
debugImplementation "androidx.compose.ui:ui-test-manifest:1.3.2"
|
||||
}
|
21
examples/whisper.android/app/proguard-rules.pro
vendored
21
examples/whisper.android/app/proguard-rules.pro
vendored
@ -1,21 +0,0 @@
|
||||
# Add project specific ProGuard rules here.
|
||||
# You can control the set of applied configuration files using the
|
||||
# proguardFiles setting in build.gradle.
|
||||
#
|
||||
# For more details, see
|
||||
# http://developer.android.com/guide/developing/tools/proguard.html
|
||||
|
||||
# If your project uses WebView with JS, uncomment the following
|
||||
# and specify the fully qualified class name to the JavaScript interface
|
||||
# class:
|
||||
#-keepclassmembers class fqcn.of.javascript.interface.for.webview {
|
||||
# public *;
|
||||
#}
|
||||
|
||||
# Uncomment this to preserve the line number information for
|
||||
# debugging stack traces.
|
||||
#-keepattributes SourceFile,LineNumberTable
|
||||
|
||||
# If you keep the line number information, uncomment this to
|
||||
# hide the original source file name.
|
||||
#-renamesourcefileattribute SourceFile
|
@ -1,24 +0,0 @@
|
||||
package com.whispercppdemo
|
||||
|
||||
import androidx.test.platform.app.InstrumentationRegistry
|
||||
import androidx.test.ext.junit.runners.AndroidJUnit4
|
||||
|
||||
import org.junit.Test
|
||||
import org.junit.runner.RunWith
|
||||
|
||||
import org.junit.Assert.*
|
||||
|
||||
/**
|
||||
* Instrumented test, which will execute on an Android device.
|
||||
*
|
||||
* See [testing documentation](http://d.android.com/tools/testing).
|
||||
*/
|
||||
@RunWith(AndroidJUnit4::class)
|
||||
class ExampleInstrumentedTest {
|
||||
@Test
|
||||
fun useAppContext() {
|
||||
// Context of the app under test.
|
||||
val appContext = InstrumentationRegistry.getInstrumentation().targetContext
|
||||
assertEquals("com.whispercppdemo", appContext.packageName)
|
||||
}
|
||||
}
|
@ -1,32 +0,0 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
xmlns:tools="http://schemas.android.com/tools">
|
||||
|
||||
<uses-permission android:name="android.permission.RECORD_AUDIO" />
|
||||
|
||||
<application
|
||||
android:allowBackup="true"
|
||||
android:dataExtractionRules="@xml/data_extraction_rules"
|
||||
android:fullBackupContent="@xml/backup_rules"
|
||||
android:icon="@mipmap/ic_launcher"
|
||||
android:label="@string/app_name"
|
||||
android:supportsRtl="true"
|
||||
android:theme="@style/Theme.WhisperCppDemo"
|
||||
tools:targetApi="31">
|
||||
<activity
|
||||
android:name=".MainActivity"
|
||||
android:exported="true"
|
||||
android:theme="@style/Theme.WhisperCppDemo">
|
||||
<intent-filter>
|
||||
<action android:name="android.intent.action.MAIN" />
|
||||
|
||||
<category android:name="android.intent.category.LAUNCHER" />
|
||||
</intent-filter>
|
||||
|
||||
<meta-data
|
||||
android:name="android.app.lib_name"
|
||||
android:value="" />
|
||||
</activity>
|
||||
</application>
|
||||
|
||||
</manifest>
|
@ -1,22 +0,0 @@
|
||||
package com.whispercppdemo
|
||||
|
||||
import android.os.Bundle
|
||||
import androidx.activity.ComponentActivity
|
||||
import androidx.activity.compose.setContent
|
||||
import androidx.activity.viewModels
|
||||
import com.whispercppdemo.ui.main.MainScreen
|
||||
import com.whispercppdemo.ui.main.MainScreenViewModel
|
||||
import com.whispercppdemo.ui.theme.WhisperCppDemoTheme
|
||||
|
||||
class MainActivity : ComponentActivity() {
|
||||
private val viewModel: MainScreenViewModel by viewModels { MainScreenViewModel.factory() }
|
||||
|
||||
override fun onCreate(savedInstanceState: Bundle?) {
|
||||
super.onCreate(savedInstanceState)
|
||||
setContent {
|
||||
WhisperCppDemoTheme {
|
||||
MainScreen(viewModel)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -1,76 +0,0 @@
|
||||
package com.whispercppdemo.media
|
||||
|
||||
import java.io.ByteArrayOutputStream
|
||||
import java.io.File
|
||||
import java.nio.ByteBuffer
|
||||
import java.nio.ByteOrder
|
||||
|
||||
fun decodeWaveFile(file: File): FloatArray {
|
||||
val baos = ByteArrayOutputStream()
|
||||
file.inputStream().use { it.copyTo(baos) }
|
||||
val buffer = ByteBuffer.wrap(baos.toByteArray())
|
||||
buffer.order(ByteOrder.LITTLE_ENDIAN)
|
||||
buffer.position(44)
|
||||
val shortBuffer = buffer.asShortBuffer()
|
||||
val shortArray = ShortArray(shortBuffer.limit())
|
||||
shortBuffer.get(shortArray)
|
||||
return FloatArray(shortArray.size) { index ->
|
||||
(shortArray[index] / 32767.0f).coerceIn(-1f..1f)
|
||||
}
|
||||
}
|
||||
|
||||
fun encodeWaveFile(file: File, data: ShortArray) {
|
||||
file.outputStream().use {
|
||||
it.write(headerBytes(data.size * 2))
|
||||
val buffer = ByteBuffer.allocate(data.size * 2)
|
||||
buffer.order(ByteOrder.LITTLE_ENDIAN)
|
||||
buffer.asShortBuffer().put(data)
|
||||
val bytes = ByteArray(buffer.limit())
|
||||
buffer.get(bytes)
|
||||
it.write(bytes)
|
||||
}
|
||||
}
|
||||
|
||||
private fun headerBytes(totalLength: Int): ByteArray {
|
||||
require(totalLength >= 44)
|
||||
ByteBuffer.allocate(44).apply {
|
||||
order(ByteOrder.LITTLE_ENDIAN)
|
||||
|
||||
put('R'.code.toByte())
|
||||
put('I'.code.toByte())
|
||||
put('F'.code.toByte())
|
||||
put('F'.code.toByte())
|
||||
|
||||
putInt(totalLength - 8)
|
||||
|
||||
put('W'.code.toByte())
|
||||
put('A'.code.toByte())
|
||||
put('V'.code.toByte())
|
||||
put('E'.code.toByte())
|
||||
|
||||
put('f'.code.toByte())
|
||||
put('m'.code.toByte())
|
||||
put('t'.code.toByte())
|
||||
put(' '.code.toByte())
|
||||
|
||||
putInt(16)
|
||||
putShort(1.toShort())
|
||||
putShort(1.toShort())
|
||||
putInt(16000)
|
||||
putInt(32000)
|
||||
putShort(2.toShort())
|
||||
putShort(16.toShort())
|
||||
|
||||
put('d'.code.toByte())
|
||||
put('a'.code.toByte())
|
||||
put('t'.code.toByte())
|
||||
put('a'.code.toByte())
|
||||
|
||||
putInt(totalLength - 44)
|
||||
position(0)
|
||||
}.also {
|
||||
val bytes = ByteArray(it.limit())
|
||||
it.get(bytes)
|
||||
return bytes
|
||||
}
|
||||
}
|
@ -1,88 +0,0 @@
|
||||
package com.whispercppdemo.recorder
|
||||
|
||||
import android.annotation.SuppressLint
|
||||
import android.media.AudioFormat
|
||||
import android.media.AudioRecord
|
||||
import android.media.MediaRecorder
|
||||
import com.whispercppdemo.media.encodeWaveFile
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.coroutines.asCoroutineDispatcher
|
||||
import kotlinx.coroutines.withContext
|
||||
import java.io.File
|
||||
import java.util.concurrent.Executors
|
||||
import java.util.concurrent.atomic.AtomicBoolean
|
||||
|
||||
class Recorder {
|
||||
private val scope: CoroutineScope = CoroutineScope(
|
||||
Executors.newSingleThreadExecutor().asCoroutineDispatcher()
|
||||
)
|
||||
private var recorder: AudioRecordThread? = null
|
||||
|
||||
suspend fun startRecording(outputFile: File, onError: (Exception) -> Unit) = withContext(scope.coroutineContext) {
|
||||
recorder = AudioRecordThread(outputFile, onError)
|
||||
recorder?.start()
|
||||
}
|
||||
|
||||
suspend fun stopRecording() = withContext(scope.coroutineContext) {
|
||||
recorder?.stopRecording()
|
||||
@Suppress("BlockingMethodInNonBlockingContext")
|
||||
recorder?.join()
|
||||
recorder = null
|
||||
}
|
||||
}
|
||||
|
||||
private class AudioRecordThread(
|
||||
private val outputFile: File,
|
||||
private val onError: (Exception) -> Unit
|
||||
) :
|
||||
Thread("AudioRecorder") {
|
||||
private var quit = AtomicBoolean(false)
|
||||
|
||||
@SuppressLint("MissingPermission")
|
||||
override fun run() {
|
||||
try {
|
||||
val bufferSize = AudioRecord.getMinBufferSize(
|
||||
16000,
|
||||
AudioFormat.CHANNEL_IN_MONO,
|
||||
AudioFormat.ENCODING_PCM_16BIT
|
||||
) * 4
|
||||
val buffer = ShortArray(bufferSize / 2)
|
||||
|
||||
val audioRecord = AudioRecord(
|
||||
MediaRecorder.AudioSource.MIC,
|
||||
16000,
|
||||
AudioFormat.CHANNEL_IN_MONO,
|
||||
AudioFormat.ENCODING_PCM_16BIT,
|
||||
bufferSize
|
||||
)
|
||||
|
||||
try {
|
||||
audioRecord.startRecording()
|
||||
|
||||
val allData = mutableListOf<Short>()
|
||||
|
||||
while (!quit.get()) {
|
||||
val read = audioRecord.read(buffer, 0, buffer.size)
|
||||
if (read > 0) {
|
||||
for (i in 0 until read) {
|
||||
allData.add(buffer[i])
|
||||
}
|
||||
} else {
|
||||
throw java.lang.RuntimeException("audioRecord.read returned $read")
|
||||
}
|
||||
}
|
||||
|
||||
audioRecord.stop()
|
||||
encodeWaveFile(outputFile, allData.toShortArray())
|
||||
} finally {
|
||||
audioRecord.release()
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
onError(e)
|
||||
}
|
||||
}
|
||||
|
||||
fun stopRecording() {
|
||||
quit.set(true)
|
||||
}
|
||||
}
|
@ -1,99 +0,0 @@
|
||||
package com.whispercppdemo.ui.main
|
||||
|
||||
import androidx.compose.foundation.layout.*
|
||||
import androidx.compose.foundation.rememberScrollState
|
||||
import androidx.compose.foundation.verticalScroll
|
||||
import androidx.compose.material3.*
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.res.stringResource
|
||||
import androidx.compose.ui.unit.dp
|
||||
import com.google.accompanist.permissions.ExperimentalPermissionsApi
|
||||
import com.google.accompanist.permissions.isGranted
|
||||
import com.google.accompanist.permissions.rememberPermissionState
|
||||
import com.whispercppdemo.R
|
||||
|
||||
@Composable
|
||||
fun MainScreen(viewModel: MainScreenViewModel) {
|
||||
MainScreen(
|
||||
canTranscribe = viewModel.canTranscribe,
|
||||
isRecording = viewModel.isRecording,
|
||||
messageLog = viewModel.dataLog,
|
||||
onTranscribeSampleTapped = viewModel::transcribeSample,
|
||||
onRecordTapped = viewModel::toggleRecord
|
||||
)
|
||||
}
|
||||
|
||||
@OptIn(ExperimentalMaterial3Api::class)
|
||||
@Composable
|
||||
private fun MainScreen(
|
||||
canTranscribe: Boolean,
|
||||
isRecording: Boolean,
|
||||
messageLog: String,
|
||||
onTranscribeSampleTapped: () -> Unit,
|
||||
onRecordTapped: () -> Unit
|
||||
) {
|
||||
Scaffold(
|
||||
topBar = {
|
||||
TopAppBar(
|
||||
title = { Text(stringResource(R.string.app_name)) }
|
||||
)
|
||||
},
|
||||
) { innerPadding ->
|
||||
Column(
|
||||
modifier = Modifier
|
||||
.padding(innerPadding)
|
||||
.padding(16.dp)
|
||||
) {
|
||||
Row(horizontalArrangement = Arrangement.SpaceBetween) {
|
||||
TranscribeSampleButton(enabled = canTranscribe, onClick = onTranscribeSampleTapped)
|
||||
RecordButton(
|
||||
enabled = canTranscribe,
|
||||
isRecording = isRecording,
|
||||
onClick = onRecordTapped
|
||||
)
|
||||
}
|
||||
MessageLog(messageLog)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Composable
|
||||
private fun MessageLog(log: String) {
|
||||
Text(modifier = Modifier.verticalScroll(rememberScrollState()), text = log)
|
||||
}
|
||||
|
||||
@Composable
|
||||
private fun TranscribeSampleButton(enabled: Boolean, onClick: () -> Unit) {
|
||||
Button(onClick = onClick, enabled = enabled) {
|
||||
Text("Transcribe sample")
|
||||
}
|
||||
}
|
||||
|
||||
@OptIn(ExperimentalPermissionsApi::class)
|
||||
@Composable
|
||||
private fun RecordButton(enabled: Boolean, isRecording: Boolean, onClick: () -> Unit) {
|
||||
val micPermissionState = rememberPermissionState(
|
||||
permission = android.Manifest.permission.RECORD_AUDIO,
|
||||
onPermissionResult = { granted ->
|
||||
if (granted) {
|
||||
onClick()
|
||||
}
|
||||
}
|
||||
)
|
||||
Button(onClick = {
|
||||
if (micPermissionState.status.isGranted) {
|
||||
onClick()
|
||||
} else {
|
||||
micPermissionState.launchPermissionRequest()
|
||||
}
|
||||
}, enabled = enabled) {
|
||||
Text(
|
||||
if (isRecording) {
|
||||
"Stop recording"
|
||||
} else {
|
||||
"Start recording"
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
@ -1,198 +0,0 @@
|
||||
package com.whispercppdemo.ui.main
|
||||
|
||||
import android.app.Application
|
||||
import android.content.Context
|
||||
import android.media.MediaPlayer
|
||||
import android.util.Log
|
||||
import androidx.compose.runtime.getValue
|
||||
import androidx.compose.runtime.mutableStateOf
|
||||
import androidx.compose.runtime.setValue
|
||||
import androidx.core.net.toUri
|
||||
import androidx.lifecycle.ViewModel
|
||||
import androidx.lifecycle.ViewModelProvider
|
||||
import androidx.lifecycle.viewModelScope
|
||||
import androidx.lifecycle.viewmodel.initializer
|
||||
import androidx.lifecycle.viewmodel.viewModelFactory
|
||||
import com.whispercppdemo.media.decodeWaveFile
|
||||
import com.whispercppdemo.recorder.Recorder
|
||||
import com.whispercppdemo.whisper.WhisperContext
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.launch
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import kotlinx.coroutines.withContext
|
||||
import java.io.File
|
||||
|
||||
private const val LOG_TAG = "MainScreenViewModel"
|
||||
|
||||
class MainScreenViewModel(private val application: Application) : ViewModel() {
|
||||
var canTranscribe by mutableStateOf(false)
|
||||
private set
|
||||
var dataLog by mutableStateOf("")
|
||||
private set
|
||||
var isRecording by mutableStateOf(false)
|
||||
private set
|
||||
|
||||
private val modelsPath = File(application.filesDir, "models")
|
||||
private val samplesPath = File(application.filesDir, "samples")
|
||||
private var recorder: Recorder = Recorder()
|
||||
private var whisperContext: WhisperContext? = null
|
||||
private var mediaPlayer: MediaPlayer? = null
|
||||
private var recordedFile: File? = null
|
||||
|
||||
init {
|
||||
viewModelScope.launch {
|
||||
loadData()
|
||||
}
|
||||
}
|
||||
|
||||
private suspend fun loadData() {
|
||||
printMessage("Loading data...\n")
|
||||
try {
|
||||
copyAssets()
|
||||
loadBaseModel()
|
||||
canTranscribe = true
|
||||
} catch (e: Exception) {
|
||||
Log.w(LOG_TAG, e)
|
||||
printMessage("${e.localizedMessage}\n")
|
||||
}
|
||||
}
|
||||
|
||||
private suspend fun printMessage(msg: String) = withContext(Dispatchers.Main) {
|
||||
dataLog += msg
|
||||
}
|
||||
|
||||
private suspend fun copyAssets() = withContext(Dispatchers.IO) {
|
||||
modelsPath.mkdirs()
|
||||
samplesPath.mkdirs()
|
||||
//application.copyData("models", modelsPath, ::printMessage)
|
||||
application.copyData("samples", samplesPath, ::printMessage)
|
||||
printMessage("All data copied to working directory.\n")
|
||||
}
|
||||
|
||||
private suspend fun loadBaseModel() = withContext(Dispatchers.IO) {
|
||||
printMessage("Loading model...\n")
|
||||
val models = application.assets.list("models/")
|
||||
if (models != null) {
|
||||
whisperContext = WhisperContext.createContextFromAsset(application.assets, "models/" + models[0])
|
||||
printMessage("Loaded model ${models[0]}.\n")
|
||||
}
|
||||
|
||||
//val firstModel = modelsPath.listFiles()!!.first()
|
||||
//whisperContext = WhisperContext.createContextFromFile(firstModel.absolutePath)
|
||||
}
|
||||
|
||||
fun transcribeSample() = viewModelScope.launch {
|
||||
transcribeAudio(getFirstSample())
|
||||
}
|
||||
|
||||
private suspend fun getFirstSample(): File = withContext(Dispatchers.IO) {
|
||||
samplesPath.listFiles()!!.first()
|
||||
}
|
||||
|
||||
private suspend fun readAudioSamples(file: File): FloatArray = withContext(Dispatchers.IO) {
|
||||
stopPlayback()
|
||||
startPlayback(file)
|
||||
return@withContext decodeWaveFile(file)
|
||||
}
|
||||
|
||||
private suspend fun stopPlayback() = withContext(Dispatchers.Main) {
|
||||
mediaPlayer?.stop()
|
||||
mediaPlayer?.release()
|
||||
mediaPlayer = null
|
||||
}
|
||||
|
||||
private suspend fun startPlayback(file: File) = withContext(Dispatchers.Main) {
|
||||
mediaPlayer = MediaPlayer.create(application, file.absolutePath.toUri())
|
||||
mediaPlayer?.start()
|
||||
}
|
||||
|
||||
private suspend fun transcribeAudio(file: File) {
|
||||
if (!canTranscribe) {
|
||||
return
|
||||
}
|
||||
|
||||
canTranscribe = false
|
||||
|
||||
try {
|
||||
printMessage("Reading wave samples...\n")
|
||||
val data = readAudioSamples(file)
|
||||
printMessage("Transcribing data...\n")
|
||||
val text = whisperContext?.transcribeData(data)
|
||||
printMessage("Done: $text\n")
|
||||
} catch (e: Exception) {
|
||||
Log.w(LOG_TAG, e)
|
||||
printMessage("${e.localizedMessage}\n")
|
||||
}
|
||||
|
||||
canTranscribe = true
|
||||
}
|
||||
|
||||
fun toggleRecord() = viewModelScope.launch {
|
||||
try {
|
||||
if (isRecording) {
|
||||
recorder.stopRecording()
|
||||
isRecording = false
|
||||
recordedFile?.let { transcribeAudio(it) }
|
||||
} else {
|
||||
stopPlayback()
|
||||
val file = getTempFileForRecording()
|
||||
recorder.startRecording(file) { e ->
|
||||
viewModelScope.launch {
|
||||
withContext(Dispatchers.Main) {
|
||||
printMessage("${e.localizedMessage}\n")
|
||||
isRecording = false
|
||||
}
|
||||
}
|
||||
}
|
||||
isRecording = true
|
||||
recordedFile = file
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
Log.w(LOG_TAG, e)
|
||||
printMessage("${e.localizedMessage}\n")
|
||||
isRecording = false
|
||||
}
|
||||
}
|
||||
|
||||
private suspend fun getTempFileForRecording() = withContext(Dispatchers.IO) {
|
||||
File.createTempFile("recording", "wav")
|
||||
}
|
||||
|
||||
override fun onCleared() {
|
||||
runBlocking {
|
||||
whisperContext?.release()
|
||||
whisperContext = null
|
||||
stopPlayback()
|
||||
}
|
||||
}
|
||||
|
||||
companion object {
|
||||
fun factory() = viewModelFactory {
|
||||
initializer {
|
||||
val application =
|
||||
this[ViewModelProvider.AndroidViewModelFactory.APPLICATION_KEY] as Application
|
||||
MainScreenViewModel(application)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private suspend fun Context.copyData(
|
||||
assetDirName: String,
|
||||
destDir: File,
|
||||
printMessage: suspend (String) -> Unit
|
||||
) = withContext(Dispatchers.IO) {
|
||||
assets.list(assetDirName)?.forEach { name ->
|
||||
val assetPath = "$assetDirName/$name"
|
||||
Log.v(LOG_TAG, "Processing $assetPath...")
|
||||
val destination = File(destDir, name)
|
||||
Log.v(LOG_TAG, "Copying $assetPath to $destination...")
|
||||
printMessage("Copying $name...\n")
|
||||
assets.open(assetPath).use { input ->
|
||||
destination.outputStream().use { output ->
|
||||
input.copyTo(output)
|
||||
}
|
||||
}
|
||||
Log.v(LOG_TAG, "Copied $assetPath to $destination")
|
||||
}
|
||||
}
|
@ -1,11 +0,0 @@
|
||||
package com.whispercppdemo.ui.theme
|
||||
|
||||
import androidx.compose.ui.graphics.Color
|
||||
|
||||
val Purple80 = Color(0xFFD0BCFF)
|
||||
val PurpleGrey80 = Color(0xFFCCC2DC)
|
||||
val Pink80 = Color(0xFFEFB8C8)
|
||||
|
||||
val Purple40 = Color(0xFF6650a4)
|
||||
val PurpleGrey40 = Color(0xFF625b71)
|
||||
val Pink40 = Color(0xFF7D5260)
|
@ -1,68 +0,0 @@
|
||||
package com.whispercppdemo.ui.theme
|
||||
|
||||
import android.app.Activity
|
||||
import android.os.Build
|
||||
import androidx.compose.foundation.isSystemInDarkTheme
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
import androidx.compose.material3.darkColorScheme
|
||||
import androidx.compose.material3.dynamicDarkColorScheme
|
||||
import androidx.compose.material3.dynamicLightColorScheme
|
||||
import androidx.compose.material3.lightColorScheme
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.SideEffect
|
||||
import androidx.compose.ui.graphics.toArgb
|
||||
import androidx.compose.ui.platform.LocalContext
|
||||
import androidx.compose.ui.platform.LocalView
|
||||
import androidx.core.view.ViewCompat
|
||||
|
||||
private val DarkColorScheme = darkColorScheme(
|
||||
primary = Purple80,
|
||||
secondary = PurpleGrey80,
|
||||
tertiary = Pink80
|
||||
)
|
||||
|
||||
private val LightColorScheme = lightColorScheme(
|
||||
primary = Purple40,
|
||||
secondary = PurpleGrey40,
|
||||
tertiary = Pink40
|
||||
|
||||
/* Other default colors to override
|
||||
background = Color(0xFFFFFBFE),
|
||||
surface = Color(0xFFFFFBFE),
|
||||
onPrimary = Color.White,
|
||||
onSecondary = Color.White,
|
||||
onTertiary = Color.White,
|
||||
onBackground = Color(0xFF1C1B1F),
|
||||
onSurface = Color(0xFF1C1B1F),
|
||||
*/
|
||||
)
|
||||
|
||||
@Composable
|
||||
fun WhisperCppDemoTheme(
|
||||
darkTheme: Boolean = isSystemInDarkTheme(),
|
||||
// Dynamic color is available on Android 12+
|
||||
dynamicColor: Boolean = true,
|
||||
content: @Composable () -> Unit
|
||||
) {
|
||||
val colorScheme = when {
|
||||
dynamicColor && Build.VERSION.SDK_INT >= Build.VERSION_CODES.S -> {
|
||||
val context = LocalContext.current
|
||||
if (darkTheme) dynamicDarkColorScheme(context) else dynamicLightColorScheme(context)
|
||||
}
|
||||
darkTheme -> DarkColorScheme
|
||||
else -> LightColorScheme
|
||||
}
|
||||
val view = LocalView.current
|
||||
if (!view.isInEditMode) {
|
||||
SideEffect {
|
||||
(view.context as Activity).window.statusBarColor = colorScheme.primary.toArgb()
|
||||
ViewCompat.getWindowInsetsController(view)?.isAppearanceLightStatusBars = darkTheme
|
||||
}
|
||||
}
|
||||
|
||||
MaterialTheme(
|
||||
colorScheme = colorScheme,
|
||||
typography = Typography,
|
||||
content = content
|
||||
)
|
||||
}
|
@ -1,34 +0,0 @@
|
||||
package com.whispercppdemo.ui.theme
|
||||
|
||||
import androidx.compose.material3.Typography
|
||||
import androidx.compose.ui.text.TextStyle
|
||||
import androidx.compose.ui.text.font.FontFamily
|
||||
import androidx.compose.ui.text.font.FontWeight
|
||||
import androidx.compose.ui.unit.sp
|
||||
|
||||
// Set of Material typography styles to start with
|
||||
val Typography = Typography(
|
||||
bodyLarge = TextStyle(
|
||||
fontFamily = FontFamily.Default,
|
||||
fontWeight = FontWeight.Normal,
|
||||
fontSize = 16.sp,
|
||||
lineHeight = 24.sp,
|
||||
letterSpacing = 0.5.sp
|
||||
)
|
||||
/* Other default text styles to override
|
||||
titleLarge = TextStyle(
|
||||
fontFamily = FontFamily.Default,
|
||||
fontWeight = FontWeight.Normal,
|
||||
fontSize = 22.sp,
|
||||
lineHeight = 28.sp,
|
||||
letterSpacing = 0.sp
|
||||
),
|
||||
labelSmall = TextStyle(
|
||||
fontFamily = FontFamily.Default,
|
||||
fontWeight = FontWeight.Medium,
|
||||
fontSize = 11.sp,
|
||||
lineHeight = 16.sp,
|
||||
letterSpacing = 0.5.sp
|
||||
)
|
||||
*/
|
||||
)
|
@ -1,122 +0,0 @@
|
||||
package com.whispercppdemo.whisper
|
||||
|
||||
import android.content.res.AssetManager
|
||||
import android.os.Build
|
||||
import android.util.Log
|
||||
import kotlinx.coroutines.*
|
||||
import java.io.File
|
||||
import java.io.InputStream
|
||||
import java.util.concurrent.Executors
|
||||
|
||||
private const val LOG_TAG = "LibWhisper"
|
||||
|
||||
class WhisperContext private constructor(private var ptr: Long) {
|
||||
// Meet Whisper C++ constraint: Don't access from more than one thread at a time.
|
||||
private val scope: CoroutineScope = CoroutineScope(
|
||||
Executors.newSingleThreadExecutor().asCoroutineDispatcher()
|
||||
)
|
||||
|
||||
suspend fun transcribeData(data: FloatArray): String = withContext(scope.coroutineContext) {
|
||||
require(ptr != 0L)
|
||||
WhisperLib.fullTranscribe(ptr, data)
|
||||
val textCount = WhisperLib.getTextSegmentCount(ptr)
|
||||
return@withContext buildString {
|
||||
for (i in 0 until textCount) {
|
||||
append(WhisperLib.getTextSegment(ptr, i))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
suspend fun release() = withContext(scope.coroutineContext) {
|
||||
if (ptr != 0L) {
|
||||
WhisperLib.freeContext(ptr)
|
||||
ptr = 0
|
||||
}
|
||||
}
|
||||
|
||||
protected fun finalize() {
|
||||
runBlocking {
|
||||
release()
|
||||
}
|
||||
}
|
||||
|
||||
companion object {
|
||||
fun createContextFromFile(filePath: String): WhisperContext {
|
||||
val ptr = WhisperLib.initContext(filePath)
|
||||
if (ptr == 0L) {
|
||||
throw java.lang.RuntimeException("Couldn't create context with path $filePath")
|
||||
}
|
||||
return WhisperContext(ptr)
|
||||
}
|
||||
|
||||
fun createContextFromInputStream(stream: InputStream): WhisperContext {
|
||||
val ptr = WhisperLib.initContextFromInputStream(stream)
|
||||
|
||||
if (ptr == 0L) {
|
||||
throw java.lang.RuntimeException("Couldn't create context from input stream")
|
||||
}
|
||||
return WhisperContext(ptr)
|
||||
}
|
||||
|
||||
fun createContextFromAsset(assetManager: AssetManager, assetPath: String): WhisperContext {
|
||||
val ptr = WhisperLib.initContextFromAsset(assetManager, assetPath)
|
||||
|
||||
if (ptr == 0L) {
|
||||
throw java.lang.RuntimeException("Couldn't create context from asset $assetPath")
|
||||
}
|
||||
return WhisperContext(ptr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private class WhisperLib {
|
||||
companion object {
|
||||
init {
|
||||
Log.d(LOG_TAG, "Primary ABI: ${Build.SUPPORTED_ABIS[0]}")
|
||||
var loadVfpv4 = false
|
||||
if (isArmEabiV7a()) {
|
||||
// armeabi-v7a needs runtime detection support
|
||||
val cpuInfo = cpuInfo()
|
||||
cpuInfo?.let {
|
||||
Log.d(LOG_TAG, "CPU info: $cpuInfo")
|
||||
if (cpuInfo.contains("vfpv4")) {
|
||||
Log.d(LOG_TAG, "CPU supports vfpv4")
|
||||
loadVfpv4 = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (loadVfpv4) {
|
||||
Log.d(LOG_TAG, "Loading libwhisper_vfpv4.so")
|
||||
System.loadLibrary("whisper_vfpv4")
|
||||
} else {
|
||||
Log.d(LOG_TAG, "Loading libwhisper.so")
|
||||
System.loadLibrary("whisper")
|
||||
}
|
||||
}
|
||||
|
||||
// JNI methods
|
||||
external fun initContextFromInputStream(inputStream: InputStream): Long
|
||||
external fun initContextFromAsset(assetManager: AssetManager, assetPath: String): Long
|
||||
external fun initContext(modelPath: String): Long
|
||||
external fun freeContext(contextPtr: Long)
|
||||
external fun fullTranscribe(contextPtr: Long, audioData: FloatArray)
|
||||
external fun getTextSegmentCount(contextPtr: Long): Int
|
||||
external fun getTextSegment(contextPtr: Long, index: Int): String
|
||||
}
|
||||
}
|
||||
|
||||
private fun isArmEabiV7a(): Boolean {
|
||||
return Build.SUPPORTED_ABIS[0].equals("armeabi-v7a")
|
||||
}
|
||||
|
||||
private fun cpuInfo(): String? {
|
||||
return try {
|
||||
File("/proc/cpuinfo").inputStream().bufferedReader().use {
|
||||
it.readText()
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
Log.w(LOG_TAG, "Couldn't read /proc/cpuinfo", e)
|
||||
null
|
||||
}
|
||||
}
|
@ -1,15 +0,0 @@
|
||||
LOCAL_PATH := $(call my-dir)
|
||||
include $(CLEAR_VARS)
|
||||
LOCAL_MODULE := libwhisper
|
||||
include $(LOCAL_PATH)/Whisper.mk
|
||||
include $(BUILD_SHARED_LIBRARY)
|
||||
|
||||
ifeq ($(TARGET_ARCH_ABI),armeabi-v7a)
|
||||
include $(CLEAR_VARS)
|
||||
LOCAL_MODULE := libwhisper_vfpv4
|
||||
include $(LOCAL_PATH)/Whisper.mk
|
||||
# Allow building NEON FMA code.
|
||||
# https://android.googlesource.com/platform/ndk/+/master/sources/android/cpufeatures/cpu-features.h
|
||||
LOCAL_CFLAGS += -mfpu=neon-vfpv4
|
||||
include $(BUILD_SHARED_LIBRARY)
|
||||
endif
|
@ -1 +0,0 @@
|
||||
APP_STL := c++_static
|
@ -1,18 +0,0 @@
|
||||
WHISPER_LIB_DIR := $(LOCAL_PATH)/../../../../../../../
|
||||
LOCAL_LDLIBS := -landroid -llog
|
||||
|
||||
# Make the final output library smaller by only keeping the symbols referenced from the app.
|
||||
ifneq ($(APP_OPTIM),debug)
|
||||
LOCAL_CFLAGS += -O3
|
||||
LOCAL_CFLAGS += -fvisibility=hidden -fvisibility-inlines-hidden
|
||||
LOCAL_CFLAGS += -ffunction-sections -fdata-sections
|
||||
LOCAL_LDFLAGS += -Wl,--gc-sections
|
||||
LOCAL_LDFLAGS += -Wl,--exclude-libs,ALL
|
||||
LOCAL_LDFLAGS += -flto
|
||||
endif
|
||||
|
||||
LOCAL_CFLAGS += -DSTDC_HEADERS -std=c11 -I $(WHISPER_LIB_DIR)
|
||||
LOCAL_CPPFLAGS += -std=c++11
|
||||
LOCAL_SRC_FILES := $(WHISPER_LIB_DIR)/ggml.c \
|
||||
$(WHISPER_LIB_DIR)/whisper.cpp \
|
||||
$(LOCAL_PATH)/jni.c
|
@ -1,216 +0,0 @@
|
||||
#include <jni.h>
|
||||
#include <android/asset_manager.h>
|
||||
#include <android/asset_manager_jni.h>
|
||||
#include <android/log.h>
|
||||
#include <stdlib.h>
|
||||
#include <sys/sysinfo.h>
|
||||
#include <string.h>
|
||||
#include "whisper.h"
|
||||
|
||||
#define UNUSED(x) (void)(x)
|
||||
#define TAG "JNI"
|
||||
|
||||
#define LOGI(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__)
|
||||
#define LOGW(...) __android_log_print(ANDROID_LOG_WARN, TAG, __VA_ARGS__)
|
||||
|
||||
static inline int min(int a, int b) {
|
||||
return (a < b) ? a : b;
|
||||
}
|
||||
|
||||
static inline int max(int a, int b) {
|
||||
return (a > b) ? a : b;
|
||||
}
|
||||
|
||||
struct input_stream_context {
|
||||
size_t offset;
|
||||
JNIEnv * env;
|
||||
jobject thiz;
|
||||
jobject input_stream;
|
||||
|
||||
jmethodID mid_available;
|
||||
jmethodID mid_read;
|
||||
};
|
||||
|
||||
size_t inputStreamRead(void * ctx, void * output, size_t read_size) {
|
||||
struct input_stream_context* is = (struct input_stream_context*)ctx;
|
||||
|
||||
jint avail_size = (*is->env)->CallIntMethod(is->env, is->input_stream, is->mid_available);
|
||||
jint size_to_copy = read_size < avail_size ? (jint)read_size : avail_size;
|
||||
|
||||
jbyteArray byte_array = (*is->env)->NewByteArray(is->env, size_to_copy);
|
||||
|
||||
jint n_read = (*is->env)->CallIntMethod(is->env, is->input_stream, is->mid_read, byte_array, 0, size_to_copy);
|
||||
|
||||
if (size_to_copy != read_size || size_to_copy != n_read) {
|
||||
LOGI("Insufficient Read: Req=%zu, ToCopy=%d, Available=%d", read_size, size_to_copy, n_read);
|
||||
}
|
||||
|
||||
jbyte* byte_array_elements = (*is->env)->GetByteArrayElements(is->env, byte_array, NULL);
|
||||
memcpy(output, byte_array_elements, size_to_copy);
|
||||
(*is->env)->ReleaseByteArrayElements(is->env, byte_array, byte_array_elements, JNI_ABORT);
|
||||
|
||||
(*is->env)->DeleteLocalRef(is->env, byte_array);
|
||||
|
||||
is->offset += size_to_copy;
|
||||
|
||||
return size_to_copy;
|
||||
}
|
||||
bool inputStreamEof(void * ctx) {
|
||||
struct input_stream_context* is = (struct input_stream_context*)ctx;
|
||||
|
||||
jint result = (*is->env)->CallIntMethod(is->env, is->input_stream, is->mid_available);
|
||||
return result <= 0;
|
||||
}
|
||||
void inputStreamClose(void * ctx) {
|
||||
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_initContextFromInputStream(
|
||||
JNIEnv *env, jobject thiz, jobject input_stream) {
|
||||
UNUSED(thiz);
|
||||
|
||||
struct whisper_context *context = NULL;
|
||||
struct whisper_model_loader loader = {};
|
||||
struct input_stream_context inp_ctx = {};
|
||||
|
||||
inp_ctx.offset = 0;
|
||||
inp_ctx.env = env;
|
||||
inp_ctx.thiz = thiz;
|
||||
inp_ctx.input_stream = input_stream;
|
||||
|
||||
jclass cls = (*env)->GetObjectClass(env, input_stream);
|
||||
inp_ctx.mid_available = (*env)->GetMethodID(env, cls, "available", "()I");
|
||||
inp_ctx.mid_read = (*env)->GetMethodID(env, cls, "read", "([BII)I");
|
||||
|
||||
loader.context = &inp_ctx;
|
||||
loader.read = inputStreamRead;
|
||||
loader.eof = inputStreamEof;
|
||||
loader.close = inputStreamClose;
|
||||
|
||||
loader.eof(loader.context);
|
||||
|
||||
context = whisper_init(&loader);
|
||||
return (jlong) context;
|
||||
}
|
||||
|
||||
static size_t asset_read(void *ctx, void *output, size_t read_size) {
|
||||
return AAsset_read((AAsset *) ctx, output, read_size);
|
||||
}
|
||||
|
||||
static bool asset_is_eof(void *ctx) {
|
||||
return AAsset_getRemainingLength64((AAsset *) ctx) <= 0;
|
||||
}
|
||||
|
||||
static void asset_close(void *ctx) {
|
||||
AAsset_close((AAsset *) ctx);
|
||||
}
|
||||
|
||||
static struct whisper_context *whisper_init_from_asset(
|
||||
JNIEnv *env,
|
||||
jobject assetManager,
|
||||
const char *asset_path
|
||||
) {
|
||||
LOGI("Loading model from asset '%s'\n", asset_path);
|
||||
AAssetManager *asset_manager = AAssetManager_fromJava(env, assetManager);
|
||||
AAsset *asset = AAssetManager_open(asset_manager, asset_path, AASSET_MODE_STREAMING);
|
||||
if (!asset) {
|
||||
LOGW("Failed to open '%s'\n", asset_path);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
whisper_model_loader loader = {
|
||||
.context = asset,
|
||||
.read = &asset_read,
|
||||
.eof = &asset_is_eof,
|
||||
.close = &asset_close
|
||||
};
|
||||
|
||||
return whisper_init(&loader);
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_initContextFromAsset(
|
||||
JNIEnv *env, jobject thiz, jobject assetManager, jstring asset_path_str) {
|
||||
UNUSED(thiz);
|
||||
struct whisper_context *context = NULL;
|
||||
const char *asset_path_chars = (*env)->GetStringUTFChars(env, asset_path_str, NULL);
|
||||
context = whisper_init_from_asset(env, assetManager, asset_path_chars);
|
||||
(*env)->ReleaseStringUTFChars(env, asset_path_str, asset_path_chars);
|
||||
return (jlong) context;
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_initContext(
|
||||
JNIEnv *env, jobject thiz, jstring model_path_str) {
|
||||
UNUSED(thiz);
|
||||
struct whisper_context *context = NULL;
|
||||
const char *model_path_chars = (*env)->GetStringUTFChars(env, model_path_str, NULL);
|
||||
context = whisper_init_from_file(model_path_chars);
|
||||
(*env)->ReleaseStringUTFChars(env, model_path_str, model_path_chars);
|
||||
return (jlong) context;
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL
|
||||
Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_freeContext(
|
||||
JNIEnv *env, jobject thiz, jlong context_ptr) {
|
||||
UNUSED(env);
|
||||
UNUSED(thiz);
|
||||
struct whisper_context *context = (struct whisper_context *) context_ptr;
|
||||
whisper_free(context);
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL
|
||||
Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_fullTranscribe(
|
||||
JNIEnv *env, jobject thiz, jlong context_ptr, jfloatArray audio_data) {
|
||||
UNUSED(thiz);
|
||||
struct whisper_context *context = (struct whisper_context *) context_ptr;
|
||||
jfloat *audio_data_arr = (*env)->GetFloatArrayElements(env, audio_data, NULL);
|
||||
const jsize audio_data_length = (*env)->GetArrayLength(env, audio_data);
|
||||
|
||||
// Leave 2 processors free (i.e. the high-efficiency cores).
|
||||
int max_threads = max(1, min(8, get_nprocs() - 2));
|
||||
LOGI("Selecting %d threads", max_threads);
|
||||
|
||||
// The below adapted from the Objective-C iOS sample
|
||||
struct whisper_full_params params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||
params.print_realtime = true;
|
||||
params.print_progress = false;
|
||||
params.print_timestamps = true;
|
||||
params.print_special = false;
|
||||
params.translate = false;
|
||||
params.language = "en";
|
||||
params.n_threads = max_threads;
|
||||
params.offset_ms = 0;
|
||||
params.no_context = true;
|
||||
params.single_segment = false;
|
||||
|
||||
whisper_reset_timings(context);
|
||||
|
||||
LOGI("About to run whisper_full");
|
||||
if (whisper_full(context, params, audio_data_arr, audio_data_length) != 0) {
|
||||
LOGI("Failed to run the model");
|
||||
} else {
|
||||
whisper_print_timings(context);
|
||||
}
|
||||
(*env)->ReleaseFloatArrayElements(env, audio_data, audio_data_arr, JNI_ABORT);
|
||||
}
|
||||
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_getTextSegmentCount(
|
||||
JNIEnv *env, jobject thiz, jlong context_ptr) {
|
||||
UNUSED(env);
|
||||
UNUSED(thiz);
|
||||
struct whisper_context *context = (struct whisper_context *) context_ptr;
|
||||
return whisper_full_n_segments(context);
|
||||
}
|
||||
|
||||
JNIEXPORT jstring JNICALL
|
||||
Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_getTextSegment(
|
||||
JNIEnv *env, jobject thiz, jlong context_ptr, jint index) {
|
||||
UNUSED(thiz);
|
||||
struct whisper_context *context = (struct whisper_context *) context_ptr;
|
||||
const char *text = whisper_full_get_segment_text(context, index);
|
||||
jstring string = (*env)->NewStringUTF(env, text);
|
||||
return string;
|
||||
}
|
@ -1,170 +0,0 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<vector xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
android:width="108dp"
|
||||
android:height="108dp"
|
||||
android:viewportWidth="108"
|
||||
android:viewportHeight="108">
|
||||
<path
|
||||
android:fillColor="#3DDC84"
|
||||
android:pathData="M0,0h108v108h-108z" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M9,0L9,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M19,0L19,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M29,0L29,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M39,0L39,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M49,0L49,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M59,0L59,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M69,0L69,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M79,0L79,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M89,0L89,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M99,0L99,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,9L108,9"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,19L108,19"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,29L108,29"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,39L108,39"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,49L108,49"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,59L108,59"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,69L108,69"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,79L108,79"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,89L108,89"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,99L108,99"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M19,29L89,29"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M19,39L89,39"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M19,49L89,49"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M19,59L89,59"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M19,69L89,69"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M19,79L89,79"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M29,19L29,89"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M39,19L39,89"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M49,19L49,89"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M59,19L59,89"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M69,19L69,89"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M79,19L79,89"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
</vector>
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user