Skip to content

Commit 36ae12e

Browse files
authored
Copy paste prototype.datapoints and prototype.transforms out of prototype area (#7259)
1 parent d805aea commit 36ae12e

33 files changed

+8531
-0
lines changed

torchvision/datapoints/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from ._bounding_box import BoundingBox, BoundingBoxFormat
2+
from ._datapoint import FillType, FillTypeJIT, InputType, InputTypeJIT
3+
from ._image import Image, ImageType, ImageTypeJIT, TensorImageType, TensorImageTypeJIT
4+
from ._label import Label, OneHotLabel
5+
from ._mask import Mask
6+
from ._video import TensorVideoType, TensorVideoTypeJIT, Video, VideoType, VideoTypeJIT
7+
8+
from ._dataset_wrapper import wrap_dataset_for_transforms_v2 # type: ignore[attr-defined] # usort: skip
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
from __future__ import annotations
2+
3+
from typing import Any, List, Optional, Sequence, Tuple, Union
4+
5+
import torch
6+
from torchvision._utils import StrEnum
7+
from torchvision.transforms import InterpolationMode # TODO: this needs to be moved out of transforms
8+
9+
from ._datapoint import Datapoint, FillTypeJIT
10+
11+
12+
class BoundingBoxFormat(StrEnum):
13+
XYXY = StrEnum.auto()
14+
XYWH = StrEnum.auto()
15+
CXCYWH = StrEnum.auto()
16+
17+
18+
class BoundingBox(Datapoint):
19+
format: BoundingBoxFormat
20+
spatial_size: Tuple[int, int]
21+
22+
@classmethod
23+
def _wrap(cls, tensor: torch.Tensor, *, format: BoundingBoxFormat, spatial_size: Tuple[int, int]) -> BoundingBox:
24+
bounding_box = tensor.as_subclass(cls)
25+
bounding_box.format = format
26+
bounding_box.spatial_size = spatial_size
27+
return bounding_box
28+
29+
def __new__(
30+
cls,
31+
data: Any,
32+
*,
33+
format: Union[BoundingBoxFormat, str],
34+
spatial_size: Tuple[int, int],
35+
dtype: Optional[torch.dtype] = None,
36+
device: Optional[Union[torch.device, str, int]] = None,
37+
requires_grad: Optional[bool] = None,
38+
) -> BoundingBox:
39+
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
40+
41+
if isinstance(format, str):
42+
format = BoundingBoxFormat.from_str(format.upper())
43+
44+
return cls._wrap(tensor, format=format, spatial_size=spatial_size)
45+
46+
@classmethod
47+
def wrap_like(
48+
cls,
49+
other: BoundingBox,
50+
tensor: torch.Tensor,
51+
*,
52+
format: Optional[BoundingBoxFormat] = None,
53+
spatial_size: Optional[Tuple[int, int]] = None,
54+
) -> BoundingBox:
55+
return cls._wrap(
56+
tensor,
57+
format=format if format is not None else other.format,
58+
spatial_size=spatial_size if spatial_size is not None else other.spatial_size,
59+
)
60+
61+
def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
62+
return self._make_repr(format=self.format, spatial_size=self.spatial_size)
63+
64+
def horizontal_flip(self) -> BoundingBox:
65+
output = self._F.horizontal_flip_bounding_box(
66+
self.as_subclass(torch.Tensor), format=self.format, spatial_size=self.spatial_size
67+
)
68+
return BoundingBox.wrap_like(self, output)
69+
70+
def vertical_flip(self) -> BoundingBox:
71+
output = self._F.vertical_flip_bounding_box(
72+
self.as_subclass(torch.Tensor), format=self.format, spatial_size=self.spatial_size
73+
)
74+
return BoundingBox.wrap_like(self, output)
75+
76+
def resize( # type: ignore[override]
77+
self,
78+
size: List[int],
79+
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
80+
max_size: Optional[int] = None,
81+
antialias: Optional[Union[str, bool]] = "warn",
82+
) -> BoundingBox:
83+
output, spatial_size = self._F.resize_bounding_box(
84+
self.as_subclass(torch.Tensor),
85+
spatial_size=self.spatial_size,
86+
size=size,
87+
max_size=max_size,
88+
)
89+
return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)
90+
91+
def crop(self, top: int, left: int, height: int, width: int) -> BoundingBox:
92+
output, spatial_size = self._F.crop_bounding_box(
93+
self.as_subclass(torch.Tensor), self.format, top=top, left=left, height=height, width=width
94+
)
95+
return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)
96+
97+
def center_crop(self, output_size: List[int]) -> BoundingBox:
98+
output, spatial_size = self._F.center_crop_bounding_box(
99+
self.as_subclass(torch.Tensor), format=self.format, spatial_size=self.spatial_size, output_size=output_size
100+
)
101+
return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)
102+
103+
def resized_crop(
104+
self,
105+
top: int,
106+
left: int,
107+
height: int,
108+
width: int,
109+
size: List[int],
110+
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
111+
antialias: Optional[Union[str, bool]] = "warn",
112+
) -> BoundingBox:
113+
output, spatial_size = self._F.resized_crop_bounding_box(
114+
self.as_subclass(torch.Tensor), self.format, top, left, height, width, size=size
115+
)
116+
return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)
117+
118+
def pad(
119+
self,
120+
padding: Union[int, Sequence[int]],
121+
fill: Optional[Union[int, float, List[float]]] = None,
122+
padding_mode: str = "constant",
123+
) -> BoundingBox:
124+
output, spatial_size = self._F.pad_bounding_box(
125+
self.as_subclass(torch.Tensor),
126+
format=self.format,
127+
spatial_size=self.spatial_size,
128+
padding=padding,
129+
padding_mode=padding_mode,
130+
)
131+
return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)
132+
133+
def rotate(
134+
self,
135+
angle: float,
136+
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
137+
expand: bool = False,
138+
center: Optional[List[float]] = None,
139+
fill: FillTypeJIT = None,
140+
) -> BoundingBox:
141+
output, spatial_size = self._F.rotate_bounding_box(
142+
self.as_subclass(torch.Tensor),
143+
format=self.format,
144+
spatial_size=self.spatial_size,
145+
angle=angle,
146+
expand=expand,
147+
center=center,
148+
)
149+
return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)
150+
151+
def affine(
152+
self,
153+
angle: Union[int, float],
154+
translate: List[float],
155+
scale: float,
156+
shear: List[float],
157+
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
158+
fill: FillTypeJIT = None,
159+
center: Optional[List[float]] = None,
160+
) -> BoundingBox:
161+
output = self._F.affine_bounding_box(
162+
self.as_subclass(torch.Tensor),
163+
self.format,
164+
self.spatial_size,
165+
angle,
166+
translate=translate,
167+
scale=scale,
168+
shear=shear,
169+
center=center,
170+
)
171+
return BoundingBox.wrap_like(self, output)
172+
173+
def perspective(
174+
self,
175+
startpoints: Optional[List[List[int]]],
176+
endpoints: Optional[List[List[int]]],
177+
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
178+
fill: FillTypeJIT = None,
179+
coefficients: Optional[List[float]] = None,
180+
) -> BoundingBox:
181+
output = self._F.perspective_bounding_box(
182+
self.as_subclass(torch.Tensor),
183+
format=self.format,
184+
spatial_size=self.spatial_size,
185+
startpoints=startpoints,
186+
endpoints=endpoints,
187+
coefficients=coefficients,
188+
)
189+
return BoundingBox.wrap_like(self, output)
190+
191+
def elastic(
192+
self,
193+
displacement: torch.Tensor,
194+
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
195+
fill: FillTypeJIT = None,
196+
) -> BoundingBox:
197+
output = self._F.elastic_bounding_box(
198+
self.as_subclass(torch.Tensor), self.format, self.spatial_size, displacement=displacement
199+
)
200+
return BoundingBox.wrap_like(self, output)

0 commit comments

Comments
 (0)