@@ -183,6 +183,30 @@ def autocontrast(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
183
183
return autocontrast_image_pil (inpt )
184
184
185
185
186
+ def _scale_channel (img_chan : torch .Tensor ) -> torch .Tensor :
187
+ # TODO: we should expect bincount to always be faster than histc, but this
188
+ # isn't always the case. Once
189
+ # https://github.com/pytorch/pytorch/issues/53194 is fixed, remove the if
190
+ # block and only use bincount.
191
+ if img_chan .is_cuda :
192
+ hist = torch .histc (img_chan .to (torch .float32 ), bins = 256 , min = 0 , max = 255 )
193
+ else :
194
+ hist = torch .bincount (img_chan .view (- 1 ), minlength = 256 )
195
+
196
+ nonzero_hist = hist [hist != 0 ]
197
+ step = torch .div (nonzero_hist [:- 1 ].sum (), 255 , rounding_mode = "floor" )
198
+ if step == 0 :
199
+ return img_chan
200
+
201
+ lut = torch .div (torch .cumsum (hist , 0 ) + torch .div (step , 2 , rounding_mode = "floor" ), step , rounding_mode = "floor" )
202
+ # Doing inplace clamp and converting lut to uint8 improves perfs
203
+ lut .clamp_ (0 , 255 )
204
+ lut = lut .to (torch .uint8 )
205
+ lut = torch .nn .functional .pad (lut [:- 1 ], [1 , 0 ])
206
+
207
+ return lut [img_chan .to (torch .int64 )]
208
+
209
+
186
210
def equalize_image_tensor (image : torch .Tensor ) -> torch .Tensor :
187
211
if image .dtype != torch .uint8 :
188
212
raise TypeError (f"Only torch.uint8 image tensors are supported, but found { image .dtype } " )
@@ -194,15 +218,9 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
194
218
if image .numel () == 0 :
195
219
return image
196
220
elif image .ndim == 2 :
197
- return _FT . _scale_channel (image )
221
+ return _scale_channel (image )
198
222
else :
199
- return torch .stack (
200
- [
201
- # TODO: when merging transforms v1 and v2, we can inline this function call
202
- _FT ._equalize_single_image (single_image )
203
- for single_image in image .view (- 1 , num_channels , height , width )
204
- ]
205
- ).view (image .shape )
223
+ return torch .stack ([_scale_channel (x ) for x in image .view (- 1 , height , width )]).view (image .shape )
206
224
207
225
208
226
equalize_image_pil = _FP .equalize
0 commit comments