Skip to content

Commit d2dc1b9

Browse files
Yosua Michael Maranathafacebook-github-bot
Yosua Michael Maranatha
authored andcommitted
[fbsync] add Video feature and kernels (#6667)
Summary: * add video feature * add video kernels * add video testing utils * add one kernel info * fix kernel names in Video feature * use only uint8 for video testing * require at least 4 dims for Video feature * add TODO for image_size -> spatial_size * image -> video in feature constructor * introduce new combined images and video type * add video to transform utils * fix transforms test * fix auto augment * cleanup * address review comments * add remaining video kernel infos * add batch dimension squashing to some kernels * fix tests and kernel infos * add xfails for arbitrary batch sizes on some kernels * fix test setup * fix equalize_image_tensor for multi batch dims * fix adjust_sharpness_image_tensor for multi batch dims * address review comments Reviewed By: NicolasHug Differential Revision: D40427483 fbshipit-source-id: 748602811638a2b9c56134f14ea107714de86040
1 parent df7db25 commit d2dc1b9

20 files changed

+1171
-257
lines changed

test/prototype_common_utils.py

Lines changed: 69 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@
4545
"make_segmentation_masks",
4646
"make_mask_loaders",
4747
"make_masks",
48+
"make_video",
49+
"make_videos",
4850
]
4951

5052

@@ -210,17 +212,19 @@ def _parse_image_size(size, *, name="size"):
210212

211213
def from_loader(loader_fn):
212214
def wrapper(*args, **kwargs):
215+
device = kwargs.pop("device", "cpu")
213216
loader = loader_fn(*args, **kwargs)
214-
return loader.load(kwargs.get("device", "cpu"))
217+
return loader.load(device)
215218

216219
return wrapper
217220

218221

219222
def from_loaders(loaders_fn):
220223
def wrapper(*args, **kwargs):
224+
device = kwargs.pop("device", "cpu")
221225
loaders = loaders_fn(*args, **kwargs)
222226
for loader in loaders:
223-
yield loader.load(kwargs.get("device", "cpu"))
227+
yield loader.load(device)
224228

225229
return wrapper
226230

@@ -246,6 +250,21 @@ def __post_init__(self):
246250
self.num_channels = self.shape[-3]
247251

248252

253+
NUM_CHANNELS_MAP = {
254+
features.ColorSpace.GRAY: 1,
255+
features.ColorSpace.GRAY_ALPHA: 2,
256+
features.ColorSpace.RGB: 3,
257+
features.ColorSpace.RGB_ALPHA: 4,
258+
}
259+
260+
261+
def get_num_channels(color_space):
262+
num_channels = NUM_CHANNELS_MAP.get(color_space)
263+
if not num_channels:
264+
raise pytest.UsageError(f"Can't determine the number of channels for color space {color_space}")
265+
return num_channels
266+
267+
249268
def make_image_loader(
250269
size="random",
251270
*,
@@ -255,16 +274,7 @@ def make_image_loader(
255274
constant_alpha=True,
256275
):
257276
size = _parse_image_size(size)
258-
259-
try:
260-
num_channels = {
261-
features.ColorSpace.GRAY: 1,
262-
features.ColorSpace.GRAY_ALPHA: 2,
263-
features.ColorSpace.RGB: 3,
264-
features.ColorSpace.RGB_ALPHA: 4,
265-
}[color_space]
266-
except KeyError as error:
267-
raise pytest.UsageError(f"Can't determine the number of channels for color space {color_space}") from error
277+
num_channels = get_num_channels(color_space)
268278

269279
def fn(shape, dtype, device):
270280
max_value = get_max_value(dtype)
@@ -531,3 +541,50 @@ def make_mask_loaders(
531541

532542

533543
make_masks = from_loaders(make_mask_loaders)
544+
545+
546+
class VideoLoader(ImageLoader):
547+
pass
548+
549+
550+
def make_video_loader(
551+
size="random",
552+
*,
553+
color_space=features.ColorSpace.RGB,
554+
num_frames="random",
555+
extra_dims=(),
556+
dtype=torch.uint8,
557+
):
558+
size = _parse_image_size(size)
559+
num_frames = int(torch.randint(1, 5, ())) if num_frames == "random" else num_frames
560+
561+
def fn(shape, dtype, device):
562+
video = make_image(size=shape[-2:], color_space=color_space, extra_dims=shape[:-3], dtype=dtype, device=device)
563+
return features.Video(video, color_space=color_space)
564+
565+
return VideoLoader(
566+
fn, shape=(*extra_dims, num_frames, get_num_channels(color_space), *size), dtype=dtype, color_space=color_space
567+
)
568+
569+
570+
make_video = from_loader(make_video_loader)
571+
572+
573+
def make_video_loaders(
574+
*,
575+
sizes=DEFAULT_IMAGE_SIZES,
576+
color_spaces=(
577+
features.ColorSpace.GRAY,
578+
features.ColorSpace.RGB,
579+
),
580+
num_frames=(1, 0, "random"),
581+
extra_dims=DEFAULT_EXTRA_DIMS,
582+
dtypes=(torch.uint8,),
583+
):
584+
for params in combinations_grid(
585+
size=sizes, color_space=color_spaces, num_frames=num_frames, extra_dims=extra_dims, dtype=dtypes
586+
):
587+
yield make_video_loader(**params)
588+
589+
590+
make_videos = from_loaders(make_video_loaders)

test/prototype_transforms_dispatcher_infos.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,23 @@ def fill_sequence_needs_broadcast(args_kwargs):
127127
)
128128

129129

130+
def xfail_all_tests(*, reason, condition):
131+
return [
132+
TestMark(("TestDispatchers", test_name), pytest.mark.xfail(reason=reason), condition=condition)
133+
for test_name in [
134+
"test_scripted_smoke",
135+
"test_dispatch_simple_tensor",
136+
"test_dispatch_feature",
137+
]
138+
]
139+
140+
141+
xfails_degenerate_or_multi_batch_dims = xfail_all_tests(
142+
reason="See https://github.com/pytorch/vision/issues/6670 for details.",
143+
condition=lambda args_kwargs: len(args_kwargs.args[0].shape) > 4 or not all(args_kwargs.args[0].shape[:-3]),
144+
)
145+
146+
130147
DISPATCHER_INFOS = [
131148
DispatcherInfo(
132149
F.horizontal_flip,
@@ -243,6 +260,7 @@ def fill_sequence_needs_broadcast(args_kwargs):
243260
pil_kernel_info=PILKernelInfo(F.perspective_image_pil),
244261
test_marks=[
245262
xfail_dispatch_pil_if_fill_sequence_needs_broadcast,
263+
*xfails_degenerate_or_multi_batch_dims,
246264
],
247265
),
248266
DispatcherInfo(
@@ -253,6 +271,7 @@ def fill_sequence_needs_broadcast(args_kwargs):
253271
features.Mask: F.elastic_mask,
254272
},
255273
pil_kernel_info=PILKernelInfo(F.elastic_image_pil),
274+
test_marks=xfails_degenerate_or_multi_batch_dims,
256275
),
257276
DispatcherInfo(
258277
F.center_crop,
@@ -275,6 +294,7 @@ def fill_sequence_needs_broadcast(args_kwargs):
275294
test_marks=[
276295
xfail_jit_python_scalar_arg("kernel_size"),
277296
xfail_jit_python_scalar_arg("sigma"),
297+
*xfails_degenerate_or_multi_batch_dims,
278298
],
279299
),
280300
DispatcherInfo(

0 commit comments

Comments
 (0)