Skip to content

Commit 0e7ab27

Browse files
authored
Merge branch 'main' into raft-reference-improvement
2 parents 2857e21 + 7be2f55 commit 0e7ab27

21 files changed

+393
-67
lines changed

README.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ supported Python versions.
2121
+--------------------------+--------------------------+---------------------------------+
2222
| ``torch`` | ``torchvision`` | ``python`` |
2323
+==========================+==========================+=================================+
24-
| ``main`` / ``nightly`` | ``main`` / ``nightly`` | ``>=3.7``, ``<=3.9`` |
24+
| ``main`` / ``nightly`` | ``main`` / ``nightly`` | ``>=3.7``, ``<=3.10`` |
25+
+--------------------------+--------------------------+---------------------------------+
26+
| ``1.11.0`` | ``0.12.0`` | ``>=3.7``, ``<=3.10`` |
2527
+--------------------------+--------------------------+---------------------------------+
2628
| ``1.10.2`` | ``0.11.3`` | ``>=3.6``, ``<=3.9`` |
2729
+--------------------------+--------------------------+---------------------------------+

gallery/plot_visualization_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
def show(imgs):
2323
if not isinstance(imgs, list):
2424
imgs = [imgs]
25-
fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
25+
fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
2626
for i, img in enumerate(imgs):
2727
img = img.detach()
2828
img = F.to_pil_image(img)

references/detection/transforms.py

Lines changed: 112 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Tuple, Dict, Optional
1+
from typing import List, Tuple, Dict, Optional, Union
22

33
import torch
44
import torchvision
@@ -326,3 +326,114 @@ def forward(
326326
)
327327

328328
return image, target
329+
330+
331+
class FixedSizeCrop(nn.Module):
332+
def __init__(self, size, fill=0, padding_mode="constant"):
333+
super().__init__()
334+
size = tuple(T._setup_size(size, error_msg="Please provide only two dimensions (h, w) for size."))
335+
self.crop_height = size[0]
336+
self.crop_width = size[1]
337+
self.fill = fill # TODO: Fill is currently respected only on PIL. Apply tensor patch.
338+
self.padding_mode = padding_mode
339+
340+
def _pad(self, img, target, padding):
341+
# Taken from the functional_tensor.py pad
342+
if isinstance(padding, int):
343+
pad_left = pad_right = pad_top = pad_bottom = padding
344+
elif len(padding) == 1:
345+
pad_left = pad_right = pad_top = pad_bottom = padding[0]
346+
elif len(padding) == 2:
347+
pad_left = pad_right = padding[0]
348+
pad_top = pad_bottom = padding[1]
349+
else:
350+
pad_left = padding[0]
351+
pad_top = padding[1]
352+
pad_right = padding[2]
353+
pad_bottom = padding[3]
354+
355+
padding = [pad_left, pad_top, pad_right, pad_bottom]
356+
img = F.pad(img, padding, self.fill, self.padding_mode)
357+
if target is not None:
358+
target["boxes"][:, 0::2] += pad_left
359+
target["boxes"][:, 1::2] += pad_top
360+
if "masks" in target:
361+
target["masks"] = F.pad(target["masks"], padding, 0, "constant")
362+
363+
return img, target
364+
365+
def _crop(self, img, target, top, left, height, width):
366+
img = F.crop(img, top, left, height, width)
367+
if target is not None:
368+
boxes = target["boxes"]
369+
boxes[:, 0::2] -= left
370+
boxes[:, 1::2] -= top
371+
boxes[:, 0::2].clamp_(min=0, max=width)
372+
boxes[:, 1::2].clamp_(min=0, max=height)
373+
374+
is_valid = (boxes[:, 0] < boxes[:, 2]) & (boxes[:, 1] < boxes[:, 3])
375+
376+
target["boxes"] = boxes[is_valid]
377+
target["labels"] = target["labels"][is_valid]
378+
if "masks" in target:
379+
target["masks"] = F.crop(target["masks"][is_valid], top, left, height, width)
380+
381+
return img, target
382+
383+
def forward(self, img, target=None):
384+
_, height, width = F.get_dimensions(img)
385+
new_height = min(height, self.crop_height)
386+
new_width = min(width, self.crop_width)
387+
388+
if new_height != height or new_width != width:
389+
offset_height = max(height - self.crop_height, 0)
390+
offset_width = max(width - self.crop_width, 0)
391+
392+
r = torch.rand(1)
393+
top = int(offset_height * r)
394+
left = int(offset_width * r)
395+
396+
img, target = self._crop(img, target, top, left, new_height, new_width)
397+
398+
pad_bottom = max(self.crop_height - new_height, 0)
399+
pad_right = max(self.crop_width - new_width, 0)
400+
if pad_bottom != 0 or pad_right != 0:
401+
img, target = self._pad(img, target, [0, 0, pad_right, pad_bottom])
402+
403+
return img, target
404+
405+
406+
class RandomShortestSize(nn.Module):
407+
def __init__(
408+
self,
409+
min_size: Union[List[int], Tuple[int], int],
410+
max_size: int,
411+
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
412+
):
413+
super().__init__()
414+
self.min_size = [min_size] if isinstance(min_size, int) else list(min_size)
415+
self.max_size = max_size
416+
self.interpolation = interpolation
417+
418+
def forward(
419+
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
420+
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
421+
_, orig_height, orig_width = F.get_dimensions(image)
422+
423+
min_size = self.min_size[torch.randint(len(self.min_size), (1,)).item()]
424+
r = min(min_size / min(orig_height, orig_width), self.max_size / max(orig_height, orig_width))
425+
426+
new_width = int(orig_width * r)
427+
new_height = int(orig_height * r)
428+
429+
image = F.resize(image, [new_height, new_width], interpolation=self.interpolation)
430+
431+
if target is not None:
432+
target["boxes"][:, 0::2] *= new_width / orig_width
433+
target["boxes"][:, 1::2] *= new_height / orig_height
434+
if "masks" in target:
435+
target["masks"] = F.resize(
436+
target["masks"], [new_height, new_width], interpolation=InterpolationMode.NEAREST
437+
)
438+
439+
return image, target

test/builtin_dataset_mocks.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
from datasets_utils import make_zip, make_tar, create_image_folder, create_image_file
2020
from torch.nn.functional import one_hot
2121
from torch.testing import make_tensor as _make_tensor
22+
from torchvision._utils import sequence_to_str
2223
from torchvision.prototype.datasets._api import find
23-
from torchvision.prototype.utils._internal import sequence_to_str
2424

2525
make_tensor = functools.partial(_make_tensor, device="cpu")
2626
make_scalar = functools.partial(make_tensor, ())
@@ -1329,20 +1329,20 @@ def cub200(info, root, config):
13291329

13301330
@register_mock
13311331
def eurosat(info, root, config):
1332-
data_folder = pathlib.Path(root, "eurosat", "2750")
1332+
data_folder = root / "2750"
13331333
data_folder.mkdir(parents=True)
13341334

13351335
num_examples_per_class = 3
1336-
classes = ("AnnualCrop", "Forest")
1337-
for cls in classes:
1336+
categories = ["AnnualCrop", "Forest"]
1337+
for category in categories:
13381338
create_image_folder(
13391339
root=data_folder,
1340-
name=cls,
1341-
file_name_fn=lambda idx: f"{cls}_{idx}.jpg",
1340+
name=category,
1341+
file_name_fn=lambda idx: f"{category}_{idx + 1}.jpg",
13421342
num_examples=num_examples_per_class,
13431343
)
13441344
make_zip(root, "EuroSAT.zip", data_folder)
1345-
return len(classes) * num_examples_per_class
1345+
return len(categories) * num_examples_per_class
13461346

13471347

13481348
@register_mock

test/common_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def _create_data_batch(height=3, width=3, channels=3, num_samples=4, device="cpu
137137
return batch_tensor
138138

139139

140-
assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=1e-6)
140+
assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0)
141141

142142

143143
def get_list_of_videos(tmpdir, num_videos=5, sizes=None, fps=None):
@@ -195,7 +195,7 @@ def _test_fn_on_batch(batch_tensors, fn, scripted_fn_atol=1e-8, **fn_kwargs):
195195
for i in range(len(batch_tensors)):
196196
img_tensor = batch_tensors[i, ...]
197197
transformed_img = fn(img_tensor, **fn_kwargs)
198-
assert_equal(transformed_img, transformed_batch[i, ...])
198+
torch.testing.assert_close(transformed_img, transformed_batch[i, ...], rtol=0, atol=1e-6)
199199

200200
if scripted_fn_atol >= 0:
201201
scripted_fn = torch.jit.script(fn)

test/test_prototype_utils.py renamed to test/test_internal_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import pytest
2-
from torchvision.prototype.utils._internal import sequence_to_str
2+
from torchvision._utils import sequence_to_str
33

44

55
@pytest.mark.parametrize(

test/test_prototype_builtin_datasets.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
from torch.utils.data.datapipes.iter.grouping import ShardingFilterIterDataPipe as ShardingFilter
1111
from torch.utils.data.graph import traverse
1212
from torchdata.datapipes.iter import IterDataPipe, Shuffler
13+
from torchvision._utils import sequence_to_str
1314
from torchvision.prototype import transforms, datasets
14-
from torchvision.prototype.utils._internal import sequence_to_str
1515

1616

1717
assert_samples_equal = functools.partial(
@@ -53,6 +53,8 @@ def test_sample(self, test_home, dataset_mock, config):
5353

5454
try:
5555
sample = next(iter(dataset))
56+
except StopIteration:
57+
raise AssertionError("Unable to draw any sample.") from None
5658
except Exception as error:
5759
raise AssertionError("Drawing a sample raised the error above.") from error
5860

test/test_prototype_transforms.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22

33
import pytest
44
import torch
5+
from common_utils import assert_equal
56
from test_prototype_transforms_functional import make_images, make_bounding_boxes, make_one_hot_labels
67
from torchvision.prototype import transforms, features
7-
from torchvision.transforms.functional import to_pil_image
8+
from torchvision.transforms.functional import to_pil_image, pil_to_tensor
89

910

1011
def make_vanilla_tensor_images(*args, **kwargs):
@@ -66,10 +67,10 @@ def parametrize_from_transforms(*transforms):
6667
class TestSmoke:
6768
@parametrize_from_transforms(
6869
transforms.RandomErasing(p=1.0),
69-
transforms.HorizontalFlip(),
7070
transforms.Resize([16, 16]),
7171
transforms.CenterCrop([16, 16]),
7272
transforms.ConvertImageDtype(),
73+
transforms.RandomHorizontalFlip(),
7374
)
7475
def test_common(self, transform, input):
7576
transform(input)
@@ -188,3 +189,56 @@ def test_random_resized_crop(self, transform, input):
188189
)
189190
def test_convert_image_color_space(self, transform, input):
190191
transform(input)
192+
193+
194+
@pytest.mark.parametrize("p", [0.0, 1.0])
195+
class TestRandomHorizontalFlip:
196+
def input_expected_image_tensor(self, p, dtype=torch.float32):
197+
input = torch.tensor([[[0, 1], [0, 1]], [[1, 0], [1, 0]]], dtype=dtype)
198+
expected = torch.tensor([[[1, 0], [1, 0]], [[0, 1], [0, 1]]], dtype=dtype)
199+
200+
return input, expected if p == 1 else input
201+
202+
def test_simple_tensor(self, p):
203+
input, expected = self.input_expected_image_tensor(p)
204+
transform = transforms.RandomHorizontalFlip(p=p)
205+
206+
actual = transform(input)
207+
208+
assert_equal(expected, actual)
209+
210+
def test_pil_image(self, p):
211+
input, expected = self.input_expected_image_tensor(p, dtype=torch.uint8)
212+
transform = transforms.RandomHorizontalFlip(p=p)
213+
214+
actual = transform(to_pil_image(input))
215+
216+
assert_equal(expected, pil_to_tensor(actual))
217+
218+
def test_features_image(self, p):
219+
input, expected = self.input_expected_image_tensor(p)
220+
transform = transforms.RandomHorizontalFlip(p=p)
221+
222+
actual = transform(features.Image(input))
223+
224+
assert_equal(features.Image(expected), actual)
225+
226+
def test_features_segmentation_mask(self, p):
227+
input, expected = self.input_expected_image_tensor(p)
228+
transform = transforms.RandomHorizontalFlip(p=p)
229+
230+
actual = transform(features.SegmentationMask(input))
231+
232+
assert_equal(features.SegmentationMask(expected), actual)
233+
234+
def test_features_bounding_box(self, p):
235+
input = features.BoundingBox([0, 0, 5, 5], format=features.BoundingBoxFormat.XYXY, image_size=(10, 10))
236+
transform = transforms.RandomHorizontalFlip(p=p)
237+
238+
actual = transform(input)
239+
240+
expected_image_tensor = torch.tensor([5, 0, 10, 5]) if p == 1.0 else input
241+
expected = features.BoundingBox.new_like(input, data=expected_image_tensor)
242+
assert_equal(expected, actual)
243+
assert actual.format == expected.format
244+
assert actual.image_size == expected.image_size

test/test_transforms.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -452,12 +452,12 @@ def test_resize_size_equals_small_edge_size(height, width):
452452

453453

454454
class TestPad:
455-
def test_pad(self):
455+
@pytest.mark.parametrize("fill", [85, 85.0])
456+
def test_pad(self, fill):
456457
height = random.randint(10, 32) * 2
457458
width = random.randint(10, 32) * 2
458459
img = torch.ones(3, height, width, dtype=torch.uint8)
459460
padding = random.randint(1, 20)
460-
fill = random.randint(1, 50)
461461
result = transforms.Compose(
462462
[
463463
transforms.ToPILImage(),
@@ -484,7 +484,7 @@ def test_pad_with_tuple_of_pad_values(self):
484484
output = transforms.Pad(padding)(img)
485485
assert output.size == (width + padding[0] * 2, height + padding[1] * 2)
486486

487-
padding = tuple(random.randint(1, 20) for _ in range(4))
487+
padding = [random.randint(1, 20) for _ in range(4)]
488488
output = transforms.Pad(padding)(img)
489489
assert output.size[0] == width + padding[0] + padding[2]
490490
assert output.size[1] == height + padding[1] + padding[3]

torchvision/_utils.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import enum
2-
from typing import TypeVar, Type
2+
from typing import Sequence, TypeVar, Type
33

44
T = TypeVar("T", bound=enum.Enum)
55

@@ -18,3 +18,15 @@ def from_str(self: Type[T], member: str) -> T: # type: ignore[misc]
1818

1919
class StrEnum(enum.Enum, metaclass=StrEnumMeta):
2020
pass
21+
22+
23+
def sequence_to_str(seq: Sequence, separate_last: str = "") -> str:
24+
if not seq:
25+
return ""
26+
if len(seq) == 1:
27+
return f"'{seq[0]}'"
28+
29+
head = "'" + "', '".join([str(item) for item in seq[:-1]]) + "'"
30+
tail = f"{'' if separate_last and len(seq) == 2 else ','} {separate_last}'{seq[-1]}'"
31+
32+
return head + tail

torchvision/prototype/datasets/utils/_dataset.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
from typing import Any, Dict, List, Optional, Sequence, Union, Tuple, Collection
88

99
from torch.utils.data import IterDataPipe
10-
from torchvision.prototype.utils._internal import FrozenBunch, make_repr, add_suggestion, sequence_to_str
10+
from torchvision._utils import sequence_to_str
11+
from torchvision.prototype.utils._internal import FrozenBunch, make_repr, add_suggestion
1112

1213
from .._home import use_sharded_dataset
1314
from ._internal import BUILTIN_DIR, _make_sharded_datapipe

torchvision/prototype/features/_bounding_box.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def to_format(self, format: Union[str, BoundingBoxFormat]) -> BoundingBox:
6464
from torchvision.prototype.transforms.functional import convert_bounding_box_format
6565

6666
if isinstance(format, str):
67-
format = BoundingBoxFormat[format]
67+
format = BoundingBoxFormat.from_str(format.upper())
6868

6969
return BoundingBox.new_like(
7070
self, convert_bounding_box_format(self, old_format=self.format, new_format=format), format=format

torchvision/prototype/transforms/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,16 @@
77
from ._augment import RandomErasing, RandomMixup, RandomCutmix
88
from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment, AugMix
99
from ._container import Compose, RandomApply, RandomChoice, RandomOrder
10-
from ._geometry import HorizontalFlip, Resize, CenterCrop, RandomResizedCrop, FiveCrop, TenCrop, BatchMultiCrop
10+
from ._geometry import (
11+
Resize,
12+
CenterCrop,
13+
RandomResizedCrop,
14+
FiveCrop,
15+
TenCrop,
16+
BatchMultiCrop,
17+
RandomHorizontalFlip,
18+
RandomZoomOut,
19+
)
1120
from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace
1221
from ._misc import Identity, Normalize, ToDtype, Lambda
1322
from ._presets import (

0 commit comments

Comments
 (0)