Skip to content

Commit cf69db9

Browse files
YosuaMichaelfacebook-github-bot
authored andcommitted
[fbsync] Add RandomEqualize prototype transforms (#5807)
Summary: Co-authored-by: Vasilis Vryniotis <[email protected]> Reviewed By: jdsgomes, NicolasHug Differential Revision: D36095678 fbshipit-source-id: 449fc4a4d9f0a3fcc2b7314001607c28a247ba21
1 parent a068746 commit cf69db9

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

torchvision/prototype/transforms/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from ._augment import RandomErasing, RandomMixup, RandomCutmix
66
from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment, AugMix
7-
from ._color import ColorJitter, RandomPhotometricDistort
7+
from ._color import ColorJitter, RandomPhotometricDistort, RandomEqualize
88
from ._container import Compose, RandomApply, RandomChoice, RandomOrder
99
from ._geometry import (
1010
Resize,

torchvision/prototype/transforms/_color.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from torchvision.prototype.transforms import Transform, functional as F
99
from torchvision.transforms import functional as _F
1010

11+
from ._transform import _RandomApplyTransform
1112
from ._utils import is_simple_tensor, get_image_dimensions, query_image
1213

1314
T = TypeVar("T", features.Image, torch.Tensor, PIL.Image.Image)
@@ -188,3 +189,19 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
188189
if params["channel_shuffle"]:
189190
input = self._channel_shuffle(input)
190191
return input
192+
193+
194+
class RandomEqualize(_RandomApplyTransform):
195+
def __init__(self, p: float = 0.5):
196+
super().__init__(p=p)
197+
198+
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
199+
if isinstance(input, features.Image):
200+
output = F.equalize_image_tensor(input)
201+
return features.Image.new_like(input, output)
202+
elif is_simple_tensor(input):
203+
return F.equalize_image_tensor(input)
204+
elif isinstance(input, PIL.Image.Image):
205+
return F.equalize_image_pil(input)
206+
else:
207+
return input

0 commit comments

Comments
 (0)