@@ -4408,11 +4408,24 @@ struct whisper_vad_model {
44084408 std::map<std::string, struct ggml_tensor *> tensors;
44094409};
44104410
4411+ struct whisper_vad_state {
4412+ // Hidden state for LSTM
4413+ std::vector<float > h; // Hidden state dimension
4414+ std::vector<float > c; // Cell state dimension
4415+
4416+ // Reset state
4417+ void reset () {
4418+ h.clear ();
4419+ c.clear ();
4420+ }
4421+ };
4422+
44114423struct whisper_vad_context {
44124424 int64_t t_load_us = 0 ;
44134425 int64_t t_start_us = 0 ;
44144426
44154427 whisper_vad_model model;
4428+ whisper_vad_state state;
44164429
44174430 std::string path_model;
44184431};
@@ -4428,18 +4441,6 @@ struct whisper_vad_params whisper_vad_default_params(void) {
44284441 return result;
44294442}
44304443
4431- struct whisper_vad_state {
4432- // Hidden state for LSTM
4433- float h[128 ]; // Hidden state dimension
4434- float c[128 ]; // Cell state dimension
4435-
4436- // Reset state
4437- void reset () {
4438- memset (h, 0 , sizeof (h));
4439- memset (c, 0 , sizeof (c));
4440- }
4441- };
4442-
44434444struct whisper_vad_result {
44444445 float probability; // Speech probability (0-1)
44454446};
@@ -4526,10 +4527,14 @@ whisper_vad_context * whisper_vad_init_from_file_with_params(
45264527 return nullptr ;
45274528 }
45284529 }
4530+ whisper_vad_model model;
4531+ whisper_vad_state state;
4532+
45294533 whisper_vad_context * vctx = new whisper_vad_context;
4534+ vctx->model = model;
4535+ vctx->state = state;
45304536 vctx->path_model = path_model;
45314537
4532- whisper_vad_model model;
45334538 auto & hparams = model.hparams ;
45344539
45354540 // load model hyper params (hparams)
@@ -4572,6 +4577,9 @@ whisper_vad_context * whisper_vad_init_from_file_with_params(
45724577 WHISPER_LOG_INFO (" %s: final_conv_out = %d\n " , __func__, hparams.final_conv_out );
45734578 }
45744579
4580+ vctx->state .h .resize (hparams.lstm_hidden_size );
4581+ vctx->state .c .resize (hparams.lstm_hidden_size );
4582+
45754583 // 1 STFT, 4*2 encoder tensors, 4 LSTM tensors, 2 final output tensors
45764584 const size_t n_tensors = hparams.n_encoder_layers * 2 + 4 + 2 + 1 ;
45774585
@@ -4676,41 +4684,45 @@ whisper_vad_context * whisper_vad_init_from_file_with_params(
46764684 model.encoder_3_bias = create_tensor (VAD_TENSOR_ENC_3_BIAS,
46774685 ggml_new_tensor_1d (ctx, GGML_TYPE_F32, hparams.encoder_out_channels [3 ]));
46784686
4687+ // Hidden State dimension (input gate, forget gate, cell gate, output gate)
4688+ const int hstate_dim = hparams.lstm_hidden_size * 4 ;
4689+
46794690 // LSTM weights - input to hidden
46804691 model.lstm_weight_ih = create_tensor (
46814692 VAD_TENSOR_LSTM_WEIGHT_IH,
4682- ggml_new_tensor_2d (ctx, type, 128 , 512 )
4693+ ggml_new_tensor_2d (ctx, type, hparams. lstm_hidden_size , hstate_dim )
46834694 );
46844695
46854696 // LSTM weights - hidden to hidden
46864697 model.lstm_weight_hh = create_tensor (
46874698 VAD_TENSOR_LSTM_WEIGHT_HH,
4688- ggml_new_tensor_2d (ctx, type, 128 , 512 )
4699+ ggml_new_tensor_2d (ctx, type, hparams. lstm_hidden_size , hstate_dim )
46894700 );
46904701
46914702 // LSTM bias - input to hidden
46924703 model.lstm_bias_ih = create_tensor (
46934704 VAD_TENSOR_LSTM_BIAS_IH,
4694- ggml_new_tensor_1d (ctx, GGML_TYPE_F32, 512 )
4705+ ggml_new_tensor_1d (ctx, GGML_TYPE_F32, hstate_dim )
46954706 );
46964707
46974708 // LSTM bias - hidden to hidden
46984709 model.lstm_bias_hh = create_tensor (
46994710 VAD_TENSOR_LSTM_BIAS_HH,
4700- ggml_new_tensor_1d (ctx, GGML_TYPE_F32, 512 )
4711+ ggml_new_tensor_1d (ctx, GGML_TYPE_F32, hstate_dim )
47014712 );
47024713
47034714 // Final conv layer weight
47044715 model.final_conv_weight = create_tensor (
47054716 VAD_TENSOR_FINAL_CONV_WEIGHT,
4706- ggml_new_tensor_2d (ctx, GGML_TYPE_F32, 128 , 1 )
4717+ ggml_new_tensor_2d (ctx, GGML_TYPE_F32, hparams. final_conv_in , 1 )
47074718 );
47084719
47094720 // Final conv layer bias
47104721 model.final_conv_bias = create_tensor (
47114722 VAD_TENSOR_FINAL_CONV_BIAS,
47124723 ggml_new_tensor_1d (ctx, GGML_TYPE_F32, 1 )
47134724 );
4725+
47144726 ggml_free (ctx);
47154727 }
47164728
0 commit comments