@@ -462,19 +462,23 @@ def __getitem__(self, cuts):
462462 audios , audio_lens , cuts = self .load_audio (cuts )
463463 sampled_noises = [sample_noise (self .noise_data , cut .sampling_rate , cut .num_samples ) for cut in cuts ]
464464
465- items = [
466- AudioNoiseItem (
467- sample_id = str (cuts [i ].id ),
468- audio = audios [i ],
469- audio_len = audio_lens [i ],
470- noise = sampled_noises [i ][0 ],
471- noise_len = sampled_noises [i ][1 ],
472- noisy_audio = audios [i ] + sampled_noises [i ][0 ],
473- noisy_audio_len = audio_lens [i ],
474- )
475- for i in range (len (cuts ))
476- ]
477- return _audio_noise_collate_fn (items , self .batch_augmentor )
465+ sampled_noises , sampled_noises_lens = zip (* sampled_noises )
466+ sampled_noises = torch .stack (sampled_noises ).float ()
467+ sampled_noises_lens = torch .tensor (sampled_noises_lens ).long ()
468+
469+ output = AudioNoiseBatch (
470+ audio = audios ,
471+ audio_len = audio_lens ,
472+ noise = sampled_noises ,
473+ noise_len = sampled_noises_lens ,
474+ noisy_audio = audios + sampled_noises ,
475+ noisy_audio_len = audio_lens ,
476+ )
477+
478+ if self .batch_augmentor is not None :
479+ output = self .batch_augmentor (output )
480+
481+ return output
478482
479483
480484def get_audio_noise_dataset (
0 commit comments