Skip to content

Commit 1ad095d

Browse files
committed
vad : add LSTM hidden state to VAD model
1 parent 19ea50a commit 1ad095d

File tree

1 file changed

+30
-18
lines changed

1 file changed

+30
-18
lines changed

src/whisper.cpp

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4408,11 +4408,24 @@ struct whisper_vad_model {
44084408
std::map<std::string, struct ggml_tensor *> tensors;
44094409
};
44104410

4411+
struct whisper_vad_state {
4412+
// Hidden state for LSTM
4413+
std::vector<float> h; // Hidden state dimension
4414+
std::vector<float> c; // Cell state dimension
4415+
4416+
// Reset state
4417+
void reset() {
4418+
h.clear();
4419+
c.clear();
4420+
}
4421+
};
4422+
44114423
struct whisper_vad_context {
44124424
int64_t t_load_us = 0;
44134425
int64_t t_start_us = 0;
44144426

44154427
whisper_vad_model model;
4428+
whisper_vad_state state;
44164429

44174430
std::string path_model;
44184431
};
@@ -4428,18 +4441,6 @@ struct whisper_vad_params whisper_vad_default_params(void) {
44284441
return result;
44294442
}
44304443

4431-
struct whisper_vad_state {
4432-
// Hidden state for LSTM
4433-
float h[128]; // Hidden state dimension
4434-
float c[128]; // Cell state dimension
4435-
4436-
// Reset state
4437-
void reset() {
4438-
memset(h, 0, sizeof(h));
4439-
memset(c, 0, sizeof(c));
4440-
}
4441-
};
4442-
44434444
struct whisper_vad_result {
44444445
float probability; // Speech probability (0-1)
44454446
};
@@ -4526,10 +4527,14 @@ whisper_vad_context * whisper_vad_init_from_file_with_params(
45264527
return nullptr;
45274528
}
45284529
}
4530+
whisper_vad_model model;
4531+
whisper_vad_state state;
4532+
45294533
whisper_vad_context * vctx = new whisper_vad_context;
4534+
vctx->model = model;
4535+
vctx->state = state;
45304536
vctx->path_model = path_model;
45314537

4532-
whisper_vad_model model;
45334538
auto & hparams = model.hparams;
45344539

45354540
// load model hyper params (hparams)
@@ -4572,6 +4577,9 @@ whisper_vad_context * whisper_vad_init_from_file_with_params(
45724577
WHISPER_LOG_INFO("%s: final_conv_out = %d\n", __func__, hparams.final_conv_out);
45734578
}
45744579

4580+
vctx->state.h.resize(hparams.lstm_hidden_size);
4581+
vctx->state.c.resize(hparams.lstm_hidden_size);
4582+
45754583
// 1 STFT, 4*2 encoder tensors, 4 LSTM tensors, 2 final output tensors
45764584
const size_t n_tensors = hparams.n_encoder_layers * 2 + 4 + 2 + 1;
45774585

@@ -4676,41 +4684,45 @@ whisper_vad_context * whisper_vad_init_from_file_with_params(
46764684
model.encoder_3_bias = create_tensor(VAD_TENSOR_ENC_3_BIAS,
46774685
ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.encoder_out_channels[3]));
46784686

4687+
// Hidden State dimension (input gate, forget gate, cell gate, output gate)
4688+
const int hstate_dim = hparams.lstm_hidden_size * 4;
4689+
46794690
// LSTM weights - input to hidden
46804691
model.lstm_weight_ih = create_tensor(
46814692
VAD_TENSOR_LSTM_WEIGHT_IH,
4682-
ggml_new_tensor_2d(ctx, type, 128, 512)
4693+
ggml_new_tensor_2d(ctx, type, hparams.lstm_hidden_size, hstate_dim)
46834694
);
46844695

46854696
// LSTM weights - hidden to hidden
46864697
model.lstm_weight_hh = create_tensor(
46874698
VAD_TENSOR_LSTM_WEIGHT_HH,
4688-
ggml_new_tensor_2d(ctx, type, 128, 512)
4699+
ggml_new_tensor_2d(ctx, type, hparams.lstm_hidden_size, hstate_dim)
46894700
);
46904701

46914702
// LSTM bias - input to hidden
46924703
model.lstm_bias_ih = create_tensor(
46934704
VAD_TENSOR_LSTM_BIAS_IH,
4694-
ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 512)
4705+
ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hstate_dim)
46954706
);
46964707

46974708
// LSTM bias - hidden to hidden
46984709
model.lstm_bias_hh = create_tensor(
46994710
VAD_TENSOR_LSTM_BIAS_HH,
4700-
ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 512)
4711+
ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hstate_dim)
47014712
);
47024713

47034714
// Final conv layer weight
47044715
model.final_conv_weight = create_tensor(
47054716
VAD_TENSOR_FINAL_CONV_WEIGHT,
4706-
ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 128, 1)
4717+
ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hparams.final_conv_in, 1)
47074718
);
47084719

47094720
// Final conv layer bias
47104721
model.final_conv_bias = create_tensor(
47114722
VAD_TENSOR_FINAL_CONV_BIAS,
47124723
ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1)
47134724
);
4725+
47144726
ggml_free(ctx);
47154727
}
47164728

0 commit comments

Comments
 (0)