Skip to content

Port _test_adjust_fn to pytest #3845

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 17, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
364 changes: 214 additions & 150 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,85 +324,6 @@ def test_pad(self):

self._test_fn_on_batch(batch_tensors, F.pad, padding=script_pad, **kwargs)

def _test_adjust_fn(self, fn, fn_pil, fn_t, configs, tol=2.0 + 1e-10, agg_method="max",
dts=(None, torch.float32, torch.float64)):
script_fn = torch.jit.script(fn)
torch.manual_seed(15)
tensor, pil_img = self._create_data(26, 34, device=self.device)
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)

for dt in dts:

if dt is not None:
tensor = F.convert_image_dtype(tensor, dt)
batch_tensors = F.convert_image_dtype(batch_tensors, dt)

for config in configs:
adjusted_tensor = fn_t(tensor, **config)
adjusted_pil = fn_pil(pil_img, **config)
scripted_result = script_fn(tensor, **config)
msg = "{}, {}".format(dt, config)
self.assertEqual(adjusted_tensor.dtype, scripted_result.dtype, msg=msg)
self.assertEqual(adjusted_tensor.size()[1:], adjusted_pil.size[::-1], msg=msg)

rbg_tensor = adjusted_tensor

if adjusted_tensor.dtype != torch.uint8:
rbg_tensor = F.convert_image_dtype(adjusted_tensor, torch.uint8)

# Check that max difference does not exceed 2 in [0, 255] range
# Exact matching is not possible due to incompatibility convert_image_dtype and PIL results
self.approxEqualTensorToPIL(rbg_tensor.float(), adjusted_pil, tol=tol, msg=msg, agg_method=agg_method)

atol = 1e-6
if adjusted_tensor.dtype == torch.uint8 and "cuda" in torch.device(self.device).type:
atol = 1.0
self.assertTrue(adjusted_tensor.allclose(scripted_result, atol=atol), msg=msg)

self._test_fn_on_batch(batch_tensors, fn, scripted_fn_atol=atol, **config)

def test_adjust_brightness(self):
self._test_adjust_fn(
F.adjust_brightness,
F_pil.adjust_brightness,
F_t.adjust_brightness,
[{"brightness_factor": f} for f in [0.1, 0.5, 1.0, 1.34, 2.5]]
)

def test_adjust_contrast(self):
self._test_adjust_fn(
F.adjust_contrast,
F_pil.adjust_contrast,
F_t.adjust_contrast,
[{"contrast_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]]
)

def test_adjust_saturation(self):
self._test_adjust_fn(
F.adjust_saturation,
F_pil.adjust_saturation,
F_t.adjust_saturation,
[{"saturation_factor": f} for f in [0.5, 0.75, 1.0, 1.5, 2.0]]
)

def test_adjust_hue(self):
self._test_adjust_fn(
F.adjust_hue,
F_pil.adjust_hue,
F_t.adjust_hue,
[{"hue_factor": f} for f in [-0.45, -0.25, 0.0, 0.25, 0.45]],
tol=16.1,
agg_method="max"
)

def test_adjust_gamma(self):
self._test_adjust_fn(
F.adjust_gamma,
F_pil.adjust_gamma,
F_t.adjust_gamma,
[{"gamma": g1, "gain": g2} for g1, g2 in zip([0.8, 1.0, 1.2], [0.7, 1.0, 1.3])]
)

def test_resize(self):
script_fn = torch.jit.script(F.resize)
tensor, pil_img = self._create_data(26, 36, device=self.device)
Expand Down Expand Up @@ -833,77 +754,6 @@ def test_gaussian_blur(self):
msg="{}, {}".format(ksize, sigma)
)

def test_invert(self):
self._test_adjust_fn(
F.invert,
F_pil.invert,
F_t.invert,
[{}],
tol=1.0,
agg_method="max"
)

def test_posterize(self):
self._test_adjust_fn(
F.posterize,
F_pil.posterize,
F_t.posterize,
[{"bits": bits} for bits in range(0, 8)],
tol=1.0,
agg_method="max",
dts=(None,)
)

def test_solarize(self):
self._test_adjust_fn(
F.solarize,
F_pil.solarize,
F_t.solarize,
[{"threshold": threshold} for threshold in [0, 64, 128, 192, 255]],
tol=1.0,
agg_method="max",
dts=(None,)
)
self._test_adjust_fn(
F.solarize,
lambda img, threshold: F_pil.solarize(img, 255 * threshold),
F_t.solarize,
[{"threshold": threshold} for threshold in [0.0, 0.25, 0.5, 0.75, 1.0]],
tol=1.0,
agg_method="max",
dts=(torch.float32, torch.float64)
)

def test_adjust_sharpness(self):
self._test_adjust_fn(
F.adjust_sharpness,
F_pil.adjust_sharpness,
F_t.adjust_sharpness,
[{"sharpness_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]]
)

def test_autocontrast(self):
self._test_adjust_fn(
F.autocontrast,
F_pil.autocontrast,
F_t.autocontrast,
[{}],
tol=1.0,
agg_method="max"
)

def test_equalize(self):
torch.set_deterministic(False)
self._test_adjust_fn(
F.equalize,
F_pil.equalize,
F_t.equalize,
[{}],
tol=1.0,
agg_method="max",
dts=(None,)
)


@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
class CUDATester(Tester):
Expand Down Expand Up @@ -1074,5 +924,219 @@ def test_resize_antialias(device, dt, size, interpolation, tester):
tester.assertTrue(resized_tensor.equal(resize_result), msg=f"{size}, {interpolation}, {dt}")


def check_functional_vs_PIL_vs_scripted(fn, fn_pil, fn_t, config, device, dtype, tol=2.0 + 1e-10, agg_method="max"):

tester = Tester()

script_fn = torch.jit.script(fn)
torch.manual_seed(15)
tensor, pil_img = tester._create_data(26, 34, device=device)
batch_tensors = tester._create_data_batch(16, 18, num_samples=4, device=device)

if dtype is not None:
tensor = F.convert_image_dtype(tensor, dtype)
batch_tensors = F.convert_image_dtype(batch_tensors, dtype)

out_fn_t = fn_t(tensor, **config)
out_pil = fn_pil(pil_img, **config)
out_scripted = script_fn(tensor, **config)
assert out_fn_t.dtype == out_scripted.dtype
assert out_fn_t.size()[1:] == out_pil.size[::-1]

rbg_tensor = out_fn_t

if out_fn_t.dtype != torch.uint8:
rbg_tensor = F.convert_image_dtype(out_fn_t, torch.uint8)

# Check that max difference does not exceed 2 in [0, 255] range
# Exact matching is not possible due to incompatibility convert_image_dtype and PIL results
tester.approxEqualTensorToPIL(rbg_tensor.float(), out_pil, tol=tol, agg_method=agg_method)

atol = 1e-6
if out_fn_t.dtype == torch.uint8 and "cuda" in torch.device(device).type:
atol = 1.0
assert out_fn_t.allclose(out_scripted, atol=atol)

# FIXME: fn will be scripted again in _test_fn_on_batch. We could avoid that.
tester._test_fn_on_batch(batch_tensors, fn, scripted_fn_atol=atol, **config)


@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64))
@pytest.mark.parametrize('config', [{"brightness_factor": f} for f in (0.1, 0.5, 1.0, 1.34, 2.5)])
def test_adjust_brightness(device, dtype, config):
check_functional_vs_PIL_vs_scripted(
F.adjust_brightness,
F_pil.adjust_brightness,
F_t.adjust_brightness,
config,
device,
dtype,
)


@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64))
def test_invert(device, dtype):
check_functional_vs_PIL_vs_scripted(
F.invert,
F_pil.invert,
F_t.invert,
{},
device,
dtype,
tol=1.0,
agg_method="max"
)


@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('config', [{"bits": bits} for bits in range(0, 8)])
def test_posterize(device, config):
check_functional_vs_PIL_vs_scripted(
F.posterize,
F_pil.posterize,
F_t.posterize,
config,
device,
dtype=None,
tol=1.0,
agg_method="max",
)


@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('config', [{"threshold": threshold} for threshold in [0, 64, 128, 192, 255]])
def test_solarize1(device, config):
check_functional_vs_PIL_vs_scripted(
F.solarize,
F_pil.solarize,
F_t.solarize,
config,
device,
dtype=None,
tol=1.0,
agg_method="max",
)


@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('dtype', (torch.float32, torch.float64))
@pytest.mark.parametrize('config', [{"threshold": threshold} for threshold in [0.0, 0.25, 0.5, 0.75, 1.0]])
def test_solarize2(device, dtype, config):
check_functional_vs_PIL_vs_scripted(
F.solarize,
lambda img, threshold: F_pil.solarize(img, 255 * threshold),
F_t.solarize,
config,
device,
dtype,
tol=1.0,
agg_method="max",
)


@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64))
@pytest.mark.parametrize('config', [{"sharpness_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]])
def test_adjust_sharpness(device, dtype, config):
check_functional_vs_PIL_vs_scripted(
F.adjust_sharpness,
F_pil.adjust_sharpness,
F_t.adjust_sharpness,
config,
device,
dtype,
)


@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64))
def test_autocontrast(device, dtype):
check_functional_vs_PIL_vs_scripted(
F.autocontrast,
F_pil.autocontrast,
F_t.autocontrast,
{},
device,
dtype,
tol=1.0,
agg_method="max"
)


@pytest.mark.parametrize('device', cpu_and_gpu())
def test_equalize(device):
torch.set_deterministic(False)
check_functional_vs_PIL_vs_scripted(
F.equalize,
F_pil.equalize,
F_t.equalize,
{},
device,
dtype=None,
tol=1.0,
agg_method="max",
)


@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64))
@pytest.mark.parametrize('config', [{"contrast_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]])
def test_adjust_contrast(device, dtype, config):
check_functional_vs_PIL_vs_scripted(
F.adjust_contrast,
F_pil.adjust_contrast,
F_t.adjust_contrast,
config,
device,
dtype
)


@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64))
@pytest.mark.parametrize('config', [{"saturation_factor": f} for f in [0.5, 0.75, 1.0, 1.5, 2.0]])
def test_adjust_saturation(device, dtype, config):
check_functional_vs_PIL_vs_scripted(
F.adjust_saturation,
F_pil.adjust_saturation,
F_t.adjust_saturation,
config,
device,
dtype
)


@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64))
@pytest.mark.parametrize('config', [{"hue_factor": f} for f in [-0.45, -0.25, 0.0, 0.25, 0.45]])
def test_adjust_hue(device, dtype, config):
check_functional_vs_PIL_vs_scripted(
F.adjust_hue,
F_pil.adjust_hue,
F_t.adjust_hue,
config,
device,
dtype,
tol=16.1,
agg_method="max"
)


@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64))
@pytest.mark.parametrize('config', [{"gamma": g1, "gain": g2} for g1, g2 in zip([0.8, 1.0, 1.2], [0.7, 1.0, 1.3])])
def test_adjust_gamma(device, dtype, config):
check_functional_vs_PIL_vs_scripted(
F.adjust_gamma,
F_pil.adjust_gamma,
F_t.adjust_gamma,
config,
device,
dtype,
)


if __name__ == '__main__':
unittest.main()