Skip to content

Commit e3b1d1e

Browse files
NicolasHugpmeiervfdev-5
authored andcommitted
[fbsync] Change default of antialias parameter from None to 'warn' (#7160)
Summary: Reviewed By: vmoens Differential Revision: D44416275 fbshipit-source-id: 916691a68545d0b487d9ac20b4a8f42ec42315b6 Co-authored-by: Philip Meier <[email protected]> Co-authored-by: vfdev <[email protected]>
1 parent 833f5b9 commit e3b1d1e

21 files changed

+345
-79
lines changed

gallery/plot_optical_flow.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,10 @@ def plot(imgs, **imshow_kwargs):
8181

8282
#########################
8383
# The RAFT model accepts RGB images. We first get the frames from
84-
# :func:`~torchvision.io.read_video` and resize them to ensure their
85-
# dimensions are divisible by 8. Then we use the transforms bundled into the
86-
# weights in order to preprocess the input and rescale its values to the
84+
# :func:`~torchvision.io.read_video` and resize them to ensure their dimensions
85+
# are divisible by 8. Note that we explicitly use ``antialias=False``, because
86+
# this is how those models were trained. Then we use the transforms bundled into
87+
# the weights in order to preprocess the input and rescale its values to the
8788
# required ``[-1, 1]`` interval.
8889

8990
from torchvision.models.optical_flow import Raft_Large_Weights
@@ -93,8 +94,8 @@ def plot(imgs, **imshow_kwargs):
9394

9495

9596
def preprocess(img1_batch, img2_batch):
96-
img1_batch = F.resize(img1_batch, size=[520, 960])
97-
img2_batch = F.resize(img2_batch, size=[520, 960])
97+
img1_batch = F.resize(img1_batch, size=[520, 960], antialias=False)
98+
img2_batch = F.resize(img2_batch, size=[520, 960], antialias=False)
9899
return transforms(img1_batch, img2_batch)
99100

100101

references/depth/stereo/transforms.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,11 @@ def forward(
455455
INTERP_MODE = self._interpolation_mode_strategy()
456456

457457
for img in images:
458-
resized_images += (F.resize(img, self.resize_size, interpolation=INTERP_MODE),)
458+
# We hard-code antialias=False to preserve results after we changed
459+
# its default from None to True (see
460+
# https://github.com/pytorch/vision/pull/7160)
461+
# TODO: we could re-train the stereo models with antialias=True?
462+
resized_images += (F.resize(img, self.resize_size, interpolation=INTERP_MODE, antialias=False),)
459463

460464
for dsp in disparities:
461465
if dsp is not None:

references/optical_flow/transforms.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,12 @@ def forward(self, img1, img2, flow, valid_flow_mask):
196196

197197
if torch.rand(1).item() < self.resize_prob:
198198
# rescale the images
199-
img1 = F.resize(img1, size=(new_h, new_w))
200-
img2 = F.resize(img2, size=(new_h, new_w))
199+
# We hard-code antialias=False to preserve results after we changed
200+
# its default from None to True (see
201+
# https://github.com/pytorch/vision/pull/7160)
202+
# TODO: we could re-train the OF models with antialias=True?
203+
img1 = F.resize(img1, size=(new_h, new_w), antialias=False)
204+
img2 = F.resize(img2, size=(new_h, new_w), antialias=False)
201205
if valid_flow_mask is None:
202206
flow = F.resize(flow, size=(new_h, new_w))
203207
flow = flow * torch.tensor([scale_x, scale_y])[:, None, None]

references/video_classification/presets.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,11 @@ def __init__(
1515
):
1616
trans = [
1717
transforms.ConvertImageDtype(torch.float32),
18-
transforms.Resize(resize_size),
18+
# We hard-code antialias=False to preserve results after we changed
19+
# its default from None to True (see
20+
# https://github.com/pytorch/vision/pull/7160)
21+
# TODO: we could re-train the video models with antialias=True?
22+
transforms.Resize(resize_size, antialias=False),
1923
]
2024
if hflip_prob > 0:
2125
trans.append(transforms.RandomHorizontalFlip(hflip_prob))
@@ -31,7 +35,11 @@ def __init__(self, *, crop_size, resize_size, mean=(0.43216, 0.394666, 0.37645),
3135
self.transforms = transforms.Compose(
3236
[
3337
transforms.ConvertImageDtype(torch.float32),
34-
transforms.Resize(resize_size),
38+
# We hard-code antialias=False to preserve results after we changed
39+
# its default from None to True (see
40+
# https://github.com/pytorch/vision/pull/7160)
41+
# TODO: we could re-train the video models with antialias=True?
42+
transforms.Resize(resize_size, antialias=False),
3543
transforms.Normalize(mean=mean, std=std),
3644
transforms.CenterCrop(crop_size),
3745
ConvertBCHWtoCBHW(),

test/test_functional_tensor.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import itertools
33
import math
44
import os
5+
import warnings
56
from functools import partial
67
from typing import Sequence
78

@@ -483,8 +484,8 @@ def test_resize(device, dt, size, max_size, interpolation):
483484
tensor = tensor.to(dt)
484485
batch_tensors = batch_tensors.to(dt)
485486

486-
resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, max_size=max_size)
487-
resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation, max_size=max_size)
487+
resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, max_size=max_size, antialias=True)
488+
resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation, max_size=max_size, antialias=True)
488489

489490
assert resized_tensor.size()[1:] == resized_pil_img.size[::-1]
490491

@@ -509,10 +510,12 @@ def test_resize(device, dt, size, max_size, interpolation):
509510
else:
510511
script_size = size
511512

512-
resize_result = script_fn(tensor, size=script_size, interpolation=interpolation, max_size=max_size)
513+
resize_result = script_fn(tensor, size=script_size, interpolation=interpolation, max_size=max_size, antialias=True)
513514
assert_equal(resized_tensor, resize_result)
514515

515-
_test_fn_on_batch(batch_tensors, F.resize, size=script_size, interpolation=interpolation, max_size=max_size)
516+
_test_fn_on_batch(
517+
batch_tensors, F.resize, size=script_size, interpolation=interpolation, max_size=max_size, antialias=True
518+
)
516519

517520

518521
@pytest.mark.parametrize("device", cpu_and_gpu())
@@ -547,7 +550,7 @@ def test_resize_antialias(device, dt, size, interpolation):
547550
tensor = tensor.to(dt)
548551

549552
resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, antialias=True)
550-
resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation)
553+
resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation, antialias=True)
551554

552555
assert resized_tensor.size()[1:] == resized_pil_img.size[::-1]
553556

@@ -596,6 +599,23 @@ def test_assert_resize_antialias(interpolation):
596599
F.resize(tensor, size=(5, 5), interpolation=interpolation, antialias=True)
597600

598601

602+
def test_resize_antialias_default_warning():
603+
604+
img = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8)
605+
606+
match = "The default value of the antialias"
607+
with pytest.warns(UserWarning, match=match):
608+
F.resize(img, size=(20, 20))
609+
with pytest.warns(UserWarning, match=match):
610+
F.resized_crop(img, 0, 0, 10, 10, size=(20, 20))
611+
612+
# For modes that aren't bicubic or bilinear, don't throw a warning
613+
with warnings.catch_warnings():
614+
warnings.simplefilter("error")
615+
F.resize(img, size=(20, 20), interpolation=NEAREST)
616+
F.resized_crop(img, 0, 0, 10, 10, size=(20, 20), interpolation=NEAREST)
617+
618+
599619
@pytest.mark.parametrize("device", cpu_and_gpu())
600620
@pytest.mark.parametrize("dt", [torch.float32, torch.float64, torch.float16])
601621
@pytest.mark.parametrize("size", [[10, 7], [10, 42], [42, 7]])
@@ -924,7 +944,9 @@ def test_resized_crop(device, mode):
924944
# 1) resize to the same size, crop to the same size => should be identity
925945
tensor, _ = _create_data(26, 36, device=device)
926946

927-
out_tensor = F.resized_crop(tensor, top=0, left=0, height=26, width=36, size=[26, 36], interpolation=mode)
947+
out_tensor = F.resized_crop(
948+
tensor, top=0, left=0, height=26, width=36, size=[26, 36], interpolation=mode, antialias=True
949+
)
928950
assert_equal(tensor, out_tensor, msg=f"{out_tensor[0, :5, :5]} vs {tensor[0, :5, :5]}")
929951

930952
# 2) resize by half and crop a TL corner
@@ -939,7 +961,14 @@ def test_resized_crop(device, mode):
939961

940962
batch_tensors = _create_data_batch(26, 36, num_samples=4, device=device)
941963
_test_fn_on_batch(
942-
batch_tensors, F.resized_crop, top=1, left=2, height=20, width=30, size=[10, 15], interpolation=NEAREST
964+
batch_tensors,
965+
F.resized_crop,
966+
top=1,
967+
left=2,
968+
height=20,
969+
width=30,
970+
size=[10, 15],
971+
interpolation=NEAREST,
943972
)
944973

945974

test/test_models.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,5 +1050,25 @@ def test_raft(model_fn, scripted):
10501050
_assert_expected(flow_pred.cpu(), name=model_fn.__name__, atol=1e-2, rtol=1)
10511051

10521052

1053+
def test_presets_antialias():
1054+
1055+
img = torch.randint(0, 256, size=(1, 3, 224, 224), dtype=torch.uint8)
1056+
1057+
match = "The default value of the antialias parameter"
1058+
with pytest.warns(UserWarning, match=match):
1059+
models.ResNet18_Weights.DEFAULT.transforms()(img)
1060+
with pytest.warns(UserWarning, match=match):
1061+
models.segmentation.DeepLabV3_ResNet50_Weights.DEFAULT.transforms()(img)
1062+
1063+
with warnings.catch_warnings():
1064+
warnings.simplefilter("error")
1065+
models.ResNet18_Weights.DEFAULT.transforms(antialias=True)(img)
1066+
models.segmentation.DeepLabV3_ResNet50_Weights.DEFAULT.transforms(antialias=True)(img)
1067+
1068+
models.detection.FasterRCNN_ResNet50_FPN_Weights.DEFAULT.transforms()(img)
1069+
models.video.R3D_18_Weights.DEFAULT.transforms()(img)
1070+
models.optical_flow.Raft_Small_Weights.DEFAULT.transforms()(img, img)
1071+
1072+
10531073
if __name__ == "__main__":
10541074
pytest.main([__file__])

test/test_prototype_transforms.py

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import itertools
22
import re
3+
import warnings
34
from collections import defaultdict
45

56
import numpy as np
@@ -94,7 +95,7 @@ def parametrize_from_transforms(*transforms):
9495
class TestSmoke:
9596
@parametrize_from_transforms(
9697
transforms.RandomErasing(p=1.0),
97-
transforms.Resize([16, 16]),
98+
transforms.Resize([16, 16], antialias=True),
9899
transforms.CenterCrop([16, 16]),
99100
transforms.ConvertDtype(),
100101
transforms.RandomHorizontalFlip(),
@@ -210,7 +211,7 @@ def test_normalize(self, transform, input):
210211
@parametrize(
211212
[
212213
(
213-
transforms.RandomResizedCrop([16, 16]),
214+
transforms.RandomResizedCrop([16, 16], antialias=True),
214215
itertools.chain(
215216
make_images(extra_dims=[(4,)]),
216217
make_vanilla_tensor_images(),
@@ -1991,6 +1992,70 @@ def test__transform(self, inpt):
19911992
assert output.dtype == inpt.dtype
19921993

19931994

1995+
# TODO: remove this test in 0.17 when the default of antialias changes to True
1996+
def test_antialias_warning():
1997+
pil_img = PIL.Image.new("RGB", size=(10, 10), color=127)
1998+
tensor_img = torch.randint(0, 256, size=(3, 10, 10), dtype=torch.uint8)
1999+
tensor_video = torch.randint(0, 256, size=(2, 3, 10, 10), dtype=torch.uint8)
2000+
2001+
match = "The default value of the antialias parameter"
2002+
with pytest.warns(UserWarning, match=match):
2003+
transforms.Resize((20, 20))(tensor_img)
2004+
with pytest.warns(UserWarning, match=match):
2005+
transforms.RandomResizedCrop((20, 20))(tensor_img)
2006+
with pytest.warns(UserWarning, match=match):
2007+
transforms.ScaleJitter((20, 20))(tensor_img)
2008+
with pytest.warns(UserWarning, match=match):
2009+
transforms.RandomShortestSize((20, 20))(tensor_img)
2010+
with pytest.warns(UserWarning, match=match):
2011+
transforms.RandomResize(10, 20)(tensor_img)
2012+
2013+
with pytest.warns(UserWarning, match=match):
2014+
transforms.functional.resize(tensor_img, (20, 20))
2015+
with pytest.warns(UserWarning, match=match):
2016+
transforms.functional.resize_image_tensor(tensor_img, (20, 20))
2017+
2018+
with pytest.warns(UserWarning, match=match):
2019+
transforms.functional.resize(tensor_video, (20, 20))
2020+
with pytest.warns(UserWarning, match=match):
2021+
transforms.functional.resize_video(tensor_video, (20, 20))
2022+
2023+
with pytest.warns(UserWarning, match=match):
2024+
datapoints.Image(tensor_img).resize((20, 20))
2025+
with pytest.warns(UserWarning, match=match):
2026+
datapoints.Image(tensor_img).resized_crop(0, 0, 10, 10, (20, 20))
2027+
2028+
with pytest.warns(UserWarning, match=match):
2029+
datapoints.Video(tensor_video).resize((20, 20))
2030+
with pytest.warns(UserWarning, match=match):
2031+
datapoints.Video(tensor_video).resized_crop(0, 0, 10, 10, (20, 20))
2032+
2033+
with warnings.catch_warnings():
2034+
warnings.simplefilter("error")
2035+
transforms.Resize((20, 20))(pil_img)
2036+
transforms.RandomResizedCrop((20, 20))(pil_img)
2037+
transforms.ScaleJitter((20, 20))(pil_img)
2038+
transforms.RandomShortestSize((20, 20))(pil_img)
2039+
transforms.RandomResize(10, 20)(pil_img)
2040+
transforms.functional.resize(pil_img, (20, 20))
2041+
2042+
transforms.Resize((20, 20), antialias=True)(tensor_img)
2043+
transforms.RandomResizedCrop((20, 20), antialias=True)(tensor_img)
2044+
transforms.ScaleJitter((20, 20), antialias=True)(tensor_img)
2045+
transforms.RandomShortestSize((20, 20), antialias=True)(tensor_img)
2046+
transforms.RandomResize(10, 20, antialias=True)(tensor_img)
2047+
2048+
transforms.functional.resize(tensor_img, (20, 20), antialias=True)
2049+
transforms.functional.resize_image_tensor(tensor_img, (20, 20), antialias=True)
2050+
transforms.functional.resize(tensor_video, (20, 20), antialias=True)
2051+
transforms.functional.resize_video(tensor_video, (20, 20), antialias=True)
2052+
2053+
datapoints.Image(tensor_img).resize((20, 20), antialias=True)
2054+
datapoints.Image(tensor_img).resized_crop(0, 0, 10, 10, (20, 20), antialias=True)
2055+
datapoints.Video(tensor_video).resize((20, 20), antialias=True)
2056+
datapoints.Video(tensor_video).resized_crop(0, 0, 10, 10, (20, 20), antialias=True)
2057+
2058+
19942059
@pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, datapoints.Image))
19952060
@pytest.mark.parametrize("label_type", (torch.Tensor, int))
19962061
@pytest.mark.parametrize("dataset_return_type", (dict, tuple))

test/test_transforms.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import random
44
import re
5+
import warnings
56
from functools import partial
67

78
import numpy as np
@@ -319,7 +320,7 @@ def test_randomresized_params():
319320
scale_range = (scale_min, scale_min + round(random.random(), 2))
320321
aspect_min = max(round(random.random(), 2), epsilon)
321322
aspect_ratio_range = (aspect_min, aspect_min + round(random.random(), 2))
322-
randresizecrop = transforms.RandomResizedCrop(size, scale_range, aspect_ratio_range)
323+
randresizecrop = transforms.RandomResizedCrop(size, scale_range, aspect_ratio_range, antialias=True)
323324
i, j, h, w = randresizecrop.get_params(img, scale_range, aspect_ratio_range)
324325
aspect_ratio_obtained = w / h
325326
assert (
@@ -366,7 +367,7 @@ def test_randomresized_params():
366367
def test_resize(height, width, osize, max_size):
367368
img = Image.new("RGB", size=(width, height), color=127)
368369

369-
t = transforms.Resize(osize, max_size=max_size)
370+
t = transforms.Resize(osize, max_size=max_size, antialias=True)
370371
result = t(img)
371372

372373
msg = f"{height}, {width} - {osize} - {max_size}"
@@ -424,7 +425,7 @@ def test_resize_sequence_output(height, width, osize):
424425
img = Image.new("RGB", size=(width, height), color=127)
425426
oheight, owidth = osize
426427

427-
t = transforms.Resize(osize)
428+
t = transforms.Resize(osize, antialias=True)
428429
result = t(img)
429430

430431
assert (owidth, oheight) == result.size
@@ -439,6 +440,16 @@ def test_resize_antialias_error():
439440
t(img)
440441

441442

443+
def test_resize_antialias_default_warning():
444+
445+
img = Image.new("RGB", size=(10, 10), color=127)
446+
# We make sure we don't warn for PIL images since the default behaviour doesn't change
447+
with warnings.catch_warnings():
448+
warnings.simplefilter("error")
449+
transforms.Resize((20, 20))(img)
450+
transforms.RandomResizedCrop((20, 20))(img)
451+
452+
442453
@pytest.mark.parametrize("height, width", ((32, 64), (64, 32)))
443454
def test_resize_size_equals_small_edge_size(height, width):
444455
# Non-regression test for https://github.com/pytorch/vision/issues/5405
@@ -447,7 +458,7 @@ def test_resize_size_equals_small_edge_size(height, width):
447458
img = Image.new("RGB", size=(width, height), color=127)
448459

449460
small_edge = min(height, width)
450-
t = transforms.Resize(small_edge, max_size=max_size)
461+
t = transforms.Resize(small_edge, max_size=max_size, antialias=True)
451462
result = t(img)
452463
assert max(result.size) == max_size
453464

@@ -1424,11 +1435,11 @@ def test_random_choice(proba_passthrough, seed):
14241435
def test_random_order():
14251436
random_state = random.getstate()
14261437
random.seed(42)
1427-
random_order_transform = transforms.RandomOrder([transforms.Resize(20), transforms.CenterCrop(10)])
1438+
random_order_transform = transforms.RandomOrder([transforms.Resize(20, antialias=True), transforms.CenterCrop(10)])
14281439
img = transforms.ToPILImage()(torch.rand(3, 25, 25))
14291440
num_samples = 250
14301441
num_normal_order = 0
1431-
resize_crop_out = transforms.CenterCrop(10)(transforms.Resize(20)(img))
1442+
resize_crop_out = transforms.CenterCrop(10)(transforms.Resize(20, antialias=True)(img))
14321443
for _ in range(num_samples):
14331444
out = random_order_transform(img)
14341445
if out == resize_crop_out:

0 commit comments

Comments
 (0)