Skip to content

Commit 528651a

Browse files
authored
[proto] Fix bug with Compose and PR 6504 (#6510)
* [proto] Fix bug with Compose and PR 6504 * Added tests and fixed other bugs
1 parent 7245dc9 commit 528651a

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

test/test_prototype_transforms.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1108,13 +1108,18 @@ def test_assertions(self, transform_cls):
11081108

11091109
@pytest.mark.parametrize("transform_cls", [transforms.Compose, transforms.RandomChoice, transforms.RandomOrder])
11101110
@pytest.mark.parametrize(
1111-
"trfms", [[transforms.Pad(2), transforms.RandomCrop(28)], [lambda x: 2.0 * x, transforms.RandomCrop(28)]]
1111+
"trfms",
1112+
[
1113+
[transforms.Pad(2), transforms.RandomCrop(28)],
1114+
[lambda x: 2.0 * x, transforms.Pad(2), transforms.RandomCrop(28)],
1115+
],
11121116
)
11131117
def test_ctor(self, transform_cls, trfms):
11141118
c = transform_cls(trfms)
11151119
inpt = torch.rand(1, 3, 32, 32)
11161120
output = c(inpt)
11171121
assert isinstance(output, torch.Tensor)
1122+
assert output.ndim == 4
11181123

11191124

11201125
class TestRandomChoice:

torchvision/prototype/transforms/_container.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@ def __init__(self, transforms: Sequence[Callable]) -> None:
1515
self.transforms = transforms
1616

1717
def forward(self, *inputs: Any) -> Any:
18+
sample = inputs if len(inputs) > 1 else inputs[0]
1819
for transform in self.transforms:
19-
inputs = transform(*inputs)
20-
return inputs
20+
sample = transform(sample)
21+
return sample
2122

2223

2324
class RandomApply(_RandomApplyTransform):
@@ -76,7 +77,8 @@ def __init__(self, transforms: Sequence[Callable]) -> None:
7677
self.transforms = transforms
7778

7879
def forward(self, *inputs: Any) -> Any:
80+
sample = inputs if len(inputs) > 1 else inputs[0]
7981
for idx in torch.randperm(len(self.transforms)):
8082
transform = self.transforms[idx]
81-
inputs = transform(*inputs)
82-
return inputs
83+
sample = transform(sample)
84+
return sample

0 commit comments

Comments
 (0)