Skip to content

Commit 3db55c1

Browse files
authored
Merge branch 'main' into szhi-s3d
2 parents 15cf85d + 330b6c9 commit 3db55c1

File tree

12 files changed

+380
-61
lines changed

12 files changed

+380
-61
lines changed

test/test_prototype_transforms.py

Lines changed: 126 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import pytest
88
import torch
9-
from common_utils import assert_equal
9+
from common_utils import assert_equal, cpu_and_gpu
1010
from test_prototype_transforms_functional import (
1111
make_bounding_box,
1212
make_bounding_boxes,
@@ -15,6 +15,7 @@
1515
make_one_hot_labels,
1616
make_segmentation_mask,
1717
)
18+
from torchvision.ops.boxes import box_iou
1819
from torchvision.prototype import features, transforms
1920
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image
2021

@@ -793,7 +794,7 @@ def test__transform(self, p, transform_cls, func_op_name, kwargs, mocker):
793794
if p > 0.0:
794795
fn.assert_called_once_with(inpt, **kwargs)
795796
else:
796-
fn.call_count == 0
797+
assert fn.call_count == 0
797798

798799

799800
class TestRandomPerspective:
@@ -1014,7 +1015,7 @@ def test__transform(self, p, inpt_type, mocker):
10141015
if p > 0.0:
10151016
fn.assert_called_once_with(erase_image_tensor_inpt, **params)
10161017
else:
1017-
fn.call_count == 0
1018+
assert fn.call_count == 0
10181019

10191020

10201021
class TestTransform:
@@ -1050,7 +1051,7 @@ def test__transform(self, inpt_type, mocker):
10501051
transform = transforms.ToImageTensor()
10511052
transform(inpt)
10521053
if inpt_type in (features.BoundingBox, str, int):
1053-
fn.call_count == 0
1054+
assert fn.call_count == 0
10541055
else:
10551056
fn.assert_called_once_with(inpt, copy=transform.copy)
10561057

@@ -1067,7 +1068,7 @@ def test__transform(self, inpt_type, mocker):
10671068
transform = transforms.ToImagePIL()
10681069
transform(inpt)
10691070
if inpt_type in (features.BoundingBox, str, int):
1070-
fn.call_count == 0
1071+
assert fn.call_count == 0
10711072
else:
10721073
fn.assert_called_once_with(inpt, copy=transform.copy)
10731074

@@ -1085,7 +1086,7 @@ def test__transform(self, inpt_type, mocker):
10851086
transform = transforms.ToPILImage()
10861087
transform(inpt)
10871088
if inpt_type in (PIL.Image.Image, features.BoundingBox, str, int):
1088-
fn.call_count == 0
1089+
assert fn.call_count == 0
10891090
else:
10901091
fn.assert_called_once_with(inpt, mode=transform.mode)
10911092

@@ -1103,7 +1104,7 @@ def test__transform(self, inpt_type, mocker):
11031104
transform = transforms.ToTensor()
11041105
transform(inpt)
11051106
if inpt_type in (features.Image, torch.Tensor, features.BoundingBox, str, int):
1106-
fn.call_count == 0
1107+
assert fn.call_count == 0
11071108
else:
11081109
fn.assert_called_once_with(inpt)
11091110

@@ -1127,6 +1128,124 @@ def test_ctor(self, trfms):
11271128
assert isinstance(output, torch.Tensor)
11281129

11291130

1131+
class TestRandomIoUCrop:
1132+
@pytest.mark.parametrize("device", cpu_and_gpu())
1133+
@pytest.mark.parametrize("options", [[0.5, 0.9], [2.0]])
1134+
def test__get_params(self, device, options, mocker):
1135+
image = mocker.MagicMock(spec=features.Image)
1136+
image.num_channels = 3
1137+
image.image_size = (24, 32)
1138+
bboxes = features.BoundingBox(
1139+
torch.tensor([[1, 1, 10, 10], [20, 20, 23, 23], [1, 20, 10, 23], [20, 1, 23, 10]]),
1140+
format="XYXY",
1141+
image_size=image.image_size,
1142+
device=device,
1143+
)
1144+
sample = [image, bboxes]
1145+
1146+
transform = transforms.RandomIoUCrop(sampler_options=options)
1147+
1148+
n_samples = 5
1149+
for _ in range(n_samples):
1150+
1151+
params = transform._get_params(sample)
1152+
1153+
if options == [2.0]:
1154+
assert len(params) == 0
1155+
return
1156+
1157+
assert len(params["is_within_crop_area"]) > 0
1158+
assert params["is_within_crop_area"].dtype == torch.bool
1159+
1160+
orig_h = image.image_size[0]
1161+
orig_w = image.image_size[1]
1162+
assert int(transform.min_scale * orig_h) <= params["height"] <= int(transform.max_scale * orig_h)
1163+
assert int(transform.min_scale * orig_w) <= params["width"] <= int(transform.max_scale * orig_w)
1164+
1165+
left, top = params["left"], params["top"]
1166+
new_h, new_w = params["height"], params["width"]
1167+
ious = box_iou(
1168+
bboxes,
1169+
torch.tensor([[left, top, left + new_w, top + new_h]], dtype=bboxes.dtype, device=bboxes.device),
1170+
)
1171+
assert ious.max() >= options[0] or ious.max() >= options[1], f"{ious} vs {options}"
1172+
1173+
def test__transform_empty_params(self, mocker):
1174+
transform = transforms.RandomIoUCrop(sampler_options=[2.0])
1175+
image = features.Image(torch.rand(1, 3, 4, 4))
1176+
bboxes = features.BoundingBox(torch.tensor([[1, 1, 2, 2]]), format="XYXY", image_size=(4, 4))
1177+
label = features.Label(torch.tensor([1]))
1178+
sample = [image, bboxes, label]
1179+
# Let's mock transform._get_params to control the output:
1180+
transform._get_params = mocker.MagicMock(return_value={})
1181+
output = transform(sample)
1182+
torch.testing.assert_close(output, sample)
1183+
1184+
def test_forward_assertion(self):
1185+
transform = transforms.RandomIoUCrop()
1186+
with pytest.raises(
1187+
TypeError,
1188+
match="requires input sample to contain Images or PIL Images, BoundingBoxes and Labels or OneHotLabels",
1189+
):
1190+
transform(torch.tensor(0))
1191+
1192+
def test__transform(self, mocker):
1193+
transform = transforms.RandomIoUCrop()
1194+
1195+
image = features.Image(torch.rand(3, 32, 24))
1196+
bboxes = make_bounding_box(format="XYXY", image_size=(32, 24), extra_dims=(6,))
1197+
label = features.Label(torch.randint(0, 10, size=(6,)))
1198+
ohe_label = features.OneHotLabel(torch.zeros(6, 10).scatter_(1, label.unsqueeze(1), 1))
1199+
masks = make_segmentation_mask((32, 24))
1200+
ohe_masks = features.SegmentationMask(torch.randint(0, 2, size=(6, 32, 24)))
1201+
sample = [image, bboxes, label, ohe_label, masks, ohe_masks]
1202+
1203+
fn = mocker.patch("torchvision.prototype.transforms.functional.crop", side_effect=lambda x, **params: x)
1204+
is_within_crop_area = torch.tensor([0, 1, 0, 1, 0, 1], dtype=torch.bool)
1205+
1206+
params = dict(top=1, left=2, height=12, width=12, is_within_crop_area=is_within_crop_area)
1207+
transform._get_params = mocker.MagicMock(return_value=params)
1208+
output = transform(sample)
1209+
1210+
assert fn.call_count == 4
1211+
1212+
expected_calls = [
1213+
mocker.call(image, top=params["top"], left=params["left"], height=params["height"], width=params["width"]),
1214+
mocker.call(bboxes, top=params["top"], left=params["left"], height=params["height"], width=params["width"]),
1215+
mocker.call(masks, top=params["top"], left=params["left"], height=params["height"], width=params["width"]),
1216+
mocker.call(
1217+
ohe_masks, top=params["top"], left=params["left"], height=params["height"], width=params["width"]
1218+
),
1219+
]
1220+
1221+
fn.assert_has_calls(expected_calls)
1222+
1223+
expected_within_targets = sum(is_within_crop_area)
1224+
1225+
# check number of bboxes vs number of labels:
1226+
output_bboxes = output[1]
1227+
assert isinstance(output_bboxes, features.BoundingBox)
1228+
assert len(output_bboxes) == expected_within_targets
1229+
1230+
# check labels
1231+
output_label = output[2]
1232+
assert isinstance(output_label, features.Label)
1233+
assert len(output_label) == expected_within_targets
1234+
torch.testing.assert_close(output_label, label[is_within_crop_area])
1235+
1236+
output_ohe_label = output[3]
1237+
assert isinstance(output_ohe_label, features.OneHotLabel)
1238+
torch.testing.assert_close(output_ohe_label, ohe_label[is_within_crop_area])
1239+
1240+
output_masks = output[4]
1241+
assert isinstance(output_masks, features.SegmentationMask)
1242+
assert output_masks.shape[:-2] == masks.shape[:-2]
1243+
1244+
output_ohe_masks = output[5]
1245+
assert isinstance(output_ohe_masks, features.SegmentationMask)
1246+
assert len(output_ohe_masks) == expected_within_targets
1247+
1248+
11301249
class TestScaleJitter:
11311250
def test__get_params(self, mocker):
11321251
image_size = (24, 32)
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import PIL.Image
2+
import pytest
3+
4+
import torch
5+
6+
from test_prototype_transforms_functional import make_bounding_box, make_image, make_segmentation_mask
7+
8+
from torchvision.prototype import features
9+
from torchvision.prototype.transforms._utils import has_all, has_any, is_simple_tensor
10+
from torchvision.prototype.transforms.functional import to_image_pil
11+
12+
13+
IMAGE = make_image(color_space=features.ColorSpace.RGB)
14+
BOUNDING_BOX = make_bounding_box(format=features.BoundingBoxFormat.XYXY, image_size=IMAGE.image_size)
15+
SEGMENTATION_MASK = make_segmentation_mask(size=IMAGE.image_size)
16+
17+
18+
@pytest.mark.parametrize(
19+
("sample", "types", "expected"),
20+
[
21+
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image,), True),
22+
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.BoundingBox,), True),
23+
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.SegmentationMask,), True),
24+
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.BoundingBox), True),
25+
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.SegmentationMask), True),
26+
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.BoundingBox, features.SegmentationMask), True),
27+
((SEGMENTATION_MASK,), (features.Image, features.BoundingBox), False),
28+
((BOUNDING_BOX,), (features.Image, features.SegmentationMask), False),
29+
((IMAGE,), (features.BoundingBox, features.SegmentationMask), False),
30+
(
31+
(IMAGE, BOUNDING_BOX, SEGMENTATION_MASK),
32+
(features.Image, features.BoundingBox, features.SegmentationMask),
33+
True,
34+
),
35+
((), (features.Image, features.BoundingBox, features.SegmentationMask), False),
36+
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda obj: isinstance(obj, features.Image),), True),
37+
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda _: False,), False),
38+
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda _: True,), True),
39+
((IMAGE,), (features.Image, PIL.Image.Image, is_simple_tensor), True),
40+
((torch.Tensor(IMAGE),), (features.Image, PIL.Image.Image, is_simple_tensor), True),
41+
((to_image_pil(IMAGE),), (features.Image, PIL.Image.Image, is_simple_tensor), True),
42+
],
43+
)
44+
def test_has_any(sample, types, expected):
45+
assert has_any(sample, *types) is expected
46+
47+
48+
@pytest.mark.parametrize(
49+
("sample", "types", "expected"),
50+
[
51+
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image,), True),
52+
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.BoundingBox,), True),
53+
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.SegmentationMask,), True),
54+
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.BoundingBox), True),
55+
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.SegmentationMask), True),
56+
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.BoundingBox, features.SegmentationMask), True),
57+
(
58+
(IMAGE, BOUNDING_BOX, SEGMENTATION_MASK),
59+
(features.Image, features.BoundingBox, features.SegmentationMask),
60+
True,
61+
),
62+
((BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.BoundingBox), False),
63+
((BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.SegmentationMask), False),
64+
((IMAGE, SEGMENTATION_MASK), (features.BoundingBox, features.SegmentationMask), False),
65+
(
66+
(IMAGE, BOUNDING_BOX, SEGMENTATION_MASK),
67+
(features.Image, features.BoundingBox, features.SegmentationMask),
68+
True,
69+
),
70+
((BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.BoundingBox, features.SegmentationMask), False),
71+
((IMAGE, SEGMENTATION_MASK), (features.Image, features.BoundingBox, features.SegmentationMask), False),
72+
((IMAGE, BOUNDING_BOX), (features.Image, features.BoundingBox, features.SegmentationMask), False),
73+
(
74+
(IMAGE, BOUNDING_BOX, SEGMENTATION_MASK),
75+
(lambda obj: isinstance(obj, (features.Image, features.BoundingBox, features.SegmentationMask)),),
76+
True,
77+
),
78+
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda _: False,), False),
79+
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda _: True,), True),
80+
],
81+
)
82+
def test_has_all(sample, types, expected):
83+
assert has_all(sample, *types) is expected

torchvision/datasets/food101.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
class Food101(VisionDataset):
1212
"""`The Food-101 Data Set <https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/>`_.
1313
14-
The Food-101 is a challenging data set of 101 food categories, with 101'000 images.
14+
The Food-101 is a challenging data set of 101 food categories with 101,000 images.
1515
For each class, 250 manually reviewed test images are provided as well as 750 training images.
1616
On purpose, the training images were not cleaned, and thus still contain some amount of noise.
1717
This comes mostly in the form of intense colors and sometimes wrong labels. All images were

torchvision/datasets/ucf101.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class UCF101(VisionDataset):
1717
by ``frames_per_clip``, where the step in frames between each clip is given by
1818
``step_between_clips``. The dataset itself can be downloaded from the dataset website;
1919
annotations that ``annotation_path`` should be pointing to can be downloaded from `here
20-
<https://www.crcv.ucf.edu/data/UCF101/UCF101TrainTestSplits-RecognitionTask.zip>`.
20+
<https://www.crcv.ucf.edu/data/UCF101/UCF101TrainTestSplits-RecognitionTask.zip>`_.
2121
2222
To give an example, for 2 videos with 10 and 15 frames respectively, if ``frames_per_clip=5``
2323
and ``step_between_clips=5``, the dataset size will be (2 + 3) = 5, where the first two

torchvision/prototype/transforms/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
RandomAffine,
2525
RandomCrop,
2626
RandomHorizontalFlip,
27+
RandomIoUCrop,
2728
RandomPerspective,
2829
RandomResizedCrop,
2930
RandomRotation,

torchvision/prototype/transforms/_augment.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
import PIL.Image
77
import torch
88
from torchvision.prototype import features
9-
from torchvision.prototype.transforms import functional as F, Transform
9+
from torchvision.prototype.transforms import functional as F
1010

1111
from ._transform import _RandomApplyTransform
12-
from ._utils import get_image_dimensions, has_all, has_any, is_simple_tensor, query_image
12+
from ._utils import get_image_dimensions, has_any, is_simple_tensor, query_image
1313

1414

1515
class RandomErasing(_RandomApplyTransform):
@@ -97,15 +97,17 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
9797
return inpt
9898

9999

100-
class _BaseMixupCutmix(Transform):
101-
def __init__(self, *, alpha: float) -> None:
102-
super().__init__()
100+
class _BaseMixupCutmix(_RandomApplyTransform):
101+
def __init__(self, *, alpha: float, p: float = 0.5) -> None:
102+
super().__init__(p=p)
103103
self.alpha = alpha
104104
self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha]))
105105

106106
def forward(self, *inpts: Any) -> Any:
107107
sample = inpts if len(inpts) > 1 else inpts[0]
108-
if not has_all(sample, features.Image, features.OneHotLabel):
108+
if not (
109+
has_any(sample, features.Image, PIL.Image.Image, is_simple_tensor) and has_any(sample, features.OneHotLabel)
110+
):
109111
raise TypeError(f"{type(self).__name__}() is only defined for Image's *and* OneHotLabel's.")
110112
if has_any(sample, features.BoundingBox, features.SegmentationMask, features.Label):
111113
raise TypeError(

0 commit comments

Comments
 (0)