Skip to content

Commit 1a300d8

Browse files
avijit9datumbox
andauthored
Cleanup functional_tensor.py (#3159) (#3171)
* added the helper method for dimension checks * unit tests for dimensio check function in functional_tensor * code formatting and typing * moved torch image check after tensor check * unit testcases for test_assert_image_tensor added and refactored * separate unit testcase file deleted * assert_image_tensor added to newly created 6 methods * test cases added for new 6 mthohds * removed wrongly pasted posterize method and added solarize method for testing Co-authored-by: Vasilis Vryniotis <[email protected]> Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent 90645cc commit 1a300d8

File tree

2 files changed

+71
-44
lines changed

2 files changed

+71
-44
lines changed

test/test_functional_tensor.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313

1414
from common_utils import TransformsTester
1515

16+
from typing import Dict, List, Tuple
17+
1618

1719
NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC
1820

@@ -34,6 +36,28 @@ def _test_fn_on_batch(self, batch_tensors, fn, **fn_kwargs):
3436
s_transformed_batch = scripted_fn(batch_tensors, **fn_kwargs)
3537
self.assertTrue(transformed_batch.allclose(s_transformed_batch))
3638

39+
def test_assert_image_tensor(self):
40+
shape = (100,)
41+
tensor = torch.rand(*shape, dtype=torch.float, device=self.device)
42+
43+
list_of_methods = [(F_t._get_image_size, (tensor, )), (F_t.vflip, (tensor, )),
44+
(F_t.hflip, (tensor, )), (F_t.crop, (tensor, 1, 2, 4, 5)),
45+
(F_t.adjust_brightness, (tensor, 0.)), (F_t.adjust_contrast, (tensor, 1.)),
46+
(F_t.adjust_hue, (tensor, -0.5)), (F_t.adjust_saturation, (tensor, 2.)),
47+
(F_t.center_crop, (tensor, [10, 11])), (F_t.five_crop, (tensor, [10, 11])),
48+
(F_t.ten_crop, (tensor, [10, 11])), (F_t.pad, (tensor, [2, ], 2, "constant")),
49+
(F_t.resize, (tensor, [10, 11])), (F_t.perspective, (tensor, [0.2, ])),
50+
(F_t.gaussian_blur, (tensor, (2, 2), (0.7, 0.5))),
51+
(F_t.invert, (tensor, )), (F_t.posterize, (tensor, 0)),
52+
(F_t.solarize, (tensor, 0.3)), (F_t.adjust_sharpness, (tensor, 0.3)),
53+
(F_t.autocontrast, (tensor, )), (F_t.equalize, (tensor, ))]
54+
55+
for func, args in list_of_methods:
56+
with self.assertRaises(Exception) as context:
57+
func(*args)
58+
59+
self.assertTrue('Tensor is not a torch image.' in str(context.exception))
60+
3761
def test_vflip(self):
3862
script_vflip = torch.jit.script(F.vflip)
3963

torchvision/transforms/functional_tensor.py

Lines changed: 47 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,15 @@ def _is_tensor_a_torch_image(x: Tensor) -> bool:
1111
return x.ndim >= 2
1212

1313

14+
def _assert_image_tensor(img):
15+
if not _is_tensor_a_torch_image(img):
16+
raise TypeError("Tensor is not a torch image.")
17+
18+
1419
def _get_image_size(img: Tensor) -> List[int]:
1520
"""Returns (w, h) of tensor image"""
16-
if _is_tensor_a_torch_image(img):
17-
return [img.shape[-1], img.shape[-2]]
18-
raise TypeError("Unexpected input type")
21+
_assert_image_tensor(img)
22+
return [img.shape[-1], img.shape[-2]]
1923

2024

2125
def _get_image_num_channels(img: Tensor) -> int:
@@ -143,8 +147,7 @@ def vflip(img: Tensor) -> Tensor:
143147
Returns:
144148
Tensor: Vertically flipped image Tensor.
145149
"""
146-
if not _is_tensor_a_torch_image(img):
147-
raise TypeError('tensor is not a torch image.')
150+
_assert_image_tensor(img)
148151

149152
return img.flip(-2)
150153

@@ -163,8 +166,7 @@ def hflip(img: Tensor) -> Tensor:
163166
Returns:
164167
Tensor: Horizontally flipped image Tensor.
165168
"""
166-
if not _is_tensor_a_torch_image(img):
167-
raise TypeError('tensor is not a torch image.')
169+
_assert_image_tensor(img)
168170

169171
return img.flip(-1)
170172

@@ -187,8 +189,7 @@ def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
187189
Returns:
188190
Tensor: Cropped image.
189191
"""
190-
if not _is_tensor_a_torch_image(img):
191-
raise TypeError("tensor is not a torch image.")
192+
_assert_image_tensor(img)
192193

193194
return img[..., top:top + height, left:left + width]
194195

@@ -254,8 +255,7 @@ def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
254255
if brightness_factor < 0:
255256
raise ValueError('brightness_factor ({}) is not non-negative.'.format(brightness_factor))
256257

257-
if not _is_tensor_a_torch_image(img):
258-
raise TypeError('tensor is not a torch image.')
258+
_assert_image_tensor(img)
259259

260260
_assert_channels(img, [1, 3])
261261

@@ -282,8 +282,7 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
282282
if contrast_factor < 0:
283283
raise ValueError('contrast_factor ({}) is not non-negative.'.format(contrast_factor))
284284

285-
if not _is_tensor_a_torch_image(img):
286-
raise TypeError('tensor is not a torch image.')
285+
_assert_image_tensor(img)
287286

288287
_assert_channels(img, [3])
289288

@@ -326,9 +325,11 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
326325
if not (-0.5 <= hue_factor <= 0.5):
327326
raise ValueError('hue_factor ({}) is not in [-0.5, 0.5].'.format(hue_factor))
328327

329-
if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)):
328+
if not (isinstance(img, torch.Tensor)):
330329
raise TypeError('Input img should be Tensor image')
331330

331+
_assert_image_tensor(img)
332+
332333
_assert_channels(img, [3])
333334

334335
orig_dtype = img.dtype
@@ -367,8 +368,7 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
367368
if saturation_factor < 0:
368369
raise ValueError('saturation_factor ({}) is not non-negative.'.format(saturation_factor))
369370

370-
if not _is_tensor_a_torch_image(img):
371-
raise TypeError('tensor is not a torch image.')
371+
_assert_image_tensor(img)
372372

373373
_assert_channels(img, [3])
374374

@@ -447,8 +447,7 @@ def center_crop(img: Tensor, output_size: BroadcastingList2[int]) -> Tensor:
447447
"Please, use ``F.center_crop`` instead."
448448
)
449449

450-
if not _is_tensor_a_torch_image(img):
451-
raise TypeError('tensor is not a torch image.')
450+
_assert_image_tensor(img)
452451

453452
_, image_width, image_height = img.size()
454453
crop_height, crop_width = output_size
@@ -497,8 +496,7 @@ def five_crop(img: Tensor, size: BroadcastingList2[int]) -> List[Tensor]:
497496
"Please, use ``F.five_crop`` instead."
498497
)
499498

500-
if not _is_tensor_a_torch_image(img):
501-
raise TypeError('tensor is not a torch image.')
499+
_assert_image_tensor(img)
502500

503501
assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
504502

@@ -553,8 +551,7 @@ def ten_crop(img: Tensor, size: BroadcastingList2[int], vertical_flip: bool = Fa
553551
"Please, use ``F.ten_crop`` instead."
554552
)
555553

556-
if not _is_tensor_a_torch_image(img):
557-
raise TypeError('tensor is not a torch image.')
554+
_assert_image_tensor(img)
558555

559556
assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
560557
first_five = five_crop(img, size)
@@ -703,8 +700,7 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con
703700
Returns:
704701
Tensor: Padded image.
705702
"""
706-
if not _is_tensor_a_torch_image(img):
707-
raise TypeError("tensor is not a torch image.")
703+
_assert_image_tensor(img)
708704

709705
if not isinstance(padding, (int, tuple, list)):
710706
raise TypeError("Got inappropriate padding arg")
@@ -796,8 +792,7 @@ def resize(img: Tensor, size: List[int], interpolation: str = "bilinear") -> Ten
796792
Returns:
797793
Tensor: Resized image.
798794
"""
799-
if not _is_tensor_a_torch_image(img):
800-
raise TypeError("tensor is not a torch image.")
795+
_assert_image_tensor(img)
801796

802797
if not isinstance(size, (int, tuple, list)):
803798
raise TypeError("Got inappropriate size arg")
@@ -855,8 +850,11 @@ def _assert_grid_transform_inputs(
855850
supported_interpolation_modes: List[str],
856851
coeffs: Optional[List[float]] = None,
857852
):
858-
if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)):
859-
raise TypeError("Input img should be Tensor Image")
853+
854+
if not (isinstance(img, torch.Tensor)):
855+
raise TypeError("Input img should be Tensor")
856+
857+
_assert_image_tensor(img)
860858

861859
if matrix is not None and not isinstance(matrix, list):
862860
raise TypeError("Argument matrix should be a list")
@@ -1112,8 +1110,11 @@ def perspective(
11121110
Returns:
11131111
Tensor: transformed image.
11141112
"""
1115-
if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)):
1116-
raise TypeError('Input img should be Tensor Image')
1113+
1114+
if not (isinstance(img, torch.Tensor)):
1115+
raise TypeError('Input img should be Tensor.')
1116+
1117+
_assert_image_tensor(img)
11171118

11181119
_assert_grid_transform_inputs(
11191120
img,
@@ -1165,8 +1166,11 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Te
11651166
Returns:
11661167
Tensor: An image that is blurred using gaussian kernel of given parameters
11671168
"""
1168-
if not (isinstance(img, torch.Tensor) or _is_tensor_a_torch_image(img)):
1169-
raise TypeError('img should be Tensor Image. Got {}'.format(type(img)))
1169+
1170+
if not (isinstance(img, torch.Tensor)):
1171+
raise TypeError('img should be Tensor. Got {}'.format(type(img)))
1172+
1173+
_assert_image_tensor(img)
11701174

11711175
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
11721176
kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=img.device)
@@ -1184,8 +1188,8 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Te
11841188

11851189

11861190
def invert(img: Tensor) -> Tensor:
1187-
if not _is_tensor_a_torch_image(img):
1188-
raise TypeError('tensor is not a torch image.')
1191+
1192+
_assert_image_tensor(img)
11891193

11901194
if img.ndim < 3:
11911195
raise TypeError("Input image tensor should have at least 3 dimensions, but found {}".format(img.ndim))
@@ -1197,8 +1201,8 @@ def invert(img: Tensor) -> Tensor:
11971201

11981202

11991203
def posterize(img: Tensor, bits: int) -> Tensor:
1200-
if not _is_tensor_a_torch_image(img):
1201-
raise TypeError('tensor is not a torch image.')
1204+
1205+
_assert_image_tensor(img)
12021206

12031207
if img.ndim < 3:
12041208
raise TypeError("Input image tensor should have at least 3 dimensions, but found {}".format(img.ndim))
@@ -1211,8 +1215,8 @@ def posterize(img: Tensor, bits: int) -> Tensor:
12111215

12121216

12131217
def solarize(img: Tensor, threshold: float) -> Tensor:
1214-
if not _is_tensor_a_torch_image(img):
1215-
raise TypeError('tensor is not a torch image.')
1218+
1219+
_assert_image_tensor(img)
12161220

12171221
if img.ndim < 3:
12181222
raise TypeError("Input image tensor should have at least 3 dimensions, but found {}".format(img.ndim))
@@ -1245,8 +1249,7 @@ def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
12451249
if sharpness_factor < 0:
12461250
raise ValueError('sharpness_factor ({}) is not non-negative.'.format(sharpness_factor))
12471251

1248-
if not _is_tensor_a_torch_image(img):
1249-
raise TypeError('tensor is not a torch image.')
1252+
_assert_image_tensor(img)
12501253

12511254
_assert_channels(img, [1, 3])
12521255

@@ -1257,8 +1260,8 @@ def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
12571260

12581261

12591262
def autocontrast(img: Tensor) -> Tensor:
1260-
if not _is_tensor_a_torch_image(img):
1261-
raise TypeError('tensor is not a torch image.')
1263+
1264+
_assert_image_tensor(img)
12621265

12631266
if img.ndim < 3:
12641267
raise TypeError("Input image tensor should have at least 3 dimensions, but found {}".format(img.ndim))
@@ -1297,8 +1300,8 @@ def _equalize_single_image(img: Tensor) -> Tensor:
12971300

12981301

12991302
def equalize(img: Tensor) -> Tensor:
1300-
if not _is_tensor_a_torch_image(img):
1301-
raise TypeError('tensor is not a torch image.')
1303+
1304+
_assert_image_tensor(img)
13021305

13031306
if not (3 <= img.ndim <= 4):
13041307
raise TypeError("Input image tensor should have 3 or 4 dimensions, but found {}".format(img.ndim))

0 commit comments

Comments
 (0)