7
7
from torchvision .datasets .video_utils import VideoClips , unfold
8
8
9
9
from common_utils import get_tmp_dir
10
+ from _assert_utils import assert_equal
10
11
11
12
12
13
@contextlib .contextmanager
@@ -40,46 +41,46 @@ def test_unfold(self):
40
41
[0 , 1 , 2 ],
41
42
[3 , 4 , 5 ],
42
43
])
43
- self . assertTrue ( r . equal ( expected ) )
44
+ assert_equal ( r , expected , check_stride = False )
44
45
45
46
r = unfold (a , 3 , 2 , 1 )
46
47
expected = torch .tensor ([
47
48
[0 , 1 , 2 ],
48
49
[2 , 3 , 4 ],
49
50
[4 , 5 , 6 ]
50
51
])
51
- self . assertTrue ( r . equal ( expected ) )
52
+ assert_equal ( r , expected , check_stride = False )
52
53
53
54
r = unfold (a , 3 , 2 , 2 )
54
55
expected = torch .tensor ([
55
56
[0 , 2 , 4 ],
56
57
[2 , 4 , 6 ],
57
58
])
58
- self . assertTrue ( r . equal ( expected ) )
59
+ assert_equal ( r , expected , check_stride = False )
59
60
60
61
@unittest .skipIf (not io .video ._av_available (), "this test requires av" )
61
62
def test_video_clips (self ):
62
63
with get_list_of_videos (num_videos = 3 ) as video_list :
63
64
video_clips = VideoClips (video_list , 5 , 5 , num_workers = 2 )
64
- self . assertEqual ( video_clips .num_clips (), 1 + 2 + 3 )
65
+ assert video_clips .num_clips () == 1 + 2 + 3
65
66
for i , (v_idx , c_idx ) in enumerate ([(0 , 0 ), (1 , 0 ), (1 , 1 ), (2 , 0 ), (2 , 1 ), (2 , 2 )]):
66
67
video_idx , clip_idx = video_clips .get_clip_location (i )
67
- self . assertEqual ( video_idx , v_idx )
68
- self . assertEqual ( clip_idx , c_idx )
68
+ assert video_idx == v_idx
69
+ assert clip_idx == c_idx
69
70
70
71
video_clips = VideoClips (video_list , 6 , 6 )
71
- self . assertEqual ( video_clips .num_clips (), 0 + 1 + 2 )
72
+ assert video_clips .num_clips () == 0 + 1 + 2
72
73
for i , (v_idx , c_idx ) in enumerate ([(1 , 0 ), (2 , 0 ), (2 , 1 )]):
73
74
video_idx , clip_idx = video_clips .get_clip_location (i )
74
- self . assertEqual ( video_idx , v_idx )
75
- self . assertEqual ( clip_idx , c_idx )
75
+ assert video_idx == v_idx
76
+ assert clip_idx == c_idx
76
77
77
78
video_clips = VideoClips (video_list , 6 , 1 )
78
- self . assertEqual ( video_clips .num_clips (), 0 + (10 - 6 + 1 ) + (15 - 6 + 1 ) )
79
+ assert video_clips .num_clips () == 0 + (10 - 6 + 1 ) + (15 - 6 + 1 )
79
80
for i , v_idx , c_idx in [(0 , 1 , 0 ), (4 , 1 , 4 ), (5 , 2 , 0 ), (6 , 2 , 1 )]:
80
81
video_idx , clip_idx = video_clips .get_clip_location (i )
81
- self . assertEqual ( video_idx , v_idx )
82
- self . assertEqual ( clip_idx , c_idx )
82
+ assert video_idx == v_idx
83
+ assert clip_idx == c_idx
83
84
84
85
@unittest .skipIf (not io .video ._av_available (), "this test requires av" )
85
86
def test_video_clips_custom_fps (self ):
@@ -89,8 +90,8 @@ def test_video_clips_custom_fps(self):
89
90
video_clips = VideoClips (video_list , num_frames , num_frames , fps , num_workers = 2 )
90
91
for i in range (video_clips .num_clips ()):
91
92
video , audio , info , video_idx = video_clips .get_clip (i )
92
- self . assertEqual ( video .shape [0 ], num_frames )
93
- self . assertEqual ( info ["video_fps" ], fps )
93
+ assert video .shape [0 ] == num_frames
94
+ assert info ["video_fps" ] == fps
94
95
# TODO add tests checking that the content is right
95
96
96
97
def test_compute_clips_for_video (self ):
@@ -103,9 +104,9 @@ def test_compute_clips_for_video(self):
103
104
clips , idxs = VideoClips .compute_clips_for_video (video_pts , num_frames , num_frames ,
104
105
orig_fps , new_fps )
105
106
resampled_idxs = VideoClips ._resample_video_idx (int (duration * new_fps ), orig_fps , new_fps )
106
- self . assertEqual ( len (clips ), 1 )
107
- self . assertTrue (clips . equal ( idxs ) )
108
- self . assertTrue (idxs [0 ]. equal ( resampled_idxs ) )
107
+ assert len (clips ) == 1
108
+ assert_equal (clips , idxs )
109
+ assert_equal (idxs [0 ], resampled_idxs )
109
110
110
111
# case 2: all frames appear only once
111
112
num_frames = 4
@@ -115,9 +116,9 @@ def test_compute_clips_for_video(self):
115
116
clips , idxs = VideoClips .compute_clips_for_video (video_pts , num_frames , num_frames ,
116
117
orig_fps , new_fps )
117
118
resampled_idxs = VideoClips ._resample_video_idx (int (duration * new_fps ), orig_fps , new_fps )
118
- self . assertEqual ( len (clips ), 3 )
119
- self . assertTrue (clips . equal ( idxs ) )
120
- self . assertTrue (idxs .flatten (). equal ( resampled_idxs ) )
119
+ assert len (clips ) == 3
120
+ assert_equal (clips , idxs )
121
+ assert_equal (idxs .flatten (), resampled_idxs )
121
122
122
123
# case 3: frames aren't enough for a clip
123
124
num_frames = 32
@@ -126,8 +127,8 @@ def test_compute_clips_for_video(self):
126
127
with self .assertWarns (UserWarning ):
127
128
clips , idxs = VideoClips .compute_clips_for_video (video_pts , num_frames , num_frames ,
128
129
orig_fps , new_fps )
129
- self . assertEqual ( len (clips ), 0 )
130
- self . assertEqual ( len (idxs ), 0 )
130
+ assert len (clips ) == 0
131
+ assert len (idxs ) == 0
131
132
132
133
133
134
if __name__ == '__main__' :
0 commit comments