Skip to content

Commit 997348d

Browse files
fix: add appropriate error message when validating padding argument. (#8959)
Co-authored-by: Nicolas Hug <[email protected]> Co-authored-by: Nicolas Hug <[email protected]>
1 parent 0f30dff commit 997348d

File tree

2 files changed

+21
-7
lines changed

2 files changed

+21
-7
lines changed

test/test_transforms_v2.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3022,12 +3022,18 @@ def test_errors(self):
30223022
with pytest.raises(ValueError, match="Please provide only two dimensions"):
30233023
transforms.RandomCrop([10, 12, 14])
30243024

3025-
with pytest.raises(TypeError, match="Got inappropriate padding arg"):
3025+
with pytest.raises(ValueError, match="Padding must be an int or a 1, 2, or 4"):
30263026
transforms.RandomCrop([10, 12], padding="abc")
30273027

30283028
with pytest.raises(ValueError, match="Padding must be an int or a 1, 2, or 4"):
30293029
transforms.RandomCrop([10, 12], padding=[-0.7, 0, 0.7])
30303030

3031+
with pytest.raises(ValueError, match="Padding must be an int or a 1, 2, or 4"):
3032+
transforms.RandomCrop([10, 12], padding=0.5)
3033+
3034+
with pytest.raises(ValueError, match="Padding must be an int or a 1, 2, or 4"):
3035+
transforms.RandomCrop([10, 12], padding=[0.5, 0.5])
3036+
30313037
with pytest.raises(TypeError, match="Got inappropriate fill arg"):
30323038
transforms.RandomCrop([10, 12], padding=1, fill="abc")
30333039

@@ -3892,12 +3898,18 @@ def test_transform(self, make_input):
38923898
check_transform(transforms.Pad(padding=[1]), make_input())
38933899

38943900
def test_transform_errors(self):
3895-
with pytest.raises(TypeError, match="Got inappropriate padding arg"):
3901+
with pytest.raises(ValueError, match="Padding must be"):
38963902
transforms.Pad("abc")
38973903

3898-
with pytest.raises(ValueError, match="Padding must be an int or a 1, 2, or 4"):
3904+
with pytest.raises(ValueError, match="Padding must be an int or a 1, 2, or 4 element of tuple or list"):
38993905
transforms.Pad([-0.7, 0, 0.7])
39003906

3907+
with pytest.raises(ValueError, match="Padding must be an int or a 1, 2, or 4 element of tuple or list"):
3908+
transforms.Pad(0.5)
3909+
3910+
with pytest.raises(ValueError, match="Padding must be an int or a 1, 2, or 4 element of tuple or list"):
3911+
transforms.Pad(padding=[0.5, 0.5])
3912+
39013913
with pytest.raises(TypeError, match="Got inappropriate fill arg"):
39023914
transforms.Pad(12, fill="abc")
39033915

torchvision/transforms/v2/_utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,13 @@ def _get_fill(fill_dict, inpt_type):
8181

8282

8383
def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None:
84-
if not isinstance(padding, (numbers.Number, tuple, list)):
85-
raise TypeError("Got inappropriate padding arg")
8684

87-
if isinstance(padding, (tuple, list)) and len(padding) not in [1, 2, 4]:
88-
raise ValueError(f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple")
85+
err_msg = f"Padding must be an int or a 1, 2, or 4 element of tuple or list, got {padding}."
86+
if isinstance(padding, (tuple, list)):
87+
if len(padding) not in [1, 2, 4] or not all(isinstance(p, int) for p in padding):
88+
raise ValueError(err_msg)
89+
elif not isinstance(padding, int):
90+
raise ValueError(err_msg)
8991

9092

9193
# TODO: let's use torchvision._utils.StrEnum to have the best of both worlds (strings and enums)

0 commit comments

Comments
 (0)