Skip to content

Commit 3dd2e3d

Browse files
vfdev-5datumbox
andauthored
[proto] Speed improvement for autocontrast op (#6811)
* WIP * Updates to speed up autocontrast Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent e96860d commit 3dd2e3d

File tree

1 file changed

+28
-1
lines changed
  • torchvision/prototype/transforms/functional

1 file changed

+28
-1
lines changed

torchvision/prototype/transforms/functional/_color.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,34 @@ def solarize(inpt: features.InputTypeJIT, threshold: float) -> features.InputTyp
211211
return solarize_image_pil(inpt, threshold=threshold)
212212

213213

214-
autocontrast_image_tensor = _FT.autocontrast
214+
def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor:
215+
216+
if not (isinstance(image, torch.Tensor)):
217+
raise TypeError("Input img should be Tensor image")
218+
219+
c = get_num_channels_image_tensor(image)
220+
221+
if c not in [1, 3]:
222+
raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}")
223+
224+
if image.numel() == 0:
225+
# exit earlier on empty images
226+
return image
227+
228+
bound = 1.0 if image.is_floating_point() else 255.0
229+
dtype = image.dtype if torch.is_floating_point(image) else torch.float32
230+
231+
minimum = image.amin(dim=(-2, -1), keepdim=True).to(dtype)
232+
maximum = image.amax(dim=(-2, -1), keepdim=True).to(dtype)
233+
234+
scale = bound / (maximum - minimum)
235+
eq_idxs = maximum == minimum
236+
minimum[eq_idxs] = 0.0
237+
scale[eq_idxs] = 1.0
238+
239+
return (image - minimum).mul_(scale).clamp_(0, bound).to(image.dtype)
240+
241+
215242
autocontrast_image_pil = _FP.autocontrast
216243

217244

0 commit comments

Comments
 (0)