Skip to content

Commit 2fa52d6

Browse files
authored
Merge branch 'main' into main
2 parents 950d0d8 + 053e7eb commit 2fa52d6

File tree

3 files changed

+64
-29
lines changed

3 files changed

+64
-29
lines changed

test/test_prototype_transforms.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ class TestSmoke:
7171
transforms.CenterCrop([16, 16]),
7272
transforms.ConvertImageDtype(),
7373
transforms.RandomHorizontalFlip(),
74+
transforms.Pad(5),
7475
)
7576
def test_common(self, transform, input):
7677
transform(input)

torchvision/prototype/transforms/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
TenCrop,
1616
BatchMultiCrop,
1717
RandomHorizontalFlip,
18+
Pad,
1819
RandomZoomOut,
1920
)
2021
from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace

torchvision/prototype/transforms/_geometry.py

Lines changed: 62 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import collections.abc
22
import math
3+
import numbers
34
import warnings
45
from typing import Any, Dict, List, Union, Sequence, Tuple, cast
56

@@ -9,6 +10,7 @@
910
from torchvision.prototype.transforms import Transform, InterpolationMode, functional as F
1011
from torchvision.transforms.functional import pil_to_tensor
1112
from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int
13+
from typing_extensions import Literal
1214

1315
from ._utils import query_image, get_image_dimensions, has_any, is_simple_tensor
1416

@@ -272,42 +274,31 @@ def apply_recursively(obj: Any) -> Any:
272274
return apply_recursively(inputs if len(inputs) > 1 else inputs[0])
273275

274276

275-
class RandomZoomOut(Transform):
277+
class Pad(Transform):
276278
def __init__(
277-
self, fill: Union[float, Sequence[float]] = 0.0, side_range: Tuple[float, float] = (1.0, 4.0), p: float = 0.5
279+
self,
280+
padding: Union[int, Sequence[int]],
281+
fill: Union[float, Sequence[float]] = 0.0,
282+
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
278283
) -> None:
279284
super().__init__()
285+
if not isinstance(padding, (numbers.Number, tuple, list)):
286+
raise TypeError("Got inappropriate padding arg")
280287

281-
if fill is None:
282-
fill = 0.0
283-
self.fill = fill
284-
285-
self.side_range = side_range
286-
if side_range[0] < 1.0 or side_range[0] > side_range[1]:
287-
raise ValueError(f"Invalid canvas side range provided {side_range}.")
288-
289-
self.p = p
290-
291-
def _get_params(self, sample: Any) -> Dict[str, Any]:
292-
image = query_image(sample)
293-
orig_c, orig_h, orig_w = get_image_dimensions(image)
294-
295-
r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0])
296-
canvas_width = int(orig_w * r)
297-
canvas_height = int(orig_h * r)
288+
if not isinstance(fill, (numbers.Number, str, tuple, list)):
289+
raise TypeError("Got inappropriate fill arg")
298290

299-
r = torch.rand(2)
300-
left = int((canvas_width - orig_w) * r[0])
301-
top = int((canvas_height - orig_h) * r[1])
302-
right = canvas_width - (left + orig_w)
303-
bottom = canvas_height - (top + orig_h)
304-
padding = [left, top, right, bottom]
291+
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
292+
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
305293

306-
fill = self.fill
307-
if not isinstance(fill, collections.abc.Sequence):
308-
fill = [fill] * orig_c
294+
if isinstance(padding, Sequence) and len(padding) not in [1, 2, 4]:
295+
raise ValueError(
296+
f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple"
297+
)
309298

310-
return dict(padding=padding, fill=fill)
299+
self.padding = padding
300+
self.fill = fill
301+
self.padding_mode = padding_mode
311302

312303
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
313304
if isinstance(input, features.Image) or is_simple_tensor(input):
@@ -349,6 +340,48 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
349340
else:
350341
return input
351342

343+
344+
class RandomZoomOut(Transform):
345+
def __init__(
346+
self, fill: Union[float, Sequence[float]] = 0.0, side_range: Tuple[float, float] = (1.0, 4.0), p: float = 0.5
347+
) -> None:
348+
super().__init__()
349+
350+
if fill is None:
351+
fill = 0.0
352+
self.fill = fill
353+
354+
self.side_range = side_range
355+
if side_range[0] < 1.0 or side_range[0] > side_range[1]:
356+
raise ValueError(f"Invalid canvas side range provided {side_range}.")
357+
358+
self.p = p
359+
360+
def _get_params(self, sample: Any) -> Dict[str, Any]:
361+
image = query_image(sample)
362+
orig_c, orig_h, orig_w = get_image_dimensions(image)
363+
364+
r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0])
365+
canvas_width = int(orig_w * r)
366+
canvas_height = int(orig_h * r)
367+
368+
r = torch.rand(2)
369+
left = int((canvas_width - orig_w) * r[0])
370+
top = int((canvas_height - orig_h) * r[1])
371+
right = canvas_width - (left + orig_w)
372+
bottom = canvas_height - (top + orig_h)
373+
padding = [left, top, right, bottom]
374+
375+
fill = self.fill
376+
if not isinstance(fill, collections.abc.Sequence):
377+
fill = [fill] * orig_c
378+
379+
return dict(padding=padding, fill=fill)
380+
381+
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
382+
transform = Pad(**params, padding_mode="constant")
383+
return transform(input)
384+
352385
def forward(self, *inputs: Any) -> Any:
353386
sample = inputs if len(inputs) > 1 else inputs[0]
354387
if torch.rand(1) >= self.p:

0 commit comments

Comments
 (0)