@@ -143,7 +143,104 @@ def adjust_sharpness(inpt: features.InputTypeJIT, sharpness_factor: float) -> fe
143
143
return adjust_sharpness_image_pil (inpt , sharpness_factor = sharpness_factor )
144
144
145
145
146
- adjust_hue_image_tensor = _FT .adjust_hue
146
+ def _rgb_to_hsv (image : torch .Tensor ) -> torch .Tensor :
147
+ r , g , _ = image .unbind (dim = - 3 )
148
+
149
+ # Implementation is based on
150
+ # https://github.com/python-pillow/Pillow/blob/4174d4267616897df3746d315d5a2d0f82c656ee/src/libImaging/Convert.c#L330
151
+ minc , maxc = torch .aminmax (image , dim = - 3 )
152
+
153
+ # The algorithm erases S and H channel where `maxc = minc`. This avoids NaN
154
+ # from happening in the results, because
155
+ # + S channel has division by `maxc`, which is zero only if `maxc = minc`
156
+ # + H channel has division by `(maxc - minc)`.
157
+ #
158
+ # Instead of overwriting NaN afterwards, we just prevent it from occuring so
159
+ # we don't need to deal with it in case we save the NaN in a buffer in
160
+ # backprop, if it is ever supported, but it doesn't hurt to do so.
161
+ eqc = maxc == minc
162
+
163
+ channels_range = maxc - minc
164
+ # Since `eqc => channels_range = 0`, replacing denominator with 1 when `eqc` is fine.
165
+ ones = torch .ones_like (maxc )
166
+ s = channels_range / torch .where (eqc , ones , maxc )
167
+ # Note that `eqc => maxc = minc = r = g = b`. So the following calculation
168
+ # of `h` would reduce to `bc - gc + 2 + rc - bc + 4 + rc - bc = 6` so it
169
+ # would not matter what values `rc`, `gc`, and `bc` have here, and thus
170
+ # replacing denominator with 1 when `eqc` is fine.
171
+ channels_range_divisor = torch .where (eqc , ones , channels_range ).unsqueeze_ (dim = - 3 )
172
+ rc , gc , bc = ((maxc .unsqueeze (dim = - 3 ) - image ) / channels_range_divisor ).unbind (dim = - 3 )
173
+
174
+ mask_maxc_neq_r = maxc != r
175
+ mask_maxc_eq_g = maxc == g
176
+ mask_maxc_neq_g = ~ mask_maxc_eq_g
177
+
178
+ hr = (bc - gc ).mul_ (~ mask_maxc_neq_r )
179
+ hg = (2.0 + rc ).sub_ (bc ).mul_ (mask_maxc_eq_g & mask_maxc_neq_r )
180
+ hb = (4.0 + gc ).sub_ (rc ).mul_ (mask_maxc_neq_g & mask_maxc_neq_r )
181
+
182
+ h = hr .add_ (hg ).add_ (hb )
183
+ h = h .div_ (6.0 ).add_ (1.0 ).fmod_ (1.0 )
184
+ return torch .stack ((h , s , maxc ), dim = - 3 )
185
+
186
+
187
+ def _hsv_to_rgb (img : torch .Tensor ) -> torch .Tensor :
188
+ h , s , v = img .unbind (dim = - 3 )
189
+ h6 = h * 6
190
+ i = torch .floor (h6 )
191
+ f = (h6 ) - i
192
+ i = i .to (dtype = torch .int32 )
193
+
194
+ p = (v * (1.0 - s )).clamp_ (0.0 , 1.0 )
195
+ q = (v * (1.0 - s * f )).clamp_ (0.0 , 1.0 )
196
+ t = (v * (1.0 - s * (1.0 - f ))).clamp_ (0.0 , 1.0 )
197
+ i .remainder_ (6 )
198
+
199
+ mask = i .unsqueeze (dim = - 3 ) == torch .arange (6 , device = i .device ).view (- 1 , 1 , 1 )
200
+
201
+ a1 = torch .stack ((v , q , p , p , t , v ), dim = - 3 )
202
+ a2 = torch .stack ((t , v , v , q , p , p ), dim = - 3 )
203
+ a3 = torch .stack ((p , p , t , v , v , q ), dim = - 3 )
204
+ a4 = torch .stack ((a1 , a2 , a3 ), dim = - 4 )
205
+
206
+ return (a4 .mul_ (mask .to (dtype = img .dtype ).unsqueeze (dim = - 4 ))).sum (dim = - 3 )
207
+
208
+
209
+ def adjust_hue_image_tensor (image : torch .Tensor , hue_factor : float ) -> torch .Tensor :
210
+ if not (- 0.5 <= hue_factor <= 0.5 ):
211
+ raise ValueError (f"hue_factor ({ hue_factor } ) is not in [-0.5, 0.5]." )
212
+
213
+ if not (isinstance (image , torch .Tensor )):
214
+ raise TypeError ("Input img should be Tensor image" )
215
+
216
+ c = get_num_channels_image_tensor (image )
217
+
218
+ if c not in [1 , 3 ]:
219
+ raise TypeError (f"Input image tensor permitted channel values are { [1 , 3 ]} , but found { c } " )
220
+
221
+ if c == 1 : # Match PIL behaviour
222
+ return image
223
+
224
+ if image .numel () == 0 :
225
+ # exit earlier on empty images
226
+ return image
227
+
228
+ orig_dtype = image .dtype
229
+ if image .dtype == torch .uint8 :
230
+ image = image / 255.0
231
+
232
+ image = _rgb_to_hsv (image )
233
+ h , s , v = image .unbind (dim = - 3 )
234
+ h .add_ (hue_factor ).remainder_ (1.0 )
235
+ image = torch .stack ((h , s , v ), dim = - 3 )
236
+ image_hue_adj = _hsv_to_rgb (image )
237
+
238
+ if orig_dtype == torch .uint8 :
239
+ image_hue_adj = image_hue_adj .mul_ (255.0 ).to (dtype = orig_dtype )
240
+
241
+ return image_hue_adj
242
+
243
+
147
244
adjust_hue_image_pil = _FP .adjust_hue
148
245
149
246
0 commit comments