Skip to content

Commit c6df9b4

Browse files
authored
improve canary performance on short audio (#15317)
* improve canary performance on short audio Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com> * Apply isort and black reformatting Signed-off-by: nithinraok <nithinraok@users.noreply.github.com> * revert non canary models Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com> --------- Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com> Signed-off-by: nithinraok <nithinraok@users.noreply.github.com> Co-authored-by: nithinraok <nithinraok@users.noreply.github.com>
1 parent 119593e commit c6df9b4

File tree

1 file changed

+26
-1
lines changed

1 file changed

+26
-1
lines changed

nemo/collections/asr/parts/mixins/transcription.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,10 @@ def __init__(self, config: Dict[str, Any]):
102102
self.augmentor_cfg = config.get('augmentor', None)
103103
self.sample_rate = config['sample_rate']
104104

105+
self.pad_min_duration = config.get('pad_min_duration', 1.0)
106+
self.pad_direction = config.get('pad_direction', 'both')
107+
self.pad_min_samples = int(self.pad_min_duration * self.sample_rate)
108+
105109
if self.augmentor_cfg is not None:
106110
self.augmentor = process_augmentations(self.augmentor_cfg, global_rank=0, world_size=1)
107111
else:
@@ -118,6 +122,25 @@ def __getitem__(self, index):
118122
def __len__(self):
119123
return self.length
120124

125+
def _pad_audio(self, samples: torch.Tensor) -> torch.Tensor:
126+
"""Pad audio to minimum duration, matching Lhotse dataloader behavior."""
127+
current_len = samples.shape[0]
128+
if current_len >= self.pad_min_samples:
129+
return samples
130+
131+
pad_total = self.pad_min_samples - current_len
132+
if self.pad_direction == 'both':
133+
pad_left = pad_total // 2
134+
pad_right = pad_total - pad_left
135+
elif self.pad_direction == 'left':
136+
pad_left = pad_total
137+
pad_right = 0
138+
else: # right (default)
139+
pad_left = 0
140+
pad_right = pad_total
141+
samples = torch.nn.functional.pad(samples, (pad_left, pad_right), mode='constant', value=0.0)
142+
return samples
143+
121144
def get_item(self, index):
122145
samples = self.audio_tensors[index]
123146

@@ -136,7 +159,7 @@ def get_item(self, index):
136159
samples = self.augmentor.perturb(segment)
137160
samples = torch.tensor(samples.samples, dtype=original_dtype)
138161

139-
# Calculate seq length
162+
samples = self._pad_audio(samples)
140163
seq_len = torch.tensor(samples.shape[0], dtype=torch.long)
141164

142165
# Typically NeMo ASR models expect the mini-batch to be a 4-tuple of (audio, audio_len, text, text_len).
@@ -538,6 +561,8 @@ def _transcribe_input_tensor_processing(
538561
'num_workers': get_value_from_transcription_config(trcfg, 'num_workers', 0),
539562
'channel_selector': get_value_from_transcription_config(trcfg, 'channel_selector', None),
540563
'sample_rate': sample_rate,
564+
'pad_min_duration': get_value_from_transcription_config(trcfg, 'pad_min_duration', 1.0),
565+
'pad_direction': get_value_from_transcription_config(trcfg, 'pad_direction', 'both'),
541566
}
542567

543568
augmentor = get_value_from_transcription_config(trcfg, 'augmentor', None)

0 commit comments

Comments
 (0)