Skip to content

Commit ef6e129

Browse files
committed
[somewhat broken] test suite for keyframes using pyav
1 parent 3760271 commit ef6e129

File tree

1 file changed

+48
-9
lines changed

1 file changed

+48
-9
lines changed

test/test_video.py

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -244,11 +244,11 @@ def _template_read_video(video_object, s=0, e=None):
244244
video_frames = torch.empty(0)
245245
frames = []
246246
video_pts = []
247-
for frame in itertools.takewhile(lambda x: x['pts'] <= e, video_object):
248-
if frame['pts'] < s:
247+
for frame in itertools.takewhile(lambda x: x["pts"] <= e, video_object):
248+
if frame["pts"] < s:
249249
continue
250-
frames.append(frame['data'])
251-
video_pts.append(frame['pts'])
250+
frames.append(frame["data"])
251+
video_pts.append(frame["pts"])
252252
if len(frames) > 0:
253253
video_frames = torch.stack(frames, 0)
254254

@@ -257,11 +257,11 @@ def _template_read_video(video_object, s=0, e=None):
257257
audio_frames = torch.empty(0)
258258
frames = []
259259
audio_pts = []
260-
for frame in itertools.takewhile(lambda x: x['pts'] <= e, video_object):
261-
if frame['pts'] < s:
260+
for frame in itertools.takewhile(lambda x: x["pts"] <= e, video_object):
261+
if frame["pts"] < s:
262262
continue
263-
frames.append(frame['data'])
264-
audio_pts.append(frame['pts'])
263+
frames.append(frame["data"])
264+
audio_pts.append(frame["pts"])
265265
if len(frames) > 0:
266266
audio_frames = torch.stack(frames, 0)
267267

@@ -294,7 +294,7 @@ def test_read_video_tensor(self):
294294
reader = VideoReader(full_path, "video")
295295
frames = []
296296
for frame in reader:
297-
frames.append(frame['data'])
297+
frames.append(frame["data"])
298298
new_api = torch.stack(frames, 0)
299299
self.assertEqual(tv_result.size(), new_api.size())
300300

@@ -402,6 +402,45 @@ def test_video_reading_fn(self):
402402
).item()
403403
self.assertEqual(is_same, True)
404404

405+
@unittest.skipIf(av is None, "PyAV unavailable")
406+
def test_keyframe_reading(self):
407+
for test_video, config in test_videos.items():
408+
full_path = os.path.join(VIDEO_DIR, test_video)
409+
410+
av_reader = av.open(full_path)
411+
# reduce streams to only keyframes
412+
av_stream = av_reader.streams.video[0]
413+
av_stream.codec_context.skip_frame = "NONKEY"
414+
415+
av_keyframes = []
416+
vr_keyframes = []
417+
if av_reader.streams.video:
418+
419+
# get all keyframes using pyav. Then, seek randomly into video reader
420+
# and assert that all the returned values are in AV_KEYFRAMES
421+
422+
for av_frame in av_reader.decode(av_stream):
423+
av_keyframes.append(float(av_frame.pts * av_frame.time_base))
424+
425+
if len(av_keyframes) > 1:
426+
video_reader = VideoReader(full_path, "video")
427+
for i in range(1, len(av_keyframes)):
428+
seek_val = (av_keyframes[i] + av_keyframes[i - 1]) / 2
429+
data = next(video_reader.seek(seek_val, True))
430+
vr_keyframes.append(data["pts"])
431+
432+
data = next(video_reader.seek(config.duration, True))
433+
vr_keyframes.append(data["pts"])
434+
435+
self.assertTrue(len(av_keyframes) == len(vr_keyframes))
436+
# NOTE: this video gets different keyframe with different
437+
# loaders (0.333 pyav, 0.666 for us)
438+
if test_video != "TrumanShow_wave_f_nm_np1_fr_med_26.avi":
439+
for i in range(len(av_keyframes)):
440+
self.assertAlmostEqual(
441+
av_keyframes[i], vr_keyframes[i], delta=0.001
442+
)
443+
405444

406445
if __name__ == "__main__":
407446
unittest.main()

0 commit comments

Comments
 (0)