Skip to content

Commit 12adbe8

Browse files
committed
Updated conditions for _BaseMixupCutmix
1 parent 3ce23ef commit 12adbe8

File tree

2 files changed

+3
-17
lines changed

2 files changed

+3
-17
lines changed

test/test_prototype_transforms.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
make_images,
88
make_bounding_boxes,
99
make_one_hot_labels,
10-
make_segmentation_masks,
11-
make_label,
1210
)
1311
from torchvision.prototype import transforms, features
1412
from torchvision.transforms.functional import to_pil_image, pil_to_tensor
@@ -108,20 +106,6 @@ def test_common(self, transform, input):
108106
def test_mixup_cutmix(self, transform, input):
109107
transform(input)
110108

111-
@pytest.mark.parametrize("transform", [transforms.RandomMixup(alpha=1.0), transforms.RandomCutmix(alpha=1.0)])
112-
def test_mixup_cutmix_assertions(self, transform):
113-
for bbox in make_bounding_boxes():
114-
with pytest.raises(TypeError, match="does not support"):
115-
transform(bbox)
116-
break
117-
for mask in make_segmentation_masks():
118-
with pytest.raises(TypeError, match="does not support"):
119-
transform(mask)
120-
break
121-
label = make_label()
122-
with pytest.raises(TypeError, match="does not support"):
123-
transform(label)
124-
125109
@parametrize(
126110
[
127111
(

torchvision/prototype/transforms/_augment.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from torchvision.prototype.transforms import Transform, functional as F
1010

1111
from ._transform import _RandomApplyTransform
12-
from ._utils import query_image, get_image_dimensions, has_any
12+
from ._utils import query_image, get_image_dimensions, has_any, has_all
1313

1414

1515
class RandomErasing(_RandomApplyTransform):
@@ -106,6 +106,8 @@ def __init__(self, *, alpha: float) -> None:
106106

107107
def forward(self, *inpts: Any) -> Any:
108108
sample = inpts if len(inpts) > 1 else inpts[0]
109+
if not has_all(sample, features.Image, features.OneHotLabel):
110+
raise TypeError(f"{type(self).__name__}() is only defined for Image's *and* OneHotLabel's.")
109111
if has_any(sample, features.BoundingBox, features.SegmentationMask, features.Label):
110112
raise TypeError(
111113
f"{type(self).__name__}() does not support bounding boxes, segmentation masks and plain labels."

0 commit comments

Comments
 (0)