Skip to content

Commit 46f2308

Browse files
stephenyan1231fmassa
authored andcommitted
move sampler into TV core. Update UniformClipSampler (#1408)
* move sampler into TV core. Update UniformClipSampler * Fix reference training script * Skip test if pyav not available * change interpolation from round() to floor() as round(0.5) behaves differently between py2 and py3
1 parent e2c8e1a commit 46f2308

File tree

5 files changed

+113
-39
lines changed

5 files changed

+113
-39
lines changed

references/video_classification/train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@
1111
import torchvision
1212
import torchvision.datasets.video_utils
1313
from torchvision import transforms
14+
from torchvision.datasets.samplers import DistributedSampler, UniformClipSampler, RandomClipSampler
1415

1516
import utils
16-
from sampler import DistributedSampler, UniformClipSampler, RandomClipSampler
17+
1718
from scheduler import WarmupMultiStepLR
1819
import transforms as T
1920

test/test_datasets_samplers.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import contextlib
2+
import sys
3+
import os
4+
import torch
5+
import unittest
6+
7+
from torchvision import io
8+
from torchvision.datasets.samplers import RandomClipSampler, UniformClipSampler
9+
from torchvision.datasets.video_utils import VideoClips, unfold
10+
from torchvision import get_video_backend
11+
12+
from common_utils import get_tmp_dir
13+
14+
15+
@contextlib.contextmanager
16+
def get_list_of_videos(num_videos=5, sizes=None, fps=None):
17+
with get_tmp_dir() as tmp_dir:
18+
names = []
19+
for i in range(num_videos):
20+
if sizes is None:
21+
size = 5 * (i + 1)
22+
else:
23+
size = sizes[i]
24+
if fps is None:
25+
f = 5
26+
else:
27+
f = fps[i]
28+
data = torch.randint(0, 255, (size, 300, 400, 3), dtype=torch.uint8)
29+
name = os.path.join(tmp_dir, "{}.mp4".format(i))
30+
names.append(name)
31+
io.write_video(name, data, fps=f)
32+
33+
yield names
34+
35+
36+
@unittest.skipIf(not io.video._av_available(), "this test requires av")
37+
class Tester(unittest.TestCase):
38+
def test_random_clip_sampler(self):
39+
with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list:
40+
video_clips = VideoClips(video_list, 5, 5)
41+
sampler = RandomClipSampler(video_clips, 3)
42+
self.assertEqual(len(sampler), 3 * 3)
43+
indices = torch.tensor(list(iter(sampler)))
44+
videos = indices // 5
45+
v_idxs, count = torch.unique(videos, return_counts=True)
46+
self.assertTrue(v_idxs.equal(torch.tensor([0, 1, 2])))
47+
self.assertTrue(count.equal(torch.tensor([3, 3, 3])))
48+
49+
def test_random_clip_sampler_unequal(self):
50+
with get_list_of_videos(num_videos=3, sizes=[10, 25, 25]) as video_list:
51+
video_clips = VideoClips(video_list, 5, 5)
52+
sampler = RandomClipSampler(video_clips, 3)
53+
self.assertEqual(len(sampler), 2 + 3 + 3)
54+
indices = list(iter(sampler))
55+
self.assertIn(0, indices)
56+
self.assertIn(1, indices)
57+
# remove elements of the first video, to simplify testing
58+
indices.remove(0)
59+
indices.remove(1)
60+
indices = torch.tensor(indices) - 2
61+
videos = indices // 5
62+
v_idxs, count = torch.unique(videos, return_counts=True)
63+
self.assertTrue(v_idxs.equal(torch.tensor([0, 1])))
64+
self.assertTrue(count.equal(torch.tensor([3, 3])))
65+
66+
def test_uniform_clip_sampler(self):
67+
with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list:
68+
video_clips = VideoClips(video_list, 5, 5)
69+
sampler = UniformClipSampler(video_clips, 3)
70+
self.assertEqual(len(sampler), 3 * 3)
71+
indices = torch.tensor(list(iter(sampler)))
72+
videos = indices // 5
73+
v_idxs, count = torch.unique(videos, return_counts=True)
74+
self.assertTrue(v_idxs.equal(torch.tensor([0, 1, 2])))
75+
self.assertTrue(count.equal(torch.tensor([3, 3, 3])))
76+
self.assertTrue(indices.equal(torch.tensor([0, 2, 4, 5, 7, 9, 10, 12, 14])))
77+
78+
def test_uniform_clip_sampler_insufficient_clips(self):
79+
with get_list_of_videos(num_videos=3, sizes=[10, 25, 25]) as video_list:
80+
video_clips = VideoClips(video_list, 5, 5)
81+
sampler = UniformClipSampler(video_clips, 3)
82+
self.assertEqual(len(sampler), 3 * 3)
83+
indices = torch.tensor(list(iter(sampler)))
84+
self.assertTrue(indices.equal(torch.tensor([0, 0, 1, 2, 4, 6, 7, 9, 11])))
85+
86+
87+
if __name__ == '__main__':
88+
unittest.main()

test/test_datasets_video_utils.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -81,36 +81,6 @@ def test_video_clips(self):
8181
self.assertEqual(video_idx, v_idx)
8282
self.assertEqual(clip_idx, c_idx)
8383

84-
@unittest.skip("Moved to reference scripts for now")
85-
def test_video_sampler(self):
86-
with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list:
87-
video_clips = VideoClips(video_list, 5, 5)
88-
sampler = RandomClipSampler(video_clips, 3) # noqa: F821
89-
self.assertEqual(len(sampler), 3 * 3)
90-
indices = torch.tensor(list(iter(sampler)))
91-
videos = indices // 5
92-
v_idxs, count = torch.unique(videos, return_counts=True)
93-
self.assertTrue(v_idxs.equal(torch.tensor([0, 1, 2])))
94-
self.assertTrue(count.equal(torch.tensor([3, 3, 3])))
95-
96-
@unittest.skip("Moved to reference scripts for now")
97-
def test_video_sampler_unequal(self):
98-
with get_list_of_videos(num_videos=3, sizes=[10, 25, 25]) as video_list:
99-
video_clips = VideoClips(video_list, 5, 5)
100-
sampler = RandomClipSampler(video_clips, 3) # noqa: F821
101-
self.assertEqual(len(sampler), 2 + 3 + 3)
102-
indices = list(iter(sampler))
103-
self.assertIn(0, indices)
104-
self.assertIn(1, indices)
105-
# remove elements of the first video, to simplify testing
106-
indices.remove(0)
107-
indices.remove(1)
108-
indices = torch.tensor(indices) - 2
109-
videos = indices // 5
110-
v_idxs, count = torch.unique(videos, return_counts=True)
111-
self.assertTrue(v_idxs.equal(torch.tensor([0, 1])))
112-
self.assertTrue(count.equal(torch.tensor([3, 3])))
113-
11484
@unittest.skipIf(not io.video._av_available(), "this test requires av")
11585
def test_video_clips_custom_fps(self):
11686
with get_list_of_videos(num_videos=3, sizes=[12, 12, 12], fps=[3, 4, 6]) as video_list:
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .clip_sampler import DistributedSampler, UniformClipSampler, RandomClipSampler
2+
3+
__all__ = ('DistributedSampler', 'UniformClipSampler', 'RandomClipSampler')

references/video_classification/sampler.py renamed to torchvision/datasets/samplers/clip_sampler.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,33 +60,45 @@ def set_epoch(self, epoch):
6060

6161
class UniformClipSampler(torch.utils.data.Sampler):
6262
"""
63-
Samples at most `max_video_clips_per_video` clips for each video, equally spaced
63+
Sample `num_video_clips_per_video` clips for each video, equally spaced.
64+
When number of unique clips in the video is fewer than num_video_clips_per_video,
65+
repeat the clips until `num_video_clips_per_video` clips are collected
66+
6467
Arguments:
6568
video_clips (VideoClips): video clips to sample from
66-
max_clips_per_video (int): maximum number of clips to be sampled per video
69+
num_clips_per_video (int): number of clips to be sampled per video
6770
"""
68-
def __init__(self, video_clips, max_clips_per_video):
71+
def __init__(self, video_clips, num_clips_per_video):
6972
if not isinstance(video_clips, torchvision.datasets.video_utils.VideoClips):
7073
raise TypeError("Expected video_clips to be an instance of VideoClips, "
7174
"got {}".format(type(video_clips)))
7275
self.video_clips = video_clips
73-
self.max_clips_per_video = max_clips_per_video
76+
self.num_clips_per_video = num_clips_per_video
7477

7578
def __iter__(self):
7679
idxs = []
7780
s = 0
78-
# select at most max_clips_per_video for each video, uniformly spaced
81+
# select num_clips_per_video for each video, uniformly spaced
7982
for c in self.video_clips.clips:
8083
length = len(c)
81-
step = max(length // self.max_clips_per_video, 1)
82-
sampled = torch.arange(length)[::step] + s
84+
if length == 0:
85+
# corner case where video decoding fails
86+
continue
87+
88+
sampled = (
89+
torch.linspace(s, s + length - 1, steps=self.num_clips_per_video)
90+
.floor()
91+
.to(torch.int64)
92+
)
8393
s += length
8494
idxs.append(sampled)
8595
idxs = torch.cat(idxs).tolist()
8696
return iter(idxs)
8797

8898
def __len__(self):
89-
return sum(min(len(c), self.max_clips_per_video) for c in self.video_clips.clips)
99+
return sum(
100+
self.num_clips_per_video for c in self.video_clips.clips if len(c) > 0
101+
)
90102

91103

92104
class RandomClipSampler(torch.utils.data.Sampler):

0 commit comments

Comments
 (0)