Skip to content

[proto] Draft for Transforms API v2 #6205

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 28 commits into from
Closed
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
25b3667
Added base tests for rotate_image_tensor
vfdev-5 Jun 22, 2022
6b3483d
Updated resize_image_tensor API and tests and fixed a bug with max_size
vfdev-5 Jun 22, 2022
ea7c513
Refactored and modified private api for resize functional op
vfdev-5 Jun 22, 2022
dc64e8a
Added base tests for rotate_image_tensor
vfdev-5 Jun 22, 2022
d341beb
Updated resize_image_tensor API and tests and fixed a bug with max_size
vfdev-5 Jun 22, 2022
aade78f
Fixed failures
vfdev-5 Jun 22, 2022
1e04ad3
Merge branch 'refactor-resize-max-size' into update-proto-test-rotate…
vfdev-5 Jun 22, 2022
a812a3b
More updates
vfdev-5 Jun 22, 2022
1aa6a78
Merge branch 'refactor-resize-max-size' into update-proto-test-rotate…
vfdev-5 Jun 22, 2022
6661d8d
Updated proto functional op: resize_image_*
vfdev-5 Jun 22, 2022
0972822
Fixed flake8
vfdev-5 Jun 22, 2022
f0c896f
Added max_size arg to resize_bounding_box and updated basic tests
vfdev-5 Jun 22, 2022
39e8bf6
Merge branch 'refactor-resize-max-size' into proto-transforms-api-oo
vfdev-5 Jun 22, 2022
c147e53
Merge branch 'update-proto-test-rotate-image' into proto-transforms-a…
vfdev-5 Jun 22, 2022
1a3a749
WIP Adding ops:
vfdev-5 Jun 22, 2022
740dfa7
Merge branch 'main' into update-proto-test-rotate-image
vfdev-5 Jun 23, 2022
6a5e5ab
Update functional.py
vfdev-5 Jun 23, 2022
8fcf4fa
Merge branch 'main' of github.com:pytorch/vision into proto-transform…
vfdev-5 Jun 23, 2022
b2ada45
Reverted fill/center order for rotate
vfdev-5 Jun 23, 2022
4af25c6
Merge branch 'update-proto-test-rotate-image' of github.com:vfdev-5/v…
vfdev-5 Jun 23, 2022
ff80373
Merge branch 'main' of github.com:pytorch/vision into proto-transform…
vfdev-5 Jun 23, 2022
5bc6a50
- Added Pad and WIP on Rotate
vfdev-5 Jun 27, 2022
6a5201a
Added more ops and mid-level functional API
vfdev-5 Jun 27, 2022
b659a08
- Added more color ops
vfdev-5 Jun 27, 2022
7a8f950
Added more f-mid-level ops
vfdev-5 Jun 28, 2022
7917a17
_check_inpt -> _check_input
vfdev-5 Jun 28, 2022
99bfad9
Fixed broken code, added a test for mid-level ops
vfdev-5 Jun 28, 2022
ae5eef9
Fixed bugs and started porting transforms
vfdev-5 Jun 29, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 64 additions & 56 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
import pytest
import torch
from common_utils import assert_equal
from test_prototype_transforms_functional import make_images, make_bounding_boxes, make_one_hot_labels
from test_prototype_transforms_functional import (
make_images,
make_bounding_boxes,
make_one_hot_labels,
make_segmentation_masks,
)
from torchvision.prototype import transforms, features
from torchvision.transforms.functional import to_pil_image, pil_to_tensor

Expand All @@ -25,56 +30,59 @@ def make_vanilla_tensor_bounding_boxes(*args, **kwargs):
yield bounding_box.data


def parametrize(transforms_with_inputs):
def parametrize(transforms_with_inpts):
return pytest.mark.parametrize(
("transform", "input"),
("transform", "inpt"),
[
pytest.param(
transform,
input,
id=f"{type(transform).__name__}-{type(input).__module__}.{type(input).__name__}-{idx}",
inpt,
id=f"{type(transform).__name__}-{type(inpt).__module__}.{type(inpt).__name__}-{idx}",
)
for transform, inputs in transforms_with_inputs
for idx, input in enumerate(inputs)
for transform, inpts in transforms_with_inpts
for idx, inpt in enumerate(inpts)
],
)


def parametrize_from_transforms(*transforms):
transforms_with_inputs = []
transforms_with_inpts = []
for transform in transforms:
for creation_fn in [
make_images,
make_bounding_boxes,
make_one_hot_labels,
make_vanilla_tensor_images,
make_pil_images,
make_segmentation_masks,
]:
inputs = list(creation_fn())
inpts = list(creation_fn())
try:
output = transform(inputs[0])
except Exception:
output = transform(inpts[0])
except TypeError:
continue
else:
if output is inputs[0]:
if output is inpts[0]:
continue

transforms_with_inputs.append((transform, inputs))
transforms_with_inpts.append((transform, inpts))

return parametrize(transforms_with_inputs)
return parametrize(transforms_with_inpts)


class TestSmoke:
@parametrize_from_transforms(
transforms.RandomErasing(p=1.0),
transforms.Resize([16, 16]),
transforms.CenterCrop([16, 16]),
transforms.RandomResizedCrop([16, 16]),
transforms.ConvertImageDtype(),
transforms.RandomHorizontalFlip(),
transforms.Pad(5),
)
def test_common(self, transform, input):
transform(input)
def test_common(self, transform, inpt):
output = transform(inpt)
assert type(output) == type(inpt)

@parametrize(
[
Expand All @@ -96,8 +104,8 @@ def test_common(self, transform, input):
]
]
)
def test_mixup_cutmix(self, transform, input):
transform(input)
def test_mixup_cutmix(self, transform, inpt):
transform(inpt)

@parametrize(
[
Expand Down Expand Up @@ -127,8 +135,8 @@ def test_mixup_cutmix(self, transform, input):
)
]
)
def test_auto_augment(self, transform, input):
transform(input)
def test_auto_augment(self, transform, inpt):
transform(inpt)

@parametrize(
[
Expand All @@ -144,8 +152,8 @@ def test_auto_augment(self, transform, input):
),
]
)
def test_normalize(self, transform, input):
transform(input)
def test_normalize(self, transform, inpt):
transform(inpt)

@parametrize(
[
Expand All @@ -159,8 +167,8 @@ def test_normalize(self, transform, input):
)
]
)
def test_random_resized_crop(self, transform, input):
transform(input)
def test_random_resized_crop(self, transform, inpt):
transform(inpt)

@parametrize(
[
Expand Down Expand Up @@ -188,111 +196,111 @@ def test_random_resized_crop(self, transform, input):
)
]
)
def test_convert_image_color_space(self, transform, input):
transform(input)
def test_convert_image_color_space(self, transform, inpt):
transform(inpt)


@pytest.mark.parametrize("p", [0.0, 1.0])
class TestRandomHorizontalFlip:
def input_expected_image_tensor(self, p, dtype=torch.float32):
input = torch.tensor([[[0, 1], [0, 1]], [[1, 0], [1, 0]]], dtype=dtype)
def inpt_expected_image_tensor(self, p, dtype=torch.float32):
inpt = torch.tensor([[[0, 1], [0, 1]], [[1, 0], [1, 0]]], dtype=dtype)
expected = torch.tensor([[[1, 0], [1, 0]], [[0, 1], [0, 1]]], dtype=dtype)

return input, expected if p == 1 else input
return inpt, expected if p == 1 else inpt

def test_simple_tensor(self, p):
input, expected = self.input_expected_image_tensor(p)
inpt, expected = self.inpt_expected_image_tensor(p)
transform = transforms.RandomHorizontalFlip(p=p)

actual = transform(input)
actual = transform(inpt)

assert_equal(expected, actual)

def test_pil_image(self, p):
input, expected = self.input_expected_image_tensor(p, dtype=torch.uint8)
inpt, expected = self.inpt_expected_image_tensor(p, dtype=torch.uint8)
transform = transforms.RandomHorizontalFlip(p=p)

actual = transform(to_pil_image(input))
actual = transform(to_pil_image(inpt))

assert_equal(expected, pil_to_tensor(actual))

def test_features_image(self, p):
input, expected = self.input_expected_image_tensor(p)
inpt, expected = self.inpt_expected_image_tensor(p)
transform = transforms.RandomHorizontalFlip(p=p)

actual = transform(features.Image(input))
actual = transform(features.Image(inpt))

assert_equal(features.Image(expected), actual)

def test_features_segmentation_mask(self, p):
input, expected = self.input_expected_image_tensor(p)
inpt, expected = self.inpt_expected_image_tensor(p)
transform = transforms.RandomHorizontalFlip(p=p)

actual = transform(features.SegmentationMask(input))
actual = transform(features.SegmentationMask(inpt))

assert_equal(features.SegmentationMask(expected), actual)

def test_features_bounding_box(self, p):
input = features.BoundingBox([0, 0, 5, 5], format=features.BoundingBoxFormat.XYXY, image_size=(10, 10))
inpt = features.BoundingBox([0, 0, 5, 5], format=features.BoundingBoxFormat.XYXY, image_size=(10, 10))
transform = transforms.RandomHorizontalFlip(p=p)

actual = transform(input)
actual = transform(inpt)

expected_image_tensor = torch.tensor([5, 0, 10, 5]) if p == 1.0 else input
expected = features.BoundingBox.new_like(input, data=expected_image_tensor)
expected_image_tensor = torch.tensor([5, 0, 10, 5]) if p == 1.0 else inpt
expected = features.BoundingBox.new_like(inpt, data=expected_image_tensor)
assert_equal(expected, actual)
assert actual.format == expected.format
assert actual.image_size == expected.image_size


@pytest.mark.parametrize("p", [0.0, 1.0])
class TestRandomVerticalFlip:
def input_expected_image_tensor(self, p, dtype=torch.float32):
input = torch.tensor([[[1, 1], [0, 0]], [[1, 1], [0, 0]]], dtype=dtype)
def inpt_expected_image_tensor(self, p, dtype=torch.float32):
inpt = torch.tensor([[[1, 1], [0, 0]], [[1, 1], [0, 0]]], dtype=dtype)
expected = torch.tensor([[[0, 0], [1, 1]], [[0, 0], [1, 1]]], dtype=dtype)

return input, expected if p == 1 else input
return inpt, expected if p == 1 else inpt

def test_simple_tensor(self, p):
input, expected = self.input_expected_image_tensor(p)
inpt, expected = self.inpt_expected_image_tensor(p)
transform = transforms.RandomVerticalFlip(p=p)

actual = transform(input)
actual = transform(inpt)

assert_equal(expected, actual)

def test_pil_image(self, p):
input, expected = self.input_expected_image_tensor(p, dtype=torch.uint8)
inpt, expected = self.inpt_expected_image_tensor(p, dtype=torch.uint8)
transform = transforms.RandomVerticalFlip(p=p)

actual = transform(to_pil_image(input))
actual = transform(to_pil_image(inpt))

assert_equal(expected, pil_to_tensor(actual))

def test_features_image(self, p):
input, expected = self.input_expected_image_tensor(p)
inpt, expected = self.inpt_expected_image_tensor(p)
transform = transforms.RandomVerticalFlip(p=p)

actual = transform(features.Image(input))
actual = transform(features.Image(inpt))

assert_equal(features.Image(expected), actual)

def test_features_segmentation_mask(self, p):
input, expected = self.input_expected_image_tensor(p)
inpt, expected = self.inpt_expected_image_tensor(p)
transform = transforms.RandomVerticalFlip(p=p)

actual = transform(features.SegmentationMask(input))
actual = transform(features.SegmentationMask(inpt))

assert_equal(features.SegmentationMask(expected), actual)

def test_features_bounding_box(self, p):
input = features.BoundingBox([0, 0, 5, 5], format=features.BoundingBoxFormat.XYXY, image_size=(10, 10))
inpt = features.BoundingBox([0, 0, 5, 5], format=features.BoundingBoxFormat.XYXY, image_size=(10, 10))
transform = transforms.RandomVerticalFlip(p=p)

actual = transform(input)
actual = transform(inpt)

expected_image_tensor = torch.tensor([0, 5, 5, 10]) if p == 1.0 else input
expected = features.BoundingBox.new_like(input, data=expected_image_tensor)
expected_image_tensor = torch.tensor([0, 5, 5, 10]) if p == 1.0 else inpt
expected = features.BoundingBox.new_like(inpt, data=expected_image_tensor)
assert_equal(expected, actual)
assert actual.format == expected.format
assert actual.image_size == expected.image_size
46 changes: 46 additions & 0 deletions torchvision/prototype/features/_bounding_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,49 @@ def to_format(self, format: Union[str, BoundingBoxFormat]) -> BoundingBox:
return BoundingBox.new_like(
self, convert_bounding_box_format(self, old_format=self.format, new_format=format), format=format
)

def horizontal_flip(self) -> BoundingBox:
output = self._F.horizontal_flip_bounding_box(self, format=self.format, image_size=self.image_size)
return BoundingBox.new_like(self, output)

def vertical_flip(self) -> BoundingBox:
output = self._F.vertical_flip_bounding_box(self, format=self.format, image_size=self.image_size)
return BoundingBox.new_like(self, output)

def resize(self, size, *, interpolation, max_size, antialias) -> BoundingBox:
interpolation, antialias # unused
output = self._F.resize_bounding_box(self, size, image_size=self.image_size, max_size=max_size)
return BoundingBox.new_like(self, output, image_size=size)

def center_crop(self, output_size) -> BoundingBox:
output = self._F.center_crop_bounding_box(
self, format=self.format, output_size=output_size, image_size=self.image_size
)
return BoundingBox.new_like(self, output, image_size=output_size)

def resized_crop(self, top, left, height, width, *, size, interpolation, antialias) -> BoundingBox:
# TODO: untested right now
interpolation, antialias # unused
output = self._F.resized_crop_bounding_box(self, self.format, top, left, height, width, size=size)
return BoundingBox.new_like(self, output, image_size=size)

def pad(self, padding, *, fill, padding_mode) -> BoundingBox:
fill # unused
if padding_mode not in ["constant"]:
raise ValueError(f"Padding mode '{padding_mode}' is not supported with bounding boxes")

output = self._F.pad_bounding_box(self, padding, fill=fill, padding_mode=padding_mode)

# Update output image size:
left, top, right, bottom = padding
height, width = self.image_size
height += top + bottom
width += left + right

return BoundingBox.new_like(self, output, image_size=(height, width))

def rotate(self, angle, *, interpolation, expand, fill, center) -> BoundingBox:
output = self._F.rotate_bounding_box(
self, angle, interpolation=interpolation, expand=expand, fill=fill, center=center
)
return BoundingBox.new_like(self, output)
47 changes: 46 additions & 1 deletion torchvision/prototype/features/_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,52 @@
F = TypeVar("F", bound="_Feature")


class _Feature(torch.Tensor):
class _TransformsMixin:
def __init__(self, *args, **kwargs):
super().__init__()

# To avoid circular dependency between features and transforms
from ..transforms import functional as F

self._F = F

def horizontal_flip(self):
# Just output itself
# How dangerous to do this instead of raising an error ?
return self
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would raise a "not implemented" error here and implement no-op explicitly on the implementations. So if bbox can't blur, then it should explicitly implement to be no-op as opposed to leaving the default implementation which returns self. Could help us avoid issues but if you disagree I'm happy to leave as-is and discuss later.

Copy link
Collaborator Author

@vfdev-5 vfdev-5 Jun 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that raising NotImplemented could be good. The only drawback I found that all targets features like Label, OneHotLabel should also implement explicitly these no-ops. From the first glance, implementing all methods like OneHotLabel.blur(self, *args, **kwargs) as no-op made me think that it is a bit too much...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Up to you. I've added this for your consideration. Feel free to ignore or postpone for later.


def vertical_flip(self):
# Just output itself
# How dangerous to do this instead of raising an error ?
return self

def resize(self, size, *, interpolation, max_size, antialias):
# Just output itself
# How dangerous to do this instead of raising an error ?
return self

def center_crop(self, output_size):
# Just output itself
# How dangerous to do this instead of raising an error ?
return self

def resized_crop(self, top, left, height, width, *, size, interpolation, antialias):
# Just output itself
# How dangerous to do this instead of raising an error ?
return self

def pad(self, padding, *, fill, padding_mode):
# Just output itself
# How dangerous to do this instead of raising an error ?
return self

def rotate(self, angle, *, interpolation, expand, fill, center):
# Just output itself
# How dangerous to do this instead of raising an error ?
return self


class _Feature(_TransformsMixin, torch.Tensor):
def __new__(
cls: Type[F],
data: Any,
Expand Down
Loading