Compare commits

..

41 Commits

Author SHA1 Message Date
e2aa556a99 whisper : experiments with Flash Attention in the decoder 2023-01-07 21:00:51 +02:00
f30b5d322c ggml : fix bug in new soft max computation 2023-01-07 21:00:07 +02:00
44efbf7ff1 cmake : add -Wno-unused-function + update whisper.js 2023-01-07 20:18:34 +02:00
d347a59a5f ggml : when using BLAS start only 1 CPU thread 2023-01-07 19:48:56 +02:00
6394c906af ggml : fix running tasks with variable number of threads 2023-01-07 19:20:18 +02:00
74ffa14e1d ggml : unroll ggml_vec_dot_f16 in ggml_compute_forward_flash_attn_f16 2023-01-07 19:19:40 +02:00
65fdcbbbbb whisper : revert accidental MB change 2023-01-07 16:18:21 +02:00
d61d55cd4b ggml : speed-up soft max via Accelerate + unroll 2023-01-07 16:16:42 +02:00
d51fc3ee0a ggml : use vDSP_sve and vDSP_maxv from Accelerate 2023-01-07 16:10:16 +02:00
f82a7dd019 ggml : make gcc happy (minor) 2023-01-07 09:34:39 +02:00
87dd4a3081 talk.wasm : bump memory usage + update whisper.js 2023-01-06 21:13:44 +02:00
41e05c6b1b cmake : support AVX2 in Windows better (#381) 2023-01-06 19:36:33 +02:00
fa379cb22a Revert "tmp"
This reverts commit 1652965529.
2023-01-06 19:33:09 +02:00
322f4e6c4e go : bindings updated so they can be used in third party packages. (#379)
* Updated bindings so they can be used in third pary packages.

* Updated makefiles to set FMA flag on optionally, for xeon E5 on Darwin
2023-01-06 19:32:28 +02:00
1652965529 tmp 2023-01-06 19:32:12 +02:00
6042c7a3be cmake : change min required version to 3.0 (#351)
We increase the min version only when want to use particular
functionality that is available in the newer version
2023-01-06 19:25:28 +02:00
6b351bb669 command : add "guided-mode" video demo in the README.md 2023-01-06 18:59:26 +02:00
a62170c656 ggml : add SSE3 and fp16 conversion lookup table (#368)
* Improves WASM performance:
  On MacBook M1 Pro, I observe 25% faster using Firefox and 35% faster using Chrome

* Add support for SSE3 SIMD

* Add SSE3 to system information

* Add Imath support for fp16-fp32 conversions

* Add Imath to system information

* Wrap Imath calls to avoid static function warnings

* Drop Imath; Add lookup table for f16 -> f32 conversions

* Remove TODO comments

* Update SSE3 to new macro arguments

* Correct updated macro definitions

* Prefer static inline where possible

* ggml : static inlines + add public f16 <-> f32 conversions

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2023-01-06 18:45:59 +02:00
1944e7c33e whisper : document POWER VSX support 2023-01-05 23:53:00 +02:00
49a8dd6732 ggml : reorganize POWER9 ppc64le SIMD code 2023-01-05 23:53:00 +02:00
8c7f642286 ggml : change f16 load and store macro arguments 2023-01-05 23:53:00 +02:00
ad2a4ffa03 whisper : do not use F16 tensors when in F32 mode (#369) 2023-01-05 22:56:25 +02:00
b3c865083e ci : add emscripten build 2023-01-05 22:10:20 +02:00
a0d4f8e65c main : make whisper_print_segment_callback() more readable (close #371) 2023-01-05 21:45:05 +02:00
4a214d2f07 cmake : add CMAKE_RUNTIME_OUTPUT_DIRECTORY
Currently needed by the wasm examples
2023-01-05 21:40:59 +02:00
0a0cfa7985 ggml : add void to argument-less functions 2023-01-05 21:40:38 +02:00
196d738974 minor : close #370 + Makefile build info print change 2023-01-05 21:35:45 +02:00
84c6b42e65 cmake : update to 3.19 (#351)
- update from 3.0 (from 2014) to 3.19 (from 2020)
- move some global setting onto the targets (through a cmake include)
2023-01-05 21:22:48 +02:00
dd6d582977 whisper : use ranged-based for loops for readability 2023-01-05 21:20:44 +02:00
d51c5eb906 ggml : define MIN / MAX only if not defined (minor) 2023-01-05 21:16:52 +02:00
0be6a1afd9 make : print build information 2023-01-02 13:35:26 +02:00
a466c3404d stream : fix data race on bool + avoid division-by-zero 2023-01-02 10:20:50 +02:00
d629c034a4 models : fix HF model URL (close #356) 2023-01-02 09:54:43 +02:00
f00509d57c command : refactor to split command list & general transcription modes (#331)
This makes it easier to understand if you're looking for only one of the capabilities.
2022-12-31 14:08:57 +02:00
424c410c42 ggml : improve f16 acceleration for POWER9 ppc64le 2022-12-31 10:02:19 +02:00
d97e6005e9 whisper : add whisper_n_audio_ctx and check for invalid audio_ctx
closes #344
2022-12-31 09:57:19 +02:00
3467230a77 models : fix typo in convert-h5-to-ggml.py
signficant -> significant
2022-12-31 09:49:01 +02:00
a091581eb3 cmake : add runtime destination install (#345)
needed for mingw32 build to successfully install the dlls in the correct location
2022-12-31 09:48:00 +02:00
68daf6e487 whisper : avoid some memory allocations 2022-12-30 13:43:48 +02:00
a593b932e4 main : add -ocsv, aka --output-csv to output a CSV file
Adds -ocsv, aka --output-csv feature to examples/main, which outputs a CSV file containing lines formatted as follows <startTime-in-integer-milliseconds>, <endTime-in-integer-milliseconds>, "<transcript-line-including-commas>".
2022-12-29 14:04:00 +02:00
9a8ad3db69 make : add i686 arch (close #329) 2022-12-29 13:58:55 +02:00
36 changed files with 1420 additions and 1070 deletions

View File

@ -235,3 +235,33 @@ jobs:
with:
name: whisper-blas-bin-${{ matrix.arch }}
path: build/bin/${{ matrix.build }}
emscripten:
runs-on: ubuntu-latest
strategy:
matrix:
build: [Release]
steps:
- name: Clone
uses: actions/checkout@v1
- name: Dependencies
run: |
wget -q https://github.com/emscripten-core/emsdk/archive/master.tar.gz
tar -xvf master.tar.gz
emsdk-master/emsdk update
emsdk-master/emsdk install latest
emsdk-master/emsdk activate latest
- name: Configure
run: echo "tmp"
- name: Build
run: |
pushd emsdk-master
source ./emsdk_env.sh
popd
emcmake cmake . -DCMAKE_BUILD_TYPE=${{ matrix.build }}
make

View File

@ -2,14 +2,15 @@ cmake_minimum_required (VERSION 3.0)
project(whisper.cpp VERSION 1.0.4)
set(CMAKE_EXPORT_COMPILE_COMMANDS "on")
# Add path to modules
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
set(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_PREFIX}/lib")
if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
set(WHISPER_STANDALONE ON)
include(cmake/GitVars.cmake)
include(cmake/BuildTypes.cmake)
include(GitVars)
include(BuildTypes)
# configure project version
if (EXISTS "${CMAKE_SOURCE_DIR}/bindings/ios/Makefile-tmpl")
@ -52,6 +53,7 @@ if (APPLE)
option(WHISPER_NO_ACCELERATE "whisper: disable Accelerate framework" OFF)
option(WHISPER_NO_AVX "whisper: disable AVX" OFF)
option(WHISPER_NO_AVX2 "whisper: disable AVX2" OFF)
option(WHISPER_NO_FMA "whisper: disable FMA" OFF)
else()
option(WHISPER_SUPPORT_OPENBLAS "whisper: support for OpenBLAS" OFF)
endif()
@ -82,9 +84,6 @@ endif()
# dependencies
set(CMAKE_C_STANDARD 11)
set(CMAKE_CXX_STANDARD 11)
find_package(Threads REQUIRED)
# on APPLE - include Accelerate framework
@ -131,6 +130,7 @@ if (WHISPER_ALL_WARNINGS)
-Wcast-qual \
-Wstrict-prototypes \
-Wpointer-arith \
-Wno-unused-function \
")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} \
-Wall \
@ -157,6 +157,7 @@ else()
if (MSVC)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX2")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /arch:AVX2")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /arch:AVX2")
else()
if (EMSCRIPTEN)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -pthread")
@ -168,7 +169,10 @@ else()
if(NOT WHISPER_NO_AVX2)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx2")
endif()
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mfma -mf16c")
if(NOT WHISPER_NO_FMA)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mfma")
endif()
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mf16c")
endif()
endif()
endif()
@ -190,6 +194,8 @@ add_library(${TARGET}
whisper.cpp
)
include(DefaultTargetOptions)
target_include_directories(${TARGET} PUBLIC
.
)
@ -223,6 +229,7 @@ target_compile_definitions(${TARGET} PUBLIC
install(TARGETS ${TARGET}
LIBRARY DESTINATION lib
ARCHIVE DESTINATION lib/static
RUNTIME DESTINATION bin
)
#

View File

@ -10,6 +10,9 @@ ifndef UNAME_M
UNAME_M := $(shell uname -m)
endif
CCV := $(shell $(CC) --version | head -n 1)
CXXV := $(shell $(CXX) --version | head -n 1)
# Mac OS + Arm can report x86_64
# ref: https://github.com/ggerganov/whisper.cpp/issues/66#issuecomment-1282546789
ifeq ($(UNAME_S),Darwin)
@ -53,10 +56,13 @@ endif
# Architecture specific
# TODO: probably these flags need to be tweaked on some architectures
# feel free to update the Makefile for your architecture and send a pull request or issue
ifeq ($(UNAME_M),x86_64)
ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686))
ifeq ($(UNAME_S),Darwin)
CFLAGS += -mfma -mf16c
CFLAGS += -mf16c
AVX1_M := $(shell sysctl machdep.cpu.features)
ifneq (,$(findstring FMA,$(AVX1_M)))
CFLAGS += -mfma
endif
ifneq (,$(findstring AVX1.0,$(AVX1_M)))
CFLAGS += -mavx
endif
@ -81,6 +87,10 @@ ifeq ($(UNAME_M),x86_64)
ifneq (,$(findstring f16c,$(F16C_M)))
CFLAGS += -mf16c
endif
SSE3_M := $(shell grep "sse3 " /proc/cpuinfo)
ifneq (,$(findstring sse3,$(SSE3_M)))
CFLAGS += -msse3
endif
else ifeq ($(UNAME_S),Haiku)
AVX1_M := $(shell sysinfo -cpu | grep "AVX ")
ifneq (,$(findstring avx,$(AVX1_M)))
@ -141,6 +151,21 @@ ifneq ($(filter armv8%,$(UNAME_M)),)
CFLAGS += -mfp16-format=ieee -mno-unaligned-access
endif
#
# Print build information
#
$(info I whisper.cpp build info: )
$(info I UNAME_S: $(UNAME_S))
$(info I UNAME_P: $(UNAME_P))
$(info I UNAME_M: $(UNAME_M))
$(info I CFLAGS: $(CFLAGS))
$(info I CXXFLAGS: $(CXXFLAGS))
$(info I LDFLAGS: $(LDFLAGS))
$(info I CC: $(CCV))
$(info I CXX: $(CXXV))
$(info )
default: main
#

View File

@ -11,6 +11,7 @@ High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisp
- Plain C/C++ implementation without dependencies
- Apple silicon first-class citizen - optimized via Arm Neon and Accelerate framework
- AVX intrinsics support for x86 architectures
- VSX intrinsics support for POWER architectures
- Mixed F16 / F32 precision
- Low memory usage (Flash Attention + Flash Forward)
- Zero memory allocations at runtime

View File

@ -1,3 +1,2 @@
build
models
go.sum

View File

@ -1,28 +1,27 @@
CMAKE := $(shell which cmake)
BUILD_DIR := "build"
MODELS_DIR := "models"
BUILD_DIR := build
MODELS_DIR := models
EXAMPLES_DIR := $(wildcard examples/*)
C_INCLUDE_PATH := "../.."
INCLUDE_PATH := $(abspath ../..)
LIBRARY_PATH := $(abspath ../..)
all: clean whisper examples
whisper: mkdir
@echo Build whisper
@${CMAKE} -S ../.. -B ${BUILD_DIR} -D BUILD_SHARED_LIBS=off -D WHISPER_NO_AVX2=on
@${CMAKE} --build ${BUILD_DIR} --target whisper
@${MAKE} -C ../.. libwhisper.a
test: model-small whisper modtidy
@go test -v .
@go test -v ./pkg/whisper/...
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go test -v .
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go test -v ./pkg/whisper/...
examples: $(EXAMPLES_DIR)
model-small: mkdir examples/go-model-download
@${BUILD_DIR}/go-model-download -out models small.en
@${BUILD_DIR}/go-model-download -out models ggml-small.en.bin
$(EXAMPLES_DIR): mkdir whisper modtidy
@echo Build example $(notdir $@)
@go build ${BUILD_FLAGS} -o ${BUILD_DIR}/$(notdir $@) ./$@
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go build ${BUILD_FLAGS} -o ${BUILD_DIR}/$(notdir $@) ./$@
mkdir:
@echo Mkdir ${BUILD_DIR}

View File

@ -74,4 +74,27 @@ And you can then test a model against samples with the following command:
./build/go-whisper -model models/ggml-tiny.en.bin samples/jfk.wav
```
## Using the bindings
To use the bindings in your own software,
1. Import `github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper` (or `github.com/ggerganov/whisper.cpp/bindings/go` into your package;
2. Compile `libwhisper.a` (you can use `make whisper` in the `bindings/go` directory);
3. Link your go binary against whisper by setting the environment variables `C_INCLUDE_PATH` and `LIBRARY_PATH`
to point to the `whisper.h` file directory and `libwhisper.a` file directory respectively.
Look at the `Makefile` in the `bindings/go` directory for an example.
The API Documentation:
* https://pkg.go.dev/github.com/ggerganov/whisper.cpp/bindings/go
* https://pkg.go.dev/github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper
Getting help:
* Follow the discussion for the go bindings [here](https://github.com/ggerganov/whisper.cpp/discussions/312)
## License
The license for the Go bindings is the same as the license for the rest of the whisper.cpp project, which is the MIT License. See the `LICENSE` file for more details.

View File

@ -17,15 +17,14 @@ import (
// CONSTANTS
const (
srcUrl = "https://huggingface.co/" // The location of the models
srcPathPrefix = "/datasets/ggerganov/whisper.cpp/resolve/main/ggml" // Filename prefix
srcExt = ".bin" // Filename extension
bufSize = 1024 * 64 // Size of the buffer used for downloading the model
srcUrl = "https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main" // The location of the models
srcExt = ".bin" // Filename extension
bufSize = 1024 * 64 // Size of the buffer used for downloading the model
)
var (
// The models which will be downloaded, if no model is specified as an argument
modelNames = []string{"tiny.en", "tiny", "base.en", "base", "small.en", "small", "medium.en", "medium", "large-v1", "large"}
modelNames = []string{"ggml-tiny.en", "ggml-tiny", "ggml-base.en", "ggml-base", "ggml-small.en", "ggml-small", "ggml-medium.en", "ggml-medium", "ggml-large-v1", "ggml-large"}
)
var (
@ -123,11 +122,14 @@ func GetModels() []string {
// URLForModel returns the URL for the given model on huggingface.co
func URLForModel(model string) (string, error) {
if filepath.Ext(model) != srcExt {
model += srcExt
}
url, err := url.Parse(srcUrl)
if err != nil {
return "", err
} else {
url.Path = srcPathPrefix + "-" + model + srcExt
url.Path = filepath.Join(url.Path, model)
}
return url.String(), nil
}

23
bindings/go/go.sum Normal file
View File

@ -0,0 +1,23 @@
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-audio/audio v1.0.0 h1:zS9vebldgbQqktK4H0lUqWrG8P0NxCJVqcj7ZpNnwd4=
github.com/go-audio/audio v1.0.0/go.mod h1:6uAu0+H2lHkwdGsAY+j2wHPNPpPoeg5AaEFh9FlA+Zs=
github.com/go-audio/riff v1.0.0 h1:d8iCGbDvox9BfLagY94fBynxSPHO80LmZCaOsmKxokA=
github.com/go-audio/riff v1.0.0/go.mod h1:l3cQwc85y79NQFCRB7TiPoNiaijp6q8Z0Uv38rVG498=
github.com/go-audio/wav v1.1.0 h1:jQgLtbqBzY7G+BM8fXF7AHUk1uHUviWS4X39d5rsL2g=
github.com/go-audio/wav v1.1.0/go.mod h1:mpe9qfwbScEbkd8uybLuIpTgHyrISw/OTuvjUW2iGtE=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@ -1,8 +1,5 @@
package whisper
// This file defines the whisper_token, whisper_token_data and whisper_full_params
// structures, which are used by the whisper_full() function.
import (
"fmt"
)

View File

@ -9,8 +9,7 @@ import (
// CGO
/*
#cgo CFLAGS: -I${SRCDIR}/../..
#cgo LDFLAGS: -L${SRCDIR}/build -lwhisper -lm -lstdc++
#cgo LDFLAGS: -lwhisper -lm -lstdc++
#cgo darwin LDFLAGS: -framework Accelerate
#include <whisper.h>
#include <stdlib.h>
@ -171,6 +170,10 @@ func (ctx *Context) Whisper_tokenize(text string, tokens []Token) (int, error) {
}
// Return the id of the specified language, returns -1 if not found
// Examples:
//
// "de" -> 2
// "german" -> 2
func (ctx *Context) Whisper_lang_id(lang string) int {
return int(C.whisper_lang_id(C.CString(lang)))
}
@ -211,6 +214,10 @@ func (ctx *Context) Whisper_n_text_ctx() int {
return int(C.whisper_n_text_ctx((*C.struct_whisper_context)(ctx)))
}
func (ctx *Context) Whisper_n_audio_ctx() int {
return int(C.whisper_n_audio_ctx((*C.struct_whisper_context)(ctx)))
}
func (ctx *Context) Whisper_is_multilingual() int {
return int(C.whisper_is_multilingual((*C.struct_whisper_context)(ctx)))
}

View File

@ -50,7 +50,10 @@ func Test_Whisper_001(t *testing.T) {
ctx := whisper.Whisper_init(ModelPath)
assert.NotNil(ctx)
defer ctx.Whisper_free()
assert.NoError(ctx.Whisper_full(ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY), buf.AsFloat32Buffer().Data, nil, nil))
params := ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY)
data := buf.AsFloat32Buffer().Data
err = ctx.Whisper_full(params, data, nil, nil)
assert.NoError(err)
// Print out tokens
num_segments := ctx.Whisper_full_n_segments()

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,17 @@
# Set the default compile features and properties for a target.
if (NOT TARGET)
message(FATAL_ERROR "TARGET not set before including DefaultTargetOptions")
endif()
target_compile_features(${TARGET}
PRIVATE
cxx_std_11
)
set_target_properties(${TARGET}
PROPERTIES
EXPORT_COMPILE_COMMANDS ON
RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin"
INSTALL_RPATH "${CMAKE_INSTALL_PREFIX}/lib"
)

View File

@ -8,6 +8,8 @@ add_executable(${TARGET}
emscripten.cpp
)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE
whisper
)

View File

@ -1,3 +1,6 @@
set(TARGET bench)
add_executable(${TARGET} bench.cpp)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE whisper ${CMAKE_THREAD_LIBS_INIT})

View File

@ -8,6 +8,8 @@ add_executable(${TARGET}
emscripten.cpp
)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE
whisper
)

View File

@ -2,6 +2,9 @@ if (WHISPER_SUPPORT_SDL2)
# command
set(TARGET command)
add_executable(${TARGET} command.cpp)
include(DefaultTargetOptions)
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
target_link_libraries(${TARGET} PRIVATE whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
endif ()

View File

@ -9,7 +9,19 @@ More info is available in [issue #171](https://github.com/ggerganov/whisper.cpp/
# On Raspberry Pi, use tiny or base models + "-ac 768" for better performance
./command -m ./models/ggml-tiny.en.bin -ac 768 -t 3 -c 0
```
https://user-images.githubusercontent.com/1991296/204038393-2f846eae-c255-4099-a76d-5735c25c49da.mp4
Web version: [examples/command.wasm](/examples/command.wasm)
## Guided mode
"Guided mode" allows you to specify a list of commands (i.e. strings) and the transcription will be guided to classify your command into one from the list. This can be useful in situations where a device is listening only for a small subset of commands.
Initial tests show that this approach might be extremely efficient in terms of performance, since it integrates very well with the "partial Encoder" idea from #137.
```bash
# Run in guided mode, the list of allowed commands is in commands.txt
./command -m ./models/ggml-base.en.bin -cmd ./examples/command/commands.txt
@ -17,9 +29,8 @@ More info is available in [issue #171](https://github.com/ggerganov/whisper.cpp/
./command -m ./models/ggml-tiny.en.bin -cmd ./examples/command/commands.txt -ac 128 -t 3 -c 0
```
https://user-images.githubusercontent.com/1991296/204038393-2f846eae-c255-4099-a76d-5735c25c49da.mp4
https://user-images.githubusercontent.com/1991296/207435352-8fc4ed3f-bde5-4555-9b8b-aeeb76bee969.mp4
Web version: [examples/command.wasm](/examples/command.wasm)
## Building

View File

@ -510,6 +510,333 @@ std::vector<std::string> read_allowed_commands(const std::string & fname) {
return allowed_commands;
}
// command-list mode
// guide the transcription to match the most likely command from a provided list
int process_command_list(struct whisper_context * ctx, audio_async &audio, const whisper_params &params) {
fprintf(stderr, "\n");
fprintf(stderr, "%s: guided mode\n", __func__);
std::vector<std::string> allowed_commands = read_allowed_commands(params.commands);
if (allowed_commands.empty()) {
fprintf(stderr, "%s: error: failed to read allowed commands from '%s'\n", __func__, params.commands.c_str());
return 2;
}
int max_len = 0;
std::vector<std::vector<whisper_token>> allowed_tokens;
for (const auto & cmd : allowed_commands) {
whisper_token tokens[1024];
allowed_tokens.emplace_back();
for (int l = 0; l < (int) cmd.size(); ++l) {
// NOTE: very important to add the whitespace !
// the reason is that the first decoded token starts with a whitespace too!
std::string ss = std::string(" ") + cmd.substr(0, l + 1);
const int n = whisper_tokenize(ctx, ss.c_str(), tokens, 1024);
if (n < 0) {
fprintf(stderr, "%s: error: failed to tokenize command '%s'\n", __func__, cmd.c_str());
return 3;
}
if (n == 1) {
allowed_tokens.back().push_back(tokens[0]);
}
}
max_len = std::max(max_len, (int) cmd.size());
}
fprintf(stderr, "%s: allowed commands [ tokens ]:\n", __func__);
fprintf(stderr, "\n");
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
fprintf(stderr, " - \033[1m%-*s\033[0m = [", max_len, allowed_commands[i].c_str());
for (const auto & token : allowed_tokens[i]) {
fprintf(stderr, " %5d", token);
}
fprintf(stderr, " ]\n");
}
std::string k_prompt = "select one from the available words: ";
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
if (i > 0) {
k_prompt += ", ";
}
k_prompt += allowed_commands[i];
}
k_prompt += ". selected word: ";
// tokenize prompt
std::vector<whisper_token> k_tokens;
{
k_tokens.resize(1024);
const int n = whisper_tokenize(ctx, k_prompt.c_str(), k_tokens.data(), 1024);
if (n < 0) {
fprintf(stderr, "%s: error: failed to tokenize prompt '%s'\n", __func__, k_prompt.c_str());
return 4;
}
k_tokens.resize(n);
}
fprintf(stderr, "\n");
fprintf(stderr, "%s: prompt: '%s'\n", __func__, k_prompt.c_str());
fprintf(stderr, "%s: tokens: [", __func__);
for (const auto & token : k_tokens) {
fprintf(stderr, " %d", token);
}
fprintf(stderr, " ]\n");
fprintf(stderr, "\n");
fprintf(stderr, "%s: listening for a command ...\n", __func__);
fprintf(stderr, "\n");
bool is_running = true;
std::vector<float> pcmf32_cur;
std::vector<float> pcmf32_prompt;
// main loop
while (is_running) {
// handle Ctrl + C
{
SDL_Event event;
while (SDL_PollEvent(&event)) {
switch (event.type) {
case SDL_QUIT:
{
is_running = false;
} break;
default:
break;
}
}
if (!is_running) {
return 0;
}
}
// delay
std::this_thread::sleep_for(std::chrono::milliseconds(100));
audio.get(2000, pcmf32_cur);
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
const auto t_start = std::chrono::high_resolution_clock::now();
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
wparams.print_progress = false;
wparams.print_special = params.print_special;
wparams.print_realtime = false;
wparams.print_timestamps = !params.no_timestamps;
wparams.translate = params.translate;
wparams.no_context = true;
wparams.single_segment = true;
wparams.max_tokens = 1;
wparams.language = params.language.c_str();
wparams.n_threads = params.n_threads;
wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;
wparams.prompt_tokens = k_tokens.data();
wparams.prompt_n_tokens = k_tokens.size();
// run the transformer and a single decoding pass
if (whisper_full(ctx, wparams, pcmf32_cur.data(), pcmf32_cur.size()) != 0) {
fprintf(stderr, "%s: ERROR: whisper_full() failed\n", __func__);
break;
}
const auto * probs = whisper_get_probs(ctx);
std::vector<std::pair<float, int>> probs_id;
double psum = 0.0;
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
probs_id.emplace_back(probs[allowed_tokens[i][0]], i);
for (int j = 1; j < (int) allowed_tokens[i].size(); ++j) {
probs_id.back().first += probs[allowed_tokens[i][j]];
}
probs_id.back().first /= allowed_tokens[i].size();
psum += probs_id.back().first;
}
// normalize
for (auto & p : probs_id) {
p.first /= psum;
}
// sort descending
{
using pair_type = decltype(probs_id)::value_type;
std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) {
return a.first > b.first;
});
}
// print the commands and the respective probabilities
{
fprintf(stdout, "\n");
for (const auto & cmd : probs_id) {
fprintf(stdout, "%s: %s%-*s%s = %f | ", __func__, "\033[1m", max_len, allowed_commands[cmd.second].c_str(), "\033[0m", cmd.first);
for (int token : allowed_tokens[cmd.second]) {
fprintf(stdout, "'%4s' %f ", whisper_token_to_str(ctx, token), probs[token]);
}
fprintf(stdout, "\n");
}
}
// best command
{
const auto t_end = std::chrono::high_resolution_clock::now();
const float prob = probs_id[0].first;
const int index = probs_id[0].second;
fprintf(stdout, "\n");
fprintf(stdout, "%s: detected command: %s%s%s | p = %f | t = %d ms\n", __func__,
"\033[1m", allowed_commands[index].c_str(), "\033[0m", prob,
(int) std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count());
fprintf(stdout, "\n");
}
audio.clear();
}
}
return 0;
}
// general-purpose mode
// freely transcribe the voice into text
int process_general_transcription(struct whisper_context * ctx, audio_async &audio, const whisper_params &params) {
bool is_running = true;
bool have_prompt = false;
bool ask_prompt = true;
float prob0 = 0.0f;
float prob = 0.0f;
std::vector<float> pcmf32_cur;
std::vector<float> pcmf32_prompt;
const std::string k_prompt = "Ok Whisper, start listening for commands.";
fprintf(stderr, "\n");
fprintf(stderr, "%s: general-purpose mode\n", __func__);
// main loop
while (is_running) {
// handle Ctrl + C
{
SDL_Event event;
while (SDL_PollEvent(&event)) {
switch (event.type) {
case SDL_QUIT:
{
is_running = false;
} break;
default:
break;
}
}
if (!is_running) {
return 0;
}
}
// delay
std::this_thread::sleep_for(std::chrono::milliseconds(100));
if (ask_prompt) {
fprintf(stdout, "\n");
fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m");
fprintf(stdout, "\n");
ask_prompt = false;
}
{
audio.get(2000, pcmf32_cur);
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
int64_t t_ms = 0;
if (!have_prompt) {
// wait for activation phrase
audio.get(params.prompt_ms, pcmf32_cur);
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob0, t_ms));
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms);
const float sim = similarity(txt, k_prompt);
if (txt.length() < 0.8*k_prompt.length() || txt.length() > 1.2*k_prompt.length() || sim < 0.8f) {
fprintf(stdout, "%s: WARNING: prompt not recognized, try again\n", __func__);
ask_prompt = true;
} else {
fprintf(stdout, "\n");
fprintf(stdout, "%s: The prompt has been recognized!\n", __func__);
fprintf(stdout, "%s: Waiting for voice commands ...\n", __func__);
fprintf(stdout, "\n");
// save the audio for the prompt
pcmf32_prompt = pcmf32_cur;
have_prompt = true;
}
} else {
// we have heard the activation phrase, now detect the commands
audio.get(params.command_ms, pcmf32_cur);
// prepend the prompt audio
pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
prob = 100.0f*(prob - prob0);
//fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str());
// find the prompt in the text
float best_sim = 0.0f;
size_t best_len = 0;
for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
const auto prompt = txt.substr(0, n);
const float sim = similarity(prompt, k_prompt);
//fprintf(stderr, "%s: prompt = '%s', sim = %f\n", __func__, prompt.c_str(), sim);
if (sim > best_sim) {
best_sim = sim;
best_len = n;
}
}
const std::string command = ::trim(txt.substr(best_len));
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
fprintf(stdout, "\n");
}
audio.clear();
}
}
}
return 0;
}
int main(int argc, char ** argv) {
whisper_params params;
@ -561,300 +888,12 @@ int main(int argc, char ** argv) {
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
audio.clear();
int max_len = 0;
bool is_running = true;
bool have_prompt = false;
bool ask_prompt = true;
float prob0 = 0.0f;
float prob = 0.0f;
std::vector<float> pcmf32_cur;
std::vector<float> pcmf32_prompt;
std::vector<std::string> allowed_commands;
std::vector<std::vector<whisper_token>> allowed_tokens;
std::string k_prompt;
std::vector<whisper_token> k_tokens;
int ret_val = 0;
if (!params.commands.empty()) {
fprintf(stderr, "\n");
fprintf(stderr, "%s: guided mode\n", __func__);
allowed_commands = read_allowed_commands(params.commands);
if (allowed_commands.empty()) {
fprintf(stderr, "%s: error: failed to read allowed commands from '%s'\n", __func__, params.commands.c_str());
return 2;
}
for (const auto & cmd : allowed_commands) {
whisper_token tokens[1024];
allowed_tokens.emplace_back();
for (int l = 0; l < (int) cmd.size(); ++l) {
// NOTE: very important to add the whitespace !
// the reason is that the first decoded token starts with a whitespace too!
std::string ss = std::string(" ") + cmd.substr(0, l + 1);
const int n = whisper_tokenize(ctx, ss.c_str(), tokens, 1024);
if (n < 0) {
fprintf(stderr, "%s: error: failed to tokenize command '%s'\n", __func__, cmd.c_str());
return 3;
}
if (n == 1) {
allowed_tokens.back().push_back(tokens[0]);
}
}
max_len = std::max(max_len, (int) cmd.size());
}
fprintf(stderr, "%s: allowed commands [ tokens ]:\n", __func__);
fprintf(stderr, "\n");
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
fprintf(stderr, " - \033[1m%-*s\033[0m = [", max_len, allowed_commands[i].c_str());
for (const auto & token : allowed_tokens[i]) {
fprintf(stderr, " %5d", token);
}
fprintf(stderr, " ]\n");
}
k_prompt = "select one from the available words: ";
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
if (i > 0) {
k_prompt += ", ";
}
k_prompt += allowed_commands[i];
}
k_prompt += ". selected word: ";
// tokenize prompt
{
k_tokens.resize(1024);
const int n = whisper_tokenize(ctx, k_prompt.c_str(), k_tokens.data(), 1024);
if (n < 0) {
fprintf(stderr, "%s: error: failed to tokenize prompt '%s'\n", __func__, k_prompt.c_str());
return 4;
}
k_tokens.resize(n);
}
fprintf(stderr, "\n");
fprintf(stderr, "%s: prompt: '%s'\n", __func__, k_prompt.c_str());
fprintf(stderr, "%s: tokens: [", __func__);
for (const auto & token : k_tokens) {
fprintf(stderr, " %d", token);
}
fprintf(stderr, " ]\n");
fprintf(stderr, "\n");
fprintf(stderr, "%s: listening for a command ...\n", __func__);
fprintf(stderr, "\n");
ret_val = process_command_list(ctx, audio, params);
} else {
fprintf(stderr, "\n");
fprintf(stderr, "%s: general-purpose mode\n", __func__);
k_prompt = "Ok Whisper, start listening for commands.";
}
// main loop
while (is_running) {
// handle Ctrl + C
{
SDL_Event event;
while (SDL_PollEvent(&event)) {
switch (event.type) {
case SDL_QUIT:
{
is_running = false;
} break;
default:
break;
}
}
if (!is_running) {
break;
}
}
// delay
std::this_thread::sleep_for(std::chrono::milliseconds(100));
if (allowed_commands.empty()) {
// general-purpose mode
// freely transcribe the voice into text
if (ask_prompt) {
fprintf(stdout, "\n");
fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m");
fprintf(stdout, "\n");
ask_prompt = false;
}
{
int64_t t_ms = 0;
audio.get(2000, pcmf32_cur);
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
if (!have_prompt) {
// wait for activation phrase
audio.get(params.prompt_ms, pcmf32_cur);
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob0, t_ms));
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms);
const float sim = similarity(txt, k_prompt);
if (txt.length() < 0.8*k_prompt.length() || txt.length() > 1.2*k_prompt.length() || sim < 0.8f) {
fprintf(stdout, "%s: WARNING: prompt not recognized, try again\n", __func__);
ask_prompt = true;
} else {
fprintf(stdout, "\n");
fprintf(stdout, "%s: The prompt has been recognized!\n", __func__);
fprintf(stdout, "%s: Waiting for voice commands ...\n", __func__);
fprintf(stdout, "\n");
// save the audio for the prompt
pcmf32_prompt = pcmf32_cur;
have_prompt = true;
}
} else {
// we have heard the activation phrase, now detect the commands
audio.get(params.command_ms, pcmf32_cur);
// prepend the prompt audio
pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
prob = 100.0f*(prob - prob0);
//fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str());
// find the prompt in the text
float best_sim = 0.0f;
size_t best_len = 0;
for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
const auto prompt = txt.substr(0, n);
const float sim = similarity(prompt, k_prompt);
//fprintf(stderr, "%s: prompt = '%s', sim = %f\n", __func__, prompt.c_str(), sim);
if (sim > best_sim) {
best_sim = sim;
best_len = n;
}
}
const std::string command = ::trim(txt.substr(best_len));
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
fprintf(stdout, "\n");
}
audio.clear();
}
}
} else {
// command-list mode
// guide the transcription to match the most likely command from a provided list
audio.get(2000, pcmf32_cur);
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
const auto t_start = std::chrono::high_resolution_clock::now();
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
wparams.print_progress = false;
wparams.print_special = params.print_special;
wparams.print_realtime = false;
wparams.print_timestamps = !params.no_timestamps;
wparams.translate = params.translate;
wparams.no_context = true;
wparams.single_segment = true;
wparams.max_tokens = 1;
wparams.language = params.language.c_str();
wparams.n_threads = params.n_threads;
wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;
wparams.prompt_tokens = k_tokens.data();
wparams.prompt_n_tokens = k_tokens.size();
// run the transformer and a single decoding pass
if (whisper_full(ctx, wparams, pcmf32_cur.data(), pcmf32_cur.size()) != 0) {
fprintf(stderr, "%s: ERROR: whisper_full() failed\n", __func__);
break;
}
const auto * probs = whisper_get_probs(ctx);
std::vector<std::pair<float, int>> probs_id;
double psum = 0.0;
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
probs_id.emplace_back(probs[allowed_tokens[i][0]], i);
for (int j = 1; j < (int) allowed_tokens[i].size(); ++j) {
probs_id.back().first += probs[allowed_tokens[i][j]];
}
probs_id.back().first /= allowed_tokens[i].size();
psum += probs_id.back().first;
}
// normalize
for (auto & p : probs_id) {
p.first /= psum;
}
// sort descending
{
using pair_type = decltype(probs_id)::value_type;
std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) {
return a.first > b.first;
});
}
// print the commands and the respective probabilities
{
fprintf(stdout, "\n");
for (const auto & cmd : probs_id) {
fprintf(stdout, "%s: %s%-*s%s = %f | ", __func__, "\033[1m", max_len, allowed_commands[cmd.second].c_str(), "\033[0m", cmd.first);
for (int i = 0; i < (int) allowed_tokens[cmd.second].size(); ++i) {
fprintf(stdout, "'%4s' %f ", whisper_token_to_str(ctx, allowed_tokens[cmd.second][i]), probs[allowed_tokens[cmd.second][i]]);
}
fprintf(stdout, "\n");
}
}
// best command
{
const auto t_end = std::chrono::high_resolution_clock::now();
fprintf(stdout, "\n");
fprintf(stdout, "%s: detected command: %s%s%s | p = %f | t = %d ms\n", __func__,
"\033[1m", allowed_commands[probs_id[0].second].c_str(), "\033[0m", probs_id[0].first,
(int) std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count());
fprintf(stdout, "\n");
}
audio.clear();
}
}
ret_val = process_general_transcription(ctx, audio, params);
}
audio.pause();
@ -862,5 +901,5 @@ int main(int argc, char ** argv) {
whisper_print_timings(ctx);
whisper_free(ctx);
return 0;
return ret_val;
}

View File

@ -1,3 +1,6 @@
set(TARGET main)
add_executable(${TARGET} main.cpp)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE whisper ${CMAKE_THREAD_LIBS_INIT})

View File

@ -69,6 +69,7 @@ struct whisper_params {
bool output_vtt = false;
bool output_srt = false;
bool output_wts = false;
bool output_csv = false;
bool print_special = false;
bool print_colors = false;
bool print_progress = false;
@ -111,6 +112,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; }
else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
@ -150,6 +152,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false");
fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false");
fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
@ -173,90 +176,81 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi
const int n_segments = whisper_full_n_segments(ctx);
std::string speaker = "";
int64_t t0;
int64_t t1;
// print the last n_new segments
const int s0 = n_segments - n_new;
if (s0 == 0) {
printf("\n");
}
for (int i = s0; i < n_segments; i++) {
if (params.no_timestamps) {
if (params.print_colors) {
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
if (params.print_special == false) {
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
if (id >= whisper_token_eot(ctx)) {
continue;
}
}
const char * text = whisper_full_get_token_text(ctx, i, j);
const float p = whisper_full_get_token_p (ctx, i, j);
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
}
} else {
const char * text = whisper_full_get_segment_text(ctx, i);
printf("%s", text);
}
fflush(stdout);
} else {
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
std::string speaker;
if (params.diarize && pcmf32s.size() == 2) {
const int64_t n_samples = pcmf32s[0].size();
const int64_t is0 = timestamp_to_sample(t0, n_samples);
const int64_t is1 = timestamp_to_sample(t1, n_samples);
double energy0 = 0.0f;
double energy1 = 0.0f;
for (int64_t j = is0; j < is1; j++) {
energy0 += fabs(pcmf32s[0][j]);
energy1 += fabs(pcmf32s[1][j]);
}
if (energy0 > 1.1*energy1) {
speaker = "(speaker 0)";
} else if (energy1 > 1.1*energy0) {
speaker = "(speaker 1)";
} else {
speaker = "(speaker ?)";
}
//printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str());
}
if (params.print_colors) {
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
if (params.print_special == false) {
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
if (id >= whisper_token_eot(ctx)) {
continue;
}
}
const char * text = whisper_full_get_token_text(ctx, i, j);
const float p = whisper_full_get_token_p (ctx, i, j);
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
}
printf("\n");
} else {
const char * text = whisper_full_get_segment_text(ctx, i);
printf("[%s --> %s] %s%s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), speaker.c_str(), text);
}
if (!params.no_timestamps || params.diarize) {
t0 = whisper_full_get_segment_t0(ctx, i);
t1 = whisper_full_get_segment_t1(ctx, i);
}
if (!params.no_timestamps) {
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
}
if (params.diarize && pcmf32s.size() == 2) {
const int64_t n_samples = pcmf32s[0].size();
const int64_t is0 = timestamp_to_sample(t0, n_samples);
const int64_t is1 = timestamp_to_sample(t1, n_samples);
double energy0 = 0.0f;
double energy1 = 0.0f;
for (int64_t j = is0; j < is1; j++) {
energy0 += fabs(pcmf32s[0][j]);
energy1 += fabs(pcmf32s[1][j]);
}
if (energy0 > 1.1*energy1) {
speaker = "(speaker 0)";
} else if (energy1 > 1.1*energy0) {
speaker = "(speaker 1)";
} else {
speaker = "(speaker ?)";
}
//printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str());
}
if (params.print_colors) {
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
if (params.print_special == false) {
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
if (id >= whisper_token_eot(ctx)) {
continue;
}
}
const char * text = whisper_full_get_token_text(ctx, i, j);
const float p = whisper_full_get_token_p (ctx, i, j);
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
}
} else {
const char * text = whisper_full_get_segment_text(ctx, i);
printf("%s%s", speaker.c_str(), text);
}
// with timestamps or speakers: each segment on new line
if (!params.no_timestamps || params.diarize) {
printf("\n");
}
fflush(stdout);
}
}
@ -325,6 +319,32 @@ bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_
return true;
}
bool output_csv(struct whisper_context * ctx, const char * fname) {
std::ofstream fout(fname);
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
return false;
}
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
if (text[0] == ' ')
text = text + sizeof(char); //whisper_full_get_segment_text() returns a string with leading space, point to the next character.
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
//need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds.
fout << 10 * t0 << ", "
<< 10 * t1 << ", \""
<< text << "\"\n";
}
return true;
}
// karaoke video generation
// outputs a bash script that uses ffmpeg to generate a video with the subtitles
// TODO: font parameter adjustments
@ -528,7 +548,7 @@ int main(int argc, char ** argv) {
}
if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
fprintf(stderr, "%s: WAV file '%s' must be 16 kHz\n", argv[0], fname_inp.c_str());
fprintf(stderr, "%s: WAV file '%s' must be %i kHz\n", argv[0], fname_inp.c_str(), WHISPER_SAMPLE_RATE/1000);
return 8;
}
@ -674,6 +694,13 @@ int main(int argc, char ** argv) {
const auto fname_wts = fname_inp + ".wts";
output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE);
}
// output to CSV file
if (params.output_csv) {
const auto fname_csv = fname_inp + ".csv";
output_csv(ctx, fname_csv.c_str());
}
}
}

View File

@ -8,6 +8,8 @@ add_executable(${TARGET}
emscripten.cpp
)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE
whisper
)

View File

@ -2,6 +2,9 @@ if (WHISPER_SUPPORT_SDL2)
# stream
set(TARGET stream)
add_executable(${TARGET} stream.cpp)
include(DefaultTargetOptions)
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
target_link_libraries(${TARGET} PRIVATE whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
endif ()

View File

@ -8,6 +8,7 @@
#include <SDL.h>
#include <SDL_audio.h>
#include <atomic>
#include <cassert>
#include <cstdio>
#include <string>
@ -144,8 +145,8 @@ private:
int m_len_ms = 0;
int m_sample_rate = 0;
bool m_running = false;
std::mutex m_mutex;
std::atomic_bool m_running;
std::mutex m_mutex;
std::vector<float> m_audio;
std::vector<float> m_audio_new;
@ -155,6 +156,8 @@ private:
audio_async::audio_async(int len_ms) {
m_len_ms = len_ms;
m_running = false;
}
audio_async::~audio_async() {
@ -427,10 +430,10 @@ int main(int argc, char ** argv) {
const int n_samples_keep = (params.keep_ms *1e-3)*WHISPER_SAMPLE_RATE;
const int n_samples_30s = (30000 *1e-3)*WHISPER_SAMPLE_RATE;
const int n_new_line = params.length_ms / params.step_ms - 1; // number of steps to print new line
const bool use_vad = n_samples_step <= 0; // sliding window mode uses VAD
const int n_new_line = !use_vad ? params.length_ms / params.step_ms - 1 : 1; // number of steps to print new line
params.no_timestamps = !use_vad;
params.no_context = use_vad;
params.max_tokens = 0;

View File

@ -9,6 +9,8 @@ add_executable(${TARGET}
gpt-2.cpp
)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE
whisper
)
@ -31,8 +33,8 @@ set_target_properties(${TARGET} PROPERTIES LINK_FLAGS " \
--bind \
-s USE_PTHREADS=1 \
-s PTHREAD_POOL_SIZE=8 \
-s INITIAL_MEMORY=1600MB \
-s TOTAL_MEMORY=1600MB \
-s INITIAL_MEMORY=1800MB \
-s TOTAL_MEMORY=1800MB \
-s FORCE_FILESYSTEM=1 \
-s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap']\" \
${EXTRA_FLAGS} \

View File

@ -36,7 +36,7 @@ In order to run this demo efficiently, you need to have the following:
- Latest Chrome or Firefox browser (Safari is not supported)
- Run this on a desktop or laptop with modern CPU (a mobile phone will likely not be good enough)
- Speak phrases that are no longer than 10 seconds - this is the audio context of the AI
- The web-page uses about 1.6GB of RAM
- The web-page uses about 1.8GB of RAM
Notice that this demo is using the smallest GPT-2 model, so the generated text responses are not always very good.
Also, the prompting strategy can likely be improved to achieve better results.

View File

@ -8,6 +8,9 @@ if (WHISPER_SUPPORT_SDL2)
# TODO: this is temporary
# need to export ggml symbols for MSVC, but too lazy ..
add_executable(${TARGET} talk.cpp gpt-2.cpp ../../ggml.c ../../whisper.cpp)
include(DefaultTargetOptions)
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS} ../../)
target_link_libraries(${TARGET} PRIVATE ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
endif ()

View File

@ -8,6 +8,8 @@ add_executable(${TARGET}
emscripten.cpp
)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE
whisper
)

1224
ggml.c

File diff suppressed because it is too large Load Diff

14
ggml.h
View File

@ -177,13 +177,11 @@ extern "C" {
#include <stddef.h>
#include <stdbool.h>
#define GGML_MAX_DIMS 4
#define GGML_MAX_NODES 4096
#define GGML_MAX_PARAMS 16
#define GGML_MAX_CONTEXTS 64
#define GGML_MAX_OPT 4
#define GGML_MAX_THREADS 64
#define GGML_MAX_THREAD_POOLS 16
#define GGML_MAX_DIMS 4
#define GGML_MAX_NODES 4096
#define GGML_MAX_PARAMS 16
#define GGML_MAX_CONTEXTS 64
#define GGML_MAX_OPT 4
#ifdef __ARM_NEON
// we use the built-in 16-bit float type
@ -733,6 +731,8 @@ int ggml_cpu_has_f16c(void);
int ggml_cpu_has_fp16_va(void);
int ggml_cpu_has_wasm_simd(void);
int ggml_cpu_has_blas(void);
int ggml_cpu_has_sse3(void);
int ggml_cpu_has_vsx(void);
#ifdef __cplusplus
}

View File

@ -56,7 +56,7 @@ def bytes_to_unicode():
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
This is a significant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""

View File

@ -40,7 +40,7 @@ if exist "ggml-%model%.bin" (
goto :eof
)
PowerShell -NoProfile -ExecutionPolicy Bypass -Command "Invoke-WebRequest -Uri https://huggingface.co/datasets/ggerganov/whisper.cpp/raw/main/ggml-%model%.bin -OutFile ggml-%model%.bin"
PowerShell -NoProfile -ExecutionPolicy Bypass -Command "Invoke-WebRequest -Uri https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main/ggml-%model%.bin -OutFile ggml-%model%.bin"
if %ERRORLEVEL% neq 0 (
echo Failed to download ggml model %model%

View File

@ -204,6 +204,10 @@ struct whisper_vocab {
std::map<token, id> token_to_id;
std::map<id, token> id_to_token;
// used to avoid memory allocations during sampling
// TODO: move to whisper_context in the future
std::vector<std::pair<double, whisper_vocab::id>> probs_id;
id token_eot = 50256;
id token_sot = 50257;
id token_prev = 50360;
@ -408,6 +412,8 @@ struct whisper_context {
std::vector<uint8_t> buf_compute;
std::vector<uint8_t> buf_compute_layer;
ggml_type wtype; // weight type (FP32 or FP16)
whisper_model model;
whisper_vocab vocab;
@ -431,9 +437,8 @@ struct whisper_context {
};
template<typename T>
static void read_safe(std::ifstream& fin, T& dest)
{
fin.read((char*)& dest, sizeof(T));
static void read_safe(std::ifstream& fin, T& dest) {
fin.read((char*)& dest, sizeof(T));
}
// load the model from a ggml file
@ -551,6 +556,9 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
std::string word;
std::vector<char> tmp;
tmp.reserve(128);
for (int i = 0; i < n_vocab; i++) {
uint32_t len;
read_safe(fin, len);
@ -603,6 +611,11 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
vocab.id_to_token[i] = word;
}
}
wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx);
wctx.probs.reserve(vocab.n_vocab*model.hparams.n_text_ctx);
vocab.probs_id.reserve(n_vocab);
}
{
@ -618,7 +631,9 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
// for the big tensors, we have the option to store the data in 16-bit floats
// in order to save memory and also to speed up the computation
const ggml_type wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
wctx.wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
const ggml_type wtype = wctx.wtype;
size_t ctx_size = 0;
@ -639,7 +654,6 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
// encoder
{
// TODO: F16 .. maybe not?
ctx_size += n_audio_ctx*n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_pe;
ctx_size += 3*n_mels*n_audio_state*ggml_type_size(wtype); // e_conv_1_w
@ -654,7 +668,6 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
// decoder
{
// TODO: F16 .. maybe not?
ctx_size += n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // d_pe;
ctx_size += n_vocab*n_text_state*ggml_type_size(wtype); // d_te;
@ -971,8 +984,8 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
const int n_mem = n_text_layer*n_text_ctx;
const int n_elements = n_text_state*n_mem;
model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
model.memory_k = ggml_new_tensor_1d(ctx, wtype, n_elements);
model.memory_v = ggml_new_tensor_1d(ctx, wtype, n_elements);
}
// key/value memory for the cross-attention layer
@ -982,8 +995,8 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
const int n_mem = n_text_layer*n_audio_ctx;
const int n_elements = n_text_state*n_mem;
model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
model.memory_cross_k = ggml_new_tensor_1d(ctx, wtype, n_elements);
model.memory_cross_v = ggml_new_tensor_1d(ctx, wtype, n_elements);
}
const size_t memory_size =
@ -1021,7 +1034,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
std::string name;
std::vector<char> tmp(length); // create a buffer
fin.read( &tmp[0], tmp.size() ); // read to buffer
fin.read(&tmp[0], tmp.size()); // read to buffer
name.assign(&tmp[0], tmp.size());
if (model.tensors.find(name) == model.tensors.end()) {
@ -1229,14 +1242,14 @@ static bool whisper_encode(
ggml_permute(ctxL,
ggml_cpy(ctxL,
Qcur,
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
0, 2, 1, 3);
struct ggml_tensor * K =
ggml_permute(ctxL,
ggml_cpy(ctxL,
Kcur,
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
0, 2, 1, 3);
struct ggml_tensor * V =
@ -1246,7 +1259,7 @@ static bool whisper_encode(
Vcur,
n_state/n_head, n_head, n_ctx),
1, 2, 0, 3),
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_ctx, n_state/n_head, n_head)
ggml_new_tensor_3d(ctxL, wctx.wtype, n_ctx, n_state/n_head, n_head)
);
struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false);
@ -1262,7 +1275,7 @@ static bool whisper_encode(
ggml_permute(ctxL,
ggml_cpy(ctxL,
Kcur,
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
0, 2, 1, 3);
// K * Q
@ -1280,7 +1293,7 @@ static bool whisper_encode(
// ggml_permute(ctxL,
// ggml_cpy(ctxL,
// Vcur,
// ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
// ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
// 1, 2, 0, 3);
//struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
@ -1292,7 +1305,7 @@ static bool whisper_encode(
Vcur,
n_state/n_head, n_head, n_ctx),
0, 2, 1, 3),
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_ctx, n_head)
ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_ctx, n_head)
);
struct ggml_tensor * KQV = ggml_mul_mat(ctxL, ggml_transpose(ctxL, V), KQ_soft_max);
@ -1337,7 +1350,7 @@ static bool whisper_encode(
#ifdef USE_FLASH_FF
cur = ggml_flash_ff(ctxL,
ggml_cpy(ctxL, cur, ggml_new_tensor_2d(ctxL, GGML_TYPE_F16, n_state, N)),
ggml_cpy(ctxL, cur, ggml_new_tensor_2d(ctxL, wctx.wtype, n_state, N)),
layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
#else
// fully connected
@ -1444,7 +1457,7 @@ static bool whisper_encode(
layer.cross_attn_k_w,
cur);
Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
//Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
struct ggml_tensor * Vcross = ggml_mul_mat(ctx0,
layer.cross_attn_v_w,
@ -1566,14 +1579,14 @@ static bool whisper_decode(
Qcur),
Qcur);
Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
//Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
// note: no bias for Key
struct ggml_tensor * Kcur = ggml_mul_mat(ctxL,
layer.attn_k_w,
cur);
Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
//Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
struct ggml_tensor * Vcur = ggml_mul_mat(ctxL,
layer.attn_v_w,
@ -1596,6 +1609,33 @@ static bool whisper_decode(
// ------
#ifdef USE_FLASH_ATTN
struct ggml_tensor * Q =
ggml_permute(ctxL,
ggml_cpy(ctxL,
Qcur,
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
0, 2, 1, 3);
struct ggml_tensor * K =
ggml_permute(ctxL,
ggml_reshape_3d(ctxL,
ggml_view_1d(ctxL, model.memory_k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_k)*n_state),
n_state/n_head, n_head, n_past + N),
0, 2, 1, 3);
struct ggml_tensor * V =
ggml_cpy(ctxL,
ggml_permute(ctxL,
ggml_reshape_3d(ctxL,
ggml_view_1d(ctxL, model.memory_v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_v)*n_state),
n_state/n_head, n_head, n_past + N),
1, 2, 0, 3),
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_past + N, n_state/n_head, n_head)
);
struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, true);
#else
struct ggml_tensor * Q =
ggml_permute(ctxL,
ggml_cpy(ctxL,
@ -1613,13 +1653,13 @@ static bool whisper_decode(
// K * Q
struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q);
//struct ggml_tensor * KQ_scaled =
// ggml_scale(ctxL,
// KQ,
// ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
// );
struct ggml_tensor * KQ_scaled =
ggml_scale(ctxL,
KQ,
ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
);
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ, n_past);
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ_scaled, n_past);
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ_masked);
@ -1631,6 +1671,7 @@ static bool whisper_decode(
1, 2, 0, 3);
struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
#endif
struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
@ -1676,7 +1717,7 @@ static bool whisper_decode(
Qcur),
Qcur);
Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
//Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
// Kcross is already scaled
struct ggml_tensor * Kcross =
@ -1691,6 +1732,24 @@ static bool whisper_decode(
// ------
#ifdef USE_FLASH_ATTN
struct ggml_tensor * Q =
ggml_permute(ctxL,
ggml_cpy(ctxL,
Qcur,
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
0, 2, 1, 3);
struct ggml_tensor * K = ggml_permute(ctxL, Kcross, 0, 2, 1, 3);
struct ggml_tensor * V =
ggml_cpy(ctxL,
ggml_permute(ctxL, Vcross, 1, 2, 0, 3),
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, M, n_state/n_head, n_head)
);
struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false);
#else
struct ggml_tensor * Q =
ggml_permute(ctxL,
ggml_cpy(ctxL,
@ -1717,6 +1776,7 @@ static bool whisper_decode(
struct ggml_tensor * V_trans = ggml_permute(ctxL, Vcross, 1, 2, 0, 3);
struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
#endif
struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
@ -1849,7 +1909,7 @@ static bool whisper_decode(
// the most basic sampling scheme - select the top token
static whisper_token_data whisper_sample_best(
const whisper_vocab & vocab,
whisper_vocab & vocab,
const float * probs,
bool force_timestamp,
bool is_initial) {
@ -1857,11 +1917,11 @@ static whisper_token_data whisper_sample_best(
0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
};
int n_logits = vocab.id_to_token.size();
const int n_logits = vocab.n_vocab;
std::vector<std::pair<double, whisper_vocab::id>> probs_id;
probs_id.reserve(n_logits);
auto & probs_id = vocab.probs_id;
probs_id.clear();
for (int i = 0; i < n_logits; i++) {
probs_id.emplace_back(probs[i], i);
}
@ -2001,6 +2061,9 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
std::vector<float> even;
std::vector<float> odd;
even.reserve(N/2);
odd.reserve(N/2);
for (int i = 0; i < N; i++) {
if (i % 2 == 0) {
even.push_back(in[i]);
@ -2434,7 +2497,7 @@ int whisper_lang_auto_detect(
std::vector<std::pair<float, int>> probs_id;
for (const auto & kv : g_lang) {
const auto token_lang = whisper_token_lang(ctx, kv.second.first);
probs_id.emplace_back( ctx->probs[token_lang], kv.second.first );
probs_id.emplace_back(ctx->probs[token_lang], kv.second.first);
}
// sort descending
@ -2458,12 +2521,12 @@ int whisper_lang_auto_detect(
}
{
for (int i = 0; i < (int) probs_id.size(); i++) {
for (const auto & prob : probs_id) {
if (lang_probs) {
lang_probs[probs_id[i].second] = probs_id[i].first;
lang_probs[prob.second] = prob.first;
}
//printf("%s: lang %2d (%3s): %f\n", __func__, probs_id[i].second, whisper_lang_str(probs_id[i].second), probs_id[i].first);
//printf("%s: lang %2d (%3s): %f\n", __func__, prob.second, whisper_lang_str(prob.second), prob.first);
}
}
@ -2482,6 +2545,10 @@ int whisper_n_text_ctx(struct whisper_context * ctx) {
return ctx->model.hparams.n_text_ctx;
}
int whisper_n_audio_ctx(struct whisper_context * ctx) {
return ctx->model.hparams.n_audio_ctx;
}
int whisper_is_multilingual(struct whisper_context * ctx) {
return ctx->vocab.is_multilingual() ? 1 : 0;
}
@ -2562,6 +2629,8 @@ const char * whisper_print_system_info(void) {
s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";
s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | ";
s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | ";
s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | ";
return s.c_str();
}
@ -2807,7 +2876,11 @@ int whisper_full(
std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end());
}
// overwrite audio_ctx
// overwrite audio_ctx, max allowed is hparams.n_audio_ctx
if (params.audio_ctx > whisper_n_audio_ctx(ctx)) {
fprintf(stderr, "%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
return -4;
}
ctx->exp_n_audio_ctx = params.audio_ctx;
// these tokens determine the task that will be performed
@ -3134,7 +3207,7 @@ int whisper_full_parallel(
// separate key + value memory for each processor
{
auto & ctx = model.ctx_mem;
auto & mctx = model.ctx_mem;
const auto & hparams = model.hparams;
@ -3147,8 +3220,8 @@ int whisper_full_parallel(
const int n_mem = n_text_layer*n_text_ctx;
const int n_elements = n_text_state*n_mem;
model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
model.memory_k = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
model.memory_v = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
}
// key/value memory for the cross-attention layer
@ -3158,8 +3231,8 @@ int whisper_full_parallel(
const int n_mem = n_text_layer*n_audio_ctx;
const int n_elements = n_text_state*n_mem;
model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
model.memory_cross_k = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
model.memory_cross_v = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
}
}
}
@ -3203,17 +3276,17 @@ int whisper_full_parallel(
for (int i = 0; i < n_processors - 1; ++i) {
auto & results_i = ctxs[i].result_all;
for (int j = 0; j < (int) results_i.size(); ++j) {
for (auto & result : results_i) {
// correct the segment timestamp taking into account the offset
results_i[j].t0 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
results_i[j].t1 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
result.t0 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
result.t1 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
// make sure that segments are not overlapping
if (!ctx->result_all.empty()) {
results_i[j].t0 = std::max(results_i[j].t0, ctx->result_all.back().t1);
result.t0 = std::max(result.t0, ctx->result_all.back().t1);
}
ctx->result_all.push_back(std::move(results_i[j]));
ctx->result_all.push_back(std::move(result));
// call the new_segment_callback for each segment
if (params.new_segment_callback) {
@ -3308,18 +3381,18 @@ static int64_t sample_to_timestamp(int i_sample) {
static float voice_length(const std::string & text) {
float res = 0.0f;
for (size_t i = 0; i < text.size(); ++i) {
if (text[i] == ' ') {
for (char c : text) {
if (c == ' ') {
res += 0.01f;
} else if (text[i] == ',') {
} else if (c == ',') {
res += 2.00f;
} else if (text[i] == '.') {
} else if (c == '.') {
res += 3.00f;
} else if (text[i] == '!') {
} else if (c == '!') {
res += 3.00f;
} else if (text[i] == '?') {
} else if (c == '?') {
res += 3.00f;
} else if (text[i] >= '0' && text[i] <= '9') {
} else if (c >= '0' && c <= '9') {
res += 3.00f;
} else {
res += 1.00f;

View File

@ -148,7 +148,7 @@ extern "C" {
struct whisper_context * ctx,
const char * text,
whisper_token * tokens,
int n_max_tokens);
int n_max_tokens);
// Largest language id (i.e. number of available languages - 1)
WHISPER_API int whisper_lang_max_id();
@ -177,6 +177,7 @@ extern "C" {
WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length
WHISPER_API int whisper_n_vocab (struct whisper_context * ctx);
WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx);
WHISPER_API int whisper_n_audio_ctx (struct whisper_context * ctx);
WHISPER_API int whisper_is_multilingual(struct whisper_context * ctx);
// The probabilities for the next token