Skip to content

Commit 52dc452

Browse files
authored
Merge branch 'main' into refactoring/5523-random-horizontal-flip
2 parents 6e84c7a + 6013230 commit 52dc452

File tree

16 files changed

+292
-53
lines changed

16 files changed

+292
-53
lines changed

README.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ supported Python versions.
2121
+--------------------------+--------------------------+---------------------------------+
2222
| ``torch`` | ``torchvision`` | ``python`` |
2323
+==========================+==========================+=================================+
24-
| ``main`` / ``nightly`` | ``main`` / ``nightly`` | ``>=3.7``, ``<=3.9`` |
24+
| ``main`` / ``nightly`` | ``main`` / ``nightly`` | ``>=3.7``, ``<=3.10`` |
25+
+--------------------------+--------------------------+---------------------------------+
26+
| ``1.11.0`` | ``0.12.0`` | ``>=3.7``, ``<=3.10`` |
2527
+--------------------------+--------------------------+---------------------------------+
2628
| ``1.10.2`` | ``0.11.3`` | ``>=3.6``, ``<=3.9`` |
2729
+--------------------------+--------------------------+---------------------------------+

references/detection/transforms.py

Lines changed: 112 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Tuple, Dict, Optional
1+
from typing import List, Tuple, Dict, Optional, Union
22

33
import torch
44
import torchvision
@@ -326,3 +326,114 @@ def forward(
326326
)
327327

328328
return image, target
329+
330+
331+
class FixedSizeCrop(nn.Module):
332+
def __init__(self, size, fill=0, padding_mode="constant"):
333+
super().__init__()
334+
size = tuple(T._setup_size(size, error_msg="Please provide only two dimensions (h, w) for size."))
335+
self.crop_height = size[0]
336+
self.crop_width = size[1]
337+
self.fill = fill # TODO: Fill is currently respected only on PIL. Apply tensor patch.
338+
self.padding_mode = padding_mode
339+
340+
def _pad(self, img, target, padding):
341+
# Taken from the functional_tensor.py pad
342+
if isinstance(padding, int):
343+
pad_left = pad_right = pad_top = pad_bottom = padding
344+
elif len(padding) == 1:
345+
pad_left = pad_right = pad_top = pad_bottom = padding[0]
346+
elif len(padding) == 2:
347+
pad_left = pad_right = padding[0]
348+
pad_top = pad_bottom = padding[1]
349+
else:
350+
pad_left = padding[0]
351+
pad_top = padding[1]
352+
pad_right = padding[2]
353+
pad_bottom = padding[3]
354+
355+
padding = [pad_left, pad_top, pad_right, pad_bottom]
356+
img = F.pad(img, padding, self.fill, self.padding_mode)
357+
if target is not None:
358+
target["boxes"][:, 0::2] += pad_left
359+
target["boxes"][:, 1::2] += pad_top
360+
if "masks" in target:
361+
target["masks"] = F.pad(target["masks"], padding, 0, "constant")
362+
363+
return img, target
364+
365+
def _crop(self, img, target, top, left, height, width):
366+
img = F.crop(img, top, left, height, width)
367+
if target is not None:
368+
boxes = target["boxes"]
369+
boxes[:, 0::2] -= left
370+
boxes[:, 1::2] -= top
371+
boxes[:, 0::2].clamp_(min=0, max=width)
372+
boxes[:, 1::2].clamp_(min=0, max=height)
373+
374+
is_valid = (boxes[:, 0] < boxes[:, 2]) & (boxes[:, 1] < boxes[:, 3])
375+
376+
target["boxes"] = boxes[is_valid]
377+
target["labels"] = target["labels"][is_valid]
378+
if "masks" in target:
379+
target["masks"] = F.crop(target["masks"][is_valid], top, left, height, width)
380+
381+
return img, target
382+
383+
def forward(self, img, target=None):
384+
_, height, width = F.get_dimensions(img)
385+
new_height = min(height, self.crop_height)
386+
new_width = min(width, self.crop_width)
387+
388+
if new_height != height or new_width != width:
389+
offset_height = max(height - self.crop_height, 0)
390+
offset_width = max(width - self.crop_width, 0)
391+
392+
r = torch.rand(1)
393+
top = int(offset_height * r)
394+
left = int(offset_width * r)
395+
396+
img, target = self._crop(img, target, top, left, new_height, new_width)
397+
398+
pad_bottom = max(self.crop_height - new_height, 0)
399+
pad_right = max(self.crop_width - new_width, 0)
400+
if pad_bottom != 0 or pad_right != 0:
401+
img, target = self._pad(img, target, [0, 0, pad_right, pad_bottom])
402+
403+
return img, target
404+
405+
406+
class RandomShortestSize(nn.Module):
407+
def __init__(
408+
self,
409+
min_size: Union[List[int], Tuple[int], int],
410+
max_size: int,
411+
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
412+
):
413+
super().__init__()
414+
self.min_size = [min_size] if isinstance(min_size, int) else list(min_size)
415+
self.max_size = max_size
416+
self.interpolation = interpolation
417+
418+
def forward(
419+
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
420+
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
421+
_, orig_height, orig_width = F.get_dimensions(image)
422+
423+
min_size = self.min_size[torch.randint(len(self.min_size), (1,)).item()]
424+
r = min(min_size / min(orig_height, orig_width), self.max_size / max(orig_height, orig_width))
425+
426+
new_width = int(orig_width * r)
427+
new_height = int(orig_height * r)
428+
429+
image = F.resize(image, [new_height, new_width], interpolation=self.interpolation)
430+
431+
if target is not None:
432+
target["boxes"][:, 0::2] *= new_width / orig_width
433+
target["boxes"][:, 1::2] *= new_height / orig_height
434+
if "masks" in target:
435+
target["masks"] = F.resize(
436+
target["masks"], [new_height, new_width], interpolation=InterpolationMode.NEAREST
437+
)
438+
439+
return image, target

test/builtin_dataset_mocks.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
from datasets_utils import make_zip, make_tar, create_image_folder, create_image_file
2020
from torch.nn.functional import one_hot
2121
from torch.testing import make_tensor as _make_tensor
22+
from torchvision._utils import sequence_to_str
2223
from torchvision.prototype.datasets._api import find
23-
from torchvision.prototype.utils._internal import sequence_to_str
2424

2525
make_tensor = functools.partial(_make_tensor, device="cpu")
2626
make_scalar = functools.partial(make_tensor, ())
@@ -1329,20 +1329,20 @@ def cub200(info, root, config):
13291329

13301330
@register_mock
13311331
def eurosat(info, root, config):
1332-
data_folder = pathlib.Path(root, "eurosat", "2750")
1332+
data_folder = root / "2750"
13331333
data_folder.mkdir(parents=True)
13341334

13351335
num_examples_per_class = 3
1336-
classes = ("AnnualCrop", "Forest")
1337-
for cls in classes:
1336+
categories = ["AnnualCrop", "Forest"]
1337+
for category in categories:
13381338
create_image_folder(
13391339
root=data_folder,
1340-
name=cls,
1341-
file_name_fn=lambda idx: f"{cls}_{idx}.jpg",
1340+
name=category,
1341+
file_name_fn=lambda idx: f"{category}_{idx + 1}.jpg",
13421342
num_examples=num_examples_per_class,
13431343
)
13441344
make_zip(root, "EuroSAT.zip", data_folder)
1345-
return len(classes) * num_examples_per_class
1345+
return len(categories) * num_examples_per_class
13461346

13471347

13481348
@register_mock

test/test_prototype_utils.py renamed to test/test_internal_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import pytest
2-
from torchvision.prototype.utils._internal import sequence_to_str
2+
from torchvision._utils import sequence_to_str
33

44

55
@pytest.mark.parametrize(

test/test_prototype_builtin_datasets.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
from torch.utils.data.datapipes.iter.grouping import ShardingFilterIterDataPipe as ShardingFilter
1111
from torch.utils.data.graph import traverse
1212
from torchdata.datapipes.iter import IterDataPipe, Shuffler
13+
from torchvision._utils import sequence_to_str
1314
from torchvision.prototype import transforms, datasets
14-
from torchvision.prototype.utils._internal import sequence_to_str
1515

1616

1717
assert_samples_equal = functools.partial(
@@ -53,6 +53,8 @@ def test_sample(self, test_home, dataset_mock, config):
5353

5454
try:
5555
sample = next(iter(dataset))
56+
except StopIteration:
57+
raise AssertionError("Unable to draw any sample.") from None
5658
except Exception as error:
5759
raise AssertionError("Drawing a sample raised the error above.") from error
5860

torchvision/_utils.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import enum
2-
from typing import TypeVar, Type
2+
from typing import Sequence, TypeVar, Type
33

44
T = TypeVar("T", bound=enum.Enum)
55

@@ -18,3 +18,15 @@ def from_str(self: Type[T], member: str) -> T: # type: ignore[misc]
1818

1919
class StrEnum(enum.Enum, metaclass=StrEnumMeta):
2020
pass
21+
22+
23+
def sequence_to_str(seq: Sequence, separate_last: str = "") -> str:
24+
if not seq:
25+
return ""
26+
if len(seq) == 1:
27+
return f"'{seq[0]}'"
28+
29+
head = "'" + "', '".join([str(item) for item in seq[:-1]]) + "'"
30+
tail = f"{'' if separate_last and len(seq) == 2 else ','} {separate_last}'{seq[-1]}'"
31+
32+
return head + tail

torchvision/prototype/datasets/utils/_dataset.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
from typing import Any, Dict, List, Optional, Sequence, Union, Tuple, Collection
88

99
from torch.utils.data import IterDataPipe
10-
from torchvision.prototype.utils._internal import FrozenBunch, make_repr, add_suggestion, sequence_to_str
10+
from torchvision._utils import sequence_to_str
11+
from torchvision.prototype.utils._internal import FrozenBunch, make_repr, add_suggestion
1112

1213
from .._home import use_sharded_dataset
1314
from ._internal import BUILTIN_DIR, _make_sharded_datapipe

torchvision/prototype/features/_bounding_box.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def to_format(self, format: Union[str, BoundingBoxFormat]) -> BoundingBox:
6464
from torchvision.prototype.transforms.functional import convert_bounding_box_format
6565

6666
if isinstance(format, str):
67-
format = BoundingBoxFormat[format]
67+
format = BoundingBoxFormat.from_str(format.upper())
6868

6969
return BoundingBox.new_like(
7070
self, convert_bounding_box_format(self, old_format=self.format, new_format=format), format=format

torchvision/prototype/transforms/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
TenCrop,
1616
BatchMultiCrop,
1717
RandomHorizontalFlip,
18+
RandomZoomOut,
1819
)
1920
from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace
2021
from ._misc import Identity, Normalize, ToDtype, Lambda

torchvision/prototype/transforms/_geometry.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,3 +270,88 @@ def apply_recursively(obj: Any) -> Any:
270270
return obj
271271

272272
return apply_recursively(inputs if len(inputs) > 1 else inputs[0])
273+
274+
275+
class RandomZoomOut(Transform):
276+
def __init__(
277+
self, fill: Union[float, Sequence[float]] = 0.0, side_range: Tuple[float, float] = (1.0, 4.0), p: float = 0.5
278+
) -> None:
279+
super().__init__()
280+
281+
if fill is None:
282+
fill = 0.0
283+
self.fill = fill
284+
285+
self.side_range = side_range
286+
if side_range[0] < 1.0 or side_range[0] > side_range[1]:
287+
raise ValueError(f"Invalid canvas side range provided {side_range}.")
288+
289+
self.p = p
290+
291+
def _get_params(self, sample: Any) -> Dict[str, Any]:
292+
image = query_image(sample)
293+
orig_c, orig_h, orig_w = get_image_dimensions(image)
294+
295+
r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0])
296+
canvas_width = int(orig_w * r)
297+
canvas_height = int(orig_h * r)
298+
299+
r = torch.rand(2)
300+
left = int((canvas_width - orig_w) * r[0])
301+
top = int((canvas_height - orig_h) * r[1])
302+
right = canvas_width - (left + orig_w)
303+
bottom = canvas_height - (top + orig_h)
304+
padding = [left, top, right, bottom]
305+
306+
fill = self.fill
307+
if not isinstance(fill, collections.abc.Sequence):
308+
fill = [fill] * orig_c
309+
310+
return dict(padding=padding, fill=fill)
311+
312+
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
313+
if isinstance(input, features.Image) or is_simple_tensor(input):
314+
# PyTorch's pad supports only integers on fill. So we need to overwrite the colour
315+
output = F.pad_image_tensor(input, params["padding"], fill=0, padding_mode="constant")
316+
317+
left, top, right, bottom = params["padding"]
318+
fill = torch.tensor(params["fill"], dtype=input.dtype, device=input.device).to().view(-1, 1, 1)
319+
320+
if top > 0:
321+
output[..., :top, :] = fill
322+
if left > 0:
323+
output[..., :, :left] = fill
324+
if bottom > 0:
325+
output[..., -bottom:, :] = fill
326+
if right > 0:
327+
output[..., :, -right:] = fill
328+
329+
if isinstance(input, features.Image):
330+
output = features.Image.new_like(input, output)
331+
332+
return output
333+
elif isinstance(input, PIL.Image.Image):
334+
return F.pad_image_pil(
335+
input,
336+
params["padding"],
337+
fill=tuple(int(v) if input.mode != "F" else v for v in params["fill"]),
338+
padding_mode="constant",
339+
)
340+
elif isinstance(input, features.BoundingBox):
341+
output = F.pad_bounding_box(input, params["padding"], format=input.format)
342+
343+
left, top, right, bottom = params["padding"]
344+
height, width = input.image_size
345+
height += top + bottom
346+
width += left + right
347+
348+
return features.BoundingBox.new_like(input, output, image_size=(height, width))
349+
else:
350+
return input
351+
352+
def forward(self, *inputs: Any) -> Any:
353+
sample = inputs if len(inputs) > 1 else inputs[0]
354+
if torch.rand(1) >= self.p:
355+
return sample
356+
357+
return super().forward(sample)

torchvision/prototype/transforms/functional/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
rotate_image_pil,
5656
pad_image_tensor,
5757
pad_image_pil,
58+
pad_bounding_box,
5859
crop_image_tensor,
5960
crop_image_pil,
6061
perspective_image_tensor,

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def horizontal_flip_bounding_box(
3131
bounding_box[:, [0, 2]] = image_size[1] - bounding_box[:, [2, 0]]
3232

3333
return convert_bounding_box_format(
34-
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format
34+
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
3535
).view(shape)
3636

3737

@@ -214,6 +214,26 @@ def rotate_image_pil(
214214
pad_image_tensor = _FT.pad
215215
pad_image_pil = _FP.pad
216216

217+
218+
def pad_bounding_box(
219+
bounding_box: torch.Tensor, padding: List[int], format: features.BoundingBoxFormat
220+
) -> torch.Tensor:
221+
left, _, top, _ = _FT._parse_pad_padding(padding)
222+
223+
shape = bounding_box.shape
224+
225+
bounding_box = convert_bounding_box_format(
226+
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
227+
).view(-1, 4)
228+
229+
bounding_box[:, 0::2] += left
230+
bounding_box[:, 1::2] += top
231+
232+
return convert_bounding_box_format(
233+
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
234+
).view(shape)
235+
236+
217237
crop_image_tensor = _FT.crop
218238
crop_image_pil = _FP.crop
219239

0 commit comments

Comments
 (0)