Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 42 additions & 4 deletions bindings/ruby/ext/ruby_whisper_context.c
Original file line number Diff line number Diff line change
Expand Up @@ -117,23 +117,61 @@ ruby_whisper_normalize_model_path(VALUE model_path)
* call-seq:
* new("base.en") -> Whisper::Context
* new("path/to/model.bin") -> Whisper::Context
* new("path/to/model.bin", use_gpu: true, flash_attn: true) -> Whisper::Context
* new(Whisper::Model::URI.new("https://example.net/uri/of/model.bin")) -> Whisper::Context
*
* Initialize a new Whisper context with optional parameters:
* use_gpu: Enable GPU acceleration (default: true)
* flash_attn: Enable flash attention (default: true)
* gpu_device: GPU device to use (default: 0)
* dtw_token_timestamps: Enable DTW token-level timestamps (default: false)
* dtw_aheads_preset: DTW attention heads preset (default: WHISPER_AHEADS_NONE)
*/
static VALUE
ruby_whisper_initialize(int argc, VALUE *argv, VALUE self)
{
ruby_whisper *rw;
VALUE whisper_model_file_path;
VALUE whisper_model_file_path, options;

// TODO: we can support init from buffer here too maybe another ruby object to expose
rb_scan_args(argc, argv, "01", &whisper_model_file_path);
rb_scan_args(argc, argv, "01:", &whisper_model_file_path, &options);
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);

whisper_model_file_path = ruby_whisper_normalize_model_path(whisper_model_file_path);
if (!rb_respond_to(whisper_model_file_path, id_to_s)) {
rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context");
}
rw->context = whisper_init_from_file_with_params(StringValueCStr(whisper_model_file_path), whisper_context_default_params());

// Build context params from options
struct whisper_context_params cparams = whisper_context_default_params();

if (!NIL_P(options)) {
VALUE use_gpu = rb_hash_aref(options, ID2SYM(rb_intern("use_gpu")));
if (!NIL_P(use_gpu)) {
cparams.use_gpu = RTEST(use_gpu);
}

VALUE flash_attn = rb_hash_aref(options, ID2SYM(rb_intern("flash_attn")));
if (!NIL_P(flash_attn)) {
cparams.flash_attn = RTEST(flash_attn);
}

VALUE gpu_device = rb_hash_aref(options, ID2SYM(rb_intern("gpu_device")));
if (!NIL_P(gpu_device)) {
cparams.gpu_device = NUM2INT(gpu_device);
}

VALUE dtw_token_timestamps = rb_hash_aref(options, ID2SYM(rb_intern("dtw_token_timestamps")));
if (!NIL_P(dtw_token_timestamps)) {
cparams.dtw_token_timestamps = RTEST(dtw_token_timestamps);
}

VALUE dtw_aheads_preset = rb_hash_aref(options, ID2SYM(rb_intern("dtw_aheads_preset")));
if (!NIL_P(dtw_aheads_preset)) {
cparams.dtw_aheads_preset = NUM2INT(dtw_aheads_preset);
}
}

rw->context = whisper_init_from_file_with_params(StringValueCStr(whisper_model_file_path), cparams);
if (rw->context == NULL) {
rb_raise(rb_eRuntimeError, "error: failed to initialize whisper context");
}
Expand Down
116 changes: 115 additions & 1 deletion bindings/ruby/ext/ruby_whisper_params.c
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
rb_define_method(cParams, #param_name, ruby_whisper_params_get_ ## param_name, 0); \
rb_define_method(cParams, #param_name "=", ruby_whisper_params_set_ ## param_name, 1);

#define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 37
#define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 42

extern VALUE cParams;
extern VALUE cVADParams;
Expand Down Expand Up @@ -75,6 +75,11 @@ static ID id_abort_callback_user_data;
static ID id_vad;
static ID id_vad_model_path;
static ID id_vad_params;
static ID id_suppress_regex;
static ID id_grammar_penalty;
static ID id_tdrz_enable;
static ID id_audio_ctx;
static ID id_debug_mode;

static void
rb_whisper_callbcack_container_mark(ruby_whisper_callback_container *rwc)
Expand Down Expand Up @@ -1141,6 +1146,105 @@ ruby_whisper_params_get_vad_params(VALUE self)
return rwp->vad_params;
}

/*
* call-seq:
* suppress_regex = regex -> regex
*/
static VALUE
ruby_whisper_params_set_suppress_regex(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
if (NIL_P(value)) {
rwp->params.suppress_regex = NULL;
return value;
}
rwp->params.suppress_regex = StringValueCStr(value);
return value;
}

static VALUE
ruby_whisper_params_get_suppress_regex(VALUE self)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
return rwp->params.suppress_regex == NULL ? Qnil : rb_str_new2(rwp->params.suppress_regex);
}

/*
* call-seq:
* grammar_penalty = penalty -> penalty
*/
static VALUE
ruby_whisper_params_set_grammar_penalty(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
rwp->params.grammar_penalty = NUM2DBL(value);
return value;
}

static VALUE
ruby_whisper_params_get_grammar_penalty(VALUE self)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
return DBL2NUM(rwp->params.grammar_penalty);
}

/*
* call-seq:
* tdrz_enable = enable -> enable
*/
static VALUE
ruby_whisper_params_set_tdrz_enable(VALUE self, VALUE value)
{
BOOL_PARAMS_SETTER(self, tdrz_enable, value)
}

static VALUE
ruby_whisper_params_get_tdrz_enable(VALUE self)
{
BOOL_PARAMS_GETTER(self, tdrz_enable)
}

/*
* call-seq:
* audio_ctx = context_size -> context_size
*/
static VALUE
ruby_whisper_params_set_audio_ctx(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
rwp->params.audio_ctx = NUM2INT(value);
return value;
}

static VALUE
ruby_whisper_params_get_audio_ctx(VALUE self)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
return INT2NUM(rwp->params.audio_ctx);
}

/*
* call-seq:
* debug_mode = enable -> enable
*/
static VALUE
ruby_whisper_params_set_debug_mode(VALUE self, VALUE value)
{
BOOL_PARAMS_SETTER(self, debug_mode, value)
}

static VALUE
ruby_whisper_params_get_debug_mode(VALUE self)
{
BOOL_PARAMS_GETTER(self, debug_mode)
}

#define SET_PARAM_IF_SAME(param_name) \
if (id == id_ ## param_name) { \
ruby_whisper_params_set_ ## param_name(self, value); \
Expand Down Expand Up @@ -1211,6 +1315,11 @@ ruby_whisper_params_initialize(int argc, VALUE *argv, VALUE self)
SET_PARAM_IF_SAME(vad)
SET_PARAM_IF_SAME(vad_model_path)
SET_PARAM_IF_SAME(vad_params)
SET_PARAM_IF_SAME(suppress_regex)
SET_PARAM_IF_SAME(grammar_penalty)
SET_PARAM_IF_SAME(tdrz_enable)
SET_PARAM_IF_SAME(audio_ctx)
SET_PARAM_IF_SAME(debug_mode)
}
}

Expand Down Expand Up @@ -1348,6 +1457,11 @@ init_ruby_whisper_params(VALUE *mWhisper)
DEFINE_PARAM(vad, 34)
DEFINE_PARAM(vad_model_path, 35)
DEFINE_PARAM(vad_params, 36)
DEFINE_PARAM(suppress_regex, 37)
DEFINE_PARAM(grammar_penalty, 38)
DEFINE_PARAM(tdrz_enable, 39)
DEFINE_PARAM(audio_ctx, 40)
DEFINE_PARAM(debug_mode, 41)

rb_define_method(cParams, "on_new_segment", ruby_whisper_params_on_new_segment, 0);
rb_define_method(cParams, "on_progress", ruby_whisper_params_on_progress, 0);
Expand Down
4 changes: 2 additions & 2 deletions bindings/ruby/lib/whisper/context.rb
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
module Whisper
class Context
def to_srt
each_segment.with_index.reduce("") {|srt, (segment, index)|
each_segment.with_index.reduce(String.new) {|srt, (segment, index)|
srt << "#{index + 1}\n#{segment.to_srt_cue}\n"
}
end

def to_webvtt
each_segment.with_index.reduce("WEBVTT\n\n") {|webvtt, (segment, index)|
each_segment.with_index.reduce(String.new("WEBVTT\n\n")) {|webvtt, (segment, index)|
webvtt << "#{index + 1}\n#{segment.to_webvtt_cue}\n"
}
end
Expand Down
59 changes: 59 additions & 0 deletions bindings/ruby/test/test_context_params.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
require_relative "helper"

class TestContextParams < TestBase
def test_context_new_with_default_params
whisper = Whisper::Context.new("base.en")
assert_instance_of Whisper::Context, whisper
end

def test_context_new_with_use_gpu
whisper = Whisper::Context.new("base.en", use_gpu: true)
assert_instance_of Whisper::Context, whisper

whisper = Whisper::Context.new("base.en", use_gpu: false)
assert_instance_of Whisper::Context, whisper
end

def test_context_new_with_flash_attn
whisper = Whisper::Context.new("base.en", flash_attn: true)
assert_instance_of Whisper::Context, whisper

whisper = Whisper::Context.new("base.en", flash_attn: false)
assert_instance_of Whisper::Context, whisper
end

def test_context_new_with_gpu_device
whisper = Whisper::Context.new("base.en", gpu_device: 0)
assert_instance_of Whisper::Context, whisper

whisper = Whisper::Context.new("base.en", gpu_device: 1)
assert_instance_of Whisper::Context, whisper
end

def test_context_new_with_dtw_token_timestamps
whisper = Whisper::Context.new("base.en", dtw_token_timestamps: true)
assert_instance_of Whisper::Context, whisper

whisper = Whisper::Context.new("base.en", dtw_token_timestamps: false)
assert_instance_of Whisper::Context, whisper
end

def test_context_new_with_dtw_aheads_preset
whisper = Whisper::Context.new("base.en", dtw_aheads_preset: 0)
assert_instance_of Whisper::Context, whisper

whisper = Whisper::Context.new("base.en", dtw_aheads_preset: 1)
assert_instance_of Whisper::Context, whisper
end

def test_context_new_with_combined_params
whisper = Whisper::Context.new("base.en",
use_gpu: true,
flash_attn: true,
gpu_device: 0,
dtw_token_timestamps: false,
dtw_aheads_preset: 0
)
assert_instance_of Whisper::Context, whisper
end
end
54 changes: 49 additions & 5 deletions bindings/ruby/test/test_params.rb
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ class TestParams < TestBase
:vad,
:vad_model_path,
:vad_params,
:suppress_regex,
:grammar_penalty,
:tdrz_enable,
:audio_ctx,
:debug_mode,
]

def setup
Expand Down Expand Up @@ -245,13 +250,50 @@ def test_vad_model_path_with_URI
end

def test_vad_params
assert_kind_of Whisper::VAD::Params, @params.vad_params
default_params = @params.vad_params
assert_same default_params, @params.vad_params
assert_equal 0.5, default_params.threshold
default_params = Whisper::VAD::Params.new
# vad_params returns a new wrapper each time, so use assert_equal instead of assert_same
retrieved_params = @params.vad_params
assert_equal default_params.threshold, retrieved_params.threshold
assert_equal 0.5, retrieved_params.threshold
new_params = Whisper::VAD::Params.new
@params.vad_params = new_params
assert_same new_params, @params.vad_params
retrieved_params = @params.vad_params
assert_equal new_params.threshold, retrieved_params.threshold
end

def test_suppress_regex
@params.suppress_regex = "[\\*\\[\\]]"
assert_equal @params.suppress_regex, "[\\*\\[\\]]"
@params.suppress_regex = nil
assert_nil @params.suppress_regex
end

def test_grammar_penalty
@params.grammar_penalty = 50.0
assert_in_delta @params.grammar_penalty, 50.0
@params.grammar_penalty = 0.0
assert_in_delta @params.grammar_penalty, 0.0
end

def test_tdrz_enable
@params.tdrz_enable = true
assert @params.tdrz_enable
@params.tdrz_enable = false
assert [email protected]_enable
end

def test_audio_ctx
@params.audio_ctx = 1024
assert_equal @params.audio_ctx, 1024
@params.audio_ctx = 0
assert_equal @params.audio_ctx, 0
end

def test_debug_mode
@params.debug_mode = true
assert @params.debug_mode
@params.debug_mode = false
assert [email protected]_mode
end

def test_new_with_kw_args
Expand Down Expand Up @@ -284,6 +326,8 @@ def test_new_with_kw_args_default_values(param)
"es"
in [:initial_prompt, *]
"Initial prompt"
in [:suppress_regex, *]
"[\\*\\[\\]]"
in [/_callback\Z/, *]
proc {}
in [/_user_data\Z/, *]
Expand Down
1 change: 0 additions & 1 deletion bindings/ruby/test/test_segment.rb
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def test_on_new_segment_twice
end

def test_transcription_after_segment_retrieved
params = Whisper::Params.new
segment = whisper.each_segment.first
assert_match(/ask not what your country can do for you, ask what you can do for your country/, segment.text)

Expand Down