Skip to content

Commit a06cc94

Browse files
committed
Add test cases using scales* arguments
1 parent 6c9fb39 commit a06cc94

File tree

1 file changed

+70
-36
lines changed

1 file changed

+70
-36
lines changed

tests/py/dynamo/conversion/test_upsample.py

Lines changed: 70 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -8,113 +8,147 @@
88
class TestUpsampleConverter(DispatchTestCase):
99
@parameterized.expand(
1010
[
11-
((2,), (4,)),
12-
((2,), (5,)),
11+
((2,), (4,), None),
12+
((2,), (5,), None),
13+
((2,), (4,), 2.0),
14+
((2,), (5,), 2.5),
1315
]
1416
)
15-
def test_nearest1d(self, input_shape, output_size):
17+
def test_nearest1d(self, input_shape, output_size, scales):
1618
class Upsample(torch.nn.Module):
1719
def forward(self, input):
18-
return torch.ops.aten.upsample_nearest1d.default(input, output_size)
20+
return torch.ops.aten.upsample_nearest1d.default(
21+
input, output_size, scales
22+
)
1923

2024
input = [torch.randn([1, 1] + list(input_shape))]
2125
self.run_test(Upsample(), input)
2226

2327
@parameterized.expand(
2428
[
25-
((2, 2), (4, 4)),
26-
((2, 2), (5, 5)),
29+
((2, 2), (4, 4), None, None),
30+
((2, 2), (5, 5), None, None),
31+
((2, 2), (4, 4), 2.0, 2.0),
32+
((2, 2), (5, 5), 2.5, 2.5),
2733
]
2834
)
29-
def test_nearest2d(self, input_shape, output_size):
35+
def test_nearest2d(self, input_shape, output_size, scales_h, scales_w):
3036
class Upsample(torch.nn.Module):
3137
def forward(self, input):
32-
return torch.ops.aten.upsample_nearest2d.default(input, output_size)
38+
return torch.ops.aten.upsample_nearest2d.default(
39+
input, output_size, scales_h, scales_w
40+
)
3341

3442
input = [torch.randn([1, 1] + list(input_shape))]
3543
self.run_test(Upsample(), input)
3644

3745
@parameterized.expand(
3846
[
39-
((2, 2, 2), (4, 4, 4)),
40-
((2, 2, 2), (5, 5, 5)),
47+
((2, 2, 2), (4, 4, 4), None, None, None),
48+
((2, 2, 2), (5, 5, 5), None, None, None),
49+
((2, 2, 2), (4, 4, 4), 2.0, 2.0, 2.0),
50+
((2, 2, 2), (5, 5, 5), 2.5, 2.5, 2.5),
4151
]
4252
)
43-
def test_nearest3d(self, input_shape, output_size):
53+
def test_nearest3d(self, input_shape, output_size, scales_d, scales_h, scales_w):
4454
class Upsample(torch.nn.Module):
4555
def forward(self, input):
46-
return torch.ops.aten.upsample_nearest3d.default(input, output_size)
56+
return torch.ops.aten.upsample_nearest3d.default(
57+
input, output_size, scales_d, scales_h, scales_w
58+
)
4759

4860
input = [torch.randn([1, 1] + list(input_shape))]
4961
self.run_test(Upsample(), input)
5062

5163
@parameterized.expand(
5264
[
53-
((2,), (4,), True),
54-
((2,), (4,), False),
55-
((2,), (5,), True),
56-
((2,), (5,), False),
65+
((2,), (4,), True, None),
66+
((2,), (4,), False, None),
67+
((2,), (5,), True, None),
68+
((2,), (5,), False, None),
69+
((2,), (4,), True, 2.0),
70+
((2,), (4,), False, 2.0),
71+
((2,), (5,), True, 2.5),
72+
((2,), (5,), False, 2.5),
5773
]
5874
)
59-
def test_linear1d(self, input_shape, output_size, align_corners):
75+
def test_linear1d(self, input_shape, output_size, align_corners, scales):
6076
class Upsample(torch.nn.Module):
6177
def forward(self, input):
6278
return torch.ops.aten.upsample_linear1d.default(
63-
input, output_size, align_corners
79+
input, output_size, align_corners, scales
6480
)
6581

6682
input = [torch.randn([1, 1] + list(input_shape))]
6783
self.run_test(Upsample(), input)
6884

6985
@parameterized.expand(
7086
[
71-
((2, 2), (4, 4), True),
72-
((2, 2), (4, 4), False),
73-
((2, 2), (5, 5), True),
74-
((2, 2), (5, 5), False),
87+
((2, 2), (4, 4), True, None, None),
88+
((2, 2), (4, 4), False, None, None),
89+
((2, 2), (5, 5), True, None, None),
90+
((2, 2), (5, 5), False, None, None),
91+
((2, 2), (4, 4), True, 2.0, 2.0),
92+
((2, 2), (4, 4), False, 2.0, 2.0),
93+
((2, 2), (5, 5), True, 2.5, 2.5),
94+
((2, 2), (5, 5), False, 2.5, 2.5),
7595
]
7696
)
77-
def test_bilinear2d(self, input_shape, output_size, align_corners):
97+
def test_bilinear2d(
98+
self, input_shape, output_size, align_corners, scales_h, scales_w
99+
):
78100
class Upsample(torch.nn.Module):
79101
def forward(self, input):
80102
return torch.ops.aten.upsample_bilinear2d.default(
81-
input, output_size, align_corners
103+
input, output_size, align_corners, scales_h, scales_w
82104
)
83105

84106
input = [torch.randn([1, 1] + list(input_shape))]
85107
self.run_test(Upsample(), input)
86108

87109
@parameterized.expand(
88110
[
89-
((2, 2, 2), (4, 4, 4), True),
90-
((2, 2, 2), (4, 4, 4), False),
91-
((2, 2, 2), (5, 5, 5), True),
92-
((2, 2, 2), (5, 5, 5), False),
111+
((2, 2, 2), (4, 4, 4), True, None, None, None),
112+
((2, 2, 2), (4, 4, 4), False, None, None, None),
113+
((2, 2, 2), (5, 5, 5), True, None, None, None),
114+
((2, 2, 2), (5, 5, 5), False, None, None, None),
115+
((2, 2, 2), (4, 4, 4), True, 2.0, 2.0, 2.0),
116+
((2, 2, 2), (4, 4, 4), False, 2.0, 2.0, 2.0),
117+
((2, 2, 2), (5, 5, 5), True, 2.5, 2.5, 2.5),
118+
((2, 2, 2), (5, 5, 5), False, 2.5, 2.5, 2.5),
93119
]
94120
)
95-
def test_trilinear3d(self, input_shape, output_size, align_corners):
121+
def test_trilinear3d(
122+
self, input_shape, output_size, align_corners, scales_d, scales_h, scales_w
123+
):
96124
class Upsample(torch.nn.Module):
97125
def forward(self, input):
98126
return torch.ops.aten.upsample_trilinear3d.default(
99-
input, output_size, align_corners
127+
input, output_size, align_corners, scales_d, scales_h, scales_w
100128
)
101129

102130
input = [torch.randn([1, 1] + list(input_shape))]
103131
self.run_test(Upsample(), input)
104132

105133
@parameterized.expand(
106134
[
107-
((2, 2), (4, 4), True),
108-
((2, 2), (4, 4), False),
109-
((2, 2), (5, 5), True),
110-
((2, 2), (5, 5), False),
135+
((2, 2), (4, 4), True, None, None),
136+
((2, 2), (4, 4), False, None, None),
137+
((2, 2), (5, 5), True, None, None),
138+
((2, 2), (5, 5), False, None, None),
139+
((2, 2), (4, 4), True, 2.0, 2.0),
140+
((2, 2), (4, 4), False, 2.0, 2.0),
141+
((2, 2), (5, 5), True, 2.5, 2.5),
142+
((2, 2), (5, 5), False, 2.5, 2.5),
111143
]
112144
)
113-
def test_bicubic2d(self, input_shape, output_size, align_corners):
145+
def test_bicubic2d(
146+
self, input_shape, output_size, align_corners, scales_h, scales_w
147+
):
114148
class Upsample(torch.nn.Module):
115149
def forward(self, input):
116150
return torch.ops.aten.upsample_bicubic2d.default(
117-
input, output_size, align_corners
151+
input, output_size, align_corners, scales_h, scales_w
118152
)
119153

120154
input = [torch.randn([1, 1] + list(input_shape))]

0 commit comments

Comments
 (0)