Skip to content

replace new_like with wrap_like #6718

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Oct 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions test/test_prototype_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,14 @@ def test_inplace_op_no_wrapping():
assert type(label) is features.Label


def test_new_like():
def test_wrap_like():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = features.Label(tensor, categories=["foo", "bar"])

# any operation besides .to() and .clone() will do here
output = label * 2

label_new = features.Label.new_like(label, output)
label_new = features.Label.wrap_like(label, output)

assert type(label_new) is features.Label
assert label_new.data_ptr() == output.data_ptr()
Expand Down
17 changes: 9 additions & 8 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch
from common_utils import assert_equal, cpu_and_gpu
from prototype_common_utils import (
DEFAULT_EXTRA_DIMS,
make_bounding_box,
make_bounding_boxes,
make_detection_mask,
Expand All @@ -23,6 +24,8 @@
from torchvision.prototype import features, transforms
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image

BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims]


def make_vanilla_tensor_images(*args, **kwargs):
for image in make_images(*args, **kwargs):
Expand Down Expand Up @@ -109,13 +112,11 @@ def test_common(self, transform, input):
(
transform,
[
dict(
image=features.Image.new_like(image, image.unsqueeze(0), dtype=torch.float),
one_hot_label=features.OneHotLabel.new_like(
one_hot_label, one_hot_label.unsqueeze(0), dtype=torch.float
),
dict(image=image, one_hot_label=one_hot_label)
for image, one_hot_label in itertools.product(
make_images(extra_dims=BATCH_EXTRA_DIMS, dtypes=[torch.float]),
make_one_hot_labels(extra_dims=BATCH_EXTRA_DIMS, dtypes=[torch.float]),
)
for image, one_hot_label in itertools.product(make_images(), make_one_hot_labels())
],
)
for transform in [
Expand Down Expand Up @@ -300,7 +301,7 @@ def test_features_bounding_box(self, p):
actual = transform(input)

expected_image_tensor = torch.tensor([5, 0, 10, 5]) if p == 1.0 else input
expected = features.BoundingBox.new_like(input, data=expected_image_tensor)
expected = features.BoundingBox.wrap_like(input, expected_image_tensor)
assert_equal(expected, actual)
assert actual.format == expected.format
assert actual.image_size == expected.image_size
Expand Down Expand Up @@ -353,7 +354,7 @@ def test_features_bounding_box(self, p):
actual = transform(input)

expected_image_tensor = torch.tensor([0, 5, 5, 10]) if p == 1.0 else input
expected = features.BoundingBox.new_like(input, data=expected_image_tensor)
expected = features.BoundingBox.wrap_like(input, expected_image_tensor)
assert_equal(expected, actual)
assert actual.format == expected.format
assert actual.image_size == expected.image_size
Expand Down
57 changes: 29 additions & 28 deletions torchvision/prototype/features/_bounding_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@ class BoundingBox(_Feature):
format: BoundingBoxFormat
image_size: Tuple[int, int]

@classmethod
def _wrap(cls, tensor: torch.Tensor, *, format: BoundingBoxFormat, image_size: Tuple[int, int]) -> BoundingBox:
bounding_box = tensor.as_subclass(cls)
bounding_box.format = format
bounding_box.image_size = image_size
return bounding_box

def __new__(
cls,
data: Any,
Expand All @@ -29,52 +36,46 @@ def __new__(
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False,
) -> BoundingBox:
bounding_box = super().__new__(cls, data, dtype=dtype, device=device, requires_grad=requires_grad)
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)

if isinstance(format, str):
format = BoundingBoxFormat.from_str(format.upper())
bounding_box.format = format

bounding_box.image_size = image_size

return bounding_box

def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
return self._make_repr(format=self.format, image_size=self.image_size)
return cls._wrap(tensor, format=format, image_size=image_size)

@classmethod
def new_like(
def wrap_like(
cls,
other: BoundingBox,
data: Any,
tensor: torch.Tensor,
*,
format: Optional[Union[BoundingBoxFormat, str]] = None,
format: Optional[BoundingBoxFormat] = None,
image_size: Optional[Tuple[int, int]] = None,
**kwargs: Any,
) -> BoundingBox:
return super().new_like(
other,
data,
return cls._wrap(
tensor,
format=format if format is not None else other.format,
image_size=image_size if image_size is not None else other.image_size,
**kwargs,
)

def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
return self._make_repr(format=self.format, image_size=self.image_size)

def to_format(self, format: Union[str, BoundingBoxFormat]) -> BoundingBox:
if isinstance(format, str):
format = BoundingBoxFormat.from_str(format.upper())

return BoundingBox.new_like(
return BoundingBox.wrap_like(
self, self._F.convert_format_bounding_box(self, old_format=self.format, new_format=format), format=format
)

def horizontal_flip(self) -> BoundingBox:
output = self._F.horizontal_flip_bounding_box(self, format=self.format, image_size=self.image_size)
return BoundingBox.new_like(self, output)
return BoundingBox.wrap_like(self, output)

def vertical_flip(self) -> BoundingBox:
output = self._F.vertical_flip_bounding_box(self, format=self.format, image_size=self.image_size)
return BoundingBox.new_like(self, output)
return BoundingBox.wrap_like(self, output)

def resize( # type: ignore[override]
self,
Expand All @@ -84,19 +85,19 @@ def resize( # type: ignore[override]
antialias: bool = False,
) -> BoundingBox:
output, image_size = self._F.resize_bounding_box(self, image_size=self.image_size, size=size, max_size=max_size)
return BoundingBox.new_like(self, output, image_size=image_size)
return BoundingBox.wrap_like(self, output, image_size=image_size)

def crop(self, top: int, left: int, height: int, width: int) -> BoundingBox:
output, image_size = self._F.crop_bounding_box(
self, self.format, top=top, left=left, height=height, width=width
)
return BoundingBox.new_like(self, output, image_size=image_size)
return BoundingBox.wrap_like(self, output, image_size=image_size)

def center_crop(self, output_size: List[int]) -> BoundingBox:
output, image_size = self._F.center_crop_bounding_box(
self, format=self.format, image_size=self.image_size, output_size=output_size
)
return BoundingBox.new_like(self, output, image_size=image_size)
return BoundingBox.wrap_like(self, output, image_size=image_size)

def resized_crop(
self,
Expand All @@ -109,7 +110,7 @@ def resized_crop(
antialias: bool = False,
) -> BoundingBox:
output, image_size = self._F.resized_crop_bounding_box(self, self.format, top, left, height, width, size=size)
return BoundingBox.new_like(self, output, image_size=image_size)
return BoundingBox.wrap_like(self, output, image_size=image_size)

def pad(
self,
Expand All @@ -120,7 +121,7 @@ def pad(
output, image_size = self._F.pad_bounding_box(
self, format=self.format, image_size=self.image_size, padding=padding, padding_mode=padding_mode
)
return BoundingBox.new_like(self, output, image_size=image_size)
return BoundingBox.wrap_like(self, output, image_size=image_size)

def rotate(
self,
Expand All @@ -133,7 +134,7 @@ def rotate(
output, image_size = self._F.rotate_bounding_box(
self, format=self.format, image_size=self.image_size, angle=angle, expand=expand, center=center
)
return BoundingBox.new_like(self, output, image_size=image_size)
return BoundingBox.wrap_like(self, output, image_size=image_size)

def affine(
self,
Expand All @@ -155,7 +156,7 @@ def affine(
shear=shear,
center=center,
)
return BoundingBox.new_like(self, output, dtype=output.dtype)
return BoundingBox.wrap_like(self, output)

def perspective(
self,
Expand All @@ -164,7 +165,7 @@ def perspective(
fill: FillTypeJIT = None,
) -> BoundingBox:
output = self._F.perspective_bounding_box(self, self.format, perspective_coeffs)
return BoundingBox.new_like(self, output, dtype=output.dtype)
return BoundingBox.wrap_like(self, output)

def elastic(
self,
Expand All @@ -173,4 +174,4 @@ def elastic(
fill: FillTypeJIT = None,
) -> BoundingBox:
output = self._F.elastic_bounding_box(self, self.format, displacement)
return BoundingBox.new_like(self, output, dtype=output.dtype)
return BoundingBox.wrap_like(self, output)
11 changes: 10 additions & 1 deletion torchvision/prototype/features/_encoded.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@


class EncodedData(_Feature):
@classmethod
def _wrap(cls: Type[D], tensor: torch.Tensor) -> D:
return tensor.as_subclass(cls)

def __new__(
cls,
data: Any,
Expand All @@ -22,8 +26,13 @@ def __new__(
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False,
) -> EncodedData:
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
# TODO: warn / bail out if we encounter a tensor with shape other than (N,) or with dtype other than uint8?
return super().__new__(cls, data, dtype=dtype, device=device, requires_grad=requires_grad)
return cls._wrap(tensor)

@classmethod
def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D:
return cls._wrap(tensor)

@classmethod
def from_file(cls: Type[D], file: BinaryIO, **kwargs: Any) -> D:
Expand Down
55 changes: 23 additions & 32 deletions torchvision/prototype/features/_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,48 +21,39 @@ def is_simple_tensor(inpt: Any) -> bool:
class _Feature(torch.Tensor):
__F: Optional[ModuleType] = None

def __new__(
cls: Type[F],
@staticmethod
def _to_tensor(
data: Any,
*,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False,
) -> F:
return (
torch.as_tensor( # type: ignore[return-value]
data,
dtype=dtype,
device=device,
)
.as_subclass(cls)
.requires_grad_(requires_grad)
)
) -> torch.Tensor:
return torch.as_tensor(data, dtype=dtype, device=device).requires_grad_(requires_grad)

@classmethod
def new_like(
cls: Type[F],
other: F,
# FIXME: this is just here for BC with the prototype datasets. Some datasets use the _Feature directly to have a
# a no-op input for the prototype transforms. For this use case, we can't use plain tensors, since they will be
# interpreted as images. We should decide if we want a public no-op feature like `GenericFeature` or make this one
# public again.
def __new__(
cls,
data: Any,
*,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: Optional[bool] = None,
**kwargs: Any,
) -> F:
return cls(
data,
dtype=dtype if dtype is not None else other.dtype,
device=device if device is not None else other.device,
requires_grad=requires_grad if requires_grad is not None else other.requires_grad,
**kwargs,
)
requires_grad: bool = False,
) -> _Feature:
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
return tensor.as_subclass(_Feature)

@classmethod
def wrap_like(cls: Type[F], other: F, tensor: torch.Tensor) -> F:
# FIXME: this is just here for BC with the prototype datasets. See __new__ for details. If that is resolved,
# this method should be made abstract
# raise NotImplementedError
return tensor.as_subclass(cls)

_NO_WRAPPING_EXCEPTIONS = {
torch.Tensor.clone: lambda cls, input, output: cls.new_like(input, output),
torch.Tensor.to: lambda cls, input, output: cls.new_like(
input, output, dtype=output.dtype, device=output.device
),
torch.Tensor.clone: lambda cls, input, output: cls.wrap_like(input, output),
torch.Tensor.to: lambda cls, input, output: cls.wrap_like(input, output),
# We don't need to wrap the output of `Tensor.requires_grad_`, since it is an inplace operation and thus
# retains the type automatically
torch.Tensor.requires_grad_: lambda cls, input, output: output,
Expand Down
Loading