diff --git a/test/test_prototype_transforms_consistency.py b/test/test_prototype_transforms_consistency.py index bb20f8a7b3a..c6709a5e550 100644 --- a/test/test_prototype_transforms_consistency.py +++ b/test/test_prototype_transforms_consistency.py @@ -806,6 +806,11 @@ def test_random_apply(self, p, sequence_type): check_call_consistency(prototype_transform, legacy_transform) + if sequence_type is nn.ModuleList: + # quick and dirty test that it is jit-scriptable + scripted = torch.jit.script(prototype_transform) + scripted(torch.rand(1, 3, 300, 300)) + # We can't test other values for `p` since the random parameter generation is different @pytest.mark.parametrize("probabilities", [(0, 1), (1, 0)]) def test_random_choice(self, probabilities): diff --git a/torchvision/prototype/transforms/_container.py b/torchvision/prototype/transforms/_container.py index 938f59f64ae..42c73a2c11e 100644 --- a/torchvision/prototype/transforms/_container.py +++ b/torchvision/prototype/transforms/_container.py @@ -1,9 +1,10 @@ import warnings -from typing import Any, Callable, List, Optional, Sequence, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Union import torch from torch import nn +from torchvision import transforms as _transforms from torchvision.prototype.transforms import Transform @@ -28,6 +29,8 @@ def extra_repr(self) -> str: class RandomApply(Transform): + _v1_transform_cls = _transforms.RandomApply + def __init__(self, transforms: Union[Sequence[Callable], nn.ModuleList], p: float = 0.5) -> None: super().__init__() @@ -39,6 +42,9 @@ def __init__(self, transforms: Union[Sequence[Callable], nn.ModuleList], p: floa raise ValueError("`p` should be a floating point value in the interval [0.0, 1.0].") self.p = p + def _extract_params_for_v1_transform(self) -> Dict[str, Any]: + return {"transforms": self.transforms, "p": self.p} + def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] diff --git a/torchvision/prototype/transforms/_transform.py b/torchvision/prototype/transforms/_transform.py index 16c30565d36..7f3c03d5e67 100644 --- a/torchvision/prototype/transforms/_transform.py +++ b/torchvision/prototype/transforms/_transform.py @@ -141,8 +141,9 @@ def __prepare_scriptable__(self) -> nn.Module: if self._v1_transform_cls is None: raise RuntimeError( f"Transform {type(self).__name__} cannot be JIT scripted. " - f"This is only support for backward compatibility with transforms which already in v1." - f"For torchscript support (on tensors only), you can use the functional API instead." + "torchscript is only supported for backward compatibility with transforms " + "which are already in torchvision.transforms. " + "For torchscript support (on tensors only), you can use the functional API instead." ) return self._v1_transform_cls(**self._extract_params_for_v1_transform())