diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 8ae75f84c5b..c1a350955db 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -724,7 +724,7 @@ def rotate(img, angle, resample=False, expand=False, center=None, fill=0): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) if isinstance(fill, int): - fill = tuple([fill] * 3) + fill = tuple([fill] * len(img.getbands())) return img.rotate(angle, resample, expand, center, fillcolor=fill)