@@ -11,11 +11,15 @@ def _is_tensor_a_torch_image(x: Tensor) -> bool:
11
11
return x .ndim >= 2
12
12
13
13
14
+ def _assert_image_tensor (img ):
15
+ if not _is_tensor_a_torch_image (img ):
16
+ raise TypeError ("Tensor is not a torch image." )
17
+
18
+
14
19
def _get_image_size (img : Tensor ) -> List [int ]:
15
20
"""Returns (w, h) of tensor image"""
16
- if _is_tensor_a_torch_image (img ):
17
- return [img .shape [- 1 ], img .shape [- 2 ]]
18
- raise TypeError ("Unexpected input type" )
21
+ _assert_image_tensor (img )
22
+ return [img .shape [- 1 ], img .shape [- 2 ]]
19
23
20
24
21
25
def _get_image_num_channels (img : Tensor ) -> int :
@@ -143,8 +147,7 @@ def vflip(img: Tensor) -> Tensor:
143
147
Returns:
144
148
Tensor: Vertically flipped image Tensor.
145
149
"""
146
- if not _is_tensor_a_torch_image (img ):
147
- raise TypeError ('tensor is not a torch image.' )
150
+ _assert_image_tensor (img )
148
151
149
152
return img .flip (- 2 )
150
153
@@ -163,8 +166,7 @@ def hflip(img: Tensor) -> Tensor:
163
166
Returns:
164
167
Tensor: Horizontally flipped image Tensor.
165
168
"""
166
- if not _is_tensor_a_torch_image (img ):
167
- raise TypeError ('tensor is not a torch image.' )
169
+ _assert_image_tensor (img )
168
170
169
171
return img .flip (- 1 )
170
172
@@ -187,8 +189,7 @@ def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
187
189
Returns:
188
190
Tensor: Cropped image.
189
191
"""
190
- if not _is_tensor_a_torch_image (img ):
191
- raise TypeError ("tensor is not a torch image." )
192
+ _assert_image_tensor (img )
192
193
193
194
return img [..., top :top + height , left :left + width ]
194
195
@@ -254,8 +255,7 @@ def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
254
255
if brightness_factor < 0 :
255
256
raise ValueError ('brightness_factor ({}) is not non-negative.' .format (brightness_factor ))
256
257
257
- if not _is_tensor_a_torch_image (img ):
258
- raise TypeError ('tensor is not a torch image.' )
258
+ _assert_image_tensor (img )
259
259
260
260
_assert_channels (img , [1 , 3 ])
261
261
@@ -282,8 +282,7 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
282
282
if contrast_factor < 0 :
283
283
raise ValueError ('contrast_factor ({}) is not non-negative.' .format (contrast_factor ))
284
284
285
- if not _is_tensor_a_torch_image (img ):
286
- raise TypeError ('tensor is not a torch image.' )
285
+ _assert_image_tensor (img )
287
286
288
287
_assert_channels (img , [3 ])
289
288
@@ -326,9 +325,11 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
326
325
if not (- 0.5 <= hue_factor <= 0.5 ):
327
326
raise ValueError ('hue_factor ({}) is not in [-0.5, 0.5].' .format (hue_factor ))
328
327
329
- if not (isinstance (img , torch .Tensor ) and _is_tensor_a_torch_image ( img ) ):
328
+ if not (isinstance (img , torch .Tensor )):
330
329
raise TypeError ('Input img should be Tensor image' )
331
330
331
+ _assert_image_tensor (img )
332
+
332
333
_assert_channels (img , [3 ])
333
334
334
335
orig_dtype = img .dtype
@@ -367,8 +368,7 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
367
368
if saturation_factor < 0 :
368
369
raise ValueError ('saturation_factor ({}) is not non-negative.' .format (saturation_factor ))
369
370
370
- if not _is_tensor_a_torch_image (img ):
371
- raise TypeError ('tensor is not a torch image.' )
371
+ _assert_image_tensor (img )
372
372
373
373
_assert_channels (img , [3 ])
374
374
@@ -447,8 +447,7 @@ def center_crop(img: Tensor, output_size: BroadcastingList2[int]) -> Tensor:
447
447
"Please, use ``F.center_crop`` instead."
448
448
)
449
449
450
- if not _is_tensor_a_torch_image (img ):
451
- raise TypeError ('tensor is not a torch image.' )
450
+ _assert_image_tensor (img )
452
451
453
452
_ , image_width , image_height = img .size ()
454
453
crop_height , crop_width = output_size
@@ -497,8 +496,7 @@ def five_crop(img: Tensor, size: BroadcastingList2[int]) -> List[Tensor]:
497
496
"Please, use ``F.five_crop`` instead."
498
497
)
499
498
500
- if not _is_tensor_a_torch_image (img ):
501
- raise TypeError ('tensor is not a torch image.' )
499
+ _assert_image_tensor (img )
502
500
503
501
assert len (size ) == 2 , "Please provide only two dimensions (h, w) for size."
504
502
@@ -553,8 +551,7 @@ def ten_crop(img: Tensor, size: BroadcastingList2[int], vertical_flip: bool = Fa
553
551
"Please, use ``F.ten_crop`` instead."
554
552
)
555
553
556
- if not _is_tensor_a_torch_image (img ):
557
- raise TypeError ('tensor is not a torch image.' )
554
+ _assert_image_tensor (img )
558
555
559
556
assert len (size ) == 2 , "Please provide only two dimensions (h, w) for size."
560
557
first_five = five_crop (img , size )
@@ -703,8 +700,7 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con
703
700
Returns:
704
701
Tensor: Padded image.
705
702
"""
706
- if not _is_tensor_a_torch_image (img ):
707
- raise TypeError ("tensor is not a torch image." )
703
+ _assert_image_tensor (img )
708
704
709
705
if not isinstance (padding , (int , tuple , list )):
710
706
raise TypeError ("Got inappropriate padding arg" )
@@ -796,8 +792,7 @@ def resize(img: Tensor, size: List[int], interpolation: str = "bilinear") -> Ten
796
792
Returns:
797
793
Tensor: Resized image.
798
794
"""
799
- if not _is_tensor_a_torch_image (img ):
800
- raise TypeError ("tensor is not a torch image." )
795
+ _assert_image_tensor (img )
801
796
802
797
if not isinstance (size , (int , tuple , list )):
803
798
raise TypeError ("Got inappropriate size arg" )
@@ -855,8 +850,11 @@ def _assert_grid_transform_inputs(
855
850
supported_interpolation_modes : List [str ],
856
851
coeffs : Optional [List [float ]] = None ,
857
852
):
858
- if not (isinstance (img , torch .Tensor ) and _is_tensor_a_torch_image (img )):
859
- raise TypeError ("Input img should be Tensor Image" )
853
+
854
+ if not (isinstance (img , torch .Tensor )):
855
+ raise TypeError ("Input img should be Tensor" )
856
+
857
+ _assert_image_tensor (img )
860
858
861
859
if matrix is not None and not isinstance (matrix , list ):
862
860
raise TypeError ("Argument matrix should be a list" )
@@ -1112,8 +1110,11 @@ def perspective(
1112
1110
Returns:
1113
1111
Tensor: transformed image.
1114
1112
"""
1115
- if not (isinstance (img , torch .Tensor ) and _is_tensor_a_torch_image (img )):
1116
- raise TypeError ('Input img should be Tensor Image' )
1113
+
1114
+ if not (isinstance (img , torch .Tensor )):
1115
+ raise TypeError ('Input img should be Tensor.' )
1116
+
1117
+ _assert_image_tensor (img )
1117
1118
1118
1119
_assert_grid_transform_inputs (
1119
1120
img ,
@@ -1165,8 +1166,11 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Te
1165
1166
Returns:
1166
1167
Tensor: An image that is blurred using gaussian kernel of given parameters
1167
1168
"""
1168
- if not (isinstance (img , torch .Tensor ) or _is_tensor_a_torch_image (img )):
1169
- raise TypeError ('img should be Tensor Image. Got {}' .format (type (img )))
1169
+
1170
+ if not (isinstance (img , torch .Tensor )):
1171
+ raise TypeError ('img should be Tensor. Got {}' .format (type (img )))
1172
+
1173
+ _assert_image_tensor (img )
1170
1174
1171
1175
dtype = img .dtype if torch .is_floating_point (img ) else torch .float32
1172
1176
kernel = _get_gaussian_kernel2d (kernel_size , sigma , dtype = dtype , device = img .device )
@@ -1184,8 +1188,8 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Te
1184
1188
1185
1189
1186
1190
def invert (img : Tensor ) -> Tensor :
1187
- if not _is_tensor_a_torch_image ( img ):
1188
- raise TypeError ( 'tensor is not a torch image.' )
1191
+
1192
+ _assert_image_tensor ( img )
1189
1193
1190
1194
if img .ndim < 3 :
1191
1195
raise TypeError ("Input image tensor should have at least 3 dimensions, but found {}" .format (img .ndim ))
@@ -1197,8 +1201,8 @@ def invert(img: Tensor) -> Tensor:
1197
1201
1198
1202
1199
1203
def posterize (img : Tensor , bits : int ) -> Tensor :
1200
- if not _is_tensor_a_torch_image ( img ):
1201
- raise TypeError ( 'tensor is not a torch image.' )
1204
+
1205
+ _assert_image_tensor ( img )
1202
1206
1203
1207
if img .ndim < 3 :
1204
1208
raise TypeError ("Input image tensor should have at least 3 dimensions, but found {}" .format (img .ndim ))
@@ -1211,8 +1215,8 @@ def posterize(img: Tensor, bits: int) -> Tensor:
1211
1215
1212
1216
1213
1217
def solarize (img : Tensor , threshold : float ) -> Tensor :
1214
- if not _is_tensor_a_torch_image ( img ):
1215
- raise TypeError ( 'tensor is not a torch image.' )
1218
+
1219
+ _assert_image_tensor ( img )
1216
1220
1217
1221
if img .ndim < 3 :
1218
1222
raise TypeError ("Input image tensor should have at least 3 dimensions, but found {}" .format (img .ndim ))
@@ -1245,8 +1249,7 @@ def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
1245
1249
if sharpness_factor < 0 :
1246
1250
raise ValueError ('sharpness_factor ({}) is not non-negative.' .format (sharpness_factor ))
1247
1251
1248
- if not _is_tensor_a_torch_image (img ):
1249
- raise TypeError ('tensor is not a torch image.' )
1252
+ _assert_image_tensor (img )
1250
1253
1251
1254
_assert_channels (img , [1 , 3 ])
1252
1255
@@ -1257,8 +1260,8 @@ def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
1257
1260
1258
1261
1259
1262
def autocontrast (img : Tensor ) -> Tensor :
1260
- if not _is_tensor_a_torch_image ( img ):
1261
- raise TypeError ( 'tensor is not a torch image.' )
1263
+
1264
+ _assert_image_tensor ( img )
1262
1265
1263
1266
if img .ndim < 3 :
1264
1267
raise TypeError ("Input image tensor should have at least 3 dimensions, but found {}" .format (img .ndim ))
@@ -1297,8 +1300,8 @@ def _equalize_single_image(img: Tensor) -> Tensor:
1297
1300
1298
1301
1299
1302
def equalize (img : Tensor ) -> Tensor :
1300
- if not _is_tensor_a_torch_image ( img ):
1301
- raise TypeError ( 'tensor is not a torch image.' )
1303
+
1304
+ _assert_image_tensor ( img )
1302
1305
1303
1306
if not (3 <= img .ndim <= 4 ):
1304
1307
raise TypeError ("Input image tensor should have 3 or 4 dimensions, but found {}" .format (img .ndim ))
0 commit comments