Skip to content

Commit 684a48f

Browse files
authored
Merge branch 'main' into malfet/add-python-3.10
2 parents 1a2d105 + 7bb5e41 commit 684a48f

File tree

7 files changed

+264
-3
lines changed

7 files changed

+264
-3
lines changed

docs/source/transforms.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ The new transform can be used standalone or mixed-and-matched with existing tran
198198
AutoAugment
199199
RandAugment
200200
TrivialAugmentWide
201+
AugMix
201202

202203
.. _functional_transforms:
203204

gallery/plot_transforms.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,14 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs):
263263
imgs = [augmenter(orig_img) for _ in range(4)]
264264
plot(imgs)
265265

266+
####################################
267+
# AugMix
268+
# ~~~~~~
269+
# The :class:`~torchvision.transforms.AugMix` transform automatically augments the data.
270+
augmenter = T.AugMix()
271+
imgs = [augmenter(orig_img) for _ in range(4)]
272+
plot(imgs)
273+
266274
####################################
267275
# Randomly-applied transforms
268276
# ---------------------------

references/classification/presets.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ def __init__(
2222
trans.append(autoaugment.RandAugment(interpolation=interpolation))
2323
elif auto_augment_policy == "ta_wide":
2424
trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation))
25+
elif auto_augment_policy == "augmix":
26+
trans.append(autoaugment.AugMix(interpolation=interpolation))
2527
else:
2628
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
2729
trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation))

references/detection/transforms.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torchvision
55
from torch import nn, Tensor
66
from torchvision.transforms import functional as F
7-
from torchvision.transforms import transforms as T
7+
from torchvision.transforms import transforms as T, InterpolationMode
88

99

1010
def _flip_coco_person_keypoints(kps, width):
@@ -282,3 +282,52 @@ def forward(
282282
image = F.to_pil_image(image)
283283

284284
return image, target
285+
286+
287+
class ScaleJitter(nn.Module):
288+
"""Randomly resizes the image and its bounding boxes within the specified scale range.
289+
The class implements the Scale Jitter augmentation as described in the paper
290+
`"Simple Copy-Paste is a Strong Data Augmentation Method for Instance Segmentation" <https://arxiv.org/abs/2012.07177>`_.
291+
292+
Args:
293+
target_size (tuple of ints): The target size for the transform provided in (height, weight) format.
294+
scale_range (tuple of ints): scaling factor interval, e.g (a, b), then scale is randomly sampled from the
295+
range a <= scale <= b.
296+
interpolation (InterpolationMode): Desired interpolation enum defined by
297+
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
298+
"""
299+
300+
def __init__(
301+
self,
302+
target_size: Tuple[int, int],
303+
scale_range: Tuple[float, float] = (0.1, 2.0),
304+
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
305+
):
306+
super().__init__()
307+
self.target_size = target_size
308+
self.scale_range = scale_range
309+
self.interpolation = interpolation
310+
311+
def forward(
312+
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
313+
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
314+
if isinstance(image, torch.Tensor):
315+
if image.ndimension() not in {2, 3}:
316+
raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.")
317+
elif image.ndimension() == 2:
318+
image = image.unsqueeze(0)
319+
320+
r = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0])
321+
new_width = int(self.target_size[1] * r)
322+
new_height = int(self.target_size[0] * r)
323+
324+
image = F.resize(image, [new_height, new_width], interpolation=self.interpolation)
325+
326+
if target is not None:
327+
target["boxes"] *= r
328+
if "masks" in target:
329+
target["masks"] = F.resize(
330+
target["masks"], [new_height, new_width], interpolation=InterpolationMode.NEAREST
331+
)
332+
333+
return image, target

test/test_transforms.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1601,6 +1601,25 @@ def test_trivialaugmentwide(fill, num_magnitude_bins, grayscale):
16011601
transform.__repr__()
16021602

16031603

1604+
@pytest.mark.parametrize("fill", [None, 85, (128, 128, 128)])
1605+
@pytest.mark.parametrize("severity", [1, 10])
1606+
@pytest.mark.parametrize("mixture_width", [1, 2])
1607+
@pytest.mark.parametrize("chain_depth", [-1, 2])
1608+
@pytest.mark.parametrize("all_ops", [True, False])
1609+
@pytest.mark.parametrize("grayscale", [True, False])
1610+
def test_augmix(fill, severity, mixture_width, chain_depth, all_ops, grayscale):
1611+
random.seed(42)
1612+
img = Image.open(GRACE_HOPPER)
1613+
if grayscale:
1614+
img, fill = _get_grayscale_test_image(img, fill)
1615+
transform = transforms.AugMix(
1616+
fill=fill, severity=severity, mixture_width=mixture_width, chain_depth=chain_depth, all_ops=all_ops
1617+
)
1618+
for _ in range(100):
1619+
img = transform(img)
1620+
transform.__repr__()
1621+
1622+
16041623
def test_random_crop():
16051624
height = random.randint(10, 32) * 2
16061625
width = random.randint(10, 32) * 2

test/test_transforms_tensor.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -720,7 +720,38 @@ def test_trivialaugmentwide(device, fill):
720720
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
721721

722722

723-
@pytest.mark.parametrize("augmentation", [T.AutoAugment, T.RandAugment, T.TrivialAugmentWide])
723+
@pytest.mark.parametrize("device", cpu_and_gpu())
724+
@pytest.mark.parametrize(
725+
"fill",
726+
[
727+
None,
728+
85,
729+
(10, -10, 10),
730+
0.7,
731+
[0.0, 0.0, 0.0],
732+
[
733+
1,
734+
],
735+
1,
736+
],
737+
)
738+
def test_augmix(device, fill):
739+
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
740+
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
741+
742+
class DeterministicAugMix(T.AugMix):
743+
def _sample_dirichlet(self, params: torch.Tensor) -> torch.Tensor:
744+
# patch the method to ensure that the order of rand calls doesn't affect the outcome
745+
return params.softmax(dim=-1)
746+
747+
transform = DeterministicAugMix(fill=fill)
748+
s_transform = torch.jit.script(transform)
749+
for _ in range(25):
750+
_test_transform_vs_scripted(transform, s_transform, tensor)
751+
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
752+
753+
754+
@pytest.mark.parametrize("augmentation", [T.AutoAugment, T.RandAugment, T.TrivialAugmentWide, T.AugMix])
724755
def test_autoaugment_save(augmentation, tmpdir):
725756
transform = augmentation()
726757
s_transform = torch.jit.script(transform)

torchvision/transforms/autoaugment.py

Lines changed: 152 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from . import functional as F, InterpolationMode
99

10-
__all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment", "TrivialAugmentWide"]
10+
__all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment", "TrivialAugmentWide", "AugMix"]
1111

1212

1313
def _apply_op(
@@ -458,3 +458,154 @@ def __repr__(self) -> str:
458458
f")"
459459
)
460460
return s
461+
462+
463+
class AugMix(torch.nn.Module):
464+
r"""AugMix data augmentation method based on
465+
`"AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty" <https://arxiv.org/abs/1912.02781>`_.
466+
If the image is torch Tensor, it should be of type torch.uint8, and it is expected
467+
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
468+
If img is PIL Image, it is expected to be in mode "L" or "RGB".
469+
470+
Args:
471+
severity (int): The severity of base augmentation operators. Default is ``3``.
472+
mixture_width (int): The number of augmentation chains. Default is ``3``.
473+
chain_depth (int): The depth of augmentation chains. A negative value denotes stochastic depth sampled from the interval [1, 3].
474+
Default is ``-1``.
475+
alpha (float): The hyperparameter for the probability distributions. Default is ``1.0``.
476+
all_ops (bool): Use all operations (including brightness, contrast, color and sharpness). Default is ``True``.
477+
interpolation (InterpolationMode): Desired interpolation enum defined by
478+
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
479+
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
480+
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
481+
image. If given a number, the value is used for all bands respectively.
482+
"""
483+
484+
def __init__(
485+
self,
486+
severity: int = 3,
487+
mixture_width: int = 3,
488+
chain_depth: int = -1,
489+
alpha: float = 1.0,
490+
all_ops: bool = True,
491+
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
492+
fill: Optional[List[float]] = None,
493+
) -> None:
494+
super().__init__()
495+
self._PARAMETER_MAX = 10
496+
if not (1 <= severity <= self._PARAMETER_MAX):
497+
raise ValueError(f"The severity must be between [1, {self._PARAMETER_MAX}]. Got {severity} instead.")
498+
self.severity = severity
499+
self.mixture_width = mixture_width
500+
self.chain_depth = chain_depth
501+
self.alpha = alpha
502+
self.all_ops = all_ops
503+
self.interpolation = interpolation
504+
self.fill = fill
505+
506+
def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]:
507+
s = {
508+
# op_name: (magnitudes, signed)
509+
"ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
510+
"ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
511+
"TranslateX": (torch.linspace(0.0, image_size[0] / 3.0, num_bins), True),
512+
"TranslateY": (torch.linspace(0.0, image_size[1] / 3.0, num_bins), True),
513+
"Rotate": (torch.linspace(0.0, 30.0, num_bins), True),
514+
"Posterize": (4 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False),
515+
"Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
516+
"AutoContrast": (torch.tensor(0.0), False),
517+
"Equalize": (torch.tensor(0.0), False),
518+
}
519+
if self.all_ops:
520+
s.update(
521+
{
522+
"Brightness": (torch.linspace(0.0, 0.9, num_bins), True),
523+
"Color": (torch.linspace(0.0, 0.9, num_bins), True),
524+
"Contrast": (torch.linspace(0.0, 0.9, num_bins), True),
525+
"Sharpness": (torch.linspace(0.0, 0.9, num_bins), True),
526+
}
527+
)
528+
return s
529+
530+
@torch.jit.unused
531+
def _pil_to_tensor(self, img) -> Tensor:
532+
return F.pil_to_tensor(img)
533+
534+
@torch.jit.unused
535+
def _tensor_to_pil(self, img: Tensor):
536+
return F.to_pil_image(img)
537+
538+
def _sample_dirichlet(self, params: Tensor) -> Tensor:
539+
# Must be on a separate method so that we can overwrite it in tests.
540+
return torch._sample_dirichlet(params)
541+
542+
def forward(self, orig_img: Tensor) -> Tensor:
543+
"""
544+
img (PIL Image or Tensor): Image to be transformed.
545+
546+
Returns:
547+
PIL Image or Tensor: Transformed image.
548+
"""
549+
fill = self.fill
550+
if isinstance(orig_img, Tensor):
551+
img = orig_img
552+
if isinstance(fill, (int, float)):
553+
fill = [float(fill)] * F.get_image_num_channels(img)
554+
elif fill is not None:
555+
fill = [float(f) for f in fill]
556+
else:
557+
img = self._pil_to_tensor(orig_img)
558+
559+
op_meta = self._augmentation_space(self._PARAMETER_MAX, F.get_image_size(img))
560+
561+
orig_dims = list(img.shape)
562+
batch = img.view([1] * max(4 - img.ndim, 0) + orig_dims)
563+
batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1)
564+
565+
# Sample the beta weights for combining the original and augmented image. To get Beta, we use a Dirichlet
566+
# with 2 parameters. The 1st column stores the weights of the original and the 2nd the ones of augmented image.
567+
m = self._sample_dirichlet(
568+
torch.tensor([self.alpha, self.alpha], device=batch.device).expand(batch_dims[0], -1)
569+
)
570+
571+
# Sample the mixing weights and combine them with the ones sampled from Beta for the augmented images.
572+
combined_weights = self._sample_dirichlet(
573+
torch.tensor([self.alpha] * self.mixture_width, device=batch.device).expand(batch_dims[0], -1)
574+
) * m[:, 1].view([batch_dims[0], -1])
575+
576+
mix = m[:, 0].view(batch_dims) * batch
577+
for i in range(self.mixture_width):
578+
aug = batch
579+
depth = self.chain_depth if self.chain_depth > 0 else int(torch.randint(low=1, high=4, size=(1,)).item())
580+
for _ in range(depth):
581+
op_index = int(torch.randint(len(op_meta), (1,)).item())
582+
op_name = list(op_meta.keys())[op_index]
583+
magnitudes, signed = op_meta[op_name]
584+
magnitude = (
585+
float(magnitudes[torch.randint(self.severity, (1,), dtype=torch.long)].item())
586+
if magnitudes.ndim > 0
587+
else 0.0
588+
)
589+
if signed and torch.randint(2, (1,)):
590+
magnitude *= -1.0
591+
aug = _apply_op(aug, op_name, magnitude, interpolation=self.interpolation, fill=fill)
592+
mix.add_(combined_weights[:, i].view(batch_dims) * aug)
593+
mix = mix.view(orig_dims).to(dtype=img.dtype)
594+
595+
if not isinstance(orig_img, Tensor):
596+
return self._tensor_to_pil(mix)
597+
return mix
598+
599+
def __repr__(self) -> str:
600+
s = (
601+
f"{self.__class__.__name__}("
602+
f"severity={self.severity}"
603+
f", mixture_width={self.mixture_width}"
604+
f", chain_depth={self.chain_depth}"
605+
f", alpha={self.alpha}"
606+
f", all_ops={self.all_ops}"
607+
f", interpolation={self.interpolation}"
608+
f", fill={self.fill}"
609+
f")"
610+
)
611+
return s

0 commit comments

Comments
 (0)