|
7 | 7 |
|
8 | 8 | from . import functional as F, InterpolationMode
|
9 | 9 |
|
10 |
| -__all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment", "TrivialAugmentWide"] |
| 10 | +__all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment", "TrivialAugmentWide", "AugMix"] |
11 | 11 |
|
12 | 12 |
|
13 | 13 | def _apply_op(
|
@@ -458,3 +458,154 @@ def __repr__(self) -> str:
|
458 | 458 | f")"
|
459 | 459 | )
|
460 | 460 | return s
|
| 461 | + |
| 462 | + |
| 463 | +class AugMix(torch.nn.Module): |
| 464 | + r"""AugMix data augmentation method based on |
| 465 | + `"AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty" <https://arxiv.org/abs/1912.02781>`_. |
| 466 | + If the image is torch Tensor, it should be of type torch.uint8, and it is expected |
| 467 | + to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. |
| 468 | + If img is PIL Image, it is expected to be in mode "L" or "RGB". |
| 469 | +
|
| 470 | + Args: |
| 471 | + severity (int): The severity of base augmentation operators. Default is ``3``. |
| 472 | + mixture_width (int): The number of augmentation chains. Default is ``3``. |
| 473 | + chain_depth (int): The depth of augmentation chains. A negative value denotes stochastic depth sampled from the interval [1, 3]. |
| 474 | + Default is ``-1``. |
| 475 | + alpha (float): The hyperparameter for the probability distributions. Default is ``1.0``. |
| 476 | + all_ops (bool): Use all operations (including brightness, contrast, color and sharpness). Default is ``True``. |
| 477 | + interpolation (InterpolationMode): Desired interpolation enum defined by |
| 478 | + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. |
| 479 | + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. |
| 480 | + fill (sequence or number, optional): Pixel fill value for the area outside the transformed |
| 481 | + image. If given a number, the value is used for all bands respectively. |
| 482 | + """ |
| 483 | + |
| 484 | + def __init__( |
| 485 | + self, |
| 486 | + severity: int = 3, |
| 487 | + mixture_width: int = 3, |
| 488 | + chain_depth: int = -1, |
| 489 | + alpha: float = 1.0, |
| 490 | + all_ops: bool = True, |
| 491 | + interpolation: InterpolationMode = InterpolationMode.BILINEAR, |
| 492 | + fill: Optional[List[float]] = None, |
| 493 | + ) -> None: |
| 494 | + super().__init__() |
| 495 | + self._PARAMETER_MAX = 10 |
| 496 | + if not (1 <= severity <= self._PARAMETER_MAX): |
| 497 | + raise ValueError(f"The severity must be between [1, {self._PARAMETER_MAX}]. Got {severity} instead.") |
| 498 | + self.severity = severity |
| 499 | + self.mixture_width = mixture_width |
| 500 | + self.chain_depth = chain_depth |
| 501 | + self.alpha = alpha |
| 502 | + self.all_ops = all_ops |
| 503 | + self.interpolation = interpolation |
| 504 | + self.fill = fill |
| 505 | + |
| 506 | + def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]: |
| 507 | + s = { |
| 508 | + # op_name: (magnitudes, signed) |
| 509 | + "ShearX": (torch.linspace(0.0, 0.3, num_bins), True), |
| 510 | + "ShearY": (torch.linspace(0.0, 0.3, num_bins), True), |
| 511 | + "TranslateX": (torch.linspace(0.0, image_size[0] / 3.0, num_bins), True), |
| 512 | + "TranslateY": (torch.linspace(0.0, image_size[1] / 3.0, num_bins), True), |
| 513 | + "Rotate": (torch.linspace(0.0, 30.0, num_bins), True), |
| 514 | + "Posterize": (4 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False), |
| 515 | + "Solarize": (torch.linspace(255.0, 0.0, num_bins), False), |
| 516 | + "AutoContrast": (torch.tensor(0.0), False), |
| 517 | + "Equalize": (torch.tensor(0.0), False), |
| 518 | + } |
| 519 | + if self.all_ops: |
| 520 | + s.update( |
| 521 | + { |
| 522 | + "Brightness": (torch.linspace(0.0, 0.9, num_bins), True), |
| 523 | + "Color": (torch.linspace(0.0, 0.9, num_bins), True), |
| 524 | + "Contrast": (torch.linspace(0.0, 0.9, num_bins), True), |
| 525 | + "Sharpness": (torch.linspace(0.0, 0.9, num_bins), True), |
| 526 | + } |
| 527 | + ) |
| 528 | + return s |
| 529 | + |
| 530 | + @torch.jit.unused |
| 531 | + def _pil_to_tensor(self, img) -> Tensor: |
| 532 | + return F.pil_to_tensor(img) |
| 533 | + |
| 534 | + @torch.jit.unused |
| 535 | + def _tensor_to_pil(self, img: Tensor): |
| 536 | + return F.to_pil_image(img) |
| 537 | + |
| 538 | + def _sample_dirichlet(self, params: Tensor) -> Tensor: |
| 539 | + # Must be on a separate method so that we can overwrite it in tests. |
| 540 | + return torch._sample_dirichlet(params) |
| 541 | + |
| 542 | + def forward(self, orig_img: Tensor) -> Tensor: |
| 543 | + """ |
| 544 | + img (PIL Image or Tensor): Image to be transformed. |
| 545 | +
|
| 546 | + Returns: |
| 547 | + PIL Image or Tensor: Transformed image. |
| 548 | + """ |
| 549 | + fill = self.fill |
| 550 | + if isinstance(orig_img, Tensor): |
| 551 | + img = orig_img |
| 552 | + if isinstance(fill, (int, float)): |
| 553 | + fill = [float(fill)] * F.get_image_num_channels(img) |
| 554 | + elif fill is not None: |
| 555 | + fill = [float(f) for f in fill] |
| 556 | + else: |
| 557 | + img = self._pil_to_tensor(orig_img) |
| 558 | + |
| 559 | + op_meta = self._augmentation_space(self._PARAMETER_MAX, F.get_image_size(img)) |
| 560 | + |
| 561 | + orig_dims = list(img.shape) |
| 562 | + batch = img.view([1] * max(4 - img.ndim, 0) + orig_dims) |
| 563 | + batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1) |
| 564 | + |
| 565 | + # Sample the beta weights for combining the original and augmented image. To get Beta, we use a Dirichlet |
| 566 | + # with 2 parameters. The 1st column stores the weights of the original and the 2nd the ones of augmented image. |
| 567 | + m = self._sample_dirichlet( |
| 568 | + torch.tensor([self.alpha, self.alpha], device=batch.device).expand(batch_dims[0], -1) |
| 569 | + ) |
| 570 | + |
| 571 | + # Sample the mixing weights and combine them with the ones sampled from Beta for the augmented images. |
| 572 | + combined_weights = self._sample_dirichlet( |
| 573 | + torch.tensor([self.alpha] * self.mixture_width, device=batch.device).expand(batch_dims[0], -1) |
| 574 | + ) * m[:, 1].view([batch_dims[0], -1]) |
| 575 | + |
| 576 | + mix = m[:, 0].view(batch_dims) * batch |
| 577 | + for i in range(self.mixture_width): |
| 578 | + aug = batch |
| 579 | + depth = self.chain_depth if self.chain_depth > 0 else int(torch.randint(low=1, high=4, size=(1,)).item()) |
| 580 | + for _ in range(depth): |
| 581 | + op_index = int(torch.randint(len(op_meta), (1,)).item()) |
| 582 | + op_name = list(op_meta.keys())[op_index] |
| 583 | + magnitudes, signed = op_meta[op_name] |
| 584 | + magnitude = ( |
| 585 | + float(magnitudes[torch.randint(self.severity, (1,), dtype=torch.long)].item()) |
| 586 | + if magnitudes.ndim > 0 |
| 587 | + else 0.0 |
| 588 | + ) |
| 589 | + if signed and torch.randint(2, (1,)): |
| 590 | + magnitude *= -1.0 |
| 591 | + aug = _apply_op(aug, op_name, magnitude, interpolation=self.interpolation, fill=fill) |
| 592 | + mix.add_(combined_weights[:, i].view(batch_dims) * aug) |
| 593 | + mix = mix.view(orig_dims).to(dtype=img.dtype) |
| 594 | + |
| 595 | + if not isinstance(orig_img, Tensor): |
| 596 | + return self._tensor_to_pil(mix) |
| 597 | + return mix |
| 598 | + |
| 599 | + def __repr__(self) -> str: |
| 600 | + s = ( |
| 601 | + f"{self.__class__.__name__}(" |
| 602 | + f"severity={self.severity}" |
| 603 | + f", mixture_width={self.mixture_width}" |
| 604 | + f", chain_depth={self.chain_depth}" |
| 605 | + f", alpha={self.alpha}" |
| 606 | + f", all_ops={self.all_ops}" |
| 607 | + f", interpolation={self.interpolation}" |
| 608 | + f", fill={self.fill}" |
| 609 | + f")" |
| 610 | + ) |
| 611 | + return s |
0 commit comments