2929from nemo .collections .asr .inference .streaming .decoders .greedy .greedy_ctc_decoder import CTCGreedyDecoder
3030from nemo .collections .asr .inference .streaming .endpointing .greedy .greedy_ctc_endpointing import CTCGreedyEndpointing
3131from nemo .collections .asr .inference .streaming .framing .multi_stream import ContinuousBatchedRequestStreamer
32- from nemo .collections .asr .inference .streaming .framing .request import FeatureBuffer , Frame
32+ from nemo .collections .asr .inference .streaming .framing .request import FeatureBuffer , Frame , Request
3333from nemo .collections .asr .inference .streaming .framing .request_options import ASRRequestOptions
3434from nemo .collections .asr .inference .streaming .state .cache_aware_ctc_state import CacheAwareCTCStreamingState
3535from nemo .collections .asr .inference .utils .endpointing_utils import millisecond_to_frames
3636from nemo .collections .asr .inference .utils .enums import RequestType
3737from nemo .collections .asr .inference .utils .pipeline_utils import (
3838 check_existance_of_required_attributes ,
39+ drop_trailing_features ,
3940 get_confidence_utils ,
4041 normalize_log_probs ,
4142)
@@ -94,9 +95,6 @@ def init_parameters(self, cfg: DictConfig) -> None:
9495 f"Number of slots in the context manager must be >= batch_size: { self .num_slots } < { self .batch_size } "
9596 )
9697 self .request_type = RequestType .from_str (cfg .streaming .request_type )
97- if self .request_type is not RequestType .FRAME :
98- raise ValueError (f"Request type { self .request_type } is not supported for cache-aware streaming." )
99-
10098 self .word_boundary_tolerance = cfg .streaming .word_boundary_tolerance
10199 self .stop_history_eou_in_milliseconds = cfg .endpointing .stop_history_eou
102100 self .residue_tokens_at_end = cfg .endpointing .residue_tokens_at_end
@@ -142,6 +140,9 @@ def init_parameters(self, cfg: DictConfig) -> None:
142140 self .drop_left_context = left_context_size
143141 self .valid_out_len = self .tokens_per_frame
144142
143+ # Expected feature buffer length for trimming (safeguard for feature buffer inputs)
144+ self .expected_feature_buffer_len = int (self .buffer_size_in_secs / self .window_stride )
145+
145146 def init_greedy_ctc_decoder (self ) -> None :
146147 """Initialize the CTC decoder."""
147148 check_existance_of_required_attributes (self , ['vocabulary' , 'conf_func' ])
@@ -210,24 +211,28 @@ def preprocess(self, buffers: list[Tensor], right_paddings: list[int] | None = N
210211 (tuple[Tensor, Tensor]) Processed feature buffers and their lengths.
211212 """
212213 feature_buffers = [f_buffer .unsqueeze_ (0 ) for f_buffer in buffers ]
214+ # Trim to expected feature buffer length (safeguard for external feature buffer inputs)
215+ feature_buffers = [
216+ drop_trailing_features (f_buffer , self .expected_feature_buffer_len ) for f_buffer in feature_buffers
217+ ]
213218 feature_buffer_lens = torch .tensor ([f_buffer .shape [2 ] for f_buffer in feature_buffers ], device = self .device )
214219 if right_paddings is not None :
215220 right_paddings = torch .tensor (right_paddings , device = feature_buffer_lens .device )
216221 feature_buffer_lens = feature_buffer_lens - right_paddings
217222 feature_buffers = torch .cat (feature_buffers ).to (self .device )
218223 return feature_buffers , feature_buffer_lens
219224
220- def run_greedy_decoder (self , state : CacheAwareCTCStreamingState , frame : Frame , log_probs : Tensor ):
225+ def run_greedy_decoder (self , state : CacheAwareCTCStreamingState , request : Request , log_probs : Tensor ):
221226 """
222227 Run the greedy CTC decoder on the log_probs and update the state
223228 Args:
224229 state: (CacheAwareCTCStreamingState) The state of the stream
225- frame : (Frame ) The current frame
226- log_probs: (Tensor) The log probabilities of the current frame
230+ request : (Request ) The current request ( frame or feature buffer)
231+ log_probs: (Tensor) The log probabilities of the current request
227232 Returns:
228233 (bool) Whether EOU is detected.
229234 """
230- eou_detected = frame .is_last
235+ eou_detected = request .is_last
231236 last_token = state .label_buffer [- 1 ] if len (state .label_buffer ) > 0 else self .blank_id
232237 cur_output = self .greedy_ctc_decoder (log_probs , compute_confidence = True , previous = last_token )
233238 state .update_label_buffer (cur_output ["labels" ])
@@ -244,25 +249,29 @@ def run_greedy_decoder(self, state: CacheAwareCTCStreamingState, frame: Frame, l
244249 return eou_detected
245250
246251 def decode_log_probs (
247- self , frames : list [Frame ], log_probs : Tensor , tail_log_probs : Tensor | None , ready_state_ids : set
252+ self ,
253+ requests : list [Request ],
254+ log_probs : Tensor ,
255+ tail_log_probs : Tensor | None ,
256+ ready_state_ids : set ,
248257 ) -> None :
249258 """
250259 Decode the log probabilities and update the state
251260 Args:
252- frames : (list[Frame ]) List of frames to transcribe.
261+ requests : (list[Request ]) List of requests ( frames or feature buffers) to transcribe.
253262 log_probs: (Tensor) Log probabilities.
254263 tail_log_probs: (Tensor | None) Tail log probabilities.
255264 ready_state_ids: (set) Set of ready state IDs.
256265 """
257266
258- for idx , frame in enumerate (frames ):
259- state = self .get_state (frame .stream_id )
260- eou_detected = self .run_greedy_decoder (state , frame , log_probs [idx ])
267+ for idx , request in enumerate (requests ):
268+ state = self .get_state (request .stream_id )
269+ eou_detected = self .run_greedy_decoder (state , request , log_probs [idx ])
261270
262271 if eou_detected :
263272 self .bpe_decoder .decode_bpe_tokens (state )
264273 state .cleanup_after_eou ()
265- ready_state_ids .add (frame .stream_id )
274+ ready_state_ids .add (request .stream_id )
266275
267276 if tail_log_probs is not None :
268277 last_token = state .label_buffer [- 1 ] if len (state .label_buffer ) > 0 else self .blank_id
@@ -273,15 +282,15 @@ def decode_log_probs(
273282
274283 def cache_aware_transcribe_step (
275284 self ,
276- frames : list [Frame ],
285+ requests : list [Request ],
277286 buffered_features : list [Tensor ],
278287 right_paddings : list [int ] | None ,
279288 ready_state_ids : set ,
280289 keep_all_outputs : bool = False ,
281290 ) -> None :
282291 """
283292 Cache Aware Transcribe Step
284- It receives a list of frames and features and do the following:
293+ It receives a list of requests (Frame or FeatureBuffer) and features and do the following:
285294
286295 1. Preprocess the features by stacking them and computing the lengths
287296 2. Get the context and mapping from the context manager for cache aware streaming
@@ -290,16 +299,16 @@ def cache_aware_transcribe_step(
290299 5. Decode the log probabilities and update the state
291300
292301 Args:
293- frames : (list[Frame ]) List of frames to transcribe.
302+ requests : (list[Request ]) List of requests ( frames or feature buffers) to transcribe.
294303 buffered_features: (list[Tensor]) List of buffered features.
295304 right_paddings: (list[int] | None) List of right paddings.
296305 ready_state_ids: (set) Set of ready state IDs.
297306 keep_all_outputs: (bool) Whether to keep all outputs or not.
298307 """
299308 feature_buffers , feature_buffer_lens = self .preprocess (buffered_features , right_paddings )
300309
301- stream_ids = [frame .stream_id for frame in frames ]
302- eos_flags = [frame .is_last for frame in frames ]
310+ stream_ids = [request .stream_id for request in requests ]
311+ eos_flags = [request .is_last for request in requests ]
303312 context , mapping = self .context_manager .get_context (stream_ids )
304313
305314 drop_extra_pre_encoded = 0 if not self .use_cache else self .asr_model .drop_extra_pre_encoded
@@ -318,7 +327,7 @@ def cache_aware_transcribe_step(
318327 log_probs = normalize_log_probs (log_probs )
319328 self .context_manager .update_cache (stream_ids , new_context , mapping )
320329 self .context_manager .reset_slots (stream_ids , eos_flags )
321- self .decode_log_probs (frames , log_probs , tail_log_probs , ready_state_ids )
330+ self .decode_log_probs (requests , log_probs , tail_log_probs , ready_state_ids )
322331
323332 def transcribe_step_for_frames (self , frames : list [Frame ]) -> None :
324333 """
@@ -362,8 +371,46 @@ def transcribe_step_for_frames(self, frames: list[Frame]) -> None:
362371 self .update_partial_transcript (frames , self .tokenizer , self .leading_regex_pattern )
363372
364373 def transcribe_step_for_feature_buffers (self , fbuffers : list [FeatureBuffer ]) -> None :
365- """Transcribe a step for feature buffers"""
366- raise NotImplementedError ("Feature buffer type is not supported for cache aware streaming." )
374+ """
375+ Transcribes the feature buffers in a streaming manner.
376+ After detecting EOU, it updates the state and run text processor.
377+ If there are multiple streams, it waits until all states are ready to run text processor.
378+ Args:
379+ fbuffers: (list[FeatureBuffer]) List of feature buffers to transcribe.
380+ """
381+ ready_state_ids = set ()
382+
383+ final_fbuffers , final_features = [], []
384+ nonfinal_fbuffers , nonfinal_features = [], []
385+ final_right_paddings = []
386+
387+ for fbuffer in fbuffers :
388+ feature = fbuffer .features
389+ right_padding = max (0 , self .expected_feature_buffer_len - fbuffer .valid_size )
390+
391+ if fbuffer .is_last :
392+ final_fbuffers .append (fbuffer )
393+ final_features .append (feature )
394+ final_right_paddings .append (right_padding )
395+ else :
396+ nonfinal_fbuffers .append (fbuffer )
397+ nonfinal_features .append (feature )
398+
399+ if len (nonfinal_fbuffers ) > 0 :
400+ self .cache_aware_transcribe_step (
401+ nonfinal_fbuffers , nonfinal_features , None , ready_state_ids , keep_all_outputs = False
402+ )
403+
404+ if len (final_fbuffers ) > 0 :
405+ self .cache_aware_transcribe_step (
406+ final_fbuffers , final_features , final_right_paddings , ready_state_ids , keep_all_outputs = True
407+ )
408+
409+ if len (ready_state_ids ) > 0 :
410+ self .text_processor .process ([self .get_state (stream_id ) for stream_id in ready_state_ids ])
411+ ready_state_ids .clear ()
412+
413+ self .update_partial_transcript (fbuffers , self .tokenizer , self .leading_regex_pattern )
367414
368415 def get_request_generator (self ) -> ContinuousBatchedRequestStreamer :
369416 """
@@ -377,9 +424,9 @@ def get_request_generator(self) -> ContinuousBatchedRequestStreamer:
377424 sample_rate = self .sample_rate ,
378425 batch_size = self .batch_size ,
379426 request_type = self .request_type ,
380- preprocessor = None ,
381- buffer_size_in_secs = None ,
382- device = None ,
427+ preprocessor = self . preprocessor ,
428+ buffer_size_in_secs = self . buffer_size_in_secs ,
429+ device = self . device ,
383430 pad_last_frame = True ,
384431 )
385432 return request_generator
0 commit comments