diff --git a/bindings/go/.gitignore b/bindings/go/.gitignore new file mode 100755 index 00000000..b4e10840 --- /dev/null +++ b/bindings/go/.gitignore @@ -0,0 +1,3 @@ +build +models +go.sum diff --git a/bindings/go/LICENSE b/bindings/go/LICENSE new file mode 100755 index 00000000..a8f0d7b9 --- /dev/null +++ b/bindings/go/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 David Thorpe + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/bindings/go/Makefile b/bindings/go/Makefile new file mode 100755 index 00000000..33742125 --- /dev/null +++ b/bindings/go/Makefile @@ -0,0 +1,38 @@ +CMAKE := $(shell which cmake) +BUILD_DIR := "build" +MODELS_DIR := "models" +EXAMPLES_DIR := $(wildcard examples/*) +C_INCLUDE_PATH := "../.." + +all: clean whisper examples + +whisper: mkdir + @echo Build whisper + @${CMAKE} -S ../.. -B ${BUILD_DIR} -D BUILD_SHARED_LIBS=off -D WHISPER_NO_AVX2=on + @${CMAKE} --build ${BUILD_DIR} --target whisper + +test: model-small whisper + @go mod tidy + @go test -v . + @go test -v ./pkg/whisper/... + +examples: $(EXAMPLES_DIR) + +model-small: mkdir examples/go-model-download + @${BUILD_DIR}/go-model-download -out models small.en + +$(EXAMPLES_DIR): mkdir whisper + @echo Build example $(notdir $@) + @go build ${BUILD_FLAGS} -o ${BUILD_DIR}/$(notdir $@) ./$@ + +mkdir: + @echo Mkdir ${BUILD_DIR} + @install -d ${BUILD_DIR} + @echo Mkdir ${MODELS_DIR} + @install -d ${MODELS_DIR} + +clean: + @echo Clean + @rm -fr $(BUILD_DIR) + @go mod tidy + @go clean diff --git a/bindings/go/README.md b/bindings/go/README.md new file mode 100755 index 00000000..8ae89c77 --- /dev/null +++ b/bindings/go/README.md @@ -0,0 +1,77 @@ +# Go bindings for Whisper + +This package provides Go bindings for whisper.cpp. They have been tested on: + + * Darwin (OS X) 12.6 on x64_64 + * Debian Linux on arm64 + * Fedora Linux on x86_64 + +The "low level" bindings are in the `bindings/go` directory and there is a more +Go-style package in the `bindings/go/pkg/whisper` directory. The most simple usage +is as follows: + +```go +import ( + "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" +) + +func main() { + var modelpath string // Path to the model + var samples []float32 // Samples to process + + // Load the model + model, err := whisper.New(modelpath) + if err != nil { + panic(err) + } + defer model.Close() + + // Process samples + context, err := model.NewContext() + if err != nil { + panic(err) + } + if err := context.Process(samples, nil); err != nil { + return err + } + + // Print out the results + for { + segment, err := context.NextSegment() + if err != nil { + break + } + fmt.Printf("[%6s->%6s] %s\n", segment.Start, segment.End, segment.Text) + } +} +``` + +## Building & Testing + +In order to build, you need to have the Go compiler installed. You can get it from [here](https://golang.org/dl/). Run the tests with: + +```bash +git clone https://github.com/ggerganov/whisper.cpp.git +cd whisper.cpp/bindings/go +make test +``` + +This will compile a static `libwhisper.a` in a `build` folder, download a model file, then run the tests. To build the examples: + +```bash +make examples +``` + +The examples are placed in the `build` directory. Once built, you can download all the models with the following command: + +```bash +./build/go-model-download -out models +``` + +And you can then test a model against samples with the following command: + +```bash +./build/go-whisper -model models/ggml-tiny.en.bin samples/jfk.wav +``` + + diff --git a/bindings/go/doc.go b/bindings/go/doc.go new file mode 100644 index 00000000..dcc351f2 --- /dev/null +++ b/bindings/go/doc.go @@ -0,0 +1,5 @@ +/* +github.com/ggerganov/whisper.cpp/bindings/go +provides a speech-to-text service bindings for the Go programming language. +*/ +package whisper diff --git a/bindings/go/examples/go-model-download/context.go b/bindings/go/examples/go-model-download/context.go new file mode 100755 index 00000000..639d8f5b --- /dev/null +++ b/bindings/go/examples/go-model-download/context.go @@ -0,0 +1,30 @@ +package main + +import ( + "context" + "os" + "os/signal" +) + +// ContextForSignal returns a context object which is cancelled when a signal +// is received. It returns nil if no signal parameter is provided +func ContextForSignal(signals ...os.Signal) context.Context { + if len(signals) == 0 { + return nil + } + + ch := make(chan os.Signal) + ctx, cancel := context.WithCancel(context.Background()) + + // Send message on channel when signal received + signal.Notify(ch, signals...) + + // When any signal received, call cancel + go func() { + <-ch + cancel() + }() + + // Return success + return ctx +} diff --git a/bindings/go/examples/go-model-download/main.go b/bindings/go/examples/go-model-download/main.go new file mode 100755 index 00000000..841a2c65 --- /dev/null +++ b/bindings/go/examples/go-model-download/main.go @@ -0,0 +1,206 @@ +package main + +import ( + "context" + "flag" + "fmt" + "io" + "net/http" + "net/url" + "os" + "path/filepath" + "syscall" + "time" +) + +/////////////////////////////////////////////////////////////////////////////// +// CONSTANTS + +const ( + srcUrl = "https://huggingface.co/" // The location of the models + srcPathPrefix = "/datasets/ggerganov/whisper.cpp/resolve/main/ggml" // Filename prefix + srcExt = ".bin" // Filename extension + bufSize = 1024 * 64 // Size of the buffer used for downloading the model +) + +var ( + // The models which will be downloaded, if no model is specified as an argument + modelNames = []string{"tiny.en", "tiny", "base.en", "base", "small.en", "small", "medium.en", "medium", "large-v1", "large"} +) + +var ( + // The output folder. When not set, use current working directory. + flagOut = flag.String("out", "", "Output folder") + + // HTTP timeout parameter - will timeout if takes longer than this to download a model + flagTimeout = flag.Duration("timeout", 30*time.Minute, "HTTP timeout") + + // Quiet parameter - will not print progress if set + flagQuiet = flag.Bool("quiet", false, "Quiet mode") +) + +/////////////////////////////////////////////////////////////////////////////// +// MAIN + +func main() { + flag.Usage = func() { + name := filepath.Base(flag.CommandLine.Name()) + fmt.Fprintf(flag.CommandLine.Output(), "Usage: %s [options] \n\n", name) + flag.PrintDefaults() + } + flag.Parse() + + // Get output path + out, err := GetOut() + if err != nil { + fmt.Fprintln(os.Stderr, "Error:", err) + os.Exit(-1) + } + + // Create context which quits on SIGINT or SIGQUIT + ctx := ContextForSignal(os.Interrupt, syscall.SIGQUIT) + + // Progress filehandle + progress := os.Stdout + if *flagQuiet { + progress, err = os.Open(os.DevNull) + if err != nil { + fmt.Fprintln(os.Stderr, "Error:", err) + os.Exit(-1) + } + defer progress.Close() + } + + // Download models - exit on error or interrupt + for _, model := range GetModels() { + url, err := URLForModel(model) + if err != nil { + fmt.Fprintln(os.Stderr, "Error:", err) + continue + } else if path, err := Download(ctx, progress, url, out); err == nil || err == io.EOF { + continue + } else if err == context.Canceled { + os.Remove(path) + fmt.Fprintln(progress, "\nInterrupted") + break + } else if err == context.DeadlineExceeded { + os.Remove(path) + fmt.Fprintln(progress, "Timeout downloading model") + continue + } else { + os.Remove(path) + fmt.Fprintln(os.Stderr, "Error:", err) + break + } + } +} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +// GetOut returns the path to the output directory +func GetOut() (string, error) { + if *flagOut == "" { + return os.Getwd() + } + if info, err := os.Stat(*flagOut); err != nil { + return "", err + } else if !info.IsDir() { + return "", fmt.Errorf("not a directory: %s", info.Name()) + } else { + return *flagOut, nil + } +} + +// GetModels returns the list of models to download +func GetModels() []string { + if flag.NArg() == 0 { + return modelNames + } else { + return flag.Args() + } +} + +// URLForModel returns the URL for the given model on huggingface.co +func URLForModel(model string) (string, error) { + url, err := url.Parse(srcUrl) + if err != nil { + return "", err + } else { + url.Path = srcPathPrefix + "-" + model + srcExt + } + return url.String(), nil +} + +// Download downloads the model from the given URL to the given output directory +func Download(ctx context.Context, p io.Writer, model, out string) (string, error) { + // Create HTTP client + client := http.Client{ + Timeout: *flagTimeout, + } + + // Initiate the download + req, err := http.NewRequest("GET", model, nil) + if err != nil { + return "", err + } + resp, err := client.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("%s: %s", model, resp.Status) + } + + // If output file exists and is the same size as the model, skip + path := filepath.Join(out, filepath.Base(model)) + if info, err := os.Stat(path); err == nil && info.Size() == resp.ContentLength { + fmt.Fprintln(p, "Skipping", model, "as it already exists") + return "", nil + } + + // Create file + w, err := os.Create(path) + if err != nil { + return "", err + } + defer w.Close() + + // Report + fmt.Fprintln(p, "Downloading", model, "to", out) + + // Progressively download the model + data := make([]byte, bufSize) + count, pct := int64(0), int64(0) + ticker := time.NewTicker(5 * time.Second) + for { + select { + case <-ctx.Done(): + // Cancelled, return error + return path, ctx.Err() + case <-ticker.C: + pct = DownloadReport(p, pct, count, resp.ContentLength) + default: + // Read body + n, err := resp.Body.Read(data) + if err != nil { + DownloadReport(p, pct, count, resp.ContentLength) + return path, err + } else if m, err := w.Write(data[:n]); err != nil { + return path, err + } else { + count += int64(m) + } + } + } +} + +// Report periodically reports the download progress when percentage changes +func DownloadReport(w io.Writer, pct, count, total int64) int64 { + pct_ := count * 100 / total + if pct_ > pct { + fmt.Fprintf(w, " ...%d MB written (%d%%)\n", count/1e6, pct_) + } + return pct_ +} diff --git a/bindings/go/examples/go-whisper/flags.go b/bindings/go/examples/go-whisper/flags.go new file mode 100755 index 00000000..a5353d1c --- /dev/null +++ b/bindings/go/examples/go-whisper/flags.go @@ -0,0 +1,61 @@ +package main + +import ( + "flag" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +type Flags struct { + *flag.FlagSet +} + +/////////////////////////////////////////////////////////////////////////////// +// LIFECYCLE + +func NewFlags(name string, args []string) (*Flags, error) { + flags := &Flags{ + FlagSet: flag.NewFlagSet(name, flag.ContinueOnError), + } + + // Register the command line arguments + registerFlags(flags) + + // Parse command line + if err := flags.Parse(args); err != nil { + return nil, err + } + + // Return success + return flags, nil +} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +func (flags *Flags) GetModel() string { + return flags.Lookup("model").Value.String() +} + +func (flags *Flags) GetLanguage() string { + return flags.Lookup("language").Value.String() +} + +func (flags *Flags) IsSpeedup() bool { + return flags.Lookup("speedup").Value.String() == "true" +} + +func (flags *Flags) IsTokens() bool { + return flags.Lookup("tokens").Value.String() == "true" +} + +/////////////////////////////////////////////////////////////////////////////// +// PRIVATE METHODS + +func registerFlags(flag *Flags) { + flag.String("model", "", "Path to the model file") + flag.String("language", "", "Language") + flag.Bool("speedup", false, "Enable speedup") + flag.Bool("tokens", false, "Display tokens") +} diff --git a/bindings/go/examples/go-whisper/main.go b/bindings/go/examples/go-whisper/main.go new file mode 100755 index 00000000..b3a89db7 --- /dev/null +++ b/bindings/go/examples/go-whisper/main.go @@ -0,0 +1,44 @@ +package main + +import ( + "flag" + "fmt" + "os" + "path/filepath" + + // Packages + whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" +) + +func main() { + flags, err := NewFlags(filepath.Base(os.Args[0]), os.Args[1:]) + if err == flag.ErrHelp { + os.Exit(0) + } else if err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } else if flags.GetModel() == "" { + fmt.Fprintln(os.Stderr, "Use -model flag to specify which model file to use") + os.Exit(1) + } else if flags.NArg() == 0 { + fmt.Fprintln(os.Stderr, "No input files specified") + os.Exit(1) + } + + // Load model + model, err := whisper.New(flags.GetModel()) + if err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } + defer model.Close() + + // Process files + for _, filename := range flags.Args() { + fmt.Println("Processing", filename) + if err := Process(model, filename, flags.GetLanguage(), flags.IsSpeedup(), flags.IsTokens()); err != nil { + fmt.Fprintln(os.Stderr, err) + continue + } + } +} diff --git a/bindings/go/examples/go-whisper/process.go b/bindings/go/examples/go-whisper/process.go new file mode 100755 index 00000000..a0e2be86 --- /dev/null +++ b/bindings/go/examples/go-whisper/process.go @@ -0,0 +1,80 @@ +package main + +import ( + "fmt" + "io" + "os" + "time" + + // Package imports + whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" + wav "github.com/go-audio/wav" +) + +func Process(model whisper.Model, path string, lang string, speedup, tokens bool) error { + var data []float32 + + // Create processing context + context, err := model.NewContext() + if err != nil { + return err + } + + // Open the file + fh, err := os.Open(path) + if err != nil { + return err + } + defer fh.Close() + + // Decode the WAV file + dec := wav.NewDecoder(fh) + if buf, err := dec.FullPCMBuffer(); err != nil { + return err + } else if dec.SampleRate != whisper.SampleRate { + return fmt.Errorf("unsupported sample rate: %d", dec.SampleRate) + } else if dec.NumChans != 1 { + return fmt.Errorf("unsupported number of channels: %d", dec.NumChans) + } else { + data = buf.AsFloat32Buffer().Data + } + + // Set the parameters + var cb whisper.SegmentCallback + if lang != "" { + if err := context.SetLanguage(lang); err != nil { + return err + } + } + if speedup { + context.SetSpeedup(true) + } + if tokens { + cb = func(segment whisper.Segment) { + fmt.Printf("%02d [%6s->%6s] ", segment.Num, segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond)) + for _, token := range segment.Tokens { + fmt.Printf("%q ", token.Text) + } + fmt.Println("") + } + } + + // Process the data + if err := context.Process(data, cb); err != nil { + return err + } + + // Print out the results + for { + segment, err := context.NextSegment() + if err == io.EOF { + break + } else if err != nil { + return err + } + fmt.Printf("[%6s->%6s] %s\n", segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond), segment.Text) + } + + // Return success + return nil +} diff --git a/bindings/go/go.mod b/bindings/go/go.mod new file mode 100755 index 00000000..594f184b --- /dev/null +++ b/bindings/go/go.mod @@ -0,0 +1,16 @@ +module github.com/ggerganov/whisper.cpp/bindings/go + +go 1.19 + +require ( + github.com/go-audio/wav v1.1.0 + github.com/stretchr/testify v1.8.1 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/go-audio/audio v1.0.0 // indirect + github.com/go-audio/riff v1.0.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/bindings/go/params.go b/bindings/go/params.go new file mode 100644 index 00000000..7f4c509c --- /dev/null +++ b/bindings/go/params.go @@ -0,0 +1,134 @@ +package whisper + +// This file defines the whisper_token, whisper_token_data and whisper_full_params +// structures, which are used by the whisper_full() function. + +import ( + "fmt" +) + +/////////////////////////////////////////////////////////////////////////////// +// CGO + +/* +#include +*/ +import "C" + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +func (p *Params) SetTranslate(v bool) { + p.translate = toBool(v) +} + +func (p *Params) SetNoContext(v bool) { + p.no_context = toBool(v) +} + +func (p *Params) SetSingleSegment(v bool) { + p.single_segment = toBool(v) +} + +func (p *Params) SetPrintSpecial(v bool) { + p.print_special = toBool(v) +} + +func (p *Params) SetPrintProgress(v bool) { + p.print_progress = toBool(v) +} + +func (p *Params) SetPrintRealtime(v bool) { + p.print_realtime = toBool(v) +} + +func (p *Params) SetPrintTimestamps(v bool) { + p.print_timestamps = toBool(v) +} + +func (p *Params) SetSpeedup(v bool) { + p.speed_up = toBool(v) +} + +func (p *Params) SetLanguage(lang int) error { + str := C.whisper_lang_str(C.int(lang)) + if str == nil { + return ErrInvalidLanguage + } else { + p.language = str + } + return nil +} + +func (p *Params) Language() int { + if p.language == nil { + return -1 + } + return int(C.whisper_lang_id(p.language)) +} + +func (p *Params) SetThreads(threads int) { + p.n_threads = C.int(threads) +} + +func (p *Params) SetOffset(offset_ms int) { + p.offset_ms = C.int(offset_ms) +} + +func (p *Params) SetDuration(duration_ms int) { + p.duration_ms = C.int(duration_ms) +} + +/////////////////////////////////////////////////////////////////////////////// +// PRIVATE METHODS + +func toBool(v bool) C.bool { + if v { + return C.bool(true) + } + return C.bool(false) +} + +/////////////////////////////////////////////////////////////////////////////// +// STRINGIFY + +func (p *Params) String() string { + str := "" +} diff --git a/bindings/go/pkg/whisper/consts.go b/bindings/go/pkg/whisper/consts.go new file mode 100755 index 00000000..710073f0 --- /dev/null +++ b/bindings/go/pkg/whisper/consts.go @@ -0,0 +1,27 @@ +package whisper + +import ( + "errors" + + // Bindings + whisper "github.com/ggerganov/whisper.cpp/bindings/go" +) + +/////////////////////////////////////////////////////////////////////////////// +// ERRORS + +var ( + ErrUnableToLoadModel = errors.New("unable to load model") + ErrInternalAppError = errors.New("internal application error") + ErrProcessingFailed = errors.New("processing failed") + ErrUnsupportedLanguage = errors.New("unsupported language") +) + +/////////////////////////////////////////////////////////////////////////////// +// CONSTANTS + +// SampleRate is the sample rate of the audio data. +const SampleRate = whisper.SampleRate + +// SampleBits is the number of bytes per sample. +const SampleBits = whisper.SampleBits diff --git a/bindings/go/pkg/whisper/context.go b/bindings/go/pkg/whisper/context.go new file mode 100755 index 00000000..baff611c --- /dev/null +++ b/bindings/go/pkg/whisper/context.go @@ -0,0 +1,145 @@ +package whisper + +import ( + "io" + "strings" + "time" + + // Bindings + whisper "github.com/ggerganov/whisper.cpp/bindings/go" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +type context struct { + n int + model *model + params whisper.Params +} + +// Make sure context adheres to the interface +var _ Context = (*context)(nil) + +/////////////////////////////////////////////////////////////////////////////// +// LIFECYCLE + +func NewContext(model *model, params whisper.Params) (Context, error) { + context := new(context) + context.model = model + context.params = params + + // Return success + return context, nil +} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +// Set the language to use for speech recognition. +func (context *context) SetLanguage(lang string) error { + if context.model.ctx == nil { + return ErrInternalAppError + } + if id := context.model.ctx.Whisper_lang_id(lang); id < 0 { + return ErrUnsupportedLanguage + } else if err := context.params.SetLanguage(id); err != nil { + return err + } + // Return success + return nil +} + +// Get language +func (context *context) Language() string { + return whisper.Whisper_lang_str(context.params.Language()) +} + +// Set speedup flag +func (context *context) SetSpeedup(v bool) { + context.params.SetSpeedup(v) +} + +// Process new sample data and return any errors +func (context *context) Process(data []float32, cb SegmentCallback) error { + if context.model.ctx == nil { + return ErrInternalAppError + } + // If the callback is defined then we force on single_segment mode + if cb != nil { + context.params.SetSingleSegment(true) + } + + // We don't do parallel processing at the moment + processors := 0 + if processors > 1 { + if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, nil, func(new int) { + if cb != nil { + num_segments := context.model.ctx.Whisper_full_n_segments() + s0 := num_segments - new + for i := s0; i < num_segments; i++ { + cb(toSegment(context.model.ctx, i)) + } + } + }); err != nil { + return err + } + } else if err := context.model.ctx.Whisper_full(context.params, data, nil, func(new int) { + if cb != nil { + num_segments := context.model.ctx.Whisper_full_n_segments() + s0 := num_segments - new + for i := s0; i < num_segments; i++ { + cb(toSegment(context.model.ctx, i)) + } + } + }); err != nil { + return err + } + + // Return success + return nil +} + +// Return the next segment of tokens +func (context *context) NextSegment() (Segment, error) { + if context.model.ctx == nil { + return Segment{}, ErrInternalAppError + } + if context.n >= context.model.ctx.Whisper_full_n_segments() { + return Segment{}, io.EOF + } + + // Populate result + result := toSegment(context.model.ctx, context.n) + + // Increment the cursor + context.n++ + + // Return success + return result, nil +} + +/////////////////////////////////////////////////////////////////////////////// +// PRIVATE METHODS + +func toSegment(ctx *whisper.Context, n int) Segment { + return Segment{ + Num: n, + Text: strings.TrimSpace(ctx.Whisper_full_get_segment_text(n)), + Start: time.Duration(ctx.Whisper_full_get_segment_t0(n)) * time.Millisecond * 10, + End: time.Duration(ctx.Whisper_full_get_segment_t1(n)) * time.Millisecond * 10, + Tokens: toTokens(ctx, n), + } +} + +func toTokens(ctx *whisper.Context, n int) []Token { + result := make([]Token, ctx.Whisper_full_n_tokens(n)) + for i := 0; i < len(result); i++ { + result[i] = Token{ + Id: int(ctx.Whisper_full_get_token_id(n, i)), + Text: strings.TrimSpace(ctx.Whisper_full_get_token_text(n, i)), + P: ctx.Whisper_full_get_token_p(n, i), + } + } + return result +} diff --git a/bindings/go/pkg/whisper/context_test.go b/bindings/go/pkg/whisper/context_test.go new file mode 100755 index 00000000..c8c6016e --- /dev/null +++ b/bindings/go/pkg/whisper/context_test.go @@ -0,0 +1,55 @@ +package whisper_test + +import ( + "os" + "testing" + + // Packages + whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" + assert "github.com/stretchr/testify/assert" +) + +const ( + ModelPath = "../../models/ggml-tiny.bin" + SamplePath = "../../samples/jfk.wav" +) + +func Test_Whisper_000(t *testing.T) { + assert := assert.New(t) + if _, err := os.Stat(ModelPath); os.IsNotExist(err) { + t.Skip("Skipping test, model not found:", ModelPath) + } + if _, err := os.Stat(SamplePath); os.IsNotExist(err) { + t.Skip("Skipping test, sample not found:", SamplePath) + } + + // Load model + model, err := whisper.New(ModelPath) + assert.NoError(err) + assert.NotNil(model) + assert.NoError(model.Close()) + + t.Log("languages=", model.Languages()) +} + +func Test_Whisper_001(t *testing.T) { + assert := assert.New(t) + if _, err := os.Stat(ModelPath); os.IsNotExist(err) { + t.Skip("Skipping test, model not found:", ModelPath) + } + if _, err := os.Stat(SamplePath); os.IsNotExist(err) { + t.Skip("Skipping test, sample not found:", SamplePath) + } + + // Load model + model, err := whisper.New(ModelPath) + assert.NoError(err) + assert.NotNil(model) + defer model.Close() + + // Get context for decoding + ctx, err := model.NewContext() + assert.NoError(err) + assert.NotNil(ctx) + +} diff --git a/bindings/go/pkg/whisper/doc.go b/bindings/go/pkg/whisper/doc.go new file mode 100755 index 00000000..fd4f1b97 --- /dev/null +++ b/bindings/go/pkg/whisper/doc.go @@ -0,0 +1,4 @@ +/* +This is the higher-level speech-to-text whisper.cpp API for go +*/ +package whisper diff --git a/bindings/go/pkg/whisper/interface.go b/bindings/go/pkg/whisper/interface.go new file mode 100755 index 00000000..53e4f3f0 --- /dev/null +++ b/bindings/go/pkg/whisper/interface.go @@ -0,0 +1,63 @@ +package whisper + +import ( + "io" + "time" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +// SegmentCallback is the callback function for processing segments in real +// time. It is called during the Process function +type SegmentCallback func(Segment) + +// Model is the interface to a whisper model. Create a new model with the +// function whisper.New(string) +type Model interface { + io.Closer + + // Return a new speech-to-text context. + NewContext() (Context, error) + + // Return all languages supported. + Languages() []string +} + +// Context is the speach recognition context. +type Context interface { + SetLanguage(string) error // Set the language to use for speech recognition. + Language() string // Get language + SetSpeedup(bool) // Set speedup flag + + // Process mono audio data and return any errors. + // If defined, newly generated segments are passed to the + // callback function during processing. + Process([]float32, SegmentCallback) error + + // After process is called, return segments until the end of the stream + // is reached, when io.EOF is returned. + NextSegment() (Segment, error) +} + +// Segment is the text result of a speech recognition. +type Segment struct { + // Segment Number + Num int + + // Time beginning and end timestamps for the segment. + Start, End time.Duration + + // The text of the segment. + Text string + + // The tokens of the segment. + Tokens []Token +} + +// Token is a text or special token +type Token struct { + Id int + Text string + P float32 +} diff --git a/bindings/go/pkg/whisper/model.go b/bindings/go/pkg/whisper/model.go new file mode 100755 index 00000000..13cb52ca --- /dev/null +++ b/bindings/go/pkg/whisper/model.go @@ -0,0 +1,95 @@ +package whisper + +import ( + "fmt" + "os" + "runtime" + + // Bindings + whisper "github.com/ggerganov/whisper.cpp/bindings/go" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +type model struct { + path string + ctx *whisper.Context +} + +// Make sure model adheres to the interface +var _ Model = (*model)(nil) + +/////////////////////////////////////////////////////////////////////////////// +// LIFECYCLE + +func New(path string) (*model, error) { + model := new(model) + if _, err := os.Stat(path); err != nil { + return nil, err + } else if ctx := whisper.Whisper_init(path); ctx == nil { + return nil, ErrUnableToLoadModel + } else { + model.ctx = ctx + model.path = path + } + + // Return success + return model, nil +} + +func (model *model) Close() error { + if model.ctx != nil { + model.ctx.Whisper_free() + } + + // Release resources + model.ctx = nil + + // Return success + return nil +} + +/////////////////////////////////////////////////////////////////////////////// +// STRINGIFY + +func (model *model) String() string { + str := "" +} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +// Return all recognized languages. Initially it is set to auto-detect +func (model *model) Languages() []string { + result := make([]string, 0, whisper.Whisper_lang_max_id()) + for i := 0; i < whisper.Whisper_lang_max_id(); i++ { + str := whisper.Whisper_lang_str(i) + if model.ctx.Whisper_lang_id(str) >= 0 { + result = append(result, str) + } + } + return result +} + +func (model *model) NewContext() (Context, error) { + if model.ctx == nil { + return nil, ErrInternalAppError + } + + // Create new context + params := model.ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY) + params.SetTranslate(false) + params.SetPrintSpecial(false) + params.SetPrintProgress(false) + params.SetPrintRealtime(false) + params.SetPrintTimestamps(false) + params.SetThreads(runtime.NumCPU()) + + // Return new context + return NewContext(model, params) +} diff --git a/bindings/go/samples/jfk.wav b/bindings/go/samples/jfk.wav new file mode 100755 index 00000000..3184d372 Binary files /dev/null and b/bindings/go/samples/jfk.wav differ diff --git a/bindings/go/whisper.go b/bindings/go/whisper.go new file mode 100644 index 00000000..2584f7bb --- /dev/null +++ b/bindings/go/whisper.go @@ -0,0 +1,412 @@ +package whisper + +import ( + "errors" + "unsafe" +) + +/////////////////////////////////////////////////////////////////////////////// +// CGO + +/* +#cgo CFLAGS: -I${SRCDIR}/../.. +#cgo LDFLAGS: -L${SRCDIR}/build -lwhisper -lm -lstdc++ +#cgo darwin LDFLAGS: -framework Accelerate +#include +#include + +extern void callNewSegment(void* user_data, int new); +extern bool callEncoderBegin(void* user_data); + +// Text segment callback +// Called on every newly generated text segment +// Use the whisper_full_...() functions to obtain the text segments +static void whisper_new_segment_cb(struct whisper_context* ctx, int n_new, void* user_data) { + if(user_data != NULL && ctx != NULL) { + callNewSegment(user_data, n_new); + } +} + +// Encoder begin callback +// If not NULL, called before the encoder starts +// If it returns false, the computation is aborted +static bool whisper_encoder_begin_cb(struct whisper_context* ctx, void* user_data) { + if(user_data != NULL && ctx != NULL) { + return callEncoderBegin(user_data); + } + return false; +} + +// Get default parameters and set callbacks +static struct whisper_full_params whisper_full_default_params_cb(struct whisper_context* ctx, enum whisper_sampling_strategy strategy) { + struct whisper_full_params params = whisper_full_default_params(strategy); + params.new_segment_callback = whisper_new_segment_cb; + params.new_segment_callback_user_data = (void*)(ctx); + params.encoder_begin_callback = whisper_encoder_begin_cb; + params.encoder_begin_callback_user_data = (void*)(ctx); + return params; +} +*/ +import "C" + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +type ( + Context C.struct_whisper_context + Token C.whisper_token + TokenData C.struct_whisper_token_data + SamplingStrategy C.enum_whisper_sampling_strategy + Params C.struct_whisper_full_params +) + +/////////////////////////////////////////////////////////////////////////////// +// GLOBALS + +const ( + SAMPLING_GREEDY SamplingStrategy = C.WHISPER_SAMPLING_GREEDY + SAMPLING_BEAM_SEARCH SamplingStrategy = C.WHISPER_SAMPLING_BEAM_SEARCH +) + +const ( + SampleRate = C.WHISPER_SAMPLE_RATE // Expected sample rate, samples per second + SampleBits = uint16(unsafe.Sizeof(C.float(0))) * 8 // Sample size in bits + NumFFT = C.WHISPER_N_FFT + NumMEL = C.WHISPER_N_MEL + HopLength = C.WHISPER_HOP_LENGTH + ChunkSize = C.WHISPER_CHUNK_SIZE +) + +var ( + ErrTokenizerFailed = errors.New("whisper_tokenize failed") + ErrAutoDetectFailed = errors.New("whisper_lang_auto_detect failed") + ErrConversionFailed = errors.New("whisper_convert failed") + ErrInvalidLanguage = errors.New("invalid language") +) + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +// Allocates all memory needed for the model and loads the model from the given file. +// Returns NULL on failure. +func Whisper_init(path string) *Context { + cPath := C.CString(path) + defer C.free(unsafe.Pointer(cPath)) + if ctx := C.whisper_init(cPath); ctx != nil { + return (*Context)(ctx) + } else { + return nil + } +} + +// Frees all memory allocated by the model. +func (ctx *Context) Whisper_free() { + C.whisper_free((*C.struct_whisper_context)(ctx)) +} + +// Convert RAW PCM audio to log mel spectrogram. +// The resulting spectrogram is stored inside the provided whisper context. +func (ctx *Context) Whisper_pcm_to_mel(data []float32, threads int) error { + if C.whisper_pcm_to_mel((*C.struct_whisper_context)(ctx), (*C.float)(&data[0]), C.int(len(data)), C.int(threads)) == 0 { + return nil + } else { + return ErrConversionFailed + } +} + +// This can be used to set a custom log mel spectrogram inside the provided whisper context. +// Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram. +// n_mel must be 80 +func (ctx *Context) Whisper_set_mel(data []float32, n_mel int) error { + if C.whisper_set_mel((*C.struct_whisper_context)(ctx), (*C.float)(&data[0]), C.int(len(data)), C.int(n_mel)) == 0 { + return nil + } else { + return ErrConversionFailed + } +} + +// Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper context. +// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first. +// offset can be used to specify the offset of the first frame in the spectrogram. +func (ctx *Context) Whisper_encode(offset, threads int) error { + if C.whisper_encode((*C.struct_whisper_context)(ctx), C.int(offset), C.int(threads)) == 0 { + return nil + } else { + return ErrConversionFailed + } +} + +// Run the Whisper decoder to obtain the logits and probabilities for the next token. +// Make sure to call whisper_encode() first. +// tokens + n_tokens is the provided context for the decoder. +// n_past is the number of tokens to use from previous decoder calls. +func (ctx *Context) Whisper_decode(tokens []Token, past, threads int) error { + if C.whisper_decode((*C.struct_whisper_context)(ctx), (*C.whisper_token)(&tokens[0]), C.int(len(tokens)), C.int(past), C.int(threads)) == 0 { + return nil + } else { + return ErrConversionFailed + } +} + +// whisper_sample_best() returns the token with the highest probability +func (ctx *Context) Whisper_sample_best() TokenData { + return TokenData(C.whisper_sample_best((*C.struct_whisper_context)(ctx))) +} + +// whisper_sample_timestamp() returns the most probable timestamp token +func (ctx *Context) Whisper_sample_timestamp(is_initial bool) TokenData { + return TokenData(C.whisper_sample_timestamp((*C.struct_whisper_context)(ctx), C.bool(is_initial))) +} + +// Convert the provided text into tokens. The tokens pointer must be large enough to hold the resulting tokens. +// Returns the number of tokens on success +func (ctx *Context) Whisper_tokenize(text string, tokens []Token) (int, error) { + cText := C.CString(text) + defer C.free(unsafe.Pointer(cText)) + if n := C.whisper_tokenize((*C.struct_whisper_context)(ctx), cText, (*C.whisper_token)(&tokens[0]), C.int(len(tokens))); n >= 0 { + return int(n), nil + } else { + return 0, ErrTokenizerFailed + } +} + +// Return the id of the specified language, returns -1 if not found +func (ctx *Context) Whisper_lang_id(lang string) int { + return int(C.whisper_lang_id(C.CString(lang))) +} + +// Largest language id (i.e. number of available languages - 1) +func Whisper_lang_max_id() int { + return int(C.whisper_lang_max_id()) +} + +// Return the short string of the specified language id (e.g. 2 -> "de"), +// returns empty string if not found +func Whisper_lang_str(id int) string { + return C.GoString(C.whisper_lang_str(C.int(id))) +} + +// Use mel data at offset_ms to try and auto-detect the spoken language +// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first. +// Returns the probabilities of all languages. +// ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69 +func (ctx *Context) Whisper_lang_auto_detect(offset_ms, n_threads int) ([]float32, error) { + probs := make([]float32, Whisper_lang_max_id()+1) + if n := int(C.whisper_lang_auto_detect((*C.struct_whisper_context)(ctx), C.int(offset_ms), C.int(n_threads), (*C.float)(&probs[0]))); n < 0 { + return nil, ErrAutoDetectFailed + } else { + return probs, nil + } +} + +func (ctx *Context) Whisper_n_len() int { + return int(C.whisper_n_len((*C.struct_whisper_context)(ctx))) +} + +func (ctx *Context) Whisper_n_vocab() int { + return int(C.whisper_n_vocab((*C.struct_whisper_context)(ctx))) +} + +func (ctx *Context) Whisper_n_text_ctx() int { + return int(C.whisper_n_text_ctx((*C.struct_whisper_context)(ctx))) +} + +func (ctx *Context) Whisper_is_multilingual() int { + return int(C.whisper_is_multilingual((*C.struct_whisper_context)(ctx))) +} + +// The probabilities for the next token +//func (ctx *Whisper_context) Whisper_get_probs() []float32 { +// return (*[1 << 30]float32)(unsafe.Pointer(C.whisper_get_probs((*C.struct_whisper_context)(ctx))))[:ctx.Whisper_n_vocab()] +//} + +// Token Id -> String. Uses the vocabulary in the provided context +func (ctx *Context) Whisper_token_to_str(token Token) string { + return C.GoString(C.whisper_token_to_str((*C.struct_whisper_context)(ctx), C.whisper_token(token))) +} + +// Special tokens +func (ctx *Context) Whisper_token_eot() Token { + return Token(C.whisper_token_eot((*C.struct_whisper_context)(ctx))) +} + +// Special tokens +func (ctx *Context) Whisper_token_sot() Token { + return Token(C.whisper_token_sot((*C.struct_whisper_context)(ctx))) +} + +// Special tokens +func (ctx *Context) Whisper_token_prev() Token { + return Token(C.whisper_token_prev((*C.struct_whisper_context)(ctx))) +} + +// Special tokens +func (ctx *Context) Whisper_token_solm() Token { + return Token(C.whisper_token_solm((*C.struct_whisper_context)(ctx))) +} + +// Special tokens +func (ctx *Context) Whisper_token_not() Token { + return Token(C.whisper_token_not((*C.struct_whisper_context)(ctx))) +} + +// Special tokens +func (ctx *Context) Whisper_token_beg() Token { + return Token(C.whisper_token_beg((*C.struct_whisper_context)(ctx))) +} + +// Special tokens +func (ctx *Context) Whisper_token_lang(lang_id int) Token { + return Token(C.whisper_token_lang((*C.struct_whisper_context)(ctx), C.int(lang_id))) +} + +// Task tokens +func Whisper_token_translate() Token { + return Token(C.whisper_token_translate()) +} + +// Task tokens +func Whisper_token_transcribe() Token { + return Token(C.whisper_token_transcribe()) +} + +// Performance information +func (ctx *Context) Whisper_print_timings() { + C.whisper_print_timings((*C.struct_whisper_context)(ctx)) +} + +// Performance information +func (ctx *Context) Whisper_reset_timings() { + C.whisper_reset_timings((*C.struct_whisper_context)(ctx)) +} + +// Print system information +func Whisper_print_system_info() string { + return C.GoString(C.whisper_print_system_info()) +} + +// Return default parameters for a strategy +func (ctx *Context) Whisper_full_default_params(strategy SamplingStrategy) Params { + // Get default parameters + return Params(C.whisper_full_default_params_cb((*C.struct_whisper_context)(ctx), C.enum_whisper_sampling_strategy(strategy))) +} + +// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text +// Uses the specified decoding strategy to obtain the text. +func (ctx *Context) Whisper_full(params Params, samples []float32, encoderBeginCallback func() bool, newSegmentCallback func(int)) error { + registerEncoderBeginCallback(ctx, encoderBeginCallback) + registerNewSegmentCallback(ctx, newSegmentCallback) + defer registerEncoderBeginCallback(ctx, nil) + defer registerNewSegmentCallback(ctx, nil) + if C.whisper_full((*C.struct_whisper_context)(ctx), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples))) == 0 { + return nil + } else { + return ErrConversionFailed + } +} + +// Split the input audio in chunks and process each chunk separately using whisper_full() +// It seems this approach can offer some speedup in some cases. +// However, the transcription accuracy can be worse at the beginning and end of each chunk. +func (ctx *Context) Whisper_full_parallel(params Params, samples []float32, processors int, encoderBeginCallback func() bool, newSegmentCallback func(int)) error { + registerEncoderBeginCallback(ctx, encoderBeginCallback) + registerNewSegmentCallback(ctx, newSegmentCallback) + defer registerEncoderBeginCallback(ctx, nil) + defer registerNewSegmentCallback(ctx, nil) + + if C.whisper_full_parallel((*C.struct_whisper_context)(ctx), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples)), C.int(processors)) == 0 { + return nil + } else { + return ErrConversionFailed + } +} + +// Number of generated text segments. +// A segment can be a few words, a sentence, or even a paragraph. +func (ctx *Context) Whisper_full_n_segments() int { + return int(C.whisper_full_n_segments((*C.struct_whisper_context)(ctx))) +} + +// Get the start and end time of the specified segment. +func (ctx *Context) Whisper_full_get_segment_t0(segment int) int64 { + return int64(C.whisper_full_get_segment_t0((*C.struct_whisper_context)(ctx), C.int(segment))) +} + +// Get the start and end time of the specified segment. +func (ctx *Context) Whisper_full_get_segment_t1(segment int) int64 { + return int64(C.whisper_full_get_segment_t1((*C.struct_whisper_context)(ctx), C.int(segment))) +} + +// Get the text of the specified segment. +func (ctx *Context) Whisper_full_get_segment_text(segment int) string { + return C.GoString(C.whisper_full_get_segment_text((*C.struct_whisper_context)(ctx), C.int(segment))) +} + +// Get number of tokens in the specified segment. +func (ctx *Context) Whisper_full_n_tokens(segment int) int { + return int(C.whisper_full_n_tokens((*C.struct_whisper_context)(ctx), C.int(segment))) +} + +// Get the token text of the specified token index in the specified segment. +func (ctx *Context) Whisper_full_get_token_text(segment int, token int) string { + return C.GoString(C.whisper_full_get_token_text((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token))) +} + +// Get the token of the specified token index in the specified segment. +func (ctx *Context) Whisper_full_get_token_id(segment int, token int) Token { + return Token(C.whisper_full_get_token_id((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token))) +} + +// Get token data for the specified token in the specified segment. +// This contains probabilities, timestamps, etc. +func (ctx *Context) whisper_full_get_token_data(segment int, token int) TokenData { + return TokenData(C.whisper_full_get_token_data((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token))) +} + +// Get the probability of the specified token in the specified segment. +func (ctx *Context) Whisper_full_get_token_p(segment int, token int) float32 { + return float32(C.whisper_full_get_token_p((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token))) +} + +/////////////////////////////////////////////////////////////////////////////// +// CALLBACKS + +var ( + cbNewSegment = make(map[unsafe.Pointer]func(int)) + cbEncoderBegin = make(map[unsafe.Pointer]func() bool) +) + +func registerNewSegmentCallback(ctx *Context, fn func(int)) { + if fn == nil { + delete(cbNewSegment, unsafe.Pointer(ctx)) + } else { + cbNewSegment[unsafe.Pointer(ctx)] = fn + } +} + +func registerEncoderBeginCallback(ctx *Context, fn func() bool) { + if fn == nil { + delete(cbEncoderBegin, unsafe.Pointer(ctx)) + } else { + cbEncoderBegin[unsafe.Pointer(ctx)] = fn + } +} + +//export callNewSegment +func callNewSegment(user_data unsafe.Pointer, new C.int) { + if fn, ok := cbNewSegment[user_data]; ok { + fn(int(new)) + } +} + +//export callEncoderBegin +func callEncoderBegin(user_data unsafe.Pointer) C.bool { + if fn, ok := cbEncoderBegin[user_data]; ok { + if fn() { + return C.bool(true) + } else { + return C.bool(false) + } + } + return true +} diff --git a/bindings/go/whisper_test.go b/bindings/go/whisper_test.go new file mode 100644 index 00000000..d7b8caef --- /dev/null +++ b/bindings/go/whisper_test.go @@ -0,0 +1,110 @@ +package whisper_test + +import ( + "os" + "runtime" + "testing" + "time" + + // Packages + whisper "github.com/ggerganov/whisper.cpp/bindings/go" + wav "github.com/go-audio/wav" + assert "github.com/stretchr/testify/assert" +) + +const ( + ModelPath = "models/ggml-small.en.bin" + SamplePath = "samples/jfk.wav" +) + +func Test_Whisper_000(t *testing.T) { + assert := assert.New(t) + if _, err := os.Stat(ModelPath); os.IsNotExist(err) { + t.Skip("Skipping test, model not found:", ModelPath) + } + ctx := whisper.Whisper_init(ModelPath) + assert.NotNil(ctx) + ctx.Whisper_free() +} + +func Test_Whisper_001(t *testing.T) { + assert := assert.New(t) + if _, err := os.Stat(ModelPath); os.IsNotExist(err) { + t.Skip("Skipping test, model not found:", ModelPath) + } + if _, err := os.Stat(SamplePath); os.IsNotExist(err) { + t.Skip("Skipping test, sample not found:", SamplePath) + } + + // Open samples + fh, err := os.Open(SamplePath) + assert.NoError(err) + defer fh.Close() + + // Read samples + d := wav.NewDecoder(fh) + buf, err := d.FullPCMBuffer() + assert.NoError(err) + + // Run whisper + ctx := whisper.Whisper_init(ModelPath) + assert.NotNil(ctx) + defer ctx.Whisper_free() + assert.NoError(ctx.Whisper_full(ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY), buf.AsFloat32Buffer().Data, nil, nil)) + + // Print out tokens + num_segments := ctx.Whisper_full_n_segments() + assert.GreaterOrEqual(num_segments, 1) + for i := 0; i < num_segments; i++ { + str := ctx.Whisper_full_get_segment_text(i) + assert.NotEmpty(str) + t0 := time.Duration(ctx.Whisper_full_get_segment_t0(i)) * time.Millisecond + t1 := time.Duration(ctx.Whisper_full_get_segment_t1(i)) * time.Millisecond + t.Logf("[%6s->%-6s] %q", t0, t1, str) + } +} + +func Test_Whisper_002(t *testing.T) { + assert := assert.New(t) + for i := 0; i < whisper.Whisper_lang_max_id(); i++ { + str := whisper.Whisper_lang_str(i) + assert.NotEmpty(str) + t.Log(str) + } +} + +func Test_Whisper_003(t *testing.T) { + threads := runtime.NumCPU() + assert := assert.New(t) + if _, err := os.Stat(ModelPath); os.IsNotExist(err) { + t.Skip("Skipping test, model not found:", ModelPath) + } + if _, err := os.Stat(SamplePath); os.IsNotExist(err) { + t.Skip("Skipping test, sample not found:", SamplePath) + } + + // Open samples + fh, err := os.Open(SamplePath) + assert.NoError(err) + defer fh.Close() + + // Read samples + d := wav.NewDecoder(fh) + buf, err := d.FullPCMBuffer() + assert.NoError(err) + + // Make the model + ctx := whisper.Whisper_init(ModelPath) + assert.NotNil(ctx) + defer ctx.Whisper_free() + + // Get MEL + assert.NoError(ctx.Whisper_pcm_to_mel(buf.AsFloat32Buffer().Data, threads)) + + // Get Languages + languages, err := ctx.Whisper_lang_auto_detect(0, threads) + assert.NoError(err) + for i, p := range languages { + t.Logf("%s: %f", whisper.Whisper_lang_str(i), p) + } +}