Skip to content

Commit 0233dbd

Browse files
authored
Add FeatureBuffer support to Cache-Aware streaming pipeline (#15188)
* Add FeatureBuffer support to Cache-Aware RNNT streaming pipeline Signed-off-by: arushid <arushid@nvidia.com> * Apply isort and black reformatting Signed-off-by: arushidNV <arushidNV@users.noreply.github.com> * Adding feature support for Cache Aware CTC pipeline Signed-off-by: arushid <arushid@nvidia.com> * Apply isort and black reformatting Signed-off-by: arushidNV <arushidNV@users.noreply.github.com> * Changed [Frame| Feature] to [Request] Signed-off-by: arushid <arushid@nvidia.com> * Apply isort and black reformatting Signed-off-by: arushidNV <arushidNV@users.noreply.github.com> * Adding feature buffer support in request generator Signed-off-by: arushid <arushid@nvidia.com> * Fixing issues with feature buffer support in request generator Signed-off-by: arushid <arushid@nvidia.com> * Updated comment in config Signed-off-by: arushid <arushid@nvidia.com> * Resolving code reviews Signed-off-by: arushid <arushid@nvidia.com> --------- Signed-off-by: arushid <arushid@nvidia.com> Signed-off-by: arushidNV <arushidNV@users.noreply.github.com> Co-authored-by: arushidNV <arushidNV@users.noreply.github.com>
1 parent 77199b5 commit 0233dbd

File tree

4 files changed

+145
-53
lines changed

4 files changed

+145
-53
lines changed

examples/asr/conf/asr_streaming_inference/cache_aware_ctc.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ streaming:
7575
use_cache: true # Whether to use cache for streaming
7676
use_feat_cache: true # Whether to cache mel-spec features, set false to re-calculate all mel-spec features in audio buffer
7777
chunk_size_in_secs: null # Amount of audio to load for each streaming step, e.g., 0.08s for FastConformer. Set to `null` for using default size equal to 1+lookahead frames.
78-
request_type: frame # Type of request: frame, only frame is supported for cache-aware streaming
78+
request_type: frame # Type of request: frame or feature_buffer
7979
num_slots: 1024 # Number of slots in the context manager: must be >= batch_size
8080

8181

examples/asr/conf/asr_streaming_inference/cache_aware_rnnt.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# ASR Configuration
33
# ================================
44
asr:
5-
model_name: stt_en_fastconformer_hybrid_large_streaming_multi # Pre-trained CTC/hybrid model from NGC/HuggingFace or local .nemo file path
5+
model_name: nvidia/nemotron-speech-streaming-en-0.6b # Pre-trained CTC/hybrid model from NGC/HuggingFace or local .nemo file path
66
device: cuda # Device for inference: 'cuda' or 'cpu'
77
device_id: 0 # GPU device ID
88
compute_dtype: bfloat16 # Compute precision: 'bfloat16' for Ampere+, 'float16' for older GPUs, or 'float32'
@@ -85,14 +85,14 @@ endpointing:
8585
# ========================
8686
streaming:
8787
sample_rate: 16000 # Audio sample rate in Hz
88-
batch_size: 256 # Number of audio frames per batch
88+
batch_size: 64 # Number of audio frames per batch
8989
word_boundary_tolerance: 4 # Tolerance for word boundaries
9090
att_context_size: [70,13] # Attention context size: [70,13],[70,6],[70,1],[70,0]
9191
use_cache: true # Whether to use cache for streaming
9292
use_feat_cache: true # Whether to cache mel-spec features, set false to re-calculate all mel-spec features in audio buffer
9393
chunk_size_in_secs: null # Amount of audio to load for each streaming step, e.g., 0.08s for FastConformer. Set to `null` for using default size equal to 1+lookahead frames.
94-
request_type: frame # Type of request: frame, only frame is supported for cache-aware streaming
95-
num_slots: 1024 # Number of slots in the context manager: must be >= batch_size
94+
request_type: frame # Type of request: frame or feature_buffer
95+
num_slots: 256 # Number of slots in the context manager: must be >= batch_size
9696

9797

9898
# ========================

nemo/collections/asr/inference/pipelines/cache_aware_ctc_pipeline.py

Lines changed: 72 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,14 @@
2929
from nemo.collections.asr.inference.streaming.decoders.greedy.greedy_ctc_decoder import CTCGreedyDecoder
3030
from nemo.collections.asr.inference.streaming.endpointing.greedy.greedy_ctc_endpointing import CTCGreedyEndpointing
3131
from nemo.collections.asr.inference.streaming.framing.multi_stream import ContinuousBatchedRequestStreamer
32-
from nemo.collections.asr.inference.streaming.framing.request import FeatureBuffer, Frame
32+
from nemo.collections.asr.inference.streaming.framing.request import FeatureBuffer, Frame, Request
3333
from nemo.collections.asr.inference.streaming.framing.request_options import ASRRequestOptions
3434
from nemo.collections.asr.inference.streaming.state.cache_aware_ctc_state import CacheAwareCTCStreamingState
3535
from nemo.collections.asr.inference.utils.endpointing_utils import millisecond_to_frames
3636
from nemo.collections.asr.inference.utils.enums import RequestType
3737
from nemo.collections.asr.inference.utils.pipeline_utils import (
3838
check_existance_of_required_attributes,
39+
drop_trailing_features,
3940
get_confidence_utils,
4041
normalize_log_probs,
4142
)
@@ -94,9 +95,6 @@ def init_parameters(self, cfg: DictConfig) -> None:
9495
f"Number of slots in the context manager must be >= batch_size: {self.num_slots} < {self.batch_size}"
9596
)
9697
self.request_type = RequestType.from_str(cfg.streaming.request_type)
97-
if self.request_type is not RequestType.FRAME:
98-
raise ValueError(f"Request type {self.request_type} is not supported for cache-aware streaming.")
99-
10098
self.word_boundary_tolerance = cfg.streaming.word_boundary_tolerance
10199
self.stop_history_eou_in_milliseconds = cfg.endpointing.stop_history_eou
102100
self.residue_tokens_at_end = cfg.endpointing.residue_tokens_at_end
@@ -142,6 +140,9 @@ def init_parameters(self, cfg: DictConfig) -> None:
142140
self.drop_left_context = left_context_size
143141
self.valid_out_len = self.tokens_per_frame
144142

143+
# Expected feature buffer length for trimming (safeguard for feature buffer inputs)
144+
self.expected_feature_buffer_len = int(self.buffer_size_in_secs / self.window_stride)
145+
145146
def init_greedy_ctc_decoder(self) -> None:
146147
"""Initialize the CTC decoder."""
147148
check_existance_of_required_attributes(self, ['vocabulary', 'conf_func'])
@@ -210,24 +211,28 @@ def preprocess(self, buffers: list[Tensor], right_paddings: list[int] | None = N
210211
(tuple[Tensor, Tensor]) Processed feature buffers and their lengths.
211212
"""
212213
feature_buffers = [f_buffer.unsqueeze_(0) for f_buffer in buffers]
214+
# Trim to expected feature buffer length (safeguard for external feature buffer inputs)
215+
feature_buffers = [
216+
drop_trailing_features(f_buffer, self.expected_feature_buffer_len) for f_buffer in feature_buffers
217+
]
213218
feature_buffer_lens = torch.tensor([f_buffer.shape[2] for f_buffer in feature_buffers], device=self.device)
214219
if right_paddings is not None:
215220
right_paddings = torch.tensor(right_paddings, device=feature_buffer_lens.device)
216221
feature_buffer_lens = feature_buffer_lens - right_paddings
217222
feature_buffers = torch.cat(feature_buffers).to(self.device)
218223
return feature_buffers, feature_buffer_lens
219224

220-
def run_greedy_decoder(self, state: CacheAwareCTCStreamingState, frame: Frame, log_probs: Tensor):
225+
def run_greedy_decoder(self, state: CacheAwareCTCStreamingState, request: Request, log_probs: Tensor):
221226
"""
222227
Run the greedy CTC decoder on the log_probs and update the state
223228
Args:
224229
state: (CacheAwareCTCStreamingState) The state of the stream
225-
frame: (Frame) The current frame
226-
log_probs: (Tensor) The log probabilities of the current frame
230+
request: (Request) The current request (frame or feature buffer)
231+
log_probs: (Tensor) The log probabilities of the current request
227232
Returns:
228233
(bool) Whether EOU is detected.
229234
"""
230-
eou_detected = frame.is_last
235+
eou_detected = request.is_last
231236
last_token = state.label_buffer[-1] if len(state.label_buffer) > 0 else self.blank_id
232237
cur_output = self.greedy_ctc_decoder(log_probs, compute_confidence=True, previous=last_token)
233238
state.update_label_buffer(cur_output["labels"])
@@ -244,25 +249,29 @@ def run_greedy_decoder(self, state: CacheAwareCTCStreamingState, frame: Frame, l
244249
return eou_detected
245250

246251
def decode_log_probs(
247-
self, frames: list[Frame], log_probs: Tensor, tail_log_probs: Tensor | None, ready_state_ids: set
252+
self,
253+
requests: list[Request],
254+
log_probs: Tensor,
255+
tail_log_probs: Tensor | None,
256+
ready_state_ids: set,
248257
) -> None:
249258
"""
250259
Decode the log probabilities and update the state
251260
Args:
252-
frames: (list[Frame]) List of frames to transcribe.
261+
requests: (list[Request]) List of requests (frames or feature buffers) to transcribe.
253262
log_probs: (Tensor) Log probabilities.
254263
tail_log_probs: (Tensor | None) Tail log probabilities.
255264
ready_state_ids: (set) Set of ready state IDs.
256265
"""
257266

258-
for idx, frame in enumerate(frames):
259-
state = self.get_state(frame.stream_id)
260-
eou_detected = self.run_greedy_decoder(state, frame, log_probs[idx])
267+
for idx, request in enumerate(requests):
268+
state = self.get_state(request.stream_id)
269+
eou_detected = self.run_greedy_decoder(state, request, log_probs[idx])
261270

262271
if eou_detected:
263272
self.bpe_decoder.decode_bpe_tokens(state)
264273
state.cleanup_after_eou()
265-
ready_state_ids.add(frame.stream_id)
274+
ready_state_ids.add(request.stream_id)
266275

267276
if tail_log_probs is not None:
268277
last_token = state.label_buffer[-1] if len(state.label_buffer) > 0 else self.blank_id
@@ -273,15 +282,15 @@ def decode_log_probs(
273282

274283
def cache_aware_transcribe_step(
275284
self,
276-
frames: list[Frame],
285+
requests: list[Request],
277286
buffered_features: list[Tensor],
278287
right_paddings: list[int] | None,
279288
ready_state_ids: set,
280289
keep_all_outputs: bool = False,
281290
) -> None:
282291
"""
283292
Cache Aware Transcribe Step
284-
It receives a list of frames and features and do the following:
293+
It receives a list of requests (Frame or FeatureBuffer) and features and do the following:
285294
286295
1. Preprocess the features by stacking them and computing the lengths
287296
2. Get the context and mapping from the context manager for cache aware streaming
@@ -290,16 +299,16 @@ def cache_aware_transcribe_step(
290299
5. Decode the log probabilities and update the state
291300
292301
Args:
293-
frames: (list[Frame]) List of frames to transcribe.
302+
requests: (list[Request]) List of requests (frames or feature buffers) to transcribe.
294303
buffered_features: (list[Tensor]) List of buffered features.
295304
right_paddings: (list[int] | None) List of right paddings.
296305
ready_state_ids: (set) Set of ready state IDs.
297306
keep_all_outputs: (bool) Whether to keep all outputs or not.
298307
"""
299308
feature_buffers, feature_buffer_lens = self.preprocess(buffered_features, right_paddings)
300309

301-
stream_ids = [frame.stream_id for frame in frames]
302-
eos_flags = [frame.is_last for frame in frames]
310+
stream_ids = [request.stream_id for request in requests]
311+
eos_flags = [request.is_last for request in requests]
303312
context, mapping = self.context_manager.get_context(stream_ids)
304313

305314
drop_extra_pre_encoded = 0 if not self.use_cache else self.asr_model.drop_extra_pre_encoded
@@ -318,7 +327,7 @@ def cache_aware_transcribe_step(
318327
log_probs = normalize_log_probs(log_probs)
319328
self.context_manager.update_cache(stream_ids, new_context, mapping)
320329
self.context_manager.reset_slots(stream_ids, eos_flags)
321-
self.decode_log_probs(frames, log_probs, tail_log_probs, ready_state_ids)
330+
self.decode_log_probs(requests, log_probs, tail_log_probs, ready_state_ids)
322331

323332
def transcribe_step_for_frames(self, frames: list[Frame]) -> None:
324333
"""
@@ -362,8 +371,46 @@ def transcribe_step_for_frames(self, frames: list[Frame]) -> None:
362371
self.update_partial_transcript(frames, self.tokenizer, self.leading_regex_pattern)
363372

364373
def transcribe_step_for_feature_buffers(self, fbuffers: list[FeatureBuffer]) -> None:
365-
"""Transcribe a step for feature buffers"""
366-
raise NotImplementedError("Feature buffer type is not supported for cache aware streaming.")
374+
"""
375+
Transcribes the feature buffers in a streaming manner.
376+
After detecting EOU, it updates the state and run text processor.
377+
If there are multiple streams, it waits until all states are ready to run text processor.
378+
Args:
379+
fbuffers: (list[FeatureBuffer]) List of feature buffers to transcribe.
380+
"""
381+
ready_state_ids = set()
382+
383+
final_fbuffers, final_features = [], []
384+
nonfinal_fbuffers, nonfinal_features = [], []
385+
final_right_paddings = []
386+
387+
for fbuffer in fbuffers:
388+
feature = fbuffer.features
389+
right_padding = max(0, self.expected_feature_buffer_len - fbuffer.valid_size)
390+
391+
if fbuffer.is_last:
392+
final_fbuffers.append(fbuffer)
393+
final_features.append(feature)
394+
final_right_paddings.append(right_padding)
395+
else:
396+
nonfinal_fbuffers.append(fbuffer)
397+
nonfinal_features.append(feature)
398+
399+
if len(nonfinal_fbuffers) > 0:
400+
self.cache_aware_transcribe_step(
401+
nonfinal_fbuffers, nonfinal_features, None, ready_state_ids, keep_all_outputs=False
402+
)
403+
404+
if len(final_fbuffers) > 0:
405+
self.cache_aware_transcribe_step(
406+
final_fbuffers, final_features, final_right_paddings, ready_state_ids, keep_all_outputs=True
407+
)
408+
409+
if len(ready_state_ids) > 0:
410+
self.text_processor.process([self.get_state(stream_id) for stream_id in ready_state_ids])
411+
ready_state_ids.clear()
412+
413+
self.update_partial_transcript(fbuffers, self.tokenizer, self.leading_regex_pattern)
367414

368415
def get_request_generator(self) -> ContinuousBatchedRequestStreamer:
369416
"""
@@ -377,9 +424,9 @@ def get_request_generator(self) -> ContinuousBatchedRequestStreamer:
377424
sample_rate=self.sample_rate,
378425
batch_size=self.batch_size,
379426
request_type=self.request_type,
380-
preprocessor=None,
381-
buffer_size_in_secs=None,
382-
device=None,
427+
preprocessor=self.preprocessor,
428+
buffer_size_in_secs=self.buffer_size_in_secs,
429+
device=self.device,
383430
pad_last_frame=True,
384431
)
385432
return request_generator

0 commit comments

Comments
 (0)