|
14 | 14 | import torchvision.transforms as T
|
15 | 15 | from torchvision.transforms import InterpolationMode
|
16 | 16 |
|
17 |
| -from common_utils import TransformsTester, cpu_and_gpu |
| 17 | +from common_utils import TransformsTester, cpu_and_gpu, needs_cuda |
18 | 18 |
|
19 | 19 | from typing import Dict, List, Sequence, Tuple
|
20 | 20 |
|
@@ -868,12 +868,14 @@ def test_perspective_interpolation_warning(tester):
|
868 | 868 | tester.assertTrue(res1.equal(res2))
|
869 | 869 |
|
870 | 870 |
|
871 |
| -@pytest.mark.parametrize('device', ["cpu", ]) |
| 871 | +@pytest.mark.parametrize('device', cpu_and_gpu()) |
872 | 872 | @pytest.mark.parametrize('dt', [None, torch.float32, torch.float64, torch.float16])
|
873 | 873 | @pytest.mark.parametrize('size', [[96, 72], [96, 420], [420, 72]])
|
874 | 874 | @pytest.mark.parametrize('interpolation', [BILINEAR, BICUBIC])
|
875 | 875 | def test_resize_antialias(device, dt, size, interpolation, tester):
|
876 | 876 |
|
| 877 | + torch.manual_seed(12) |
| 878 | + |
877 | 879 | if dt == torch.float16 and device == "cpu":
|
878 | 880 | # skip float16 on CPU case
|
879 | 881 | return
|
@@ -924,6 +926,19 @@ def test_resize_antialias(device, dt, size, interpolation, tester):
|
924 | 926 | tester.assertTrue(resized_tensor.equal(resize_result), msg=f"{size}, {interpolation}, {dt}")
|
925 | 927 |
|
926 | 928 |
|
| 929 | +@needs_cuda |
| 930 | +@pytest.mark.parametrize('interpolation', [BILINEAR, BICUBIC]) |
| 931 | +def test_assert_resize_antialias(interpolation, tester): |
| 932 | + |
| 933 | + # Checks implementation on very large scales |
| 934 | + # and catch TORCH_CHECK inside interpolate_aa_kernels.cu |
| 935 | + torch.manual_seed(12) |
| 936 | + tensor, pil_img = tester._create_data(1000, 1000, device="cuda") |
| 937 | + |
| 938 | + with pytest.raises(RuntimeError, match=r"Max supported scale factor is"): |
| 939 | + F.resize(tensor, size=(5, 5), interpolation=interpolation, antialias=True) |
| 940 | + |
| 941 | + |
927 | 942 | def check_functional_vs_PIL_vs_scripted(fn, fn_pil, fn_t, config, device, dtype, tol=2.0 + 1e-10, agg_method="max"):
|
928 | 943 |
|
929 | 944 | tester = Tester()
|
|
0 commit comments