Skip to content

Commit 72ac231

Browse files
committed
Updates, patches and test updates
1 parent 427c4c1 commit 72ac231

File tree

4 files changed

+37
-15
lines changed

4 files changed

+37
-15
lines changed

test/common_utils.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -497,18 +497,21 @@ def make_image_loader(
497497
extra_dims=(),
498498
dtype=torch.float32,
499499
constant_alpha=True,
500+
memory_format=torch.contiguous_format,
500501
):
501502
size = _parse_spatial_size(size)
502503
num_channels = get_num_channels(color_space)
503504

504-
def fn(shape, dtype, device):
505+
def fn(shape, dtype, device, memory_format):
505506
max_value = get_max_value(dtype)
506-
data = torch.testing.make_tensor(shape, low=0, high=max_value, dtype=dtype, device=device)
507+
data = torch.testing.make_tensor(
508+
shape, low=0, high=max_value, dtype=dtype, device=device, memory_format=memory_format
509+
)
507510
if color_space in {"GRAY_ALPHA", "RGBA"} and constant_alpha:
508511
data[..., -1, :, :] = max_value
509512
return datapoints.Image(data)
510513

511-
return ImageLoader(fn, shape=(*extra_dims, num_channels, *size), dtype=dtype)
514+
return ImageLoader(fn, shape=(*extra_dims, num_channels, *size), dtype=dtype, memory_format=memory_format)
512515

513516

514517
make_image = from_loader(make_image_loader)
@@ -757,8 +760,10 @@ def make_video_loader(
757760
size = _parse_spatial_size(size)
758761
num_frames = int(torch.randint(1, 5, ())) if num_frames == "random" else num_frames
759762

760-
def fn(shape, dtype, device):
761-
video = make_image(size=shape[-2:], extra_dims=shape[:-3], dtype=dtype, device=device)
763+
def fn(shape, dtype, device, memory_format):
764+
video = make_image(
765+
size=shape[-2:], extra_dims=shape[:-3], dtype=dtype, device=device, memory_format=memory_format
766+
)
762767
return datapoints.Video(video)
763768

764769
return VideoLoader(fn, shape=(*extra_dims, num_frames, get_num_channels(color_space), *size), dtype=dtype)

test/test_transforms_v2_consistency.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ def __init__(
9898
ArgsKwargs((29, 32), antialias=False),
9999
ArgsKwargs((28, 31), antialias=True),
100100
],
101+
# rtol=atol=1 due to Resize v2 is using native uint8 interpolate path for biliear and nearest modes
102+
closeness_kwargs=dict(rtol=1, atol=1),
101103
),
102104
ConsistencyConfig(
103105
v2_transforms.CenterCrop,
@@ -313,6 +315,8 @@ def __init__(
313315
ArgsKwargs((29, 32), antialias=False),
314316
ArgsKwargs((28, 31), antialias=True),
315317
],
318+
# rtol=atol=1 due to Resize v2 is using native uint8 interpolate path for biliear and nearest modes
319+
closeness_kwargs=dict(rtol=1, atol=1),
316320
),
317321
ConsistencyConfig(
318322
v2_transforms.RandomErasing,
@@ -783,7 +787,8 @@ def test_compose(self):
783787
]
784788
)
785789

786-
check_call_consistency(prototype_transform, legacy_transform)
790+
# rtol=atol=1 due to Resize v2 is using native uint8 interpolate path for biliear and nearest modes
791+
check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=1, atol=1))
787792

788793
@pytest.mark.parametrize("p", [0, 0.1, 0.5, 0.9, 1])
789794
@pytest.mark.parametrize("sequence_type", [list, nn.ModuleList])
@@ -807,7 +812,8 @@ def test_random_apply(self, p, sequence_type):
807812
p=p,
808813
)
809814

810-
check_call_consistency(prototype_transform, legacy_transform)
815+
# rtol=atol=1 due to Resize v2 is using native uint8 interpolate path for biliear and nearest modes
816+
check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=1, atol=1))
811817

812818
if sequence_type is nn.ModuleList:
813819
# quick and dirty test that it is jit-scriptable
@@ -832,7 +838,8 @@ def test_random_choice(self, probabilities):
832838
p=probabilities,
833839
)
834840

835-
check_call_consistency(prototype_transform, legacy_transform)
841+
# rtol=atol=1 due to Resize v2 is using native uint8 interpolate path for biliear and nearest modes
842+
check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=1, atol=1))
836843

837844

838845
class TestToTensorTransforms:

test/transforms_v2_kernel_infos.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1569,31 +1569,35 @@ def reference_inputs_equalize_image_tensor():
15691569
# We are not using `make_image_loaders` here since that uniformly samples the values over the whole value range.
15701570
# Since the whole point of this kernel is to transform an arbitrary distribution of values into a uniform one,
15711571
# the information gain is low if we already provide something really close to the expected value.
1572-
def make_uniform_band_image(shape, dtype, device, *, low_factor, high_factor):
1572+
def make_uniform_band_image(shape, dtype, device, *, low_factor, high_factor, memory_format):
15731573
if dtype.is_floating_point:
15741574
low = low_factor
15751575
high = high_factor
15761576
else:
15771577
max_value = torch.iinfo(dtype).max
15781578
low = int(low_factor * max_value)
15791579
high = int(high_factor * max_value)
1580-
return torch.testing.make_tensor(shape, dtype=dtype, device=device, low=low, high=high)
1580+
return torch.testing.make_tensor(shape, dtype=dtype, device=device, low=low, high=high).to(
1581+
memory_format=memory_format, copy=True
1582+
)
15811583

1582-
def make_beta_distributed_image(shape, dtype, device, *, alpha, beta):
1584+
def make_beta_distributed_image(shape, dtype, device, *, alpha, beta, memory_format):
15831585
image = torch.distributions.Beta(alpha, beta).sample(shape)
15841586
if not dtype.is_floating_point:
15851587
image.mul_(torch.iinfo(dtype).max).round_()
1586-
return image.to(dtype=dtype, device=device)
1588+
return image.to(dtype=dtype, device=device, memory_format=memory_format, copy=True)
15871589

15881590
spatial_size = (256, 256)
15891591
for dtype, color_space, fn in itertools.product(
15901592
[torch.uint8],
15911593
["GRAY", "RGB"],
15921594
[
1593-
lambda shape, dtype, device: torch.zeros(shape, dtype=dtype, device=device),
1594-
lambda shape, dtype, device: torch.full(
1595-
shape, 1.0 if dtype.is_floating_point else torch.iinfo(dtype).max, dtype=dtype, device=device
1595+
lambda shape, dtype, device, memory_format: torch.zeros(shape, dtype=dtype, device=device).to(
1596+
memory_format=memory_format, copy=True
15961597
),
1598+
lambda shape, dtype, device, memory_format: torch.full(
1599+
shape, 1.0 if dtype.is_floating_point else torch.iinfo(dtype).max, dtype=dtype, device=device
1600+
).to(memory_format=memory_format, copy=True),
15971601
*[
15981602
functools.partial(make_uniform_band_image, low_factor=low_factor, high_factor=high_factor)
15991603
for low_factor, high_factor in [

torchvision/transforms/v2/functional/_geometry.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,12 @@ def resize_image_tensor(
195195
if "AVX2" in torch.backends.cpu.get_cpu_capability():
196196
acceptable_dtypes.append(torch.uint8)
197197

198+
# TODO: Remove when https://github.com/pytorch/pytorch/pull/101136 is landed
199+
if dtype == torch.uint8 and not (
200+
image.is_contiguous() or image.is_contiguous(memory_format=torch.channels_last)
201+
):
202+
image = image.contiguous(memory_format=torch.channels_last)
203+
198204
if image.is_contiguous(memory_format=torch.channels_last):
199205
strides = image.stride()
200206
numel = image.numel()

0 commit comments

Comments
 (0)