@@ -211,10 +211,12 @@ def _check_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> None:
211
211
212
212
213
213
def _setup_fill_arg (fill : Union [FillType , Dict [Type , FillType ]]) -> Dict [Type , FillType ]:
214
+ _check_fill_arg (fill )
215
+
214
216
if isinstance (fill , dict ):
215
217
return fill
216
- else :
217
- return defaultdict (lambda : fill , { features . Mask : 0 } ) # type: ignore[arg-type, return-value]
218
+
219
+ return defaultdict (lambda : fill ) # type: ignore[arg-type, return-value]
218
220
219
221
220
222
def _check_padding_arg (padding : Union [int , Sequence [int ]]) -> None :
@@ -242,7 +244,6 @@ def __init__(
242
244
super ().__init__ ()
243
245
244
246
_check_padding_arg (padding )
245
- _check_fill_arg (fill )
246
247
_check_padding_mode_arg (padding_mode )
247
248
248
249
self .padding = padding
@@ -263,7 +264,6 @@ def __init__(
263
264
) -> None :
264
265
super ().__init__ (p = p )
265
266
266
- _check_fill_arg (fill )
267
267
self .fill = _setup_fill_arg (fill )
268
268
269
269
_check_sequence_input (side_range , "side_range" , req_sizes = (2 ,))
@@ -299,17 +299,15 @@ def __init__(
299
299
degrees : Union [numbers .Number , Sequence ],
300
300
interpolation : InterpolationMode = InterpolationMode .NEAREST ,
301
301
expand : bool = False ,
302
- fill : Union [int , float , Sequence [ int ], Sequence [ float ]] = 0 ,
302
+ fill : Union [FillType , Dict [ Type , FillType ]] = 0 ,
303
303
center : Optional [List [float ]] = None ,
304
304
) -> None :
305
305
super ().__init__ ()
306
306
self .degrees = _setup_angle (degrees , name = "degrees" , req_sizes = (2 ,))
307
307
self .interpolation = interpolation
308
308
self .expand = expand
309
309
310
- _check_fill_arg (fill )
311
-
312
- self .fill = fill
310
+ self .fill = _setup_fill_arg (fill )
313
311
314
312
if center is not None :
315
313
_check_sequence_input (center , "center" , req_sizes = (2 ,))
@@ -321,12 +319,13 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
321
319
return dict (angle = angle )
322
320
323
321
def _transform (self , inpt : Any , params : Dict [str , Any ]) -> Any :
322
+ fill = self .fill [type (inpt )]
324
323
return F .rotate (
325
324
inpt ,
326
325
** params ,
327
326
interpolation = self .interpolation ,
328
327
expand = self .expand ,
329
- fill = self . fill ,
328
+ fill = fill ,
330
329
center = self .center ,
331
330
)
332
331
@@ -339,7 +338,7 @@ def __init__(
339
338
scale : Optional [Sequence [float ]] = None ,
340
339
shear : Optional [Union [float , Sequence [float ]]] = None ,
341
340
interpolation : InterpolationMode = InterpolationMode .NEAREST ,
342
- fill : Union [int , float , Sequence [ int ], Sequence [ float ]] = 0 ,
341
+ fill : Union [FillType , Dict [ Type , FillType ]] = 0 ,
343
342
center : Optional [List [float ]] = None ,
344
343
) -> None :
345
344
super ().__init__ ()
@@ -363,10 +362,7 @@ def __init__(
363
362
self .shear = shear
364
363
365
364
self .interpolation = interpolation
366
-
367
- _check_fill_arg (fill )
368
-
369
- self .fill = fill
365
+ self .fill = _setup_fill_arg (fill )
370
366
371
367
if center is not None :
372
368
_check_sequence_input (center , "center" , req_sizes = (2 ,))
@@ -404,11 +400,12 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
404
400
return dict (angle = angle , translate = translate , scale = scale , shear = shear )
405
401
406
402
def _transform (self , inpt : Any , params : Dict [str , Any ]) -> Any :
403
+ fill = self .fill [type (inpt )]
407
404
return F .affine (
408
405
inpt ,
409
406
** params ,
410
407
interpolation = self .interpolation ,
411
- fill = self . fill ,
408
+ fill = fill ,
412
409
center = self .center ,
413
410
)
414
411
@@ -419,7 +416,7 @@ def __init__(
419
416
size : Union [int , Sequence [int ]],
420
417
padding : Optional [Union [int , Sequence [int ]]] = None ,
421
418
pad_if_needed : bool = False ,
422
- fill : Union [int , float , Sequence [ int ], Sequence [ float ]] = 0 ,
419
+ fill : Union [FillType , Dict [ Type , FillType ]] = 0 ,
423
420
padding_mode : Literal ["constant" , "edge" , "reflect" , "symmetric" ] = "constant" ,
424
421
) -> None :
425
422
super ().__init__ ()
@@ -429,12 +426,11 @@ def __init__(
429
426
if pad_if_needed or padding is not None :
430
427
if padding is not None :
431
428
_check_padding_arg (padding )
432
- _check_fill_arg (fill )
433
429
_check_padding_mode_arg (padding_mode )
434
430
435
431
self .padding = padding
436
432
self .pad_if_needed = pad_if_needed
437
- self .fill = fill
433
+ self .fill = _setup_fill_arg ( fill )
438
434
self .padding_mode = padding_mode
439
435
440
436
def _get_params (self , sample : Any ) -> Dict [str , Any ]:
@@ -483,17 +479,18 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
483
479
484
480
def _transform (self , inpt : Any , params : Dict [str , Any ]) -> Any :
485
481
# TODO: (PERF) check for speed optimization if we avoid repeated pad calls
482
+ fill = self .fill [type (inpt )]
486
483
if self .padding is not None :
487
- inpt = F .pad (inpt , padding = self .padding , fill = self . fill , padding_mode = self .padding_mode )
484
+ inpt = F .pad (inpt , padding = self .padding , fill = fill , padding_mode = self .padding_mode )
488
485
489
486
if self .pad_if_needed :
490
487
input_width , input_height = params ["input_width" ], params ["input_height" ]
491
488
if input_width < self .size [1 ]:
492
489
padding = [self .size [1 ] - input_width , 0 ]
493
- inpt = F .pad (inpt , padding = padding , fill = self . fill , padding_mode = self .padding_mode )
490
+ inpt = F .pad (inpt , padding = padding , fill = fill , padding_mode = self .padding_mode )
494
491
if input_height < self .size [0 ]:
495
492
padding = [0 , self .size [0 ] - input_height ]
496
- inpt = F .pad (inpt , padding = padding , fill = self . fill , padding_mode = self .padding_mode )
493
+ inpt = F .pad (inpt , padding = padding , fill = fill , padding_mode = self .padding_mode )
497
494
498
495
return F .crop (inpt , top = params ["top" ], left = params ["left" ], height = params ["height" ], width = params ["width" ])
499
496
@@ -502,19 +499,18 @@ class RandomPerspective(_RandomApplyTransform):
502
499
def __init__ (
503
500
self ,
504
501
distortion_scale : float = 0.5 ,
505
- fill : Union [int , float , Sequence [ int ], Sequence [ float ]] = 0 ,
502
+ fill : Union [FillType , Dict [ Type , FillType ]] = 0 ,
506
503
interpolation : InterpolationMode = InterpolationMode .BILINEAR ,
507
504
p : float = 0.5 ,
508
505
) -> None :
509
506
super ().__init__ (p = p )
510
507
511
- _check_fill_arg (fill )
512
508
if not (0 <= distortion_scale <= 1 ):
513
509
raise ValueError ("Argument distortion_scale value should be between 0 and 1" )
514
510
515
511
self .distortion_scale = distortion_scale
516
512
self .interpolation = interpolation
517
- self .fill = fill
513
+ self .fill = _setup_fill_arg ( fill )
518
514
519
515
def _get_params (self , sample : Any ) -> Dict [str , Any ]:
520
516
# Get image size
@@ -546,10 +542,11 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
546
542
return dict (startpoints = startpoints , endpoints = endpoints )
547
543
548
544
def _transform (self , inpt : Any , params : Dict [str , Any ]) -> Any :
545
+ fill = self .fill [type (inpt )]
549
546
return F .perspective (
550
547
inpt ,
551
548
** params ,
552
- fill = self . fill ,
549
+ fill = fill ,
553
550
interpolation = self .interpolation ,
554
551
)
555
552
@@ -576,17 +573,15 @@ def __init__(
576
573
self ,
577
574
alpha : Union [float , Sequence [float ]] = 50.0 ,
578
575
sigma : Union [float , Sequence [float ]] = 5.0 ,
579
- fill : Union [int , float , Sequence [ int ], Sequence [ float ]] = 0 ,
576
+ fill : Union [FillType , Dict [ Type , FillType ]] = 0 ,
580
577
interpolation : InterpolationMode = InterpolationMode .BILINEAR ,
581
578
) -> None :
582
579
super ().__init__ ()
583
580
self .alpha = _setup_float_or_seq (alpha , "alpha" , 2 )
584
581
self .sigma = _setup_float_or_seq (sigma , "sigma" , 2 )
585
582
586
- _check_fill_arg (fill )
587
-
588
583
self .interpolation = interpolation
589
- self .fill = fill
584
+ self .fill = _setup_fill_arg ( fill )
590
585
591
586
def _get_params (self , sample : Any ) -> Dict [str , Any ]:
592
587
# Get image size
@@ -614,10 +609,11 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
614
609
return dict (displacement = displacement )
615
610
616
611
def _transform (self , inpt : Any , params : Dict [str , Any ]) -> Any :
612
+ fill = self .fill [type (inpt )]
617
613
return F .elastic (
618
614
inpt ,
619
615
** params ,
620
- fill = self . fill ,
616
+ fill = fill ,
621
617
interpolation = self .interpolation ,
622
618
)
623
619
@@ -789,14 +785,16 @@ class FixedSizeCrop(Transform):
789
785
def __init__ (
790
786
self ,
791
787
size : Union [int , Sequence [int ]],
792
- fill : Union [int , float , Sequence [ int ], Sequence [ float ]] = 0 ,
788
+ fill : Union [FillType , Dict [ Type , FillType ]] = 0 ,
793
789
padding_mode : str = "constant" ,
794
790
) -> None :
795
791
super ().__init__ ()
796
792
size = tuple (_setup_size (size , error_msg = "Please provide only two dimensions (h, w) for size." ))
797
793
self .crop_height = size [0 ]
798
794
self .crop_width = size [1 ]
799
- self .fill = fill # TODO: Fill is currently respected only on PIL. Apply tensor patch.
795
+
796
+ self .fill = _setup_fill_arg (fill )
797
+
800
798
self .padding_mode = padding_mode
801
799
802
800
def _get_params (self , sample : Any ) -> Dict [str , Any ]:
@@ -869,7 +867,8 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
869
867
)
870
868
871
869
if params ["needs_pad" ]:
872
- inpt = F .pad (inpt , params ["padding" ], fill = self .fill , padding_mode = self .padding_mode )
870
+ fill = self .fill [type (inpt )]
871
+ inpt = F .pad (inpt , params ["padding" ], fill = fill , padding_mode = self .padding_mode )
873
872
874
873
return inpt
875
874
0 commit comments