Skip to content

Commit 7cffef6

Browse files
committed
Merge branch 'revamp-prototype-features-transforms' of https://github.com/pytorch/vision into revamp-prototype-features-transforms
2 parents fad04f4 + 466b845 commit 7cffef6

12 files changed

+495
-1
lines changed
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
import functools
2+
import itertools
3+
4+
import pytest
5+
import torch.testing
6+
import torchvision.prototype.transforms.functional as F
7+
from torch import jit
8+
from torchvision.prototype import features
9+
10+
make_tensor = functools.partial(torch.testing.make_tensor, device="cpu")
11+
12+
13+
def make_image(size=None, *, color_space, extra_dims=(), dtype=torch.float32):
14+
size = size or torch.randint(16, 33, (2,)).tolist()
15+
16+
if isinstance(color_space, str):
17+
color_space = features.ColorSpace[color_space]
18+
num_channels = {
19+
features.ColorSpace.GRAYSCALE: 1,
20+
features.ColorSpace.RGB: 3,
21+
}[color_space]
22+
23+
shape = (*extra_dims, num_channels, *size)
24+
if dtype.is_floating_point:
25+
data = torch.rand(shape, dtype=dtype)
26+
else:
27+
data = torch.randint(0, torch.iinfo(dtype).max, shape, dtype=dtype)
28+
return features.Image(data, color_space=color_space)
29+
30+
31+
make_grayscale_image = functools.partial(make_image, color_space=features.ColorSpace.GRAYSCALE)
32+
make_rgb_image = functools.partial(make_image, color_space=features.ColorSpace.RGB)
33+
34+
35+
def make_images(
36+
sizes=((16, 16), (7, 33), (31, 9)),
37+
color_spaces=(features.ColorSpace.GRAYSCALE, features.ColorSpace.RGB),
38+
dtypes=(torch.float32, torch.uint8),
39+
extra_dims=((4,), (2, 3)),
40+
):
41+
for size, color_space, dtype in itertools.product(sizes, color_spaces, dtypes):
42+
yield make_image(size, color_space=color_space)
43+
44+
for color_space, extra_dims_ in itertools.product(color_spaces, extra_dims):
45+
yield make_image(color_space=color_space, extra_dims=extra_dims_)
46+
47+
48+
def randint_with_tensor_bounds(arg1, arg2=None, **kwargs):
49+
low, high = torch.broadcast_tensors(
50+
*[torch.as_tensor(arg) for arg in ((0, arg1) if arg2 is None else (arg1, arg2))]
51+
)
52+
try:
53+
return torch.stack(
54+
[
55+
torch.randint(low_scalar, high_scalar, (), **kwargs)
56+
for low_scalar, high_scalar in zip(low.flatten().tolist(), high.flatten().tolist())
57+
]
58+
).reshape(low.shape)
59+
except RuntimeError as error:
60+
raise error
61+
62+
63+
def make_bounding_box(*, format, image_size=(32, 32), extra_dims=(), dtype=torch.int64):
64+
if isinstance(format, str):
65+
format = features.BoundingBoxFormat[format]
66+
67+
height, width = image_size
68+
69+
if format == features.BoundingBoxFormat.XYXY:
70+
x1 = torch.randint(0, width // 2, extra_dims)
71+
y1 = torch.randint(0, height // 2, extra_dims)
72+
x2 = randint_with_tensor_bounds(x1 + 1, width - x1) + x1
73+
y2 = randint_with_tensor_bounds(y1 + 1, height - y1) + y1
74+
parts = (x1, y1, x2, y2)
75+
elif format == features.BoundingBoxFormat.XYWH:
76+
x = torch.randint(0, width // 2, extra_dims)
77+
y = torch.randint(0, height // 2, extra_dims)
78+
w = randint_with_tensor_bounds(1, width - x)
79+
h = randint_with_tensor_bounds(1, height - y)
80+
parts = (x, y, w, h)
81+
elif format == features.BoundingBoxFormat.CXCYWH:
82+
cx = torch.randint(1, width - 1, ())
83+
cy = torch.randint(1, height - 1, ())
84+
w = randint_with_tensor_bounds(1, torch.minimum(cx, width - cx) + 1)
85+
h = randint_with_tensor_bounds(1, torch.minimum(cy, width - cy) + 1)
86+
parts = (cx, cy, w, h)
87+
else: # format == features.BoundingBoxFormat._SENTINEL:
88+
raise ValueError()
89+
90+
return features.BoundingBox(torch.stack(parts, dim=-1).to(dtype), format=format, image_size=image_size)
91+
92+
93+
make_xyxy_bounding_box = functools.partial(make_bounding_box, format=features.BoundingBoxFormat.XYXY)
94+
95+
96+
def make_bounding_boxes(
97+
formats=(features.BoundingBoxFormat.XYXY, features.BoundingBoxFormat.XYWH, features.BoundingBoxFormat.CXCYWH),
98+
image_sizes=((32, 32),),
99+
dtypes=(torch.int64, torch.float32),
100+
extra_dims=((4,), (2, 3)),
101+
):
102+
for format, image_size, dtype in itertools.product(formats, image_sizes, dtypes):
103+
yield make_bounding_box(format=format, image_size=image_size, dtype=dtype)
104+
105+
for format, extra_dims_ in itertools.product(formats, extra_dims):
106+
yield make_bounding_box(format=format, extra_dims=extra_dims_)
107+
108+
109+
class SampleInput:
110+
def __init__(self, *args, **kwargs):
111+
self.args = args
112+
self.kwargs = kwargs
113+
114+
115+
class KernelInfo:
116+
def __init__(self, name, *, sample_inputs_fn):
117+
self.name = name
118+
self.kernel = getattr(F, name)
119+
self._sample_inputs_fn = sample_inputs_fn
120+
121+
def sample_inputs(self):
122+
yield from self._sample_inputs_fn()
123+
124+
def __call__(self, *args, **kwargs):
125+
if len(args) == 1 and not kwargs and isinstance(args[0], SampleInput):
126+
sample_input = args[0]
127+
return self.kernel(*sample_input.args, **sample_input.kwargs)
128+
129+
return self.kernel(*args, **kwargs)
130+
131+
132+
KERNEL_INFOS = []
133+
134+
135+
def register_kernel_info_from_sample_inputs_fn(sample_inputs_fn):
136+
KERNEL_INFOS.append(KernelInfo(sample_inputs_fn.__name__, sample_inputs_fn=sample_inputs_fn))
137+
return sample_inputs_fn
138+
139+
140+
@register_kernel_info_from_sample_inputs_fn
141+
def horizontal_flip_image():
142+
for image in make_images():
143+
yield SampleInput(image)
144+
145+
146+
@register_kernel_info_from_sample_inputs_fn
147+
def horizontal_flip_bounding_box():
148+
for bounding_box in make_bounding_boxes(formats=[features.BoundingBoxFormat.XYXY]):
149+
yield SampleInput(bounding_box, image_size=bounding_box.image_size)
150+
151+
152+
@register_kernel_info_from_sample_inputs_fn
153+
def resize_image():
154+
for image, interpolation in itertools.product(
155+
make_images(),
156+
[
157+
F.InterpolationMode.BILINEAR,
158+
F.InterpolationMode.NEAREST,
159+
],
160+
):
161+
height, width = image.shape[-2:]
162+
for size in [
163+
(height, width),
164+
(int(height * 0.75), int(width * 1.25)),
165+
]:
166+
yield SampleInput(image, size=size, interpolation=interpolation)
167+
168+
169+
@register_kernel_info_from_sample_inputs_fn
170+
def resize_bounding_box():
171+
for bounding_box in make_bounding_boxes():
172+
height, width = bounding_box.image_size
173+
for new_image_size in [
174+
(height, width),
175+
(int(height * 0.75), int(width * 1.25)),
176+
]:
177+
yield SampleInput(bounding_box, old_image_size=bounding_box.image_size, new_image_size=new_image_size)
178+
179+
180+
class TestKernelsCommon:
181+
@pytest.mark.parametrize("kernel_info", KERNEL_INFOS, ids=lambda kernel_info: kernel_info.name)
182+
def test_scriptable(self, kernel_info):
183+
jit.script(kernel_info.kernel)
184+
185+
@pytest.mark.parametrize(
186+
("kernel_info", "sample_input"),
187+
[
188+
pytest.param(kernel_info, sample_input, id=f"{kernel_info.name}-{idx}")
189+
for kernel_info in KERNEL_INFOS
190+
for idx, sample_input in enumerate(kernel_info.sample_inputs())
191+
],
192+
)
193+
def test_eager_vs_scripted(self, kernel_info, sample_input):
194+
eager = kernel_info(sample_input)
195+
scripted = jit.script(kernel_info.kernel)(*sample_input.args, **sample_input.kwargs)
196+
197+
torch.testing.assert_close(eager, scripted)

torchvision/prototype/features/_bounding_box.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,14 @@ def __new__(
3636
bounding_box._metadata.update(dict(format=format, image_size=image_size))
3737

3838
return bounding_box
39+
40+
def to_format(self, format: Union[str, BoundingBoxFormat]) -> "BoundingBox":
41+
# import at runtime to avoid cyclic imports
42+
from torchvision.prototype.transforms.functional import convert_bounding_box_format
43+
44+
if isinstance(format, str):
45+
format = BoundingBoxFormat[format]
46+
47+
return BoundingBox.new_like(
48+
self, convert_bounding_box_format(self, old_format=self.format, new_format=format), format=format
49+
)

torchvision/prototype/features/_encoded.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from torchvision.prototype.utils._internal import fromfile, ReadOnlyTensorBuffer
88

99
from ._feature import Feature
10+
from ._image import Image
1011

1112
D = TypeVar("D", bound="EncodedData")
1213

@@ -37,6 +38,12 @@ def image_size(self) -> Tuple[int, int]:
3738

3839
return self._image_size
3940

41+
def decode(self) -> Image:
42+
# import at runtime to avoid cyclic imports
43+
from torchvision.prototype.transforms.functional import decode_image_with_pil
44+
45+
return Image(decode_image_with_pil(self))
46+
4047

4148
class EncodedVideo(EncodedData):
4249
pass

torchvision/prototype/features/_image.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
import torch
55
from torchvision.prototype.utils._internal import StrEnum
66
from torchvision.transforms.functional import to_pil_image
7+
from torchvision.utils import draw_bounding_boxes
78
from torchvision.utils import make_grid
89

10+
from ._bounding_box import BoundingBox
911
from ._feature import Feature
1012

1113

@@ -76,3 +78,6 @@ def guess_color_space(data: torch.Tensor) -> ColorSpace:
7678

7779
def show(self) -> None:
7880
to_pil_image(make_grid(self.view(-1, *self.shape[-3:]))).show()
81+
82+
def draw_bounding_box(self, bounding_box: BoundingBox, **kwargs: Any) -> "Image":
83+
return Image.new_like(self, draw_bounding_boxes(self, bounding_box.to_format("xyxy").view(-1, 4), **kwargs))
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
from . import functional
2+
from .functional import InterpolationMode # usort: skip
3+
14
from ._transform import Transform
25
from ._container import Compose, RandomApply, RandomChoice, RandomOrder # usort: skip
3-
46
from ._geometry import Resize, RandomResize, HorizontalFlip, Crop, CenterCrop, RandomCrop
57
from ._misc import Identity, Normalize
68
from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from ._augment import erase_image, mixup_image, mixup_one_hot_label, cutmix_image, cutmix_one_hot_label
2+
from ._color import (
3+
adjust_brightness_image,
4+
adjust_contrast_image,
5+
adjust_saturation_image,
6+
adjust_sharpness_image,
7+
posterize_image,
8+
solarize_image,
9+
autocontrast_image,
10+
equalize_image,
11+
invert_image,
12+
)
13+
from ._geometry import (
14+
horizontal_flip_bounding_box,
15+
horizontal_flip_image,
16+
resize_bounding_box,
17+
resize_image,
18+
resize_segmentation_mask,
19+
center_crop_image,
20+
resized_crop_image,
21+
InterpolationMode,
22+
affine_image,
23+
rotate_image,
24+
)
25+
from ._meta_conversion import convert_color_space, convert_bounding_box_format
26+
from ._misc import normalize_image
27+
from ._type_conversion import decode_image_with_pil, decode_video_with_av, label_to_one_hot
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from typing import Tuple
2+
3+
import torch
4+
from torchvision.transforms import functional as _F
5+
6+
7+
erase_image = _F.erase
8+
9+
10+
def _mixup(input: torch.Tensor, batch_dim: int, lam: float, inplace: bool) -> torch.Tensor:
11+
if not inplace:
12+
input = input.clone()
13+
14+
input_rolled = input.roll(1, batch_dim)
15+
return input.mul_(lam).add_(input_rolled.mul_(1 - lam))
16+
17+
18+
def mixup_image(image_batch: torch.Tensor, *, lam: float, inplace: bool = False) -> torch.Tensor:
19+
return _mixup(image_batch, -4, lam, inplace)
20+
21+
22+
def mixup_one_hot_label(one_hot_label_batch: torch.Tensor, *, lam: float, inplace: bool = False) -> torch.Tensor:
23+
return _mixup(one_hot_label_batch, -2, lam, inplace)
24+
25+
26+
def cutmix_image(image: torch.Tensor, *, box: Tuple[int, int, int, int], inplace: bool = False) -> torch.Tensor:
27+
if not inplace:
28+
image = image.clone()
29+
30+
x1, y1, x2, y2 = box
31+
image_rolled = image.roll(1, -4)
32+
33+
image[..., y1:y2, x1:x2] = image_rolled[..., y1:y2, x1:x2]
34+
return image
35+
36+
37+
def cutmix_one_hot_label(
38+
one_hot_label_batch: torch.Tensor, *, lam_adjusted: float, inplace: bool = False
39+
) -> torch.Tensor:
40+
return mixup_one_hot_label(one_hot_label_batch, lam=lam_adjusted, inplace=inplace)
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from torchvision.transforms import functional as _F
2+
3+
4+
adjust_brightness_image = _F.adjust_brightness
5+
6+
adjust_saturation_image = _F.adjust_saturation
7+
8+
adjust_contrast_image = _F.adjust_contrast
9+
10+
adjust_sharpness_image = _F.adjust_sharpness
11+
12+
posterize_image = _F.posterize
13+
14+
solarize_image = _F.solarize
15+
16+
autocontrast_image = _F.autocontrast
17+
18+
equalize_image = _F.equalize
19+
20+
invert_image = _F.invert

0 commit comments

Comments
 (0)