Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/datasets/features/_torchcodec.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import numpy as np
from torchcodec.decoders import AudioDecoder as _AudioDecoder


class AudioDecoder(_AudioDecoder):
def __getitem__(self, key: str):
if key == "array":
y = self.get_all_samples().data.cpu().numpy()
return np.mean(y, axis=tuple(range(y.ndim - 1))) if y.ndim > 1 else y
if y.ndim <= 1:
return y
requested_num_channels = getattr(self, "_hf_num_channels", None)
if requested_num_channels == 1:
return y.squeeze(0)
return y
elif key == "sampling_rate":
return self.get_samples_played_in_range(0, 0).sample_rate
elif hasattr(super(), "__getitem__"):
Expand Down
1 change: 1 addition & 0 deletions src/datasets/features/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def decode_example(
audio = AudioDecoder(
bytes, stream_index=self.stream_index, sample_rate=self.sampling_rate, num_channels=self.num_channels
)
audio._hf_num_channels = self.num_channels
audio._hf_encoded = {"path": path, "bytes": bytes}
audio.metadata.path = path
return audio
Expand Down
14 changes: 10 additions & 4 deletions tests/features/test_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,15 +236,17 @@ def test_backwards_compatibility(shared_datadir):
assert isinstance(decoded_example, AudioDecoder)
samples = decoded_example.get_all_samples()
assert decoded_example["sampling_rate"] == samples.sample_rate
assert decoded_example["array"].ndim == 1 # mono
assert abs(decoded_example["array"].shape[0] - samples.data.shape[1]) < 2 # can have off by one error
assert decoded_example["array"].ndim == 2
assert decoded_example["array"].shape[0] == samples.data.shape[0]
assert abs(decoded_example["array"].shape[1] - samples.data.shape[1]) < 2 # can have off by one error

decoded_example = audio.decode_example(audio.encode_example(audio_path2))
assert isinstance(decoded_example, AudioDecoder)
samples = decoded_example.get_all_samples()
assert decoded_example["sampling_rate"] == samples.sample_rate
assert decoded_example["array"].ndim == 1 # mono
assert abs(decoded_example["array"].shape[0] - samples.data.shape[1]) < 2 # can have off by one error
assert decoded_example["array"].ndim == 2
assert decoded_example["array"].shape[0] == samples.data.shape[0]
assert abs(decoded_example["array"].shape[1] - samples.data.shape[1]) < 2 # can have off by one error


@require_torchcodec
Expand Down Expand Up @@ -801,6 +803,8 @@ def test_audio_decode_example_opus_convert_to_stereo(shared_datadir):
decoded_example = audio.decode_example(audio.encode_example(audio_path))
assert isinstance(decoded_example, AudioDecoder)
samples = decoded_example.get_all_samples()
assert decoded_example["array"].ndim == 2
assert decoded_example["array"].shape[0] == 2
assert samples.sample_rate == 48000
assert samples.data.shape == (2, 48000)

Expand All @@ -815,5 +819,7 @@ def test_audio_decode_example_opus_convert_to_mono(shared_datadir):
decoded_example = audio.decode_example(audio.encode_example(audio_path))
assert isinstance(decoded_example, AudioDecoder)
samples = decoded_example.get_all_samples()
assert decoded_example["array"].ndim == 1
assert abs(decoded_example["array"].shape[0] - samples.data.shape[1]) < 2
assert samples.sample_rate == 44100
assert samples.data.shape == (1, 202311)
Loading