Skip to content

[proto] Draft for Transforms API v2 #6205

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 28 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
25b3667
Added base tests for rotate_image_tensor
vfdev-5 Jun 22, 2022
6b3483d
Updated resize_image_tensor API and tests and fixed a bug with max_size
vfdev-5 Jun 22, 2022
ea7c513
Refactored and modified private api for resize functional op
vfdev-5 Jun 22, 2022
dc64e8a
Added base tests for rotate_image_tensor
vfdev-5 Jun 22, 2022
d341beb
Updated resize_image_tensor API and tests and fixed a bug with max_size
vfdev-5 Jun 22, 2022
aade78f
Fixed failures
vfdev-5 Jun 22, 2022
1e04ad3
Merge branch 'refactor-resize-max-size' into update-proto-test-rotate…
vfdev-5 Jun 22, 2022
a812a3b
More updates
vfdev-5 Jun 22, 2022
1aa6a78
Merge branch 'refactor-resize-max-size' into update-proto-test-rotate…
vfdev-5 Jun 22, 2022
6661d8d
Updated proto functional op: resize_image_*
vfdev-5 Jun 22, 2022
0972822
Fixed flake8
vfdev-5 Jun 22, 2022
f0c896f
Added max_size arg to resize_bounding_box and updated basic tests
vfdev-5 Jun 22, 2022
39e8bf6
Merge branch 'refactor-resize-max-size' into proto-transforms-api-oo
vfdev-5 Jun 22, 2022
c147e53
Merge branch 'update-proto-test-rotate-image' into proto-transforms-a…
vfdev-5 Jun 22, 2022
1a3a749
WIP Adding ops:
vfdev-5 Jun 22, 2022
740dfa7
Merge branch 'main' into update-proto-test-rotate-image
vfdev-5 Jun 23, 2022
6a5e5ab
Update functional.py
vfdev-5 Jun 23, 2022
8fcf4fa
Merge branch 'main' of github.com:pytorch/vision into proto-transform…
vfdev-5 Jun 23, 2022
b2ada45
Reverted fill/center order for rotate
vfdev-5 Jun 23, 2022
4af25c6
Merge branch 'update-proto-test-rotate-image' of github.com:vfdev-5/v…
vfdev-5 Jun 23, 2022
ff80373
Merge branch 'main' of github.com:pytorch/vision into proto-transform…
vfdev-5 Jun 23, 2022
5bc6a50
- Added Pad and WIP on Rotate
vfdev-5 Jun 27, 2022
6a5201a
Added more ops and mid-level functional API
vfdev-5 Jun 27, 2022
b659a08
- Added more color ops
vfdev-5 Jun 27, 2022
7a8f950
Added more f-mid-level ops
vfdev-5 Jun 28, 2022
7917a17
_check_inpt -> _check_input
vfdev-5 Jun 28, 2022
99bfad9
Fixed broken code, added a test for mid-level ops
vfdev-5 Jun 28, 2022
ae5eef9
Fixed bugs and started porting transforms
vfdev-5 Jun 29, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 64 additions & 56 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
import pytest
import torch
from common_utils import assert_equal
from test_prototype_transforms_functional import make_images, make_bounding_boxes, make_one_hot_labels
from test_prototype_transforms_functional import (
make_images,
make_bounding_boxes,
make_one_hot_labels,
make_segmentation_masks,
)
from torchvision.prototype import transforms, features
from torchvision.transforms.functional import to_pil_image, pil_to_tensor

Expand All @@ -25,56 +30,59 @@ def make_vanilla_tensor_bounding_boxes(*args, **kwargs):
yield bounding_box.data


def parametrize(transforms_with_inputs):
def parametrize(transforms_with_inpts):
return pytest.mark.parametrize(
("transform", "input"),
("transform", "inpt"),
[
pytest.param(
transform,
input,
id=f"{type(transform).__name__}-{type(input).__module__}.{type(input).__name__}-{idx}",
inpt,
id=f"{type(transform).__name__}-{type(inpt).__module__}.{type(inpt).__name__}-{idx}",
)
for transform, inputs in transforms_with_inputs
for idx, input in enumerate(inputs)
for transform, inpts in transforms_with_inpts
for idx, inpt in enumerate(inpts)
],
)


def parametrize_from_transforms(*transforms):
transforms_with_inputs = []
transforms_with_inpts = []
for transform in transforms:
for creation_fn in [
make_images,
make_bounding_boxes,
make_one_hot_labels,
make_vanilla_tensor_images,
make_pil_images,
make_segmentation_masks,
]:
inputs = list(creation_fn())
inpts = list(creation_fn())
try:
output = transform(inputs[0])
except Exception:
output = transform(inpts[0])
except (TypeError, RuntimeError):
continue
else:
if output is inputs[0]:
if output is inpts[0]:
continue

transforms_with_inputs.append((transform, inputs))
transforms_with_inpts.append((transform, inpts))

return parametrize(transforms_with_inputs)
return parametrize(transforms_with_inpts)


class TestSmoke:
@parametrize_from_transforms(
transforms.RandomErasing(p=1.0),
transforms.Resize([16, 16]),
transforms.CenterCrop([16, 16]),
transforms.RandomResizedCrop([16, 16]),
transforms.ConvertImageDtype(),
transforms.RandomHorizontalFlip(),
transforms.Pad(5),
)
def test_common(self, transform, input):
transform(input)
def test_common(self, transform, inpt):
output = transform(inpt)
assert type(output) == type(inpt)

@parametrize(
[
Expand All @@ -96,8 +104,8 @@ def test_common(self, transform, input):
]
]
)
def test_mixup_cutmix(self, transform, input):
transform(input)
def test_mixup_cutmix(self, transform, inpt):
transform(inpt)

@parametrize(
[
Expand Down Expand Up @@ -127,8 +135,8 @@ def test_mixup_cutmix(self, transform, input):
)
]
)
def test_auto_augment(self, transform, input):
transform(input)
def test_auto_augment(self, transform, inpt):
transform(inpt)

@parametrize(
[
Expand All @@ -144,8 +152,8 @@ def test_auto_augment(self, transform, input):
),
]
)
def test_normalize(self, transform, input):
transform(input)
def test_normalize(self, transform, inpt):
transform(inpt)

@parametrize(
[
Expand All @@ -159,8 +167,8 @@ def test_normalize(self, transform, input):
)
]
)
def test_random_resized_crop(self, transform, input):
transform(input)
def test_random_resized_crop(self, transform, inpt):
transform(inpt)

@parametrize(
[
Expand Down Expand Up @@ -188,111 +196,111 @@ def test_random_resized_crop(self, transform, input):
)
]
)
def test_convert_image_color_space(self, transform, input):
transform(input)
def test_convert_image_color_space(self, transform, inpt):
transform(inpt)


@pytest.mark.parametrize("p", [0.0, 1.0])
class TestRandomHorizontalFlip:
def input_expected_image_tensor(self, p, dtype=torch.float32):
input = torch.tensor([[[0, 1], [0, 1]], [[1, 0], [1, 0]]], dtype=dtype)
def inpt_expected_image_tensor(self, p, dtype=torch.float32):
inpt = torch.tensor([[[0, 1], [0, 1]], [[1, 0], [1, 0]]], dtype=dtype)
expected = torch.tensor([[[1, 0], [1, 0]], [[0, 1], [0, 1]]], dtype=dtype)

return input, expected if p == 1 else input
return inpt, expected if p == 1 else inpt

def test_simple_tensor(self, p):
input, expected = self.input_expected_image_tensor(p)
inpt, expected = self.inpt_expected_image_tensor(p)
transform = transforms.RandomHorizontalFlip(p=p)

actual = transform(input)
actual = transform(inpt)

assert_equal(expected, actual)

def test_pil_image(self, p):
input, expected = self.input_expected_image_tensor(p, dtype=torch.uint8)
inpt, expected = self.inpt_expected_image_tensor(p, dtype=torch.uint8)
transform = transforms.RandomHorizontalFlip(p=p)

actual = transform(to_pil_image(input))
actual = transform(to_pil_image(inpt))

assert_equal(expected, pil_to_tensor(actual))

def test_features_image(self, p):
input, expected = self.input_expected_image_tensor(p)
inpt, expected = self.inpt_expected_image_tensor(p)
transform = transforms.RandomHorizontalFlip(p=p)

actual = transform(features.Image(input))
actual = transform(features.Image(inpt))

assert_equal(features.Image(expected), actual)

def test_features_segmentation_mask(self, p):
input, expected = self.input_expected_image_tensor(p)
inpt, expected = self.inpt_expected_image_tensor(p)
transform = transforms.RandomHorizontalFlip(p=p)

actual = transform(features.SegmentationMask(input))
actual = transform(features.SegmentationMask(inpt))

assert_equal(features.SegmentationMask(expected), actual)

def test_features_bounding_box(self, p):
input = features.BoundingBox([0, 0, 5, 5], format=features.BoundingBoxFormat.XYXY, image_size=(10, 10))
inpt = features.BoundingBox([0, 0, 5, 5], format=features.BoundingBoxFormat.XYXY, image_size=(10, 10))
transform = transforms.RandomHorizontalFlip(p=p)

actual = transform(input)
actual = transform(inpt)

expected_image_tensor = torch.tensor([5, 0, 10, 5]) if p == 1.0 else input
expected = features.BoundingBox.new_like(input, data=expected_image_tensor)
expected_image_tensor = torch.tensor([5, 0, 10, 5]) if p == 1.0 else inpt
expected = features.BoundingBox.new_like(inpt, data=expected_image_tensor)
assert_equal(expected, actual)
assert actual.format == expected.format
assert actual.image_size == expected.image_size


@pytest.mark.parametrize("p", [0.0, 1.0])
class TestRandomVerticalFlip:
def input_expected_image_tensor(self, p, dtype=torch.float32):
input = torch.tensor([[[1, 1], [0, 0]], [[1, 1], [0, 0]]], dtype=dtype)
def inpt_expected_image_tensor(self, p, dtype=torch.float32):
inpt = torch.tensor([[[1, 1], [0, 0]], [[1, 1], [0, 0]]], dtype=dtype)
expected = torch.tensor([[[0, 0], [1, 1]], [[0, 0], [1, 1]]], dtype=dtype)

return input, expected if p == 1 else input
return inpt, expected if p == 1 else inpt

def test_simple_tensor(self, p):
input, expected = self.input_expected_image_tensor(p)
inpt, expected = self.inpt_expected_image_tensor(p)
transform = transforms.RandomVerticalFlip(p=p)

actual = transform(input)
actual = transform(inpt)

assert_equal(expected, actual)

def test_pil_image(self, p):
input, expected = self.input_expected_image_tensor(p, dtype=torch.uint8)
inpt, expected = self.inpt_expected_image_tensor(p, dtype=torch.uint8)
transform = transforms.RandomVerticalFlip(p=p)

actual = transform(to_pil_image(input))
actual = transform(to_pil_image(inpt))

assert_equal(expected, pil_to_tensor(actual))

def test_features_image(self, p):
input, expected = self.input_expected_image_tensor(p)
inpt, expected = self.inpt_expected_image_tensor(p)
transform = transforms.RandomVerticalFlip(p=p)

actual = transform(features.Image(input))
actual = transform(features.Image(inpt))

assert_equal(features.Image(expected), actual)

def test_features_segmentation_mask(self, p):
input, expected = self.input_expected_image_tensor(p)
inpt, expected = self.inpt_expected_image_tensor(p)
transform = transforms.RandomVerticalFlip(p=p)

actual = transform(features.SegmentationMask(input))
actual = transform(features.SegmentationMask(inpt))

assert_equal(features.SegmentationMask(expected), actual)

def test_features_bounding_box(self, p):
input = features.BoundingBox([0, 0, 5, 5], format=features.BoundingBoxFormat.XYXY, image_size=(10, 10))
inpt = features.BoundingBox([0, 0, 5, 5], format=features.BoundingBoxFormat.XYXY, image_size=(10, 10))
transform = transforms.RandomVerticalFlip(p=p)

actual = transform(input)
actual = transform(inpt)

expected_image_tensor = torch.tensor([0, 5, 5, 10]) if p == 1.0 else input
expected = features.BoundingBox.new_like(input, data=expected_image_tensor)
expected_image_tensor = torch.tensor([0, 5, 5, 10]) if p == 1.0 else inpt
expected = features.BoundingBox.new_like(inpt, data=expected_image_tensor)
assert_equal(expected, actual)
assert actual.format == expected.format
assert actual.image_size == expected.image_size
32 changes: 28 additions & 4 deletions test/test_prototype_transforms_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,16 +489,40 @@ def center_crop_segmentation_mask():
and callable(kernel)
and any(feature_type in name for feature_type in {"image", "segmentation_mask", "bounding_box", "label"})
and "pil" not in name
and name
not in {
"to_image_tensor",
}
and name not in {"to_image_tensor"}
],
)
def test_scriptable(kernel):
jit.script(kernel)


@pytest.mark.parametrize(
"func",
[
pytest.param(func, id=name)
for name, func in F.__dict__.items()
if not name.startswith("_")
and callable(func)
and all(
feature_type not in name for feature_type in {"image", "segmentation_mask", "bounding_box", "label", "pil"}
)
and name not in {"to_image_tensor", "InterpolationMode", "decode_video_with_av"}
],
)
def test_functional_mid_level(func):
finfos = [finfo for finfo in FUNCTIONAL_INFOS if f"{func.__name__}_" in finfo.name]
for finfo in finfos:
for sample_input in finfo.sample_inputs():
expected = finfo(sample_input)
kwargs = dict(sample_input.kwargs)
for key in ["format", "image_size"]:
if key in kwargs:
del kwargs[key]
output = func(*sample_input.args, **kwargs)
torch.testing.assert_close(output, expected, msg=f"finfo={finfo}, output={output}, expected={expected}")
break


@pytest.mark.parametrize(
("functional_info", "sample_input"),
[
Expand Down
Loading