mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-02-06 10:59:11 +00:00
go : improve progress reporting and callback handling (#1024)
- Rename `cb` to `callNewSegment` in the `Process` function - Add `callProgress` as a new parameter to the `Process` function - Introduce `ProgressCallback` type for reporting progress during processing - Update `Whisper_full` function to include `progressCallback` parameter - Add `registerProgressCallback` function and `cbProgress` map for handling progress callbacks Signed-off-by: appleboy <appleboy.tw@gmail.com>
This commit is contained in:
parent
6a7f3b8db2
commit
7dfc11843c
@ -152,7 +152,11 @@ func (context *context) WhisperLangAutoDetect(offset_ms int, n_threads int) ([]f
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Process new sample data and return any errors
|
// Process new sample data and return any errors
|
||||||
func (context *context) Process(data []float32, cb SegmentCallback) error {
|
func (context *context) Process(
|
||||||
|
data []float32,
|
||||||
|
callNewSegment SegmentCallback,
|
||||||
|
callProgress ProgressCallback,
|
||||||
|
) error {
|
||||||
if context.model.ctx == nil {
|
if context.model.ctx == nil {
|
||||||
return ErrInternalAppError
|
return ErrInternalAppError
|
||||||
}
|
}
|
||||||
@ -165,24 +169,28 @@ func (context *context) Process(data []float32, cb SegmentCallback) error {
|
|||||||
processors := 0
|
processors := 0
|
||||||
if processors > 1 {
|
if processors > 1 {
|
||||||
if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, nil, func(new int) {
|
if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, nil, func(new int) {
|
||||||
if cb != nil {
|
if callNewSegment != nil {
|
||||||
num_segments := context.model.ctx.Whisper_full_n_segments()
|
num_segments := context.model.ctx.Whisper_full_n_segments()
|
||||||
s0 := num_segments - new
|
s0 := num_segments - new
|
||||||
for i := s0; i < num_segments; i++ {
|
for i := s0; i < num_segments; i++ {
|
||||||
cb(toSegment(context.model.ctx, i))
|
callNewSegment(toSegment(context.model.ctx, i))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
} else if err := context.model.ctx.Whisper_full(context.params, data, nil, func(new int) {
|
} else if err := context.model.ctx.Whisper_full(context.params, data, nil, func(new int) {
|
||||||
if cb != nil {
|
if callNewSegment != nil {
|
||||||
num_segments := context.model.ctx.Whisper_full_n_segments()
|
num_segments := context.model.ctx.Whisper_full_n_segments()
|
||||||
s0 := num_segments - new
|
s0 := num_segments - new
|
||||||
for i := s0; i < num_segments; i++ {
|
for i := s0; i < num_segments; i++ {
|
||||||
cb(toSegment(context.model.ctx, i))
|
callNewSegment(toSegment(context.model.ctx, i))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}, func(progress int) {
|
||||||
|
if callProgress != nil {
|
||||||
|
callProgress(progress)
|
||||||
|
}
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -12,6 +12,10 @@ import (
|
|||||||
// time. It is called during the Process function
|
// time. It is called during the Process function
|
||||||
type SegmentCallback func(Segment)
|
type SegmentCallback func(Segment)
|
||||||
|
|
||||||
|
// ProgressCallback is the callback function for reporting progress during
|
||||||
|
// processing. It is called during the Process function
|
||||||
|
type ProgressCallback func(int)
|
||||||
|
|
||||||
// Model is the interface to a whisper model. Create a new model with the
|
// Model is the interface to a whisper model. Create a new model with the
|
||||||
// function whisper.New(string)
|
// function whisper.New(string)
|
||||||
type Model interface {
|
type Model interface {
|
||||||
@ -47,7 +51,7 @@ type Context interface {
|
|||||||
// Process mono audio data and return any errors.
|
// Process mono audio data and return any errors.
|
||||||
// If defined, newly generated segments are passed to the
|
// If defined, newly generated segments are passed to the
|
||||||
// callback function during processing.
|
// callback function during processing.
|
||||||
Process([]float32, SegmentCallback) error
|
Process([]float32, SegmentCallback, ProgressCallback) error
|
||||||
|
|
||||||
// After process is called, return segments until the end of the stream
|
// After process is called, return segments until the end of the stream
|
||||||
// is reached, when io.EOF is returned.
|
// is reached, when io.EOF is returned.
|
||||||
|
@ -15,6 +15,7 @@ import (
|
|||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
|
|
||||||
extern void callNewSegment(void* user_data, int new);
|
extern void callNewSegment(void* user_data, int new);
|
||||||
|
extern void callProgress(void* user_data, int progress);
|
||||||
extern bool callEncoderBegin(void* user_data);
|
extern bool callEncoderBegin(void* user_data);
|
||||||
|
|
||||||
// Text segment callback
|
// Text segment callback
|
||||||
@ -26,6 +27,15 @@ static void whisper_new_segment_cb(struct whisper_context* ctx, struct whisper_s
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Progress callback
|
||||||
|
// Called on every newly generated text segment
|
||||||
|
// Use the whisper_full_...() functions to obtain the text segments
|
||||||
|
static void whisper_progress_cb(struct whisper_context* ctx, struct whisper_state* state, int progress, void* user_data) {
|
||||||
|
if(user_data != NULL && ctx != NULL) {
|
||||||
|
callProgress(user_data, progress);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Encoder begin callback
|
// Encoder begin callback
|
||||||
// If not NULL, called before the encoder starts
|
// If not NULL, called before the encoder starts
|
||||||
// If it returns false, the computation is aborted
|
// If it returns false, the computation is aborted
|
||||||
@ -43,6 +53,8 @@ static struct whisper_full_params whisper_full_default_params_cb(struct whisper_
|
|||||||
params.new_segment_callback_user_data = (void*)(ctx);
|
params.new_segment_callback_user_data = (void*)(ctx);
|
||||||
params.encoder_begin_callback = whisper_encoder_begin_cb;
|
params.encoder_begin_callback = whisper_encoder_begin_cb;
|
||||||
params.encoder_begin_callback_user_data = (void*)(ctx);
|
params.encoder_begin_callback_user_data = (void*)(ctx);
|
||||||
|
params.progress_callback = whisper_progress_cb;
|
||||||
|
params.progress_callback_user_data = (void*)(ctx);
|
||||||
return params;
|
return params;
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
@ -290,11 +302,19 @@ func (ctx *Context) Whisper_full_default_params(strategy SamplingStrategy) Param
|
|||||||
|
|
||||||
// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
|
// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
|
||||||
// Uses the specified decoding strategy to obtain the text.
|
// Uses the specified decoding strategy to obtain the text.
|
||||||
func (ctx *Context) Whisper_full(params Params, samples []float32, encoderBeginCallback func() bool, newSegmentCallback func(int)) error {
|
func (ctx *Context) Whisper_full(
|
||||||
|
params Params,
|
||||||
|
samples []float32,
|
||||||
|
encoderBeginCallback func() bool,
|
||||||
|
newSegmentCallback func(int),
|
||||||
|
progressCallback func(int),
|
||||||
|
) error {
|
||||||
registerEncoderBeginCallback(ctx, encoderBeginCallback)
|
registerEncoderBeginCallback(ctx, encoderBeginCallback)
|
||||||
registerNewSegmentCallback(ctx, newSegmentCallback)
|
registerNewSegmentCallback(ctx, newSegmentCallback)
|
||||||
|
registerProgressCallback(ctx, progressCallback)
|
||||||
defer registerEncoderBeginCallback(ctx, nil)
|
defer registerEncoderBeginCallback(ctx, nil)
|
||||||
defer registerNewSegmentCallback(ctx, nil)
|
defer registerNewSegmentCallback(ctx, nil)
|
||||||
|
defer registerProgressCallback(ctx, nil)
|
||||||
if C.whisper_full((*C.struct_whisper_context)(ctx), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples))) == 0 {
|
if C.whisper_full((*C.struct_whisper_context)(ctx), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples))) == 0 {
|
||||||
return nil
|
return nil
|
||||||
} else {
|
} else {
|
||||||
@ -370,6 +390,7 @@ func (ctx *Context) Whisper_full_get_token_p(segment int, token int) float32 {
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
cbNewSegment = make(map[unsafe.Pointer]func(int))
|
cbNewSegment = make(map[unsafe.Pointer]func(int))
|
||||||
|
cbProgress = make(map[unsafe.Pointer]func(int))
|
||||||
cbEncoderBegin = make(map[unsafe.Pointer]func() bool)
|
cbEncoderBegin = make(map[unsafe.Pointer]func() bool)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -381,6 +402,14 @@ func registerNewSegmentCallback(ctx *Context, fn func(int)) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func registerProgressCallback(ctx *Context, fn func(int)) {
|
||||||
|
if fn == nil {
|
||||||
|
delete(cbProgress, unsafe.Pointer(ctx))
|
||||||
|
} else {
|
||||||
|
cbProgress[unsafe.Pointer(ctx)] = fn
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func registerEncoderBeginCallback(ctx *Context, fn func() bool) {
|
func registerEncoderBeginCallback(ctx *Context, fn func() bool) {
|
||||||
if fn == nil {
|
if fn == nil {
|
||||||
delete(cbEncoderBegin, unsafe.Pointer(ctx))
|
delete(cbEncoderBegin, unsafe.Pointer(ctx))
|
||||||
@ -396,6 +425,13 @@ func callNewSegment(user_data unsafe.Pointer, new C.int) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//export callProgress
|
||||||
|
func callProgress(user_data unsafe.Pointer, progress C.int) {
|
||||||
|
if fn, ok := cbProgress[user_data]; ok {
|
||||||
|
fn(int(progress))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//export callEncoderBegin
|
//export callEncoderBegin
|
||||||
func callEncoderBegin(user_data unsafe.Pointer) C.bool {
|
func callEncoderBegin(user_data unsafe.Pointer) C.bool {
|
||||||
if fn, ok := cbEncoderBegin[user_data]; ok {
|
if fn, ok := cbEncoderBegin[user_data]; ok {
|
||||||
|
@ -52,7 +52,7 @@ func Test_Whisper_001(t *testing.T) {
|
|||||||
defer ctx.Whisper_free()
|
defer ctx.Whisper_free()
|
||||||
params := ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY)
|
params := ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY)
|
||||||
data := buf.AsFloat32Buffer().Data
|
data := buf.AsFloat32Buffer().Data
|
||||||
err = ctx.Whisper_full(params, data, nil, nil)
|
err = ctx.Whisper_full(params, data, nil, nil, nil)
|
||||||
assert.NoError(err)
|
assert.NoError(err)
|
||||||
|
|
||||||
// Print out tokens
|
// Print out tokens
|
||||||
|
Loading…
x
Reference in New Issue
Block a user