From 794bbb11788540e449f8c0faa2646cfcf6bfa023 Mon Sep 17 00:00:00 2001 From: Brian Date: Fri, 17 Jul 2020 14:05:00 -0400 Subject: [PATCH 1/7] make convert_image_dtype scriptable --- test/test_transforms.py | 4 +++ torchvision/transforms/functional.py | 37 +++++++++++++++------ torchvision/transforms/functional_tensor.py | 8 ++--- 3 files changed, 35 insertions(+), 14 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 61ec525961d..ec79d8c63cc 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -526,6 +526,10 @@ def test_to_tensor(self): output = trans(img) self.assertTrue(np.allclose(input_data.numpy(), output.numpy())) + def test_max_value(self): + for dtype in int_dtypes(): + self.assertEqual(F._max_value(dtype), torch.iinfo(dtype).max) + def test_convert_image_dtype_float_to_float(self): for input_dtype, output_dtypes in cycle_over(float_dtypes()): input_image = torch.tensor((0.0, 1.0), dtype=input_dtype) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index f3d1f96089f..5cd601bcb73 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -124,7 +124,24 @@ def pil_to_tensor(pic): return img -def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor: +# torch.iinfo isn't scriptable so using this helper function +# https://github.com/pytorch/pytorch/issues/41492 +def _max_value(dtype: int) -> int: + a = torch.tensor(2, dtype=dtype) + signed = 1 if torch.tensor(0, dtype=dtype).is_signed() else 0 + bits = 1 + max_value = torch.tensor(-signed, dtype=torch.long) + while(True): + next_value = a.pow(bits - signed).sub(1) + if next_value > max_value: + max_value = next_value + bits *= 2 + else: + return max_value.item() + return max_value.item() + + +def convert_image_dtype(image: torch.Tensor, dtype: int = torch.float) -> torch.Tensor: """Convert a tensor image to the given ``dtype`` and scale the values accordingly Args: @@ -148,9 +165,9 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) - if image.dtype == dtype: return image - if image.dtype.is_floating_point: + if torch.empty(0, dtype=image.dtype).is_floating_point(): # float to float - if dtype.is_floating_point: + if torch.tensor(0, dtype=dtype).is_floating_point(): return image.to(dtype) # float to int @@ -166,19 +183,19 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) - # `max + 1 - epsilon` provides more evenly distributed mapping of # ranges of floats to ints. eps = 1e-3 - result = image.mul(torch.iinfo(dtype).max + 1 - eps) + max_val = _max_value(dtype) + result = image.mul(max_val + 1.0 - eps) return result.to(dtype) else: + input_max = _max_value(image.dtype) + output_max = _max_value(dtype) + # int to float - if dtype.is_floating_point: - max = torch.iinfo(image.dtype).max + if torch.tensor(0, dtype=dtype).is_floating_point(): image = image.to(dtype) - return image / max + return image / input_max # int to int - input_max = torch.iinfo(image.dtype).max - output_max = torch.iinfo(dtype).max - if input_max > output_max: factor = (input_max + 1) // (output_max + 1) image = image // factor diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index f2e47b056d3..4627e641eed 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -6,6 +6,8 @@ from torch.nn.functional import affine_grid, grid_sample from torch.jit.annotations import List, BroadcastingList2 +import torchvision.transforms.functional as F + def _is_tensor_a_torch_image(x: Tensor) -> bool: return x.ndim >= 2 @@ -228,13 +230,11 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: result = img dtype = img.dtype if not torch.is_floating_point(img): - result = result / 255.0 + result = F.convert_image_dtype(result, torch.get_default_dtype()) result = (gain * result ** gamma).clamp(0, 1) - if result.dtype != dtype: - eps = 1e-3 - result = (255 + 1.0 - eps) * result + result = F.convert_image_dtype(result, dtype) result = result.to(dtype) return result From f610463f1aa886de26dd1b8fe8b95d3592694875 Mon Sep 17 00:00:00 2001 From: Brian Date: Fri, 17 Jul 2020 14:34:53 -0400 Subject: [PATCH 2/7] move convert dtype to functional_tensor since only works on tensors --- test/test_functional_tensor.py | 4 +- test/test_transforms.py | 3 +- torchvision/transforms/functional.py | 82 ------------------- torchvision/transforms/functional_tensor.py | 88 ++++++++++++++++++++- torchvision/transforms/transforms.py | 4 +- 5 files changed, 90 insertions(+), 91 deletions(-) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 2e3477ad12b..847ba14356c 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -309,7 +309,7 @@ def test_adjust_gamma(self): for dt in [torch.float64, torch.float32, None]: if dt is not None: - tensor = F.convert_image_dtype(tensor, dt) + tensor = F_t.convert_image_dtype(tensor, dt) gammas = [0.8, 1.0, 1.2] gains = [0.7, 1.0, 1.3] @@ -323,7 +323,7 @@ def test_adjust_gamma(self): rbg_tensor = adjusted_tensor if adjusted_tensor.dtype != torch.uint8: - rbg_tensor = F.convert_image_dtype(adjusted_tensor, torch.uint8) + rbg_tensor = F_t.convert_image_dtype(adjusted_tensor, torch.uint8) self.compareTensorToPIL(rbg_tensor, adjusted_pil) diff --git a/test/test_transforms.py b/test/test_transforms.py index ec79d8c63cc..ac13de8d1f6 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -2,6 +2,7 @@ import torch import torchvision.transforms as transforms import torchvision.transforms.functional as F +import torchvision.transforms.functional_tensor as F_t from torch._utils_internal import get_file_path_2 from numpy.testing import assert_array_almost_equal import unittest @@ -528,7 +529,7 @@ def test_to_tensor(self): def test_max_value(self): for dtype in int_dtypes(): - self.assertEqual(F._max_value(dtype), torch.iinfo(dtype).max) + self.assertEqual(F_t._max_value(dtype), torch.iinfo(dtype).max) def test_convert_image_dtype_float_to_float(self): for input_dtype, output_dtypes in cycle_over(float_dtypes()): diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 5cd601bcb73..8bf589fbaf8 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -124,88 +124,6 @@ def pil_to_tensor(pic): return img -# torch.iinfo isn't scriptable so using this helper function -# https://github.com/pytorch/pytorch/issues/41492 -def _max_value(dtype: int) -> int: - a = torch.tensor(2, dtype=dtype) - signed = 1 if torch.tensor(0, dtype=dtype).is_signed() else 0 - bits = 1 - max_value = torch.tensor(-signed, dtype=torch.long) - while(True): - next_value = a.pow(bits - signed).sub(1) - if next_value > max_value: - max_value = next_value - bits *= 2 - else: - return max_value.item() - return max_value.item() - - -def convert_image_dtype(image: torch.Tensor, dtype: int = torch.float) -> torch.Tensor: - """Convert a tensor image to the given ``dtype`` and scale the values accordingly - - Args: - image (torch.Tensor): Image to be converted - dtype (torch.dtype): Desired data type of the output - - Returns: - (torch.Tensor): Converted image - - .. note:: - - When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly. - If converted back and forth, this mismatch has no effect. - - Raises: - RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as - well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to - overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range - of the integer ``dtype``. - """ - if image.dtype == dtype: - return image - - if torch.empty(0, dtype=image.dtype).is_floating_point(): - # float to float - if torch.tensor(0, dtype=dtype).is_floating_point(): - return image.to(dtype) - - # float to int - if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or ( - image.dtype == torch.float64 and dtype == torch.int64 - ): - msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely." - raise RuntimeError(msg) - - # https://github.com/pytorch/vision/pull/2078#issuecomment-612045321 - # For data in the range 0-1, (float * 255).to(uint) is only 255 - # when float is exactly 1.0. - # `max + 1 - epsilon` provides more evenly distributed mapping of - # ranges of floats to ints. - eps = 1e-3 - max_val = _max_value(dtype) - result = image.mul(max_val + 1.0 - eps) - return result.to(dtype) - else: - input_max = _max_value(image.dtype) - output_max = _max_value(dtype) - - # int to float - if torch.tensor(0, dtype=dtype).is_floating_point(): - image = image.to(dtype) - return image / input_max - - # int to int - if input_max > output_max: - factor = (input_max + 1) // (output_max + 1) - image = image // factor - return image.to(dtype) - else: - factor = (output_max + 1) // (input_max + 1) - image = image.to(dtype) - return image * factor - - def to_pil_image(pic, mode=None): """Convert a tensor or an ndarray to PIL Image. diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 4627e641eed..1c99ea0b55f 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -6,8 +6,6 @@ from torch.nn.functional import affine_grid, grid_sample from torch.jit.annotations import List, BroadcastingList2 -import torchvision.transforms.functional as F - def _is_tensor_a_torch_image(x: Tensor) -> bool: return x.ndim >= 2 @@ -20,6 +18,88 @@ def _get_image_size(img: Tensor) -> List[int]: raise TypeError("Unexpected type {}".format(type(img))) +# torch.iinfo isn't scriptable so using this helper function +# https://github.com/pytorch/pytorch/issues/41492 +def _max_value(dtype: int) -> int: + a = torch.tensor(2, dtype=dtype) + signed = 1 if torch.tensor(0, dtype=dtype).is_signed() else 0 + bits = 1 + max_value = torch.tensor(-signed, dtype=torch.long) + while(True): + next_value = a.pow(bits - signed).sub(1) + if next_value > max_value: + max_value = next_value + bits *= 2 + else: + return max_value.item() + return max_value.item() + + +def convert_image_dtype(image: torch.Tensor, dtype: int = torch.float) -> torch.Tensor: + """Convert a tensor image to the given ``dtype`` and scale the values accordingly + + Args: + image (torch.Tensor): Image to be converted + dtype (torch.dtype): Desired data type of the output + + Returns: + (torch.Tensor): Converted image + + .. note:: + + When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly. + If converted back and forth, this mismatch has no effect. + + Raises: + RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as + well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to + overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range + of the integer ``dtype``. + """ + if image.dtype == dtype: + return image + + if torch.empty(0, dtype=image.dtype).is_floating_point(): + # float to float + if torch.tensor(0, dtype=dtype).is_floating_point(): + return image.to(dtype) + + # float to int + if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or ( + image.dtype == torch.float64 and dtype == torch.int64 + ): + msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely." + raise RuntimeError(msg) + + # https://github.com/pytorch/vision/pull/2078#issuecomment-612045321 + # For data in the range 0-1, (float * 255).to(uint) is only 255 + # when float is exactly 1.0. + # `max + 1 - epsilon` provides more evenly distributed mapping of + # ranges of floats to ints. + eps = 1e-3 + max_val = _max_value(dtype) + result = image.mul(max_val + 1.0 - eps) + return result.to(dtype) + else: + input_max = _max_value(image.dtype) + output_max = _max_value(dtype) + + # int to float + if torch.tensor(0, dtype=dtype).is_floating_point(): + image = image.to(dtype) + return image / input_max + + # int to int + if input_max > output_max: + factor = (input_max + 1) // (output_max + 1) + image = image // factor + return image.to(dtype) + else: + factor = (output_max + 1) // (input_max + 1) + image = image.to(dtype) + return image * factor + + def vflip(img: Tensor) -> Tensor: """Vertically flip the given the Image Tensor. @@ -230,11 +310,11 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: result = img dtype = img.dtype if not torch.is_floating_point(img): - result = F.convert_image_dtype(result, torch.get_default_dtype()) + result = convert_image_dtype(result, torch.float32) result = (gain * result ** gamma).clamp(0, 1) - result = F.convert_image_dtype(result, dtype) + result = convert_image_dtype(result, dtype) result = result.to(dtype) return result diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index f7d421d2b83..540ca43776d 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -16,7 +16,7 @@ accimage = None from . import functional as F - +from . import functional_tensor as F_t __all__ = ["Compose", "ToTensor", "PILToTensor", "ConvertImageDtype", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", @@ -131,7 +131,7 @@ def __init__(self, dtype: torch.dtype) -> None: self.dtype = dtype def __call__(self, image: torch.Tensor) -> torch.Tensor: - return F.convert_image_dtype(image, self.dtype) + return F_t.convert_image_dtype(image, self.dtype) class ToPILImage(object): From ace0f93ba0d66dff1914b4750a62a1561fbe7e37 Mon Sep 17 00:00:00 2001 From: Brian Date: Fri, 7 Aug 2020 09:53:58 -0400 Subject: [PATCH 3/7] retain availability of convert_image_dtype in functional.py --- torchvision/transforms/functional.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 8bf589fbaf8..c4fff4729c3 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -124,6 +124,12 @@ def pil_to_tensor(pic): return img +# import to main namespace +# this is temporary until we merge the implementation of +# F_t inside functional.py +convert_image_dtype = F_t.convert_image_dtype + + def to_pil_image(pic, mode=None): """Convert a tensor or an ndarray to PIL Image. From 242a0b2af984eb8fa5fa8c43465dc0b9ab1768f6 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 2 Oct 2020 16:37:06 +0000 Subject: [PATCH 4/7] Update code and tests --- test/test_functional_tensor.py | 4 +++ test/test_transforms.py | 28 ++++++++++++++++++++ torchvision/transforms/functional.py | 29 ++++++++++++++++++--- torchvision/transforms/functional_tensor.py | 25 +++++++++++++----- torchvision/transforms/transforms.py | 5 ++-- 5 files changed, 78 insertions(+), 13 deletions(-) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 87373359e83..855754ea68d 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -744,6 +744,10 @@ def test_perspective(self): batch_tensors, F.perspective, startpoints=spoints, endpoints=epoints, interpolation=0 ) + def test_convert_image_dtype(self): + # TODO: add tests of CPU/CUDA on tensor and batch + pass + @unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device") class CUDATester(Tester): diff --git a/test/test_transforms.py b/test/test_transforms.py index d1cf34cb351..fb11fd2aa05 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -539,13 +539,22 @@ def test_max_value(self): for dtype in int_dtypes(): self.assertEqual(F_t._max_value(dtype), torch.iinfo(dtype).max) + for dtype in float_dtypes(): + self.assertGreater(F_t._max_value(dtype), torch.finfo(dtype).max) + def test_convert_image_dtype_float_to_float(self): for input_dtype, output_dtypes in cycle_over(float_dtypes()): input_image = torch.tensor((0.0, 1.0), dtype=input_dtype) for output_dtype in output_dtypes: with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): transform = transforms.ConvertImageDtype(output_dtype) + transform_script = torch.jit.script(F.convert_image_dtype) + output_image = transform(input_image) + output_image_script = transform_script(input_image, output_dtype) + + script_diff = output_image_script - output_image + self.assertLess(script_diff.abs().max(), 1e-6) actual_min, actual_max = output_image.tolist() desired_min, desired_max = 0.0, 1.0 @@ -559,6 +568,7 @@ def test_convert_image_dtype_float_to_int(self): for output_dtype in int_dtypes(): with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): transform = transforms.ConvertImageDtype(output_dtype) + transform_script = torch.jit.script(F.convert_image_dtype) if (input_dtype == torch.float32 and output_dtype in (torch.int32, torch.int64)) or ( input_dtype == torch.float64 and output_dtype == torch.int64 @@ -567,6 +577,10 @@ def test_convert_image_dtype_float_to_int(self): transform(input_image) else: output_image = transform(input_image) + output_image_script = transform_script(input_image, output_dtype) + + script_diff = output_image_script - output_image + self.assertLess(script_diff.abs().max(), 1e-6) actual_min, actual_max = output_image.tolist() desired_min, desired_max = 0, torch.iinfo(output_dtype).max @@ -580,7 +594,13 @@ def test_convert_image_dtype_int_to_float(self): for output_dtype in float_dtypes(): with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): transform = transforms.ConvertImageDtype(output_dtype) + transform_script = torch.jit.script(F.convert_image_dtype) + output_image = transform(input_image) + output_image_script = transform_script(input_image, output_dtype) + + script_diff = output_image_script - output_image + self.assertLess(script_diff.abs().max(), 1e-6) actual_min, actual_max = output_image.tolist() desired_min, desired_max = 0.0, 1.0 @@ -599,7 +619,15 @@ def test_convert_image_dtype_int_to_int(self): with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): transform = transforms.ConvertImageDtype(output_dtype) + transform_script = torch.jit.script(F.convert_image_dtype) + output_image = transform(input_image) + output_image_script = transform_script(input_image, output_dtype) + + script_diff = output_image_script.float() - output_image.float() + self.assertLess( + script_diff.abs().max(), 1e-6, msg="{} vs {}".format(output_image_script, output_image) + ) actual_min, actual_max = output_image.tolist() desired_min, desired_max = 0, output_max diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index a6e899eeba7..07eeb796ffb 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -131,10 +131,31 @@ def pil_to_tensor(pic): return img -# import to main namespace -# this is temporary until we merge the implementation of -# F_t inside functional.py -convert_image_dtype = F_t.convert_image_dtype +def convert_image_dtype(image: torch.Tensor, dtype: int = torch.float) -> torch.Tensor: + """Convert a tensor image to the given ``dtype`` and scale the values accordingly + + Args: + image (torch.Tensor): Image to be converted + dtype (torch.dtype): Desired data type of the output + + Returns: + (torch.Tensor): Converted image + + .. note:: + + When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly. + If converted back and forth, this mismatch has no effect. + + Raises: + RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as + well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to + overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range + of the integer ``dtype``. + """ + if not isinstance(image, torch.Tensor): + raise TypeError('Input img should be Tensor Image') + + return F_t.convert_image_dtype(image, dtype) def to_pil_image(pic, mode=None): diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index d96c7b9b6d4..d515735eac1 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -27,14 +27,15 @@ def _get_image_num_channels(img: Tensor) -> int: raise TypeError("Input ndim should be 2 or more. Got {}".format(img.ndim)) -# torch.iinfo isn't scriptable so using this helper function -# https://github.com/pytorch/pytorch/issues/41492 -def _max_value(dtype: int) -> int: +def _max_value(dtype: int) -> float: + # TODO: replace this method with torch.iinfo when it gets torchscript support. + # https://github.com/pytorch/pytorch/issues/41492 + a = torch.tensor(2, dtype=dtype) signed = 1 if torch.tensor(0, dtype=dtype).is_signed() else 0 bits = 1 max_value = torch.tensor(-signed, dtype=torch.long) - while(True): + while True: next_value = a.pow(bits - signed).sub(1) if next_value > max_value: max_value = next_value @@ -45,7 +46,12 @@ def _max_value(dtype: int) -> int: def convert_image_dtype(image: torch.Tensor, dtype: int = torch.float) -> torch.Tensor: - """Convert a tensor image to the given ``dtype`` and scale the values accordingly + """PRIVATE METHOD. Convert a tensor image to the given ``dtype`` and scale the values accordingly + + .. warning:: + + Module ``transforms.functional_tensor`` is private and should not be used in user application. + Please, consider instead using methods from `transforms.functional` module. Args: image (torch.Tensor): Image to be converted @@ -68,8 +74,10 @@ def convert_image_dtype(image: torch.Tensor, dtype: int = torch.float) -> torch. if image.dtype == dtype: return image + # TODO: replace with image.dtype.is_floating_point when torchscript supports it if torch.empty(0, dtype=image.dtype).is_floating_point(): - # float to float + + # TODO: replace with dtype.is_floating_point when torchscript supports it if torch.tensor(0, dtype=dtype).is_floating_point(): return image.to(dtype) @@ -94,13 +102,16 @@ def convert_image_dtype(image: torch.Tensor, dtype: int = torch.float) -> torch. output_max = _max_value(dtype) # int to float + # TODO: replace with dtype.is_floating_point when torchscript supports it if torch.tensor(0, dtype=dtype).is_floating_point(): image = image.to(dtype) return image / input_max # int to int if input_max > output_max: - factor = (input_max + 1) // (output_max + 1) + # factor should be forced to int for torch jit script + # otherwise factor is a float and image // factor can produce different results + factor = int((input_max + 1) // (output_max + 1)) image = image // factor return image.to(dtype) else: diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 651d4173688..a3efeab4d2d 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -127,7 +127,7 @@ def __repr__(self): return self.__class__.__name__ + '()' -class ConvertImageDtype: +class ConvertImageDtype(torch.nn.Module): """Convert a tensor image to the given ``dtype`` and scale the values accordingly Args: @@ -146,9 +146,10 @@ class ConvertImageDtype: """ def __init__(self, dtype: torch.dtype) -> None: + super().__init__() self.dtype = dtype - def __call__(self, image: torch.Tensor) -> torch.Tensor: + def forward(self, image: torch.Tensor) -> torch.Tensor: return F_t.convert_image_dtype(image, self.dtype) From b098a1453216f037042645f65a58b32f4000bb5c Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 2 Oct 2020 17:04:25 +0000 Subject: [PATCH 5/7] Replaced int by torch.dtype --- torchvision/transforms/functional.py | 2 +- torchvision/transforms/functional_tensor.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 07eeb796ffb..9085b0c45e8 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -131,7 +131,7 @@ def pil_to_tensor(pic): return img -def convert_image_dtype(image: torch.Tensor, dtype: int = torch.float) -> torch.Tensor: +def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor: """Convert a tensor image to the given ``dtype`` and scale the values accordingly Args: diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index d515735eac1..c9a7ab57dc7 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -45,7 +45,7 @@ def _max_value(dtype: int) -> float: return max_value.item() -def convert_image_dtype(image: torch.Tensor, dtype: int = torch.float) -> torch.Tensor: +def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor: """PRIVATE METHOD. Convert a tensor image to the given ``dtype`` and scale the values accordingly .. warning:: From e5781479c46803d18341e176447ecc1e08955447 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 2 Oct 2020 17:14:01 +0000 Subject: [PATCH 6/7] int -> torch.dtype and use F instead of F_t --- torchvision/transforms/functional_tensor.py | 2 +- torchvision/transforms/transforms.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index c9a7ab57dc7..e1d756bd8c4 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -27,7 +27,7 @@ def _get_image_num_channels(img: Tensor) -> int: raise TypeError("Input ndim should be 2 or more. Got {}".format(img.ndim)) -def _max_value(dtype: int) -> float: +def _max_value(dtype: torch.dtype) -> float: # TODO: replace this method with torch.iinfo when it gets torchscript support. # https://github.com/pytorch/pytorch/issues/41492 diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index a3efeab4d2d..2a585f98c3f 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -15,7 +15,6 @@ accimage = None from . import functional as F -from . import functional_tensor as F_t __all__ = ["Compose", "ToTensor", "PILToTensor", "ConvertImageDtype", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", @@ -150,7 +149,7 @@ def __init__(self, dtype: torch.dtype) -> None: self.dtype = dtype def forward(self, image: torch.Tensor) -> torch.Tensor: - return F_t.convert_image_dtype(image, self.dtype) + return F.convert_image_dtype(image, self.dtype) class ToPILImage: From 0790f3044e29a5b7c8b46c5ffcb69ba684a8068e Mon Sep 17 00:00:00 2001 From: vfdev Date: Mon, 5 Oct 2020 12:54:20 +0200 Subject: [PATCH 7/7] Update functional_tensor.py --- torchvision/transforms/functional_tensor.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index e1d756bd8c4..5436aeff9c0 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -115,7 +115,9 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) - image = image // factor return image.to(dtype) else: - factor = (output_max + 1) // (input_max + 1) + # factor should be forced to int for torch jit script + # otherwise factor is a float and image * factor can produce different results + factor = int((output_max + 1) // (input_max + 1)) image = image.to(dtype) return image * factor