File tree Expand file tree Collapse file tree 1 file changed +14
-2
lines changed
torchvision/prototype/transforms/functional Expand file tree Collapse file tree 1 file changed +14
-2
lines changed Original file line number Diff line number Diff line change @@ -312,7 +312,13 @@ def posterize(inpt: features.InputTypeJIT, bits: int) -> features.InputTypeJIT:
312312 return posterize_image_pil (inpt , bits = bits )
313313
314314
315- solarize_image_tensor = _FT .solarize
315+ def solarize_image_tensor (image : torch .Tensor , threshold : float ) -> torch .Tensor :
316+ if threshold > _FT ._max_value (image .dtype ):
317+ raise TypeError (f"Threshold should be less or equal the maximum value of the dtype, but got { threshold } " )
318+
319+ return torch .where (image >= threshold , invert_image_tensor (image ), image )
320+
321+
316322solarize_image_pil = _FP .solarize
317323
318324
@@ -456,7 +462,13 @@ def equalize(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
456462 return equalize_image_pil (inpt )
457463
458464
459- invert_image_tensor = _FT .invert
465+ def invert_image_tensor (image : torch .Tensor ) -> torch .Tensor :
466+ if image .dtype == torch .uint8 :
467+ return image .bitwise_not ()
468+ else :
469+ return _FT ._max_value (image .dtype ) - image # type: ignore[no-any-return]
470+
471+
460472invert_image_pil = _FP .invert
461473
462474
You can’t perform that action at this time.
0 commit comments