Skip to content

[proto] Added RandomCrop transform and tests #6271

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ class TestSmoke:
transforms.RandomZoomOut(),
transforms.RandomRotation(degrees=(-45, 45)),
transforms.RandomAffine(degrees=(-45, 45)),
transforms.RandomCrop([16, 16], padding=1, pad_if_needed=True),
)
def test_common(self, transform, input):
transform(input)
Expand Down Expand Up @@ -566,3 +567,80 @@ def test__transform(self, degrees, translate, scale, shear, fill, center, mocker
params = transform._get_params(inpt)

fn.assert_called_once_with(inpt, **params, interpolation=interpolation, fill=fill, center=center)


class TestRandomCrop:
def test_assertions(self):
with pytest.raises(ValueError, match="Please provide only two dimensions"):
transforms.RandomCrop([10, 12, 14])

with pytest.raises(TypeError, match="Got inappropriate padding arg"):
transforms.RandomCrop([10, 12], padding="abc")

with pytest.raises(ValueError, match="Padding must be an int or a 1, 2, or 4"):
transforms.RandomCrop([10, 12], padding=[-0.7, 0, 0.7])

with pytest.raises(TypeError, match="Got inappropriate fill arg"):
transforms.RandomCrop([10, 12], padding=1, fill="abc")

with pytest.raises(ValueError, match="Padding mode should be either"):
transforms.RandomCrop([10, 12], padding=1, padding_mode="abc")

def test__get_params(self):
image = features.Image(torch.rand(1, 3, 32, 32))
h, w = image.shape[-2:]

transform = transforms.RandomCrop([10, 10])
params = transform._get_params(image)

assert 0 <= params["top"] <= h - transform.size[0] + 1
assert 0 <= params["left"] <= w - transform.size[1] + 1
assert params["height"] == 10
assert params["width"] == 10

@pytest.mark.parametrize("padding", [None, 1, [2, 3], [1, 2, 3, 4]])
@pytest.mark.parametrize("pad_if_needed", [False, True])
@pytest.mark.parametrize("fill", [False, True])
@pytest.mark.parametrize("padding_mode", ["constant", "edge"])
def test_forward(self, padding, pad_if_needed, fill, padding_mode, mocker):
output_size = [10, 12]
transform = transforms.RandomCrop(
output_size, padding=padding, pad_if_needed=pad_if_needed, fill=fill, padding_mode=padding_mode
)

inpt = features.Image(torch.rand(1, 3, 32, 32))
expected = mocker.MagicMock(spec=features.Image)
expected.num_channels = 3
if isinstance(padding, int):
expected.image_size = (inpt.image_size[0] + padding, inpt.image_size[1] + padding)
elif isinstance(padding, list):
expected.image_size = (
inpt.image_size[0] + sum(padding[0::2]),
inpt.image_size[1] + sum(padding[1::2]),
)
else:
expected.image_size = inpt.image_size
_ = mocker.patch("torchvision.prototype.transforms.functional.pad", return_value=expected)
fn_crop = mocker.patch("torchvision.prototype.transforms.functional.crop")

# vfdev-5, Feature Request: let's store params as Transform attribute
# This could be also helpful for users
torch.manual_seed(12)
_ = transform(inpt)
torch.manual_seed(12)
if padding is None and not pad_if_needed:
params = transform._get_params(inpt)
fn_crop.assert_called_once_with(
inpt, top=params["top"], left=params["left"], height=output_size[0], width=output_size[1]
)
elif not pad_if_needed:
params = transform._get_params(expected)
fn_crop.assert_called_once_with(
expected, top=params["top"], left=params["left"], height=output_size[0], width=output_size[1]
)
elif padding is None:
# vfdev-5: I do not know how to mock and test this case
pass
else:
# vfdev-5: I do not know how to mock and test this case
pass
1 change: 1 addition & 0 deletions torchvision/prototype/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Resize,
CenterCrop,
RandomResizedCrop,
RandomCrop,
FiveCrop,
TenCrop,
BatchMultiCrop,
Expand Down
101 changes: 89 additions & 12 deletions torchvision/prototype/transforms/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def __init__(
antialias: Optional[bool] = None,
) -> None:
super().__init__()
self.size = [size] if isinstance(size, int) else list(size)

self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
self.interpolation = interpolation
self.max_size = max_size
self.antialias = antialias
Expand Down Expand Up @@ -80,7 +81,6 @@ def __init__(
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
warnings.warn("Scale and ratio should be of kind (min, max)")

self.size = size
self.scale = scale
self.ratio = ratio
self.interpolation = interpolation
Expand Down Expand Up @@ -225,6 +225,19 @@ def _check_fill_arg(fill: Union[int, float, Sequence[int], Sequence[float]]) ->
raise TypeError("Got inappropriate fill arg")


def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None:
if not isinstance(padding, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate padding arg")

if isinstance(padding, (tuple, list)) and len(padding) not in [1, 2, 4]:
raise ValueError(f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple")


def _check_padding_mode_arg(padding_mode: Literal["constant", "edge", "reflect", "symmetric"]) -> None:
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")


class Pad(Transform):
def __init__(
self,
Expand All @@ -233,18 +246,10 @@ def __init__(
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> None:
super().__init__()
if not isinstance(padding, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate padding arg")

if isinstance(padding, (tuple, list)) and len(padding) not in [1, 2, 4]:
raise ValueError(
f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple"
)

_check_padding_arg(padding)
_check_fill_arg(fill)

if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
_check_padding_mode_arg(padding_mode)

self.padding = padding
self.fill = fill
Expand Down Expand Up @@ -416,3 +421,75 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill=self.fill,
center=self.center,
)


class RandomCrop(Transform):
def __init__(
self,
size: Union[int, Sequence[int]],
padding: Optional[Union[int, Sequence[int]]] = None,
pad_if_needed: bool = False,
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> None:
super().__init__()

self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")

if padding is not None:
_check_padding_arg(padding)

if (padding is not None) or pad_if_needed:
_check_padding_mode_arg(padding_mode)
_check_fill_arg(fill)

self.padding = padding
self.pad_if_needed = pad_if_needed
self.fill = fill
self.padding_mode = padding_mode

def _get_params(self, sample: Any) -> Dict[str, Any]:
image = query_image(sample)
_, height, width = get_image_dimensions(image)
output_height, output_width = self.size

if height + 1 < output_height or width + 1 < output_width:
raise ValueError(
f"Required crop size {(output_height, output_width)} is larger then input image size {(height, width)}"
)

if width == output_width and height == output_height:
return dict(top=0, left=0, height=height, width=width)

top = torch.randint(0, height - output_height + 1, size=(1,)).item()
left = torch.randint(0, width - output_width + 1, size=(1,)).item()
return dict(top=top, left=left, height=output_height, width=output_width)

def _forward(self, flat_inputs: List[Any]) -> List[Any]:
if self.padding is not None:
flat_inputs = [F.pad(flat_input, self.padding, self.fill, self.padding_mode) for flat_input in flat_inputs]

image = query_image(flat_inputs)
_, height, width = get_image_dimensions(image)

# pad the width if needed
if self.pad_if_needed and width < self.size[1]:
padding = [self.size[1] - width, 0]
flat_inputs = [F.pad(flat_input, padding, self.fill, self.padding_mode) for flat_input in flat_inputs]
# pad the height if needed
if self.pad_if_needed and height < self.size[0]:
padding = [0, self.size[0] - height]
flat_inputs = [F.pad(flat_input, padding, self.fill, self.padding_mode) for flat_input in flat_inputs]

params = self._get_params(flat_inputs)

return [F.crop(flat_input, **params) for flat_input in flat_inputs]

def forward(self, *inputs: Any) -> Any:
from torch.utils._pytree import tree_flatten, tree_unflatten

sample = inputs if len(inputs) > 1 else inputs[0]

flat_inputs, spec = tree_flatten(sample)
out_flat_inputs = self._forward(flat_inputs)
return tree_unflatten(out_flat_inputs, spec)
Comment on lines +468 to +495
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, we can't just override _transform method as we need to access image data and generate params on transformed sample.
Proposed solution is to 1) flatten sample structure into a list, 2) apply crop+pad logic, 3) unflatten output into input sample structure

13 changes: 12 additions & 1 deletion torchvision/prototype/transforms/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,28 @@

import PIL.Image
import torch
from torch.utils._pytree import tree_flatten
from torchvision.prototype import features
from torchvision.prototype.utils._internal import query_recursively

from .functional._meta import get_dimensions_image_tensor, get_dimensions_image_pil


def query_image(sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Image]:
flat_sample, _ = tree_flatten(sample)
for i in flat_sample:
if type(i) == torch.Tensor or isinstance(i, (PIL.Image.Image, features.Image)):
return i

raise TypeError("No image was found in the sample")


# vfdev-5: let's use tree_flatten instead of query_recursively and internal fn to make the code simplier
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's simplify the code using tree_flatten instead of home-made methods.

def query_image_(sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Image]:
def fn(
id: Tuple[Any, ...], input: Any
) -> Optional[Tuple[Tuple[Any, ...], Union[PIL.Image.Image, torch.Tensor, features.Image]]]:
if type(input) in {torch.Tensor, features.Image} or isinstance(input, PIL.Image.Image):
if type(input) == torch.Tensor or isinstance(input, (PIL.Image.Image, features.Image)):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pmeier I suggest to check for image-like types in the following way. This helps to use mocker with spec type.

return id, input

return None
Expand Down