@@ -475,7 +475,7 @@ def _assert_grid_transform_inputs(
475
475
img : Tensor ,
476
476
matrix : Optional [List [float ]],
477
477
interpolation : str ,
478
- fill : Optional [List [float ]],
478
+ fill : Optional [Union [ int , float , List [float ] ]],
479
479
supported_interpolation_modes : List [str ],
480
480
coeffs : Optional [List [float ]] = None ,
481
481
) -> None :
@@ -499,7 +499,7 @@ def _assert_grid_transform_inputs(
499
499
500
500
# Check fill
501
501
num_channels = get_dimensions (img )[0 ]
502
- if isinstance (fill , (tuple , list )) and (len (fill ) > 1 and len (fill ) != num_channels ):
502
+ if fill is not None and isinstance (fill , (tuple , list )) and (len (fill ) > 1 and len (fill ) != num_channels ):
503
503
msg = (
504
504
"The number of elements in 'fill' cannot broadcast to match the number of "
505
505
"channels of the image ({} != {})"
@@ -539,7 +539,9 @@ def _cast_squeeze_out(img: Tensor, need_cast: bool, need_squeeze: bool, out_dtyp
539
539
return img
540
540
541
541
542
- def _apply_grid_transform (img : Tensor , grid : Tensor , mode : str , fill : Optional [List [float ]]) -> Tensor :
542
+ def _apply_grid_transform (
543
+ img : Tensor , grid : Tensor , mode : str , fill : Optional [Union [int , float , List [float ]]]
544
+ ) -> Tensor :
543
545
544
546
img , need_cast , need_squeeze , out_dtype = _cast_squeeze_in (img , [grid .dtype ])
545
547
@@ -559,8 +561,8 @@ def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str, fill: Optional[L
559
561
mask = img [:, - 1 :, :, :] # N * 1 * H * W
560
562
img = img [:, :- 1 , :, :] # N * C * H * W
561
563
mask = mask .expand_as (img )
562
- len_fill = len (fill ) if isinstance (fill , (tuple , list )) else 1
563
- fill_img = torch .tensor (fill , dtype = img .dtype , device = img .device ).view (1 , len_fill , 1 , 1 ).expand_as (img )
564
+ fill_list , len_fill = ( fill , len (fill )) if isinstance (fill , (tuple , list )) else ([ float ( fill )], 1 )
565
+ fill_img = torch .tensor (fill_list , dtype = img .dtype , device = img .device ).view (1 , len_fill , 1 , 1 ).expand_as (img )
564
566
if mode == "nearest" :
565
567
mask = mask < 0.5
566
568
img [mask ] = fill_img [mask ]
@@ -648,7 +650,7 @@ def rotate(
648
650
matrix : List [float ],
649
651
interpolation : str = "nearest" ,
650
652
expand : bool = False ,
651
- fill : Optional [List [float ]] = None ,
653
+ fill : Optional [Union [ int , float , List [float ] ]] = None ,
652
654
) -> Tensor :
653
655
_assert_grid_transform_inputs (img , matrix , interpolation , fill , ["nearest" , "bilinear" ])
654
656
w , h = img .shape [- 1 ], img .shape [- 2 ]
0 commit comments