Compare commits

..

1 Commits

Author SHA1 Message Date
4e6d2e98ab ggml : try to improve threading 2022-12-29 13:05:20 +02:00
68 changed files with 1972 additions and 3889 deletions

View File

@ -1,17 +0,0 @@
name: Bindings Tests
on:
push:
paths:
- bindings/go/**
jobs:
ubuntu-latest:
runs-on: ubuntu-latest
steps:
- uses: actions/setup-go@v3
with:
go-version: '^1.19'
- uses: actions/checkout@v1
- run: |
cd bindings/go
make test

View File

@ -235,33 +235,3 @@ jobs:
with:
name: whisper-blas-bin-${{ matrix.arch }}
path: build/bin/${{ matrix.build }}
emscripten:
runs-on: ubuntu-latest
strategy:
matrix:
build: [Release]
steps:
- name: Clone
uses: actions/checkout@v1
- name: Dependencies
run: |
wget -q https://github.com/emscripten-core/emsdk/archive/master.tar.gz
tar -xvf master.tar.gz
emsdk-master/emsdk update
emsdk-master/emsdk install latest
emsdk-master/emsdk activate latest
- name: Configure
run: echo "tmp"
- name: Build
run: |
pushd emsdk-master
source ./emsdk_env.sh
popd
emcmake cmake . -DCMAKE_BUILD_TYPE=${{ matrix.build }}
make

3
.gitignore vendored
View File

@ -8,7 +8,6 @@ build/
build-em/
build-debug/
build-release/
build-static/
build-sanitize-addr/
build-sanitize-thread/
@ -18,9 +17,7 @@ build-sanitize-thread/
/talk
/bench
arm_neon.h
sync.sh
libwhisper.a
libwhisper.so
compile_commands.json

View File

@ -1,16 +1,15 @@
cmake_minimum_required (VERSION 3.0)
project(whisper.cpp VERSION 1.1.1)
# Add path to modules
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
project(whisper.cpp VERSION 1.0.4)
set(CMAKE_EXPORT_COMPILE_COMMANDS "on")
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
set(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_PREFIX}/lib")
if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
set(WHISPER_STANDALONE ON)
include(GitVars)
include(BuildTypes)
include(cmake/GitVars.cmake)
include(cmake/BuildTypes.cmake)
# configure project version
if (EXISTS "${CMAKE_SOURCE_DIR}/bindings/ios/Makefile-tmpl")
@ -53,7 +52,6 @@ if (APPLE)
option(WHISPER_NO_ACCELERATE "whisper: disable Accelerate framework" OFF)
option(WHISPER_NO_AVX "whisper: disable AVX" OFF)
option(WHISPER_NO_AVX2 "whisper: disable AVX2" OFF)
option(WHISPER_NO_FMA "whisper: disable FMA" OFF)
else()
option(WHISPER_SUPPORT_OPENBLAS "whisper: support for OpenBLAS" OFF)
endif()
@ -84,6 +82,9 @@ endif()
# dependencies
set(CMAKE_C_STANDARD 11)
set(CMAKE_CXX_STANDARD 11)
find_package(Threads REQUIRED)
# on APPLE - include Accelerate framework
@ -130,7 +131,6 @@ if (WHISPER_ALL_WARNINGS)
-Wcast-qual \
-Wstrict-prototypes \
-Wpointer-arith \
-Wno-unused-function \
")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} \
-Wall \
@ -157,7 +157,6 @@ else()
if (MSVC)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX2")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /arch:AVX2")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /arch:AVX2")
else()
if (EMSCRIPTEN)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -pthread")
@ -169,10 +168,7 @@ else()
if(NOT WHISPER_NO_AVX2)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx2")
endif()
if(NOT WHISPER_NO_FMA)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mfma")
endif()
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mf16c")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mfma -mf16c")
endif()
endif()
endif()
@ -194,8 +190,6 @@ add_library(${TARGET}
whisper.cpp
)
include(DefaultTargetOptions)
target_include_directories(${TARGET} PUBLIC
.
)
@ -229,7 +223,6 @@ target_compile_definitions(${TARGET} PUBLIC
install(TARGETS ${TARGET}
LIBRARY DESTINATION lib
ARCHIVE DESTINATION lib/static
RUNTIME DESTINATION bin
)
#

View File

@ -10,9 +10,6 @@ ifndef UNAME_M
UNAME_M := $(shell uname -m)
endif
CCV := $(shell $(CC) --version | head -n 1)
CXXV := $(shell $(CXX) --version | head -n 1)
# Mac OS + Arm can report x86_64
# ref: https://github.com/ggerganov/whisper.cpp/issues/66#issuecomment-1282546789
ifeq ($(UNAME_S),Darwin)
@ -56,13 +53,10 @@ endif
# Architecture specific
# TODO: probably these flags need to be tweaked on some architectures
# feel free to update the Makefile for your architecture and send a pull request or issue
ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686))
ifeq ($(UNAME_M),x86_64)
ifeq ($(UNAME_S),Darwin)
CFLAGS += -mf16c
CFLAGS += -mfma -mf16c
AVX1_M := $(shell sysctl machdep.cpu.features)
ifneq (,$(findstring FMA,$(AVX1_M)))
CFLAGS += -mfma
endif
ifneq (,$(findstring AVX1.0,$(AVX1_M)))
CFLAGS += -mavx
endif
@ -87,10 +81,6 @@ ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686))
ifneq (,$(findstring f16c,$(F16C_M)))
CFLAGS += -mf16c
endif
SSE3_M := $(shell grep "sse3 " /proc/cpuinfo)
ifneq (,$(findstring sse3,$(SSE3_M)))
CFLAGS += -msse3
endif
else ifeq ($(UNAME_S),Haiku)
AVX1_M := $(shell sysinfo -cpu | grep "AVX ")
ifneq (,$(findstring avx,$(AVX1_M)))
@ -133,8 +123,8 @@ ifdef WHISPER_OPENBLAS
LDFLAGS += -lopenblas
endif
ifdef WHISPER_GPROF
CFLAGS += -pg
CXXFLAGS += -pg
CFLAGS += -pg
CXXFLAGS += -pg
endif
ifneq ($(filter aarch64%,$(UNAME_M)),)
endif
@ -151,21 +141,6 @@ ifneq ($(filter armv8%,$(UNAME_M)),)
CFLAGS += -mfp16-format=ieee -mno-unaligned-access
endif
#
# Print build information
#
$(info I whisper.cpp build info: )
$(info I UNAME_S: $(UNAME_S))
$(info I UNAME_P: $(UNAME_P))
$(info I UNAME_M: $(UNAME_M))
$(info I CFLAGS: $(CFLAGS))
$(info I CXXFLAGS: $(CXXFLAGS))
$(info I LDFLAGS: $(LDFLAGS))
$(info I CC: $(CCV))
$(info I CXX: $(CXXV))
$(info )
default: main
#

View File

@ -4,14 +4,13 @@
[![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](https://opensource.org/licenses/MIT)
[![npm](https://img.shields.io/npm/v/whisper.cpp.svg)](https://www.npmjs.com/package/whisper.cpp/)
Stable: [v1.1.1](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.1.1) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126)
[Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126)
High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model:
- Plain C/C++ implementation without dependencies
- Apple silicon first-class citizen - optimized via Arm Neon and Accelerate framework
- AVX intrinsics support for x86 architectures
- VSX intrinsics support for POWER architectures
- Mixed F16 / F32 precision
- Low memory usage (Flash Attention + Flash Forward)
- Zero memory allocations at runtime
@ -71,7 +70,7 @@ Now build the [main](examples/main) example and transcribe an audio file like th
make
# transcribe an audio file
./main -f samples/jfk.wav
./main -f input.wav
```
---
@ -89,36 +88,27 @@ c++ -I. -I./examples -O3 -std=c++11 -pthread examples/main/main.cpp whisper.o gg
usage: ./main [options] file0.wav file1.wav ...
options:
-h, --help [default] show this help message and exit
-t N, --threads N [4 ] number of threads to use during computation
-p N, --processors N [1 ] number of processors to use during computation
-ot N, --offset-t N [0 ] time offset in milliseconds
-on N, --offset-n N [0 ] segment index offset
-d N, --duration N [0 ] duration of audio to process in milliseconds
-mc N, --max-context N [-1 ] maximum number of text context tokens to store
-ml N, --max-len N [0 ] maximum segment length in characters
-bo N, --best-of N [5 ] number of best candidates to keep
-bs N, --beam-size N [-1 ] beam size for beam search
-wt N, --word-thold N [0.01 ] word timestamp probability threshold
-et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail
-lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail
-su, --speed-up [false ] speed up audio by x2 (reduced accuracy)
-tr, --translate [false ] translate from source language to english
-di, --diarize [false ] stereo audio diarization
-otxt, --output-txt [false ] output result in a text file
-ovtt, --output-vtt [false ] output result in a vtt file
-osrt, --output-srt [false ] output result in a srt file
-owts, --output-words [false ] output script for generating karaoke video
-ocsv, --output-csv [false ] output result in a CSV file
-ps, --print-special [false ] print special tokens
-pc, --print-colors [false ] print colors
-pp, --print-progress [false ] print progress
-nt, --no-timestamps [true ] do not print timestamps
-l LANG, --language LANG [en ] spoken language ('auto' for auto-detect)
--prompt PROMPT [ ] initial prompt
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path
-f FNAME, --file FNAME [ ] input WAV file path
-h, --help [default] show this help message and exit
-t N, --threads N [4 ] number of threads to use during computation
-p N, --processors N [1 ] number of processors to use during computation
-ot N, --offset-t N [0 ] time offset in milliseconds
-on N, --offset-n N [0 ] segment index offset
-d N, --duration N [0 ] duration of audio to process in milliseconds
-mc N, --max-context N [-1 ] maximum number of text context tokens to store
-ml N, --max-len N [0 ] maximum segment length in characters
-wt N, --word-thold N [0.01 ] word timestamp probability threshold
-su, --speed-up [false ] speed up audio by x2 (reduced accuracy)
-tr, --translate [false ] translate from source language to english
-otxt, --output-txt [false ] output result in a text file
-ovtt, --output-vtt [false ] output result in a vtt file
-osrt, --output-srt [false ] output result in a srt file
-owts, --output-words [false ] output script for generating karaoke video
-ps, --print-special [false ] print special tokens
-pc, --print-colors [false ] print colors
-nt, --no-timestamps [true ] do not print timestamps
-l LANG, --language LANG [en ] spoken language
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path
-f FNAME, --file FNAME [ ] input WAV file path
bash ./models/download-ggml-model.sh base.en
Downloading ggml model base.en ...
@ -221,7 +211,17 @@ make large
## Limitations
- Inference only
- No GPU support (yet)
- No GPU support
- Very basic greedy sampling scheme - always pick up the token with highest probability.
This should be similar to the [GreedyDecoder](https://github.com/openai/whisper/blob/main/whisper/decoding.py#L249-L274)
from the original python implementation, so in order to make a fair comparison between the 2 implementations, make sure
to run the python code with the following parameters:
```
whisper --best_of None --beam_size None ...
```
In the future, `whisper.cpp` will support more sampling strategies.
## Another example
@ -306,7 +306,6 @@ The [stream](examples/stream) tool samples the audio every half a second and run
More info is available in [issue #10](https://github.com/ggerganov/whisper.cpp/issues/10).
```java
make stream
./stream -m ./models/ggml-base.en.bin -t 8 --step 500 --length 5000
```

View File

@ -1,2 +1,3 @@
build
models
go.sum

View File

@ -1,27 +1,28 @@
BUILD_DIR := build
MODELS_DIR := models
CMAKE := $(shell which cmake)
BUILD_DIR := "build"
MODELS_DIR := "models"
EXAMPLES_DIR := $(wildcard examples/*)
INCLUDE_PATH := $(abspath ../..)
LIBRARY_PATH := $(abspath ../..)
C_INCLUDE_PATH := "../.."
all: clean whisper examples
whisper: mkdir
@echo Build whisper
@${MAKE} -C ../.. libwhisper.a
@${CMAKE} -S ../.. -B ${BUILD_DIR} -D BUILD_SHARED_LIBS=off -D WHISPER_NO_AVX2=on
@${CMAKE} --build ${BUILD_DIR} --target whisper
test: model-small whisper modtidy
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go test -v .
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go test -v ./pkg/whisper/...
@go test -v .
@go test -v ./pkg/whisper/...
examples: $(EXAMPLES_DIR)
model-small: mkdir examples/go-model-download
@${BUILD_DIR}/go-model-download -out models ggml-small.en.bin
@${BUILD_DIR}/go-model-download -out models small.en
$(EXAMPLES_DIR): mkdir whisper modtidy
@echo Build example $(notdir $@)
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go build ${BUILD_FLAGS} -o ${BUILD_DIR}/$(notdir $@) ./$@
@go build ${BUILD_FLAGS} -o ${BUILD_DIR}/$(notdir $@) ./$@
mkdir:
@echo Mkdir ${BUILD_DIR}

View File

@ -74,27 +74,4 @@ And you can then test a model against samples with the following command:
./build/go-whisper -model models/ggml-tiny.en.bin samples/jfk.wav
```
## Using the bindings
To use the bindings in your own software,
1. Import `github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper` (or `github.com/ggerganov/whisper.cpp/bindings/go` into your package;
2. Compile `libwhisper.a` (you can use `make whisper` in the `bindings/go` directory);
3. Link your go binary against whisper by setting the environment variables `C_INCLUDE_PATH` and `LIBRARY_PATH`
to point to the `whisper.h` file directory and `libwhisper.a` file directory respectively.
Look at the `Makefile` in the `bindings/go` directory for an example.
The API Documentation:
* https://pkg.go.dev/github.com/ggerganov/whisper.cpp/bindings/go
* https://pkg.go.dev/github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper
Getting help:
* Follow the discussion for the go bindings [here](https://github.com/ggerganov/whisper.cpp/discussions/312)
## License
The license for the Go bindings is the same as the license for the rest of the whisper.cpp project, which is the MIT License. See the `LICENSE` file for more details.

View File

@ -17,14 +17,15 @@ import (
// CONSTANTS
const (
srcUrl = "https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main" // The location of the models
srcExt = ".bin" // Filename extension
bufSize = 1024 * 64 // Size of the buffer used for downloading the model
srcUrl = "https://huggingface.co/" // The location of the models
srcPathPrefix = "/datasets/ggerganov/whisper.cpp/resolve/main/ggml" // Filename prefix
srcExt = ".bin" // Filename extension
bufSize = 1024 * 64 // Size of the buffer used for downloading the model
)
var (
// The models which will be downloaded, if no model is specified as an argument
modelNames = []string{"ggml-tiny.en", "ggml-tiny", "ggml-base.en", "ggml-base", "ggml-small.en", "ggml-small", "ggml-medium.en", "ggml-medium", "ggml-large-v1", "ggml-large"}
modelNames = []string{"tiny.en", "tiny", "base.en", "base", "small.en", "small", "medium.en", "medium", "large-v1", "large"}
)
var (
@ -122,14 +123,11 @@ func GetModels() []string {
// URLForModel returns the URL for the given model on huggingface.co
func URLForModel(model string) (string, error) {
if filepath.Ext(model) != srcExt {
model += srcExt
}
url, err := url.Parse(srcUrl)
if err != nil {
return "", err
} else {
url.Path = filepath.Join(url.Path, model)
url.Path = srcPathPrefix + "-" + model + srcExt
}
return url.String(), nil
}

View File

@ -1,22 +0,0 @@
package main
import "fmt"
///////////////////////////////////////////////////////////////////////////////
// CONSTANTS
const (
Reset = "\033[0m"
RGBPrefix = "\033[38;5;" // followed by RGB values in decimal format separated by colons
RGBSuffix = "m"
)
///////////////////////////////////////////////////////////////////////////////
// PUBLIC METHODS
// Colorize text with RGB values, from 0 to 23
func Colorize(text string, v int) string {
// https://en.wikipedia.org/wiki/ANSI_escape_code#8-bit
// Grayscale colors are in the range 232-255
return RGBPrefix + fmt.Sprint(v%24+232) + RGBSuffix + text + Reset
}

View File

@ -2,12 +2,6 @@ package main
import (
"flag"
"fmt"
"strings"
"time"
// Packages
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
)
///////////////////////////////////////////////////////////////////////////////
@ -48,26 +42,6 @@ func (flags *Flags) GetLanguage() string {
return flags.Lookup("language").Value.String()
}
func (flags *Flags) IsTranslate() bool {
return flags.Lookup("translate").Value.(flag.Getter).Get().(bool)
}
func (flags *Flags) GetOffset() time.Duration {
return flags.Lookup("offset").Value.(flag.Getter).Get().(time.Duration)
}
func (flags *Flags) GetDuration() time.Duration {
return flags.Lookup("duration").Value.(flag.Getter).Get().(time.Duration)
}
func (flags *Flags) GetThreads() uint {
return flags.Lookup("threads").Value.(flag.Getter).Get().(uint)
}
func (flags *Flags) GetOut() string {
return strings.ToLower(flags.Lookup("out").Value.String())
}
func (flags *Flags) IsSpeedup() bool {
return flags.Lookup("speedup").Value.String() == "true"
}
@ -76,81 +50,12 @@ func (flags *Flags) IsTokens() bool {
return flags.Lookup("tokens").Value.String() == "true"
}
func (flags *Flags) IsColorize() bool {
return flags.Lookup("colorize").Value.String() == "true"
}
func (flags *Flags) GetMaxLen() uint {
return flags.Lookup("max-len").Value.(flag.Getter).Get().(uint)
}
func (flags *Flags) GetMaxTokens() uint {
return flags.Lookup("max-tokens").Value.(flag.Getter).Get().(uint)
}
func (flags *Flags) GetWordThreshold() float32 {
return float32(flags.Lookup("word-thold").Value.(flag.Getter).Get().(float64))
}
func (flags *Flags) SetParams(context whisper.Context) error {
if lang := flags.GetLanguage(); lang != "" && lang != "auto" {
fmt.Fprintf(flags.Output(), "Setting language to %q\n", lang)
if err := context.SetLanguage(lang); err != nil {
return err
}
}
if flags.IsTranslate() && context.IsMultilingual() {
fmt.Fprintf(flags.Output(), "Setting translate to true\n")
context.SetTranslate(true)
}
if offset := flags.GetOffset(); offset != 0 {
fmt.Fprintf(flags.Output(), "Setting offset to %v\n", offset)
context.SetOffset(offset)
}
if duration := flags.GetDuration(); duration != 0 {
fmt.Fprintf(flags.Output(), "Setting duration to %v\n", duration)
context.SetDuration(duration)
}
if flags.IsSpeedup() {
fmt.Fprintf(flags.Output(), "Setting speedup to true\n")
context.SetSpeedup(true)
}
if threads := flags.GetThreads(); threads != 0 {
fmt.Fprintf(flags.Output(), "Setting threads to %d\n", threads)
context.SetThreads(threads)
}
if max_len := flags.GetMaxLen(); max_len != 0 {
fmt.Fprintf(flags.Output(), "Setting max_segment_length to %d\n", max_len)
context.SetMaxSegmentLength(max_len)
}
if max_tokens := flags.GetMaxTokens(); max_tokens != 0 {
fmt.Fprintf(flags.Output(), "Setting max_tokens to %d\n", max_tokens)
context.SetMaxTokensPerSegment(max_tokens)
}
if word_threshold := flags.GetWordThreshold(); word_threshold != 0 {
fmt.Fprintf(flags.Output(), "Setting word_threshold to %f\n", word_threshold)
context.SetTokenThreshold(word_threshold)
}
// Return success
return nil
}
///////////////////////////////////////////////////////////////////////////////
// PRIVATE METHODS
func registerFlags(flag *Flags) {
flag.String("model", "", "Path to the model file")
flag.String("language", "", "Spoken language")
flag.Bool("translate", false, "Translate from source language to english")
flag.Duration("offset", 0, "Time offset")
flag.Duration("duration", 0, "Duration of audio to process")
flag.Uint("threads", 0, "Number of threads to use")
flag.String("language", "", "Language")
flag.Bool("speedup", false, "Enable speedup")
flag.Uint("max-len", 0, "Maximum segment length in characters")
flag.Uint("max-tokens", 0, "Maximum tokens per segment")
flag.Float64("word-thold", 0, "Maximum segment score")
flag.Bool("tokens", false, "Display tokens")
flag.Bool("colorize", false, "Colorize tokens")
flag.String("out", "", "Output format (srt, none or leave as empty string)")
}

View File

@ -35,7 +35,8 @@ func main() {
// Process files
for _, filename := range flags.Args() {
if err := Process(model, filename, flags); err != nil {
fmt.Println("Processing", filename)
if err := Process(model, filename, flags.GetLanguage(), flags.IsSpeedup(), flags.IsTokens()); err != nil {
fmt.Fprintln(os.Stderr, err)
continue
}

View File

@ -11,7 +11,7 @@ import (
wav "github.com/go-audio/wav"
)
func Process(model whisper.Model, path string, flags *Flags) error {
func Process(model whisper.Model, path string, lang string, speedup, tokens bool) error {
var data []float32
// Create processing context
@ -20,20 +20,14 @@ func Process(model whisper.Model, path string, flags *Flags) error {
return err
}
// Set the parameters
if err := flags.SetParams(context); err != nil {
return err
}
// Open the file
fmt.Fprintf(flags.Output(), "Loading %q\n", path)
fh, err := os.Open(path)
if err != nil {
return err
}
defer fh.Close()
// Decode the WAV file - load the full buffer
// Decode the WAV file
dec := wav.NewDecoder(fh)
if buf, err := dec.FullPCMBuffer(); err != nil {
return err
@ -45,83 +39,42 @@ func Process(model whisper.Model, path string, flags *Flags) error {
data = buf.AsFloat32Buffer().Data
}
// Segment callback when -tokens is specified
// Set the parameters
var cb whisper.SegmentCallback
if flags.IsTokens() {
if lang != "" {
if err := context.SetLanguage(lang); err != nil {
return err
}
}
if speedup {
context.SetSpeedup(true)
}
if tokens {
cb = func(segment whisper.Segment) {
fmt.Fprintf(flags.Output(), "%02d [%6s->%6s] ", segment.Num, segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond))
fmt.Printf("%02d [%6s->%6s] ", segment.Num, segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond))
for _, token := range segment.Tokens {
if flags.IsColorize() && context.IsText(token) {
fmt.Fprint(flags.Output(), Colorize(token.Text, int(token.P*24.0)), " ")
} else {
fmt.Fprint(flags.Output(), token.Text, " ")
}
fmt.Printf("%q ", token.Text)
}
fmt.Fprintln(flags.Output(), "")
fmt.Fprintln(flags.Output(), "")
fmt.Println("")
}
}
// Process the data
fmt.Fprintf(flags.Output(), " ...processing %q\n", path)
if err := context.Process(data, cb); err != nil {
return err
}
// Print out the results
switch {
case flags.GetOut() == "srt":
return OutputSRT(os.Stdout, context)
case flags.GetOut() == "none":
return nil
default:
return Output(os.Stdout, context, flags.IsColorize())
}
}
// Output text as SRT file
func OutputSRT(w io.Writer, context whisper.Context) error {
n := 1
for {
segment, err := context.NextSegment()
if err == io.EOF {
return nil
break
} else if err != nil {
return err
}
fmt.Fprintln(w, n)
fmt.Fprintln(w, srtTimestamp(segment.Start), " --> ", srtTimestamp(segment.End))
fmt.Fprintln(w, segment.Text)
fmt.Fprintln(w, "")
n++
fmt.Printf("[%6s->%6s] %s\n", segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond), segment.Text)
}
}
// Output text to terminal
func Output(w io.Writer, context whisper.Context, colorize bool) error {
for {
segment, err := context.NextSegment()
if err == io.EOF {
return nil
} else if err != nil {
return err
}
fmt.Fprintf(w, "[%6s->%6s]", segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond))
if colorize {
for _, token := range segment.Tokens {
if !context.IsText(token) {
continue
}
fmt.Fprint(w, " ", Colorize(token.Text, int(token.P*24.0)))
}
fmt.Fprint(w, "\n")
} else {
fmt.Fprintln(w, " ", segment.Text)
}
}
}
// Return srtTimestamp
func srtTimestamp(t time.Duration) string {
return fmt.Sprintf("%02d:%02d:%02d,%03d", t/time.Hour, (t%time.Hour)/time.Minute, (t%time.Minute)/time.Second, (t%time.Second)/time.Millisecond)
// Return success
return nil
}

View File

@ -1,23 +0,0 @@
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-audio/audio v1.0.0 h1:zS9vebldgbQqktK4H0lUqWrG8P0NxCJVqcj7ZpNnwd4=
github.com/go-audio/audio v1.0.0/go.mod h1:6uAu0+H2lHkwdGsAY+j2wHPNPpPoeg5AaEFh9FlA+Zs=
github.com/go-audio/riff v1.0.0 h1:d8iCGbDvox9BfLagY94fBynxSPHO80LmZCaOsmKxokA=
github.com/go-audio/riff v1.0.0/go.mod h1:l3cQwc85y79NQFCRB7TiPoNiaijp6q8Z0Uv38rVG498=
github.com/go-audio/wav v1.1.0 h1:jQgLtbqBzY7G+BM8fXF7AHUk1uHUviWS4X39d5rsL2g=
github.com/go-audio/wav v1.1.0/go.mod h1:mpe9qfwbScEbkd8uybLuIpTgHyrISw/OTuvjUW2iGtE=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@ -1,5 +1,8 @@
package whisper
// This file defines the whisper_token, whisper_token_data and whisper_full_params
// structures, which are used by the whisper_full() function.
import (
"fmt"
)
@ -47,7 +50,6 @@ func (p *Params) SetSpeedup(v bool) {
p.speed_up = toBool(v)
}
// Set language id
func (p *Params) SetLanguage(lang int) error {
str := C.whisper_lang_str(C.int(lang))
if str == nil {
@ -58,7 +60,6 @@ func (p *Params) SetLanguage(lang int) error {
return nil
}
// Get language id
func (p *Params) Language() int {
if p.language == nil {
return -1
@ -66,41 +67,18 @@ func (p *Params) Language() int {
return int(C.whisper_lang_id(p.language))
}
// Set number of threads to use
func (p *Params) SetThreads(threads int) {
p.n_threads = C.int(threads)
}
// Set start offset in ms
func (p *Params) SetOffset(offset_ms int) {
p.offset_ms = C.int(offset_ms)
}
// Set audio duration to process in ms
func (p *Params) SetDuration(duration_ms int) {
p.duration_ms = C.int(duration_ms)
}
// Set timestamp token probability threshold (~0.01)
func (p *Params) SetTokenThreshold(t float32) {
p.thold_pt = C.float(t)
}
// Set timestamp token sum probability threshold (~0.01)
func (p *Params) SetTokenSumThreshold(t float32) {
p.thold_ptsum = C.float(t)
}
// Set max segment length in characters
func (p *Params) SetMaxSegmentLength(n int) {
p.max_len = C.int(n)
}
// Set max tokens per segment (0 = no limit)
func (p *Params) SetMaxTokensPerSegment(n int) {
p.max_tokens = C.int(n)
}
///////////////////////////////////////////////////////////////////////////////
// PRIVATE METHODS

View File

@ -11,11 +11,10 @@ import (
// ERRORS
var (
ErrUnableToLoadModel = errors.New("unable to load model")
ErrInternalAppError = errors.New("internal application error")
ErrProcessingFailed = errors.New("processing failed")
ErrUnsupportedLanguage = errors.New("unsupported language")
ErrModelNotMultilingual = errors.New("model is not multilingual")
ErrUnableToLoadModel = errors.New("unable to load model")
ErrInternalAppError = errors.New("internal application error")
ErrProcessingFailed = errors.New("processing failed")
ErrUnsupportedLanguage = errors.New("unsupported language")
)
///////////////////////////////////////////////////////////////////////////////

View File

@ -24,7 +24,7 @@ var _ Context = (*context)(nil)
///////////////////////////////////////////////////////////////////////////////
// LIFECYCLE
func newContext(model *model, params whisper.Params) (Context, error) {
func NewContext(model *model, params whisper.Params) (Context, error) {
context := new(context)
context.model = model
context.params = params
@ -41,9 +41,6 @@ func (context *context) SetLanguage(lang string) error {
if context.model.ctx == nil {
return ErrInternalAppError
}
if !context.model.IsMultilingual() {
return ErrModelNotMultilingual
}
if id := context.model.ctx.Whisper_lang_id(lang); id < 0 {
return ErrUnsupportedLanguage
} else if err := context.params.SetLanguage(id); err != nil {
@ -53,60 +50,16 @@ func (context *context) SetLanguage(lang string) error {
return nil
}
func (context *context) IsMultilingual() bool {
return context.model.IsMultilingual()
}
// Get language
func (context *context) Language() string {
return whisper.Whisper_lang_str(context.params.Language())
}
// Set translate flag
func (context *context) SetTranslate(v bool) {
context.params.SetTranslate(v)
}
// Set speedup flag
func (context *context) SetSpeedup(v bool) {
context.params.SetSpeedup(v)
}
// Set number of threads to use
func (context *context) SetThreads(v uint) {
context.params.SetThreads(int(v))
}
// Set time offset
func (context *context) SetOffset(v time.Duration) {
context.params.SetOffset(int(v.Milliseconds()))
}
// Set duration of audio to process
func (context *context) SetDuration(v time.Duration) {
context.params.SetOffset(int(v.Milliseconds()))
}
// Set timestamp token probability threshold (~0.01)
func (context *context) SetTokenThreshold(t float32) {
context.params.SetTokenThreshold(t)
}
// Set timestamp token sum probability threshold (~0.01)
func (context *context) SetTokenSumThreshold(t float32) {
context.params.SetTokenSumThreshold(t)
}
// Set max segment length in characters
func (context *context) SetMaxSegmentLength(n uint) {
context.params.SetMaxSegmentLength(int(n))
}
// Set max tokens per segment (0 = no limit)
func (context *context) SetMaxTokensPerSegment(n uint) {
context.params.SetMaxTokensPerSegment(int(n))
}
// Process new sample data and return any errors
func (context *context) Process(data []float32, cb SegmentCallback) error {
if context.model.ctx == nil {
@ -166,65 +119,6 @@ func (context *context) NextSegment() (Segment, error) {
return result, nil
}
// Test for text tokens
func (context *context) IsText(t Token) bool {
switch {
case context.IsBEG(t):
return false
case context.IsSOT(t):
return false
case whisper.Token(t.Id) >= context.model.ctx.Whisper_token_eot():
return false
case context.IsPREV(t):
return false
case context.IsSOLM(t):
return false
case context.IsNOT(t):
return false
default:
return true
}
}
// Test for "begin" token
func (context *context) IsBEG(t Token) bool {
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_beg()
}
// Test for "start of transcription" token
func (context *context) IsSOT(t Token) bool {
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_sot()
}
// Test for "end of transcription" token
func (context *context) IsEOT(t Token) bool {
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_eot()
}
// Test for "start of prev" token
func (context *context) IsPREV(t Token) bool {
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_prev()
}
// Test for "start of lm" token
func (context *context) IsSOLM(t Token) bool {
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_solm()
}
// Test for "No timestamps" token
func (context *context) IsNOT(t Token) bool {
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_not()
}
// Test for token associated with a specific language
func (context *context) IsLANG(t Token, lang string) bool {
if id := context.model.ctx.Whisper_lang_id(lang); id >= 0 {
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_lang(id)
} else {
return false
}
}
///////////////////////////////////////////////////////////////////////////////
// PRIVATE METHODS

View File

@ -20,9 +20,6 @@ type Model interface {
// Return a new speech-to-text context.
NewContext() (Context, error)
// Return true if the model is multilingual.
IsMultilingual() bool
// Return all languages supported.
Languages() []string
}
@ -30,18 +27,8 @@ type Model interface {
// Context is the speach recognition context.
type Context interface {
SetLanguage(string) error // Set the language to use for speech recognition.
SetTranslate(bool) // Set translate flag
IsMultilingual() bool // Return true if the model is multilingual.
Language() string // Get language
SetOffset(time.Duration) // Set offset
SetDuration(time.Duration) // Set duration
SetThreads(uint) // Set number of threads to use
SetSpeedup(bool) // Set speedup flag
SetTokenThreshold(float32) // Set timestamp token probability threshold
SetTokenSumThreshold(float32) // Set timestamp token sum probability threshold
SetMaxSegmentLength(uint) // Set max segment length in characters
SetMaxTokensPerSegment(uint) // Set max tokens per segment (0 = no limit)
SetSpeedup(bool) // Set speedup flag
// Process mono audio data and return any errors.
// If defined, newly generated segments are passed to the
@ -51,15 +38,6 @@ type Context interface {
// After process is called, return segments until the end of the stream
// is reached, when io.EOF is returned.
NextSegment() (Segment, error)
IsBEG(Token) bool // Test for "begin" token
IsSOT(Token) bool // Test for "start of transcription" token
IsEOT(Token) bool // Test for "end of transcription" token
IsPREV(Token) bool // Test for "start of prev" token
IsSOLM(Token) bool // Test for "start of lm" token
IsNOT(Token) bool // Test for "No timestamps" token
IsLANG(Token, string) bool // Test for token associated with a specific language
IsText(Token) bool // Test for text token
}
// Segment is the text result of a speech recognition.

View File

@ -23,7 +23,7 @@ var _ Model = (*model)(nil)
///////////////////////////////////////////////////////////////////////////////
// LIFECYCLE
func New(path string) (Model, error) {
func New(path string) (*model, error) {
model := new(model)
if _, err := os.Stat(path); err != nil {
return nil, err
@ -64,11 +64,6 @@ func (model *model) String() string {
///////////////////////////////////////////////////////////////////////////////
// PUBLIC METHODS
// Return true if model is multilingual (language and translation options are supported)
func (model *model) IsMultilingual() bool {
return model.ctx.Whisper_is_multilingual() != 0
}
// Return all recognized languages. Initially it is set to auto-detect
func (model *model) Languages() []string {
result := make([]string, 0, whisper.Whisper_lang_max_id())
@ -96,5 +91,5 @@ func (model *model) NewContext() (Context, error) {
params.SetThreads(runtime.NumCPU())
// Return new context
return newContext(model, params)
return NewContext(model, params)
}

View File

@ -9,7 +9,8 @@ import (
// CGO
/*
#cgo LDFLAGS: -lwhisper -lm -lstdc++
#cgo CFLAGS: -I${SRCDIR}/../..
#cgo LDFLAGS: -L${SRCDIR}/build -lwhisper -lm -lstdc++
#cgo darwin LDFLAGS: -framework Accelerate
#include <whisper.h>
#include <stdlib.h>
@ -91,7 +92,7 @@ var (
func Whisper_init(path string) *Context {
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
if ctx := C.whisper_init_from_file(cPath); ctx != nil {
if ctx := C.whisper_init(cPath); ctx != nil {
return (*Context)(ctx)
} else {
return nil
@ -147,6 +148,16 @@ func (ctx *Context) Whisper_decode(tokens []Token, past, threads int) error {
}
}
// whisper_sample_best() returns the token with the highest probability
func (ctx *Context) Whisper_sample_best() TokenData {
return TokenData(C.whisper_sample_best((*C.struct_whisper_context)(ctx)))
}
// whisper_sample_timestamp() returns the most probable timestamp token
func (ctx *Context) Whisper_sample_timestamp(is_initial bool) TokenData {
return TokenData(C.whisper_sample_timestamp((*C.struct_whisper_context)(ctx), C.bool(is_initial)))
}
// Convert the provided text into tokens. The tokens pointer must be large enough to hold the resulting tokens.
// Returns the number of tokens on success
func (ctx *Context) Whisper_tokenize(text string, tokens []Token) (int, error) {
@ -160,10 +171,6 @@ func (ctx *Context) Whisper_tokenize(text string, tokens []Token) (int, error) {
}
// Return the id of the specified language, returns -1 if not found
// Examples:
//
// "de" -> 2
// "german" -> 2
func (ctx *Context) Whisper_lang_id(lang string) int {
return int(C.whisper_lang_id(C.CString(lang)))
}
@ -204,10 +211,6 @@ func (ctx *Context) Whisper_n_text_ctx() int {
return int(C.whisper_n_text_ctx((*C.struct_whisper_context)(ctx)))
}
func (ctx *Context) Whisper_n_audio_ctx() int {
return int(C.whisper_n_audio_ctx((*C.struct_whisper_context)(ctx)))
}
func (ctx *Context) Whisper_is_multilingual() int {
return int(C.whisper_is_multilingual((*C.struct_whisper_context)(ctx)))
}

View File

@ -50,10 +50,7 @@ func Test_Whisper_001(t *testing.T) {
ctx := whisper.Whisper_init(ModelPath)
assert.NotNil(ctx)
defer ctx.Whisper_free()
params := ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY)
data := buf.AsFloat32Buffer().Data
err = ctx.Whisper_full(params, data, nil, nil)
assert.NoError(err)
assert.NoError(ctx.Whisper_full(ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY), buf.AsFloat32Buffer().Data, nil, nil))
// Print out tokens
num_segments := ctx.Whisper_full_n_segments()

View File

@ -20,7 +20,7 @@ struct whisper_context * g_context;
EMSCRIPTEN_BINDINGS(whisper) {
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
if (g_context == nullptr) {
g_context = whisper_init_from_file(path_model.c_str());
g_context = whisper_init(path_model.c_str());
if (g_context != nullptr) {
return true;
} else {

View File

@ -1,6 +1,6 @@
{
"name": "whisper.cpp",
"version": "1.1.1",
"version": "1.0.4",
"description": "Whisper speech recognition",
"main": "whisper.js",
"scripts": {

File diff suppressed because one or more lines are too long

View File

@ -1,17 +0,0 @@
# Set the default compile features and properties for a target.
if (NOT TARGET)
message(FATAL_ERROR "TARGET not set before including DefaultTargetOptions")
endif()
target_compile_features(${TARGET}
PRIVATE
cxx_std_11
)
set_target_properties(${TARGET}
PROPERTIES
EXPORT_COMPILE_COMMANDS ON
RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin"
INSTALL_RPATH "${CMAKE_INSTALL_PREFIX}/lib"
)

View File

@ -8,8 +8,6 @@ add_executable(${TARGET}
emscripten.cpp
)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE
whisper
)

View File

@ -28,11 +28,6 @@ void bench_main(size_t index) {
return;
}
{
fprintf(stderr, "\n");
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", n_threads, std::thread::hardware_concurrency(), whisper_print_system_info());
}
if (int ret = whisper_encode(ctx, 0, n_threads) != 0) {
fprintf(stderr, "error: failed to encode model: %d\n", ret);
return;
@ -57,7 +52,7 @@ EMSCRIPTEN_BINDINGS(bench) {
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
for (size_t i = 0; i < g_contexts.size(); ++i) {
if (g_contexts[i] == nullptr) {
g_contexts[i] = whisper_init_from_file(path_model.c_str());
g_contexts[i] = whisper_init(path_model.c_str());
if (g_contexts[i] != nullptr) {
if (g_worker.joinable()) {
g_worker.join();

View File

@ -1,6 +1,3 @@
set(TARGET bench)
add_executable(${TARGET} bench.cpp)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE whisper ${CMAKE_THREAD_LIBS_INIT})

View File

@ -7,7 +7,6 @@
// command-line parameters
struct whisper_params {
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
int32_t what = 0; // what to benchmark: 0 - whisper ecoder, 1 - memcpy, 2 - ggml_mul_mat
std::string model = "models/ggml-base.en.bin";
};
@ -24,7 +23,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
}
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else if (arg == "-w" || arg == "--what") { params.what = atoi(argv[++i]); }
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params);
@ -43,17 +41,19 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -w N, --what N [%-7d] what to benchmark:\n", params.what);
fprintf(stderr, " %-7s 0 - whisper encoder\n", "");
fprintf(stderr, " %-7s 1 - memcpy\n", "");
fprintf(stderr, " %-7s 2 - ggml_mul_mat\n", "");
fprintf(stderr, "\n");
}
int whisper_bench_encoder(const whisper_params & params) {
int main(int argc, char ** argv) {
whisper_params params;
if (whisper_params_parse(argc, argv, params) == false) {
return 1;
}
// whisper init
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
struct whisper_context * ctx = whisper_init(params.model.c_str());
{
fprintf(stderr, "\n");
@ -92,22 +92,3 @@ int whisper_bench_encoder(const whisper_params & params) {
return 0;
}
int main(int argc, char ** argv) {
whisper_params params;
if (whisper_params_parse(argc, argv, params) == false) {
return 1;
}
int ret = -1;
switch (params.what) {
case 0: ret = whisper_bench_encoder(params); break;
case 1: ret = whisper_bench_memcpy(params.n_threads); break;
case 2: ret = whisper_bench_ggml_mul_mat(params.n_threads); break;
default: fprintf(stderr, "error: unknown benchmark: %d\n", params.what); break;
}
return ret;
}

View File

@ -8,8 +8,6 @@ add_executable(${TARGET}
emscripten.cpp
)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE
whisper
)

View File

@ -324,7 +324,7 @@ EMSCRIPTEN_BINDINGS(command) {
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
for (size_t i = 0; i < g_contexts.size(); ++i) {
if (g_contexts[i] == nullptr) {
g_contexts[i] = whisper_init_from_file(path_model.c_str());
g_contexts[i] = whisper_init(path_model.c_str());
if (g_contexts[i] != nullptr) {
g_running = true;
if (g_worker.joinable()) {

View File

@ -2,9 +2,6 @@ if (WHISPER_SUPPORT_SDL2)
# command
set(TARGET command)
add_executable(${TARGET} command.cpp)
include(DefaultTargetOptions)
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
target_link_libraries(${TARGET} PRIVATE whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
endif ()

View File

@ -9,19 +9,7 @@ More info is available in [issue #171](https://github.com/ggerganov/whisper.cpp/
# On Raspberry Pi, use tiny or base models + "-ac 768" for better performance
./command -m ./models/ggml-tiny.en.bin -ac 768 -t 3 -c 0
```
https://user-images.githubusercontent.com/1991296/204038393-2f846eae-c255-4099-a76d-5735c25c49da.mp4
Web version: [examples/command.wasm](/examples/command.wasm)
## Guided mode
"Guided mode" allows you to specify a list of commands (i.e. strings) and the transcription will be guided to classify your command into one from the list. This can be useful in situations where a device is listening only for a small subset of commands.
Initial tests show that this approach might be extremely efficient in terms of performance, since it integrates very well with the "partial Encoder" idea from #137.
```bash
# Run in guided mode, the list of allowed commands is in commands.txt
./command -m ./models/ggml-base.en.bin -cmd ./examples/command/commands.txt
@ -29,8 +17,9 @@ Initial tests show that this approach might be extremely efficient in terms of p
./command -m ./models/ggml-tiny.en.bin -cmd ./examples/command/commands.txt -ac 128 -t 3 -c 0
```
https://user-images.githubusercontent.com/1991296/207435352-8fc4ed3f-bde5-4555-9b8b-aeeb76bee969.mp4
https://user-images.githubusercontent.com/1991296/204038393-2f846eae-c255-4099-a76d-5735c25c49da.mp4
Web version: [examples/command.wasm](/examples/command.wasm)
## Building

View File

@ -11,7 +11,6 @@
#include <SDL.h>
#include <SDL_audio.h>
#include <sstream>
#include <cassert>
#include <cstdio>
#include <fstream>
@ -26,7 +25,7 @@
struct whisper_params {
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
int32_t prompt_ms = 5000;
int32_t command_ms = 8000;
int32_t command_ms = 4000;
int32_t capture_id = -1;
int32_t max_tokens = 32;
int32_t audio_ctx = 0;
@ -44,7 +43,6 @@ struct whisper_params {
std::string model = "models/ggml-base.en.bin";
std::string fname_out;
std::string commands;
std::string prompt;
};
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
@ -73,7 +71,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
else if (arg == "-cmd" || arg == "--commands") { params.commands = argv[++i]; }
else if (arg == "-p" || arg == "--prompt") { params.prompt = argv[++i]; }
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params);
@ -106,7 +103,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
fprintf(stderr, " -cmd FNAME, --commands FNAME [%-7s] text file with allowed commands\n", params.commands.c_str());
fprintf(stderr, " -p, --prompt [%-7s] the required activation prompt\n", params.prompt.c_str());
fprintf(stderr, "\n");
}
@ -514,433 +510,6 @@ std::vector<std::string> read_allowed_commands(const std::string & fname) {
return allowed_commands;
}
std::vector<std::string> get_words(const std::string &txt) {
std::vector<std::string> words;
std::istringstream iss(txt);
std::string word;
while (iss >> word) {
words.push_back(word);
}
return words;
}
// returns true if no exit event was received
bool process_sdl_events() {
SDL_Event event;
while (SDL_PollEvent(&event)) {
switch (event.type) {
case SDL_QUIT:
{
return false;
} break;
default:
break;
}
}
return true;
}
// command-list mode
// guide the transcription to match the most likely command from a provided list
int process_command_list(struct whisper_context * ctx, audio_async &audio, const whisper_params &params) {
fprintf(stderr, "\n");
fprintf(stderr, "%s: guided mode\n", __func__);
std::vector<std::string> allowed_commands = read_allowed_commands(params.commands);
if (allowed_commands.empty()) {
fprintf(stderr, "%s: error: failed to read allowed commands from '%s'\n", __func__, params.commands.c_str());
return 2;
}
int max_len = 0;
std::vector<std::vector<whisper_token>> allowed_tokens;
for (const auto & cmd : allowed_commands) {
whisper_token tokens[1024];
allowed_tokens.emplace_back();
for (int l = 0; l < (int) cmd.size(); ++l) {
// NOTE: very important to add the whitespace !
// the reason is that the first decoded token starts with a whitespace too!
std::string ss = std::string(" ") + cmd.substr(0, l + 1);
const int n = whisper_tokenize(ctx, ss.c_str(), tokens, 1024);
if (n < 0) {
fprintf(stderr, "%s: error: failed to tokenize command '%s'\n", __func__, cmd.c_str());
return 3;
}
if (n == 1) {
allowed_tokens.back().push_back(tokens[0]);
}
}
max_len = std::max(max_len, (int) cmd.size());
}
fprintf(stderr, "%s: allowed commands [ tokens ]:\n", __func__);
fprintf(stderr, "\n");
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
fprintf(stderr, " - \033[1m%-*s\033[0m = [", max_len, allowed_commands[i].c_str());
for (const auto & token : allowed_tokens[i]) {
fprintf(stderr, " %5d", token);
}
fprintf(stderr, " ]\n");
}
std::string k_prompt = "select one from the available words: ";
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
if (i > 0) {
k_prompt += ", ";
}
k_prompt += allowed_commands[i];
}
k_prompt += ". selected word: ";
// tokenize prompt
std::vector<whisper_token> k_tokens;
{
k_tokens.resize(1024);
const int n = whisper_tokenize(ctx, k_prompt.c_str(), k_tokens.data(), 1024);
if (n < 0) {
fprintf(stderr, "%s: error: failed to tokenize prompt '%s'\n", __func__, k_prompt.c_str());
return 4;
}
k_tokens.resize(n);
}
fprintf(stderr, "\n");
fprintf(stderr, "%s: prompt: '%s'\n", __func__, k_prompt.c_str());
fprintf(stderr, "%s: tokens: [", __func__);
for (const auto & token : k_tokens) {
fprintf(stderr, " %d", token);
}
fprintf(stderr, " ]\n");
fprintf(stderr, "\n");
fprintf(stderr, "%s: listening for a command ...\n", __func__);
fprintf(stderr, "\n");
bool is_running = true;
std::vector<float> pcmf32_cur;
std::vector<float> pcmf32_prompt;
// main loop
while (is_running) {
// handle Ctrl + C
is_running = process_sdl_events();
// delay
std::this_thread::sleep_for(std::chrono::milliseconds(100));
audio.get(2000, pcmf32_cur);
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
const auto t_start = std::chrono::high_resolution_clock::now();
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
wparams.print_progress = false;
wparams.print_special = params.print_special;
wparams.print_realtime = false;
wparams.print_timestamps = !params.no_timestamps;
wparams.translate = params.translate;
wparams.no_context = true;
wparams.single_segment = true;
wparams.max_tokens = 1;
wparams.language = params.language.c_str();
wparams.n_threads = params.n_threads;
wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;
wparams.prompt_tokens = k_tokens.data();
wparams.prompt_n_tokens = k_tokens.size();
// run the transformer and a single decoding pass
if (whisper_full(ctx, wparams, pcmf32_cur.data(), pcmf32_cur.size()) != 0) {
fprintf(stderr, "%s: ERROR: whisper_full() failed\n", __func__);
break;
}
// estimate command probability
// NOTE: not optimal
{
const auto * logits = whisper_get_logits(ctx);
std::vector<float> probs(whisper_n_vocab(ctx), 0.0f);
// compute probs from logits via softmax
{
float max = -1e9;
for (int i = 0; i < (int) probs.size(); ++i) {
max = std::max(max, logits[i]);
}
float sum = 0.0f;
for (int i = 0; i < (int) probs.size(); ++i) {
probs[i] = expf(logits[i] - max);
sum += probs[i];
}
for (int i = 0; i < (int) probs.size(); ++i) {
probs[i] /= sum;
}
}
std::vector<std::pair<float, int>> probs_id;
double psum = 0.0;
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
probs_id.emplace_back(probs[allowed_tokens[i][0]], i);
for (int j = 1; j < (int) allowed_tokens[i].size(); ++j) {
probs_id.back().first += probs[allowed_tokens[i][j]];
}
probs_id.back().first /= allowed_tokens[i].size();
psum += probs_id.back().first;
}
// normalize
for (auto & p : probs_id) {
p.first /= psum;
}
// sort descending
{
using pair_type = decltype(probs_id)::value_type;
std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) {
return a.first > b.first;
});
}
// print the commands and the respective probabilities
{
fprintf(stdout, "\n");
for (const auto & cmd : probs_id) {
fprintf(stdout, "%s: %s%-*s%s = %f | ", __func__, "\033[1m", max_len, allowed_commands[cmd.second].c_str(), "\033[0m", cmd.first);
for (int token : allowed_tokens[cmd.second]) {
fprintf(stdout, "'%4s' %f ", whisper_token_to_str(ctx, token), probs[token]);
}
fprintf(stdout, "\n");
}
}
// best command
{
const auto t_end = std::chrono::high_resolution_clock::now();
const float prob = probs_id[0].first;
const int index = probs_id[0].second;
fprintf(stdout, "\n");
fprintf(stdout, "%s: detected command: %s%s%s | p = %f | t = %d ms\n", __func__,
"\033[1m", allowed_commands[index].c_str(), "\033[0m", prob,
(int) std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count());
fprintf(stdout, "\n");
}
}
audio.clear();
}
}
return 0;
}
// always-prompt mode
// transcribe the voice into text after valid prompt
int always_prompt_transcription(struct whisper_context * ctx, audio_async & audio, const whisper_params & params) {
bool is_running = true;
bool ask_prompt = true;
float prob = 0.0f;
std::vector<float> pcmf32_cur;
const std::string k_prompt = params.prompt;
const int k_prompt_length = get_words(k_prompt).size();
fprintf(stderr, "\n");
fprintf(stderr, "%s: always-prompt mode\n", __func__);
// main loop
while (is_running) {
// handle Ctrl + C
is_running = process_sdl_events();
// delay
std::this_thread::sleep_for(std::chrono::milliseconds(100));
if (ask_prompt) {
fprintf(stdout, "\n");
fprintf(stdout, "%s: The prompt is: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m");
fprintf(stdout, "\n");
ask_prompt = false;
}
{
audio.get(2000, pcmf32_cur);
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
int64_t t_ms = 0;
// detect the commands
audio.get(params.command_ms, pcmf32_cur);
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
const auto words = get_words(txt);
std::string prompt;
std::string command;
for (int i = 0; i < (int) words.size(); ++i) {
if (i < k_prompt_length) {
prompt += words[i] + " ";
} else {
command += words[i] + " ";
}
}
const float sim = similarity(prompt, k_prompt);
//debug
//fprintf(stdout, "command size: %i\n", command_length);
if ((sim > 0.7f) && (command.size() > 0)) {
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
}
fprintf(stdout, "\n");
audio.clear();
}
}
}
return 0;
}
// general-purpose mode
// freely transcribe the voice into text
int process_general_transcription(struct whisper_context * ctx, audio_async &audio, const whisper_params &params) {
bool is_running = true;
bool have_prompt = false;
bool ask_prompt = true;
float prob0 = 0.0f;
float prob = 0.0f;
std::vector<float> pcmf32_cur;
std::vector<float> pcmf32_prompt;
const std::string k_prompt = "Ok Whisper, start listening for commands.";
fprintf(stderr, "\n");
fprintf(stderr, "%s: general-purpose mode\n", __func__);
// main loop
while (is_running) {
// handle Ctrl + C
is_running = process_sdl_events();
// delay
std::this_thread::sleep_for(std::chrono::milliseconds(100));
if (ask_prompt) {
fprintf(stdout, "\n");
fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m");
fprintf(stdout, "\n");
ask_prompt = false;
}
{
audio.get(2000, pcmf32_cur);
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
int64_t t_ms = 0;
if (!have_prompt) {
// wait for activation phrase
audio.get(params.prompt_ms, pcmf32_cur);
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob0, t_ms));
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms);
const float sim = similarity(txt, k_prompt);
if (txt.length() < 0.8*k_prompt.length() || txt.length() > 1.2*k_prompt.length() || sim < 0.8f) {
fprintf(stdout, "%s: WARNING: prompt not recognized, try again\n", __func__);
ask_prompt = true;
} else {
fprintf(stdout, "\n");
fprintf(stdout, "%s: The prompt has been recognized!\n", __func__);
fprintf(stdout, "%s: Waiting for voice commands ...\n", __func__);
fprintf(stdout, "\n");
// save the audio for the prompt
pcmf32_prompt = pcmf32_cur;
have_prompt = true;
}
} else {
// we have heard the activation phrase, now detect the commands
audio.get(params.command_ms, pcmf32_cur);
// prepend the prompt audio
pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
prob = 100.0f*(prob - prob0);
//fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str());
// find the prompt in the text
float best_sim = 0.0f;
size_t best_len = 0;
for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
const auto prompt = txt.substr(0, n);
const float sim = similarity(prompt, k_prompt);
//fprintf(stderr, "%s: prompt = '%s', sim = %f\n", __func__, prompt.c_str(), sim);
if (sim > best_sim) {
best_sim = sim;
best_len = n;
}
}
const std::string command = ::trim(txt.substr(best_len));
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
fprintf(stdout, "\n");
}
audio.clear();
}
}
}
return 0;
}
int main(int argc, char ** argv) {
whisper_params params;
@ -956,7 +525,7 @@ int main(int argc, char ** argv) {
// whisper init
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
struct whisper_context * ctx = whisper_init(params.model.c_str());
// print some info about the processing
{
@ -992,14 +561,300 @@ int main(int argc, char ** argv) {
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
audio.clear();
int ret_val = 0;
int max_len = 0;
bool is_running = true;
bool have_prompt = false;
bool ask_prompt = true;
float prob0 = 0.0f;
float prob = 0.0f;
std::vector<float> pcmf32_cur;
std::vector<float> pcmf32_prompt;
std::vector<std::string> allowed_commands;
std::vector<std::vector<whisper_token>> allowed_tokens;
std::string k_prompt;
std::vector<whisper_token> k_tokens;
if (!params.commands.empty()) {
ret_val = process_command_list(ctx, audio, params);
} else if (!params.prompt.empty()) {
ret_val = always_prompt_transcription(ctx, audio, params);
fprintf(stderr, "\n");
fprintf(stderr, "%s: guided mode\n", __func__);
allowed_commands = read_allowed_commands(params.commands);
if (allowed_commands.empty()) {
fprintf(stderr, "%s: error: failed to read allowed commands from '%s'\n", __func__, params.commands.c_str());
return 2;
}
for (const auto & cmd : allowed_commands) {
whisper_token tokens[1024];
allowed_tokens.emplace_back();
for (int l = 0; l < (int) cmd.size(); ++l) {
// NOTE: very important to add the whitespace !
// the reason is that the first decoded token starts with a whitespace too!
std::string ss = std::string(" ") + cmd.substr(0, l + 1);
const int n = whisper_tokenize(ctx, ss.c_str(), tokens, 1024);
if (n < 0) {
fprintf(stderr, "%s: error: failed to tokenize command '%s'\n", __func__, cmd.c_str());
return 3;
}
if (n == 1) {
allowed_tokens.back().push_back(tokens[0]);
}
}
max_len = std::max(max_len, (int) cmd.size());
}
fprintf(stderr, "%s: allowed commands [ tokens ]:\n", __func__);
fprintf(stderr, "\n");
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
fprintf(stderr, " - \033[1m%-*s\033[0m = [", max_len, allowed_commands[i].c_str());
for (const auto & token : allowed_tokens[i]) {
fprintf(stderr, " %5d", token);
}
fprintf(stderr, " ]\n");
}
k_prompt = "select one from the available words: ";
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
if (i > 0) {
k_prompt += ", ";
}
k_prompt += allowed_commands[i];
}
k_prompt += ". selected word: ";
// tokenize prompt
{
k_tokens.resize(1024);
const int n = whisper_tokenize(ctx, k_prompt.c_str(), k_tokens.data(), 1024);
if (n < 0) {
fprintf(stderr, "%s: error: failed to tokenize prompt '%s'\n", __func__, k_prompt.c_str());
return 4;
}
k_tokens.resize(n);
}
fprintf(stderr, "\n");
fprintf(stderr, "%s: prompt: '%s'\n", __func__, k_prompt.c_str());
fprintf(stderr, "%s: tokens: [", __func__);
for (const auto & token : k_tokens) {
fprintf(stderr, " %d", token);
}
fprintf(stderr, " ]\n");
fprintf(stderr, "\n");
fprintf(stderr, "%s: listening for a command ...\n", __func__);
fprintf(stderr, "\n");
} else {
ret_val = process_general_transcription(ctx, audio, params);
fprintf(stderr, "\n");
fprintf(stderr, "%s: general-purpose mode\n", __func__);
k_prompt = "Ok Whisper, start listening for commands.";
}
// main loop
while (is_running) {
// handle Ctrl + C
{
SDL_Event event;
while (SDL_PollEvent(&event)) {
switch (event.type) {
case SDL_QUIT:
{
is_running = false;
} break;
default:
break;
}
}
if (!is_running) {
break;
}
}
// delay
std::this_thread::sleep_for(std::chrono::milliseconds(100));
if (allowed_commands.empty()) {
// general-purpose mode
// freely transcribe the voice into text
if (ask_prompt) {
fprintf(stdout, "\n");
fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m");
fprintf(stdout, "\n");
ask_prompt = false;
}
{
int64_t t_ms = 0;
audio.get(2000, pcmf32_cur);
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
if (!have_prompt) {
// wait for activation phrase
audio.get(params.prompt_ms, pcmf32_cur);
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob0, t_ms));
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms);
const float sim = similarity(txt, k_prompt);
if (txt.length() < 0.8*k_prompt.length() || txt.length() > 1.2*k_prompt.length() || sim < 0.8f) {
fprintf(stdout, "%s: WARNING: prompt not recognized, try again\n", __func__);
ask_prompt = true;
} else {
fprintf(stdout, "\n");
fprintf(stdout, "%s: The prompt has been recognized!\n", __func__);
fprintf(stdout, "%s: Waiting for voice commands ...\n", __func__);
fprintf(stdout, "\n");
// save the audio for the prompt
pcmf32_prompt = pcmf32_cur;
have_prompt = true;
}
} else {
// we have heard the activation phrase, now detect the commands
audio.get(params.command_ms, pcmf32_cur);
// prepend the prompt audio
pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
prob = 100.0f*(prob - prob0);
//fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str());
// find the prompt in the text
float best_sim = 0.0f;
size_t best_len = 0;
for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
const auto prompt = txt.substr(0, n);
const float sim = similarity(prompt, k_prompt);
//fprintf(stderr, "%s: prompt = '%s', sim = %f\n", __func__, prompt.c_str(), sim);
if (sim > best_sim) {
best_sim = sim;
best_len = n;
}
}
const std::string command = ::trim(txt.substr(best_len));
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
fprintf(stdout, "\n");
}
audio.clear();
}
}
} else {
// command-list mode
// guide the transcription to match the most likely command from a provided list
audio.get(2000, pcmf32_cur);
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
const auto t_start = std::chrono::high_resolution_clock::now();
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
wparams.print_progress = false;
wparams.print_special = params.print_special;
wparams.print_realtime = false;
wparams.print_timestamps = !params.no_timestamps;
wparams.translate = params.translate;
wparams.no_context = true;
wparams.single_segment = true;
wparams.max_tokens = 1;
wparams.language = params.language.c_str();
wparams.n_threads = params.n_threads;
wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;
wparams.prompt_tokens = k_tokens.data();
wparams.prompt_n_tokens = k_tokens.size();
// run the transformer and a single decoding pass
if (whisper_full(ctx, wparams, pcmf32_cur.data(), pcmf32_cur.size()) != 0) {
fprintf(stderr, "%s: ERROR: whisper_full() failed\n", __func__);
break;
}
const auto * probs = whisper_get_probs(ctx);
std::vector<std::pair<float, int>> probs_id;
double psum = 0.0;
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
probs_id.emplace_back(probs[allowed_tokens[i][0]], i);
for (int j = 1; j < (int) allowed_tokens[i].size(); ++j) {
probs_id.back().first += probs[allowed_tokens[i][j]];
}
probs_id.back().first /= allowed_tokens[i].size();
psum += probs_id.back().first;
}
// normalize
for (auto & p : probs_id) {
p.first /= psum;
}
// sort descending
{
using pair_type = decltype(probs_id)::value_type;
std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) {
return a.first > b.first;
});
}
// print the commands and the respective probabilities
{
fprintf(stdout, "\n");
for (const auto & cmd : probs_id) {
fprintf(stdout, "%s: %s%-*s%s = %f | ", __func__, "\033[1m", max_len, allowed_commands[cmd.second].c_str(), "\033[0m", cmd.first);
for (int i = 0; i < (int) allowed_tokens[cmd.second].size(); ++i) {
fprintf(stdout, "'%4s' %f ", whisper_token_to_str(ctx, allowed_tokens[cmd.second][i]), probs[allowed_tokens[cmd.second][i]]);
}
fprintf(stdout, "\n");
}
}
// best command
{
const auto t_end = std::chrono::high_resolution_clock::now();
fprintf(stdout, "\n");
fprintf(stdout, "%s: detected command: %s%s%s | p = %f | t = %d ms\n", __func__,
"\033[1m", allowed_commands[probs_id[0].second].c_str(), "\033[0m", probs_id[0].first,
(int) std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count());
fprintf(stdout, "\n");
}
audio.clear();
}
}
}
audio.pause();
@ -1007,5 +862,5 @@ int main(int argc, char ** argv) {
whisper_print_timings(ctx);
whisper_free(ctx);
return ret_val;
return 0;
}

View File

@ -1,6 +1,3 @@
set(TARGET main)
add_executable(${TARGET} main.cpp)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE whisper ${CMAKE_THREAD_LIBS_INIT})

View File

@ -59,12 +59,8 @@ struct whisper_params {
int32_t duration_ms = 0;
int32_t max_context = -1;
int32_t max_len = 0;
int32_t best_of = 5;
int32_t beam_size = -1;
float word_thold = 0.01f;
float entropy_thold = 2.4f;
float logprob_thold = -1.0f;
float word_thold = 0.01f;
bool speed_up = false;
bool translate = false;
@ -73,7 +69,6 @@ struct whisper_params {
bool output_vtt = false;
bool output_srt = false;
bool output_wts = false;
bool output_csv = false;
bool print_special = false;
bool print_colors = false;
bool print_progress = false;
@ -84,7 +79,6 @@ struct whisper_params {
std::string model = "models/ggml-base.en.bin";
std::vector<std::string> fname_inp = {};
std::vector<std::string> fname_outp = {};
};
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
@ -109,11 +103,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); }
else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); }
else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); }
else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); }
else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); }
else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); }
else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); }
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
@ -121,8 +111,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; }
else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; }
else if (arg == "-of" || arg == "--output-file") { params.fname_outp.emplace_back(argv[++i]); }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
@ -146,36 +134,30 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, "usage: %s [options] file0.wav file1.wav ...\n", argv[0]);
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors);
fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms);
fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n);
fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms);
fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len);
fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of);
fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size);
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold);
fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold);
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false");
fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false");
fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false");
fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false");
fprintf(stderr, " -of FNAME, --output-file FNAME [%-7s] output file path (without file extension)\n", "");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true");
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors);
fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms);
fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n);
fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms);
fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len);
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false");
fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false");
fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true");
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
fprintf(stderr, "\n");
}
@ -191,81 +173,90 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi
const int n_segments = whisper_full_n_segments(ctx);
std::string speaker = "";
int64_t t0;
int64_t t1;
// print the last n_new segments
const int s0 = n_segments - n_new;
if (s0 == 0) {
printf("\n");
}
for (int i = s0; i < n_segments; i++) {
if (!params.no_timestamps || params.diarize) {
t0 = whisper_full_get_segment_t0(ctx, i);
t1 = whisper_full_get_segment_t1(ctx, i);
}
if (!params.no_timestamps) {
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
}
if (params.diarize && pcmf32s.size() == 2) {
const int64_t n_samples = pcmf32s[0].size();
const int64_t is0 = timestamp_to_sample(t0, n_samples);
const int64_t is1 = timestamp_to_sample(t1, n_samples);
double energy0 = 0.0f;
double energy1 = 0.0f;
for (int64_t j = is0; j < is1; j++) {
energy0 += fabs(pcmf32s[0][j]);
energy1 += fabs(pcmf32s[1][j]);
}
if (energy0 > 1.1*energy1) {
speaker = "(speaker 0)";
} else if (energy1 > 1.1*energy0) {
speaker = "(speaker 1)";
} else {
speaker = "(speaker ?)";
}
//printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str());
}
if (params.print_colors) {
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
if (params.print_special == false) {
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
if (id >= whisper_token_eot(ctx)) {
continue;
if (params.no_timestamps) {
if (params.print_colors) {
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
if (params.print_special == false) {
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
if (id >= whisper_token_eot(ctx)) {
continue;
}
}
const char * text = whisper_full_get_token_text(ctx, i, j);
const float p = whisper_full_get_token_p (ctx, i, j);
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
}
} else {
const char * text = whisper_full_get_segment_text(ctx, i);
printf("%s", text);
}
fflush(stdout);
} else {
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
std::string speaker;
if (params.diarize && pcmf32s.size() == 2) {
const int64_t n_samples = pcmf32s[0].size();
const int64_t is0 = timestamp_to_sample(t0, n_samples);
const int64_t is1 = timestamp_to_sample(t1, n_samples);
double energy0 = 0.0f;
double energy1 = 0.0f;
for (int64_t j = is0; j < is1; j++) {
energy0 += fabs(pcmf32s[0][j]);
energy1 += fabs(pcmf32s[1][j]);
}
const char * text = whisper_full_get_token_text(ctx, i, j);
const float p = whisper_full_get_token_p (ctx, i, j);
if (energy0 > 1.1*energy1) {
speaker = "(speaker 0)";
} else if (energy1 > 1.1*energy0) {
speaker = "(speaker 1)";
} else {
speaker = "(speaker ?)";
}
const int col = std::max(0, std::min((int) k_colors.size() - 1, (int) (std::pow(p, 3)*float(k_colors.size()))));
printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
//printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str());
}
} else {
const char * text = whisper_full_get_segment_text(ctx, i);
printf("%s%s", speaker.c_str(), text);
if (params.print_colors) {
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
if (params.print_special == false) {
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
if (id >= whisper_token_eot(ctx)) {
continue;
}
}
const char * text = whisper_full_get_token_text(ctx, i, j);
const float p = whisper_full_get_token_p (ctx, i, j);
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
}
printf("\n");
} else {
const char * text = whisper_full_get_segment_text(ctx, i);
printf("[%s --> %s] %s%s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), speaker.c_str(), text);
}
}
// with timestamps or speakers: each segment on new line
if (!params.no_timestamps || params.diarize) {
printf("\n");
}
fflush(stdout);
}
}
@ -334,31 +325,6 @@ bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_
return true;
}
bool output_csv(struct whisper_context * ctx, const char * fname) {
std::ofstream fout(fname);
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
return false;
}
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
if (text[0] == ' ') {
text = text + sizeof(char); //whisper_full_get_segment_text() returns a string with leading space, point to the next character.
}
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
//need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds.
fout << 10 * t0 << ", " << 10 * t1 << ", \"" << text << "\"\n";
}
return true;
}
// karaoke video generation
// outputs a bash script that uses ffmpeg to generate a video with the subtitles
// TODO: font parameter adjustments
@ -492,7 +458,7 @@ int main(int argc, char ** argv) {
// whisper init
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
struct whisper_context * ctx = whisper_init(params.model.c_str());
if (ctx == nullptr) {
fprintf(stderr, "error: failed to initialize whisper context\n");
@ -517,7 +483,6 @@ int main(int argc, char ** argv) {
for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
const auto fname_inp = params.fname_inp[f];
const auto fname_outp = f < params.fname_outp.size() && !params.fname_outp[f].empty() ? params.fname_outp[f] : params.fname_inp[f];
std::vector<float> pcmf32; // mono-channel F32 PCM
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
@ -563,7 +528,7 @@ int main(int argc, char ** argv) {
}
if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
fprintf(stderr, "%s: WAV file '%s' must be %i kHz\n", argv[0], fname_inp.c_str(), WHISPER_SAMPLE_RATE/1000);
fprintf(stderr, "%s: WAV file '%s' must be 16 kHz\n", argv[0], fname_inp.c_str());
return 8;
}
@ -635,8 +600,6 @@ int main(int argc, char ** argv) {
{
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
wparams.strategy = params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY;
wparams.print_realtime = false;
wparams.print_progress = params.print_progress;
wparams.print_timestamps = !params.no_timestamps;
@ -650,17 +613,12 @@ int main(int argc, char ** argv) {
wparams.token_timestamps = params.output_wts || params.max_len > 0;
wparams.thold_pt = params.word_thold;
wparams.entropy_thold = params.entropy_thold;
wparams.logprob_thold = params.logprob_thold;
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
wparams.speed_up = params.speed_up;
wparams.greedy.best_of = params.best_of;
wparams.beam_search.beam_size = params.beam_size;
wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();
wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();
whisper_print_user_data user_data = { &params, &pcmf32s };
@ -695,34 +653,27 @@ int main(int argc, char ** argv) {
// output to text file
if (params.output_txt) {
const auto fname_txt = fname_outp + ".txt";
const auto fname_txt = fname_inp + ".txt";
output_txt(ctx, fname_txt.c_str());
}
// output to VTT file
if (params.output_vtt) {
const auto fname_vtt = fname_outp + ".vtt";
const auto fname_vtt = fname_inp + ".vtt";
output_vtt(ctx, fname_vtt.c_str());
}
// output to SRT file
if (params.output_srt) {
const auto fname_srt = fname_outp + ".srt";
const auto fname_srt = fname_inp + ".srt";
output_srt(ctx, fname_srt.c_str(), params);
}
// output to WTS file
if (params.output_wts) {
const auto fname_wts = fname_outp + ".wts";
const auto fname_wts = fname_inp + ".wts";
output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE);
}
// output to CSV file
if (params.output_csv) {
const auto fname_csv = fname_outp + ".csv";
output_csv(ctx, fname_csv.c_str());
}
}
}

View File

@ -8,8 +8,6 @@ add_executable(${TARGET}
emscripten.cpp
)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE
whisper
)

View File

@ -49,9 +49,6 @@ void stream_main(size_t index) {
wparams.max_tokens = 32;
wparams.audio_ctx = 768; // partial encoder context for better performance
// disable temperature fallback
wparams.temperature_inc = -1.0f;
wparams.language = "en";
printf("stream: using %d threads\n", wparams.n_threads);
@ -132,7 +129,7 @@ EMSCRIPTEN_BINDINGS(stream) {
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
for (size_t i = 0; i < g_contexts.size(); ++i) {
if (g_contexts[i] == nullptr) {
g_contexts[i] = whisper_init_from_file(path_model.c_str());
g_contexts[i] = whisper_init(path_model.c_str());
if (g_contexts[i] != nullptr) {
g_running = true;
if (g_worker.joinable()) {

View File

@ -2,9 +2,6 @@ if (WHISPER_SUPPORT_SDL2)
# stream
set(TARGET stream)
add_executable(${TARGET} stream.cpp)
include(DefaultTargetOptions)
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
target_link_libraries(${TARGET} PRIVATE whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
endif ()

View File

@ -8,7 +8,6 @@
#include <SDL.h>
#include <SDL_audio.h>
#include <atomic>
#include <cassert>
#include <cstdio>
#include <string>
@ -145,8 +144,8 @@ private:
int m_len_ms = 0;
int m_sample_rate = 0;
std::atomic_bool m_running;
std::mutex m_mutex;
bool m_running = false;
std::mutex m_mutex;
std::vector<float> m_audio;
std::vector<float> m_audio_new;
@ -156,8 +155,6 @@ private:
audio_async::audio_async(int len_ms) {
m_len_ms = len_ms;
m_running = false;
}
audio_async::~audio_async() {
@ -423,21 +420,20 @@ int main(int argc, char ** argv) {
return 1;
}
params.keep_ms = std::min(params.keep_ms, params.step_ms);
params.length_ms = std::max(params.length_ms, params.step_ms);
params.keep_ms = std::min(params.keep_ms, params.step_ms); // cannot be more than step_ms
const int n_samples_step = (params.step_ms *1e-3)*WHISPER_SAMPLE_RATE;
const int n_samples_len = (params.length_ms*1e-3)*WHISPER_SAMPLE_RATE;
const int n_samples_keep = (params.keep_ms *1e-3)*WHISPER_SAMPLE_RATE;
const int n_samples_30s = (30000 *1e-3)*WHISPER_SAMPLE_RATE;
const int n_new_line = params.length_ms / params.step_ms - 1; // number of steps to print new line
const bool use_vad = n_samples_step <= 0; // sliding window mode uses VAD
const int n_new_line = !use_vad ? std::max(1, params.length_ms / params.step_ms - 1) : 1; // number of steps to print new line
params.no_timestamps = !use_vad;
params.no_context |= use_vad;
params.max_tokens = 0;
params.no_timestamps = !use_vad;
params.no_context = use_vad;
params.max_tokens = 0;
// init audio
@ -457,10 +453,10 @@ int main(int argc, char ** argv) {
exit(0);
}
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
struct whisper_context * ctx = whisper_init(params.model.c_str());
std::vector<float> pcmf32 (n_samples_30s, 0.0f);
std::vector<float> pcmf32_old;
std::vector<float> pcmf32_old(n_samples_30s, 0.0f);
std::vector<float> pcmf32_new(n_samples_30s, 0.0f);
std::vector<whisper_token> prompt_tokens;
@ -487,7 +483,7 @@ int main(int argc, char ** argv) {
params.no_timestamps ? 0 : 1);
if (!use_vad) {
fprintf(stderr, "%s: n_new_line = %d, no_context = %d\n", __func__, n_new_line, params.no_context);
fprintf(stderr, "%s: n_new_line = %d\n", __func__, n_new_line);
} else {
fprintf(stderr, "%s: using VAD, will transcribe on speech activity\n", __func__);
}
@ -616,9 +612,6 @@ int main(int argc, char ** argv) {
wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;
// disable temperature fallback
wparams.temperature_inc = -1.0f;
wparams.prompt_tokens = params.no_context ? nullptr : prompt_tokens.data();
wparams.prompt_n_tokens = params.no_context ? 0 : prompt_tokens.size();

View File

@ -9,8 +9,6 @@ add_executable(${TARGET}
gpt-2.cpp
)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE
whisper
)
@ -33,8 +31,8 @@ set_target_properties(${TARGET} PROPERTIES LINK_FLAGS " \
--bind \
-s USE_PTHREADS=1 \
-s PTHREAD_POOL_SIZE=8 \
-s INITIAL_MEMORY=1800MB \
-s TOTAL_MEMORY=1800MB \
-s INITIAL_MEMORY=1600MB \
-s TOTAL_MEMORY=1600MB \
-s FORCE_FILESYSTEM=1 \
-s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap']\" \
${EXTRA_FLAGS} \

View File

@ -36,7 +36,7 @@ In order to run this demo efficiently, you need to have the following:
- Latest Chrome or Firefox browser (Safari is not supported)
- Run this on a desktop or laptop with modern CPU (a mobile phone will likely not be good enough)
- Speak phrases that are no longer than 10 seconds - this is the audio context of the AI
- The web-page uses about 1.8GB of RAM
- The web-page uses about 1.6GB of RAM
Notice that this demo is using the smallest GPT-2 model, so the generated text responses are not always very good.
Also, the prompting strategy can likely be improved to achieve better results.

View File

@ -271,7 +271,7 @@ EMSCRIPTEN_BINDINGS(talk) {
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
for (size_t i = 0; i < g_contexts.size(); ++i) {
if (g_contexts[i] == nullptr) {
g_contexts[i] = whisper_init_from_file(path_model.c_str());
g_contexts[i] = whisper_init(path_model.c_str());
if (g_contexts[i] != nullptr) {
g_running = true;
if (g_worker.joinable()) {

View File

@ -8,9 +8,6 @@ if (WHISPER_SUPPORT_SDL2)
# TODO: this is temporary
# need to export ggml symbols for MSVC, but too lazy ..
add_executable(${TARGET} talk.cpp gpt-2.cpp ../../ggml.c ../../whisper.cpp)
include(DefaultTargetOptions)
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS} ../../)
target_link_libraries(${TARGET} PRIVATE ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
endif ()

View File

@ -498,7 +498,7 @@ int main(int argc, char ** argv) {
// whisper init
struct whisper_context * ctx_wsp = whisper_init_from_file(params.model_wsp.c_str());
struct whisper_context * ctx_wsp = whisper_init(params.model_wsp.c_str());
// gpt init

View File

@ -64,21 +64,16 @@ class MainScreenViewModel(private val application: Application) : ViewModel() {
private suspend fun copyAssets() = withContext(Dispatchers.IO) {
modelsPath.mkdirs()
samplesPath.mkdirs()
//application.copyData("models", modelsPath, ::printMessage)
application.copyData("models", modelsPath, ::printMessage)
application.copyData("samples", samplesPath, ::printMessage)
printMessage("All data copied to working directory.\n")
}
private suspend fun loadBaseModel() = withContext(Dispatchers.IO) {
printMessage("Loading model...\n")
val models = application.assets.list("models/")
if (models != null) {
whisperContext = WhisperContext.createContextFromAsset(application.assets, "models/" + models[0])
printMessage("Loaded model ${models[0]}.\n")
}
//val firstModel = modelsPath.listFiles()!!.first()
//whisperContext = WhisperContext.createContextFromFile(firstModel.absolutePath)
val firstModel = modelsPath.listFiles()!!.first()
whisperContext = WhisperContext.createContext(firstModel.absolutePath)
printMessage("Loaded model ${firstModel.name}.\n")
}
fun transcribeSample() = viewModelScope.launch {

View File

@ -1,11 +1,9 @@
package com.whispercppdemo.whisper
import android.content.res.AssetManager
import android.os.Build
import android.util.Log
import kotlinx.coroutines.*
import java.io.File
import java.io.InputStream
import java.util.concurrent.Executors
private const val LOG_TAG = "LibWhisper"
@ -41,31 +39,13 @@ class WhisperContext private constructor(private var ptr: Long) {
}
companion object {
fun createContextFromFile(filePath: String): WhisperContext {
fun createContext(filePath: String): WhisperContext {
val ptr = WhisperLib.initContext(filePath)
if (ptr == 0L) {
throw java.lang.RuntimeException("Couldn't create context with path $filePath")
}
return WhisperContext(ptr)
}
fun createContextFromInputStream(stream: InputStream): WhisperContext {
val ptr = WhisperLib.initContextFromInputStream(stream)
if (ptr == 0L) {
throw java.lang.RuntimeException("Couldn't create context from input stream")
}
return WhisperContext(ptr)
}
fun createContextFromAsset(assetManager: AssetManager, assetPath: String): WhisperContext {
val ptr = WhisperLib.initContextFromAsset(assetManager, assetPath)
if (ptr == 0L) {
throw java.lang.RuntimeException("Couldn't create context from asset $assetPath")
}
return WhisperContext(ptr)
}
}
}
@ -96,8 +76,6 @@ private class WhisperLib {
}
// JNI methods
external fun initContextFromInputStream(inputStream: InputStream): Long
external fun initContextFromAsset(assetManager: AssetManager, assetPath: String): Long
external fun initContext(modelPath: String): Long
external fun freeContext(contextPtr: Long)
external fun fullTranscribe(contextPtr: Long, audioData: FloatArray)

View File

@ -1,5 +1,5 @@
WHISPER_LIB_DIR := $(LOCAL_PATH)/../../../../../../../
LOCAL_LDLIBS := -landroid -llog
LOCAL_LDLIBS := -llog
# Make the final output library smaller by only keeping the symbols referenced from the app.
ifneq ($(APP_OPTIM),debug)

View File

@ -1,17 +1,13 @@
#include <jni.h>
#include <android/asset_manager.h>
#include <android/asset_manager_jni.h>
#include <android/log.h>
#include <stdlib.h>
#include <sys/sysinfo.h>
#include <string.h>
#include "whisper.h"
#define UNUSED(x) (void)(x)
#define TAG "JNI"
#define LOGI(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__)
#define LOGW(...) __android_log_print(ANDROID_LOG_WARN, TAG, __VA_ARGS__)
static inline int min(int a, int b) {
return (a < b) ? a : b;
@ -21,132 +17,13 @@ static inline int max(int a, int b) {
return (a > b) ? a : b;
}
struct input_stream_context {
size_t offset;
JNIEnv * env;
jobject thiz;
jobject input_stream;
jmethodID mid_available;
jmethodID mid_read;
};
size_t inputStreamRead(void * ctx, void * output, size_t read_size) {
struct input_stream_context* is = (struct input_stream_context*)ctx;
jint avail_size = (*is->env)->CallIntMethod(is->env, is->input_stream, is->mid_available);
jint size_to_copy = read_size < avail_size ? (jint)read_size : avail_size;
jbyteArray byte_array = (*is->env)->NewByteArray(is->env, size_to_copy);
jint n_read = (*is->env)->CallIntMethod(is->env, is->input_stream, is->mid_read, byte_array, 0, size_to_copy);
if (size_to_copy != read_size || size_to_copy != n_read) {
LOGI("Insufficient Read: Req=%zu, ToCopy=%d, Available=%d", read_size, size_to_copy, n_read);
}
jbyte* byte_array_elements = (*is->env)->GetByteArrayElements(is->env, byte_array, NULL);
memcpy(output, byte_array_elements, size_to_copy);
(*is->env)->ReleaseByteArrayElements(is->env, byte_array, byte_array_elements, JNI_ABORT);
(*is->env)->DeleteLocalRef(is->env, byte_array);
is->offset += size_to_copy;
return size_to_copy;
}
bool inputStreamEof(void * ctx) {
struct input_stream_context* is = (struct input_stream_context*)ctx;
jint result = (*is->env)->CallIntMethod(is->env, is->input_stream, is->mid_available);
return result <= 0;
}
void inputStreamClose(void * ctx) {
}
JNIEXPORT jlong JNICALL
Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_initContextFromInputStream(
JNIEnv *env, jobject thiz, jobject input_stream) {
UNUSED(thiz);
struct whisper_context *context = NULL;
struct whisper_model_loader loader = {};
struct input_stream_context inp_ctx = {};
inp_ctx.offset = 0;
inp_ctx.env = env;
inp_ctx.thiz = thiz;
inp_ctx.input_stream = input_stream;
jclass cls = (*env)->GetObjectClass(env, input_stream);
inp_ctx.mid_available = (*env)->GetMethodID(env, cls, "available", "()I");
inp_ctx.mid_read = (*env)->GetMethodID(env, cls, "read", "([BII)I");
loader.context = &inp_ctx;
loader.read = inputStreamRead;
loader.eof = inputStreamEof;
loader.close = inputStreamClose;
loader.eof(loader.context);
context = whisper_init(&loader);
return (jlong) context;
}
static size_t asset_read(void *ctx, void *output, size_t read_size) {
return AAsset_read((AAsset *) ctx, output, read_size);
}
static bool asset_is_eof(void *ctx) {
return AAsset_getRemainingLength64((AAsset *) ctx) <= 0;
}
static void asset_close(void *ctx) {
AAsset_close((AAsset *) ctx);
}
static struct whisper_context *whisper_init_from_asset(
JNIEnv *env,
jobject assetManager,
const char *asset_path
) {
LOGI("Loading model from asset '%s'\n", asset_path);
AAssetManager *asset_manager = AAssetManager_fromJava(env, assetManager);
AAsset *asset = AAssetManager_open(asset_manager, asset_path, AASSET_MODE_STREAMING);
if (!asset) {
LOGW("Failed to open '%s'\n", asset_path);
return NULL;
}
whisper_model_loader loader = {
.context = asset,
.read = &asset_read,
.eof = &asset_is_eof,
.close = &asset_close
};
return whisper_init(&loader);
}
JNIEXPORT jlong JNICALL
Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_initContextFromAsset(
JNIEnv *env, jobject thiz, jobject assetManager, jstring asset_path_str) {
UNUSED(thiz);
struct whisper_context *context = NULL;
const char *asset_path_chars = (*env)->GetStringUTFChars(env, asset_path_str, NULL);
context = whisper_init_from_asset(env, assetManager, asset_path_chars);
(*env)->ReleaseStringUTFChars(env, asset_path_str, asset_path_chars);
return (jlong) context;
}
JNIEXPORT jlong JNICALL
Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_initContext(
JNIEnv *env, jobject thiz, jstring model_path_str) {
UNUSED(thiz);
struct whisper_context *context = NULL;
const char *model_path_chars = (*env)->GetStringUTFChars(env, model_path_str, NULL);
context = whisper_init_from_file(model_path_chars);
context = whisper_init(model_path_chars);
(*env)->ReleaseStringUTFChars(env, model_path_str, model_path_chars);
return (jlong) context;
}

View File

@ -0,0 +1,10 @@
## This file is automatically generated by Android Studio.
# Do not modify this file -- YOUR CHANGES WILL BE ERASED!
#
# This file should *NOT* be checked into Version Control Systems,
# as it contains information specific to your local configuration.
#
# Location of the SDK. This is only used by Gradle.
# For customization when using a Version Control System, please read the
# header note.
sdk.dir=/Users/kevin/Library/Android/sdk

View File

@ -61,7 +61,7 @@ void AudioInputCallback(void * inUserData,
NSLog(@"Loading model from %@", modelPath);
// create ggml context
stateInp.ctx = whisper_init_from_file([modelPath UTF8String]);
stateInp.ctx = whisper_init([modelPath UTF8String]);
// check if the model was loaded successfully
if (stateInp.ctx == NULL) {

View File

@ -10,5 +10,3 @@ To use:
5. Select the "release" build configuration under "Run", then deploy and run to your device.
[^1]: I recommend the tiny, base or small models for running on an iOS device.
![image](https://user-images.githubusercontent.com/1991296/212539216-0aef65e4-f882-480a-8358-0f816838fd52.png)

View File

@ -55,7 +55,7 @@ actor WhisperContext {
}
static func createContext(path: String) throws -> WhisperContext {
let context = whisper_init_from_file(path)
let context = whisper_init(path)
if let context {
return WhisperContext(context: context)
} else {

View File

@ -35,10 +35,10 @@
0AAC5DA029539CD0003032C3 /* WhisperCppDemo.entitlements */ = {isa = PBXFileReference; lastKnownFileType = text.plist.entitlements; path = WhisperCppDemo.entitlements; sourceTree = "<group>"; };
0AAC5DA229539CD0003032C3 /* Preview Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = "Preview Assets.xcassets"; sourceTree = "<group>"; };
0AAC5DC629539EAF003032C3 /* WhisperCppDemo-Bridging-Header.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = "WhisperCppDemo-Bridging-Header.h"; sourceTree = "<group>"; };
0AAC5DC729539EB0003032C3 /* whisper.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = whisper.cpp; sourceTree = "<group>"; };
0AAC5DC829539EB0003032C3 /* whisper.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = whisper.h; sourceTree = "<group>"; };
0AAC5DC929539EB0003032C3 /* ggml.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; path = ggml.c; sourceTree = "<group>"; };
0AAC5DCA29539EB0003032C3 /* ggml.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = ggml.h; sourceTree = "<group>"; };
0AAC5DC729539EB0003032C3 /* whisper.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = whisper.cpp; path = ../../../whisper.cpp; sourceTree = "<group>"; };
0AAC5DC829539EB0003032C3 /* whisper.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = whisper.h; path = ../../../whisper.h; sourceTree = "<group>"; };
0AAC5DC929539EB0003032C3 /* ggml.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; name = ggml.c; path = ../../../ggml.c; sourceTree = "<group>"; };
0AAC5DCA29539EB0003032C3 /* ggml.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = ggml.h; path = ../../../ggml.h; sourceTree = "<group>"; };
0AAC5DCD2953A05C003032C3 /* WhisperState.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = WhisperState.swift; sourceTree = "<group>"; };
0AAC5DD02953A394003032C3 /* LibWhisper.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LibWhisper.swift; sourceTree = "<group>"; };
/* End PBXFileReference section */
@ -129,8 +129,7 @@
0AAC5DC729539EB0003032C3 /* whisper.cpp */,
0AAC5DC829539EB0003032C3 /* whisper.h */,
);
name = whisper.cpp;
path = ../..;
path = whisper.cpp;
sourceTree = "<group>";
};
0AAC5DCF2953A36C003032C3 /* whisper.cpp.swift */ = {

View File

@ -8,8 +8,6 @@ add_executable(${TARGET}
emscripten.cpp
)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE
whisper
)
@ -32,8 +30,8 @@ set_target_properties(${TARGET} PROPERTIES LINK_FLAGS " \
--bind \
-s USE_PTHREADS=1 \
-s PTHREAD_POOL_SIZE=8 \
-s INITIAL_MEMORY=1500MB \
-s TOTAL_MEMORY=1500MB \
-s INITIAL_MEMORY=1024MB \
-s TOTAL_MEMORY=1024MB \
-s FORCE_FILESYSTEM=1 \
-s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap']\" \
${EXTRA_FLAGS} \

View File

@ -18,7 +18,7 @@ EMSCRIPTEN_BINDINGS(whisper) {
for (size_t i = 0; i < g_contexts.size(); ++i) {
if (g_contexts[i] == nullptr) {
g_contexts[i] = whisper_init_from_file(path_model.c_str());
g_contexts[i] = whisper_init(path_model.c_str());
if (g_contexts[i] != nullptr) {
return i + 1;
} else {

View File

@ -46,12 +46,10 @@
<div id="model">
Whisper model: <span id="model-whisper-status"></span>
<button id="fetch-whisper-tiny-en" onclick="loadWhisper('tiny.en')">tiny.en (75 MB)</button>
<button id="fetch-whisper-tiny" onclick="loadWhisper('tiny')">tiny (75 MB)</button>
<button id="fetch-whisper-base-en" onclick="loadWhisper('base.en')">base.en (142 MB)</button>
<button id="fetch-whisper-base" onclick="loadWhisper('base')">base (142 MB)</button>
<button id="fetch-whisper-small-en" onclick="loadWhisper('small.en')">small.en (466 MB)</button>
<button id="fetch-whisper-small" onclick="loadWhisper('small')">small (466 MB)</button>
<button id="fetch-whisper-tiny-en" onclick="loadWhisper('tiny.en')">tiny.en (75 MB)</button>
<button id="fetch-whisper-tiny" onclick="loadWhisper('tiny')">tiny (75 MB)</button>
<button id="fetch-whisper-base-en" onclick="loadWhisper('base.en')">base.en (142 MB)</button>
<button id="fetch-whisper-base" onclick="loadWhisper('base')">base (142 MB)</button>
<span id="fetch-whisper-progress"></span>
<input type="file" id="whisper-file" name="file" onchange="loadFile(event, 'whisper.bin')" />
@ -286,33 +284,27 @@
}
reader.readAsArrayBuffer(file);
document.getElementById('fetch-whisper-tiny-en' ).style.display = 'none';
document.getElementById('fetch-whisper-base-en' ).style.display = 'none';
document.getElementById('fetch-whisper-small-en').style.display = 'none';
document.getElementById('fetch-whisper-tiny' ).style.display = 'none';
document.getElementById('fetch-whisper-base' ).style.display = 'none';
document.getElementById('fetch-whisper-small' ).style.display = 'none';
document.getElementById('whisper-file' ).style.display = 'none';
document.getElementById('model-whisper-status' ).innerHTML = 'loaded model: ' + file.name;
document.getElementById('fetch-whisper-tiny-en').style.display = 'none';
document.getElementById('fetch-whisper-base-en').style.display = 'none';
document.getElementById('fetch-whisper-tiny' ).style.display = 'none';
document.getElementById('fetch-whisper-base' ).style.display = 'none';
document.getElementById('whisper-file' ).style.display = 'none';
document.getElementById('model-whisper-status' ).innerHTML = 'loaded model: ' + file.name;
}
function loadWhisper(model) {
let urls = {
'tiny.en': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en.bin',
'tiny': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.bin',
'base.en': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en.bin',
'base': 'https://whisper.ggerganov.com/ggml-model-whisper-base.bin',
'small.en': 'https://whisper.ggerganov.com/ggml-model-whisper-small.en.bin',
'small': 'https://whisper.ggerganov.com/ggml-model-whisper-small.bin',
'tiny.en': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en.bin',
'tiny': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.bin',
'base.en': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en.bin',
'base': 'https://whisper.ggerganov.com/ggml-model-whisper-base.bin',
};
let sizes = {
'tiny.en': 75,
'tiny': 75,
'base.en': 142,
'base': 142,
'small.en': 466,
'small': 466,
'tiny.en': 75,
'tiny': 75,
'base.en': 142,
'base': 142,
};
let url = urls[model];
@ -321,14 +313,12 @@
model_whisper = model;
document.getElementById('fetch-whisper-tiny-en' ).style.display = 'none';
document.getElementById('fetch-whisper-base-en' ).style.display = 'none';
document.getElementById('fetch-whisper-small-en').style.display = 'none';
document.getElementById('fetch-whisper-tiny' ).style.display = 'none';
document.getElementById('fetch-whisper-base' ).style.display = 'none';
document.getElementById('fetch-whisper-small' ).style.display = 'none';
document.getElementById('whisper-file' ).style.display = 'none';
document.getElementById('model-whisper-status' ).innerHTML = 'loading model: ' + model;
document.getElementById('fetch-whisper-tiny-en').style.display = 'none';
document.getElementById('fetch-whisper-base-en').style.display = 'none';
document.getElementById('fetch-whisper-tiny' ).style.display = 'none';
document.getElementById('fetch-whisper-base' ).style.display = 'none';
document.getElementById('whisper-file' ).style.display = 'none';
document.getElementById('model-whisper-status' ).innerHTML = 'loading model: ' + model;
cbProgress = function(p) {
let el = document.getElementById('fetch-whisper-progress');
@ -337,14 +327,12 @@
cbCancel = function() {
var el;
el = document.getElementById('fetch-whisper-tiny-en' ); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-base-en' ); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-small-en'); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-tiny' ); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-base' ); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-small' ); if (el) el.style.display = 'inline-block';
el = document.getElementById('whisper-file' ); if (el) el.style.display = 'inline-block';
el = document.getElementById('model-whisper-status' ); if (el) el.innerHTML = '';
el = document.getElementById('fetch-whisper-tiny-en'); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-base-en'); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-tiny' ); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-base' ); if (el) el.style.display = 'inline-block';
el = document.getElementById('whisper-file' ); if (el) el.style.display = 'inline-block';
el = document.getElementById('model-whisper-status' ); if (el) el.innerHTML = '';
};
loadRemote(url, dst, size_mb, cbProgress, storeFS, cbCancel, printTextarea);

View File

@ -12,18 +12,6 @@ fi
models=( "tiny" "base" "small" "medium" "large" )
printf "\n"
printf "Running memcpy benchmark with 1 thread\n"
printf "\n"
./bench -w 1 -t 1 2>&1
printf "\n"
printf "Running ggml_mul_mat benchmark with $n_threads threads\n"
printf "\n"
./bench -w 2 -t $n_threads 2>&1
printf "\n"
printf "Running benchmark for all models\n"
printf "This can take a while!\n"
@ -68,3 +56,4 @@ for model in "${models[@]}"; do
printf "| <todo> | <todo> | $config | $model | $n_threads | $load_time | $encode_time | $commit |\n"
done

1274
ggml.c

File diff suppressed because it is too large Load Diff

14
ggml.h
View File

@ -177,11 +177,13 @@ extern "C" {
#include <stddef.h>
#include <stdbool.h>
#define GGML_MAX_DIMS 4
#define GGML_MAX_NODES 4096
#define GGML_MAX_PARAMS 16
#define GGML_MAX_CONTEXTS 64
#define GGML_MAX_OPT 4
#define GGML_MAX_DIMS 4
#define GGML_MAX_NODES 4096
#define GGML_MAX_PARAMS 16
#define GGML_MAX_CONTEXTS 64
#define GGML_MAX_OPT 4
#define GGML_MAX_THREADS 64
#define GGML_MAX_THREAD_POOLS 16
#ifdef __ARM_NEON
// we use the built-in 16-bit float type
@ -731,8 +733,6 @@ int ggml_cpu_has_f16c(void);
int ggml_cpu_has_fp16_va(void);
int ggml_cpu_has_wasm_simd(void);
int ggml_cpu_has_blas(void);
int ggml_cpu_has_sse3(void);
int ggml_cpu_has_vsx(void);
#ifdef __cplusplus
}

View File

@ -56,7 +56,7 @@ def bytes_to_unicode():
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a significant percentage of your normal, say, 32K bpe vocab.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""

View File

@ -40,7 +40,7 @@ if exist "ggml-%model%.bin" (
goto :eof
)
PowerShell -NoProfile -ExecutionPolicy Bypass -Command "Invoke-WebRequest -Uri https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main/ggml-%model%.bin -OutFile ggml-%model%.bin"
PowerShell -NoProfile -ExecutionPolicy Bypass -Command "Invoke-WebRequest -Uri https://huggingface.co/datasets/ggerganov/whisper.cpp/raw/main/ggml-%model%.bin -OutFile ggml-%model%.bin"
if %ERRORLEVEL% neq 0 (
echo Failed to download ggml model %model%

File diff suppressed because it is too large Load Diff

View File

@ -1,7 +1,6 @@
#ifndef WHISPER_H
#define WHISPER_H
#include <stddef.h>
#include <stdint.h>
#include <stdbool.h>
@ -41,7 +40,7 @@ extern "C" {
//
// ...
//
// struct whisper_context * ctx = whisper_init_from_file("/path/to/ggml-base.en.bin");
// struct whisper_context * ctx = whisper_init("/path/to/ggml-base.en.bin");
//
// if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
// fprintf(stderr, "failed to process audio\n");
@ -74,7 +73,6 @@ extern "C" {
whisper_token tid; // forced timestamp token id
float p; // probability of the token
float plog; // log probability of the token
float pt; // probability of the timestamp token
float ptsum; // sum of probabilities of all timestamp tokens
@ -86,20 +84,9 @@ extern "C" {
float vlen; // voice length of the token
} whisper_token_data;
typedef struct whisper_model_loader {
void * context;
size_t (*read)(void * ctx, void * output, size_t read_size);
bool (*eof)(void * ctx);
void (*close)(void * ctx);
} whisper_model_loader;
// Various functions for loading a ggml whisper model.
// Allocate (almost) all memory needed for the model.
// Return NULL on failure
WHISPER_API struct whisper_context * whisper_init_from_file(const char * path_model);
WHISPER_API struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size);
WHISPER_API struct whisper_context * whisper_init(struct whisper_model_loader * loader);
// Allocates all memory needed for the model and loads the model from the given file.
// Returns NULL on failure.
WHISPER_API struct whisper_context * whisper_init(const char * path_model);
// Frees all memory allocated by the model.
WHISPER_API void whisper_free(struct whisper_context * ctx);
@ -137,7 +124,6 @@ extern "C" {
// tokens + n_tokens is the provided context for the decoder.
// n_past is the number of tokens to use from previous decoder calls.
// Returns 0 on success
// TODO: add support for multiple decoders
WHISPER_API int whisper_decode(
struct whisper_context * ctx,
const whisper_token * tokens,
@ -145,6 +131,14 @@ extern "C" {
int n_past,
int n_threads);
// Token sampling methods.
// These are provided for convenience and can be used after each call to whisper_decode().
// You can also implement your own sampling method using the whisper_get_probs() function.
// whisper_sample_best() returns the token with the highest probability
// whisper_sample_timestamp() returns the most probable timestamp token
WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx);
WHISPER_API whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial);
// Convert the provided text into tokens.
// The tokens pointer must be large enough to hold the resulting tokens.
// Returns the number of tokens on success, no more than n_max_tokens
@ -154,7 +148,7 @@ extern "C" {
struct whisper_context * ctx,
const char * text,
whisper_token * tokens,
int n_max_tokens);
int n_max_tokens);
// Largest language id (i.e. number of available languages - 1)
WHISPER_API int whisper_lang_max_id();
@ -183,14 +177,10 @@ extern "C" {
WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length
WHISPER_API int whisper_n_vocab (struct whisper_context * ctx);
WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx);
WHISPER_API int whisper_n_audio_ctx (struct whisper_context * ctx);
WHISPER_API int whisper_is_multilingual(struct whisper_context * ctx);
// Token logits obtained from the last call to whisper_decode()
// The logits for the last token are stored in the last row
// Rows: n_tokens
// Cols: n_vocab
WHISPER_API float * whisper_get_logits(struct whisper_context * ctx);
// The probabilities for the next token
WHISPER_API float * whisper_get_probs(struct whisper_context * ctx);
// Token Id -> String. Uses the vocabulary in the provided context
WHISPER_API const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token);
@ -219,8 +209,8 @@ extern "C" {
// Available sampling strategies
enum whisper_sampling_strategy {
WHISPER_SAMPLING_GREEDY, // similar to OpenAI's GreefyDecoder
WHISPER_SAMPLING_BEAM_SEARCH, // similar to OpenAI's BeamSearchDecoder
WHISPER_SAMPLING_GREEDY, // Always select the most probable token
WHISPER_SAMPLING_BEAM_SEARCH, // TODO: not implemented yet!
};
// Text segment callback
@ -240,17 +230,17 @@ extern "C" {
enum whisper_sampling_strategy strategy;
int n_threads;
int n_max_text_ctx; // max tokens to use from past text as prompt for the decoder
int n_max_text_ctx;
int offset_ms; // start offset in ms
int duration_ms; // audio duration to process in ms
bool translate;
bool no_context; // do not use past transcription (if any) as initial prompt for the decoder
bool no_context;
bool single_segment; // force single segment output (useful for streaming)
bool print_special; // print special tokens (e.g. <SOT>, <EOT>, <BEG>, etc.)
bool print_progress; // print progress information
bool print_realtime; // print results from within whisper.cpp (avoid it, use callback instead)
bool print_timestamps; // print timestamps for each text segment when printing realtime
bool print_special;
bool print_progress;
bool print_realtime;
bool print_timestamps;
// [EXPERIMENTAL] token-level timestamps
bool token_timestamps; // enable token-level timestamps
@ -260,11 +250,10 @@ extern "C" {
int max_tokens; // max tokens per segment (0 = no limit)
// [EXPERIMENTAL] speed-up techniques
// note: these can significantly reduce the quality of the output
bool speed_up; // speed-up the audio by 2x using Phase Vocoder
int audio_ctx; // overwrite the audio context size (0 = use default)
// tokens to provide to the whisper decoder as initial prompt
// tokens to provide the whisper model as initial prompt
// these are prepended to any existing text context from a previous call
const whisper_token * prompt_tokens;
int prompt_n_tokens;
@ -272,35 +261,19 @@ extern "C" {
// for auto-detection, set to nullptr, "" or "auto"
const char * language;
// common decoding parameters:
bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89
float temperature; // initial decoding temperature, ref: https://ai.stackexchange.com/a/32478
float max_initial_ts; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97
float length_penalty; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L267
// fallback parameters
// ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L274-L278
float temperature_inc;
float entropy_thold; // similar to OpenAI's "compression_ratio_threshold"
float logprob_thold;
float no_speech_thold; // TODO: not implemented
struct {
int best_of; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L264
int n_past;
} greedy;
struct {
int beam_size; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L265
float patience; // TODO: not implemented, ref: https://arxiv.org/pdf/2204.05424.pdf
int n_past;
int beam_width;
int n_best;
} beam_search;
// called for every newly generated text segment
whisper_new_segment_callback new_segment_callback;
void * new_segment_callback_user_data;
// called each time before the encoder starts
whisper_encoder_begin_callback encoder_begin_callback;
void * encoder_begin_callback_user_data;
};
@ -350,13 +323,6 @@ extern "C" {
// Get the probability of the specified token in the specified segment.
WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token);
////////////////////////////////////////////////////////////////////////////
// Temporary helpers needed for exposing ggml interface
WHISPER_API int whisper_bench_memcpy(int n_threads);
WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads);
#ifdef __cplusplus
}
#endif