Skip to content

Commit 0cc1c75

Browse files
datumboxfacebook-github-bot
authored andcommitted
[fbsync] [proto] Speed improvements for adjust hue op (#6805)
Summary: * WIP * Updated rgb2hsv and a bit of hsv2rgb * Fix issue with batch of images * Few improvements * hsv2rgb improvements * PR review * another update * Fix cuda issue with empty images torch.aminmax is failing Reviewed By: YosuaMichael Differential Revision: D40722899 fbshipit-source-id: 59edbba970a015fbc58c26828b36197945f46080 Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent 27fc3e6 commit 0cc1c75

File tree

1 file changed

+98
-1
lines changed
  • torchvision/prototype/transforms/functional

1 file changed

+98
-1
lines changed

torchvision/prototype/transforms/functional/_color.py

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,104 @@ def adjust_sharpness(inpt: features.InputTypeJIT, sharpness_factor: float) -> fe
143143
return adjust_sharpness_image_pil(inpt, sharpness_factor=sharpness_factor)
144144

145145

146-
adjust_hue_image_tensor = _FT.adjust_hue
146+
def _rgb_to_hsv(image: torch.Tensor) -> torch.Tensor:
147+
r, g, _ = image.unbind(dim=-3)
148+
149+
# Implementation is based on
150+
# https://github.com/python-pillow/Pillow/blob/4174d4267616897df3746d315d5a2d0f82c656ee/src/libImaging/Convert.c#L330
151+
minc, maxc = torch.aminmax(image, dim=-3)
152+
153+
# The algorithm erases S and H channel where `maxc = minc`. This avoids NaN
154+
# from happening in the results, because
155+
# + S channel has division by `maxc`, which is zero only if `maxc = minc`
156+
# + H channel has division by `(maxc - minc)`.
157+
#
158+
# Instead of overwriting NaN afterwards, we just prevent it from occuring so
159+
# we don't need to deal with it in case we save the NaN in a buffer in
160+
# backprop, if it is ever supported, but it doesn't hurt to do so.
161+
eqc = maxc == minc
162+
163+
channels_range = maxc - minc
164+
# Since `eqc => channels_range = 0`, replacing denominator with 1 when `eqc` is fine.
165+
ones = torch.ones_like(maxc)
166+
s = channels_range / torch.where(eqc, ones, maxc)
167+
# Note that `eqc => maxc = minc = r = g = b`. So the following calculation
168+
# of `h` would reduce to `bc - gc + 2 + rc - bc + 4 + rc - bc = 6` so it
169+
# would not matter what values `rc`, `gc`, and `bc` have here, and thus
170+
# replacing denominator with 1 when `eqc` is fine.
171+
channels_range_divisor = torch.where(eqc, ones, channels_range).unsqueeze_(dim=-3)
172+
rc, gc, bc = ((maxc.unsqueeze(dim=-3) - image) / channels_range_divisor).unbind(dim=-3)
173+
174+
mask_maxc_neq_r = maxc != r
175+
mask_maxc_eq_g = maxc == g
176+
mask_maxc_neq_g = ~mask_maxc_eq_g
177+
178+
hr = (bc - gc).mul_(~mask_maxc_neq_r)
179+
hg = (2.0 + rc).sub_(bc).mul_(mask_maxc_eq_g & mask_maxc_neq_r)
180+
hb = (4.0 + gc).sub_(rc).mul_(mask_maxc_neq_g & mask_maxc_neq_r)
181+
182+
h = hr.add_(hg).add_(hb)
183+
h = h.div_(6.0).add_(1.0).fmod_(1.0)
184+
return torch.stack((h, s, maxc), dim=-3)
185+
186+
187+
def _hsv_to_rgb(img: torch.Tensor) -> torch.Tensor:
188+
h, s, v = img.unbind(dim=-3)
189+
h6 = h * 6
190+
i = torch.floor(h6)
191+
f = (h6) - i
192+
i = i.to(dtype=torch.int32)
193+
194+
p = (v * (1.0 - s)).clamp_(0.0, 1.0)
195+
q = (v * (1.0 - s * f)).clamp_(0.0, 1.0)
196+
t = (v * (1.0 - s * (1.0 - f))).clamp_(0.0, 1.0)
197+
i.remainder_(6)
198+
199+
mask = i.unsqueeze(dim=-3) == torch.arange(6, device=i.device).view(-1, 1, 1)
200+
201+
a1 = torch.stack((v, q, p, p, t, v), dim=-3)
202+
a2 = torch.stack((t, v, v, q, p, p), dim=-3)
203+
a3 = torch.stack((p, p, t, v, v, q), dim=-3)
204+
a4 = torch.stack((a1, a2, a3), dim=-4)
205+
206+
return (a4.mul_(mask.to(dtype=img.dtype).unsqueeze(dim=-4))).sum(dim=-3)
207+
208+
209+
def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Tensor:
210+
if not (-0.5 <= hue_factor <= 0.5):
211+
raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].")
212+
213+
if not (isinstance(image, torch.Tensor)):
214+
raise TypeError("Input img should be Tensor image")
215+
216+
c = get_num_channels_image_tensor(image)
217+
218+
if c not in [1, 3]:
219+
raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}")
220+
221+
if c == 1: # Match PIL behaviour
222+
return image
223+
224+
if image.numel() == 0:
225+
# exit earlier on empty images
226+
return image
227+
228+
orig_dtype = image.dtype
229+
if image.dtype == torch.uint8:
230+
image = image / 255.0
231+
232+
image = _rgb_to_hsv(image)
233+
h, s, v = image.unbind(dim=-3)
234+
h.add_(hue_factor).remainder_(1.0)
235+
image = torch.stack((h, s, v), dim=-3)
236+
image_hue_adj = _hsv_to_rgb(image)
237+
238+
if orig_dtype == torch.uint8:
239+
image_hue_adj = image_hue_adj.mul_(255.0).to(dtype=orig_dtype)
240+
241+
return image_hue_adj
242+
243+
147244
adjust_hue_image_pil = _FP.adjust_hue
148245

149246

0 commit comments

Comments
 (0)