whisper.cpp/bindings/ruby/ext/ruby_whisper_params.c
KITAITI Makoto 7ffcd05267
ruby : Make context accept initial parameters, API to retrieve a segment and more (#2749)
* Fix type signature for Whisper.log_set

* Use cache file for model when offline

* Extract ruby_whisper_transcribe() into a file

* Extract Whisper::Error

* Use FileList for ext/*.{c,cpp,h}

* Extract Whisper::Segment

* Extract Whisper::Model

* Extract Whisper::Params

* Extract Whisper::Context

* Extract log_callback function

* Write base code in C rather than C++

* Use chdir instead of Dir.chdir in Rakefile

* Define alloc func for Whisper::Model

* Define Whisper::Params' calback and user data reader

* Add test for Whisper::Params.new with keyword arguments

* Make Whisper::Params.new accept keyword arguments

* Update type signatures

* Update README

* Update CLEAN targets

* Fix document comment for Whisper::Params#new_segment_callback=

* Use macro to define params

* Fix dependency of build task

* Set Whisper.finalize_log_callback visibility to private

* Make Whisper::Context#full and full_parallel return self

* Add test for Whisper::Context#full_get_segment

* Add Whisper::Context#full_get_segment

* Update signatures

* Update README

* Fix signature

* Resplace #initialize with .new in signature file [skip ci]

* Fix potential overflow
2025-01-21 09:39:54 +02:00

1078 lines
28 KiB
C

#include <ruby.h>
#include "ruby_whisper.h"
#define BOOL_PARAMS_SETTER(self, prop, value) \
ruby_whisper_params *rwp; \
Data_Get_Struct(self, ruby_whisper_params, rwp); \
if (value == Qfalse || value == Qnil) { \
rwp->params.prop = false; \
} else { \
rwp->params.prop = true; \
} \
return value; \
#define BOOL_PARAMS_GETTER(self, prop) \
ruby_whisper_params *rwp; \
Data_Get_Struct(self, ruby_whisper_params, rwp); \
if (rwp->params.prop) { \
return Qtrue; \
} else { \
return Qfalse; \
}
#define DEFINE_PARAM(param_name, nth) \
id_ ## param_name = rb_intern(#param_name); \
param_names[nth] = id_ ## param_name; \
rb_define_method(cParams, #param_name, ruby_whisper_params_get_ ## param_name, 0); \
rb_define_method(cParams, #param_name "=", ruby_whisper_params_set_ ## param_name, 1);
#define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 30
extern VALUE cParams;
extern ID id_call;
extern VALUE rb_whisper_segment_initialize(VALUE context, int index);
static ID param_names[RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT];
static ID id_language;
static ID id_translate;
static ID id_no_context;
static ID id_single_segment;
static ID id_print_special;
static ID id_print_progress;
static ID id_print_realtime;
static ID id_print_timestamps;
static ID id_suppress_blank;
static ID id_suppress_nst;
static ID id_token_timestamps;
static ID id_split_on_word;
static ID id_initial_prompt;
static ID id_diarize;
static ID id_offset;
static ID id_duration;
static ID id_max_text_tokens;
static ID id_temperature;
static ID id_max_initial_ts;
static ID id_length_penalty;
static ID id_temperature_inc;
static ID id_entropy_thold;
static ID id_logprob_thold;
static ID id_no_speech_thold;
static ID id_new_segment_callback;
static ID id_new_segment_callback_user_data;
static ID id_progress_callback;
static ID id_progress_callback_user_data;
static ID id_abort_callback;
static ID id_abort_callback_user_data;
static void
rb_whisper_callbcack_container_mark(ruby_whisper_callback_container *rwc)
{
rb_gc_mark(rwc->user_data);
rb_gc_mark(rwc->callback);
rb_gc_mark(rwc->callbacks);
}
static ruby_whisper_callback_container*
rb_whisper_callback_container_allocate() {
ruby_whisper_callback_container *container;
container = ALLOC(ruby_whisper_callback_container);
container->context = NULL;
container->user_data = Qnil;
container->callback = Qnil;
container->callbacks = rb_ary_new();
return container;
}
static void new_segment_callback(struct whisper_context *ctx, struct whisper_state *state, int n_new, void *user_data) {
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
// Currently, doesn't support state because
// those require to resolve GC-related problems.
if (!NIL_P(container->callback)) {
rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(n_new), container->user_data);
}
const long callbacks_len = RARRAY_LEN(container->callbacks);
if (0 == callbacks_len) {
return;
}
const int n_segments = whisper_full_n_segments_from_state(state);
for (int i = n_new; i > 0; i--) {
int i_segment = n_segments - i;
VALUE segment = rb_whisper_segment_initialize(*container->context, i_segment);
for (int j = 0; j < callbacks_len; j++) {
VALUE cb = rb_ary_entry(container->callbacks, j);
rb_funcall(cb, id_call, 1, segment);
}
}
}
static void progress_callback(struct whisper_context *ctx, struct whisper_state *state, int progress_cur, void *user_data) {
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
const VALUE progress = INT2NUM(progress_cur);
// Currently, doesn't support state because
// those require to resolve GC-related problems.
if (!NIL_P(container->callback)) {
rb_funcall(container->callback, id_call, 4, *container->context, Qnil, progress, container->user_data);
}
const long callbacks_len = RARRAY_LEN(container->callbacks);
if (0 == callbacks_len) {
return;
}
for (int j = 0; j < callbacks_len; j++) {
VALUE cb = rb_ary_entry(container->callbacks, j);
rb_funcall(cb, id_call, 1, progress);
}
}
static bool abort_callback(void * user_data) {
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
if (!NIL_P(container->callback)) {
VALUE result = rb_funcall(container->callback, id_call, 1, container->user_data);
if (!NIL_P(result) && Qfalse != result) {
return true;
}
}
const long callbacks_len = RARRAY_LEN(container->callbacks);
if (0 == callbacks_len) {
return false;
}
for (int j = 0; j < callbacks_len; j++) {
VALUE cb = rb_ary_entry(container->callbacks, j);
VALUE result = rb_funcall(cb, id_call, 1, container->user_data);
if (!NIL_P(result) && Qfalse != result) {
return true;
}
}
return false;
}
void register_callbacks(ruby_whisper_params * rwp, VALUE * context) {
if (!NIL_P(rwp->new_segment_callback_container->callback) || 0 != RARRAY_LEN(rwp->new_segment_callback_container->callbacks)) {
rwp->new_segment_callback_container->context = context;
rwp->params.new_segment_callback = new_segment_callback;
rwp->params.new_segment_callback_user_data = rwp->new_segment_callback_container;
}
if (!NIL_P(rwp->progress_callback_container->callback) || 0 != RARRAY_LEN(rwp->progress_callback_container->callbacks)) {
rwp->progress_callback_container->context = context;
rwp->params.progress_callback = progress_callback;
rwp->params.progress_callback_user_data = rwp->progress_callback_container;
}
if (!NIL_P(rwp->abort_callback_container->callback) || 0 != RARRAY_LEN(rwp->abort_callback_container->callbacks)) {
rwp->abort_callback_container->context = context;
rwp->params.abort_callback = abort_callback;
rwp->params.abort_callback_user_data = rwp->abort_callback_container;
}
}
void
rb_whisper_params_mark(ruby_whisper_params *rwp)
{
rb_whisper_callbcack_container_mark(rwp->new_segment_callback_container);
rb_whisper_callbcack_container_mark(rwp->progress_callback_container);
rb_whisper_callbcack_container_mark(rwp->abort_callback_container);
}
void
ruby_whisper_params_free(ruby_whisper_params *rwp)
{
}
void
rb_whisper_params_free(ruby_whisper_params *rwp)
{
// How to free user_data and callback only when not referred to by others?
ruby_whisper_params_free(rwp);
free(rwp);
}
static VALUE
ruby_whisper_params_allocate(VALUE klass)
{
ruby_whisper_params *rwp;
rwp = ALLOC(ruby_whisper_params);
rwp->params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
rwp->diarize = false;
rwp->new_segment_callback_container = rb_whisper_callback_container_allocate();
rwp->progress_callback_container = rb_whisper_callback_container_allocate();
rwp->abort_callback_container = rb_whisper_callback_container_allocate();
return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp);
}
/*
* params.language = "auto" | "en", etc...
*
* call-seq:
* language = lang_name -> lang_name
*/
static VALUE
ruby_whisper_params_set_language(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
if (value == Qfalse || value == Qnil) {
rwp->params.language = "auto";
} else {
rwp->params.language = StringValueCStr(value);
}
return value;
}
/*
* call-seq:
* language -> String
*/
static VALUE
ruby_whisper_params_get_language(VALUE self)
{
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
if (rwp->params.language) {
return rb_str_new2(rwp->params.language);
} else {
return rb_str_new2("auto");
}
}
/*
* call-seq:
* translate = do_translate -> do_translate
*/
static VALUE
ruby_whisper_params_set_translate(VALUE self, VALUE value)
{
BOOL_PARAMS_SETTER(self, translate, value)
}
/*
* call-seq:
* translate -> bool
*/
static VALUE
ruby_whisper_params_get_translate(VALUE self)
{
BOOL_PARAMS_GETTER(self, translate)
}
/*
* call-seq:
* no_context = dont_use_context -> dont_use_context
*/
static VALUE
ruby_whisper_params_set_no_context(VALUE self, VALUE value)
{
BOOL_PARAMS_SETTER(self, no_context, value)
}
/*
* If true, does not use past transcription (if any) as initial prompt for the decoder.
*
* call-seq:
* no_context -> bool
*/
static VALUE
ruby_whisper_params_get_no_context(VALUE self)
{
BOOL_PARAMS_GETTER(self, no_context)
}
/*
* call-seq:
* single_segment = force_single -> force_single
*/
static VALUE
ruby_whisper_params_set_single_segment(VALUE self, VALUE value)
{
BOOL_PARAMS_SETTER(self, single_segment, value)
}
/*
* If true, forces single segment output (useful for streaming).
*
* call-seq:
* single_segment -> bool
*/
static VALUE
ruby_whisper_params_get_single_segment(VALUE self)
{
BOOL_PARAMS_GETTER(self, single_segment)
}
/*
* call-seq:
* print_special = force_print -> force_print
*/
static VALUE
ruby_whisper_params_set_print_special(VALUE self, VALUE value)
{
BOOL_PARAMS_SETTER(self, print_special, value)
}
/*
* If true, prints special tokens (e.g. <SOT>, <EOT>, <BEG>, etc.).
*
* call-seq:
* print_special -> bool
*/
static VALUE
ruby_whisper_params_get_print_special(VALUE self)
{
BOOL_PARAMS_GETTER(self, print_special)
}
/*
* call-seq:
* print_progress = force_print -> force_print
*/
static VALUE
ruby_whisper_params_set_print_progress(VALUE self, VALUE value)
{
BOOL_PARAMS_SETTER(self, print_progress, value)
}
/*
* If true, prints progress information.
*
* call-seq:
* print_progress -> bool
*/
static VALUE
ruby_whisper_params_get_print_progress(VALUE self)
{
BOOL_PARAMS_GETTER(self, print_progress)
}
/*
* call-seq:
* print_realtime = force_print -> force_print
*/
static VALUE
ruby_whisper_params_set_print_realtime(VALUE self, VALUE value)
{
BOOL_PARAMS_SETTER(self, print_realtime, value)
}
/*
* If true, prints results from within whisper.cpp. (avoid it, use callback instead)
* call-seq:
* print_realtime -> bool
*/
static VALUE
ruby_whisper_params_get_print_realtime(VALUE self)
{
BOOL_PARAMS_GETTER(self, print_realtime)
}
/*
* call-seq:
* print_timestamps = force_print -> force_print
*/
static VALUE
ruby_whisper_params_set_print_timestamps(VALUE self, VALUE value)
{
BOOL_PARAMS_SETTER(self, print_timestamps, value)
}
/*
* If true, prints timestamps for each text segment when printing realtime.
*
* call-seq:
* print_timestamps -> bool
*/
static VALUE
ruby_whisper_params_get_print_timestamps(VALUE self)
{
BOOL_PARAMS_GETTER(self, print_timestamps)
}
/*
* call-seq:
* suppress_blank = force_suppress -> force_suppress
*/
static VALUE
ruby_whisper_params_set_suppress_blank(VALUE self, VALUE value)
{
BOOL_PARAMS_SETTER(self, suppress_blank, value)
}
/*
* If true, suppresses blank outputs.
*
* call-seq:
* suppress_blank -> bool
*/
static VALUE
ruby_whisper_params_get_suppress_blank(VALUE self)
{
BOOL_PARAMS_GETTER(self, suppress_blank)
}
/*
* call-seq:
* suppress_nst = force_suppress -> force_suppress
*/
static VALUE
ruby_whisper_params_set_suppress_nst(VALUE self, VALUE value)
{
BOOL_PARAMS_SETTER(self, suppress_nst, value)
}
/*
* If true, suppresses non-speech-tokens.
*
* call-seq:
* suppress_nst -> bool
*/
static VALUE
ruby_whisper_params_get_suppress_nst(VALUE self)
{
BOOL_PARAMS_GETTER(self, suppress_nst)
}
/*
* If true, enables token-level timestamps.
*
* call-seq:
* token_timestamps -> bool
*/
static VALUE
ruby_whisper_params_get_token_timestamps(VALUE self)
{
BOOL_PARAMS_GETTER(self, token_timestamps)
}
/*
* call-seq:
* token_timestamps = force_timestamps -> force_timestamps
*/
static VALUE
ruby_whisper_params_set_token_timestamps(VALUE self, VALUE value)
{
BOOL_PARAMS_SETTER(self, token_timestamps, value)
}
/*
* If true, split on word rather than on token (when used with max_len).
*
* call-seq:
* translate -> bool
*/
static VALUE
ruby_whisper_params_get_split_on_word(VALUE self)
{
BOOL_PARAMS_GETTER(self, split_on_word)
}
/*
* call-seq:
* split_on_word = force_split -> force_split
*/
static VALUE
ruby_whisper_params_set_split_on_word(VALUE self, VALUE value)
{
BOOL_PARAMS_SETTER(self, split_on_word, value)
}
/*
* Tokens to provide to the whisper decoder as initial prompt
* these are prepended to any existing text context from a previous call
* use whisper_tokenize() to convert text to tokens.
* Maximum of whisper_n_text_ctx()/2 tokens are used (typically 224).
*
* call-seq:
* initial_prompt -> String
*/
static VALUE
ruby_whisper_params_get_initial_prompt(VALUE self)
{
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
return rwp->params.initial_prompt == NULL ? Qnil : rb_str_new2(rwp->params.initial_prompt);
}
/*
* call-seq:
* initial_prompt = prompt -> prompt
*/
static VALUE
ruby_whisper_params_set_initial_prompt(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->params.initial_prompt = StringValueCStr(value);
return value;
}
/*
* If true, enables diarization.
*
* call-seq:
* diarize -> bool
*/
static VALUE
ruby_whisper_params_get_diarize(VALUE self)
{
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
if (rwp->diarize) {
return Qtrue;
} else {
return Qfalse;
}
}
/*
* call-seq:
* diarize = force_diarize -> force_diarize
*/
static VALUE
ruby_whisper_params_set_diarize(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
if (value == Qfalse || value == Qnil) {
rwp->diarize = false;
} else {
rwp->diarize = true;
} \
return value;
}
/*
* Start offset in ms.
*
* call-seq:
* offset -> Integer
*/
static VALUE
ruby_whisper_params_get_offset(VALUE self)
{
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
return INT2NUM(rwp->params.offset_ms);
}
/*
* call-seq:
* offset = offset_ms -> offset_ms
*/
static VALUE
ruby_whisper_params_set_offset(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->params.offset_ms = NUM2INT(value);
return value;
}
/*
* Audio duration to process in ms.
*
* call-seq:
* duration -> Integer
*/
static VALUE
ruby_whisper_params_get_duration(VALUE self)
{
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
return INT2NUM(rwp->params.duration_ms);
}
/*
* call-seq:
* duration = duration_ms -> duration_ms
*/
static VALUE
ruby_whisper_params_set_duration(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->params.duration_ms = NUM2INT(value);
return value;
}
/*
* Max tokens to use from past text as prompt for the decoder.
*
* call-seq:
* max_text_tokens -> Integer
*/
static VALUE
ruby_whisper_params_get_max_text_tokens(VALUE self)
{
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
return INT2NUM(rwp->params.n_max_text_ctx);
}
/*
* call-seq:
* max_text_tokens = n_tokens -> n_tokens
*/
static VALUE
ruby_whisper_params_set_max_text_tokens(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->params.n_max_text_ctx = NUM2INT(value);
return value;
}
/*
* call-seq:
* temperature -> Float
*/
static VALUE
ruby_whisper_params_get_temperature(VALUE self)
{
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
return DBL2NUM(rwp->params.temperature);
}
/*
* call-seq:
* temperature = temp -> temp
*/
static VALUE
ruby_whisper_params_set_temperature(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->params.temperature = RFLOAT_VALUE(value);
return value;
}
/*
* See https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97
*
* call-seq:
* max_initial_ts -> Flaot
*/
static VALUE
ruby_whisper_params_get_max_initial_ts(VALUE self)
{
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
return DBL2NUM(rwp->params.max_initial_ts);
}
/*
* call-seq:
* max_initial_ts = timestamp -> timestamp
*/
static VALUE
ruby_whisper_params_set_max_initial_ts(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->params.max_initial_ts = RFLOAT_VALUE(value);
return value;
}
/*
* call-seq:
* length_penalty -> Float
*/
static VALUE
ruby_whisper_params_get_length_penalty(VALUE self)
{
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
return DBL2NUM(rwp->params.length_penalty);
}
/*
* call-seq:
* length_penalty = penalty -> penalty
*/
static VALUE
ruby_whisper_params_set_length_penalty(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->params.length_penalty = RFLOAT_VALUE(value);
return value;
}
/*
* call-seq:
* temperature_inc -> Float
*/
static VALUE
ruby_whisper_params_get_temperature_inc(VALUE self)
{
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
return DBL2NUM(rwp->params.temperature_inc);
}
/*
* call-seq:
* temperature_inc = inc -> inc
*/
static VALUE
ruby_whisper_params_set_temperature_inc(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->params.temperature_inc = RFLOAT_VALUE(value);
return value;
}
/*
* Similar to OpenAI's "compression_ratio_threshold"
*
* call-seq:
* entropy_thold -> Float
*/
static VALUE
ruby_whisper_params_get_entropy_thold(VALUE self)
{
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
return DBL2NUM(rwp->params.entropy_thold);
}
/*
* call-seq:
* entropy_thold = threshold -> threshold
*/
static VALUE
ruby_whisper_params_set_entropy_thold(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->params.entropy_thold = RFLOAT_VALUE(value);
return value;
}
/*
* call-seq:
* logprob_thold -> Float
*/
static VALUE
ruby_whisper_params_get_logprob_thold(VALUE self)
{
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
return DBL2NUM(rwp->params.logprob_thold);
}
/*
* call-seq:
* logprob_thold = threshold -> threshold
*/
static VALUE
ruby_whisper_params_set_logprob_thold(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->params.logprob_thold = RFLOAT_VALUE(value);
return value;
}
/*
* call-seq:
* no_speech_thold -> Float
*/
static VALUE
ruby_whisper_params_get_no_speech_thold(VALUE self)
{
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
return DBL2NUM(rwp->params.no_speech_thold);
}
/*
* call-seq:
* no_speech_thold = threshold -> threshold
*/
static VALUE
ruby_whisper_params_set_no_speech_thold(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->params.no_speech_thold = RFLOAT_VALUE(value);
return value;
}
static VALUE
ruby_whisper_params_get_new_segment_callback(VALUE self)
{
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
return rwp->new_segment_callback_container->callback;
}
/*
* Sets new segment callback, called for every newly generated text segment.
*
* params.new_segment_callback = ->(context, _, n_new, user_data) {
* # ...
* }
*
* call-seq:
* new_segment_callback = callback -> callback
*/
static VALUE
ruby_whisper_params_set_new_segment_callback(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->new_segment_callback_container->callback = value;
return value;
}
static VALUE
ruby_whisper_params_get_new_segment_callback_user_data(VALUE self)
{
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
return rwp->new_segment_callback_container->user_data;
}
/*
* Sets user data passed to the last argument of new segment callback.
*
* call-seq:
* new_segment_callback_user_data = user_data -> use_data
*/
static VALUE
ruby_whisper_params_set_new_segment_callback_user_data(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->new_segment_callback_container->user_data = value;
return value;
}
static VALUE
ruby_whisper_params_get_progress_callback(VALUE self)
{
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
return rwp->progress_callback_container->callback;
}
/*
* Sets progress callback, called on each progress update.
*
* params.new_segment_callback = ->(context, _, progress, user_data) {
* # ...
* }
*
* +progress+ is an Integer between 0 and 100.
*
* call-seq:
* progress_callback = callback -> callback
*/
static VALUE
ruby_whisper_params_set_progress_callback(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->progress_callback_container->callback = value;
return value;
}
static VALUE
ruby_whisper_params_get_progress_callback_user_data(VALUE self)
{
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
return rwp->progress_callback_container->user_data;
}
/*
* Sets user data passed to the last argument of progress callback.
*
* call-seq:
* progress_callback_user_data = user_data -> use_data
*/
static VALUE
ruby_whisper_params_set_progress_callback_user_data(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->progress_callback_container->user_data = value;
return value;
}
static VALUE
ruby_whisper_params_get_abort_callback(VALUE self)
{
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
return rwp->abort_callback_container->callback;
}
/*
* Sets abort callback, called to check if the process should be aborted.
*
* params.abort_callback = ->(user_data) {
* # ...
* }
*
* call-seq:
* abort_callback = callback -> callback
*/
static VALUE
ruby_whisper_params_set_abort_callback(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->abort_callback_container->callback = value;
return value;
}
static VALUE
ruby_whisper_params_get_abort_callback_user_data(VALUE self)
{
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
return rwp->abort_callback_container->user_data;
}
/*
* Sets user data passed to the last argument of abort callback.
*
* call-seq:
* abort_callback_user_data = user_data -> use_data
*/
static VALUE
ruby_whisper_params_set_abort_callback_user_data(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->abort_callback_container->user_data = value;
return value;
}
#define SET_PARAM_IF_SAME(param_name) \
if (id == id_ ## param_name) { \
ruby_whisper_params_set_ ## param_name(self, value); \
continue; \
}
static VALUE
ruby_whisper_params_initialize(int argc, VALUE *argv, VALUE self)
{
VALUE kw_hash;
VALUE values[RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT] = {Qundef};
VALUE value;
ruby_whisper_params *rwp;
ID id;
int i;
rb_scan_args_kw(RB_SCAN_ARGS_KEYWORDS, argc, argv, ":", &kw_hash);
if (NIL_P(kw_hash)) {
return self;
}
rb_get_kwargs(kw_hash, &param_names, 0, RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT, &values);
Data_Get_Struct(self, ruby_whisper_params, rwp);
for (i = 0; i < RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT; i++) {
id = param_names[i];
value = values[i];
if (value == Qundef) {
continue;
}
if (id == id_diarize) {
rwp->diarize = value;
continue;
} else {
SET_PARAM_IF_SAME(language)
SET_PARAM_IF_SAME(translate)
SET_PARAM_IF_SAME(no_context)
SET_PARAM_IF_SAME(single_segment)
SET_PARAM_IF_SAME(print_special)
SET_PARAM_IF_SAME(print_progress)
SET_PARAM_IF_SAME(print_realtime)
SET_PARAM_IF_SAME(print_timestamps)
SET_PARAM_IF_SAME(suppress_blank)
SET_PARAM_IF_SAME(suppress_nst)
SET_PARAM_IF_SAME(token_timestamps)
SET_PARAM_IF_SAME(split_on_word)
SET_PARAM_IF_SAME(initial_prompt)
SET_PARAM_IF_SAME(offset)
SET_PARAM_IF_SAME(duration)
SET_PARAM_IF_SAME(max_text_tokens)
SET_PARAM_IF_SAME(temperature)
SET_PARAM_IF_SAME(max_initial_ts)
SET_PARAM_IF_SAME(length_penalty)
SET_PARAM_IF_SAME(temperature_inc)
SET_PARAM_IF_SAME(entropy_thold)
SET_PARAM_IF_SAME(logprob_thold)
SET_PARAM_IF_SAME(no_speech_thold)
SET_PARAM_IF_SAME(new_segment_callback)
SET_PARAM_IF_SAME(new_segment_callback_user_data)
SET_PARAM_IF_SAME(progress_callback)
SET_PARAM_IF_SAME(progress_callback_user_data)
SET_PARAM_IF_SAME(abort_callback)
SET_PARAM_IF_SAME(abort_callback_user_data)
}
}
return self;
}
#undef SET_PARAM_IF_SAME
/*
* Hook called on new segment. Yields each Whisper::Segment.
*
* whisper.on_new_segment do |segment|
* # ...
* end
*
* call-seq:
* on_new_segment {|segment| ... }
*/
static VALUE
ruby_whisper_params_on_new_segment(VALUE self)
{
ruby_whisper_params *rws;
Data_Get_Struct(self, ruby_whisper_params, rws);
const VALUE blk = rb_block_proc();
rb_ary_push(rws->new_segment_callback_container->callbacks, blk);
return Qnil;
}
/*
* Hook called on progress update. Yields each progress Integer between 0 and 100.
*
* whisper.on_progress do |progress|
* # ...
* end
*
* call-seq:
* on_progress {|progress| ... }
*/
static VALUE
ruby_whisper_params_on_progress(VALUE self)
{
ruby_whisper_params *rws;
Data_Get_Struct(self, ruby_whisper_params, rws);
const VALUE blk = rb_block_proc();
rb_ary_push(rws->progress_callback_container->callbacks, blk);
return Qnil;
}
/*
* Call block to determine whether abort or not. Return +true+ when you want to abort.
*
* params.abort_on do
* if some_condition
* true # abort
* else
* false # continue
* end
* end
*
* call-seq:
* abort_on { ... }
*/
static VALUE
ruby_whisper_params_abort_on(VALUE self)
{
ruby_whisper_params *rws;
Data_Get_Struct(self, ruby_whisper_params, rws);
const VALUE blk = rb_block_proc();
rb_ary_push(rws->abort_callback_container->callbacks, blk);
return Qnil;
}
void
init_ruby_whisper_params(VALUE *mWhisper)
{
cParams = rb_define_class_under(*mWhisper, "Params", rb_cObject);
rb_define_alloc_func(cParams, ruby_whisper_params_allocate);
rb_define_method(cParams, "initialize", ruby_whisper_params_initialize, -1);
DEFINE_PARAM(language, 0)
DEFINE_PARAM(translate, 1)
DEFINE_PARAM(no_context, 2)
DEFINE_PARAM(single_segment, 3)
DEFINE_PARAM(print_special, 4)
DEFINE_PARAM(print_progress, 5)
DEFINE_PARAM(print_realtime, 6)
DEFINE_PARAM(print_timestamps, 7)
DEFINE_PARAM(suppress_blank, 8)
DEFINE_PARAM(suppress_nst, 9)
DEFINE_PARAM(token_timestamps, 10)
DEFINE_PARAM(split_on_word, 11)
DEFINE_PARAM(initial_prompt, 12)
DEFINE_PARAM(diarize, 13)
DEFINE_PARAM(offset, 14)
DEFINE_PARAM(duration, 15)
DEFINE_PARAM(max_text_tokens, 16)
DEFINE_PARAM(temperature, 17)
DEFINE_PARAM(max_initial_ts, 18)
DEFINE_PARAM(length_penalty, 19)
DEFINE_PARAM(temperature_inc, 20)
DEFINE_PARAM(entropy_thold, 21)
DEFINE_PARAM(logprob_thold, 22)
DEFINE_PARAM(no_speech_thold, 23)
DEFINE_PARAM(new_segment_callback, 24)
DEFINE_PARAM(new_segment_callback_user_data, 25)
DEFINE_PARAM(progress_callback, 26)
DEFINE_PARAM(progress_callback_user_data, 27)
DEFINE_PARAM(abort_callback, 28)
DEFINE_PARAM(abort_callback_user_data, 29)
rb_define_method(cParams, "on_new_segment", ruby_whisper_params_on_new_segment, 0);
rb_define_method(cParams, "on_progress", ruby_whisper_params_on_progress, 0);
rb_define_method(cParams, "abort_on", ruby_whisper_params_abort_on, 0);
}