31
31
VIDEO_DIR = os .path .join (os .path .dirname (os .path .abspath (__file__ )), "assets" , "videos" )
32
32
33
33
CheckerConfig = [
34
+ "duration" ,
34
35
"video_fps" ,
35
36
"audio_sample_rate" ,
36
37
# We find for some videos (e.g. HMDB51 videos), the decoded audio frames and pts are
44
45
)
45
46
46
47
all_check_config = GroundTruth (
48
+ duration = 0 ,
47
49
video_fps = 0 ,
48
50
audio_sample_rate = 0 ,
49
51
check_aframes = True ,
52
54
53
55
test_videos = {
54
56
"RATRACE_wave_f_nm_np1_fr_goo_37.avi" : GroundTruth (
57
+ duration = 2.0 ,
55
58
video_fps = 30.0 ,
56
59
audio_sample_rate = None ,
57
60
check_aframes = True ,
58
61
check_aframe_pts = True ,
59
62
),
60
63
"SchoolRulesHowTheyHelpUs_wave_f_nm_np1_ba_med_0.avi" : GroundTruth (
64
+ duration = 2.0 ,
61
65
video_fps = 30.0 ,
62
66
audio_sample_rate = None ,
63
67
check_aframes = True ,
64
68
check_aframe_pts = True ,
65
69
),
66
70
"TrumanShow_wave_f_nm_np1_fr_med_26.avi" : GroundTruth (
71
+ duration = 2.0 ,
67
72
video_fps = 30.0 ,
68
73
audio_sample_rate = None ,
69
74
check_aframes = True ,
70
75
check_aframe_pts = True ,
71
76
),
72
77
"v_SoccerJuggling_g23_c01.avi" : GroundTruth (
78
+ duration = 8.0 ,
73
79
video_fps = 29.97 ,
74
80
audio_sample_rate = None ,
75
81
check_aframes = True ,
76
82
check_aframe_pts = True ,
77
83
),
78
84
"v_SoccerJuggling_g24_c01.avi" : GroundTruth (
85
+ duration = 8.0 ,
79
86
video_fps = 29.97 ,
80
87
audio_sample_rate = None ,
81
88
check_aframes = True ,
82
89
check_aframe_pts = True ,
83
90
),
84
91
"R6llTwEh07w.mp4" : GroundTruth (
92
+ duration = 10.0 ,
85
93
video_fps = 30.0 ,
86
94
audio_sample_rate = 44100 ,
87
95
# PyAv miss one audio frame at the beginning (pts=0)
88
96
check_aframes = False ,
89
97
check_aframe_pts = False ,
90
98
),
91
99
"SOX5yA1l24A.mp4" : GroundTruth (
100
+ duration = 11.0 ,
92
101
video_fps = 29.97 ,
93
102
audio_sample_rate = 48000 ,
94
103
# PyAv miss one audio frame at the beginning (pts=0)
95
104
check_aframes = False ,
96
105
check_aframe_pts = False ,
97
106
),
98
107
"WUzgd7C1pWA.mp4" : GroundTruth (
108
+ duration = 11.0 ,
99
109
video_fps = 29.97 ,
100
110
audio_sample_rate = 48000 ,
101
111
# PyAv miss one audio frame at the beginning (pts=0)
@@ -272,13 +282,22 @@ class TestVideoReader(unittest.TestCase):
272
282
def check_separate_decoding_result (self , tv_result , config ):
273
283
"""check the decoding results from TorchVision decoder
274
284
"""
275
- vframes , vframe_pts , vtimebase , vfps , aframes , aframe_pts , atimebase , asample_rate = (
276
- tv_result
285
+ vframes , vframe_pts , vtimebase , vfps , vduration , aframes , aframe_pts , \
286
+ atimebase , asample_rate , aduration = tv_result
287
+
288
+ video_duration = vduration .item () * Fraction (
289
+ vtimebase [0 ].item (), vtimebase [1 ].item ()
277
290
)
291
+ self .assertAlmostEqual (video_duration , config .duration , delta = 0.5 )
278
292
279
293
self .assertAlmostEqual (vfps .item (), config .video_fps , delta = 0.5 )
280
294
if asample_rate .numel () > 0 :
281
295
self .assertEqual (asample_rate .item (), config .audio_sample_rate )
296
+ audio_duration = aduration .item () * Fraction (
297
+ atimebase [0 ].item (), atimebase [1 ].item ()
298
+ )
299
+ self .assertAlmostEqual (audio_duration , config .duration , delta = 0.5 )
300
+
282
301
# check if pts of video frames are sorted in ascending order
283
302
for i in range (len (vframe_pts ) - 1 ):
284
303
self .assertEqual (vframe_pts [i ] < vframe_pts [i + 1 ], True )
@@ -288,6 +307,20 @@ def check_separate_decoding_result(self, tv_result, config):
288
307
for i in range (len (aframe_pts ) - 1 ):
289
308
self .assertEqual (aframe_pts [i ] < aframe_pts [i + 1 ], True )
290
309
310
+ def check_probe_result (self , result , config ):
311
+ vtimebase , vfps , vduration , atimebase , asample_rate , aduration = result
312
+ video_duration = vduration .item () * Fraction (
313
+ vtimebase [0 ].item (), vtimebase [1 ].item ()
314
+ )
315
+ self .assertAlmostEqual (video_duration , config .duration , delta = 0.5 )
316
+ self .assertAlmostEqual (vfps .item (), config .video_fps , delta = 0.5 )
317
+ if asample_rate .numel () > 0 :
318
+ self .assertEqual (asample_rate .item (), config .audio_sample_rate )
319
+ audio_duration = aduration .item () * Fraction (
320
+ atimebase [0 ].item (), atimebase [1 ].item ()
321
+ )
322
+ self .assertAlmostEqual (audio_duration , config .duration , delta = 0.5 )
323
+
291
324
def compare_decoding_result (self , tv_result , ref_result , config = all_check_config ):
292
325
"""
293
326
Compare decoding results from two sources.
@@ -297,18 +330,17 @@ def compare_decoding_result(self, tv_result, ref_result, config=all_check_config
297
330
decoder or TorchVision decoder with getPtsOnly = 1
298
331
config: config of decoding results checker
299
332
"""
300
- vframes , vframe_pts , vtimebase , _vfps , aframes , aframe_pts , atimebase , _asample_rate = (
301
- tv_result
302
- )
333
+ vframes , vframe_pts , vtimebase , _vfps , _vduration , aframes , aframe_pts , \
334
+ atimebase , _asample_rate , _aduration = tv_result
303
335
if isinstance (ref_result , list ):
304
336
# the ref_result is from new video_reader decoder
305
337
ref_result = DecoderResult (
306
338
vframes = ref_result [0 ],
307
339
vframe_pts = ref_result [1 ],
308
340
vtimebase = ref_result [2 ],
309
- aframes = ref_result [4 ],
310
- aframe_pts = ref_result [5 ],
311
- atimebase = ref_result [6 ],
341
+ aframes = ref_result [5 ],
342
+ aframe_pts = ref_result [6 ],
343
+ atimebase = ref_result [7 ],
312
344
)
313
345
314
346
if vframes .numel () > 0 and ref_result .vframes .numel () > 0 :
@@ -351,12 +383,12 @@ def test_stress_test_read_video_from_file(self):
351
383
audio_start_pts , audio_end_pts = 0 , - 1
352
384
audio_timebase_num , audio_timebase_den = 0 , 1
353
385
354
- for i in range (num_iter ):
355
- for test_video , config in test_videos .items ():
386
+ for _i in range (num_iter ):
387
+ for test_video , _config in test_videos .items ():
356
388
full_path = os .path .join (VIDEO_DIR , test_video )
357
389
358
390
# pass 1: decode all frames using new decoder
359
- _ = torch .ops .video_reader .read_video_from_file (
391
+ torch .ops .video_reader .read_video_from_file (
360
392
full_path ,
361
393
seek_frame_margin ,
362
394
0 , # getPtsOnly
@@ -460,9 +492,8 @@ def test_read_video_from_file_read_single_stream_only(self):
460
492
audio_timebase_den ,
461
493
)
462
494
463
- vframes , vframe_pts , vtimebase , vfps , aframes , aframe_pts , atimebase , asample_rate = (
464
- tv_result
465
- )
495
+ vframes , vframe_pts , vtimebase , vfps , vduration , aframes , aframe_pts , \
496
+ atimebase , asample_rate , aduration = tv_result
466
497
467
498
self .assertEqual (vframes .numel () > 0 , readVideoStream )
468
499
self .assertEqual (vframe_pts .numel () > 0 , readVideoStream )
@@ -489,7 +520,7 @@ def test_read_video_from_file_rescale_min_dimension(self):
489
520
audio_start_pts , audio_end_pts = 0 , - 1
490
521
audio_timebase_num , audio_timebase_den = 0 , 1
491
522
492
- for test_video , config in test_videos .items ():
523
+ for test_video , _config in test_videos .items ():
493
524
full_path = os .path .join (VIDEO_DIR , test_video )
494
525
495
526
tv_result = torch .ops .video_reader .read_video_from_file (
@@ -528,7 +559,7 @@ def test_read_video_from_file_rescale_width(self):
528
559
audio_start_pts , audio_end_pts = 0 , - 1
529
560
audio_timebase_num , audio_timebase_den = 0 , 1
530
561
531
- for test_video , config in test_videos .items ():
562
+ for test_video , _config in test_videos .items ():
532
563
full_path = os .path .join (VIDEO_DIR , test_video )
533
564
534
565
tv_result = torch .ops .video_reader .read_video_from_file (
@@ -567,7 +598,7 @@ def test_read_video_from_file_rescale_height(self):
567
598
audio_start_pts , audio_end_pts = 0 , - 1
568
599
audio_timebase_num , audio_timebase_den = 0 , 1
569
600
570
- for test_video , config in test_videos .items ():
601
+ for test_video , _config in test_videos .items ():
571
602
full_path = os .path .join (VIDEO_DIR , test_video )
572
603
573
604
tv_result = torch .ops .video_reader .read_video_from_file (
@@ -606,7 +637,7 @@ def test_read_video_from_file_rescale_width_and_height(self):
606
637
audio_start_pts , audio_end_pts = 0 , - 1
607
638
audio_timebase_num , audio_timebase_den = 0 , 1
608
639
609
- for test_video , config in test_videos .items ():
640
+ for test_video , _config in test_videos .items ():
610
641
full_path = os .path .join (VIDEO_DIR , test_video )
611
642
612
643
tv_result = torch .ops .video_reader .read_video_from_file (
@@ -651,7 +682,7 @@ def test_read_video_from_file_audio_resampling(self):
651
682
audio_start_pts , audio_end_pts = 0 , - 1
652
683
audio_timebase_num , audio_timebase_den = 0 , 1
653
684
654
- for test_video , config in test_videos .items ():
685
+ for test_video , _config in test_videos .items ():
655
686
full_path = os .path .join (VIDEO_DIR , test_video )
656
687
657
688
tv_result = torch .ops .video_reader .read_video_from_file (
@@ -674,18 +705,17 @@ def test_read_video_from_file_audio_resampling(self):
674
705
audio_timebase_num ,
675
706
audio_timebase_den ,
676
707
)
677
- vframes , vframe_pts , vtimebase , vfps , aframes , aframe_pts , atimebase , a_sample_rate = (
678
- tv_result
679
- )
708
+ vframes , vframe_pts , vtimebase , vfps , vduration , aframes , aframe_pts , \
709
+ atimebase , asample_rate , aduration = tv_result
680
710
if aframes .numel () > 0 :
681
- self .assertEqual (samples , a_sample_rate .item ())
711
+ self .assertEqual (samples , asample_rate .item ())
682
712
self .assertEqual (1 , aframes .size (1 ))
683
713
# when audio stream is found
684
714
duration = float (aframe_pts [- 1 ]) * float (atimebase [0 ]) / float (atimebase [1 ])
685
715
self .assertAlmostEqual (
686
716
aframes .size (0 ),
687
- int (duration * a_sample_rate .item ()),
688
- delta = 0.1 * a_sample_rate .item (),
717
+ int (duration * asample_rate .item ()),
718
+ delta = 0.1 * asample_rate .item (),
689
719
)
690
720
691
721
def test_compare_read_video_from_memory_and_file (self ):
@@ -859,7 +889,7 @@ def test_read_video_from_memory_get_pts_only(self):
859
889
)
860
890
861
891
self .assertEqual (tv_result_pts_only [0 ].numel (), 0 )
862
- self .assertEqual (tv_result_pts_only [4 ].numel (), 0 )
892
+ self .assertEqual (tv_result_pts_only [5 ].numel (), 0 )
863
893
self .compare_decoding_result (tv_result , tv_result_pts_only )
864
894
865
895
def test_read_video_in_range_from_memory (self ):
@@ -899,9 +929,8 @@ def test_read_video_in_range_from_memory(self):
899
929
audio_timebase_num ,
900
930
audio_timebase_den ,
901
931
)
902
- vframes , vframe_pts , vtimebase , vfps , aframes , aframe_pts , atimebase , asample_rate = (
903
- tv_result
904
- )
932
+ vframes , vframe_pts , vtimebase , vfps , vduration , aframes , aframe_pts , \
933
+ atimebase , asample_rate , aduration = tv_result
905
934
self .assertAlmostEqual (config .video_fps , vfps .item (), delta = 0.01 )
906
935
907
936
for num_frames in [4 , 8 , 16 , 32 , 64 , 128 ]:
@@ -997,6 +1026,24 @@ def test_read_video_in_range_from_memory(self):
997
1026
# and PyAv
998
1027
self .compare_decoding_result (tv_result , pyav_result , config )
999
1028
1029
+ def test_probe_video_from_file (self ):
1030
+ """
1031
+ Test the case when decoder probes a video file
1032
+ """
1033
+ for test_video , config in test_videos .items ():
1034
+ full_path = os .path .join (VIDEO_DIR , test_video )
1035
+ probe_result = torch .ops .video_reader .probe_video_from_file (full_path )
1036
+ self .check_probe_result (probe_result , config )
1037
+
1038
+ def test_probe_video_from_memory (self ):
1039
+ """
1040
+ Test the case when decoder probes a video in memory
1041
+ """
1042
+ for test_video , config in test_videos .items ():
1043
+ full_path , video_tensor = _get_video_tensor (VIDEO_DIR , test_video )
1044
+ probe_result = torch .ops .video_reader .probe_video_from_memory (video_tensor )
1045
+ self .check_probe_result (probe_result , config )
1046
+
1000
1047
1001
1048
if __name__ == '__main__' :
1002
1049
unittest .main ()
0 commit comments