From a1d7c2826f19c1cc469a63ca293bc0d1859c16ee Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 6 Feb 2023 19:53:51 +0100 Subject: [PATCH 1/3] make transforms v2 get_params a staticmethod --- test/test_prototype_transforms_consistency.py | 9 +++++++++ torchvision/prototype/transforms/_transform.py | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/test/test_prototype_transforms_consistency.py b/test/test_prototype_transforms_consistency.py index 79a2b591a59..f31031e6d77 100644 --- a/test/test_prototype_transforms_consistency.py +++ b/test/test_prototype_transforms_consistency.py @@ -663,6 +663,15 @@ def test_call_consistency(config, args_kwargs): def test_get_params_alias(config): assert config.prototype_cls.get_params is config.legacy_cls.get_params + if not config.args_kwargs: + return + + args, kwargs = config.args_kwargs[0] + legacy_transform = config.legacy_cls(*args, **kwargs) + prototype_transform = config.prototype_cls(*args, **kwargs) + + assert prototype_transform.get_params is legacy_transform.get_params + @pytest.mark.parametrize( ("transform_cls", "args_kwargs"), diff --git a/torchvision/prototype/transforms/_transform.py b/torchvision/prototype/transforms/_transform.py index 18678a5265a..206889ace72 100644 --- a/torchvision/prototype/transforms/_transform.py +++ b/torchvision/prototype/transforms/_transform.py @@ -67,7 +67,7 @@ def __init_subclass__(cls) -> None: # Since `get_params` is a `@staticmethod`, we have to bind it to the class itself rather than to an instance. # This method is called after subclassing has happened, i.e. `cls` is the subclass, e.g. `Resize`. if cls._v1_transform_cls is not None and hasattr(cls._v1_transform_cls, "get_params"): - cls.get_params = cls._v1_transform_cls.get_params # type: ignore[attr-defined] + cls.get_params = staticmethod(cls._v1_transform_cls.get_params) # type: ignore[attr-defined] def _extract_params_for_v1_transform(self) -> Dict[str, Any]: # This method is called by `__prepare_scriptable__` to instantiate the equivalent v1 transform from the current From 20878fa7302a6e9cc2ee636d02833cda9f4c6a97 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 7 Feb 2023 10:36:26 +0100 Subject: [PATCH 2/3] also check get_params scriptability on transform instances --- test/test_prototype_transforms_consistency.py | 66 +++++++++++-------- 1 file changed, 39 insertions(+), 27 deletions(-) diff --git a/test/test_prototype_transforms_consistency.py b/test/test_prototype_transforms_consistency.py index f31031e6d77..231a0f5bebd 100644 --- a/test/test_prototype_transforms_consistency.py +++ b/test/test_prototype_transforms_consistency.py @@ -655,17 +655,39 @@ def test_call_consistency(config, args_kwargs): ) -@pytest.mark.parametrize( - "config", - [config for config in CONSISTENCY_CONFIGS if hasattr(config.legacy_cls, "get_params")], - ids=lambda config: config.legacy_cls.__name__, +get_paramsl_parametrization = pytest.mark.parametrize( + ("config", "get_params_args_kwargs"), + [ + pytest.param( + next(config for config in CONSISTENCY_CONFIGS if config.prototype_cls is transform_cls), + get_params_args_kwargs, + id=transform_cls.__name__, + ) + for transform_cls, get_params_args_kwargs in [ + (prototype_transforms.RandomResizedCrop, ArgsKwargs(make_image(), scale=[0.3, 0.7], ratio=[0.5, 1.5])), + (prototype_transforms.RandomErasing, ArgsKwargs(make_image(), scale=(0.3, 0.7), ratio=(0.5, 1.5))), + (prototype_transforms.ColorJitter, ArgsKwargs(brightness=None, contrast=None, saturation=None, hue=None)), + (prototype_transforms.ElasticTransform, ArgsKwargs(alpha=[15.3, 27.2], sigma=[2.5, 3.9], size=[17, 31])), + (prototype_transforms.GaussianBlur, ArgsKwargs(0.3, 1.4)), + ( + prototype_transforms.RandomAffine, + ArgsKwargs(degrees=[-20.0, 10.0], translate=None, scale_ranges=None, shears=None, img_size=[15, 29]), + ), + (prototype_transforms.RandomCrop, ArgsKwargs(make_image(size=(61, 47)), output_size=(19, 25))), + (prototype_transforms.RandomPerspective, ArgsKwargs(23, 17, 0.5)), + (prototype_transforms.RandomRotation, ArgsKwargs(degrees=[-20.0, 10.0])), + (prototype_transforms.AutoAugment, ArgsKwargs(5)), + ] + ], ) -def test_get_params_alias(config): + + +@get_paramsl_parametrization +def test_get_params_alias(config, get_params_args_kwargs): assert config.prototype_cls.get_params is config.legacy_cls.get_params if not config.args_kwargs: return - args, kwargs = config.args_kwargs[0] legacy_transform = config.legacy_cls(*args, **kwargs) prototype_transform = config.prototype_cls(*args, **kwargs) @@ -673,28 +695,18 @@ def test_get_params_alias(config): assert prototype_transform.get_params is legacy_transform.get_params -@pytest.mark.parametrize( - ("transform_cls", "args_kwargs"), - [ - (prototype_transforms.RandomResizedCrop, ArgsKwargs(make_image(), scale=[0.3, 0.7], ratio=[0.5, 1.5])), - (prototype_transforms.RandomErasing, ArgsKwargs(make_image(), scale=(0.3, 0.7), ratio=(0.5, 1.5))), - (prototype_transforms.ColorJitter, ArgsKwargs(brightness=None, contrast=None, saturation=None, hue=None)), - (prototype_transforms.ElasticTransform, ArgsKwargs(alpha=[15.3, 27.2], sigma=[2.5, 3.9], size=[17, 31])), - (prototype_transforms.GaussianBlur, ArgsKwargs(0.3, 1.4)), - ( - prototype_transforms.RandomAffine, - ArgsKwargs(degrees=[-20.0, 10.0], translate=None, scale_ranges=None, shears=None, img_size=[15, 29]), - ), - (prototype_transforms.RandomCrop, ArgsKwargs(make_image(size=(61, 47)), output_size=(19, 25))), - (prototype_transforms.RandomPerspective, ArgsKwargs(23, 17, 0.5)), - (prototype_transforms.RandomRotation, ArgsKwargs(degrees=[-20.0, 10.0])), - (prototype_transforms.AutoAugment, ArgsKwargs(5)), - ], -) -def test_get_params_jit(transform_cls, args_kwargs): - args, kwargs = args_kwargs +@get_paramsl_parametrization +def test_get_params_jit(config, get_params_args_kwargs): + get_params_args, get_params_kwargs = get_params_args_kwargs + + torch.jit.script(config.prototype_cls.get_params)(*get_params_args, **get_params_kwargs) + + if not config.args_kwargs: + return + args, kwargs = config.args_kwargs[0] + transform = config.prototype_cls(*args, **kwargs) - torch.jit.script(transform_cls.get_params)(*args, **kwargs) + torch.jit.script(transform.get_params)(*get_params_args, **get_params_kwargs) @pytest.mark.parametrize( From 20117d3f83054414d2d556f4fca1e62034fc7b23 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 7 Feb 2023 10:58:58 +0100 Subject: [PATCH 3/3] Update test/test_prototype_transforms_consistency.py Co-authored-by: Nicolas Hug --- test/test_prototype_transforms_consistency.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_prototype_transforms_consistency.py b/test/test_prototype_transforms_consistency.py index 231a0f5bebd..40186b0159e 100644 --- a/test/test_prototype_transforms_consistency.py +++ b/test/test_prototype_transforms_consistency.py @@ -655,7 +655,7 @@ def test_call_consistency(config, args_kwargs): ) -get_paramsl_parametrization = pytest.mark.parametrize( +get_params_parametrization = pytest.mark.parametrize( ("config", "get_params_args_kwargs"), [ pytest.param(