Skip to content

Commit cf46fc9

Browse files
committed
vad : add context to frames
1 parent 34c5ec6 commit cf46fc9

File tree

1 file changed

+31
-16
lines changed

1 file changed

+31
-16
lines changed

src/whisper.cpp

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4429,6 +4429,7 @@ struct whisper_vad_context {
44294429
int64_t t_start_us = 0;
44304430

44314431
int n_window;
4432+
int n_context;
44324433
std::string path_model;
44334434

44344435
whisper_vad_model model;
@@ -4573,7 +4574,7 @@ static struct ggml_cgraph * whisper_vad_build_graph(whisper_vad_context & vctx,
45734574
whisper_vad_state & vstate) {
45744575
const auto & model = vctx.model;
45754576
const auto & hparams = model.hparams;
4576-
const int n_window = vctx.n_window;
4577+
const int frame_size = vctx.n_window + vctx.n_context;
45774578

45784579
WHISPER_LOG_INFO("%s: Building VAD graph\n", __func__);
45794580
struct ggml_init_params params = {
@@ -4586,9 +4587,8 @@ static struct ggml_cgraph * whisper_vad_build_graph(whisper_vad_context & vctx,
45864587

45874588
ggml_cgraph * gf = ggml_new_graph(ctx0);
45884589

4589-
WHISPER_LOG_INFO("%s: n_window = %d\n", __func__, n_window);
4590-
// We process one frame/segment at a time of size n_window.
4591-
struct ggml_tensor * frame = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_window, 1, 1);
4590+
WHISPER_LOG_INFO("%s: n_window = %d\n", __func__, frame_size);
4591+
struct ggml_tensor * frame = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, frame_size, 1, 1);
45924592
ggml_set_name(frame, "frame");
45934593
ggml_set_input(frame);
45944594

@@ -4725,7 +4725,9 @@ whisper_vad_context * whisper_vad_init_from_file_with_params_no_state(
47254725

47264726
whisper_vad_context * vctx = new whisper_vad_context;
47274727
vctx->path_model = path_model;
4728+
// TODO(danbev) Read these from the model since they are tied to the model.
47284729
vctx->n_window = 512;
4730+
vctx->n_context = 64;
47294731

47304732
auto & model = vctx->model;
47314733
auto & hparams = model.hparams;
@@ -5196,6 +5198,8 @@ struct whisper_vad_segments whisper_vad_detect_speech(
51965198
WHISPER_LOG_INFO("%s: detecting speech in %d samples\n", __func__, n_samples);
51975199
auto & sched = vctx->state->sched.sched;
51985200
const int hidden_dim = vctx->model.hparams.lstm_hidden_size;
5201+
const int n_context = vctx->n_context;
5202+
const int frame_size = vctx->n_window + n_context;
51995203

52005204
struct whisper_vad_segments segments {
52015205
/* n_segments = */ 0,
@@ -5212,8 +5216,9 @@ struct whisper_vad_segments whisper_vad_detect_speech(
52125216
return segments;
52135217
}
52145218

5215-
WHISPER_LOG_INFO("%s: n_window: %u\n", __func__, vctx->n_window);
5216-
std::vector<float> window_with_context(vctx->n_window);
5219+
WHISPER_LOG_INFO("%s: frame_size: %u\n", __func__, frame_size);
5220+
std::vector<float> previous_context(n_context, 0.0f);
5221+
std::vector<float> window_with_context(frame_size + previous_context.size());
52175222

52185223
struct ggml_tensor * frame = ggml_graph_get_tensor(gf, "frame");
52195224
struct ggml_tensor * c_out = ggml_graph_get_tensor(gf, "c_out");
@@ -5231,18 +5236,25 @@ struct whisper_vad_segments whisper_vad_detect_speech(
52315236
std::vector<float> h_state(hidden_dim, 0.0f);
52325237
std::vector<float> c_state(hidden_dim, 0.0f);
52335238

5234-
int n_frames = n_samples / vctx->n_window;
5239+
int n_frames = n_samples / window_with_context.size();
52355240
std::vector<float> probs(n_frames, 0.0f);
52365241

52375242
WHISPER_LOG_INFO("%s: frame tensor size: %ld\n", __func__, frame->ne[0]);
52385243

5239-
for (int i = 0; i < n_samples; i += vctx->n_window) {
5244+
for (int i = 0; i < n_samples; i += frame_size) {
52405245
// Skip if we don't have enough samples for a full window
5241-
if (i + vctx->n_window > n_samples) {
5246+
if (i + frame_size > n_samples) {
52425247
break;
52435248
}
52445249

5245-
ggml_backend_tensor_set(frame, pcmf32 + i, 0, ggml_nelements(frame) * sizeof(float));
5250+
// Copy the previous context to the beginning of window_with_context.
5251+
std::copy(previous_context.begin(), previous_context.end(), window_with_context.begin());
5252+
5253+
// Copy current frame samples to after the context.
5254+
std::copy(pcmf32 + i, pcmf32 + i + frame_size, window_with_context.begin() + previous_context.size());
5255+
5256+
// Set the frame tensor data with the context + the samples.
5257+
ggml_backend_tensor_set(frame, window_with_context.data(), 0, ggml_nelements(frame) * sizeof(float));
52465258

52475259
ggml_backend_tensor_set(h_in, h_state.data(), 0, hidden_dim * sizeof(float));
52485260
ggml_backend_tensor_set(c_in, c_state.data(), 0, hidden_dim * sizeof(float));
@@ -5252,7 +5264,7 @@ struct whisper_vad_segments whisper_vad_detect_speech(
52525264
break;
52535265
}
52545266

5255-
// Print out some intermediate results
5267+
// Print out some intermediate results for debugging
52565268
WHISPER_LOG_INFO("%s:###### Intermediate results #####\n", __func__);
52575269
{
52585270
struct ggml_tensor * tensor = ggml_graph_get_tensor(gf, "stft");
@@ -5262,6 +5274,7 @@ struct whisper_vad_segments whisper_vad_detect_speech(
52625274
WHISPER_LOG_INFO("%s: sftf: [%d]: %f\n", __func__, i, read_b[i]);
52635275
}
52645276
}
5277+
// Print out some intermediate results for debugging
52655278
{
52665279
struct ggml_tensor * tensor = ggml_graph_get_tensor(gf, "final_conv");
52675280
std::vector<float> read_b(ggml_nbytes(tensor));
@@ -5271,16 +5284,18 @@ struct whisper_vad_segments whisper_vad_detect_speech(
52715284
}
52725285
}
52735286

5287+
// Update the LSTM states
52745288
ggml_backend_tensor_get(h_out, h_state.data(), 0, hidden_dim * sizeof(float));
52755289
ggml_backend_tensor_get(c_out, c_state.data(), 0, hidden_dim * sizeof(float));
52765290

5277-
WHISPER_LOG_INFO("%s: h_state first 3 values: %f, %f, %f\n",
5278-
__func__, h_state[0], h_state[1], h_state[2]);
5279-
WHISPER_LOG_INFO("%s: c_state first 3 values: %f, %f, %f\n",
5280-
__func__, c_state[0], c_state[1], c_state[2]);
5291+
WHISPER_LOG_INFO("%s: h_state first 3 values: %f, %f, %f\n", __func__, h_state[0], h_state[1], h_state[2]);
5292+
WHISPER_LOG_INFO("%s: c_state first 3 values: %f, %f, %f\n", __func__, c_state[0], c_state[1], c_state[2]);
52815293

5282-
ggml_backend_tensor_get(prob, &probs[i/vctx->n_window], 0, sizeof(float));
5294+
// Get the probabilities.
5295+
ggml_backend_tensor_get(prob, &probs[i/frame_size], 0, sizeof(float));
52835296

5297+
// Copy the last n_context to add to the next frame.
5298+
std::copy(window_with_context.end() - n_context, window_with_context.end(), previous_context.begin());
52845299
}
52855300
WHISPER_LOG_INFO("%s: finished processing %d samples\n", __func__, n_samples);
52865301
for (size_t i = 0; i < probs.size(); i++) {

0 commit comments

Comments
 (0)