@@ -244,11 +244,11 @@ def _template_read_video(video_object, s=0, e=None):
244
244
video_frames = torch .empty (0 )
245
245
frames = []
246
246
video_pts = []
247
- for frame in itertools .takewhile (lambda x : x [' pts' ] <= e , video_object ):
248
- if frame [' pts' ] < s :
247
+ for frame in itertools .takewhile (lambda x : x [" pts" ] <= e , video_object ):
248
+ if frame [" pts" ] < s :
249
249
continue
250
- frames .append (frame [' data' ])
251
- video_pts .append (frame [' pts' ])
250
+ frames .append (frame [" data" ])
251
+ video_pts .append (frame [" pts" ])
252
252
if len (frames ) > 0 :
253
253
video_frames = torch .stack (frames , 0 )
254
254
@@ -257,11 +257,11 @@ def _template_read_video(video_object, s=0, e=None):
257
257
audio_frames = torch .empty (0 )
258
258
frames = []
259
259
audio_pts = []
260
- for frame in itertools .takewhile (lambda x : x [' pts' ] <= e , video_object ):
261
- if frame [' pts' ] < s :
260
+ for frame in itertools .takewhile (lambda x : x [" pts" ] <= e , video_object ):
261
+ if frame [" pts" ] < s :
262
262
continue
263
- frames .append (frame [' data' ])
264
- audio_pts .append (frame [' pts' ])
263
+ frames .append (frame [" data" ])
264
+ audio_pts .append (frame [" pts" ])
265
265
if len (frames ) > 0 :
266
266
audio_frames = torch .stack (frames , 0 )
267
267
@@ -294,7 +294,7 @@ def test_read_video_tensor(self):
294
294
reader = VideoReader (full_path , "video" )
295
295
frames = []
296
296
for frame in reader :
297
- frames .append (frame [' data' ])
297
+ frames .append (frame [" data" ])
298
298
new_api = torch .stack (frames , 0 )
299
299
self .assertEqual (tv_result .size (), new_api .size ())
300
300
@@ -402,6 +402,45 @@ def test_video_reading_fn(self):
402
402
).item ()
403
403
self .assertEqual (is_same , True )
404
404
405
+ @unittest .skipIf (av is None , "PyAV unavailable" )
406
+ def test_keyframe_reading (self ):
407
+ for test_video , config in test_videos .items ():
408
+ full_path = os .path .join (VIDEO_DIR , test_video )
409
+
410
+ av_reader = av .open (full_path )
411
+ # reduce streams to only keyframes
412
+ av_stream = av_reader .streams .video [0 ]
413
+ av_stream .codec_context .skip_frame = "NONKEY"
414
+
415
+ av_keyframes = []
416
+ vr_keyframes = []
417
+ if av_reader .streams .video :
418
+
419
+ # get all keyframes using pyav. Then, seek randomly into video reader
420
+ # and assert that all the returned values are in AV_KEYFRAMES
421
+
422
+ for av_frame in av_reader .decode (av_stream ):
423
+ av_keyframes .append (float (av_frame .pts * av_frame .time_base ))
424
+
425
+ if len (av_keyframes ) > 1 :
426
+ video_reader = VideoReader (full_path , "video" )
427
+ for i in range (1 , len (av_keyframes )):
428
+ seek_val = (av_keyframes [i ] + av_keyframes [i - 1 ]) / 2
429
+ data = next (video_reader .seek (seek_val , True ))
430
+ vr_keyframes .append (data ["pts" ])
431
+
432
+ data = next (video_reader .seek (config .duration , True ))
433
+ vr_keyframes .append (data ["pts" ])
434
+
435
+ self .assertTrue (len (av_keyframes ) == len (vr_keyframes ))
436
+ # NOTE: this video gets different keyframe with different
437
+ # loaders (0.333 pyav, 0.666 for us)
438
+ if test_video != "TrumanShow_wave_f_nm_np1_fr_med_26.avi" :
439
+ for i in range (len (av_keyframes )):
440
+ self .assertAlmostEqual (
441
+ av_keyframes [i ], vr_keyframes [i ], delta = 0.001
442
+ )
443
+
405
444
406
445
if __name__ == "__main__" :
407
446
unittest .main ()
0 commit comments