mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2024-12-18 20:27:53 +00:00
bindings : initial import of golang bindings (#287)
* Initial import of golang bindings * Updated makefile rules * Updated bindings * Makefile update to add in more tests
This commit is contained in:
parent
90564f85f9
commit
231bebca7d
3
bindings/go/.gitignore
vendored
Executable file
3
bindings/go/.gitignore
vendored
Executable file
@ -0,0 +1,3 @@
|
||||
build
|
||||
models
|
||||
go.sum
|
21
bindings/go/LICENSE
Executable file
21
bindings/go/LICENSE
Executable file
@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2022 David Thorpe
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
38
bindings/go/Makefile
Executable file
38
bindings/go/Makefile
Executable file
@ -0,0 +1,38 @@
|
||||
CMAKE := $(shell which cmake)
|
||||
BUILD_DIR := "build"
|
||||
MODELS_DIR := "models"
|
||||
EXAMPLES_DIR := $(wildcard examples/*)
|
||||
C_INCLUDE_PATH := "../.."
|
||||
|
||||
all: clean whisper examples
|
||||
|
||||
whisper: mkdir
|
||||
@echo Build whisper
|
||||
@${CMAKE} -S ../.. -B ${BUILD_DIR} -D BUILD_SHARED_LIBS=off -D WHISPER_NO_AVX2=on
|
||||
@${CMAKE} --build ${BUILD_DIR} --target whisper
|
||||
|
||||
test: model-small whisper
|
||||
@go mod tidy
|
||||
@go test -v .
|
||||
@go test -v ./pkg/whisper/...
|
||||
|
||||
examples: $(EXAMPLES_DIR)
|
||||
|
||||
model-small: mkdir examples/go-model-download
|
||||
@${BUILD_DIR}/go-model-download -out models small.en
|
||||
|
||||
$(EXAMPLES_DIR): mkdir whisper
|
||||
@echo Build example $(notdir $@)
|
||||
@go build ${BUILD_FLAGS} -o ${BUILD_DIR}/$(notdir $@) ./$@
|
||||
|
||||
mkdir:
|
||||
@echo Mkdir ${BUILD_DIR}
|
||||
@install -d ${BUILD_DIR}
|
||||
@echo Mkdir ${MODELS_DIR}
|
||||
@install -d ${MODELS_DIR}
|
||||
|
||||
clean:
|
||||
@echo Clean
|
||||
@rm -fr $(BUILD_DIR)
|
||||
@go mod tidy
|
||||
@go clean
|
77
bindings/go/README.md
Executable file
77
bindings/go/README.md
Executable file
@ -0,0 +1,77 @@
|
||||
# Go bindings for Whisper
|
||||
|
||||
This package provides Go bindings for whisper.cpp. They have been tested on:
|
||||
|
||||
* Darwin (OS X) 12.6 on x64_64
|
||||
* Debian Linux on arm64
|
||||
* Fedora Linux on x86_64
|
||||
|
||||
The "low level" bindings are in the `bindings/go` directory and there is a more
|
||||
Go-style package in the `bindings/go/pkg/whisper` directory. The most simple usage
|
||||
is as follows:
|
||||
|
||||
```go
|
||||
import (
|
||||
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
||||
)
|
||||
|
||||
func main() {
|
||||
var modelpath string // Path to the model
|
||||
var samples []float32 // Samples to process
|
||||
|
||||
// Load the model
|
||||
model, err := whisper.New(modelpath)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer model.Close()
|
||||
|
||||
// Process samples
|
||||
context, err := model.NewContext()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := context.Process(samples, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Print out the results
|
||||
for {
|
||||
segment, err := context.NextSegment()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
fmt.Printf("[%6s->%6s] %s\n", segment.Start, segment.End, segment.Text)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Building & Testing
|
||||
|
||||
In order to build, you need to have the Go compiler installed. You can get it from [here](https://golang.org/dl/). Run the tests with:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/ggerganov/whisper.cpp.git
|
||||
cd whisper.cpp/bindings/go
|
||||
make test
|
||||
```
|
||||
|
||||
This will compile a static `libwhisper.a` in a `build` folder, download a model file, then run the tests. To build the examples:
|
||||
|
||||
```bash
|
||||
make examples
|
||||
```
|
||||
|
||||
The examples are placed in the `build` directory. Once built, you can download all the models with the following command:
|
||||
|
||||
```bash
|
||||
./build/go-model-download -out models
|
||||
```
|
||||
|
||||
And you can then test a model against samples with the following command:
|
||||
|
||||
```bash
|
||||
./build/go-whisper -model models/ggml-tiny.en.bin samples/jfk.wav
|
||||
```
|
||||
|
||||
|
5
bindings/go/doc.go
Normal file
5
bindings/go/doc.go
Normal file
@ -0,0 +1,5 @@
|
||||
/*
|
||||
github.com/ggerganov/whisper.cpp/bindings/go
|
||||
provides a speech-to-text service bindings for the Go programming language.
|
||||
*/
|
||||
package whisper
|
30
bindings/go/examples/go-model-download/context.go
Executable file
30
bindings/go/examples/go-model-download/context.go
Executable file
@ -0,0 +1,30 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"os/signal"
|
||||
)
|
||||
|
||||
// ContextForSignal returns a context object which is cancelled when a signal
|
||||
// is received. It returns nil if no signal parameter is provided
|
||||
func ContextForSignal(signals ...os.Signal) context.Context {
|
||||
if len(signals) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
ch := make(chan os.Signal)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// Send message on channel when signal received
|
||||
signal.Notify(ch, signals...)
|
||||
|
||||
// When any signal received, call cancel
|
||||
go func() {
|
||||
<-ch
|
||||
cancel()
|
||||
}()
|
||||
|
||||
// Return success
|
||||
return ctx
|
||||
}
|
206
bindings/go/examples/go-model-download/main.go
Executable file
206
bindings/go/examples/go-model-download/main.go
Executable file
@ -0,0 +1,206 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// CONSTANTS
|
||||
|
||||
const (
|
||||
srcUrl = "https://huggingface.co/" // The location of the models
|
||||
srcPathPrefix = "/datasets/ggerganov/whisper.cpp/resolve/main/ggml" // Filename prefix
|
||||
srcExt = ".bin" // Filename extension
|
||||
bufSize = 1024 * 64 // Size of the buffer used for downloading the model
|
||||
)
|
||||
|
||||
var (
|
||||
// The models which will be downloaded, if no model is specified as an argument
|
||||
modelNames = []string{"tiny.en", "tiny", "base.en", "base", "small.en", "small", "medium.en", "medium", "large-v1", "large"}
|
||||
)
|
||||
|
||||
var (
|
||||
// The output folder. When not set, use current working directory.
|
||||
flagOut = flag.String("out", "", "Output folder")
|
||||
|
||||
// HTTP timeout parameter - will timeout if takes longer than this to download a model
|
||||
flagTimeout = flag.Duration("timeout", 30*time.Minute, "HTTP timeout")
|
||||
|
||||
// Quiet parameter - will not print progress if set
|
||||
flagQuiet = flag.Bool("quiet", false, "Quiet mode")
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MAIN
|
||||
|
||||
func main() {
|
||||
flag.Usage = func() {
|
||||
name := filepath.Base(flag.CommandLine.Name())
|
||||
fmt.Fprintf(flag.CommandLine.Output(), "Usage: %s [options] <model>\n\n", name)
|
||||
flag.PrintDefaults()
|
||||
}
|
||||
flag.Parse()
|
||||
|
||||
// Get output path
|
||||
out, err := GetOut()
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "Error:", err)
|
||||
os.Exit(-1)
|
||||
}
|
||||
|
||||
// Create context which quits on SIGINT or SIGQUIT
|
||||
ctx := ContextForSignal(os.Interrupt, syscall.SIGQUIT)
|
||||
|
||||
// Progress filehandle
|
||||
progress := os.Stdout
|
||||
if *flagQuiet {
|
||||
progress, err = os.Open(os.DevNull)
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "Error:", err)
|
||||
os.Exit(-1)
|
||||
}
|
||||
defer progress.Close()
|
||||
}
|
||||
|
||||
// Download models - exit on error or interrupt
|
||||
for _, model := range GetModels() {
|
||||
url, err := URLForModel(model)
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "Error:", err)
|
||||
continue
|
||||
} else if path, err := Download(ctx, progress, url, out); err == nil || err == io.EOF {
|
||||
continue
|
||||
} else if err == context.Canceled {
|
||||
os.Remove(path)
|
||||
fmt.Fprintln(progress, "\nInterrupted")
|
||||
break
|
||||
} else if err == context.DeadlineExceeded {
|
||||
os.Remove(path)
|
||||
fmt.Fprintln(progress, "Timeout downloading model")
|
||||
continue
|
||||
} else {
|
||||
os.Remove(path)
|
||||
fmt.Fprintln(os.Stderr, "Error:", err)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// PUBLIC METHODS
|
||||
|
||||
// GetOut returns the path to the output directory
|
||||
func GetOut() (string, error) {
|
||||
if *flagOut == "" {
|
||||
return os.Getwd()
|
||||
}
|
||||
if info, err := os.Stat(*flagOut); err != nil {
|
||||
return "", err
|
||||
} else if !info.IsDir() {
|
||||
return "", fmt.Errorf("not a directory: %s", info.Name())
|
||||
} else {
|
||||
return *flagOut, nil
|
||||
}
|
||||
}
|
||||
|
||||
// GetModels returns the list of models to download
|
||||
func GetModels() []string {
|
||||
if flag.NArg() == 0 {
|
||||
return modelNames
|
||||
} else {
|
||||
return flag.Args()
|
||||
}
|
||||
}
|
||||
|
||||
// URLForModel returns the URL for the given model on huggingface.co
|
||||
func URLForModel(model string) (string, error) {
|
||||
url, err := url.Parse(srcUrl)
|
||||
if err != nil {
|
||||
return "", err
|
||||
} else {
|
||||
url.Path = srcPathPrefix + "-" + model + srcExt
|
||||
}
|
||||
return url.String(), nil
|
||||
}
|
||||
|
||||
// Download downloads the model from the given URL to the given output directory
|
||||
func Download(ctx context.Context, p io.Writer, model, out string) (string, error) {
|
||||
// Create HTTP client
|
||||
client := http.Client{
|
||||
Timeout: *flagTimeout,
|
||||
}
|
||||
|
||||
// Initiate the download
|
||||
req, err := http.NewRequest("GET", model, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("%s: %s", model, resp.Status)
|
||||
}
|
||||
|
||||
// If output file exists and is the same size as the model, skip
|
||||
path := filepath.Join(out, filepath.Base(model))
|
||||
if info, err := os.Stat(path); err == nil && info.Size() == resp.ContentLength {
|
||||
fmt.Fprintln(p, "Skipping", model, "as it already exists")
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Create file
|
||||
w, err := os.Create(path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer w.Close()
|
||||
|
||||
// Report
|
||||
fmt.Fprintln(p, "Downloading", model, "to", out)
|
||||
|
||||
// Progressively download the model
|
||||
data := make([]byte, bufSize)
|
||||
count, pct := int64(0), int64(0)
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Cancelled, return error
|
||||
return path, ctx.Err()
|
||||
case <-ticker.C:
|
||||
pct = DownloadReport(p, pct, count, resp.ContentLength)
|
||||
default:
|
||||
// Read body
|
||||
n, err := resp.Body.Read(data)
|
||||
if err != nil {
|
||||
DownloadReport(p, pct, count, resp.ContentLength)
|
||||
return path, err
|
||||
} else if m, err := w.Write(data[:n]); err != nil {
|
||||
return path, err
|
||||
} else {
|
||||
count += int64(m)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Report periodically reports the download progress when percentage changes
|
||||
func DownloadReport(w io.Writer, pct, count, total int64) int64 {
|
||||
pct_ := count * 100 / total
|
||||
if pct_ > pct {
|
||||
fmt.Fprintf(w, " ...%d MB written (%d%%)\n", count/1e6, pct_)
|
||||
}
|
||||
return pct_
|
||||
}
|
61
bindings/go/examples/go-whisper/flags.go
Executable file
61
bindings/go/examples/go-whisper/flags.go
Executable file
@ -0,0 +1,61 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// TYPES
|
||||
|
||||
type Flags struct {
|
||||
*flag.FlagSet
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// LIFECYCLE
|
||||
|
||||
func NewFlags(name string, args []string) (*Flags, error) {
|
||||
flags := &Flags{
|
||||
FlagSet: flag.NewFlagSet(name, flag.ContinueOnError),
|
||||
}
|
||||
|
||||
// Register the command line arguments
|
||||
registerFlags(flags)
|
||||
|
||||
// Parse command line
|
||||
if err := flags.Parse(args); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Return success
|
||||
return flags, nil
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// PUBLIC METHODS
|
||||
|
||||
func (flags *Flags) GetModel() string {
|
||||
return flags.Lookup("model").Value.String()
|
||||
}
|
||||
|
||||
func (flags *Flags) GetLanguage() string {
|
||||
return flags.Lookup("language").Value.String()
|
||||
}
|
||||
|
||||
func (flags *Flags) IsSpeedup() bool {
|
||||
return flags.Lookup("speedup").Value.String() == "true"
|
||||
}
|
||||
|
||||
func (flags *Flags) IsTokens() bool {
|
||||
return flags.Lookup("tokens").Value.String() == "true"
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// PRIVATE METHODS
|
||||
|
||||
func registerFlags(flag *Flags) {
|
||||
flag.String("model", "", "Path to the model file")
|
||||
flag.String("language", "", "Language")
|
||||
flag.Bool("speedup", false, "Enable speedup")
|
||||
flag.Bool("tokens", false, "Display tokens")
|
||||
}
|
44
bindings/go/examples/go-whisper/main.go
Executable file
44
bindings/go/examples/go-whisper/main.go
Executable file
@ -0,0 +1,44 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
// Packages
|
||||
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
||||
)
|
||||
|
||||
func main() {
|
||||
flags, err := NewFlags(filepath.Base(os.Args[0]), os.Args[1:])
|
||||
if err == flag.ErrHelp {
|
||||
os.Exit(0)
|
||||
} else if err != nil {
|
||||
fmt.Fprintln(os.Stderr, err)
|
||||
os.Exit(1)
|
||||
} else if flags.GetModel() == "" {
|
||||
fmt.Fprintln(os.Stderr, "Use -model flag to specify which model file to use")
|
||||
os.Exit(1)
|
||||
} else if flags.NArg() == 0 {
|
||||
fmt.Fprintln(os.Stderr, "No input files specified")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Load model
|
||||
model, err := whisper.New(flags.GetModel())
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer model.Close()
|
||||
|
||||
// Process files
|
||||
for _, filename := range flags.Args() {
|
||||
fmt.Println("Processing", filename)
|
||||
if err := Process(model, filename, flags.GetLanguage(), flags.IsSpeedup(), flags.IsTokens()); err != nil {
|
||||
fmt.Fprintln(os.Stderr, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
80
bindings/go/examples/go-whisper/process.go
Executable file
80
bindings/go/examples/go-whisper/process.go
Executable file
@ -0,0 +1,80 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
// Package imports
|
||||
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
||||
wav "github.com/go-audio/wav"
|
||||
)
|
||||
|
||||
func Process(model whisper.Model, path string, lang string, speedup, tokens bool) error {
|
||||
var data []float32
|
||||
|
||||
// Create processing context
|
||||
context, err := model.NewContext()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Open the file
|
||||
fh, err := os.Open(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer fh.Close()
|
||||
|
||||
// Decode the WAV file
|
||||
dec := wav.NewDecoder(fh)
|
||||
if buf, err := dec.FullPCMBuffer(); err != nil {
|
||||
return err
|
||||
} else if dec.SampleRate != whisper.SampleRate {
|
||||
return fmt.Errorf("unsupported sample rate: %d", dec.SampleRate)
|
||||
} else if dec.NumChans != 1 {
|
||||
return fmt.Errorf("unsupported number of channels: %d", dec.NumChans)
|
||||
} else {
|
||||
data = buf.AsFloat32Buffer().Data
|
||||
}
|
||||
|
||||
// Set the parameters
|
||||
var cb whisper.SegmentCallback
|
||||
if lang != "" {
|
||||
if err := context.SetLanguage(lang); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if speedup {
|
||||
context.SetSpeedup(true)
|
||||
}
|
||||
if tokens {
|
||||
cb = func(segment whisper.Segment) {
|
||||
fmt.Printf("%02d [%6s->%6s] ", segment.Num, segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond))
|
||||
for _, token := range segment.Tokens {
|
||||
fmt.Printf("%q ", token.Text)
|
||||
}
|
||||
fmt.Println("")
|
||||
}
|
||||
}
|
||||
|
||||
// Process the data
|
||||
if err := context.Process(data, cb); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Print out the results
|
||||
for {
|
||||
segment, err := context.NextSegment()
|
||||
if err == io.EOF {
|
||||
break
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Printf("[%6s->%6s] %s\n", segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond), segment.Text)
|
||||
}
|
||||
|
||||
// Return success
|
||||
return nil
|
||||
}
|
16
bindings/go/go.mod
Executable file
16
bindings/go/go.mod
Executable file
@ -0,0 +1,16 @@
|
||||
module github.com/ggerganov/whisper.cpp/bindings/go
|
||||
|
||||
go 1.19
|
||||
|
||||
require (
|
||||
github.com/go-audio/wav v1.1.0
|
||||
github.com/stretchr/testify v1.8.1
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/go-audio/audio v1.0.0 // indirect
|
||||
github.com/go-audio/riff v1.0.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
134
bindings/go/params.go
Normal file
134
bindings/go/params.go
Normal file
@ -0,0 +1,134 @@
|
||||
package whisper
|
||||
|
||||
// This file defines the whisper_token, whisper_token_data and whisper_full_params
|
||||
// structures, which are used by the whisper_full() function.
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// CGO
|
||||
|
||||
/*
|
||||
#include <whisper.h>
|
||||
*/
|
||||
import "C"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// PUBLIC METHODS
|
||||
|
||||
func (p *Params) SetTranslate(v bool) {
|
||||
p.translate = toBool(v)
|
||||
}
|
||||
|
||||
func (p *Params) SetNoContext(v bool) {
|
||||
p.no_context = toBool(v)
|
||||
}
|
||||
|
||||
func (p *Params) SetSingleSegment(v bool) {
|
||||
p.single_segment = toBool(v)
|
||||
}
|
||||
|
||||
func (p *Params) SetPrintSpecial(v bool) {
|
||||
p.print_special = toBool(v)
|
||||
}
|
||||
|
||||
func (p *Params) SetPrintProgress(v bool) {
|
||||
p.print_progress = toBool(v)
|
||||
}
|
||||
|
||||
func (p *Params) SetPrintRealtime(v bool) {
|
||||
p.print_realtime = toBool(v)
|
||||
}
|
||||
|
||||
func (p *Params) SetPrintTimestamps(v bool) {
|
||||
p.print_timestamps = toBool(v)
|
||||
}
|
||||
|
||||
func (p *Params) SetSpeedup(v bool) {
|
||||
p.speed_up = toBool(v)
|
||||
}
|
||||
|
||||
func (p *Params) SetLanguage(lang int) error {
|
||||
str := C.whisper_lang_str(C.int(lang))
|
||||
if str == nil {
|
||||
return ErrInvalidLanguage
|
||||
} else {
|
||||
p.language = str
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Params) Language() int {
|
||||
if p.language == nil {
|
||||
return -1
|
||||
}
|
||||
return int(C.whisper_lang_id(p.language))
|
||||
}
|
||||
|
||||
func (p *Params) SetThreads(threads int) {
|
||||
p.n_threads = C.int(threads)
|
||||
}
|
||||
|
||||
func (p *Params) SetOffset(offset_ms int) {
|
||||
p.offset_ms = C.int(offset_ms)
|
||||
}
|
||||
|
||||
func (p *Params) SetDuration(duration_ms int) {
|
||||
p.duration_ms = C.int(duration_ms)
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// PRIVATE METHODS
|
||||
|
||||
func toBool(v bool) C.bool {
|
||||
if v {
|
||||
return C.bool(true)
|
||||
}
|
||||
return C.bool(false)
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// STRINGIFY
|
||||
|
||||
func (p *Params) String() string {
|
||||
str := "<whisper.params"
|
||||
str += fmt.Sprintf(" strategy=%v", p.strategy)
|
||||
str += fmt.Sprintf(" n_threads=%d", p.n_threads)
|
||||
if p.language != nil {
|
||||
str += fmt.Sprintf(" language=%s", C.GoString(p.language))
|
||||
}
|
||||
str += fmt.Sprintf(" n_max_text_ctx=%d", p.n_max_text_ctx)
|
||||
str += fmt.Sprintf(" offset_ms=%d", p.offset_ms)
|
||||
str += fmt.Sprintf(" duration_ms=%d", p.duration_ms)
|
||||
if p.translate {
|
||||
str += " translate"
|
||||
}
|
||||
if p.no_context {
|
||||
str += " no_context"
|
||||
}
|
||||
if p.single_segment {
|
||||
str += " single_segment"
|
||||
}
|
||||
if p.print_special {
|
||||
str += " print_special"
|
||||
}
|
||||
if p.print_progress {
|
||||
str += " print_progress"
|
||||
}
|
||||
if p.print_realtime {
|
||||
str += " print_realtime"
|
||||
}
|
||||
if p.print_timestamps {
|
||||
str += " print_timestamps"
|
||||
}
|
||||
if p.token_timestamps {
|
||||
str += " token_timestamps"
|
||||
}
|
||||
if p.speed_up {
|
||||
str += " speed_up"
|
||||
}
|
||||
|
||||
return str + ">"
|
||||
}
|
27
bindings/go/pkg/whisper/consts.go
Executable file
27
bindings/go/pkg/whisper/consts.go
Executable file
@ -0,0 +1,27 @@
|
||||
package whisper
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
// Bindings
|
||||
whisper "github.com/ggerganov/whisper.cpp/bindings/go"
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// ERRORS
|
||||
|
||||
var (
|
||||
ErrUnableToLoadModel = errors.New("unable to load model")
|
||||
ErrInternalAppError = errors.New("internal application error")
|
||||
ErrProcessingFailed = errors.New("processing failed")
|
||||
ErrUnsupportedLanguage = errors.New("unsupported language")
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// CONSTANTS
|
||||
|
||||
// SampleRate is the sample rate of the audio data.
|
||||
const SampleRate = whisper.SampleRate
|
||||
|
||||
// SampleBits is the number of bytes per sample.
|
||||
const SampleBits = whisper.SampleBits
|
145
bindings/go/pkg/whisper/context.go
Executable file
145
bindings/go/pkg/whisper/context.go
Executable file
@ -0,0 +1,145 @@
|
||||
package whisper
|
||||
|
||||
import (
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
// Bindings
|
||||
whisper "github.com/ggerganov/whisper.cpp/bindings/go"
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// TYPES
|
||||
|
||||
type context struct {
|
||||
n int
|
||||
model *model
|
||||
params whisper.Params
|
||||
}
|
||||
|
||||
// Make sure context adheres to the interface
|
||||
var _ Context = (*context)(nil)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// LIFECYCLE
|
||||
|
||||
func NewContext(model *model, params whisper.Params) (Context, error) {
|
||||
context := new(context)
|
||||
context.model = model
|
||||
context.params = params
|
||||
|
||||
// Return success
|
||||
return context, nil
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// PUBLIC METHODS
|
||||
|
||||
// Set the language to use for speech recognition.
|
||||
func (context *context) SetLanguage(lang string) error {
|
||||
if context.model.ctx == nil {
|
||||
return ErrInternalAppError
|
||||
}
|
||||
if id := context.model.ctx.Whisper_lang_id(lang); id < 0 {
|
||||
return ErrUnsupportedLanguage
|
||||
} else if err := context.params.SetLanguage(id); err != nil {
|
||||
return err
|
||||
}
|
||||
// Return success
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get language
|
||||
func (context *context) Language() string {
|
||||
return whisper.Whisper_lang_str(context.params.Language())
|
||||
}
|
||||
|
||||
// Set speedup flag
|
||||
func (context *context) SetSpeedup(v bool) {
|
||||
context.params.SetSpeedup(v)
|
||||
}
|
||||
|
||||
// Process new sample data and return any errors
|
||||
func (context *context) Process(data []float32, cb SegmentCallback) error {
|
||||
if context.model.ctx == nil {
|
||||
return ErrInternalAppError
|
||||
}
|
||||
// If the callback is defined then we force on single_segment mode
|
||||
if cb != nil {
|
||||
context.params.SetSingleSegment(true)
|
||||
}
|
||||
|
||||
// We don't do parallel processing at the moment
|
||||
processors := 0
|
||||
if processors > 1 {
|
||||
if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, nil, func(new int) {
|
||||
if cb != nil {
|
||||
num_segments := context.model.ctx.Whisper_full_n_segments()
|
||||
s0 := num_segments - new
|
||||
for i := s0; i < num_segments; i++ {
|
||||
cb(toSegment(context.model.ctx, i))
|
||||
}
|
||||
}
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if err := context.model.ctx.Whisper_full(context.params, data, nil, func(new int) {
|
||||
if cb != nil {
|
||||
num_segments := context.model.ctx.Whisper_full_n_segments()
|
||||
s0 := num_segments - new
|
||||
for i := s0; i < num_segments; i++ {
|
||||
cb(toSegment(context.model.ctx, i))
|
||||
}
|
||||
}
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Return success
|
||||
return nil
|
||||
}
|
||||
|
||||
// Return the next segment of tokens
|
||||
func (context *context) NextSegment() (Segment, error) {
|
||||
if context.model.ctx == nil {
|
||||
return Segment{}, ErrInternalAppError
|
||||
}
|
||||
if context.n >= context.model.ctx.Whisper_full_n_segments() {
|
||||
return Segment{}, io.EOF
|
||||
}
|
||||
|
||||
// Populate result
|
||||
result := toSegment(context.model.ctx, context.n)
|
||||
|
||||
// Increment the cursor
|
||||
context.n++
|
||||
|
||||
// Return success
|
||||
return result, nil
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// PRIVATE METHODS
|
||||
|
||||
func toSegment(ctx *whisper.Context, n int) Segment {
|
||||
return Segment{
|
||||
Num: n,
|
||||
Text: strings.TrimSpace(ctx.Whisper_full_get_segment_text(n)),
|
||||
Start: time.Duration(ctx.Whisper_full_get_segment_t0(n)) * time.Millisecond * 10,
|
||||
End: time.Duration(ctx.Whisper_full_get_segment_t1(n)) * time.Millisecond * 10,
|
||||
Tokens: toTokens(ctx, n),
|
||||
}
|
||||
}
|
||||
|
||||
func toTokens(ctx *whisper.Context, n int) []Token {
|
||||
result := make([]Token, ctx.Whisper_full_n_tokens(n))
|
||||
for i := 0; i < len(result); i++ {
|
||||
result[i] = Token{
|
||||
Id: int(ctx.Whisper_full_get_token_id(n, i)),
|
||||
Text: strings.TrimSpace(ctx.Whisper_full_get_token_text(n, i)),
|
||||
P: ctx.Whisper_full_get_token_p(n, i),
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
55
bindings/go/pkg/whisper/context_test.go
Executable file
55
bindings/go/pkg/whisper/context_test.go
Executable file
@ -0,0 +1,55 @@
|
||||
package whisper_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
// Packages
|
||||
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
||||
assert "github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const (
|
||||
ModelPath = "../../models/ggml-tiny.bin"
|
||||
SamplePath = "../../samples/jfk.wav"
|
||||
)
|
||||
|
||||
func Test_Whisper_000(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, model not found:", ModelPath)
|
||||
}
|
||||
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, sample not found:", SamplePath)
|
||||
}
|
||||
|
||||
// Load model
|
||||
model, err := whisper.New(ModelPath)
|
||||
assert.NoError(err)
|
||||
assert.NotNil(model)
|
||||
assert.NoError(model.Close())
|
||||
|
||||
t.Log("languages=", model.Languages())
|
||||
}
|
||||
|
||||
func Test_Whisper_001(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, model not found:", ModelPath)
|
||||
}
|
||||
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, sample not found:", SamplePath)
|
||||
}
|
||||
|
||||
// Load model
|
||||
model, err := whisper.New(ModelPath)
|
||||
assert.NoError(err)
|
||||
assert.NotNil(model)
|
||||
defer model.Close()
|
||||
|
||||
// Get context for decoding
|
||||
ctx, err := model.NewContext()
|
||||
assert.NoError(err)
|
||||
assert.NotNil(ctx)
|
||||
|
||||
}
|
4
bindings/go/pkg/whisper/doc.go
Executable file
4
bindings/go/pkg/whisper/doc.go
Executable file
@ -0,0 +1,4 @@
|
||||
/*
|
||||
This is the higher-level speech-to-text whisper.cpp API for go
|
||||
*/
|
||||
package whisper
|
63
bindings/go/pkg/whisper/interface.go
Executable file
63
bindings/go/pkg/whisper/interface.go
Executable file
@ -0,0 +1,63 @@
|
||||
package whisper
|
||||
|
||||
import (
|
||||
"io"
|
||||
"time"
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// TYPES
|
||||
|
||||
// SegmentCallback is the callback function for processing segments in real
|
||||
// time. It is called during the Process function
|
||||
type SegmentCallback func(Segment)
|
||||
|
||||
// Model is the interface to a whisper model. Create a new model with the
|
||||
// function whisper.New(string)
|
||||
type Model interface {
|
||||
io.Closer
|
||||
|
||||
// Return a new speech-to-text context.
|
||||
NewContext() (Context, error)
|
||||
|
||||
// Return all languages supported.
|
||||
Languages() []string
|
||||
}
|
||||
|
||||
// Context is the speach recognition context.
|
||||
type Context interface {
|
||||
SetLanguage(string) error // Set the language to use for speech recognition.
|
||||
Language() string // Get language
|
||||
SetSpeedup(bool) // Set speedup flag
|
||||
|
||||
// Process mono audio data and return any errors.
|
||||
// If defined, newly generated segments are passed to the
|
||||
// callback function during processing.
|
||||
Process([]float32, SegmentCallback) error
|
||||
|
||||
// After process is called, return segments until the end of the stream
|
||||
// is reached, when io.EOF is returned.
|
||||
NextSegment() (Segment, error)
|
||||
}
|
||||
|
||||
// Segment is the text result of a speech recognition.
|
||||
type Segment struct {
|
||||
// Segment Number
|
||||
Num int
|
||||
|
||||
// Time beginning and end timestamps for the segment.
|
||||
Start, End time.Duration
|
||||
|
||||
// The text of the segment.
|
||||
Text string
|
||||
|
||||
// The tokens of the segment.
|
||||
Tokens []Token
|
||||
}
|
||||
|
||||
// Token is a text or special token
|
||||
type Token struct {
|
||||
Id int
|
||||
Text string
|
||||
P float32
|
||||
}
|
95
bindings/go/pkg/whisper/model.go
Executable file
95
bindings/go/pkg/whisper/model.go
Executable file
@ -0,0 +1,95 @@
|
||||
package whisper
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
|
||||
// Bindings
|
||||
whisper "github.com/ggerganov/whisper.cpp/bindings/go"
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// TYPES
|
||||
|
||||
type model struct {
|
||||
path string
|
||||
ctx *whisper.Context
|
||||
}
|
||||
|
||||
// Make sure model adheres to the interface
|
||||
var _ Model = (*model)(nil)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// LIFECYCLE
|
||||
|
||||
func New(path string) (*model, error) {
|
||||
model := new(model)
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
return nil, err
|
||||
} else if ctx := whisper.Whisper_init(path); ctx == nil {
|
||||
return nil, ErrUnableToLoadModel
|
||||
} else {
|
||||
model.ctx = ctx
|
||||
model.path = path
|
||||
}
|
||||
|
||||
// Return success
|
||||
return model, nil
|
||||
}
|
||||
|
||||
func (model *model) Close() error {
|
||||
if model.ctx != nil {
|
||||
model.ctx.Whisper_free()
|
||||
}
|
||||
|
||||
// Release resources
|
||||
model.ctx = nil
|
||||
|
||||
// Return success
|
||||
return nil
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// STRINGIFY
|
||||
|
||||
func (model *model) String() string {
|
||||
str := "<whisper.model"
|
||||
if model.ctx != nil {
|
||||
str += fmt.Sprintf(" model=%q", model.path)
|
||||
}
|
||||
return str + ">"
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// PUBLIC METHODS
|
||||
|
||||
// Return all recognized languages. Initially it is set to auto-detect
|
||||
func (model *model) Languages() []string {
|
||||
result := make([]string, 0, whisper.Whisper_lang_max_id())
|
||||
for i := 0; i < whisper.Whisper_lang_max_id(); i++ {
|
||||
str := whisper.Whisper_lang_str(i)
|
||||
if model.ctx.Whisper_lang_id(str) >= 0 {
|
||||
result = append(result, str)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (model *model) NewContext() (Context, error) {
|
||||
if model.ctx == nil {
|
||||
return nil, ErrInternalAppError
|
||||
}
|
||||
|
||||
// Create new context
|
||||
params := model.ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY)
|
||||
params.SetTranslate(false)
|
||||
params.SetPrintSpecial(false)
|
||||
params.SetPrintProgress(false)
|
||||
params.SetPrintRealtime(false)
|
||||
params.SetPrintTimestamps(false)
|
||||
params.SetThreads(runtime.NumCPU())
|
||||
|
||||
// Return new context
|
||||
return NewContext(model, params)
|
||||
}
|
BIN
bindings/go/samples/jfk.wav
Executable file
BIN
bindings/go/samples/jfk.wav
Executable file
Binary file not shown.
412
bindings/go/whisper.go
Normal file
412
bindings/go/whisper.go
Normal file
@ -0,0 +1,412 @@
|
||||
package whisper
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// CGO
|
||||
|
||||
/*
|
||||
#cgo CFLAGS: -I${SRCDIR}/../..
|
||||
#cgo LDFLAGS: -L${SRCDIR}/build -lwhisper -lm -lstdc++
|
||||
#cgo darwin LDFLAGS: -framework Accelerate
|
||||
#include <whisper.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
extern void callNewSegment(void* user_data, int new);
|
||||
extern bool callEncoderBegin(void* user_data);
|
||||
|
||||
// Text segment callback
|
||||
// Called on every newly generated text segment
|
||||
// Use the whisper_full_...() functions to obtain the text segments
|
||||
static void whisper_new_segment_cb(struct whisper_context* ctx, int n_new, void* user_data) {
|
||||
if(user_data != NULL && ctx != NULL) {
|
||||
callNewSegment(user_data, n_new);
|
||||
}
|
||||
}
|
||||
|
||||
// Encoder begin callback
|
||||
// If not NULL, called before the encoder starts
|
||||
// If it returns false, the computation is aborted
|
||||
static bool whisper_encoder_begin_cb(struct whisper_context* ctx, void* user_data) {
|
||||
if(user_data != NULL && ctx != NULL) {
|
||||
return callEncoderBegin(user_data);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Get default parameters and set callbacks
|
||||
static struct whisper_full_params whisper_full_default_params_cb(struct whisper_context* ctx, enum whisper_sampling_strategy strategy) {
|
||||
struct whisper_full_params params = whisper_full_default_params(strategy);
|
||||
params.new_segment_callback = whisper_new_segment_cb;
|
||||
params.new_segment_callback_user_data = (void*)(ctx);
|
||||
params.encoder_begin_callback = whisper_encoder_begin_cb;
|
||||
params.encoder_begin_callback_user_data = (void*)(ctx);
|
||||
return params;
|
||||
}
|
||||
*/
|
||||
import "C"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// TYPES
|
||||
|
||||
type (
|
||||
Context C.struct_whisper_context
|
||||
Token C.whisper_token
|
||||
TokenData C.struct_whisper_token_data
|
||||
SamplingStrategy C.enum_whisper_sampling_strategy
|
||||
Params C.struct_whisper_full_params
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GLOBALS
|
||||
|
||||
const (
|
||||
SAMPLING_GREEDY SamplingStrategy = C.WHISPER_SAMPLING_GREEDY
|
||||
SAMPLING_BEAM_SEARCH SamplingStrategy = C.WHISPER_SAMPLING_BEAM_SEARCH
|
||||
)
|
||||
|
||||
const (
|
||||
SampleRate = C.WHISPER_SAMPLE_RATE // Expected sample rate, samples per second
|
||||
SampleBits = uint16(unsafe.Sizeof(C.float(0))) * 8 // Sample size in bits
|
||||
NumFFT = C.WHISPER_N_FFT
|
||||
NumMEL = C.WHISPER_N_MEL
|
||||
HopLength = C.WHISPER_HOP_LENGTH
|
||||
ChunkSize = C.WHISPER_CHUNK_SIZE
|
||||
)
|
||||
|
||||
var (
|
||||
ErrTokenizerFailed = errors.New("whisper_tokenize failed")
|
||||
ErrAutoDetectFailed = errors.New("whisper_lang_auto_detect failed")
|
||||
ErrConversionFailed = errors.New("whisper_convert failed")
|
||||
ErrInvalidLanguage = errors.New("invalid language")
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// PUBLIC METHODS
|
||||
|
||||
// Allocates all memory needed for the model and loads the model from the given file.
|
||||
// Returns NULL on failure.
|
||||
func Whisper_init(path string) *Context {
|
||||
cPath := C.CString(path)
|
||||
defer C.free(unsafe.Pointer(cPath))
|
||||
if ctx := C.whisper_init(cPath); ctx != nil {
|
||||
return (*Context)(ctx)
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Frees all memory allocated by the model.
|
||||
func (ctx *Context) Whisper_free() {
|
||||
C.whisper_free((*C.struct_whisper_context)(ctx))
|
||||
}
|
||||
|
||||
// Convert RAW PCM audio to log mel spectrogram.
|
||||
// The resulting spectrogram is stored inside the provided whisper context.
|
||||
func (ctx *Context) Whisper_pcm_to_mel(data []float32, threads int) error {
|
||||
if C.whisper_pcm_to_mel((*C.struct_whisper_context)(ctx), (*C.float)(&data[0]), C.int(len(data)), C.int(threads)) == 0 {
|
||||
return nil
|
||||
} else {
|
||||
return ErrConversionFailed
|
||||
}
|
||||
}
|
||||
|
||||
// This can be used to set a custom log mel spectrogram inside the provided whisper context.
|
||||
// Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
|
||||
// n_mel must be 80
|
||||
func (ctx *Context) Whisper_set_mel(data []float32, n_mel int) error {
|
||||
if C.whisper_set_mel((*C.struct_whisper_context)(ctx), (*C.float)(&data[0]), C.int(len(data)), C.int(n_mel)) == 0 {
|
||||
return nil
|
||||
} else {
|
||||
return ErrConversionFailed
|
||||
}
|
||||
}
|
||||
|
||||
// Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper context.
|
||||
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
|
||||
// offset can be used to specify the offset of the first frame in the spectrogram.
|
||||
func (ctx *Context) Whisper_encode(offset, threads int) error {
|
||||
if C.whisper_encode((*C.struct_whisper_context)(ctx), C.int(offset), C.int(threads)) == 0 {
|
||||
return nil
|
||||
} else {
|
||||
return ErrConversionFailed
|
||||
}
|
||||
}
|
||||
|
||||
// Run the Whisper decoder to obtain the logits and probabilities for the next token.
|
||||
// Make sure to call whisper_encode() first.
|
||||
// tokens + n_tokens is the provided context for the decoder.
|
||||
// n_past is the number of tokens to use from previous decoder calls.
|
||||
func (ctx *Context) Whisper_decode(tokens []Token, past, threads int) error {
|
||||
if C.whisper_decode((*C.struct_whisper_context)(ctx), (*C.whisper_token)(&tokens[0]), C.int(len(tokens)), C.int(past), C.int(threads)) == 0 {
|
||||
return nil
|
||||
} else {
|
||||
return ErrConversionFailed
|
||||
}
|
||||
}
|
||||
|
||||
// whisper_sample_best() returns the token with the highest probability
|
||||
func (ctx *Context) Whisper_sample_best() TokenData {
|
||||
return TokenData(C.whisper_sample_best((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
||||
// whisper_sample_timestamp() returns the most probable timestamp token
|
||||
func (ctx *Context) Whisper_sample_timestamp(is_initial bool) TokenData {
|
||||
return TokenData(C.whisper_sample_timestamp((*C.struct_whisper_context)(ctx), C.bool(is_initial)))
|
||||
}
|
||||
|
||||
// Convert the provided text into tokens. The tokens pointer must be large enough to hold the resulting tokens.
|
||||
// Returns the number of tokens on success
|
||||
func (ctx *Context) Whisper_tokenize(text string, tokens []Token) (int, error) {
|
||||
cText := C.CString(text)
|
||||
defer C.free(unsafe.Pointer(cText))
|
||||
if n := C.whisper_tokenize((*C.struct_whisper_context)(ctx), cText, (*C.whisper_token)(&tokens[0]), C.int(len(tokens))); n >= 0 {
|
||||
return int(n), nil
|
||||
} else {
|
||||
return 0, ErrTokenizerFailed
|
||||
}
|
||||
}
|
||||
|
||||
// Return the id of the specified language, returns -1 if not found
|
||||
func (ctx *Context) Whisper_lang_id(lang string) int {
|
||||
return int(C.whisper_lang_id(C.CString(lang)))
|
||||
}
|
||||
|
||||
// Largest language id (i.e. number of available languages - 1)
|
||||
func Whisper_lang_max_id() int {
|
||||
return int(C.whisper_lang_max_id())
|
||||
}
|
||||
|
||||
// Return the short string of the specified language id (e.g. 2 -> "de"),
|
||||
// returns empty string if not found
|
||||
func Whisper_lang_str(id int) string {
|
||||
return C.GoString(C.whisper_lang_str(C.int(id)))
|
||||
}
|
||||
|
||||
// Use mel data at offset_ms to try and auto-detect the spoken language
|
||||
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
|
||||
// Returns the probabilities of all languages.
|
||||
// ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69
|
||||
func (ctx *Context) Whisper_lang_auto_detect(offset_ms, n_threads int) ([]float32, error) {
|
||||
probs := make([]float32, Whisper_lang_max_id()+1)
|
||||
if n := int(C.whisper_lang_auto_detect((*C.struct_whisper_context)(ctx), C.int(offset_ms), C.int(n_threads), (*C.float)(&probs[0]))); n < 0 {
|
||||
return nil, ErrAutoDetectFailed
|
||||
} else {
|
||||
return probs, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (ctx *Context) Whisper_n_len() int {
|
||||
return int(C.whisper_n_len((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
||||
func (ctx *Context) Whisper_n_vocab() int {
|
||||
return int(C.whisper_n_vocab((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
||||
func (ctx *Context) Whisper_n_text_ctx() int {
|
||||
return int(C.whisper_n_text_ctx((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
||||
func (ctx *Context) Whisper_is_multilingual() int {
|
||||
return int(C.whisper_is_multilingual((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
||||
// The probabilities for the next token
|
||||
//func (ctx *Whisper_context) Whisper_get_probs() []float32 {
|
||||
// return (*[1 << 30]float32)(unsafe.Pointer(C.whisper_get_probs((*C.struct_whisper_context)(ctx))))[:ctx.Whisper_n_vocab()]
|
||||
//}
|
||||
|
||||
// Token Id -> String. Uses the vocabulary in the provided context
|
||||
func (ctx *Context) Whisper_token_to_str(token Token) string {
|
||||
return C.GoString(C.whisper_token_to_str((*C.struct_whisper_context)(ctx), C.whisper_token(token)))
|
||||
}
|
||||
|
||||
// Special tokens
|
||||
func (ctx *Context) Whisper_token_eot() Token {
|
||||
return Token(C.whisper_token_eot((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
||||
// Special tokens
|
||||
func (ctx *Context) Whisper_token_sot() Token {
|
||||
return Token(C.whisper_token_sot((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
||||
// Special tokens
|
||||
func (ctx *Context) Whisper_token_prev() Token {
|
||||
return Token(C.whisper_token_prev((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
||||
// Special tokens
|
||||
func (ctx *Context) Whisper_token_solm() Token {
|
||||
return Token(C.whisper_token_solm((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
||||
// Special tokens
|
||||
func (ctx *Context) Whisper_token_not() Token {
|
||||
return Token(C.whisper_token_not((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
||||
// Special tokens
|
||||
func (ctx *Context) Whisper_token_beg() Token {
|
||||
return Token(C.whisper_token_beg((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
||||
// Special tokens
|
||||
func (ctx *Context) Whisper_token_lang(lang_id int) Token {
|
||||
return Token(C.whisper_token_lang((*C.struct_whisper_context)(ctx), C.int(lang_id)))
|
||||
}
|
||||
|
||||
// Task tokens
|
||||
func Whisper_token_translate() Token {
|
||||
return Token(C.whisper_token_translate())
|
||||
}
|
||||
|
||||
// Task tokens
|
||||
func Whisper_token_transcribe() Token {
|
||||
return Token(C.whisper_token_transcribe())
|
||||
}
|
||||
|
||||
// Performance information
|
||||
func (ctx *Context) Whisper_print_timings() {
|
||||
C.whisper_print_timings((*C.struct_whisper_context)(ctx))
|
||||
}
|
||||
|
||||
// Performance information
|
||||
func (ctx *Context) Whisper_reset_timings() {
|
||||
C.whisper_reset_timings((*C.struct_whisper_context)(ctx))
|
||||
}
|
||||
|
||||
// Print system information
|
||||
func Whisper_print_system_info() string {
|
||||
return C.GoString(C.whisper_print_system_info())
|
||||
}
|
||||
|
||||
// Return default parameters for a strategy
|
||||
func (ctx *Context) Whisper_full_default_params(strategy SamplingStrategy) Params {
|
||||
// Get default parameters
|
||||
return Params(C.whisper_full_default_params_cb((*C.struct_whisper_context)(ctx), C.enum_whisper_sampling_strategy(strategy)))
|
||||
}
|
||||
|
||||
// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
|
||||
// Uses the specified decoding strategy to obtain the text.
|
||||
func (ctx *Context) Whisper_full(params Params, samples []float32, encoderBeginCallback func() bool, newSegmentCallback func(int)) error {
|
||||
registerEncoderBeginCallback(ctx, encoderBeginCallback)
|
||||
registerNewSegmentCallback(ctx, newSegmentCallback)
|
||||
defer registerEncoderBeginCallback(ctx, nil)
|
||||
defer registerNewSegmentCallback(ctx, nil)
|
||||
if C.whisper_full((*C.struct_whisper_context)(ctx), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples))) == 0 {
|
||||
return nil
|
||||
} else {
|
||||
return ErrConversionFailed
|
||||
}
|
||||
}
|
||||
|
||||
// Split the input audio in chunks and process each chunk separately using whisper_full()
|
||||
// It seems this approach can offer some speedup in some cases.
|
||||
// However, the transcription accuracy can be worse at the beginning and end of each chunk.
|
||||
func (ctx *Context) Whisper_full_parallel(params Params, samples []float32, processors int, encoderBeginCallback func() bool, newSegmentCallback func(int)) error {
|
||||
registerEncoderBeginCallback(ctx, encoderBeginCallback)
|
||||
registerNewSegmentCallback(ctx, newSegmentCallback)
|
||||
defer registerEncoderBeginCallback(ctx, nil)
|
||||
defer registerNewSegmentCallback(ctx, nil)
|
||||
|
||||
if C.whisper_full_parallel((*C.struct_whisper_context)(ctx), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples)), C.int(processors)) == 0 {
|
||||
return nil
|
||||
} else {
|
||||
return ErrConversionFailed
|
||||
}
|
||||
}
|
||||
|
||||
// Number of generated text segments.
|
||||
// A segment can be a few words, a sentence, or even a paragraph.
|
||||
func (ctx *Context) Whisper_full_n_segments() int {
|
||||
return int(C.whisper_full_n_segments((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
||||
// Get the start and end time of the specified segment.
|
||||
func (ctx *Context) Whisper_full_get_segment_t0(segment int) int64 {
|
||||
return int64(C.whisper_full_get_segment_t0((*C.struct_whisper_context)(ctx), C.int(segment)))
|
||||
}
|
||||
|
||||
// Get the start and end time of the specified segment.
|
||||
func (ctx *Context) Whisper_full_get_segment_t1(segment int) int64 {
|
||||
return int64(C.whisper_full_get_segment_t1((*C.struct_whisper_context)(ctx), C.int(segment)))
|
||||
}
|
||||
|
||||
// Get the text of the specified segment.
|
||||
func (ctx *Context) Whisper_full_get_segment_text(segment int) string {
|
||||
return C.GoString(C.whisper_full_get_segment_text((*C.struct_whisper_context)(ctx), C.int(segment)))
|
||||
}
|
||||
|
||||
// Get number of tokens in the specified segment.
|
||||
func (ctx *Context) Whisper_full_n_tokens(segment int) int {
|
||||
return int(C.whisper_full_n_tokens((*C.struct_whisper_context)(ctx), C.int(segment)))
|
||||
}
|
||||
|
||||
// Get the token text of the specified token index in the specified segment.
|
||||
func (ctx *Context) Whisper_full_get_token_text(segment int, token int) string {
|
||||
return C.GoString(C.whisper_full_get_token_text((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
|
||||
}
|
||||
|
||||
// Get the token of the specified token index in the specified segment.
|
||||
func (ctx *Context) Whisper_full_get_token_id(segment int, token int) Token {
|
||||
return Token(C.whisper_full_get_token_id((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
|
||||
}
|
||||
|
||||
// Get token data for the specified token in the specified segment.
|
||||
// This contains probabilities, timestamps, etc.
|
||||
func (ctx *Context) whisper_full_get_token_data(segment int, token int) TokenData {
|
||||
return TokenData(C.whisper_full_get_token_data((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
|
||||
}
|
||||
|
||||
// Get the probability of the specified token in the specified segment.
|
||||
func (ctx *Context) Whisper_full_get_token_p(segment int, token int) float32 {
|
||||
return float32(C.whisper_full_get_token_p((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// CALLBACKS
|
||||
|
||||
var (
|
||||
cbNewSegment = make(map[unsafe.Pointer]func(int))
|
||||
cbEncoderBegin = make(map[unsafe.Pointer]func() bool)
|
||||
)
|
||||
|
||||
func registerNewSegmentCallback(ctx *Context, fn func(int)) {
|
||||
if fn == nil {
|
||||
delete(cbNewSegment, unsafe.Pointer(ctx))
|
||||
} else {
|
||||
cbNewSegment[unsafe.Pointer(ctx)] = fn
|
||||
}
|
||||
}
|
||||
|
||||
func registerEncoderBeginCallback(ctx *Context, fn func() bool) {
|
||||
if fn == nil {
|
||||
delete(cbEncoderBegin, unsafe.Pointer(ctx))
|
||||
} else {
|
||||
cbEncoderBegin[unsafe.Pointer(ctx)] = fn
|
||||
}
|
||||
}
|
||||
|
||||
//export callNewSegment
|
||||
func callNewSegment(user_data unsafe.Pointer, new C.int) {
|
||||
if fn, ok := cbNewSegment[user_data]; ok {
|
||||
fn(int(new))
|
||||
}
|
||||
}
|
||||
|
||||
//export callEncoderBegin
|
||||
func callEncoderBegin(user_data unsafe.Pointer) C.bool {
|
||||
if fn, ok := cbEncoderBegin[user_data]; ok {
|
||||
if fn() {
|
||||
return C.bool(true)
|
||||
} else {
|
||||
return C.bool(false)
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
110
bindings/go/whisper_test.go
Normal file
110
bindings/go/whisper_test.go
Normal file
@ -0,0 +1,110 @@
|
||||
package whisper_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
// Packages
|
||||
whisper "github.com/ggerganov/whisper.cpp/bindings/go"
|
||||
wav "github.com/go-audio/wav"
|
||||
assert "github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const (
|
||||
ModelPath = "models/ggml-small.en.bin"
|
||||
SamplePath = "samples/jfk.wav"
|
||||
)
|
||||
|
||||
func Test_Whisper_000(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, model not found:", ModelPath)
|
||||
}
|
||||
ctx := whisper.Whisper_init(ModelPath)
|
||||
assert.NotNil(ctx)
|
||||
ctx.Whisper_free()
|
||||
}
|
||||
|
||||
func Test_Whisper_001(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, model not found:", ModelPath)
|
||||
}
|
||||
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, sample not found:", SamplePath)
|
||||
}
|
||||
|
||||
// Open samples
|
||||
fh, err := os.Open(SamplePath)
|
||||
assert.NoError(err)
|
||||
defer fh.Close()
|
||||
|
||||
// Read samples
|
||||
d := wav.NewDecoder(fh)
|
||||
buf, err := d.FullPCMBuffer()
|
||||
assert.NoError(err)
|
||||
|
||||
// Run whisper
|
||||
ctx := whisper.Whisper_init(ModelPath)
|
||||
assert.NotNil(ctx)
|
||||
defer ctx.Whisper_free()
|
||||
assert.NoError(ctx.Whisper_full(ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY), buf.AsFloat32Buffer().Data, nil, nil))
|
||||
|
||||
// Print out tokens
|
||||
num_segments := ctx.Whisper_full_n_segments()
|
||||
assert.GreaterOrEqual(num_segments, 1)
|
||||
for i := 0; i < num_segments; i++ {
|
||||
str := ctx.Whisper_full_get_segment_text(i)
|
||||
assert.NotEmpty(str)
|
||||
t0 := time.Duration(ctx.Whisper_full_get_segment_t0(i)) * time.Millisecond
|
||||
t1 := time.Duration(ctx.Whisper_full_get_segment_t1(i)) * time.Millisecond
|
||||
t.Logf("[%6s->%-6s] %q", t0, t1, str)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Whisper_002(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
for i := 0; i < whisper.Whisper_lang_max_id(); i++ {
|
||||
str := whisper.Whisper_lang_str(i)
|
||||
assert.NotEmpty(str)
|
||||
t.Log(str)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Whisper_003(t *testing.T) {
|
||||
threads := runtime.NumCPU()
|
||||
assert := assert.New(t)
|
||||
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, model not found:", ModelPath)
|
||||
}
|
||||
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, sample not found:", SamplePath)
|
||||
}
|
||||
|
||||
// Open samples
|
||||
fh, err := os.Open(SamplePath)
|
||||
assert.NoError(err)
|
||||
defer fh.Close()
|
||||
|
||||
// Read samples
|
||||
d := wav.NewDecoder(fh)
|
||||
buf, err := d.FullPCMBuffer()
|
||||
assert.NoError(err)
|
||||
|
||||
// Make the model
|
||||
ctx := whisper.Whisper_init(ModelPath)
|
||||
assert.NotNil(ctx)
|
||||
defer ctx.Whisper_free()
|
||||
|
||||
// Get MEL
|
||||
assert.NoError(ctx.Whisper_pcm_to_mel(buf.AsFloat32Buffer().Data, threads))
|
||||
|
||||
// Get Languages
|
||||
languages, err := ctx.Whisper_lang_auto_detect(0, threads)
|
||||
assert.NoError(err)
|
||||
for i, p := range languages {
|
||||
t.Logf("%s: %f", whisper.Whisper_lang_str(i), p)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user