|
1 | 1 | import collections.abc
|
2 | 2 | import math
|
| 3 | +import numbers |
3 | 4 | import warnings
|
4 | 5 | from typing import Any, Dict, List, Union, Sequence, Tuple, cast
|
5 | 6 |
|
|
9 | 10 | from torchvision.prototype.transforms import Transform, InterpolationMode, functional as F
|
10 | 11 | from torchvision.transforms.functional import pil_to_tensor
|
11 | 12 | from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int
|
| 13 | +from typing_extensions import Literal |
12 | 14 |
|
13 | 15 | from ._utils import query_image, get_image_dimensions, has_any, is_simple_tensor
|
14 | 16 |
|
@@ -272,42 +274,31 @@ def apply_recursively(obj: Any) -> Any:
|
272 | 274 | return apply_recursively(inputs if len(inputs) > 1 else inputs[0])
|
273 | 275 |
|
274 | 276 |
|
275 |
| -class RandomZoomOut(Transform): |
| 277 | +class Pad(Transform): |
276 | 278 | def __init__(
|
277 |
| - self, fill: Union[float, Sequence[float]] = 0.0, side_range: Tuple[float, float] = (1.0, 4.0), p: float = 0.5 |
| 279 | + self, |
| 280 | + padding: Union[int, Sequence[int]], |
| 281 | + fill: Union[float, Sequence[float]] = 0.0, |
| 282 | + padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant", |
278 | 283 | ) -> None:
|
279 | 284 | super().__init__()
|
| 285 | + if not isinstance(padding, (numbers.Number, tuple, list)): |
| 286 | + raise TypeError("Got inappropriate padding arg") |
280 | 287 |
|
281 |
| - if fill is None: |
282 |
| - fill = 0.0 |
283 |
| - self.fill = fill |
284 |
| - |
285 |
| - self.side_range = side_range |
286 |
| - if side_range[0] < 1.0 or side_range[0] > side_range[1]: |
287 |
| - raise ValueError(f"Invalid canvas side range provided {side_range}.") |
288 |
| - |
289 |
| - self.p = p |
290 |
| - |
291 |
| - def _get_params(self, sample: Any) -> Dict[str, Any]: |
292 |
| - image = query_image(sample) |
293 |
| - orig_c, orig_h, orig_w = get_image_dimensions(image) |
294 |
| - |
295 |
| - r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0]) |
296 |
| - canvas_width = int(orig_w * r) |
297 |
| - canvas_height = int(orig_h * r) |
| 288 | + if not isinstance(fill, (numbers.Number, str, tuple, list)): |
| 289 | + raise TypeError("Got inappropriate fill arg") |
298 | 290 |
|
299 |
| - r = torch.rand(2) |
300 |
| - left = int((canvas_width - orig_w) * r[0]) |
301 |
| - top = int((canvas_height - orig_h) * r[1]) |
302 |
| - right = canvas_width - (left + orig_w) |
303 |
| - bottom = canvas_height - (top + orig_h) |
304 |
| - padding = [left, top, right, bottom] |
| 291 | + if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: |
| 292 | + raise ValueError("Padding mode should be either constant, edge, reflect or symmetric") |
305 | 293 |
|
306 |
| - fill = self.fill |
307 |
| - if not isinstance(fill, collections.abc.Sequence): |
308 |
| - fill = [fill] * orig_c |
| 294 | + if isinstance(padding, Sequence) and len(padding) not in [1, 2, 4]: |
| 295 | + raise ValueError( |
| 296 | + f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple" |
| 297 | + ) |
309 | 298 |
|
310 |
| - return dict(padding=padding, fill=fill) |
| 299 | + self.padding = padding |
| 300 | + self.fill = fill |
| 301 | + self.padding_mode = padding_mode |
311 | 302 |
|
312 | 303 | def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
|
313 | 304 | if isinstance(input, features.Image) or is_simple_tensor(input):
|
@@ -349,6 +340,48 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
|
349 | 340 | else:
|
350 | 341 | return input
|
351 | 342 |
|
| 343 | + |
| 344 | +class RandomZoomOut(Transform): |
| 345 | + def __init__( |
| 346 | + self, fill: Union[float, Sequence[float]] = 0.0, side_range: Tuple[float, float] = (1.0, 4.0), p: float = 0.5 |
| 347 | + ) -> None: |
| 348 | + super().__init__() |
| 349 | + |
| 350 | + if fill is None: |
| 351 | + fill = 0.0 |
| 352 | + self.fill = fill |
| 353 | + |
| 354 | + self.side_range = side_range |
| 355 | + if side_range[0] < 1.0 or side_range[0] > side_range[1]: |
| 356 | + raise ValueError(f"Invalid canvas side range provided {side_range}.") |
| 357 | + |
| 358 | + self.p = p |
| 359 | + |
| 360 | + def _get_params(self, sample: Any) -> Dict[str, Any]: |
| 361 | + image = query_image(sample) |
| 362 | + orig_c, orig_h, orig_w = get_image_dimensions(image) |
| 363 | + |
| 364 | + r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0]) |
| 365 | + canvas_width = int(orig_w * r) |
| 366 | + canvas_height = int(orig_h * r) |
| 367 | + |
| 368 | + r = torch.rand(2) |
| 369 | + left = int((canvas_width - orig_w) * r[0]) |
| 370 | + top = int((canvas_height - orig_h) * r[1]) |
| 371 | + right = canvas_width - (left + orig_w) |
| 372 | + bottom = canvas_height - (top + orig_h) |
| 373 | + padding = [left, top, right, bottom] |
| 374 | + |
| 375 | + fill = self.fill |
| 376 | + if not isinstance(fill, collections.abc.Sequence): |
| 377 | + fill = [fill] * orig_c |
| 378 | + |
| 379 | + return dict(padding=padding, fill=fill) |
| 380 | + |
| 381 | + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: |
| 382 | + transform = Pad(**params, padding_mode="constant") |
| 383 | + return transform(input) |
| 384 | + |
352 | 385 | def forward(self, *inputs: Any) -> Any:
|
353 | 386 | sample = inputs if len(inputs) > 1 else inputs[0]
|
354 | 387 | if torch.rand(1) >= self.p:
|
|
0 commit comments