diff --git a/test/test_video_reader.py b/test/test_video_reader.py index d9326138397..9818b6fc900 100644 --- a/test/test_video_reader.py +++ b/test/test_video_reader.py @@ -1238,6 +1238,45 @@ def test_read_video_from_memory_scripted(self): ) # FUTURE: check value of video / audio frames + def test_audio_video_sync(self): + """Test if audio/video are synchronised with pyav output.""" + for test_video, config in test_videos.items(): + full_path = os.path.join(VIDEO_DIR, test_video) + container = av.open(full_path) + if not container.streams.audio: + # Skip if no audio stream + continue + start_pts_val, cutoff = 0, 1 + if container.streams.video: + video = container.streams.video[0] + arr = [] + for index, frame in enumerate(container.decode(video)): + if index == cutoff: + start_pts_val = frame.pts + if index >= cutoff: + arr.append(frame.to_rgb().to_ndarray()) + visual, _, info = io.read_video(full_path, start_pts=start_pts_val, pts_unit='pts') + self.assertAlmostEqual( + config.video_fps, info['video_fps'], delta=0.0001 + ) + arr = torch.Tensor(arr) + if arr.shape == visual.shape: + self.assertGreaterEqual( + torch.mean(torch.isclose(visual.float(), arr, atol=1e-5).float()), 0.99) + + container = av.open(full_path) + if container.streams.audio: + audio = container.streams.audio[0] + arr = [] + for index, frame in enumerate(container.decode(audio)): + if index >= cutoff: + arr.append(frame.to_ndarray()) + _, audio, _ = io.read_video(full_path, start_pts=start_pts_val, pts_unit='pts') + arr = torch.as_tensor(np.concatenate(arr, axis=1)) + if arr.shape == audio.shape: + self.assertGreaterEqual( + torch.mean(torch.isclose(audio.float(), arr).float()), 0.99) + if __name__ == "__main__": unittest.main() diff --git a/torchvision/io/_video_opt.py b/torchvision/io/_video_opt.py index 4cc2b60c706..a34b023bc6c 100644 --- a/torchvision/io/_video_opt.py +++ b/torchvision/io/_video_opt.py @@ -471,6 +471,14 @@ def _probe_video_from_memory(video_data): return info +def _convert_to_sec(start_pts, end_pts, pts_unit, time_base): + if pts_unit == 'pts': + start_pts = float(start_pts * time_base) + end_pts = float(end_pts * time_base) + pts_unit = 'sec' + return start_pts, end_pts, pts_unit + + def _read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"): if end_pts is None: end_pts = float("inf") @@ -485,32 +493,43 @@ def _read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"): has_video = info.has_video has_audio = info.has_audio + video_pts_range = (0, -1) + video_timebase = default_timebase + audio_pts_range = (0, -1) + audio_timebase = default_timebase + time_base = default_timebase + + if has_video: + video_timebase = Fraction( + info.video_timebase.numerator, info.video_timebase.denominator + ) + time_base = video_timebase + + if has_audio: + audio_timebase = Fraction( + info.audio_timebase.numerator, info.audio_timebase.denominator + ) + time_base = time_base if time_base else audio_timebase + + # video_timebase is the default time_base + start_pts_sec, end_pts_sec, pts_unit = _convert_to_sec( + start_pts, end_pts, pts_unit, time_base) def get_pts(time_base): - start_offset = start_pts - end_offset = end_pts + start_offset = start_pts_sec + end_offset = end_pts_sec if pts_unit == "sec": - start_offset = int(math.floor(start_pts * (1 / time_base))) + start_offset = int(math.floor(start_pts_sec * (1 / time_base))) if end_offset != float("inf"): - end_offset = int(math.ceil(end_pts * (1 / time_base))) + end_offset = int(math.ceil(end_pts_sec * (1 / time_base))) if end_offset == float("inf"): end_offset = -1 return start_offset, end_offset - video_pts_range = (0, -1) - video_timebase = default_timebase if has_video: - video_timebase = Fraction( - info.video_timebase.numerator, info.video_timebase.denominator - ) video_pts_range = get_pts(video_timebase) - audio_pts_range = (0, -1) - audio_timebase = default_timebase if has_audio: - audio_timebase = Fraction( - info.audio_timebase.numerator, info.audio_timebase.denominator - ) audio_pts_range = get_pts(audio_timebase) vframes, aframes, info = _read_video_from_file( diff --git a/torchvision/io/video.py b/torchvision/io/video.py index 22cad38d10b..e16e8906d97 100644 --- a/torchvision/io/video.py +++ b/torchvision/io/video.py @@ -278,11 +278,19 @@ def read_video( try: with av.open(filename, metadata_errors="ignore") as container: + time_base = _video_opt.default_timebase + if container.streams.video: + time_base = container.streams.video[0].time_base + elif container.streams.audio: + time_base = container.streams.audio[0].time_base + # video_timebase is the default time_base + start_pts_sec, end_pts_sec, pts_unit = _video_opt._convert_to_sec( + start_pts, end_pts, pts_unit, time_base) if container.streams.video: video_frames = _read_from_stream( container, - start_pts, - end_pts, + start_pts_sec, + end_pts_sec, pts_unit, container.streams.video[0], {"video": 0}, @@ -295,8 +303,8 @@ def read_video( if container.streams.audio: audio_frames = _read_from_stream( container, - start_pts, - end_pts, + start_pts_sec, + end_pts_sec, pts_unit, container.streams.audio[0], {"audio": 0},