@@ -102,6 +102,10 @@ def __init__(self, config: Dict[str, Any]):
102102 self .augmentor_cfg = config .get ('augmentor' , None )
103103 self .sample_rate = config ['sample_rate' ]
104104
105+ self .pad_min_duration = config .get ('pad_min_duration' , 1.0 )
106+ self .pad_direction = config .get ('pad_direction' , 'both' )
107+ self .pad_min_samples = int (self .pad_min_duration * self .sample_rate )
108+
105109 if self .augmentor_cfg is not None :
106110 self .augmentor = process_augmentations (self .augmentor_cfg , global_rank = 0 , world_size = 1 )
107111 else :
@@ -118,6 +122,25 @@ def __getitem__(self, index):
118122 def __len__ (self ):
119123 return self .length
120124
125+ def _pad_audio (self , samples : torch .Tensor ) -> torch .Tensor :
126+ """Pad audio to minimum duration, matching Lhotse dataloader behavior."""
127+ current_len = samples .shape [0 ]
128+ if current_len >= self .pad_min_samples :
129+ return samples
130+
131+ pad_total = self .pad_min_samples - current_len
132+ if self .pad_direction == 'both' :
133+ pad_left = pad_total // 2
134+ pad_right = pad_total - pad_left
135+ elif self .pad_direction == 'left' :
136+ pad_left = pad_total
137+ pad_right = 0
138+ else : # right (default)
139+ pad_left = 0
140+ pad_right = pad_total
141+ samples = torch .nn .functional .pad (samples , (pad_left , pad_right ), mode = 'constant' , value = 0.0 )
142+ return samples
143+
121144 def get_item (self , index ):
122145 samples = self .audio_tensors [index ]
123146
@@ -136,7 +159,7 @@ def get_item(self, index):
136159 samples = self .augmentor .perturb (segment )
137160 samples = torch .tensor (samples .samples , dtype = original_dtype )
138161
139- # Calculate seq length
162+ samples = self . _pad_audio ( samples )
140163 seq_len = torch .tensor (samples .shape [0 ], dtype = torch .long )
141164
142165 # Typically NeMo ASR models expect the mini-batch to be a 4-tuple of (audio, audio_len, text, text_len).
@@ -538,6 +561,8 @@ def _transcribe_input_tensor_processing(
538561 'num_workers' : get_value_from_transcription_config (trcfg , 'num_workers' , 0 ),
539562 'channel_selector' : get_value_from_transcription_config (trcfg , 'channel_selector' , None ),
540563 'sample_rate' : sample_rate ,
564+ 'pad_min_duration' : get_value_from_transcription_config (trcfg , 'pad_min_duration' , 1.0 ),
565+ 'pad_direction' : get_value_from_transcription_config (trcfg , 'pad_direction' , 'both' ),
541566 }
542567
543568 augmentor = get_value_from_transcription_config (trcfg , 'augmentor' , None )
0 commit comments