diff --git a/test/common_utils.py b/test/common_utils.py index c5826a36ff5..1d0b82a827c 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -465,11 +465,15 @@ def load(self, device): class ImageLoader(TensorLoader): spatial_size: Tuple[int, int] = dataclasses.field(init=False) num_channels: int = dataclasses.field(init=False) + memory_format: torch.memory_format = torch.contiguous_format def __post_init__(self): self.spatial_size = self.shape[-2:] self.num_channels = self.shape[-3] + def load(self, device): + return self.fn(self.shape, self.dtype, device, memory_format=self.memory_format) + NUM_CHANNELS_MAP = { "GRAY": 1, @@ -493,18 +497,21 @@ def make_image_loader( extra_dims=(), dtype=torch.float32, constant_alpha=True, + memory_format=torch.contiguous_format, ): size = _parse_spatial_size(size) num_channels = get_num_channels(color_space) - def fn(shape, dtype, device): + def fn(shape, dtype, device, memory_format): max_value = get_max_value(dtype) - data = torch.testing.make_tensor(shape, low=0, high=max_value, dtype=dtype, device=device) + data = torch.testing.make_tensor( + shape, low=0, high=max_value, dtype=dtype, device=device, memory_format=memory_format + ) if color_space in {"GRAY_ALPHA", "RGBA"} and constant_alpha: data[..., -1, :, :] = max_value return datapoints.Image(data) - return ImageLoader(fn, shape=(*extra_dims, num_channels, *size), dtype=dtype) + return ImageLoader(fn, shape=(*extra_dims, num_channels, *size), dtype=dtype, memory_format=memory_format) make_image = from_loader(make_image_loader) @@ -530,11 +537,13 @@ def make_image_loaders( make_images = from_loaders(make_image_loaders) -def make_image_loader_for_interpolation(size="random", *, color_space="RGB", dtype=torch.uint8): +def make_image_loader_for_interpolation( + size="random", *, color_space="RGB", dtype=torch.uint8, memory_format=torch.contiguous_format +): size = _parse_spatial_size(size) num_channels = get_num_channels(color_space) - def fn(shape, dtype, device): + def fn(shape, dtype, device, memory_format): height, width = shape[-2:] image_pil = ( @@ -550,19 +559,25 @@ def fn(shape, dtype, device): ) ) - image_tensor = convert_dtype_image_tensor(to_image_tensor(image_pil).to(device=device), dtype=dtype) + image_tensor = to_image_tensor(image_pil) + if memory_format == torch.contiguous_format: + image_tensor = image_tensor.to(device=device, memory_format=memory_format, copy=True) + else: + image_tensor = image_tensor.to(device=device) + image_tensor = convert_dtype_image_tensor(image_tensor, dtype=dtype) return datapoints.Image(image_tensor) - return ImageLoader(fn, shape=(num_channels, *size), dtype=dtype) + return ImageLoader(fn, shape=(num_channels, *size), dtype=dtype, memory_format=memory_format) def make_image_loaders_for_interpolation( sizes=((233, 147),), color_spaces=("RGB",), dtypes=(torch.uint8,), + memory_formats=(torch.contiguous_format, torch.channels_last), ): - for params in combinations_grid(size=sizes, color_space=color_spaces, dtype=dtypes): + for params in combinations_grid(size=sizes, color_space=color_spaces, dtype=dtypes, memory_format=memory_formats): yield make_image_loader_for_interpolation(**params) @@ -744,8 +759,10 @@ def make_video_loader( size = _parse_spatial_size(size) num_frames = int(torch.randint(1, 5, ())) if num_frames == "random" else num_frames - def fn(shape, dtype, device): - video = make_image(size=shape[-2:], extra_dims=shape[:-3], dtype=dtype, device=device) + def fn(shape, dtype, device, memory_format): + video = make_image( + size=shape[-2:], extra_dims=shape[:-3], dtype=dtype, device=device, memory_format=memory_format + ) return datapoints.Video(video) return VideoLoader(fn, shape=(*extra_dims, num_frames, get_num_channels(color_space), *size), dtype=dtype) diff --git a/test/test_transforms_v2_consistency.py b/test/test_transforms_v2_consistency.py index a8a87cd43dd..05ab6b67af5 100644 --- a/test/test_transforms_v2_consistency.py +++ b/test/test_transforms_v2_consistency.py @@ -98,6 +98,8 @@ def __init__( ArgsKwargs((29, 32), antialias=False), ArgsKwargs((28, 31), antialias=True), ], + # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes + closeness_kwargs=dict(rtol=0, atol=1), ), ConsistencyConfig( v2_transforms.CenterCrop, @@ -313,6 +315,8 @@ def __init__( ArgsKwargs((29, 32), antialias=False), ArgsKwargs((28, 31), antialias=True), ], + # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes + closeness_kwargs=dict(rtol=0, atol=1), ), ConsistencyConfig( v2_transforms.RandomErasing, @@ -783,7 +787,8 @@ def test_compose(self): ] ) - check_call_consistency(prototype_transform, legacy_transform) + # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes + check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1)) @pytest.mark.parametrize("p", [0, 0.1, 0.5, 0.9, 1]) @pytest.mark.parametrize("sequence_type", [list, nn.ModuleList]) @@ -807,7 +812,8 @@ def test_random_apply(self, p, sequence_type): p=p, ) - check_call_consistency(prototype_transform, legacy_transform) + # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes + check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1)) if sequence_type is nn.ModuleList: # quick and dirty test that it is jit-scriptable @@ -832,7 +838,8 @@ def test_random_choice(self, probabilities): p=probabilities, ) - check_call_consistency(prototype_transform, legacy_transform) + # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes + check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1)) class TestToTensorTransforms: diff --git a/test/test_transforms_v2_functional.py b/test/test_transforms_v2_functional.py index ee9576b6487..ed861fee97e 100644 --- a/test/test_transforms_v2_functional.py +++ b/test/test_transforms_v2_functional.py @@ -1365,3 +1365,33 @@ def test_correctness_uniform_temporal_subsample(device): out_video = F.uniform_temporal_subsample(video, 8) assert out_video.unique().tolist() == [0, 1, 2, 3, 5, 6, 7, 9] + + +# TODO: We can remove this test and related torchvision workaround +# once we fixed related pytorch issue: https://github.com/pytorch/pytorch/issues/68430 +@make_info_args_kwargs_parametrization( + [info for info in KERNEL_INFOS if info.kernel is F.resize_image_tensor], + args_kwargs_fn=lambda info: info.reference_inputs_fn(), +) +def test_memory_format_consistency_resize_image_tensor(test_id, info, args_kwargs): + (input, *other_args), kwargs = args_kwargs.load("cpu") + + output = info.kernel(input.as_subclass(torch.Tensor), *other_args, **kwargs) + + error_msg_fn = parametrized_error_message(input, *other_args, **kwargs) + assert input.ndim == 3, error_msg_fn + input_stride = input.stride() + output_stride = output.stride() + # Here we check output memory format according to the input: + # if input_stride is (..., 1) then input is most likely channels first and thus + # output strides should match channels first strides (H * W, H, 1) + # if input_stride is (1, ...) then input is most likely channels last and thus + # output strides should match channels last strides (1, W * C, C) + if input_stride[-1] == 1: + expected_stride = (output.shape[-2] * output.shape[-1], output.shape[-1], 1) + assert expected_stride == output_stride, error_msg_fn("") + elif input_stride[0] == 1: + expected_stride = (1, output.shape[0] * output.shape[-1], output.shape[0]) + assert expected_stride == output_stride, error_msg_fn("") + else: + assert False, error_msg_fn("") diff --git a/test/transforms_v2_kernel_infos.py b/test/transforms_v2_kernel_infos.py index 1678c3fb230..e5873f80d15 100644 --- a/test/transforms_v2_kernel_infos.py +++ b/test/transforms_v2_kernel_infos.py @@ -1569,7 +1569,7 @@ def reference_inputs_equalize_image_tensor(): # We are not using `make_image_loaders` here since that uniformly samples the values over the whole value range. # Since the whole point of this kernel is to transform an arbitrary distribution of values into a uniform one, # the information gain is low if we already provide something really close to the expected value. - def make_uniform_band_image(shape, dtype, device, *, low_factor, high_factor): + def make_uniform_band_image(shape, dtype, device, *, low_factor, high_factor, memory_format): if dtype.is_floating_point: low = low_factor high = high_factor @@ -1577,23 +1577,27 @@ def make_uniform_band_image(shape, dtype, device, *, low_factor, high_factor): max_value = torch.iinfo(dtype).max low = int(low_factor * max_value) high = int(high_factor * max_value) - return torch.testing.make_tensor(shape, dtype=dtype, device=device, low=low, high=high) + return torch.testing.make_tensor(shape, dtype=dtype, device=device, low=low, high=high).to( + memory_format=memory_format, copy=True + ) - def make_beta_distributed_image(shape, dtype, device, *, alpha, beta): + def make_beta_distributed_image(shape, dtype, device, *, alpha, beta, memory_format): image = torch.distributions.Beta(alpha, beta).sample(shape) if not dtype.is_floating_point: image.mul_(torch.iinfo(dtype).max).round_() - return image.to(dtype=dtype, device=device) + return image.to(dtype=dtype, device=device, memory_format=memory_format, copy=True) spatial_size = (256, 256) for dtype, color_space, fn in itertools.product( [torch.uint8], ["GRAY", "RGB"], [ - lambda shape, dtype, device: torch.zeros(shape, dtype=dtype, device=device), - lambda shape, dtype, device: torch.full( - shape, 1.0 if dtype.is_floating_point else torch.iinfo(dtype).max, dtype=dtype, device=device + lambda shape, dtype, device, memory_format: torch.zeros(shape, dtype=dtype, device=device).to( + memory_format=memory_format, copy=True ), + lambda shape, dtype, device, memory_format: torch.full( + shape, 1.0 if dtype.is_floating_point else torch.iinfo(dtype).max, dtype=dtype, device=device + ).to(memory_format=memory_format, copy=True), *[ functools.partial(make_uniform_band_image, low_factor=low_factor, high_factor=high_factor) for low_factor, high_factor in [ diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 1ceabbd80f0..c9551c9eea4 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -176,16 +176,47 @@ def resize_image_tensor( antialias = False shape = image.shape + numel = image.numel() num_channels, old_height, old_width = shape[-3:] new_height, new_width = _compute_resized_output_size((old_height, old_width), size=size, max_size=max_size) if (new_height, new_width) == (old_height, old_width): return image - elif image.numel() > 0: + elif numel > 0: image = image.reshape(-1, num_channels, old_height, old_width) dtype = image.dtype - need_cast = dtype not in (torch.float32, torch.float64) + acceptable_dtypes = [torch.float32, torch.float64] + if interpolation == InterpolationMode.NEAREST or interpolation == InterpolationMode.NEAREST_EXACT: + # uint8 dtype can be included for cpu and cuda input if nearest mode + acceptable_dtypes.append(torch.uint8) + elif interpolation == InterpolationMode.BILINEAR and image.device.type == "cpu": + # uint8 dtype support for bilinear mode is limited to cpu and + # according to our benchmarks non-AVX CPUs should prefer u8->f32->interpolate->u8 path + if "AVX2" in torch.backends.cpu.get_cpu_capability(): + acceptable_dtypes.append(torch.uint8) + + # TODO: Remove when https://github.com/pytorch/pytorch/pull/101136 is landed + if dtype == torch.uint8 and not ( + image.is_contiguous() or image.is_contiguous(memory_format=torch.channels_last) + ): + image = image.contiguous(memory_format=torch.channels_last) + + strides = image.stride() + if image.is_contiguous(memory_format=torch.channels_last) and image.shape[0] == 1 and numel != strides[0]: + # There is a weird behaviour in torch core where the output tensor of `interpolate()` can be allocated as + # contiguous even though the input is un-ambiguously channels_last (https://github.com/pytorch/pytorch/issues/68430). + # In particular this happens for the typical torchvision use-case of single CHW images where we fake the batch dim + # to become 1CHW. Below, we restride those tensors to trick torch core into properly allocating the output as + # channels_last, thus preserving the memory format of the input. This is not just for format consistency: + # for uint8 bilinear images, this also avoids an extra copy (re-packing) of the output and saves time. + # TODO: when https://github.com/pytorch/pytorch/issues/68430 is fixed (possibly by https://github.com/pytorch/pytorch/pull/100373), + # we should be able to remove this hack. + new_strides = list(strides) + new_strides[0] = numel + image = image.as_strided((1, num_channels, old_height, old_width), new_strides) + + need_cast = dtype not in acceptable_dtypes if need_cast: image = image.to(dtype=torch.float32)