From 1744362b9eb2f07801cabe3f450fbfd88cf34cbd Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Thu, 4 May 2023 16:41:21 +0200 Subject: [PATCH 1/7] Resize relies on interpolate's native uint8 handling Description: - Now that https://github.com/pytorch/pytorch/pull/90771 is merged, let Resize() rely on interpolate()'s native uint8 handling instead of converting to and from float. - uint8 input is not casted to f32 for nearest mode and bilinear mode if the latter has AVX2. Context: https://github.com/pytorch/vision/issues/7217 Benchmarks: ``` [----------- Resize cpu torch.uint8 InterpolationMode.NEAREST -----------] | resize v2 | resize stable | resize nightly 1 threads: --------------------------------------------------------------- (3, 400, 400) | 457 | 461 | 480 (16, 3, 400, 400) | 6870 | 6850 | 10100 Times are in microseconds (us). [---------- Resize cpu torch.uint8 InterpolationMode.BILINEAR -----------] | resize v2 | resize stable | resize nightly 1 threads: --------------------------------------------------------------- (3, 400, 400) | 326 | 329 | 844 (16, 3, 400, 400) | 4380 | 4390 | 14800 Times are in microseconds (us). ``` [Source](https://gist.github.com/vfdev-5/a2e30ed50b5996807c9b09d5d33d8bc2) --- torchvision/transforms/_functional_tensor.py | 12 +++++++++++- torchvision/transforms/v2/functional/_geometry.py | 12 +++++++++++- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/torchvision/transforms/_functional_tensor.py b/torchvision/transforms/_functional_tensor.py index d0e7c17882b..117ba488ce2 100644 --- a/torchvision/transforms/_functional_tensor.py +++ b/torchvision/transforms/_functional_tensor.py @@ -459,7 +459,17 @@ def resize( # now we don't as True is the default. antialias = False - img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [torch.float32, torch.float64]) + acceptable_dtypes = [torch.float32, torch.float64] + if interpolation in ["nearest", "nearest-exact"]: + # uint8 dtype can be included for cpu and cuda input if nearest mode + acceptable_dtypes.append(torch.uint8) + elif interpolation == "bilinear" and img.device.type == "cpu": + # uint8 dtype support for bilinear mode is limited to cpu and + # according to our benchmarks non-AVX CPUs should prefer u8->f32->interpolate->u8 path + if "AVX2" in torch.backends.cpu.get_cpu_capability(): + acceptable_dtypes.append(torch.uint8) + + img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, acceptable_dtypes) # Define align_corners to avoid warnings align_corners = False if interpolation in ["bilinear", "bicubic"] else None diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 1ceabbd80f0..629422f75cf 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -185,7 +185,17 @@ def resize_image_tensor( image = image.reshape(-1, num_channels, old_height, old_width) dtype = image.dtype - need_cast = dtype not in (torch.float32, torch.float64) + acceptable_dtypes = [torch.float32, torch.float64] + if interpolation in [InterpolationMode.NEAREST, InterpolationMode.NEAREST_EXACT]: + # uint8 dtype can be included for cpu and cuda input if nearest mode + acceptable_dtypes.append(torch.uint8) + elif interpolation == InterpolationMode.BILINEAR and image.device.type == "cpu": + # uint8 dtype support for bilinear mode is limited to cpu and + # according to our benchmarks non-AVX CPUs should prefer u8->f32->interpolate->u8 path + if "AVX2" in torch.backends.cpu.get_cpu_capability(): + acceptable_dtypes.append(torch.uint8) + + need_cast = dtype not in acceptable_dtypes if need_cast: image = image.to(dtype=torch.float32) From 6580907b036c67fc5a4ecb510d883f8cd2101214 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Thu, 4 May 2023 19:31:12 +0200 Subject: [PATCH 2/7] Skip torchscipt when checking cpu capability as torch.backends.cpu.get_cpu_capability() can't be scripted --- torchvision/transforms/_functional_tensor.py | 6 ++++-- torchvision/transforms/v2/functional/_geometry.py | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/torchvision/transforms/_functional_tensor.py b/torchvision/transforms/_functional_tensor.py index 117ba488ce2..307a52a16a2 100644 --- a/torchvision/transforms/_functional_tensor.py +++ b/torchvision/transforms/_functional_tensor.py @@ -466,8 +466,10 @@ def resize( elif interpolation == "bilinear" and img.device.type == "cpu": # uint8 dtype support for bilinear mode is limited to cpu and # according to our benchmarks non-AVX CPUs should prefer u8->f32->interpolate->u8 path - if "AVX2" in torch.backends.cpu.get_cpu_capability(): - acceptable_dtypes.append(torch.uint8) + # TODO: enable torchscript and torch.backends.cpu.get_cpu_capability + if not torch.jit.is_scripting(): + if "AVX2" in torch.backends.cpu.get_cpu_capability(): + acceptable_dtypes.append(torch.uint8) img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, acceptable_dtypes) diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 629422f75cf..06c783f6bb9 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -192,8 +192,10 @@ def resize_image_tensor( elif interpolation == InterpolationMode.BILINEAR and image.device.type == "cpu": # uint8 dtype support for bilinear mode is limited to cpu and # according to our benchmarks non-AVX CPUs should prefer u8->f32->interpolate->u8 path - if "AVX2" in torch.backends.cpu.get_cpu_capability(): - acceptable_dtypes.append(torch.uint8) + # TODO: enable torchscript and torch.backends.cpu.get_cpu_capability + if not torch.jit.is_scripting(): + if "AVX2" in torch.backends.cpu.get_cpu_capability(): + acceptable_dtypes.append(torch.uint8) need_cast = dtype not in acceptable_dtypes if need_cast: From 791488b833c7bc9d2a9e36e0397a544263776b88 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Tue, 9 May 2023 13:09:54 +0200 Subject: [PATCH 3/7] Reverted changes in v1 and updated v2 code --- torchvision/transforms/_functional_tensor.py | 14 +------------- torchvision/transforms/v2/functional/_geometry.py | 7 +++---- 2 files changed, 4 insertions(+), 17 deletions(-) diff --git a/torchvision/transforms/_functional_tensor.py b/torchvision/transforms/_functional_tensor.py index 307a52a16a2..d0e7c17882b 100644 --- a/torchvision/transforms/_functional_tensor.py +++ b/torchvision/transforms/_functional_tensor.py @@ -459,19 +459,7 @@ def resize( # now we don't as True is the default. antialias = False - acceptable_dtypes = [torch.float32, torch.float64] - if interpolation in ["nearest", "nearest-exact"]: - # uint8 dtype can be included for cpu and cuda input if nearest mode - acceptable_dtypes.append(torch.uint8) - elif interpolation == "bilinear" and img.device.type == "cpu": - # uint8 dtype support for bilinear mode is limited to cpu and - # according to our benchmarks non-AVX CPUs should prefer u8->f32->interpolate->u8 path - # TODO: enable torchscript and torch.backends.cpu.get_cpu_capability - if not torch.jit.is_scripting(): - if "AVX2" in torch.backends.cpu.get_cpu_capability(): - acceptable_dtypes.append(torch.uint8) - - img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, acceptable_dtypes) + img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [torch.float32, torch.float64]) # Define align_corners to avoid warnings align_corners = False if interpolation in ["bilinear", "bicubic"] else None diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 06c783f6bb9..e3f18e435c7 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -186,16 +186,15 @@ def resize_image_tensor( dtype = image.dtype acceptable_dtypes = [torch.float32, torch.float64] - if interpolation in [InterpolationMode.NEAREST, InterpolationMode.NEAREST_EXACT]: + if interpolation == InterpolationMode.NEAREST or interpolation == InterpolationMode.NEAREST_EXACT: # uint8 dtype can be included for cpu and cuda input if nearest mode acceptable_dtypes.append(torch.uint8) elif interpolation == InterpolationMode.BILINEAR and image.device.type == "cpu": # uint8 dtype support for bilinear mode is limited to cpu and # according to our benchmarks non-AVX CPUs should prefer u8->f32->interpolate->u8 path # TODO: enable torchscript and torch.backends.cpu.get_cpu_capability - if not torch.jit.is_scripting(): - if "AVX2" in torch.backends.cpu.get_cpu_capability(): - acceptable_dtypes.append(torch.uint8) + if "AVX2" in torch.backends.cpu.get_cpu_capability(): + acceptable_dtypes.append(torch.uint8) need_cast = dtype not in acceptable_dtypes if need_cast: From 427c4c1f6a1c0c8dabaebd3207148adc648b527e Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Tue, 9 May 2023 22:35:53 +0200 Subject: [PATCH 4/7] Added strides fix for 3D CL-like tensors in Resize Added tests on mem format --- test/common_utils.py | 23 +++++++++++++---- test/test_transforms_v2_functional.py | 25 +++++++++++++++++++ .../transforms/v2/functional/_geometry.py | 13 +++++++++- 3 files changed, 55 insertions(+), 6 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index c5826a36ff5..bf324d691c0 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -465,11 +465,15 @@ def load(self, device): class ImageLoader(TensorLoader): spatial_size: Tuple[int, int] = dataclasses.field(init=False) num_channels: int = dataclasses.field(init=False) + memory_format: torch.memory_format = torch.contiguous_format def __post_init__(self): self.spatial_size = self.shape[-2:] self.num_channels = self.shape[-3] + def load(self, device): + return self.fn(self.shape, self.dtype, device, memory_format=self.memory_format) + NUM_CHANNELS_MAP = { "GRAY": 1, @@ -530,11 +534,13 @@ def make_image_loaders( make_images = from_loaders(make_image_loaders) -def make_image_loader_for_interpolation(size="random", *, color_space="RGB", dtype=torch.uint8): +def make_image_loader_for_interpolation( + size="random", *, color_space="RGB", dtype=torch.uint8, memory_format=torch.contiguous_format +): size = _parse_spatial_size(size) num_channels = get_num_channels(color_space) - def fn(shape, dtype, device): + def fn(shape, dtype, device, memory_format): height, width = shape[-2:] image_pil = ( @@ -550,19 +556,26 @@ def fn(shape, dtype, device): ) ) - image_tensor = convert_dtype_image_tensor(to_image_tensor(image_pil).to(device=device), dtype=dtype) + image_tensor = to_image_tensor(image_pil) + if memory_format == torch.contiguous_format: + image_tensor = image_tensor.to(device=device, memory_format=memory_format, copy=True) + else: + image_tensor = image_tensor.to(device=device) + image_tensor = convert_dtype_image_tensor(image_tensor, dtype=dtype) + assert image_tensor[None].is_contiguous(memory_format=memory_format) return datapoints.Image(image_tensor) - return ImageLoader(fn, shape=(num_channels, *size), dtype=dtype) + return ImageLoader(fn, shape=(num_channels, *size), dtype=dtype, memory_format=memory_format) def make_image_loaders_for_interpolation( sizes=((233, 147),), color_spaces=("RGB",), dtypes=(torch.uint8,), + memory_formats=(torch.contiguous_format, torch.channels_last), ): - for params in combinations_grid(size=sizes, color_space=color_spaces, dtype=dtypes): + for params in combinations_grid(size=sizes, color_space=color_spaces, dtype=dtypes, memory_format=memory_formats): yield make_image_loader_for_interpolation(**params) diff --git a/test/test_transforms_v2_functional.py b/test/test_transforms_v2_functional.py index ee9576b6487..a74c7d73d26 100644 --- a/test/test_transforms_v2_functional.py +++ b/test/test_transforms_v2_functional.py @@ -1365,3 +1365,28 @@ def test_correctness_uniform_temporal_subsample(device): out_video = F.uniform_temporal_subsample(video, 8) assert out_video.unique().tolist() == [0, 1, 2, 3, 5, 6, 7, 9] + + +# TODO: We can remove this test and related torchvision workaround +# once we fixed related pytorch issue: https://github.com/pytorch/pytorch/issues/68430 +@make_info_args_kwargs_parametrization( + [info for info in KERNEL_INFOS if info.kernel is F.resize_image_tensor], + args_kwargs_fn=lambda info: info.reference_inputs_fn(), +) +def test_memory_format_consistency_resize_image_tensor(test_id, info, args_kwargs): + (input, *other_args), kwargs = args_kwargs.load("cpu") + + output = info.kernel(input.as_subclass(torch.Tensor), *other_args, **kwargs) + + error_msg_fn = parametrized_error_message(input, *other_args, **kwargs) + assert input.ndim == 3, error_msg_fn + input_stride = input.stride() + output_stride = output.stride() + if input_stride[-1] == 1: + expected_stride = (output.shape[-2] * output.shape[-1], output.shape[-1], 1) + assert expected_stride == output_stride, error_msg_fn("") + elif input_stride[0] == 1: + expected_stride = (1, output.shape[0] * output.shape[-1], output.shape[0]) + assert expected_stride == output_stride, error_msg_fn("") + else: + assert False, error_msg_fn("") diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index e3f18e435c7..6656d2fed0d 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -192,10 +192,21 @@ def resize_image_tensor( elif interpolation == InterpolationMode.BILINEAR and image.device.type == "cpu": # uint8 dtype support for bilinear mode is limited to cpu and # according to our benchmarks non-AVX CPUs should prefer u8->f32->interpolate->u8 path - # TODO: enable torchscript and torch.backends.cpu.get_cpu_capability if "AVX2" in torch.backends.cpu.get_cpu_capability(): acceptable_dtypes.append(torch.uint8) + if image.is_contiguous(memory_format=torch.channels_last): + strides = image.stride() + numel = image.numel() + if image.shape[0] == 1 and numel != strides[0]: + # This is the case when channels last tensor was squeezed and unsqueezed such that + # stride[0] set as image.shape[1] * image.stride()[1] instead of being image.numel() + # Let's restride image such that it will be correctly treated as channels last. + # Related pytorch issue: https://github.com/pytorch/pytorch/issues/68430 + new_strides = list(strides) + new_strides[0] = numel + image = image.as_strided((1, num_channels, old_height, old_width), new_strides) + need_cast = dtype not in acceptable_dtypes if need_cast: image = image.to(dtype=torch.float32) From 72ac231fa044ef8c3fb27c4008aa2bbe00934727 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Thu, 11 May 2023 22:03:10 +0200 Subject: [PATCH 5/7] Updates, patches and test updates --- test/common_utils.py | 15 ++++++++++----- test/test_transforms_v2_consistency.py | 13 ++++++++++--- test/transforms_v2_kernel_infos.py | 18 +++++++++++------- .../transforms/v2/functional/_geometry.py | 6 ++++++ 4 files changed, 37 insertions(+), 15 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index bf324d691c0..424717f6c8d 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -497,18 +497,21 @@ def make_image_loader( extra_dims=(), dtype=torch.float32, constant_alpha=True, + memory_format=torch.contiguous_format, ): size = _parse_spatial_size(size) num_channels = get_num_channels(color_space) - def fn(shape, dtype, device): + def fn(shape, dtype, device, memory_format): max_value = get_max_value(dtype) - data = torch.testing.make_tensor(shape, low=0, high=max_value, dtype=dtype, device=device) + data = torch.testing.make_tensor( + shape, low=0, high=max_value, dtype=dtype, device=device, memory_format=memory_format + ) if color_space in {"GRAY_ALPHA", "RGBA"} and constant_alpha: data[..., -1, :, :] = max_value return datapoints.Image(data) - return ImageLoader(fn, shape=(*extra_dims, num_channels, *size), dtype=dtype) + return ImageLoader(fn, shape=(*extra_dims, num_channels, *size), dtype=dtype, memory_format=memory_format) make_image = from_loader(make_image_loader) @@ -757,8 +760,10 @@ def make_video_loader( size = _parse_spatial_size(size) num_frames = int(torch.randint(1, 5, ())) if num_frames == "random" else num_frames - def fn(shape, dtype, device): - video = make_image(size=shape[-2:], extra_dims=shape[:-3], dtype=dtype, device=device) + def fn(shape, dtype, device, memory_format): + video = make_image( + size=shape[-2:], extra_dims=shape[:-3], dtype=dtype, device=device, memory_format=memory_format + ) return datapoints.Video(video) return VideoLoader(fn, shape=(*extra_dims, num_frames, get_num_channels(color_space), *size), dtype=dtype) diff --git a/test/test_transforms_v2_consistency.py b/test/test_transforms_v2_consistency.py index a8a87cd43dd..042f42b6f0b 100644 --- a/test/test_transforms_v2_consistency.py +++ b/test/test_transforms_v2_consistency.py @@ -98,6 +98,8 @@ def __init__( ArgsKwargs((29, 32), antialias=False), ArgsKwargs((28, 31), antialias=True), ], + # rtol=atol=1 due to Resize v2 is using native uint8 interpolate path for biliear and nearest modes + closeness_kwargs=dict(rtol=1, atol=1), ), ConsistencyConfig( v2_transforms.CenterCrop, @@ -313,6 +315,8 @@ def __init__( ArgsKwargs((29, 32), antialias=False), ArgsKwargs((28, 31), antialias=True), ], + # rtol=atol=1 due to Resize v2 is using native uint8 interpolate path for biliear and nearest modes + closeness_kwargs=dict(rtol=1, atol=1), ), ConsistencyConfig( v2_transforms.RandomErasing, @@ -783,7 +787,8 @@ def test_compose(self): ] ) - check_call_consistency(prototype_transform, legacy_transform) + # rtol=atol=1 due to Resize v2 is using native uint8 interpolate path for biliear and nearest modes + check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=1, atol=1)) @pytest.mark.parametrize("p", [0, 0.1, 0.5, 0.9, 1]) @pytest.mark.parametrize("sequence_type", [list, nn.ModuleList]) @@ -807,7 +812,8 @@ def test_random_apply(self, p, sequence_type): p=p, ) - check_call_consistency(prototype_transform, legacy_transform) + # rtol=atol=1 due to Resize v2 is using native uint8 interpolate path for biliear and nearest modes + check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=1, atol=1)) if sequence_type is nn.ModuleList: # quick and dirty test that it is jit-scriptable @@ -832,7 +838,8 @@ def test_random_choice(self, probabilities): p=probabilities, ) - check_call_consistency(prototype_transform, legacy_transform) + # rtol=atol=1 due to Resize v2 is using native uint8 interpolate path for biliear and nearest modes + check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=1, atol=1)) class TestToTensorTransforms: diff --git a/test/transforms_v2_kernel_infos.py b/test/transforms_v2_kernel_infos.py index 1678c3fb230..e5873f80d15 100644 --- a/test/transforms_v2_kernel_infos.py +++ b/test/transforms_v2_kernel_infos.py @@ -1569,7 +1569,7 @@ def reference_inputs_equalize_image_tensor(): # We are not using `make_image_loaders` here since that uniformly samples the values over the whole value range. # Since the whole point of this kernel is to transform an arbitrary distribution of values into a uniform one, # the information gain is low if we already provide something really close to the expected value. - def make_uniform_band_image(shape, dtype, device, *, low_factor, high_factor): + def make_uniform_band_image(shape, dtype, device, *, low_factor, high_factor, memory_format): if dtype.is_floating_point: low = low_factor high = high_factor @@ -1577,23 +1577,27 @@ def make_uniform_band_image(shape, dtype, device, *, low_factor, high_factor): max_value = torch.iinfo(dtype).max low = int(low_factor * max_value) high = int(high_factor * max_value) - return torch.testing.make_tensor(shape, dtype=dtype, device=device, low=low, high=high) + return torch.testing.make_tensor(shape, dtype=dtype, device=device, low=low, high=high).to( + memory_format=memory_format, copy=True + ) - def make_beta_distributed_image(shape, dtype, device, *, alpha, beta): + def make_beta_distributed_image(shape, dtype, device, *, alpha, beta, memory_format): image = torch.distributions.Beta(alpha, beta).sample(shape) if not dtype.is_floating_point: image.mul_(torch.iinfo(dtype).max).round_() - return image.to(dtype=dtype, device=device) + return image.to(dtype=dtype, device=device, memory_format=memory_format, copy=True) spatial_size = (256, 256) for dtype, color_space, fn in itertools.product( [torch.uint8], ["GRAY", "RGB"], [ - lambda shape, dtype, device: torch.zeros(shape, dtype=dtype, device=device), - lambda shape, dtype, device: torch.full( - shape, 1.0 if dtype.is_floating_point else torch.iinfo(dtype).max, dtype=dtype, device=device + lambda shape, dtype, device, memory_format: torch.zeros(shape, dtype=dtype, device=device).to( + memory_format=memory_format, copy=True ), + lambda shape, dtype, device, memory_format: torch.full( + shape, 1.0 if dtype.is_floating_point else torch.iinfo(dtype).max, dtype=dtype, device=device + ).to(memory_format=memory_format, copy=True), *[ functools.partial(make_uniform_band_image, low_factor=low_factor, high_factor=high_factor) for low_factor, high_factor in [ diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 6656d2fed0d..5a7800dd2bb 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -195,6 +195,12 @@ def resize_image_tensor( if "AVX2" in torch.backends.cpu.get_cpu_capability(): acceptable_dtypes.append(torch.uint8) + # TODO: Remove when https://github.com/pytorch/pytorch/pull/101136 is landed + if dtype == torch.uint8 and not ( + image.is_contiguous() or image.is_contiguous(memory_format=torch.channels_last) + ): + image = image.contiguous(memory_format=torch.channels_last) + if image.is_contiguous(memory_format=torch.channels_last): strides = image.stride() numel = image.numel() From b878b3d029ed22e2bb4d1ace2c9b5b0db278eb44 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 12 May 2023 17:39:01 +0200 Subject: [PATCH 6/7] Updates according to the review --- test/common_utils.py | 1 - test/test_transforms_v2_consistency.py | 10 +++---- test/test_transforms_v2_functional.py | 5 ++++ .../transforms/v2/functional/_geometry.py | 27 ++++++++++--------- 4 files changed, 25 insertions(+), 18 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index 424717f6c8d..1d0b82a827c 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -565,7 +565,6 @@ def fn(shape, dtype, device, memory_format): else: image_tensor = image_tensor.to(device=device) image_tensor = convert_dtype_image_tensor(image_tensor, dtype=dtype) - assert image_tensor[None].is_contiguous(memory_format=memory_format) return datapoints.Image(image_tensor) diff --git a/test/test_transforms_v2_consistency.py b/test/test_transforms_v2_consistency.py index 042f42b6f0b..652c649aae9 100644 --- a/test/test_transforms_v2_consistency.py +++ b/test/test_transforms_v2_consistency.py @@ -98,7 +98,7 @@ def __init__( ArgsKwargs((29, 32), antialias=False), ArgsKwargs((28, 31), antialias=True), ], - # rtol=atol=1 due to Resize v2 is using native uint8 interpolate path for biliear and nearest modes + # rtol=atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes closeness_kwargs=dict(rtol=1, atol=1), ), ConsistencyConfig( @@ -315,7 +315,7 @@ def __init__( ArgsKwargs((29, 32), antialias=False), ArgsKwargs((28, 31), antialias=True), ], - # rtol=atol=1 due to Resize v2 is using native uint8 interpolate path for biliear and nearest modes + # rtol=atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes closeness_kwargs=dict(rtol=1, atol=1), ), ConsistencyConfig( @@ -787,7 +787,7 @@ def test_compose(self): ] ) - # rtol=atol=1 due to Resize v2 is using native uint8 interpolate path for biliear and nearest modes + # rtol=atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=1, atol=1)) @pytest.mark.parametrize("p", [0, 0.1, 0.5, 0.9, 1]) @@ -812,7 +812,7 @@ def test_random_apply(self, p, sequence_type): p=p, ) - # rtol=atol=1 due to Resize v2 is using native uint8 interpolate path for biliear and nearest modes + # rtol=atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=1, atol=1)) if sequence_type is nn.ModuleList: @@ -838,7 +838,7 @@ def test_random_choice(self, probabilities): p=probabilities, ) - # rtol=atol=1 due to Resize v2 is using native uint8 interpolate path for biliear and nearest modes + # rtol=atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=1, atol=1)) diff --git a/test/test_transforms_v2_functional.py b/test/test_transforms_v2_functional.py index a74c7d73d26..ed861fee97e 100644 --- a/test/test_transforms_v2_functional.py +++ b/test/test_transforms_v2_functional.py @@ -1382,6 +1382,11 @@ def test_memory_format_consistency_resize_image_tensor(test_id, info, args_kwarg assert input.ndim == 3, error_msg_fn input_stride = input.stride() output_stride = output.stride() + # Here we check output memory format according to the input: + # if input_stride is (..., 1) then input is most likely channels first and thus + # output strides should match channels first strides (H * W, H, 1) + # if input_stride is (1, ...) then input is most likely channels last and thus + # output strides should match channels last strides (1, W * C, C) if input_stride[-1] == 1: expected_stride = (output.shape[-2] * output.shape[-1], output.shape[-1], 1) assert expected_stride == output_stride, error_msg_fn("") diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 5a7800dd2bb..c9551c9eea4 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -176,12 +176,13 @@ def resize_image_tensor( antialias = False shape = image.shape + numel = image.numel() num_channels, old_height, old_width = shape[-3:] new_height, new_width = _compute_resized_output_size((old_height, old_width), size=size, max_size=max_size) if (new_height, new_width) == (old_height, old_width): return image - elif image.numel() > 0: + elif numel > 0: image = image.reshape(-1, num_channels, old_height, old_width) dtype = image.dtype @@ -201,17 +202,19 @@ def resize_image_tensor( ): image = image.contiguous(memory_format=torch.channels_last) - if image.is_contiguous(memory_format=torch.channels_last): - strides = image.stride() - numel = image.numel() - if image.shape[0] == 1 and numel != strides[0]: - # This is the case when channels last tensor was squeezed and unsqueezed such that - # stride[0] set as image.shape[1] * image.stride()[1] instead of being image.numel() - # Let's restride image such that it will be correctly treated as channels last. - # Related pytorch issue: https://github.com/pytorch/pytorch/issues/68430 - new_strides = list(strides) - new_strides[0] = numel - image = image.as_strided((1, num_channels, old_height, old_width), new_strides) + strides = image.stride() + if image.is_contiguous(memory_format=torch.channels_last) and image.shape[0] == 1 and numel != strides[0]: + # There is a weird behaviour in torch core where the output tensor of `interpolate()` can be allocated as + # contiguous even though the input is un-ambiguously channels_last (https://github.com/pytorch/pytorch/issues/68430). + # In particular this happens for the typical torchvision use-case of single CHW images where we fake the batch dim + # to become 1CHW. Below, we restride those tensors to trick torch core into properly allocating the output as + # channels_last, thus preserving the memory format of the input. This is not just for format consistency: + # for uint8 bilinear images, this also avoids an extra copy (re-packing) of the output and saves time. + # TODO: when https://github.com/pytorch/pytorch/issues/68430 is fixed (possibly by https://github.com/pytorch/pytorch/pull/100373), + # we should be able to remove this hack. + new_strides = list(strides) + new_strides[0] = numel + image = image.as_strided((1, num_channels, old_height, old_width), new_strides) need_cast = dtype not in acceptable_dtypes if need_cast: From 0e27ad89334d0067ae7fa577285cb2cac19c6063 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Tue, 16 May 2023 12:50:30 +0200 Subject: [PATCH 7/7] Set rtol=0 --- test/test_transforms_v2_consistency.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/test/test_transforms_v2_consistency.py b/test/test_transforms_v2_consistency.py index 652c649aae9..05ab6b67af5 100644 --- a/test/test_transforms_v2_consistency.py +++ b/test/test_transforms_v2_consistency.py @@ -98,8 +98,8 @@ def __init__( ArgsKwargs((29, 32), antialias=False), ArgsKwargs((28, 31), antialias=True), ], - # rtol=atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes - closeness_kwargs=dict(rtol=1, atol=1), + # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes + closeness_kwargs=dict(rtol=0, atol=1), ), ConsistencyConfig( v2_transforms.CenterCrop, @@ -315,8 +315,8 @@ def __init__( ArgsKwargs((29, 32), antialias=False), ArgsKwargs((28, 31), antialias=True), ], - # rtol=atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes - closeness_kwargs=dict(rtol=1, atol=1), + # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes + closeness_kwargs=dict(rtol=0, atol=1), ), ConsistencyConfig( v2_transforms.RandomErasing, @@ -787,8 +787,8 @@ def test_compose(self): ] ) - # rtol=atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes - check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=1, atol=1)) + # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes + check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1)) @pytest.mark.parametrize("p", [0, 0.1, 0.5, 0.9, 1]) @pytest.mark.parametrize("sequence_type", [list, nn.ModuleList]) @@ -812,8 +812,8 @@ def test_random_apply(self, p, sequence_type): p=p, ) - # rtol=atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes - check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=1, atol=1)) + # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes + check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1)) if sequence_type is nn.ModuleList: # quick and dirty test that it is jit-scriptable @@ -838,8 +838,8 @@ def test_random_choice(self, probabilities): p=probabilities, ) - # rtol=atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes - check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=1, atol=1)) + # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes + check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1)) class TestToTensorTransforms: