@@ -223,19 +223,16 @@ def __init__(
223
223
_check_padding_arg (padding )
224
224
_check_padding_mode_arg (padding_mode )
225
225
226
+ # This cast does Sequence[int] -> List[int] and is required to make mypy happy
227
+ if not isinstance (padding , int ):
228
+ padding = list (padding )
226
229
self .padding = padding
227
230
self .fill = _setup_fill_arg (fill )
228
231
self .padding_mode = padding_mode
229
232
230
233
def _transform (self , inpt : Any , params : Dict [str , Any ]) -> Any :
231
234
fill = self .fill [type (inpt )]
232
-
233
- # This cast does Sequence[int] -> List[int] and is required to make mypy happy
234
- padding = self .padding
235
- if not isinstance (padding , int ):
236
- padding = list (padding )
237
-
238
- return F .pad (inpt , padding = padding , fill = fill , padding_mode = self .padding_mode )
235
+ return F .pad (inpt , padding = self .padding , fill = fill , padding_mode = self .padding_mode )
239
236
240
237
241
238
class RandomZoomOut (_RandomApplyTransform ):
@@ -298,7 +295,7 @@ def __init__(
298
295
self .center = center
299
296
300
297
def _get_params (self , flat_inputs : List [Any ]) -> Dict [str , Any ]:
301
- angle = float ( torch .empty (1 ).uniform_ (float ( self .degrees [0 ]), float ( self .degrees [1 ])) .item () )
298
+ angle = torch .empty (1 ).uniform_ (self .degrees [0 ], self .degrees [1 ]).item ()
302
299
return dict (angle = angle )
303
300
304
301
def _transform (self , inpt : Any , params : Dict [str , Any ]) -> Any :
@@ -355,7 +352,7 @@ def __init__(
355
352
def _get_params (self , flat_inputs : List [Any ]) -> Dict [str , Any ]:
356
353
height , width = query_spatial_size (flat_inputs )
357
354
358
- angle = float ( torch .empty (1 ).uniform_ (float ( self .degrees [0 ]), float ( self .degrees [1 ])) .item () )
355
+ angle = torch .empty (1 ).uniform_ (self .degrees [0 ], self .degrees [1 ]).item ()
359
356
if self .translate is not None :
360
357
max_dx = float (self .translate [0 ] * width )
361
358
max_dy = float (self .translate [1 ] * height )
@@ -366,15 +363,15 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
366
363
translate = (0 , 0 )
367
364
368
365
if self .scale is not None :
369
- scale = float ( torch .empty (1 ).uniform_ (self .scale [0 ], self .scale [1 ]).item () )
366
+ scale = torch .empty (1 ).uniform_ (self .scale [0 ], self .scale [1 ]).item ()
370
367
else :
371
368
scale = 1.0
372
369
373
370
shear_x = shear_y = 0.0
374
371
if self .shear is not None :
375
- shear_x = float ( torch .empty (1 ).uniform_ (self .shear [0 ], self .shear [1 ]).item () )
372
+ shear_x = torch .empty (1 ).uniform_ (self .shear [0 ], self .shear [1 ]).item ()
376
373
if len (self .shear ) == 4 :
377
- shear_y = float ( torch .empty (1 ).uniform_ (self .shear [2 ], self .shear [3 ]).item () )
374
+ shear_y = torch .empty (1 ).uniform_ (self .shear [2 ], self .shear [3 ]).item ()
378
375
379
376
shear = (shear_x , shear_y )
380
377
return dict (angle = angle , translate = translate , scale = scale , shear = shear )
@@ -451,12 +448,12 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
451
448
needs_pad = any (padding )
452
449
453
450
needs_vert_crop , top = (
454
- (True , int ( torch .randint (0 , padded_height - cropped_height + 1 , size = ())))
451
+ (True , torch .randint (0 , padded_height - cropped_height + 1 , size = ()). item ( ))
455
452
if padded_height > cropped_height
456
453
else (False , 0 )
457
454
)
458
455
needs_horz_crop , left = (
459
- (True , int ( torch .randint (0 , padded_width - cropped_width + 1 , size = ())))
456
+ (True , torch .randint (0 , padded_width - cropped_width + 1 , size = ()). item ( ))
460
457
if padded_width > cropped_width
461
458
else (False , 0 )
462
459
)
@@ -506,21 +503,23 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
506
503
507
504
half_height = height // 2
508
505
half_width = width // 2
506
+ bound_height = int (distortion_scale * half_height ) + 1
507
+ bound_width = int (distortion_scale * half_width ) + 1
509
508
topleft = [
510
- int ( torch .randint (0 , int ( distortion_scale * half_width ) + 1 , size = (1 ,)).item () ),
511
- int ( torch .randint (0 , int ( distortion_scale * half_height ) + 1 , size = (1 ,)).item () ),
509
+ torch .randint (0 , bound_width , size = (1 ,)).item (),
510
+ torch .randint (0 , bound_height , size = (1 ,)).item (),
512
511
]
513
512
topright = [
514
- int ( torch .randint (width - int ( distortion_scale * half_width ) - 1 , width , size = (1 ,)).item () ),
515
- int ( torch .randint (0 , int ( distortion_scale * half_height ) + 1 , size = (1 ,)).item () ),
513
+ torch .randint (width - bound_width , width , size = (1 ,)).item (),
514
+ torch .randint (0 , bound_height , size = (1 ,)).item (),
516
515
]
517
516
botright = [
518
- int ( torch .randint (width - int ( distortion_scale * half_width ) - 1 , width , size = (1 ,)).item () ),
519
- int ( torch .randint (height - int ( distortion_scale * half_height ) - 1 , height , size = (1 ,)).item () ),
517
+ torch .randint (width - bound_width , width , size = (1 ,)).item (),
518
+ torch .randint (height - bound_height , height , size = (1 ,)).item (),
520
519
]
521
520
botleft = [
522
- int ( torch .randint (0 , int ( distortion_scale * half_width ) + 1 , size = (1 ,)).item () ),
523
- int ( torch .randint (height - int ( distortion_scale * half_height ) - 1 , height , size = (1 ,)).item () ),
521
+ torch .randint (0 , bound_width , size = (1 ,)).item (),
522
+ torch .randint (height - bound_height , height , size = (1 ,)).item (),
524
523
]
525
524
startpoints = [[0 , 0 ], [width - 1 , 0 ], [width - 1 , height - 1 ], [0 , height - 1 ]]
526
525
endpoints = [topleft , topright , botright , botleft ]
@@ -623,7 +622,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
623
622
624
623
while True :
625
624
# sample an option
626
- idx = int ( torch .randint (low = 0 , high = len (self .options ), size = (1 ,)))
625
+ idx = torch .randint (low = 0 , high = len (self .options ), size = (1 ,)). item ( )
627
626
min_jaccard_overlap = self .options [idx ]
628
627
if min_jaccard_overlap >= 1.0 : # a value larger than 1 encodes the leave as-is option
629
628
return dict ()
0 commit comments