Skip to content

Commit fab4ec8

Browse files
committed
vad : fix tensor dimensions for VAD operations
1 parent 3fd945c commit fab4ec8

File tree

1 file changed

+17
-13
lines changed

1 file changed

+17
-13
lines changed

src/whisper.cpp

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4477,10 +4477,11 @@ static ggml_backend_buffer_type_t select_weight_buft(const whisper_vad_hparams &
44774477

44784478
static ggml_tensor * whisper_vad_build_stft_layer(ggml_context* ctx0,
44794479
const whisper_vad_model & model, ggml_tensor * cur) {
4480-
// We need the stft tensor to be in {258, 1, 256},
4481-
// that is a kernel size of 258, 1 channel, and 256 frequency bins (output)
4482-
struct ggml_tensor * stft_reshaped = ggml_reshape_3d(ctx0, model.stft_forward_basis, 258, 1, 256);
4483-
cur = ggml_conv_1d(ctx0, stft_reshaped, cur, 1, 1, 1);
4480+
ggml_tensor* padded = ggml_pad(ctx0, cur, 64, 0, 0, 0);
4481+
// We need the stft tensor to be in {256, 1, 258},
4482+
// 256 frequency bins (output), 1 channel (input), and 258 kernel size.
4483+
struct ggml_tensor * stft_reshaped = ggml_reshape_3d(ctx0, model.stft_forward_basis, 256, 1, 258);
4484+
cur = ggml_conv_1d(ctx0, stft_reshaped, padded, 128, 0, 1);
44844485
ggml_set_name(cur, "stft");
44854486
ggml_set_output(cur);
44864487
return cur;
@@ -4489,19 +4490,20 @@ static ggml_tensor * whisper_vad_build_stft_layer(ggml_context* ctx0,
44894490
static ggml_tensor * whisper_vad_build_encoder_layer(ggml_context* ctx0,
44904491
const whisper_vad_model & model, ggml_tensor * cur) {
44914492
WHISPER_LOG_INFO("%s: building encoder layer\n", __func__);
4492-
// Reshape from the STFT output which is [258, 1, 1, 1] where the first
4493+
// Reshape from the STFT output which is [4, 258, 1, 1] where the second
44934494
// dimension are complex number pairs. I think we can ignore the imaginary
44944495
// part and just use the real part here.
4495-
struct ggml_tensor * real_part = ggml_view_1d(ctx0, cur, 129, 0);
4496-
struct ggml_tensor * reshaped = ggml_reshape_3d(ctx0, real_part, 1, 129, 1);
4496+
struct ggml_tensor * real_part = ggml_view_2d(ctx0, cur, 4, 129,
4497+
cur->nb[0], // stride for moving between frequency bins
4498+
0); // offset = 0 to start from the beginning
44974499

44984500
// First Conv1D: expands to 128 channels.
4499-
cur = ggml_conv_1d(ctx0, model.encoder_0_weight, reshaped, 1, 1, 1);
4501+
cur = ggml_conv_1d(ctx0, model.encoder_0_weight, real_part, 2, 1, 1);
45004502
cur = ggml_add(ctx0, cur, ggml_reshape_3d(ctx0, model.encoder_0_bias, 1, 128, 1));
45014503
cur = ggml_relu(ctx0, cur);
45024504

4503-
// First Conv1D: reduces to 64 channels.
4504-
cur = ggml_conv_1d(ctx0, model.encoder_1_weight, cur, 1, 1, 1);
4505+
// Second Conv1D: reduces to 64 channels.
4506+
cur = ggml_conv_1d(ctx0, model.encoder_1_weight, cur, 2, 1, 1);
45054507
cur = ggml_add(ctx0, cur, ggml_reshape_3d(ctx0, model.encoder_1_bias, 1, 64, 1));
45064508
cur = ggml_relu(ctx0, cur);
45074509

@@ -4523,8 +4525,6 @@ static ggml_tensor * whisper_vad_build_lstm_layer(ggml_context* ctx0,
45234525
WHISPER_LOG_INFO("%s: building LSTM layer\n", __func__);
45244526

45254527
const whisper_vad_model & model = vctx.model;
4526-
const int seq_length = cur->ne[0];
4527-
const int input_dim = cur->ne[1];
45284528
const int hdim = model.hparams.lstm_hidden_size;
45294529
const int hdim_bytes = hdim * sizeof(float);
45304530

@@ -4597,23 +4597,27 @@ static struct ggml_cgraph * whisper_vad_build_graph(whisper_vad_context & vctx,
45974597

45984598
WHISPER_LOG_INFO("%s: n_window = %d\n", __func__, n_window);
45994599
// We process one frame/segment at a time of size n_window.
4600-
struct ggml_tensor * frame = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_window);
4600+
struct ggml_tensor * frame = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_window, 1, 1);
46014601
ggml_set_name(frame, "frame");
46024602
ggml_set_input(frame);
46034603

46044604
struct ggml_tensor * cur = nullptr;
46054605
{
46064606
cur = whisper_vad_build_stft_layer(ctx0, model, frame);
4607+
WHISPER_LOG_INFO("%s: stft output shape = [%d, %d, %d]\n", __func__, cur->ne[0], cur->ne[1], cur->ne[2]);
46074608

46084609
cur = whisper_vad_build_encoder_layer(ctx0, model, cur);
4610+
WHISPER_LOG_INFO("%s: endoder output shape = [%d, %d, %d]\n", __func__, cur->ne[0], cur->ne[1], cur->ne[2]);
46094611

46104612
cur = whisper_vad_build_lstm_layer(ctx0, vctx, cur);
4613+
WHISPER_LOG_INFO("%s: lstm output shape = [%d, %d, %d]\n", __func__, cur->ne[0], cur->ne[1], cur->ne[2]);
46114614

46124615
cur = ggml_relu(ctx0, cur);
46134616

46144617
// Final output layer
46154618
cur = ggml_conv_1d(ctx0, model.final_conv_weight, cur, 1, 0, 1);
46164619
cur = ggml_add(ctx0, cur, model.final_conv_bias);
4620+
WHISPER_LOG_INFO("%s: final decoder output shape = [%d, %d, %d]\n", __func__, cur->ne[0], cur->ne[1], cur->ne[2]);
46174621

46184622
// Apply sigmoid to get probability between 0 and 1
46194623
cur = ggml_sigmoid(ctx0, cur);

0 commit comments

Comments
 (0)