Skip to content

Commit dbbc5c8

Browse files
authored
[proto] Added dict support for fill arg for remaining transforms (#6599)
* Updated fill arg typehint for affine, perspective and elastic ops * Updated pad op on prototype side * Code updates * Few other minor updates * WIP * WIP * Updates * Update _image.py * Fixed tests
1 parent 2718f73 commit dbbc5c8

File tree

4 files changed

+44
-38
lines changed

4 files changed

+44
-38
lines changed

test/test_prototype_transforms.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ def test__transform_image_mask(self, fill, mocker):
391391
if isinstance(fill, int):
392392
calls = [
393393
mocker.call(image, padding=1, fill=fill, padding_mode="constant"),
394-
mocker.call(mask, padding=1, fill=0, padding_mode="constant"),
394+
mocker.call(mask, padding=1, fill=fill, padding_mode="constant"),
395395
]
396396
else:
397397
calls = [
@@ -467,7 +467,7 @@ def test__transform_image_mask(self, fill, mocker):
467467
if isinstance(fill, int):
468468
calls = [
469469
mocker.call(image, **params, fill=fill),
470-
mocker.call(mask, **params, fill=0),
470+
mocker.call(mask, **params, fill=fill),
471471
]
472472
else:
473473
calls = [
@@ -1555,7 +1555,7 @@ def test__get_params(self, mocker):
15551555

15561556
@pytest.mark.parametrize("needs", list(itertools.product((False, True), repeat=2)))
15571557
def test__transform(self, mocker, needs):
1558-
fill_sentinel = mocker.MagicMock()
1558+
fill_sentinel = 12
15591559
padding_mode_sentinel = mocker.MagicMock()
15601560

15611561
transform = transforms.FixedSizeCrop((-1, -1), fill=fill_sentinel, padding_mode=padding_mode_sentinel)

test/test_prototype_transforms_functional.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,9 +195,16 @@ def pad_image_tensor():
195195
for image, padding, fill, padding_mode in itertools.product(
196196
make_images(),
197197
[[1], [1, 1], [1, 1, 2, 2]], # padding
198-
[None, 12, 12.0], # fill
198+
[None, 128.0, 128, [12.0], [12.0, 13.0, 14.0]], # fill
199199
["constant", "symmetric", "edge", "reflect"], # padding mode,
200200
):
201+
if padding_mode != "constant" and fill is not None:
202+
# ValueError: Padding mode 'reflect' is not supported if fill is not scalar
203+
continue
204+
205+
if isinstance(fill, list) and len(fill) != image.shape[-3]:
206+
continue
207+
201208
yield ArgsKwargs(image, padding=padding, fill=fill, padding_mode=padding_mode)
202209

203210

torchvision/prototype/transforms/_geometry.py

Lines changed: 32 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -211,10 +211,12 @@ def _check_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> None:
211211

212212

213213
def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, FillType]:
214+
_check_fill_arg(fill)
215+
214216
if isinstance(fill, dict):
215217
return fill
216-
else:
217-
return defaultdict(lambda: fill, {features.Mask: 0}) # type: ignore[arg-type, return-value]
218+
219+
return defaultdict(lambda: fill) # type: ignore[arg-type, return-value]
218220

219221

220222
def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None:
@@ -242,7 +244,6 @@ def __init__(
242244
super().__init__()
243245

244246
_check_padding_arg(padding)
245-
_check_fill_arg(fill)
246247
_check_padding_mode_arg(padding_mode)
247248

248249
self.padding = padding
@@ -263,7 +264,6 @@ def __init__(
263264
) -> None:
264265
super().__init__(p=p)
265266

266-
_check_fill_arg(fill)
267267
self.fill = _setup_fill_arg(fill)
268268

269269
_check_sequence_input(side_range, "side_range", req_sizes=(2,))
@@ -299,17 +299,15 @@ def __init__(
299299
degrees: Union[numbers.Number, Sequence],
300300
interpolation: InterpolationMode = InterpolationMode.NEAREST,
301301
expand: bool = False,
302-
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
302+
fill: Union[FillType, Dict[Type, FillType]] = 0,
303303
center: Optional[List[float]] = None,
304304
) -> None:
305305
super().__init__()
306306
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
307307
self.interpolation = interpolation
308308
self.expand = expand
309309

310-
_check_fill_arg(fill)
311-
312-
self.fill = fill
310+
self.fill = _setup_fill_arg(fill)
313311

314312
if center is not None:
315313
_check_sequence_input(center, "center", req_sizes=(2,))
@@ -321,12 +319,13 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
321319
return dict(angle=angle)
322320

323321
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
322+
fill = self.fill[type(inpt)]
324323
return F.rotate(
325324
inpt,
326325
**params,
327326
interpolation=self.interpolation,
328327
expand=self.expand,
329-
fill=self.fill,
328+
fill=fill,
330329
center=self.center,
331330
)
332331

@@ -339,7 +338,7 @@ def __init__(
339338
scale: Optional[Sequence[float]] = None,
340339
shear: Optional[Union[float, Sequence[float]]] = None,
341340
interpolation: InterpolationMode = InterpolationMode.NEAREST,
342-
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
341+
fill: Union[FillType, Dict[Type, FillType]] = 0,
343342
center: Optional[List[float]] = None,
344343
) -> None:
345344
super().__init__()
@@ -363,10 +362,7 @@ def __init__(
363362
self.shear = shear
364363

365364
self.interpolation = interpolation
366-
367-
_check_fill_arg(fill)
368-
369-
self.fill = fill
365+
self.fill = _setup_fill_arg(fill)
370366

371367
if center is not None:
372368
_check_sequence_input(center, "center", req_sizes=(2,))
@@ -404,11 +400,12 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
404400
return dict(angle=angle, translate=translate, scale=scale, shear=shear)
405401

406402
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
403+
fill = self.fill[type(inpt)]
407404
return F.affine(
408405
inpt,
409406
**params,
410407
interpolation=self.interpolation,
411-
fill=self.fill,
408+
fill=fill,
412409
center=self.center,
413410
)
414411

@@ -419,7 +416,7 @@ def __init__(
419416
size: Union[int, Sequence[int]],
420417
padding: Optional[Union[int, Sequence[int]]] = None,
421418
pad_if_needed: bool = False,
422-
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
419+
fill: Union[FillType, Dict[Type, FillType]] = 0,
423420
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
424421
) -> None:
425422
super().__init__()
@@ -429,12 +426,11 @@ def __init__(
429426
if pad_if_needed or padding is not None:
430427
if padding is not None:
431428
_check_padding_arg(padding)
432-
_check_fill_arg(fill)
433429
_check_padding_mode_arg(padding_mode)
434430

435431
self.padding = padding
436432
self.pad_if_needed = pad_if_needed
437-
self.fill = fill
433+
self.fill = _setup_fill_arg(fill)
438434
self.padding_mode = padding_mode
439435

440436
def _get_params(self, sample: Any) -> Dict[str, Any]:
@@ -483,17 +479,18 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
483479

484480
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
485481
# TODO: (PERF) check for speed optimization if we avoid repeated pad calls
482+
fill = self.fill[type(inpt)]
486483
if self.padding is not None:
487-
inpt = F.pad(inpt, padding=self.padding, fill=self.fill, padding_mode=self.padding_mode)
484+
inpt = F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode)
488485

489486
if self.pad_if_needed:
490487
input_width, input_height = params["input_width"], params["input_height"]
491488
if input_width < self.size[1]:
492489
padding = [self.size[1] - input_width, 0]
493-
inpt = F.pad(inpt, padding=padding, fill=self.fill, padding_mode=self.padding_mode)
490+
inpt = F.pad(inpt, padding=padding, fill=fill, padding_mode=self.padding_mode)
494491
if input_height < self.size[0]:
495492
padding = [0, self.size[0] - input_height]
496-
inpt = F.pad(inpt, padding=padding, fill=self.fill, padding_mode=self.padding_mode)
493+
inpt = F.pad(inpt, padding=padding, fill=fill, padding_mode=self.padding_mode)
497494

498495
return F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"])
499496

@@ -502,19 +499,18 @@ class RandomPerspective(_RandomApplyTransform):
502499
def __init__(
503500
self,
504501
distortion_scale: float = 0.5,
505-
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
502+
fill: Union[FillType, Dict[Type, FillType]] = 0,
506503
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
507504
p: float = 0.5,
508505
) -> None:
509506
super().__init__(p=p)
510507

511-
_check_fill_arg(fill)
512508
if not (0 <= distortion_scale <= 1):
513509
raise ValueError("Argument distortion_scale value should be between 0 and 1")
514510

515511
self.distortion_scale = distortion_scale
516512
self.interpolation = interpolation
517-
self.fill = fill
513+
self.fill = _setup_fill_arg(fill)
518514

519515
def _get_params(self, sample: Any) -> Dict[str, Any]:
520516
# Get image size
@@ -546,10 +542,11 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
546542
return dict(startpoints=startpoints, endpoints=endpoints)
547543

548544
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
545+
fill = self.fill[type(inpt)]
549546
return F.perspective(
550547
inpt,
551548
**params,
552-
fill=self.fill,
549+
fill=fill,
553550
interpolation=self.interpolation,
554551
)
555552

@@ -576,17 +573,15 @@ def __init__(
576573
self,
577574
alpha: Union[float, Sequence[float]] = 50.0,
578575
sigma: Union[float, Sequence[float]] = 5.0,
579-
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
576+
fill: Union[FillType, Dict[Type, FillType]] = 0,
580577
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
581578
) -> None:
582579
super().__init__()
583580
self.alpha = _setup_float_or_seq(alpha, "alpha", 2)
584581
self.sigma = _setup_float_or_seq(sigma, "sigma", 2)
585582

586-
_check_fill_arg(fill)
587-
588583
self.interpolation = interpolation
589-
self.fill = fill
584+
self.fill = _setup_fill_arg(fill)
590585

591586
def _get_params(self, sample: Any) -> Dict[str, Any]:
592587
# Get image size
@@ -614,10 +609,11 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
614609
return dict(displacement=displacement)
615610

616611
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
612+
fill = self.fill[type(inpt)]
617613
return F.elastic(
618614
inpt,
619615
**params,
620-
fill=self.fill,
616+
fill=fill,
621617
interpolation=self.interpolation,
622618
)
623619

@@ -789,14 +785,16 @@ class FixedSizeCrop(Transform):
789785
def __init__(
790786
self,
791787
size: Union[int, Sequence[int]],
792-
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
788+
fill: Union[FillType, Dict[Type, FillType]] = 0,
793789
padding_mode: str = "constant",
794790
) -> None:
795791
super().__init__()
796792
size = tuple(_setup_size(size, error_msg="Please provide only two dimensions (h, w) for size."))
797793
self.crop_height = size[0]
798794
self.crop_width = size[1]
799-
self.fill = fill # TODO: Fill is currently respected only on PIL. Apply tensor patch.
795+
796+
self.fill = _setup_fill_arg(fill)
797+
800798
self.padding_mode = padding_mode
801799

802800
def _get_params(self, sample: Any) -> Dict[str, Any]:
@@ -869,7 +867,8 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
869867
)
870868

871869
if params["needs_pad"]:
872-
inpt = F.pad(inpt, params["padding"], fill=self.fill, padding_mode=self.padding_mode)
870+
fill = self.fill[type(inpt)]
871+
inpt = F.pad(inpt, params["padding"], fill=fill, padding_mode=self.padding_mode)
873872

874873
return inpt
875874

torchvision/transforms/functional_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,7 @@ def _assert_grid_transform_inputs(
499499

500500
# Check fill
501501
num_channels = get_dimensions(img)[0]
502-
if fill is not None and isinstance(fill, (tuple, list)) and (len(fill) > 1 and len(fill) != num_channels):
502+
if fill is not None and isinstance(fill, (tuple, list)) and len(fill) > 1 and len(fill) != num_channels:
503503
msg = (
504504
"The number of elements in 'fill' cannot broadcast to match the number of "
505505
"channels of the image ({} != {})"

0 commit comments

Comments
 (0)