2
2
from torchvision .prototype import features
3
3
from torchvision .transforms import functional_pil as _FP , functional_tensor as _FT
4
4
5
- from ._meta import _rgb_to_gray , convert_dtype_image_tensor
5
+ from ._meta import _num_value_bits , _rgb_to_gray , convert_dtype_image_tensor
6
6
7
7
8
8
def _blend (image1 : torch .Tensor , image2 : torch .Tensor , ratio : float ) -> torch .Tensor :
9
9
ratio = float (ratio )
10
10
fp = image1 .is_floating_point ()
11
- bound = 1.0 if fp else 255.0
11
+ bound = _FT . _max_value ( image1 . dtype )
12
12
output = image1 .mul (ratio ).add_ (image2 , alpha = (1.0 - ratio )).clamp_ (0 , bound )
13
13
return output if fp else output .to (image1 .dtype )
14
14
@@ -20,7 +20,7 @@ def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float
20
20
_FT ._assert_channels (image , [1 , 3 ])
21
21
22
22
fp = image .is_floating_point ()
23
- bound = 1.0 if fp else 255.0
23
+ bound = _FT . _max_value ( image . dtype )
24
24
output = image .mul (brightness_factor ).clamp_ (0 , bound )
25
25
return output if fp else output .to (image .dtype )
26
26
@@ -222,19 +222,15 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten
222
222
return image
223
223
224
224
orig_dtype = image .dtype
225
- if image .dtype == torch .uint8 :
226
- image = image / 255.0
225
+ image = convert_dtype_image_tensor (image , torch .float32 )
227
226
228
227
image = _rgb_to_hsv (image )
229
228
h , s , v = image .unbind (dim = - 3 )
230
229
h .add_ (hue_factor ).remainder_ (1.0 )
231
230
image = torch .stack ((h , s , v ), dim = - 3 )
232
231
image_hue_adj = _hsv_to_rgb (image )
233
232
234
- if orig_dtype == torch .uint8 :
235
- image_hue_adj = image_hue_adj .mul_ (255.0 ).to (dtype = orig_dtype )
236
-
237
- return image_hue_adj
233
+ return convert_dtype_image_tensor (image_hue_adj , orig_dtype )
238
234
239
235
240
236
adjust_hue_image_pil = _FP .adjust_hue
@@ -289,14 +285,15 @@ def adjust_gamma(inpt: features.InputTypeJIT, gamma: float, gain: float = 1) ->
289
285
290
286
291
287
def posterize_image_tensor (image : torch .Tensor , bits : int ) -> torch .Tensor :
292
- if bits > 8 :
293
- return image
294
-
295
288
if image .is_floating_point ():
296
289
levels = 1 << bits
297
290
return image .mul (levels ).floor_ ().clamp_ (0 , levels - 1 ).div_ (levels )
298
291
else :
299
- mask = ((1 << bits ) - 1 ) << (8 - bits )
292
+ num_value_bits = _num_value_bits (image .dtype )
293
+ if bits >= num_value_bits :
294
+ return image
295
+
296
+ mask = ((1 << bits ) - 1 ) << (num_value_bits - bits )
300
297
return image & mask
301
298
302
299
@@ -317,8 +314,7 @@ def posterize(inpt: features.InputTypeJIT, bits: int) -> features.InputTypeJIT:
317
314
318
315
319
316
def solarize_image_tensor (image : torch .Tensor , threshold : float ) -> torch .Tensor :
320
- bound = 1 if image .is_floating_point () else 255
321
- if threshold > bound :
317
+ if threshold > _FT ._max_value (image .dtype ):
322
318
raise TypeError (f"Threshold should be less or equal the maximum value of the dtype, but got { threshold } " )
323
319
324
320
return torch .where (image >= threshold , invert_image_tensor (image ), image )
@@ -349,7 +345,7 @@ def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor:
349
345
# exit earlier on empty images
350
346
return image
351
347
352
- bound = 1.0 if image .is_floating_point () else 255.0
348
+ bound = _FT . _max_value ( image .dtype )
353
349
dtype = image .dtype if torch .is_floating_point (image ) else torch .float32
354
350
355
351
minimum = image .amin (dim = (- 2 , - 1 ), keepdim = True ).to (dtype )
@@ -383,14 +379,18 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
383
379
if image .numel () == 0 :
384
380
return image
385
381
382
+ # 1. The algorithm below can easily be extended to support arbitrary integer dtypes. However, the histogram that
383
+ # would be needed to computed will have at least `torch.iinfo(dtype).max + 1` values. That is perfectly fine for
384
+ # `torch.int8`, `torch.uint8`, and `torch.int16`, at least questionable for `torch.int32` and completely
385
+ # unfeasible for `torch.int64`.
386
+ # 2. Floating point inputs need to be binned for this algorithm. Apart from converting them to an integer dtype, we
387
+ # could also use PyTorch's builtin histogram functionality. However, that has its own set of issues: in addition
388
+ # to being slow in general, PyTorch's implementation also doesn't support batches. In total, that makes it slower
389
+ # and more complicated to implement than a simple conversion and a fast histogram implementation for integers.
390
+ # Since we need to convert in most cases anyway and out of the acceptable dtypes mentioned in 1. `torch.uint8` is
391
+ # by far the most common, we choose it as base.
386
392
output_dtype = image .dtype
387
- if image .is_floating_point ():
388
- # Floating point inputs need to be binned for this algorithm. Apart from converting them to an integer dtype, we
389
- # could also use PyTorch's builtin histogram functionality. However, that has its own set of issues: in addition
390
- # to being slow in general, PyTorch's implementation also doesn't support batches. In total, that makes it
391
- # slower and more complicated to implement than a simple conversion and a fast histogram implementation for
392
- # integers.
393
- image = convert_dtype_image_tensor (image , torch .uint8 )
393
+ image = convert_dtype_image_tensor (image , torch .uint8 )
394
394
395
395
# The histogram is computed by using the flattened image as index. For example, a pixel value of 127 in the image
396
396
# corresponds to adding 1 to index 127 in the histogram.
@@ -461,10 +461,13 @@ def equalize(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
461
461
462
462
463
463
def invert_image_tensor (image : torch .Tensor ) -> torch .Tensor :
464
- if image .dtype == torch .uint8 :
464
+ if image .is_floating_point ():
465
+ return 1.0 - image # type: ignore[no-any-return]
466
+ elif image .dtype == torch .uint8 :
465
467
return image .bitwise_not ()
466
- else :
467
- return (1 if image .is_floating_point () else 255 ) - image # type: ignore[no-any-return]
468
+ else : # signed integer dtypes
469
+ # We can't use `Tensor.bitwise_not` here, since we want to retain the leading zero bit that encodes the sign
470
+ return image .bitwise_xor ((1 << _num_value_bits (image .dtype )) - 1 )
468
471
469
472
470
473
invert_image_pil = _FP .invert
0 commit comments