Skip to content

Commit 414a1ee

Browse files
committed
Minor code-quality changes on Geometical Transforms.
1 parent 08ae56f commit 414a1ee

File tree

1 file changed

+22
-23
lines changed

1 file changed

+22
-23
lines changed

torchvision/prototype/transforms/_geometry.py

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -223,19 +223,16 @@ def __init__(
223223
_check_padding_arg(padding)
224224
_check_padding_mode_arg(padding_mode)
225225

226+
# This cast does Sequence[int] -> List[int] and is required to make mypy happy
227+
if not isinstance(padding, int):
228+
padding = list(padding)
226229
self.padding = padding
227230
self.fill = _setup_fill_arg(fill)
228231
self.padding_mode = padding_mode
229232

230233
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
231234
fill = self.fill[type(inpt)]
232-
233-
# This cast does Sequence[int] -> List[int] and is required to make mypy happy
234-
padding = self.padding
235-
if not isinstance(padding, int):
236-
padding = list(padding)
237-
238-
return F.pad(inpt, padding=padding, fill=fill, padding_mode=self.padding_mode)
235+
return F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode)
239236

240237

241238
class RandomZoomOut(_RandomApplyTransform):
@@ -298,7 +295,7 @@ def __init__(
298295
self.center = center
299296

300297
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
301-
angle = float(torch.empty(1).uniform_(float(self.degrees[0]), float(self.degrees[1])).item())
298+
angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item()
302299
return dict(angle=angle)
303300

304301
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
@@ -355,7 +352,7 @@ def __init__(
355352
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
356353
height, width = query_spatial_size(flat_inputs)
357354

358-
angle = float(torch.empty(1).uniform_(float(self.degrees[0]), float(self.degrees[1])).item())
355+
angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item()
359356
if self.translate is not None:
360357
max_dx = float(self.translate[0] * width)
361358
max_dy = float(self.translate[1] * height)
@@ -366,15 +363,15 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
366363
translate = (0, 0)
367364

368365
if self.scale is not None:
369-
scale = float(torch.empty(1).uniform_(self.scale[0], self.scale[1]).item())
366+
scale = torch.empty(1).uniform_(self.scale[0], self.scale[1]).item()
370367
else:
371368
scale = 1.0
372369

373370
shear_x = shear_y = 0.0
374371
if self.shear is not None:
375-
shear_x = float(torch.empty(1).uniform_(self.shear[0], self.shear[1]).item())
372+
shear_x = torch.empty(1).uniform_(self.shear[0], self.shear[1]).item()
376373
if len(self.shear) == 4:
377-
shear_y = float(torch.empty(1).uniform_(self.shear[2], self.shear[3]).item())
374+
shear_y = torch.empty(1).uniform_(self.shear[2], self.shear[3]).item()
378375

379376
shear = (shear_x, shear_y)
380377
return dict(angle=angle, translate=translate, scale=scale, shear=shear)
@@ -451,12 +448,12 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
451448
needs_pad = any(padding)
452449

453450
needs_vert_crop, top = (
454-
(True, int(torch.randint(0, padded_height - cropped_height + 1, size=())))
451+
(True, torch.randint(0, padded_height - cropped_height + 1, size=()).item())
455452
if padded_height > cropped_height
456453
else (False, 0)
457454
)
458455
needs_horz_crop, left = (
459-
(True, int(torch.randint(0, padded_width - cropped_width + 1, size=())))
456+
(True, torch.randint(0, padded_width - cropped_width + 1, size=()).item())
460457
if padded_width > cropped_width
461458
else (False, 0)
462459
)
@@ -506,21 +503,23 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
506503

507504
half_height = height // 2
508505
half_width = width // 2
506+
bound_height = int(distortion_scale * half_height) + 1
507+
bound_width = int(distortion_scale * half_width) + 1
509508
topleft = [
510-
int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()),
511-
int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()),
509+
torch.randint(0, bound_width + 1, size=(1,)).item(),
510+
torch.randint(0, bound_height, size=(1,)).item(),
512511
]
513512
topright = [
514-
int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()),
515-
int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()),
513+
torch.randint(width - bound_width, width, size=(1,)).item(),
514+
torch.randint(0, bound_height, size=(1,)).item(),
516515
]
517516
botright = [
518-
int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()),
519-
int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()),
517+
torch.randint(width - bound_width, width, size=(1,)).item(),
518+
torch.randint(height - bound_height, height, size=(1,)).item(),
520519
]
521520
botleft = [
522-
int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()),
523-
int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()),
521+
torch.randint(0, bound_width, size=(1,)).item(),
522+
torch.randint(height - bound_height, height, size=(1,)).item(),
524523
]
525524
startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]
526525
endpoints = [topleft, topright, botright, botleft]
@@ -623,7 +622,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
623622

624623
while True:
625624
# sample an option
626-
idx = int(torch.randint(low=0, high=len(self.options), size=(1,)))
625+
idx = torch.randint(low=0, high=len(self.options), size=(1,)).item()
627626
min_jaccard_overlap = self.options[idx]
628627
if min_jaccard_overlap >= 1.0: # a value larger than 1 encodes the leave as-is option
629628
return dict()

0 commit comments

Comments
 (0)