Skip to content

Commit 30bbae9

Browse files
committed
Fix search space of TrivialAugment.
1 parent 2933667 commit 30bbae9

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

torchvision/transforms/autoaugment.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ def _apply_op(img: Tensor, op_name: str, magnitude: float,
4444
img = F.equalize(img)
4545
elif op_name == "Invert":
4646
img = F.invert(img)
47+
elif op_name == "Identity":
48+
pass
4749
else:
4850
raise ValueError("The provided operator {} is not recognized.".format(op_name))
4951
return img
@@ -353,6 +355,7 @@ def __init__(self, num_magnitude_bins: int = 30, interpolation: InterpolationMod
353355
def _augmentation_space(self, num_bins: int) -> Dict[str, Tuple[Tensor, bool]]:
354356
return {
355357
# op_name: (magnitudes, signed)
358+
"Identity": (torch.tensor(0.0), False),
356359
"ShearX": (torch.linspace(0.0, 0.99, num_bins), True),
357360
"ShearY": (torch.linspace(0.0, 0.99, num_bins), True),
358361
"TranslateX": (torch.linspace(0.0, 32.0, num_bins), True),
@@ -366,7 +369,6 @@ def _augmentation_space(self, num_bins: int) -> Dict[str, Tuple[Tensor, bool]]:
366369
"Solarize": (torch.linspace(256.0, 0.0, num_bins), False),
367370
"AutoContrast": (torch.tensor(0.0), False),
368371
"Equalize": (torch.tensor(0.0), False),
369-
"Invert": (torch.tensor(0.0), False),
370372
}
371373

372374
def forward(self, img: Tensor) -> Tensor:

0 commit comments

Comments
 (0)