Skip to content

Commit d8945e6

Browse files
committed
fix test setup
1 parent ad4d424 commit d8945e6

File tree

2 files changed

+3
-6
lines changed

2 files changed

+3
-6
lines changed

test/prototype_common_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -559,7 +559,7 @@ def make_video_loader(
559559
num_frames = int(torch.randint(1, 5, ())) if num_frames == "random" else num_frames
560560

561561
def fn(shape, dtype, device):
562-
video = make_image(size=shape[-2:], color_space=color_space, extra_dims=shape[:-2], dtype=dtype, device=device)
562+
video = make_image(size=shape[-2:], color_space=color_space, extra_dims=shape[:-3], dtype=dtype, device=device)
563563
return features.Video(video, color_space=color_space)
564564

565565
return VideoLoader(

test/test_prototype_transforms.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -157,11 +157,8 @@ def test_mixup_cutmix(self, transform, input):
157157
features.ColorSpace.RGB,
158158
],
159159
dtypes=[torch.uint8],
160-
**(
161-
dict(num_frames=[1, "random"], extra_dims=[()])
162-
if fn is make_videos
163-
else dict(extra_dims=[(4,)])
164-
),
160+
extra_dims=[(), (4,)],
161+
**(dict(num_frames=["random"]) if fn is make_videos else dict()),
165162
)
166163
for fn in [
167164
make_images,

0 commit comments

Comments
 (0)