Skip to content

Commit cabd3e9

Browse files
Update ssl_dataset.py (#14086)
Signed-off-by: Mahmoud Ashraf <hassouna97.ma@gmail.com>
1 parent 6cadb8e commit cabd3e9

File tree

1 file changed

+17
-13
lines changed

1 file changed

+17
-13
lines changed

nemo/collections/asr/data/ssl_dataset.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -462,19 +462,23 @@ def __getitem__(self, cuts):
462462
audios, audio_lens, cuts = self.load_audio(cuts)
463463
sampled_noises = [sample_noise(self.noise_data, cut.sampling_rate, cut.num_samples) for cut in cuts]
464464

465-
items = [
466-
AudioNoiseItem(
467-
sample_id=str(cuts[i].id),
468-
audio=audios[i],
469-
audio_len=audio_lens[i],
470-
noise=sampled_noises[i][0],
471-
noise_len=sampled_noises[i][1],
472-
noisy_audio=audios[i] + sampled_noises[i][0],
473-
noisy_audio_len=audio_lens[i],
474-
)
475-
for i in range(len(cuts))
476-
]
477-
return _audio_noise_collate_fn(items, self.batch_augmentor)
465+
sampled_noises, sampled_noises_lens = zip(*sampled_noises)
466+
sampled_noises = torch.stack(sampled_noises).float()
467+
sampled_noises_lens = torch.tensor(sampled_noises_lens).long()
468+
469+
output = AudioNoiseBatch(
470+
audio=audios,
471+
audio_len=audio_lens,
472+
noise=sampled_noises,
473+
noise_len=sampled_noises_lens,
474+
noisy_audio=audios + sampled_noises,
475+
noisy_audio_len=audio_lens,
476+
)
477+
478+
if self.batch_augmentor is not None:
479+
output = self.batch_augmentor(output)
480+
481+
return output
478482

479483

480484
def get_audio_noise_dataset(

0 commit comments

Comments
 (0)