Skip to content

Commit 05e061f

Browse files
NicolasHugpmeier
andauthored
Use torch.testing.assert_close in datasets_video_utils.py (#3875)
Co-authored-by: Philip Meier <[email protected]>
1 parent 05a3941 commit 05e061f

File tree

1 file changed

+23
-22
lines changed

1 file changed

+23
-22
lines changed

test/test_datasets_video_utils.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from torchvision.datasets.video_utils import VideoClips, unfold
88

99
from common_utils import get_tmp_dir
10+
from _assert_utils import assert_equal
1011

1112

1213
@contextlib.contextmanager
@@ -40,46 +41,46 @@ def test_unfold(self):
4041
[0, 1, 2],
4142
[3, 4, 5],
4243
])
43-
self.assertTrue(r.equal(expected))
44+
assert_equal(r, expected, check_stride=False)
4445

4546
r = unfold(a, 3, 2, 1)
4647
expected = torch.tensor([
4748
[0, 1, 2],
4849
[2, 3, 4],
4950
[4, 5, 6]
5051
])
51-
self.assertTrue(r.equal(expected))
52+
assert_equal(r, expected, check_stride=False)
5253

5354
r = unfold(a, 3, 2, 2)
5455
expected = torch.tensor([
5556
[0, 2, 4],
5657
[2, 4, 6],
5758
])
58-
self.assertTrue(r.equal(expected))
59+
assert_equal(r, expected, check_stride=False)
5960

6061
@unittest.skipIf(not io.video._av_available(), "this test requires av")
6162
def test_video_clips(self):
6263
with get_list_of_videos(num_videos=3) as video_list:
6364
video_clips = VideoClips(video_list, 5, 5, num_workers=2)
64-
self.assertEqual(video_clips.num_clips(), 1 + 2 + 3)
65+
assert video_clips.num_clips() == 1 + 2 + 3
6566
for i, (v_idx, c_idx) in enumerate([(0, 0), (1, 0), (1, 1), (2, 0), (2, 1), (2, 2)]):
6667
video_idx, clip_idx = video_clips.get_clip_location(i)
67-
self.assertEqual(video_idx, v_idx)
68-
self.assertEqual(clip_idx, c_idx)
68+
assert video_idx == v_idx
69+
assert clip_idx == c_idx
6970

7071
video_clips = VideoClips(video_list, 6, 6)
71-
self.assertEqual(video_clips.num_clips(), 0 + 1 + 2)
72+
assert video_clips.num_clips() == 0 + 1 + 2
7273
for i, (v_idx, c_idx) in enumerate([(1, 0), (2, 0), (2, 1)]):
7374
video_idx, clip_idx = video_clips.get_clip_location(i)
74-
self.assertEqual(video_idx, v_idx)
75-
self.assertEqual(clip_idx, c_idx)
75+
assert video_idx == v_idx
76+
assert clip_idx == c_idx
7677

7778
video_clips = VideoClips(video_list, 6, 1)
78-
self.assertEqual(video_clips.num_clips(), 0 + (10 - 6 + 1) + (15 - 6 + 1))
79+
assert video_clips.num_clips() == 0 + (10 - 6 + 1) + (15 - 6 + 1)
7980
for i, v_idx, c_idx in [(0, 1, 0), (4, 1, 4), (5, 2, 0), (6, 2, 1)]:
8081
video_idx, clip_idx = video_clips.get_clip_location(i)
81-
self.assertEqual(video_idx, v_idx)
82-
self.assertEqual(clip_idx, c_idx)
82+
assert video_idx == v_idx
83+
assert clip_idx == c_idx
8384

8485
@unittest.skipIf(not io.video._av_available(), "this test requires av")
8586
def test_video_clips_custom_fps(self):
@@ -89,8 +90,8 @@ def test_video_clips_custom_fps(self):
8990
video_clips = VideoClips(video_list, num_frames, num_frames, fps, num_workers=2)
9091
for i in range(video_clips.num_clips()):
9192
video, audio, info, video_idx = video_clips.get_clip(i)
92-
self.assertEqual(video.shape[0], num_frames)
93-
self.assertEqual(info["video_fps"], fps)
93+
assert video.shape[0] == num_frames
94+
assert info["video_fps"] == fps
9495
# TODO add tests checking that the content is right
9596

9697
def test_compute_clips_for_video(self):
@@ -103,9 +104,9 @@ def test_compute_clips_for_video(self):
103104
clips, idxs = VideoClips.compute_clips_for_video(video_pts, num_frames, num_frames,
104105
orig_fps, new_fps)
105106
resampled_idxs = VideoClips._resample_video_idx(int(duration * new_fps), orig_fps, new_fps)
106-
self.assertEqual(len(clips), 1)
107-
self.assertTrue(clips.equal(idxs))
108-
self.assertTrue(idxs[0].equal(resampled_idxs))
107+
assert len(clips) == 1
108+
assert_equal(clips, idxs)
109+
assert_equal(idxs[0], resampled_idxs)
109110

110111
# case 2: all frames appear only once
111112
num_frames = 4
@@ -115,9 +116,9 @@ def test_compute_clips_for_video(self):
115116
clips, idxs = VideoClips.compute_clips_for_video(video_pts, num_frames, num_frames,
116117
orig_fps, new_fps)
117118
resampled_idxs = VideoClips._resample_video_idx(int(duration * new_fps), orig_fps, new_fps)
118-
self.assertEqual(len(clips), 3)
119-
self.assertTrue(clips.equal(idxs))
120-
self.assertTrue(idxs.flatten().equal(resampled_idxs))
119+
assert len(clips) == 3
120+
assert_equal(clips, idxs)
121+
assert_equal(idxs.flatten(), resampled_idxs)
121122

122123
# case 3: frames aren't enough for a clip
123124
num_frames = 32
@@ -126,8 +127,8 @@ def test_compute_clips_for_video(self):
126127
with self.assertWarns(UserWarning):
127128
clips, idxs = VideoClips.compute_clips_for_video(video_pts, num_frames, num_frames,
128129
orig_fps, new_fps)
129-
self.assertEqual(len(clips), 0)
130-
self.assertEqual(len(idxs), 0)
130+
assert len(clips) == 0
131+
assert len(idxs) == 0
131132

132133

133134
if __name__ == '__main__':

0 commit comments

Comments
 (0)