Skip to content

Commit f1e2bfa

Browse files
committed
fix tests and kernel infos
1 parent 0d2ad96 commit f1e2bfa

File tree

2 files changed

+29
-19
lines changed

2 files changed

+29
-19
lines changed

test/prototype_common_utils.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,21 @@ def __post_init__(self):
250250
self.num_channels = self.shape[-3]
251251

252252

253+
NUM_CHANNELS_MAP = {
254+
features.ColorSpace.GRAY: 1,
255+
features.ColorSpace.GRAY_ALPHA: 2,
256+
features.ColorSpace.RGB: 3,
257+
features.ColorSpace.RGB_ALPHA: 4,
258+
}
259+
260+
261+
def get_num_channels(color_space):
262+
num_channels = NUM_CHANNELS_MAP.get(color_space)
263+
if not num_channels:
264+
raise pytest.UsageError(f"Can't determine the number of channels for color space {color_space}")
265+
return num_channels
266+
267+
253268
def make_image_loader(
254269
size="random",
255270
*,
@@ -259,16 +274,7 @@ def make_image_loader(
259274
constant_alpha=True,
260275
):
261276
size = _parse_image_size(size)
262-
263-
try:
264-
num_channels = {
265-
features.ColorSpace.GRAY: 1,
266-
features.ColorSpace.GRAY_ALPHA: 2,
267-
features.ColorSpace.RGB: 3,
268-
features.ColorSpace.RGB_ALPHA: 4,
269-
}[color_space]
270-
except KeyError as error:
271-
raise pytest.UsageError(f"Can't determine the number of channels for color space {color_space}") from error
277+
num_channels = get_num_channels(color_space)
272278

273279
def fn(shape, dtype, device):
274280
max_value = get_max_value(dtype)
@@ -550,13 +556,15 @@ def make_video_loader(
550556
dtype=torch.uint8,
551557
):
552558
size = _parse_image_size(size)
553-
num_frames = int(torch.randint(1, 4, ())) if num_frames == "random" else num_frames
559+
num_frames = int(torch.randint(1, 5, ())) if num_frames == "random" else num_frames
554560

555561
def fn(shape, dtype, device):
556562
video = make_image(size=shape[-2:], color_space=color_space, extra_dims=shape[:-2], dtype=dtype, device=device)
557563
return features.Video(video, color_space=color_space)
558564

559-
return VideoLoader(fn, shape=(*extra_dims, num_frames, *size), dtype=dtype, color_space=color_space)
565+
return VideoLoader(
566+
fn, shape=(*extra_dims, num_frames, get_num_channels(color_space), *size), dtype=dtype, color_space=color_space
567+
)
560568

561569

562570
make_video = from_loader(make_video_loader)

test/prototype_transforms_kernel_infos.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def sample_inputs_horizontal_flip_mask():
171171

172172

173173
def sample_inputs_horizontal_flip_video():
174-
for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"], dtypes=[torch.float32]):
174+
for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
175175
yield ArgsKwargs(video_loader)
176176

177177

@@ -298,8 +298,8 @@ def reference_inputs_resize_mask():
298298

299299

300300
def sample_inputs_resize_video():
301-
for mask_loader in make_mask_loaders(sizes=["random"], num_categories=["random"], num_objects=["random"]):
302-
yield ArgsKwargs(mask_loader, size=[min(mask_loader.shape[-2:]) + 1])
301+
for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
302+
yield ArgsKwargs(video_loader, size=[min(video_loader.shape[-2:]) + 1])
303303

304304

305305
KERNEL_INFOS.extend(
@@ -522,8 +522,8 @@ def reference_inputs_resize_mask():
522522

523523

524524
def sample_inputs_affine_video():
525-
for mask_loader in make_mask_loaders(sizes=["random"], num_categories=["random"], num_objects=["random"]):
526-
yield ArgsKwargs(mask_loader, **_full_affine_params())
525+
for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
526+
yield ArgsKwargs(video_loader, **_full_affine_params())
527527

528528

529529
KERNEL_INFOS.extend(
@@ -1364,7 +1364,7 @@ def sample_inputs_gaussian_blur_image_tensor():
13641364

13651365
def sample_inputs_gaussian_blur_video():
13661366
for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
1367-
yield ArgsKwargs(video_loader, kernel_size=3)
1367+
yield ArgsKwargs(video_loader, kernel_size=[3, 3])
13681368

13691369

13701370
KERNEL_INFOS.extend(
@@ -1967,7 +1967,9 @@ def sample_inputs_normalize_image_tensor():
19671967

19681968
def sample_inputs_normalize_video():
19691969
mean, std = _NORMALIZE_MEANS_STDS[0]
1970-
for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
1970+
for video_loader in make_video_loaders(
1971+
sizes=["random"], color_spaces=[features.ColorSpace.RGB], num_frames=["random"], dtypes=[torch.float32]
1972+
):
19711973
yield ArgsKwargs(video_loader, mean=mean, std=std)
19721974

19731975

0 commit comments

Comments
 (0)