Skip to content

Commit 5365dab

Browse files
committed
vad : remove sigmoid activation from VAD output
1 parent 67efe42 commit 5365dab

File tree

1 file changed

+6
-10
lines changed

1 file changed

+6
-10
lines changed

src/whisper.cpp

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4608,10 +4608,7 @@ static struct ggml_cgraph * whisper_vad_build_graph(whisper_vad_context & vctx,
46084608
// Final output layer - linear transformation from LSTM output
46094609
cur = ggml_mul_mat(ctx0, model.final_conv_weight, cur);
46104610
cur = ggml_add(ctx0, cur, model.final_conv_bias);
4611-
4612-
// Apply sigmoid to get probability between 0 and 1
4613-
cur = ggml_sigmoid(ctx0, cur);
4614-
ggml_set_name(cur, "prob");
4611+
ggml_set_name(cur, "logits");
46154612
ggml_set_output(cur);
46164613
}
46174614

@@ -4773,7 +4770,6 @@ whisper_vad_context * whisper_vad_init_from_file_with_params_no_state(
47734770
WHISPER_LOG_INFO("%s: final_conv_out = %d\n", __func__, hparams.final_conv_out);
47744771
}
47754772

4776-
47774773
// 1 STFT tensor, 4*2 encoder tensors, 4 LSTM tensors, 2 final output tensors
47784774
const size_t n_tensors = hparams.n_encoder_layers * 2 + 4 + 2 + 1;
47794775

@@ -5221,7 +5217,7 @@ struct whisper_vad_segments whisper_vad_detect_speech(
52215217
struct ggml_tensor * frame = ggml_graph_get_tensor(gf, "frame");
52225218
struct ggml_tensor * c_out = ggml_graph_get_tensor(gf, "c_out");
52235219
struct ggml_tensor * h_out = ggml_graph_get_tensor(gf, "h_out");
5224-
struct ggml_tensor * prob = ggml_graph_get_tensor(gf, "prob");
5220+
struct ggml_tensor * logits_tensor = ggml_graph_get_tensor(gf, "logits");
52255221

52265222
struct ggml_tensor * c_in = ggml_graph_get_tensor(gf, "c_in");
52275223
struct ggml_tensor * h_in = ggml_graph_get_tensor(gf, "h_in");
@@ -5232,7 +5228,7 @@ struct whisper_vad_segments whisper_vad_detect_speech(
52325228
std::vector<float> c_state(hidden_dim, 0.0f);
52335229

52345230
int n_frames = n_samples / vctx->window_size_samples;
5235-
std::vector<float> probs(n_frames, 0.0f);
5231+
std::vector<float> logits(n_frames, 0.0f);
52365232

52375233
WHISPER_LOG_INFO("%s: frame tensor size: %ld\n", __func__, frame->ne[0]);
52385234

@@ -5267,13 +5263,13 @@ struct whisper_vad_segments whisper_vad_detect_speech(
52675263

52685264
ggml_backend_tensor_get(h_out, h_state.data(), 0, hidden_dim * sizeof(float));
52695265
ggml_backend_tensor_get(c_out, c_state.data(), 0, hidden_dim * sizeof(float));
5270-
ggml_backend_tensor_get(prob, &probs[i/vctx->window_size_samples], 0, sizeof(float));
5266+
ggml_backend_tensor_get(logits_tensor, &logits[i/vctx->window_size_samples], 0, sizeof(float));
52715267

52725268
vctx->current_sample += vctx->window_size_samples;
52735269
}
52745270
WHISPER_LOG_INFO("%s: finished processing %d samples\n", __func__, n_samples);
5275-
for (int i = 0; i < probs.size(); i++) {
5276-
//WHISPER_LOG_INFO("%s: prob[%d]: %f\n", __func__, i, probs[i]);
5271+
for (int i = 0; i < logits.size(); i++) {
5272+
WHISPER_LOG_INFO("%s: logits[%d]: %f\n", __func__, i, logits[i]);
52775273
}
52785274

52795275
segments.n_segments = n_frames;

0 commit comments

Comments
 (0)