Skip to content

Commit 2586de6

Browse files
authored
Added erase_image_pil and eager/jit erase_image_tensor test (#6320)
1 parent 23112f8 commit 2586de6

File tree

4 files changed

+19
-12
lines changed

4 files changed

+19
-12
lines changed

test/test_prototype_transforms_functional.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -654,6 +654,13 @@ def adjust_sharpness_image_tensor():
654654
yield SampleInput(image, sharpness_factor=sharpness_factor)
655655

656656

657+
@register_kernel_info_from_sample_inputs_fn
658+
def erase_image_tensor():
659+
for image in make_images():
660+
c = image.shape[-3]
661+
yield SampleInput(image, i=1, j=2, h=6, w=7, v=torch.rand(c, 6, 7))
662+
663+
657664
@pytest.mark.parametrize(
658665
"kernel",
659666
[

torchvision/prototype/transforms/_augment.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import torch
88
from torchvision.prototype import features
99
from torchvision.prototype.transforms import functional as F, Transform
10-
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
1110

1211
from ._transform import _RandomApplyTransform
1312
from ._utils import get_image_dimensions, has_all, has_any, query_image
@@ -93,9 +92,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
9392
return features.Image.new_like(inpt, output)
9493
return output
9594
elif isinstance(inpt, PIL.Image.Image):
96-
t_img = pil_to_tensor(inpt)
97-
output = F.erase_image_tensor(t_img, **params)
98-
return to_pil_image(output, mode=inpt.mode)
95+
return F.erase_image_pil(inpt, **params)
9996
else:
10097
return inpt
10198

torchvision/prototype/transforms/functional/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
convert_image_color_space_pil,
66
) # usort: skip
77

8-
from ._augment import erase_image_tensor
8+
from ._augment import erase_image_pil, erase_image_tensor
99
from ._color import (
1010
adjust_brightness,
1111
adjust_brightness_image_pil,
Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
1+
import PIL.Image
2+
3+
import torch
14
from torchvision.transforms import functional_tensor as _FT
5+
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
26

37

48
erase_image_tensor = _FT.erase
59

610

7-
# TODO: Don't forget to clean up from the primitives kernels those that shouldn't be kernels.
8-
# Like the mixup and cutmix stuff
9-
10-
# This function is copy-pasted to Image and OneHotLabel and may be refactored
11-
# def _mixup_tensor(input: torch.Tensor, batch_dim: int, lam: float) -> torch.Tensor:
12-
# input = input.clone()
13-
# return input.roll(1, batch_dim).mul_(1 - lam).add_(input.mul_(lam))
11+
def erase_image_pil(
12+
img: PIL.Image.Image, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
13+
) -> PIL.Image.Image:
14+
t_img = pil_to_tensor(img)
15+
output = erase_image_tensor(t_img, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
16+
return to_pil_image(output, mode=img.mode)

0 commit comments

Comments
 (0)