Skip to content

Commit b83d5f7

Browse files
authored
add support for apply probability to CutMix and MixUp (#6448)
1 parent 2a0eea8 commit b83d5f7

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

torchvision/prototype/transforms/_augment.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import PIL.Image
77
import torch
88
from torchvision.prototype import features
9-
from torchvision.prototype.transforms import functional as F, Transform
9+
from torchvision.prototype.transforms import functional as F
1010

1111
from ._transform import _RandomApplyTransform
1212
from ._utils import get_image_dimensions, has_all, has_any, is_simple_tensor, query_image
@@ -97,9 +97,9 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
9797
return inpt
9898

9999

100-
class _BaseMixupCutmix(Transform):
101-
def __init__(self, *, alpha: float) -> None:
102-
super().__init__()
100+
class _BaseMixupCutmix(_RandomApplyTransform):
101+
def __init__(self, *, alpha: float, p: float = 0.5) -> None:
102+
super().__init__(p=p)
103103
self.alpha = alpha
104104
self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha]))
105105

0 commit comments

Comments
 (0)