mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-06-25 09:31:44 +00:00
Compare commits
41 Commits
threads
...
fa-decoder
Author | SHA1 | Date | |
---|---|---|---|
e2aa556a99 | |||
f30b5d322c | |||
44efbf7ff1 | |||
d347a59a5f | |||
6394c906af | |||
74ffa14e1d | |||
65fdcbbbbb | |||
d61d55cd4b | |||
d51fc3ee0a | |||
f82a7dd019 | |||
87dd4a3081 | |||
41e05c6b1b | |||
fa379cb22a | |||
322f4e6c4e | |||
1652965529 | |||
6042c7a3be | |||
6b351bb669 | |||
a62170c656 | |||
1944e7c33e | |||
49a8dd6732 | |||
8c7f642286 | |||
ad2a4ffa03 | |||
b3c865083e | |||
a0d4f8e65c | |||
4a214d2f07 | |||
0a0cfa7985 | |||
196d738974 | |||
84c6b42e65 | |||
dd6d582977 | |||
d51c5eb906 | |||
0be6a1afd9 | |||
a466c3404d | |||
d629c034a4 | |||
f00509d57c | |||
424c410c42 | |||
d97e6005e9 | |||
3467230a77 | |||
a091581eb3 | |||
68daf6e487 | |||
a593b932e4 | |||
9a8ad3db69 |
30
.github/workflows/build.yml
vendored
30
.github/workflows/build.yml
vendored
@ -235,3 +235,33 @@ jobs:
|
||||
with:
|
||||
name: whisper-blas-bin-${{ matrix.arch }}
|
||||
path: build/bin/${{ matrix.build }}
|
||||
|
||||
emscripten:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
build: [Release]
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v1
|
||||
|
||||
- name: Dependencies
|
||||
run: |
|
||||
wget -q https://github.com/emscripten-core/emsdk/archive/master.tar.gz
|
||||
tar -xvf master.tar.gz
|
||||
emsdk-master/emsdk update
|
||||
emsdk-master/emsdk install latest
|
||||
emsdk-master/emsdk activate latest
|
||||
|
||||
- name: Configure
|
||||
run: echo "tmp"
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
pushd emsdk-master
|
||||
source ./emsdk_env.sh
|
||||
popd
|
||||
emcmake cmake . -DCMAKE_BUILD_TYPE=${{ matrix.build }}
|
||||
make
|
||||
|
@ -2,14 +2,15 @@ cmake_minimum_required (VERSION 3.0)
|
||||
|
||||
project(whisper.cpp VERSION 1.0.4)
|
||||
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS "on")
|
||||
# Add path to modules
|
||||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
|
||||
|
||||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
|
||||
set(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_PREFIX}/lib")
|
||||
|
||||
if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
|
||||
set(WHISPER_STANDALONE ON)
|
||||
include(cmake/GitVars.cmake)
|
||||
include(cmake/BuildTypes.cmake)
|
||||
include(GitVars)
|
||||
include(BuildTypes)
|
||||
|
||||
# configure project version
|
||||
if (EXISTS "${CMAKE_SOURCE_DIR}/bindings/ios/Makefile-tmpl")
|
||||
@ -52,6 +53,7 @@ if (APPLE)
|
||||
option(WHISPER_NO_ACCELERATE "whisper: disable Accelerate framework" OFF)
|
||||
option(WHISPER_NO_AVX "whisper: disable AVX" OFF)
|
||||
option(WHISPER_NO_AVX2 "whisper: disable AVX2" OFF)
|
||||
option(WHISPER_NO_FMA "whisper: disable FMA" OFF)
|
||||
else()
|
||||
option(WHISPER_SUPPORT_OPENBLAS "whisper: support for OpenBLAS" OFF)
|
||||
endif()
|
||||
@ -82,9 +84,6 @@ endif()
|
||||
|
||||
# dependencies
|
||||
|
||||
set(CMAKE_C_STANDARD 11)
|
||||
set(CMAKE_CXX_STANDARD 11)
|
||||
|
||||
find_package(Threads REQUIRED)
|
||||
|
||||
# on APPLE - include Accelerate framework
|
||||
@ -131,6 +130,7 @@ if (WHISPER_ALL_WARNINGS)
|
||||
-Wcast-qual \
|
||||
-Wstrict-prototypes \
|
||||
-Wpointer-arith \
|
||||
-Wno-unused-function \
|
||||
")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} \
|
||||
-Wall \
|
||||
@ -157,6 +157,7 @@ else()
|
||||
if (MSVC)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX2")
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /arch:AVX2")
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /arch:AVX2")
|
||||
else()
|
||||
if (EMSCRIPTEN)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -pthread")
|
||||
@ -168,7 +169,10 @@ else()
|
||||
if(NOT WHISPER_NO_AVX2)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx2")
|
||||
endif()
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mfma -mf16c")
|
||||
if(NOT WHISPER_NO_FMA)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mfma")
|
||||
endif()
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mf16c")
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
@ -190,6 +194,8 @@ add_library(${TARGET}
|
||||
whisper.cpp
|
||||
)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_include_directories(${TARGET} PUBLIC
|
||||
.
|
||||
)
|
||||
@ -223,6 +229,7 @@ target_compile_definitions(${TARGET} PUBLIC
|
||||
install(TARGETS ${TARGET}
|
||||
LIBRARY DESTINATION lib
|
||||
ARCHIVE DESTINATION lib/static
|
||||
RUNTIME DESTINATION bin
|
||||
)
|
||||
|
||||
#
|
||||
|
29
Makefile
29
Makefile
@ -10,6 +10,9 @@ ifndef UNAME_M
|
||||
UNAME_M := $(shell uname -m)
|
||||
endif
|
||||
|
||||
CCV := $(shell $(CC) --version | head -n 1)
|
||||
CXXV := $(shell $(CXX) --version | head -n 1)
|
||||
|
||||
# Mac OS + Arm can report x86_64
|
||||
# ref: https://github.com/ggerganov/whisper.cpp/issues/66#issuecomment-1282546789
|
||||
ifeq ($(UNAME_S),Darwin)
|
||||
@ -53,10 +56,13 @@ endif
|
||||
# Architecture specific
|
||||
# TODO: probably these flags need to be tweaked on some architectures
|
||||
# feel free to update the Makefile for your architecture and send a pull request or issue
|
||||
ifeq ($(UNAME_M),x86_64)
|
||||
ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686))
|
||||
ifeq ($(UNAME_S),Darwin)
|
||||
CFLAGS += -mfma -mf16c
|
||||
CFLAGS += -mf16c
|
||||
AVX1_M := $(shell sysctl machdep.cpu.features)
|
||||
ifneq (,$(findstring FMA,$(AVX1_M)))
|
||||
CFLAGS += -mfma
|
||||
endif
|
||||
ifneq (,$(findstring AVX1.0,$(AVX1_M)))
|
||||
CFLAGS += -mavx
|
||||
endif
|
||||
@ -81,6 +87,10 @@ ifeq ($(UNAME_M),x86_64)
|
||||
ifneq (,$(findstring f16c,$(F16C_M)))
|
||||
CFLAGS += -mf16c
|
||||
endif
|
||||
SSE3_M := $(shell grep "sse3 " /proc/cpuinfo)
|
||||
ifneq (,$(findstring sse3,$(SSE3_M)))
|
||||
CFLAGS += -msse3
|
||||
endif
|
||||
else ifeq ($(UNAME_S),Haiku)
|
||||
AVX1_M := $(shell sysinfo -cpu | grep "AVX ")
|
||||
ifneq (,$(findstring avx,$(AVX1_M)))
|
||||
@ -141,6 +151,21 @@ ifneq ($(filter armv8%,$(UNAME_M)),)
|
||||
CFLAGS += -mfp16-format=ieee -mno-unaligned-access
|
||||
endif
|
||||
|
||||
#
|
||||
# Print build information
|
||||
#
|
||||
|
||||
$(info I whisper.cpp build info: )
|
||||
$(info I UNAME_S: $(UNAME_S))
|
||||
$(info I UNAME_P: $(UNAME_P))
|
||||
$(info I UNAME_M: $(UNAME_M))
|
||||
$(info I CFLAGS: $(CFLAGS))
|
||||
$(info I CXXFLAGS: $(CXXFLAGS))
|
||||
$(info I LDFLAGS: $(LDFLAGS))
|
||||
$(info I CC: $(CCV))
|
||||
$(info I CXX: $(CXXV))
|
||||
$(info )
|
||||
|
||||
default: main
|
||||
|
||||
#
|
||||
|
@ -11,6 +11,7 @@ High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisp
|
||||
- Plain C/C++ implementation without dependencies
|
||||
- Apple silicon first-class citizen - optimized via Arm Neon and Accelerate framework
|
||||
- AVX intrinsics support for x86 architectures
|
||||
- VSX intrinsics support for POWER architectures
|
||||
- Mixed F16 / F32 precision
|
||||
- Low memory usage (Flash Attention + Flash Forward)
|
||||
- Zero memory allocations at runtime
|
||||
|
1
bindings/go/.gitignore
vendored
1
bindings/go/.gitignore
vendored
@ -1,3 +1,2 @@
|
||||
build
|
||||
models
|
||||
go.sum
|
||||
|
@ -1,28 +1,27 @@
|
||||
CMAKE := $(shell which cmake)
|
||||
BUILD_DIR := "build"
|
||||
MODELS_DIR := "models"
|
||||
BUILD_DIR := build
|
||||
MODELS_DIR := models
|
||||
EXAMPLES_DIR := $(wildcard examples/*)
|
||||
C_INCLUDE_PATH := "../.."
|
||||
INCLUDE_PATH := $(abspath ../..)
|
||||
LIBRARY_PATH := $(abspath ../..)
|
||||
|
||||
all: clean whisper examples
|
||||
|
||||
whisper: mkdir
|
||||
@echo Build whisper
|
||||
@${CMAKE} -S ../.. -B ${BUILD_DIR} -D BUILD_SHARED_LIBS=off -D WHISPER_NO_AVX2=on
|
||||
@${CMAKE} --build ${BUILD_DIR} --target whisper
|
||||
@${MAKE} -C ../.. libwhisper.a
|
||||
|
||||
test: model-small whisper modtidy
|
||||
@go test -v .
|
||||
@go test -v ./pkg/whisper/...
|
||||
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go test -v .
|
||||
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go test -v ./pkg/whisper/...
|
||||
|
||||
examples: $(EXAMPLES_DIR)
|
||||
|
||||
model-small: mkdir examples/go-model-download
|
||||
@${BUILD_DIR}/go-model-download -out models small.en
|
||||
@${BUILD_DIR}/go-model-download -out models ggml-small.en.bin
|
||||
|
||||
$(EXAMPLES_DIR): mkdir whisper modtidy
|
||||
@echo Build example $(notdir $@)
|
||||
@go build ${BUILD_FLAGS} -o ${BUILD_DIR}/$(notdir $@) ./$@
|
||||
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go build ${BUILD_FLAGS} -o ${BUILD_DIR}/$(notdir $@) ./$@
|
||||
|
||||
mkdir:
|
||||
@echo Mkdir ${BUILD_DIR}
|
||||
|
@ -74,4 +74,27 @@ And you can then test a model against samples with the following command:
|
||||
./build/go-whisper -model models/ggml-tiny.en.bin samples/jfk.wav
|
||||
```
|
||||
|
||||
## Using the bindings
|
||||
|
||||
To use the bindings in your own software,
|
||||
|
||||
1. Import `github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper` (or `github.com/ggerganov/whisper.cpp/bindings/go` into your package;
|
||||
2. Compile `libwhisper.a` (you can use `make whisper` in the `bindings/go` directory);
|
||||
3. Link your go binary against whisper by setting the environment variables `C_INCLUDE_PATH` and `LIBRARY_PATH`
|
||||
to point to the `whisper.h` file directory and `libwhisper.a` file directory respectively.
|
||||
|
||||
Look at the `Makefile` in the `bindings/go` directory for an example.
|
||||
|
||||
The API Documentation:
|
||||
|
||||
* https://pkg.go.dev/github.com/ggerganov/whisper.cpp/bindings/go
|
||||
* https://pkg.go.dev/github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper
|
||||
|
||||
Getting help:
|
||||
|
||||
* Follow the discussion for the go bindings [here](https://github.com/ggerganov/whisper.cpp/discussions/312)
|
||||
|
||||
## License
|
||||
|
||||
The license for the Go bindings is the same as the license for the rest of the whisper.cpp project, which is the MIT License. See the `LICENSE` file for more details.
|
||||
|
||||
|
@ -17,15 +17,14 @@ import (
|
||||
// CONSTANTS
|
||||
|
||||
const (
|
||||
srcUrl = "https://huggingface.co/" // The location of the models
|
||||
srcPathPrefix = "/datasets/ggerganov/whisper.cpp/resolve/main/ggml" // Filename prefix
|
||||
srcExt = ".bin" // Filename extension
|
||||
bufSize = 1024 * 64 // Size of the buffer used for downloading the model
|
||||
srcUrl = "https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main" // The location of the models
|
||||
srcExt = ".bin" // Filename extension
|
||||
bufSize = 1024 * 64 // Size of the buffer used for downloading the model
|
||||
)
|
||||
|
||||
var (
|
||||
// The models which will be downloaded, if no model is specified as an argument
|
||||
modelNames = []string{"tiny.en", "tiny", "base.en", "base", "small.en", "small", "medium.en", "medium", "large-v1", "large"}
|
||||
modelNames = []string{"ggml-tiny.en", "ggml-tiny", "ggml-base.en", "ggml-base", "ggml-small.en", "ggml-small", "ggml-medium.en", "ggml-medium", "ggml-large-v1", "ggml-large"}
|
||||
)
|
||||
|
||||
var (
|
||||
@ -123,11 +122,14 @@ func GetModels() []string {
|
||||
|
||||
// URLForModel returns the URL for the given model on huggingface.co
|
||||
func URLForModel(model string) (string, error) {
|
||||
if filepath.Ext(model) != srcExt {
|
||||
model += srcExt
|
||||
}
|
||||
url, err := url.Parse(srcUrl)
|
||||
if err != nil {
|
||||
return "", err
|
||||
} else {
|
||||
url.Path = srcPathPrefix + "-" + model + srcExt
|
||||
url.Path = filepath.Join(url.Path, model)
|
||||
}
|
||||
return url.String(), nil
|
||||
}
|
||||
|
23
bindings/go/go.sum
Normal file
23
bindings/go/go.sum
Normal file
@ -0,0 +1,23 @@
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/go-audio/audio v1.0.0 h1:zS9vebldgbQqktK4H0lUqWrG8P0NxCJVqcj7ZpNnwd4=
|
||||
github.com/go-audio/audio v1.0.0/go.mod h1:6uAu0+H2lHkwdGsAY+j2wHPNPpPoeg5AaEFh9FlA+Zs=
|
||||
github.com/go-audio/riff v1.0.0 h1:d8iCGbDvox9BfLagY94fBynxSPHO80LmZCaOsmKxokA=
|
||||
github.com/go-audio/riff v1.0.0/go.mod h1:l3cQwc85y79NQFCRB7TiPoNiaijp6q8Z0Uv38rVG498=
|
||||
github.com/go-audio/wav v1.1.0 h1:jQgLtbqBzY7G+BM8fXF7AHUk1uHUviWS4X39d5rsL2g=
|
||||
github.com/go-audio/wav v1.1.0/go.mod h1:mpe9qfwbScEbkd8uybLuIpTgHyrISw/OTuvjUW2iGtE=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
@ -1,8 +1,5 @@
|
||||
package whisper
|
||||
|
||||
// This file defines the whisper_token, whisper_token_data and whisper_full_params
|
||||
// structures, which are used by the whisper_full() function.
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
@ -9,8 +9,7 @@ import (
|
||||
// CGO
|
||||
|
||||
/*
|
||||
#cgo CFLAGS: -I${SRCDIR}/../..
|
||||
#cgo LDFLAGS: -L${SRCDIR}/build -lwhisper -lm -lstdc++
|
||||
#cgo LDFLAGS: -lwhisper -lm -lstdc++
|
||||
#cgo darwin LDFLAGS: -framework Accelerate
|
||||
#include <whisper.h>
|
||||
#include <stdlib.h>
|
||||
@ -171,6 +170,10 @@ func (ctx *Context) Whisper_tokenize(text string, tokens []Token) (int, error) {
|
||||
}
|
||||
|
||||
// Return the id of the specified language, returns -1 if not found
|
||||
// Examples:
|
||||
//
|
||||
// "de" -> 2
|
||||
// "german" -> 2
|
||||
func (ctx *Context) Whisper_lang_id(lang string) int {
|
||||
return int(C.whisper_lang_id(C.CString(lang)))
|
||||
}
|
||||
@ -211,6 +214,10 @@ func (ctx *Context) Whisper_n_text_ctx() int {
|
||||
return int(C.whisper_n_text_ctx((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
||||
func (ctx *Context) Whisper_n_audio_ctx() int {
|
||||
return int(C.whisper_n_audio_ctx((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
||||
func (ctx *Context) Whisper_is_multilingual() int {
|
||||
return int(C.whisper_is_multilingual((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
@ -50,7 +50,10 @@ func Test_Whisper_001(t *testing.T) {
|
||||
ctx := whisper.Whisper_init(ModelPath)
|
||||
assert.NotNil(ctx)
|
||||
defer ctx.Whisper_free()
|
||||
assert.NoError(ctx.Whisper_full(ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY), buf.AsFloat32Buffer().Data, nil, nil))
|
||||
params := ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY)
|
||||
data := buf.AsFloat32Buffer().Data
|
||||
err = ctx.Whisper_full(params, data, nil, nil)
|
||||
assert.NoError(err)
|
||||
|
||||
// Print out tokens
|
||||
num_segments := ctx.Whisper_full_n_segments()
|
||||
|
Submodule bindings/ios updated: 1502317fe0...6707f1ea1c
File diff suppressed because one or more lines are too long
17
cmake/DefaultTargetOptions.cmake
Normal file
17
cmake/DefaultTargetOptions.cmake
Normal file
@ -0,0 +1,17 @@
|
||||
# Set the default compile features and properties for a target.
|
||||
|
||||
if (NOT TARGET)
|
||||
message(FATAL_ERROR "TARGET not set before including DefaultTargetOptions")
|
||||
endif()
|
||||
|
||||
target_compile_features(${TARGET}
|
||||
PRIVATE
|
||||
cxx_std_11
|
||||
)
|
||||
|
||||
set_target_properties(${TARGET}
|
||||
PROPERTIES
|
||||
EXPORT_COMPILE_COMMANDS ON
|
||||
RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin"
|
||||
INSTALL_RPATH "${CMAKE_INSTALL_PREFIX}/lib"
|
||||
)
|
@ -8,6 +8,8 @@ add_executable(${TARGET}
|
||||
emscripten.cpp
|
||||
)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE
|
||||
whisper
|
||||
)
|
||||
|
@ -1,3 +1,6 @@
|
||||
set(TARGET bench)
|
||||
add_executable(${TARGET} bench.cpp)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE whisper ${CMAKE_THREAD_LIBS_INIT})
|
||||
|
@ -8,6 +8,8 @@ add_executable(${TARGET}
|
||||
emscripten.cpp
|
||||
)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE
|
||||
whisper
|
||||
)
|
||||
|
@ -2,6 +2,9 @@ if (WHISPER_SUPPORT_SDL2)
|
||||
# command
|
||||
set(TARGET command)
|
||||
add_executable(${TARGET} command.cpp)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
|
||||
target_link_libraries(${TARGET} PRIVATE whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
|
||||
endif ()
|
||||
|
@ -9,7 +9,19 @@ More info is available in [issue #171](https://github.com/ggerganov/whisper.cpp/
|
||||
|
||||
# On Raspberry Pi, use tiny or base models + "-ac 768" for better performance
|
||||
./command -m ./models/ggml-tiny.en.bin -ac 768 -t 3 -c 0
|
||||
```
|
||||
|
||||
https://user-images.githubusercontent.com/1991296/204038393-2f846eae-c255-4099-a76d-5735c25c49da.mp4
|
||||
|
||||
Web version: [examples/command.wasm](/examples/command.wasm)
|
||||
|
||||
## Guided mode
|
||||
|
||||
"Guided mode" allows you to specify a list of commands (i.e. strings) and the transcription will be guided to classify your command into one from the list. This can be useful in situations where a device is listening only for a small subset of commands.
|
||||
|
||||
Initial tests show that this approach might be extremely efficient in terms of performance, since it integrates very well with the "partial Encoder" idea from #137.
|
||||
|
||||
```bash
|
||||
# Run in guided mode, the list of allowed commands is in commands.txt
|
||||
./command -m ./models/ggml-base.en.bin -cmd ./examples/command/commands.txt
|
||||
|
||||
@ -17,9 +29,8 @@ More info is available in [issue #171](https://github.com/ggerganov/whisper.cpp/
|
||||
./command -m ./models/ggml-tiny.en.bin -cmd ./examples/command/commands.txt -ac 128 -t 3 -c 0
|
||||
```
|
||||
|
||||
https://user-images.githubusercontent.com/1991296/204038393-2f846eae-c255-4099-a76d-5735c25c49da.mp4
|
||||
https://user-images.githubusercontent.com/1991296/207435352-8fc4ed3f-bde5-4555-9b8b-aeeb76bee969.mp4
|
||||
|
||||
Web version: [examples/command.wasm](/examples/command.wasm)
|
||||
|
||||
## Building
|
||||
|
||||
|
@ -510,6 +510,333 @@ std::vector<std::string> read_allowed_commands(const std::string & fname) {
|
||||
return allowed_commands;
|
||||
}
|
||||
|
||||
// command-list mode
|
||||
// guide the transcription to match the most likely command from a provided list
|
||||
int process_command_list(struct whisper_context * ctx, audio_async &audio, const whisper_params ¶ms) {
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "%s: guided mode\n", __func__);
|
||||
|
||||
std::vector<std::string> allowed_commands = read_allowed_commands(params.commands);
|
||||
|
||||
if (allowed_commands.empty()) {
|
||||
fprintf(stderr, "%s: error: failed to read allowed commands from '%s'\n", __func__, params.commands.c_str());
|
||||
return 2;
|
||||
}
|
||||
|
||||
int max_len = 0;
|
||||
|
||||
std::vector<std::vector<whisper_token>> allowed_tokens;
|
||||
|
||||
for (const auto & cmd : allowed_commands) {
|
||||
whisper_token tokens[1024];
|
||||
allowed_tokens.emplace_back();
|
||||
|
||||
for (int l = 0; l < (int) cmd.size(); ++l) {
|
||||
// NOTE: very important to add the whitespace !
|
||||
// the reason is that the first decoded token starts with a whitespace too!
|
||||
std::string ss = std::string(" ") + cmd.substr(0, l + 1);
|
||||
|
||||
const int n = whisper_tokenize(ctx, ss.c_str(), tokens, 1024);
|
||||
if (n < 0) {
|
||||
fprintf(stderr, "%s: error: failed to tokenize command '%s'\n", __func__, cmd.c_str());
|
||||
return 3;
|
||||
}
|
||||
|
||||
if (n == 1) {
|
||||
allowed_tokens.back().push_back(tokens[0]);
|
||||
}
|
||||
}
|
||||
|
||||
max_len = std::max(max_len, (int) cmd.size());
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s: allowed commands [ tokens ]:\n", __func__);
|
||||
fprintf(stderr, "\n");
|
||||
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
|
||||
fprintf(stderr, " - \033[1m%-*s\033[0m = [", max_len, allowed_commands[i].c_str());
|
||||
for (const auto & token : allowed_tokens[i]) {
|
||||
fprintf(stderr, " %5d", token);
|
||||
}
|
||||
fprintf(stderr, " ]\n");
|
||||
}
|
||||
|
||||
std::string k_prompt = "select one from the available words: ";
|
||||
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
|
||||
if (i > 0) {
|
||||
k_prompt += ", ";
|
||||
}
|
||||
k_prompt += allowed_commands[i];
|
||||
}
|
||||
k_prompt += ". selected word: ";
|
||||
|
||||
// tokenize prompt
|
||||
std::vector<whisper_token> k_tokens;
|
||||
{
|
||||
k_tokens.resize(1024);
|
||||
const int n = whisper_tokenize(ctx, k_prompt.c_str(), k_tokens.data(), 1024);
|
||||
if (n < 0) {
|
||||
fprintf(stderr, "%s: error: failed to tokenize prompt '%s'\n", __func__, k_prompt.c_str());
|
||||
return 4;
|
||||
}
|
||||
k_tokens.resize(n);
|
||||
}
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "%s: prompt: '%s'\n", __func__, k_prompt.c_str());
|
||||
fprintf(stderr, "%s: tokens: [", __func__);
|
||||
for (const auto & token : k_tokens) {
|
||||
fprintf(stderr, " %d", token);
|
||||
}
|
||||
fprintf(stderr, " ]\n");
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "%s: listening for a command ...\n", __func__);
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
bool is_running = true;
|
||||
|
||||
std::vector<float> pcmf32_cur;
|
||||
std::vector<float> pcmf32_prompt;
|
||||
|
||||
// main loop
|
||||
while (is_running) {
|
||||
// handle Ctrl + C
|
||||
{
|
||||
SDL_Event event;
|
||||
while (SDL_PollEvent(&event)) {
|
||||
switch (event.type) {
|
||||
case SDL_QUIT:
|
||||
{
|
||||
is_running = false;
|
||||
} break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!is_running) {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
// delay
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
|
||||
audio.get(2000, pcmf32_cur);
|
||||
|
||||
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
|
||||
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
|
||||
|
||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||
|
||||
wparams.print_progress = false;
|
||||
wparams.print_special = params.print_special;
|
||||
wparams.print_realtime = false;
|
||||
wparams.print_timestamps = !params.no_timestamps;
|
||||
wparams.translate = params.translate;
|
||||
wparams.no_context = true;
|
||||
wparams.single_segment = true;
|
||||
wparams.max_tokens = 1;
|
||||
wparams.language = params.language.c_str();
|
||||
wparams.n_threads = params.n_threads;
|
||||
|
||||
wparams.audio_ctx = params.audio_ctx;
|
||||
wparams.speed_up = params.speed_up;
|
||||
|
||||
wparams.prompt_tokens = k_tokens.data();
|
||||
wparams.prompt_n_tokens = k_tokens.size();
|
||||
|
||||
// run the transformer and a single decoding pass
|
||||
if (whisper_full(ctx, wparams, pcmf32_cur.data(), pcmf32_cur.size()) != 0) {
|
||||
fprintf(stderr, "%s: ERROR: whisper_full() failed\n", __func__);
|
||||
break;
|
||||
}
|
||||
|
||||
const auto * probs = whisper_get_probs(ctx);
|
||||
std::vector<std::pair<float, int>> probs_id;
|
||||
|
||||
double psum = 0.0;
|
||||
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
|
||||
probs_id.emplace_back(probs[allowed_tokens[i][0]], i);
|
||||
for (int j = 1; j < (int) allowed_tokens[i].size(); ++j) {
|
||||
probs_id.back().first += probs[allowed_tokens[i][j]];
|
||||
}
|
||||
probs_id.back().first /= allowed_tokens[i].size();
|
||||
psum += probs_id.back().first;
|
||||
}
|
||||
|
||||
// normalize
|
||||
for (auto & p : probs_id) {
|
||||
p.first /= psum;
|
||||
}
|
||||
|
||||
// sort descending
|
||||
{
|
||||
using pair_type = decltype(probs_id)::value_type;
|
||||
std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) {
|
||||
return a.first > b.first;
|
||||
});
|
||||
}
|
||||
|
||||
// print the commands and the respective probabilities
|
||||
{
|
||||
fprintf(stdout, "\n");
|
||||
for (const auto & cmd : probs_id) {
|
||||
fprintf(stdout, "%s: %s%-*s%s = %f | ", __func__, "\033[1m", max_len, allowed_commands[cmd.second].c_str(), "\033[0m", cmd.first);
|
||||
for (int token : allowed_tokens[cmd.second]) {
|
||||
fprintf(stdout, "'%4s' %f ", whisper_token_to_str(ctx, token), probs[token]);
|
||||
}
|
||||
fprintf(stdout, "\n");
|
||||
}
|
||||
}
|
||||
|
||||
// best command
|
||||
{
|
||||
const auto t_end = std::chrono::high_resolution_clock::now();
|
||||
|
||||
const float prob = probs_id[0].first;
|
||||
const int index = probs_id[0].second;
|
||||
|
||||
fprintf(stdout, "\n");
|
||||
fprintf(stdout, "%s: detected command: %s%s%s | p = %f | t = %d ms\n", __func__,
|
||||
"\033[1m", allowed_commands[index].c_str(), "\033[0m", prob,
|
||||
(int) std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count());
|
||||
fprintf(stdout, "\n");
|
||||
}
|
||||
|
||||
audio.clear();
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
// general-purpose mode
|
||||
// freely transcribe the voice into text
|
||||
int process_general_transcription(struct whisper_context * ctx, audio_async &audio, const whisper_params ¶ms) {
|
||||
bool is_running = true;
|
||||
bool have_prompt = false;
|
||||
bool ask_prompt = true;
|
||||
|
||||
float prob0 = 0.0f;
|
||||
float prob = 0.0f;
|
||||
|
||||
std::vector<float> pcmf32_cur;
|
||||
std::vector<float> pcmf32_prompt;
|
||||
|
||||
const std::string k_prompt = "Ok Whisper, start listening for commands.";
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "%s: general-purpose mode\n", __func__);
|
||||
|
||||
// main loop
|
||||
while (is_running) {
|
||||
// handle Ctrl + C
|
||||
{
|
||||
SDL_Event event;
|
||||
while (SDL_PollEvent(&event)) {
|
||||
switch (event.type) {
|
||||
case SDL_QUIT:
|
||||
{
|
||||
is_running = false;
|
||||
} break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!is_running) {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
// delay
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
|
||||
if (ask_prompt) {
|
||||
fprintf(stdout, "\n");
|
||||
fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m");
|
||||
fprintf(stdout, "\n");
|
||||
|
||||
ask_prompt = false;
|
||||
}
|
||||
|
||||
{
|
||||
audio.get(2000, pcmf32_cur);
|
||||
|
||||
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
|
||||
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
|
||||
|
||||
int64_t t_ms = 0;
|
||||
|
||||
if (!have_prompt) {
|
||||
// wait for activation phrase
|
||||
audio.get(params.prompt_ms, pcmf32_cur);
|
||||
|
||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob0, t_ms));
|
||||
|
||||
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms);
|
||||
|
||||
const float sim = similarity(txt, k_prompt);
|
||||
|
||||
if (txt.length() < 0.8*k_prompt.length() || txt.length() > 1.2*k_prompt.length() || sim < 0.8f) {
|
||||
fprintf(stdout, "%s: WARNING: prompt not recognized, try again\n", __func__);
|
||||
ask_prompt = true;
|
||||
} else {
|
||||
fprintf(stdout, "\n");
|
||||
fprintf(stdout, "%s: The prompt has been recognized!\n", __func__);
|
||||
fprintf(stdout, "%s: Waiting for voice commands ...\n", __func__);
|
||||
fprintf(stdout, "\n");
|
||||
|
||||
// save the audio for the prompt
|
||||
pcmf32_prompt = pcmf32_cur;
|
||||
have_prompt = true;
|
||||
}
|
||||
} else {
|
||||
// we have heard the activation phrase, now detect the commands
|
||||
audio.get(params.command_ms, pcmf32_cur);
|
||||
|
||||
// prepend the prompt audio
|
||||
pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
|
||||
|
||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
|
||||
|
||||
prob = 100.0f*(prob - prob0);
|
||||
|
||||
//fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str());
|
||||
|
||||
// find the prompt in the text
|
||||
float best_sim = 0.0f;
|
||||
size_t best_len = 0;
|
||||
for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
|
||||
const auto prompt = txt.substr(0, n);
|
||||
|
||||
const float sim = similarity(prompt, k_prompt);
|
||||
|
||||
//fprintf(stderr, "%s: prompt = '%s', sim = %f\n", __func__, prompt.c_str(), sim);
|
||||
|
||||
if (sim > best_sim) {
|
||||
best_sim = sim;
|
||||
best_len = n;
|
||||
}
|
||||
}
|
||||
|
||||
const std::string command = ::trim(txt.substr(best_len));
|
||||
|
||||
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
|
||||
fprintf(stdout, "\n");
|
||||
}
|
||||
|
||||
audio.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
whisper_params params;
|
||||
|
||||
@ -561,300 +888,12 @@ int main(int argc, char ** argv) {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
|
||||
audio.clear();
|
||||
|
||||
int max_len = 0;
|
||||
|
||||
bool is_running = true;
|
||||
bool have_prompt = false;
|
||||
bool ask_prompt = true;
|
||||
|
||||
float prob0 = 0.0f;
|
||||
float prob = 0.0f;
|
||||
|
||||
std::vector<float> pcmf32_cur;
|
||||
std::vector<float> pcmf32_prompt;
|
||||
|
||||
std::vector<std::string> allowed_commands;
|
||||
std::vector<std::vector<whisper_token>> allowed_tokens;
|
||||
|
||||
std::string k_prompt;
|
||||
std::vector<whisper_token> k_tokens;
|
||||
int ret_val = 0;
|
||||
|
||||
if (!params.commands.empty()) {
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "%s: guided mode\n", __func__);
|
||||
|
||||
allowed_commands = read_allowed_commands(params.commands);
|
||||
|
||||
if (allowed_commands.empty()) {
|
||||
fprintf(stderr, "%s: error: failed to read allowed commands from '%s'\n", __func__, params.commands.c_str());
|
||||
return 2;
|
||||
}
|
||||
|
||||
for (const auto & cmd : allowed_commands) {
|
||||
whisper_token tokens[1024];
|
||||
allowed_tokens.emplace_back();
|
||||
|
||||
for (int l = 0; l < (int) cmd.size(); ++l) {
|
||||
// NOTE: very important to add the whitespace !
|
||||
// the reason is that the first decoded token starts with a whitespace too!
|
||||
std::string ss = std::string(" ") + cmd.substr(0, l + 1);
|
||||
|
||||
const int n = whisper_tokenize(ctx, ss.c_str(), tokens, 1024);
|
||||
if (n < 0) {
|
||||
fprintf(stderr, "%s: error: failed to tokenize command '%s'\n", __func__, cmd.c_str());
|
||||
return 3;
|
||||
}
|
||||
|
||||
if (n == 1) {
|
||||
allowed_tokens.back().push_back(tokens[0]);
|
||||
}
|
||||
}
|
||||
|
||||
max_len = std::max(max_len, (int) cmd.size());
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s: allowed commands [ tokens ]:\n", __func__);
|
||||
fprintf(stderr, "\n");
|
||||
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
|
||||
fprintf(stderr, " - \033[1m%-*s\033[0m = [", max_len, allowed_commands[i].c_str());
|
||||
for (const auto & token : allowed_tokens[i]) {
|
||||
fprintf(stderr, " %5d", token);
|
||||
}
|
||||
fprintf(stderr, " ]\n");
|
||||
}
|
||||
|
||||
k_prompt = "select one from the available words: ";
|
||||
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
|
||||
if (i > 0) {
|
||||
k_prompt += ", ";
|
||||
}
|
||||
k_prompt += allowed_commands[i];
|
||||
}
|
||||
k_prompt += ". selected word: ";
|
||||
|
||||
// tokenize prompt
|
||||
{
|
||||
k_tokens.resize(1024);
|
||||
const int n = whisper_tokenize(ctx, k_prompt.c_str(), k_tokens.data(), 1024);
|
||||
if (n < 0) {
|
||||
fprintf(stderr, "%s: error: failed to tokenize prompt '%s'\n", __func__, k_prompt.c_str());
|
||||
return 4;
|
||||
}
|
||||
k_tokens.resize(n);
|
||||
}
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "%s: prompt: '%s'\n", __func__, k_prompt.c_str());
|
||||
fprintf(stderr, "%s: tokens: [", __func__);
|
||||
for (const auto & token : k_tokens) {
|
||||
fprintf(stderr, " %d", token);
|
||||
}
|
||||
fprintf(stderr, " ]\n");
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "%s: listening for a command ...\n", __func__);
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
ret_val = process_command_list(ctx, audio, params);
|
||||
} else {
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "%s: general-purpose mode\n", __func__);
|
||||
|
||||
k_prompt = "Ok Whisper, start listening for commands.";
|
||||
}
|
||||
|
||||
// main loop
|
||||
while (is_running) {
|
||||
// handle Ctrl + C
|
||||
{
|
||||
SDL_Event event;
|
||||
while (SDL_PollEvent(&event)) {
|
||||
switch (event.type) {
|
||||
case SDL_QUIT:
|
||||
{
|
||||
is_running = false;
|
||||
} break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!is_running) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// delay
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
|
||||
if (allowed_commands.empty()) {
|
||||
// general-purpose mode
|
||||
// freely transcribe the voice into text
|
||||
|
||||
if (ask_prompt) {
|
||||
fprintf(stdout, "\n");
|
||||
fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m");
|
||||
fprintf(stdout, "\n");
|
||||
|
||||
ask_prompt = false;
|
||||
}
|
||||
|
||||
{
|
||||
int64_t t_ms = 0;
|
||||
|
||||
audio.get(2000, pcmf32_cur);
|
||||
|
||||
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
|
||||
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
|
||||
|
||||
if (!have_prompt) {
|
||||
// wait for activation phrase
|
||||
audio.get(params.prompt_ms, pcmf32_cur);
|
||||
|
||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob0, t_ms));
|
||||
|
||||
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms);
|
||||
|
||||
const float sim = similarity(txt, k_prompt);
|
||||
|
||||
if (txt.length() < 0.8*k_prompt.length() || txt.length() > 1.2*k_prompt.length() || sim < 0.8f) {
|
||||
fprintf(stdout, "%s: WARNING: prompt not recognized, try again\n", __func__);
|
||||
ask_prompt = true;
|
||||
} else {
|
||||
fprintf(stdout, "\n");
|
||||
fprintf(stdout, "%s: The prompt has been recognized!\n", __func__);
|
||||
fprintf(stdout, "%s: Waiting for voice commands ...\n", __func__);
|
||||
fprintf(stdout, "\n");
|
||||
|
||||
// save the audio for the prompt
|
||||
pcmf32_prompt = pcmf32_cur;
|
||||
have_prompt = true;
|
||||
}
|
||||
} else {
|
||||
// we have heard the activation phrase, now detect the commands
|
||||
audio.get(params.command_ms, pcmf32_cur);
|
||||
|
||||
// prepend the prompt audio
|
||||
pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
|
||||
|
||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
|
||||
|
||||
prob = 100.0f*(prob - prob0);
|
||||
|
||||
//fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str());
|
||||
|
||||
// find the prompt in the text
|
||||
float best_sim = 0.0f;
|
||||
size_t best_len = 0;
|
||||
for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
|
||||
const auto prompt = txt.substr(0, n);
|
||||
|
||||
const float sim = similarity(prompt, k_prompt);
|
||||
|
||||
//fprintf(stderr, "%s: prompt = '%s', sim = %f\n", __func__, prompt.c_str(), sim);
|
||||
|
||||
if (sim > best_sim) {
|
||||
best_sim = sim;
|
||||
best_len = n;
|
||||
}
|
||||
}
|
||||
|
||||
const std::string command = ::trim(txt.substr(best_len));
|
||||
|
||||
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
|
||||
fprintf(stdout, "\n");
|
||||
}
|
||||
|
||||
audio.clear();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// command-list mode
|
||||
// guide the transcription to match the most likely command from a provided list
|
||||
|
||||
audio.get(2000, pcmf32_cur);
|
||||
|
||||
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
|
||||
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
|
||||
|
||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||
|
||||
wparams.print_progress = false;
|
||||
wparams.print_special = params.print_special;
|
||||
wparams.print_realtime = false;
|
||||
wparams.print_timestamps = !params.no_timestamps;
|
||||
wparams.translate = params.translate;
|
||||
wparams.no_context = true;
|
||||
wparams.single_segment = true;
|
||||
wparams.max_tokens = 1;
|
||||
wparams.language = params.language.c_str();
|
||||
wparams.n_threads = params.n_threads;
|
||||
|
||||
wparams.audio_ctx = params.audio_ctx;
|
||||
wparams.speed_up = params.speed_up;
|
||||
|
||||
wparams.prompt_tokens = k_tokens.data();
|
||||
wparams.prompt_n_tokens = k_tokens.size();
|
||||
|
||||
// run the transformer and a single decoding pass
|
||||
if (whisper_full(ctx, wparams, pcmf32_cur.data(), pcmf32_cur.size()) != 0) {
|
||||
fprintf(stderr, "%s: ERROR: whisper_full() failed\n", __func__);
|
||||
break;
|
||||
}
|
||||
|
||||
const auto * probs = whisper_get_probs(ctx);
|
||||
std::vector<std::pair<float, int>> probs_id;
|
||||
|
||||
double psum = 0.0;
|
||||
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
|
||||
probs_id.emplace_back(probs[allowed_tokens[i][0]], i);
|
||||
for (int j = 1; j < (int) allowed_tokens[i].size(); ++j) {
|
||||
probs_id.back().first += probs[allowed_tokens[i][j]];
|
||||
}
|
||||
probs_id.back().first /= allowed_tokens[i].size();
|
||||
psum += probs_id.back().first;
|
||||
}
|
||||
|
||||
// normalize
|
||||
for (auto & p : probs_id) {
|
||||
p.first /= psum;
|
||||
}
|
||||
|
||||
// sort descending
|
||||
{
|
||||
using pair_type = decltype(probs_id)::value_type;
|
||||
std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) {
|
||||
return a.first > b.first;
|
||||
});
|
||||
}
|
||||
|
||||
// print the commands and the respective probabilities
|
||||
{
|
||||
fprintf(stdout, "\n");
|
||||
for (const auto & cmd : probs_id) {
|
||||
fprintf(stdout, "%s: %s%-*s%s = %f | ", __func__, "\033[1m", max_len, allowed_commands[cmd.second].c_str(), "\033[0m", cmd.first);
|
||||
for (int i = 0; i < (int) allowed_tokens[cmd.second].size(); ++i) {
|
||||
fprintf(stdout, "'%4s' %f ", whisper_token_to_str(ctx, allowed_tokens[cmd.second][i]), probs[allowed_tokens[cmd.second][i]]);
|
||||
}
|
||||
fprintf(stdout, "\n");
|
||||
}
|
||||
}
|
||||
|
||||
// best command
|
||||
{
|
||||
const auto t_end = std::chrono::high_resolution_clock::now();
|
||||
|
||||
fprintf(stdout, "\n");
|
||||
fprintf(stdout, "%s: detected command: %s%s%s | p = %f | t = %d ms\n", __func__,
|
||||
"\033[1m", allowed_commands[probs_id[0].second].c_str(), "\033[0m", probs_id[0].first,
|
||||
(int) std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count());
|
||||
fprintf(stdout, "\n");
|
||||
}
|
||||
|
||||
audio.clear();
|
||||
}
|
||||
}
|
||||
ret_val = process_general_transcription(ctx, audio, params);
|
||||
}
|
||||
|
||||
audio.pause();
|
||||
@ -862,5 +901,5 @@ int main(int argc, char ** argv) {
|
||||
whisper_print_timings(ctx);
|
||||
whisper_free(ctx);
|
||||
|
||||
return 0;
|
||||
return ret_val;
|
||||
}
|
||||
|
@ -1,3 +1,6 @@
|
||||
set(TARGET main)
|
||||
add_executable(${TARGET} main.cpp)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE whisper ${CMAKE_THREAD_LIBS_INIT})
|
||||
|
@ -69,6 +69,7 @@ struct whisper_params {
|
||||
bool output_vtt = false;
|
||||
bool output_srt = false;
|
||||
bool output_wts = false;
|
||||
bool output_csv = false;
|
||||
bool print_special = false;
|
||||
bool print_colors = false;
|
||||
bool print_progress = false;
|
||||
@ -111,6 +112,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
|
||||
else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
|
||||
else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; }
|
||||
else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; }
|
||||
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
|
||||
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
|
||||
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
|
||||
@ -150,6 +152,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
||||
fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
|
||||
fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false");
|
||||
fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false");
|
||||
fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false");
|
||||
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
||||
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
|
||||
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
|
||||
@ -173,90 +176,81 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi
|
||||
|
||||
const int n_segments = whisper_full_n_segments(ctx);
|
||||
|
||||
std::string speaker = "";
|
||||
|
||||
int64_t t0;
|
||||
int64_t t1;
|
||||
|
||||
// print the last n_new segments
|
||||
const int s0 = n_segments - n_new;
|
||||
|
||||
if (s0 == 0) {
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
for (int i = s0; i < n_segments; i++) {
|
||||
if (params.no_timestamps) {
|
||||
if (params.print_colors) {
|
||||
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
|
||||
if (params.print_special == false) {
|
||||
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
|
||||
if (id >= whisper_token_eot(ctx)) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
const char * text = whisper_full_get_token_text(ctx, i, j);
|
||||
const float p = whisper_full_get_token_p (ctx, i, j);
|
||||
|
||||
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
|
||||
|
||||
printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
|
||||
}
|
||||
} else {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
printf("%s", text);
|
||||
}
|
||||
fflush(stdout);
|
||||
} else {
|
||||
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
|
||||
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
||||
|
||||
std::string speaker;
|
||||
|
||||
if (params.diarize && pcmf32s.size() == 2) {
|
||||
const int64_t n_samples = pcmf32s[0].size();
|
||||
|
||||
const int64_t is0 = timestamp_to_sample(t0, n_samples);
|
||||
const int64_t is1 = timestamp_to_sample(t1, n_samples);
|
||||
|
||||
double energy0 = 0.0f;
|
||||
double energy1 = 0.0f;
|
||||
|
||||
for (int64_t j = is0; j < is1; j++) {
|
||||
energy0 += fabs(pcmf32s[0][j]);
|
||||
energy1 += fabs(pcmf32s[1][j]);
|
||||
}
|
||||
|
||||
if (energy0 > 1.1*energy1) {
|
||||
speaker = "(speaker 0)";
|
||||
} else if (energy1 > 1.1*energy0) {
|
||||
speaker = "(speaker 1)";
|
||||
} else {
|
||||
speaker = "(speaker ?)";
|
||||
}
|
||||
|
||||
//printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str());
|
||||
}
|
||||
|
||||
if (params.print_colors) {
|
||||
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
|
||||
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
|
||||
if (params.print_special == false) {
|
||||
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
|
||||
if (id >= whisper_token_eot(ctx)) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
const char * text = whisper_full_get_token_text(ctx, i, j);
|
||||
const float p = whisper_full_get_token_p (ctx, i, j);
|
||||
|
||||
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
|
||||
|
||||
printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
|
||||
}
|
||||
printf("\n");
|
||||
} else {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
|
||||
printf("[%s --> %s] %s%s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), speaker.c_str(), text);
|
||||
}
|
||||
if (!params.no_timestamps || params.diarize) {
|
||||
t0 = whisper_full_get_segment_t0(ctx, i);
|
||||
t1 = whisper_full_get_segment_t1(ctx, i);
|
||||
}
|
||||
|
||||
if (!params.no_timestamps) {
|
||||
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
|
||||
}
|
||||
|
||||
if (params.diarize && pcmf32s.size() == 2) {
|
||||
const int64_t n_samples = pcmf32s[0].size();
|
||||
|
||||
const int64_t is0 = timestamp_to_sample(t0, n_samples);
|
||||
const int64_t is1 = timestamp_to_sample(t1, n_samples);
|
||||
|
||||
double energy0 = 0.0f;
|
||||
double energy1 = 0.0f;
|
||||
|
||||
for (int64_t j = is0; j < is1; j++) {
|
||||
energy0 += fabs(pcmf32s[0][j]);
|
||||
energy1 += fabs(pcmf32s[1][j]);
|
||||
}
|
||||
|
||||
if (energy0 > 1.1*energy1) {
|
||||
speaker = "(speaker 0)";
|
||||
} else if (energy1 > 1.1*energy0) {
|
||||
speaker = "(speaker 1)";
|
||||
} else {
|
||||
speaker = "(speaker ?)";
|
||||
}
|
||||
|
||||
//printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str());
|
||||
}
|
||||
|
||||
if (params.print_colors) {
|
||||
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
|
||||
if (params.print_special == false) {
|
||||
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
|
||||
if (id >= whisper_token_eot(ctx)) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
const char * text = whisper_full_get_token_text(ctx, i, j);
|
||||
const float p = whisper_full_get_token_p (ctx, i, j);
|
||||
|
||||
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
|
||||
|
||||
printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
|
||||
}
|
||||
} else {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
|
||||
printf("%s%s", speaker.c_str(), text);
|
||||
}
|
||||
|
||||
// with timestamps or speakers: each segment on new line
|
||||
if (!params.no_timestamps || params.diarize) {
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
fflush(stdout);
|
||||
}
|
||||
}
|
||||
|
||||
@ -325,6 +319,32 @@ bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_
|
||||
return true;
|
||||
}
|
||||
|
||||
bool output_csv(struct whisper_context * ctx, const char * fname) {
|
||||
std::ofstream fout(fname);
|
||||
if (!fout.is_open()) {
|
||||
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
|
||||
return false;
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
|
||||
|
||||
const int n_segments = whisper_full_n_segments(ctx);
|
||||
for (int i = 0; i < n_segments; ++i) {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
if (text[0] == ' ')
|
||||
text = text + sizeof(char); //whisper_full_get_segment_text() returns a string with leading space, point to the next character.
|
||||
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
|
||||
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
||||
//need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds.
|
||||
fout << 10 * t0 << ", "
|
||||
<< 10 * t1 << ", \""
|
||||
<< text << "\"\n";
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
// karaoke video generation
|
||||
// outputs a bash script that uses ffmpeg to generate a video with the subtitles
|
||||
// TODO: font parameter adjustments
|
||||
@ -528,7 +548,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
|
||||
fprintf(stderr, "%s: WAV file '%s' must be 16 kHz\n", argv[0], fname_inp.c_str());
|
||||
fprintf(stderr, "%s: WAV file '%s' must be %i kHz\n", argv[0], fname_inp.c_str(), WHISPER_SAMPLE_RATE/1000);
|
||||
return 8;
|
||||
}
|
||||
|
||||
@ -674,6 +694,13 @@ int main(int argc, char ** argv) {
|
||||
const auto fname_wts = fname_inp + ".wts";
|
||||
output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE);
|
||||
}
|
||||
|
||||
// output to CSV file
|
||||
if (params.output_csv) {
|
||||
const auto fname_csv = fname_inp + ".csv";
|
||||
output_csv(ctx, fname_csv.c_str());
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -8,6 +8,8 @@ add_executable(${TARGET}
|
||||
emscripten.cpp
|
||||
)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE
|
||||
whisper
|
||||
)
|
||||
|
@ -2,6 +2,9 @@ if (WHISPER_SUPPORT_SDL2)
|
||||
# stream
|
||||
set(TARGET stream)
|
||||
add_executable(${TARGET} stream.cpp)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
|
||||
target_link_libraries(${TARGET} PRIVATE whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
|
||||
endif ()
|
||||
|
@ -8,6 +8,7 @@
|
||||
#include <SDL.h>
|
||||
#include <SDL_audio.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <cassert>
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
@ -144,8 +145,8 @@ private:
|
||||
int m_len_ms = 0;
|
||||
int m_sample_rate = 0;
|
||||
|
||||
bool m_running = false;
|
||||
std::mutex m_mutex;
|
||||
std::atomic_bool m_running;
|
||||
std::mutex m_mutex;
|
||||
|
||||
std::vector<float> m_audio;
|
||||
std::vector<float> m_audio_new;
|
||||
@ -155,6 +156,8 @@ private:
|
||||
|
||||
audio_async::audio_async(int len_ms) {
|
||||
m_len_ms = len_ms;
|
||||
|
||||
m_running = false;
|
||||
}
|
||||
|
||||
audio_async::~audio_async() {
|
||||
@ -427,10 +430,10 @@ int main(int argc, char ** argv) {
|
||||
const int n_samples_keep = (params.keep_ms *1e-3)*WHISPER_SAMPLE_RATE;
|
||||
const int n_samples_30s = (30000 *1e-3)*WHISPER_SAMPLE_RATE;
|
||||
|
||||
const int n_new_line = params.length_ms / params.step_ms - 1; // number of steps to print new line
|
||||
|
||||
const bool use_vad = n_samples_step <= 0; // sliding window mode uses VAD
|
||||
|
||||
const int n_new_line = !use_vad ? params.length_ms / params.step_ms - 1 : 1; // number of steps to print new line
|
||||
|
||||
params.no_timestamps = !use_vad;
|
||||
params.no_context = use_vad;
|
||||
params.max_tokens = 0;
|
||||
|
@ -9,6 +9,8 @@ add_executable(${TARGET}
|
||||
gpt-2.cpp
|
||||
)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE
|
||||
whisper
|
||||
)
|
||||
@ -31,8 +33,8 @@ set_target_properties(${TARGET} PROPERTIES LINK_FLAGS " \
|
||||
--bind \
|
||||
-s USE_PTHREADS=1 \
|
||||
-s PTHREAD_POOL_SIZE=8 \
|
||||
-s INITIAL_MEMORY=1600MB \
|
||||
-s TOTAL_MEMORY=1600MB \
|
||||
-s INITIAL_MEMORY=1800MB \
|
||||
-s TOTAL_MEMORY=1800MB \
|
||||
-s FORCE_FILESYSTEM=1 \
|
||||
-s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap']\" \
|
||||
${EXTRA_FLAGS} \
|
||||
|
@ -36,7 +36,7 @@ In order to run this demo efficiently, you need to have the following:
|
||||
- Latest Chrome or Firefox browser (Safari is not supported)
|
||||
- Run this on a desktop or laptop with modern CPU (a mobile phone will likely not be good enough)
|
||||
- Speak phrases that are no longer than 10 seconds - this is the audio context of the AI
|
||||
- The web-page uses about 1.6GB of RAM
|
||||
- The web-page uses about 1.8GB of RAM
|
||||
|
||||
Notice that this demo is using the smallest GPT-2 model, so the generated text responses are not always very good.
|
||||
Also, the prompting strategy can likely be improved to achieve better results.
|
||||
|
@ -8,6 +8,9 @@ if (WHISPER_SUPPORT_SDL2)
|
||||
# TODO: this is temporary
|
||||
# need to export ggml symbols for MSVC, but too lazy ..
|
||||
add_executable(${TARGET} talk.cpp gpt-2.cpp ../../ggml.c ../../whisper.cpp)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS} ../../)
|
||||
target_link_libraries(${TARGET} PRIVATE ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
|
||||
endif ()
|
||||
|
@ -8,6 +8,8 @@ add_executable(${TARGET}
|
||||
emscripten.cpp
|
||||
)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE
|
||||
whisper
|
||||
)
|
||||
|
14
ggml.h
14
ggml.h
@ -177,13 +177,11 @@ extern "C" {
|
||||
#include <stddef.h>
|
||||
#include <stdbool.h>
|
||||
|
||||
#define GGML_MAX_DIMS 4
|
||||
#define GGML_MAX_NODES 4096
|
||||
#define GGML_MAX_PARAMS 16
|
||||
#define GGML_MAX_CONTEXTS 64
|
||||
#define GGML_MAX_OPT 4
|
||||
#define GGML_MAX_THREADS 64
|
||||
#define GGML_MAX_THREAD_POOLS 16
|
||||
#define GGML_MAX_DIMS 4
|
||||
#define GGML_MAX_NODES 4096
|
||||
#define GGML_MAX_PARAMS 16
|
||||
#define GGML_MAX_CONTEXTS 64
|
||||
#define GGML_MAX_OPT 4
|
||||
|
||||
#ifdef __ARM_NEON
|
||||
// we use the built-in 16-bit float type
|
||||
@ -733,6 +731,8 @@ int ggml_cpu_has_f16c(void);
|
||||
int ggml_cpu_has_fp16_va(void);
|
||||
int ggml_cpu_has_wasm_simd(void);
|
||||
int ggml_cpu_has_blas(void);
|
||||
int ggml_cpu_has_sse3(void);
|
||||
int ggml_cpu_has_vsx(void);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
@ -56,7 +56,7 @@ def bytes_to_unicode():
|
||||
The reversible bpe codes work on unicode strings.
|
||||
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
||||
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
||||
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
||||
This is a significant percentage of your normal, say, 32K bpe vocab.
|
||||
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
||||
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
||||
"""
|
||||
|
@ -40,7 +40,7 @@ if exist "ggml-%model%.bin" (
|
||||
goto :eof
|
||||
)
|
||||
|
||||
PowerShell -NoProfile -ExecutionPolicy Bypass -Command "Invoke-WebRequest -Uri https://huggingface.co/datasets/ggerganov/whisper.cpp/raw/main/ggml-%model%.bin -OutFile ggml-%model%.bin"
|
||||
PowerShell -NoProfile -ExecutionPolicy Bypass -Command "Invoke-WebRequest -Uri https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main/ggml-%model%.bin -OutFile ggml-%model%.bin"
|
||||
|
||||
if %ERRORLEVEL% neq 0 (
|
||||
echo Failed to download ggml model %model%
|
||||
|
181
whisper.cpp
181
whisper.cpp
@ -204,6 +204,10 @@ struct whisper_vocab {
|
||||
std::map<token, id> token_to_id;
|
||||
std::map<id, token> id_to_token;
|
||||
|
||||
// used to avoid memory allocations during sampling
|
||||
// TODO: move to whisper_context in the future
|
||||
std::vector<std::pair<double, whisper_vocab::id>> probs_id;
|
||||
|
||||
id token_eot = 50256;
|
||||
id token_sot = 50257;
|
||||
id token_prev = 50360;
|
||||
@ -408,6 +412,8 @@ struct whisper_context {
|
||||
std::vector<uint8_t> buf_compute;
|
||||
std::vector<uint8_t> buf_compute_layer;
|
||||
|
||||
ggml_type wtype; // weight type (FP32 or FP16)
|
||||
|
||||
whisper_model model;
|
||||
whisper_vocab vocab;
|
||||
|
||||
@ -431,9 +437,8 @@ struct whisper_context {
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
static void read_safe(std::ifstream& fin, T& dest)
|
||||
{
|
||||
fin.read((char*)& dest, sizeof(T));
|
||||
static void read_safe(std::ifstream& fin, T& dest) {
|
||||
fin.read((char*)& dest, sizeof(T));
|
||||
}
|
||||
|
||||
// load the model from a ggml file
|
||||
@ -551,6 +556,9 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
|
||||
|
||||
std::string word;
|
||||
std::vector<char> tmp;
|
||||
|
||||
tmp.reserve(128);
|
||||
|
||||
for (int i = 0; i < n_vocab; i++) {
|
||||
uint32_t len;
|
||||
read_safe(fin, len);
|
||||
@ -603,6 +611,11 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
|
||||
vocab.id_to_token[i] = word;
|
||||
}
|
||||
}
|
||||
|
||||
wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx);
|
||||
wctx.probs.reserve(vocab.n_vocab*model.hparams.n_text_ctx);
|
||||
|
||||
vocab.probs_id.reserve(n_vocab);
|
||||
}
|
||||
|
||||
{
|
||||
@ -618,7 +631,9 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
|
||||
|
||||
// for the big tensors, we have the option to store the data in 16-bit floats
|
||||
// in order to save memory and also to speed up the computation
|
||||
const ggml_type wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
||||
wctx.wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
||||
|
||||
const ggml_type wtype = wctx.wtype;
|
||||
|
||||
size_t ctx_size = 0;
|
||||
|
||||
@ -639,7 +654,6 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
|
||||
|
||||
// encoder
|
||||
{
|
||||
// TODO: F16 .. maybe not?
|
||||
ctx_size += n_audio_ctx*n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_pe;
|
||||
|
||||
ctx_size += 3*n_mels*n_audio_state*ggml_type_size(wtype); // e_conv_1_w
|
||||
@ -654,7 +668,6 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
|
||||
|
||||
// decoder
|
||||
{
|
||||
// TODO: F16 .. maybe not?
|
||||
ctx_size += n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // d_pe;
|
||||
|
||||
ctx_size += n_vocab*n_text_state*ggml_type_size(wtype); // d_te;
|
||||
@ -971,8 +984,8 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
|
||||
const int n_mem = n_text_layer*n_text_ctx;
|
||||
const int n_elements = n_text_state*n_mem;
|
||||
|
||||
model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
|
||||
model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
|
||||
model.memory_k = ggml_new_tensor_1d(ctx, wtype, n_elements);
|
||||
model.memory_v = ggml_new_tensor_1d(ctx, wtype, n_elements);
|
||||
}
|
||||
|
||||
// key/value memory for the cross-attention layer
|
||||
@ -982,8 +995,8 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
|
||||
const int n_mem = n_text_layer*n_audio_ctx;
|
||||
const int n_elements = n_text_state*n_mem;
|
||||
|
||||
model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
|
||||
model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
|
||||
model.memory_cross_k = ggml_new_tensor_1d(ctx, wtype, n_elements);
|
||||
model.memory_cross_v = ggml_new_tensor_1d(ctx, wtype, n_elements);
|
||||
}
|
||||
|
||||
const size_t memory_size =
|
||||
@ -1021,7 +1034,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
|
||||
|
||||
std::string name;
|
||||
std::vector<char> tmp(length); // create a buffer
|
||||
fin.read( &tmp[0], tmp.size() ); // read to buffer
|
||||
fin.read(&tmp[0], tmp.size()); // read to buffer
|
||||
name.assign(&tmp[0], tmp.size());
|
||||
|
||||
if (model.tensors.find(name) == model.tensors.end()) {
|
||||
@ -1229,14 +1242,14 @@ static bool whisper_encode(
|
||||
ggml_permute(ctxL,
|
||||
ggml_cpy(ctxL,
|
||||
Qcur,
|
||||
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
|
||||
ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
|
||||
0, 2, 1, 3);
|
||||
|
||||
struct ggml_tensor * K =
|
||||
ggml_permute(ctxL,
|
||||
ggml_cpy(ctxL,
|
||||
Kcur,
|
||||
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
|
||||
ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
|
||||
0, 2, 1, 3);
|
||||
|
||||
struct ggml_tensor * V =
|
||||
@ -1246,7 +1259,7 @@ static bool whisper_encode(
|
||||
Vcur,
|
||||
n_state/n_head, n_head, n_ctx),
|
||||
1, 2, 0, 3),
|
||||
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_ctx, n_state/n_head, n_head)
|
||||
ggml_new_tensor_3d(ctxL, wctx.wtype, n_ctx, n_state/n_head, n_head)
|
||||
);
|
||||
|
||||
struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false);
|
||||
@ -1262,7 +1275,7 @@ static bool whisper_encode(
|
||||
ggml_permute(ctxL,
|
||||
ggml_cpy(ctxL,
|
||||
Kcur,
|
||||
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
|
||||
ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
|
||||
0, 2, 1, 3);
|
||||
|
||||
// K * Q
|
||||
@ -1280,7 +1293,7 @@ static bool whisper_encode(
|
||||
// ggml_permute(ctxL,
|
||||
// ggml_cpy(ctxL,
|
||||
// Vcur,
|
||||
// ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
|
||||
// ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
|
||||
// 1, 2, 0, 3);
|
||||
|
||||
//struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
|
||||
@ -1292,7 +1305,7 @@ static bool whisper_encode(
|
||||
Vcur,
|
||||
n_state/n_head, n_head, n_ctx),
|
||||
0, 2, 1, 3),
|
||||
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_ctx, n_head)
|
||||
ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_ctx, n_head)
|
||||
);
|
||||
|
||||
struct ggml_tensor * KQV = ggml_mul_mat(ctxL, ggml_transpose(ctxL, V), KQ_soft_max);
|
||||
@ -1337,7 +1350,7 @@ static bool whisper_encode(
|
||||
|
||||
#ifdef USE_FLASH_FF
|
||||
cur = ggml_flash_ff(ctxL,
|
||||
ggml_cpy(ctxL, cur, ggml_new_tensor_2d(ctxL, GGML_TYPE_F16, n_state, N)),
|
||||
ggml_cpy(ctxL, cur, ggml_new_tensor_2d(ctxL, wctx.wtype, n_state, N)),
|
||||
layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
|
||||
#else
|
||||
// fully connected
|
||||
@ -1444,7 +1457,7 @@ static bool whisper_encode(
|
||||
layer.cross_attn_k_w,
|
||||
cur);
|
||||
|
||||
Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
|
||||
//Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
|
||||
|
||||
struct ggml_tensor * Vcross = ggml_mul_mat(ctx0,
|
||||
layer.cross_attn_v_w,
|
||||
@ -1566,14 +1579,14 @@ static bool whisper_decode(
|
||||
Qcur),
|
||||
Qcur);
|
||||
|
||||
Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
|
||||
//Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
|
||||
|
||||
// note: no bias for Key
|
||||
struct ggml_tensor * Kcur = ggml_mul_mat(ctxL,
|
||||
layer.attn_k_w,
|
||||
cur);
|
||||
|
||||
Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
|
||||
//Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
|
||||
|
||||
struct ggml_tensor * Vcur = ggml_mul_mat(ctxL,
|
||||
layer.attn_v_w,
|
||||
@ -1596,6 +1609,33 @@ static bool whisper_decode(
|
||||
|
||||
// ------
|
||||
|
||||
#ifdef USE_FLASH_ATTN
|
||||
struct ggml_tensor * Q =
|
||||
ggml_permute(ctxL,
|
||||
ggml_cpy(ctxL,
|
||||
Qcur,
|
||||
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
|
||||
0, 2, 1, 3);
|
||||
|
||||
struct ggml_tensor * K =
|
||||
ggml_permute(ctxL,
|
||||
ggml_reshape_3d(ctxL,
|
||||
ggml_view_1d(ctxL, model.memory_k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_k)*n_state),
|
||||
n_state/n_head, n_head, n_past + N),
|
||||
0, 2, 1, 3);
|
||||
|
||||
struct ggml_tensor * V =
|
||||
ggml_cpy(ctxL,
|
||||
ggml_permute(ctxL,
|
||||
ggml_reshape_3d(ctxL,
|
||||
ggml_view_1d(ctxL, model.memory_v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_v)*n_state),
|
||||
n_state/n_head, n_head, n_past + N),
|
||||
1, 2, 0, 3),
|
||||
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_past + N, n_state/n_head, n_head)
|
||||
);
|
||||
|
||||
struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, true);
|
||||
#else
|
||||
struct ggml_tensor * Q =
|
||||
ggml_permute(ctxL,
|
||||
ggml_cpy(ctxL,
|
||||
@ -1613,13 +1653,13 @@ static bool whisper_decode(
|
||||
// K * Q
|
||||
struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q);
|
||||
|
||||
//struct ggml_tensor * KQ_scaled =
|
||||
// ggml_scale(ctxL,
|
||||
// KQ,
|
||||
// ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
|
||||
// );
|
||||
struct ggml_tensor * KQ_scaled =
|
||||
ggml_scale(ctxL,
|
||||
KQ,
|
||||
ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
|
||||
);
|
||||
|
||||
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ, n_past);
|
||||
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ_scaled, n_past);
|
||||
|
||||
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ_masked);
|
||||
|
||||
@ -1631,6 +1671,7 @@ static bool whisper_decode(
|
||||
1, 2, 0, 3);
|
||||
|
||||
struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
|
||||
#endif
|
||||
|
||||
struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
|
||||
|
||||
@ -1676,7 +1717,7 @@ static bool whisper_decode(
|
||||
Qcur),
|
||||
Qcur);
|
||||
|
||||
Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
|
||||
//Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
|
||||
|
||||
// Kcross is already scaled
|
||||
struct ggml_tensor * Kcross =
|
||||
@ -1691,6 +1732,24 @@ static bool whisper_decode(
|
||||
|
||||
// ------
|
||||
|
||||
#ifdef USE_FLASH_ATTN
|
||||
struct ggml_tensor * Q =
|
||||
ggml_permute(ctxL,
|
||||
ggml_cpy(ctxL,
|
||||
Qcur,
|
||||
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
|
||||
0, 2, 1, 3);
|
||||
|
||||
struct ggml_tensor * K = ggml_permute(ctxL, Kcross, 0, 2, 1, 3);
|
||||
|
||||
struct ggml_tensor * V =
|
||||
ggml_cpy(ctxL,
|
||||
ggml_permute(ctxL, Vcross, 1, 2, 0, 3),
|
||||
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, M, n_state/n_head, n_head)
|
||||
);
|
||||
|
||||
struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false);
|
||||
#else
|
||||
struct ggml_tensor * Q =
|
||||
ggml_permute(ctxL,
|
||||
ggml_cpy(ctxL,
|
||||
@ -1717,6 +1776,7 @@ static bool whisper_decode(
|
||||
struct ggml_tensor * V_trans = ggml_permute(ctxL, Vcross, 1, 2, 0, 3);
|
||||
|
||||
struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
|
||||
#endif
|
||||
|
||||
struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
|
||||
|
||||
@ -1849,7 +1909,7 @@ static bool whisper_decode(
|
||||
|
||||
// the most basic sampling scheme - select the top token
|
||||
static whisper_token_data whisper_sample_best(
|
||||
const whisper_vocab & vocab,
|
||||
whisper_vocab & vocab,
|
||||
const float * probs,
|
||||
bool force_timestamp,
|
||||
bool is_initial) {
|
||||
@ -1857,11 +1917,11 @@ static whisper_token_data whisper_sample_best(
|
||||
0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
|
||||
};
|
||||
|
||||
int n_logits = vocab.id_to_token.size();
|
||||
const int n_logits = vocab.n_vocab;
|
||||
|
||||
std::vector<std::pair<double, whisper_vocab::id>> probs_id;
|
||||
probs_id.reserve(n_logits);
|
||||
auto & probs_id = vocab.probs_id;
|
||||
|
||||
probs_id.clear();
|
||||
for (int i = 0; i < n_logits; i++) {
|
||||
probs_id.emplace_back(probs[i], i);
|
||||
}
|
||||
@ -2001,6 +2061,9 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
|
||||
std::vector<float> even;
|
||||
std::vector<float> odd;
|
||||
|
||||
even.reserve(N/2);
|
||||
odd.reserve(N/2);
|
||||
|
||||
for (int i = 0; i < N; i++) {
|
||||
if (i % 2 == 0) {
|
||||
even.push_back(in[i]);
|
||||
@ -2434,7 +2497,7 @@ int whisper_lang_auto_detect(
|
||||
std::vector<std::pair<float, int>> probs_id;
|
||||
for (const auto & kv : g_lang) {
|
||||
const auto token_lang = whisper_token_lang(ctx, kv.second.first);
|
||||
probs_id.emplace_back( ctx->probs[token_lang], kv.second.first );
|
||||
probs_id.emplace_back(ctx->probs[token_lang], kv.second.first);
|
||||
}
|
||||
|
||||
// sort descending
|
||||
@ -2458,12 +2521,12 @@ int whisper_lang_auto_detect(
|
||||
}
|
||||
|
||||
{
|
||||
for (int i = 0; i < (int) probs_id.size(); i++) {
|
||||
for (const auto & prob : probs_id) {
|
||||
if (lang_probs) {
|
||||
lang_probs[probs_id[i].second] = probs_id[i].first;
|
||||
lang_probs[prob.second] = prob.first;
|
||||
}
|
||||
|
||||
//printf("%s: lang %2d (%3s): %f\n", __func__, probs_id[i].second, whisper_lang_str(probs_id[i].second), probs_id[i].first);
|
||||
//printf("%s: lang %2d (%3s): %f\n", __func__, prob.second, whisper_lang_str(prob.second), prob.first);
|
||||
}
|
||||
}
|
||||
|
||||
@ -2482,6 +2545,10 @@ int whisper_n_text_ctx(struct whisper_context * ctx) {
|
||||
return ctx->model.hparams.n_text_ctx;
|
||||
}
|
||||
|
||||
int whisper_n_audio_ctx(struct whisper_context * ctx) {
|
||||
return ctx->model.hparams.n_audio_ctx;
|
||||
}
|
||||
|
||||
int whisper_is_multilingual(struct whisper_context * ctx) {
|
||||
return ctx->vocab.is_multilingual() ? 1 : 0;
|
||||
}
|
||||
@ -2562,6 +2629,8 @@ const char * whisper_print_system_info(void) {
|
||||
s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";
|
||||
s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
|
||||
s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | ";
|
||||
s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | ";
|
||||
s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | ";
|
||||
|
||||
return s.c_str();
|
||||
}
|
||||
@ -2807,7 +2876,11 @@ int whisper_full(
|
||||
std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end());
|
||||
}
|
||||
|
||||
// overwrite audio_ctx
|
||||
// overwrite audio_ctx, max allowed is hparams.n_audio_ctx
|
||||
if (params.audio_ctx > whisper_n_audio_ctx(ctx)) {
|
||||
fprintf(stderr, "%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
|
||||
return -4;
|
||||
}
|
||||
ctx->exp_n_audio_ctx = params.audio_ctx;
|
||||
|
||||
// these tokens determine the task that will be performed
|
||||
@ -3134,7 +3207,7 @@ int whisper_full_parallel(
|
||||
|
||||
// separate key + value memory for each processor
|
||||
{
|
||||
auto & ctx = model.ctx_mem;
|
||||
auto & mctx = model.ctx_mem;
|
||||
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
@ -3147,8 +3220,8 @@ int whisper_full_parallel(
|
||||
const int n_mem = n_text_layer*n_text_ctx;
|
||||
const int n_elements = n_text_state*n_mem;
|
||||
|
||||
model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
|
||||
model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
|
||||
model.memory_k = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
|
||||
model.memory_v = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
|
||||
}
|
||||
|
||||
// key/value memory for the cross-attention layer
|
||||
@ -3158,8 +3231,8 @@ int whisper_full_parallel(
|
||||
const int n_mem = n_text_layer*n_audio_ctx;
|
||||
const int n_elements = n_text_state*n_mem;
|
||||
|
||||
model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
|
||||
model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
|
||||
model.memory_cross_k = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
|
||||
model.memory_cross_v = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -3203,17 +3276,17 @@ int whisper_full_parallel(
|
||||
for (int i = 0; i < n_processors - 1; ++i) {
|
||||
auto & results_i = ctxs[i].result_all;
|
||||
|
||||
for (int j = 0; j < (int) results_i.size(); ++j) {
|
||||
for (auto & result : results_i) {
|
||||
// correct the segment timestamp taking into account the offset
|
||||
results_i[j].t0 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
|
||||
results_i[j].t1 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
|
||||
result.t0 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
|
||||
result.t1 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
|
||||
|
||||
// make sure that segments are not overlapping
|
||||
if (!ctx->result_all.empty()) {
|
||||
results_i[j].t0 = std::max(results_i[j].t0, ctx->result_all.back().t1);
|
||||
result.t0 = std::max(result.t0, ctx->result_all.back().t1);
|
||||
}
|
||||
|
||||
ctx->result_all.push_back(std::move(results_i[j]));
|
||||
ctx->result_all.push_back(std::move(result));
|
||||
|
||||
// call the new_segment_callback for each segment
|
||||
if (params.new_segment_callback) {
|
||||
@ -3308,18 +3381,18 @@ static int64_t sample_to_timestamp(int i_sample) {
|
||||
static float voice_length(const std::string & text) {
|
||||
float res = 0.0f;
|
||||
|
||||
for (size_t i = 0; i < text.size(); ++i) {
|
||||
if (text[i] == ' ') {
|
||||
for (char c : text) {
|
||||
if (c == ' ') {
|
||||
res += 0.01f;
|
||||
} else if (text[i] == ',') {
|
||||
} else if (c == ',') {
|
||||
res += 2.00f;
|
||||
} else if (text[i] == '.') {
|
||||
} else if (c == '.') {
|
||||
res += 3.00f;
|
||||
} else if (text[i] == '!') {
|
||||
} else if (c == '!') {
|
||||
res += 3.00f;
|
||||
} else if (text[i] == '?') {
|
||||
} else if (c == '?') {
|
||||
res += 3.00f;
|
||||
} else if (text[i] >= '0' && text[i] <= '9') {
|
||||
} else if (c >= '0' && c <= '9') {
|
||||
res += 3.00f;
|
||||
} else {
|
||||
res += 1.00f;
|
||||
|
@ -148,7 +148,7 @@ extern "C" {
|
||||
struct whisper_context * ctx,
|
||||
const char * text,
|
||||
whisper_token * tokens,
|
||||
int n_max_tokens);
|
||||
int n_max_tokens);
|
||||
|
||||
// Largest language id (i.e. number of available languages - 1)
|
||||
WHISPER_API int whisper_lang_max_id();
|
||||
@ -177,6 +177,7 @@ extern "C" {
|
||||
WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length
|
||||
WHISPER_API int whisper_n_vocab (struct whisper_context * ctx);
|
||||
WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx);
|
||||
WHISPER_API int whisper_n_audio_ctx (struct whisper_context * ctx);
|
||||
WHISPER_API int whisper_is_multilingual(struct whisper_context * ctx);
|
||||
|
||||
// The probabilities for the next token
|
||||
|
Reference in New Issue
Block a user