Skip to content

Commit 11a2eed

Browse files
authored
[proto] Small improvement for tensor equalize op (#6738)
* [proto] Small improvement for tensor equalize op * Fix code formatting * Added a comment on the ops
1 parent 9d16da2 commit 11a2eed

File tree

1 file changed

+26
-8
lines changed
  • torchvision/prototype/transforms/functional

1 file changed

+26
-8
lines changed

torchvision/prototype/transforms/functional/_color.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,30 @@ def autocontrast(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
183183
return autocontrast_image_pil(inpt)
184184

185185

186+
def _scale_channel(img_chan: torch.Tensor) -> torch.Tensor:
187+
# TODO: we should expect bincount to always be faster than histc, but this
188+
# isn't always the case. Once
189+
# https://github.com/pytorch/pytorch/issues/53194 is fixed, remove the if
190+
# block and only use bincount.
191+
if img_chan.is_cuda:
192+
hist = torch.histc(img_chan.to(torch.float32), bins=256, min=0, max=255)
193+
else:
194+
hist = torch.bincount(img_chan.view(-1), minlength=256)
195+
196+
nonzero_hist = hist[hist != 0]
197+
step = torch.div(nonzero_hist[:-1].sum(), 255, rounding_mode="floor")
198+
if step == 0:
199+
return img_chan
200+
201+
lut = torch.div(torch.cumsum(hist, 0) + torch.div(step, 2, rounding_mode="floor"), step, rounding_mode="floor")
202+
# Doing inplace clamp and converting lut to uint8 improves perfs
203+
lut.clamp_(0, 255)
204+
lut = lut.to(torch.uint8)
205+
lut = torch.nn.functional.pad(lut[:-1], [1, 0])
206+
207+
return lut[img_chan.to(torch.int64)]
208+
209+
186210
def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
187211
if image.dtype != torch.uint8:
188212
raise TypeError(f"Only torch.uint8 image tensors are supported, but found {image.dtype}")
@@ -194,15 +218,9 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
194218
if image.numel() == 0:
195219
return image
196220
elif image.ndim == 2:
197-
return _FT._scale_channel(image)
221+
return _scale_channel(image)
198222
else:
199-
return torch.stack(
200-
[
201-
# TODO: when merging transforms v1 and v2, we can inline this function call
202-
_FT._equalize_single_image(single_image)
203-
for single_image in image.view(-1, num_channels, height, width)
204-
]
205-
).view(image.shape)
223+
return torch.stack([_scale_channel(x) for x in image.view(-1, height, width)]).view(image.shape)
206224

207225

208226
equalize_image_pil = _FP.equalize

0 commit comments

Comments
 (0)