Skip to content

Commit 7f5513d

Browse files
authored
improve performance of {invert, solarize}_image_tensor (#6819)
* improve performance of invert_image_tensor * cleanup * lint * more cleanup * use new invert in solarize
1 parent 6979888 commit 7f5513d

File tree

1 file changed

+14
-2
lines changed
  • torchvision/prototype/transforms/functional

1 file changed

+14
-2
lines changed

torchvision/prototype/transforms/functional/_color.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,13 @@ def posterize(inpt: features.InputTypeJIT, bits: int) -> features.InputTypeJIT:
312312
return posterize_image_pil(inpt, bits=bits)
313313

314314

315-
solarize_image_tensor = _FT.solarize
315+
def solarize_image_tensor(image: torch.Tensor, threshold: float) -> torch.Tensor:
316+
if threshold > _FT._max_value(image.dtype):
317+
raise TypeError(f"Threshold should be less or equal the maximum value of the dtype, but got {threshold}")
318+
319+
return torch.where(image >= threshold, invert_image_tensor(image), image)
320+
321+
316322
solarize_image_pil = _FP.solarize
317323

318324

@@ -456,7 +462,13 @@ def equalize(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
456462
return equalize_image_pil(inpt)
457463

458464

459-
invert_image_tensor = _FT.invert
465+
def invert_image_tensor(image: torch.Tensor) -> torch.Tensor:
466+
if image.dtype == torch.uint8:
467+
return image.bitwise_not()
468+
else:
469+
return _FT._max_value(image.dtype) - image # type: ignore[no-any-return]
470+
471+
460472
invert_image_pil = _FP.invert
461473

462474

0 commit comments

Comments
 (0)