@@ -10,11 +10,11 @@ class TestPoolConverter(DispatchTestCase):
1010 @parameterized .expand (
1111 [
1212 (3 , 1 , 0 ),
13- (3 , 1 , 1 ),
14- (2 , None , 0 ),
15- (4 , 1 , 1 ),
16- (5 , 2 , 0 ),
17- (7 , 2 , 1 ),
13+ (( 3 ,), ( 1 ,), ( 1 ,) ),
14+ (( 2 ,), [], ( 0 ,) ),
15+ (( 4 ,), ( 1 ,), ( 1 ,) ),
16+ (( 5 ,), ( 2 ,), ( 0 ,) ),
17+ (( 7 ,), ( 2 ,), ( 1 ,) ),
1818 ]
1919 )
2020 def test_avg_pool1d (
@@ -26,14 +26,10 @@ def test_avg_pool1d(
2626 count_include_pad = True ,
2727 ):
2828 class TestModule (torch .nn .Module ):
29- def __init__ (self ):
30- super ().__init__ ()
31- self .pool = torch .nn .AvgPool1d (
32- kernel_size , stride , padding , ceil_mode , count_include_pad
33- )
34-
3529 def forward (self , x ):
36- return self .pool (x )
30+ return torch .ops .aten .avg_pool1d .default (
31+ x , kernel_size , stride , padding , ceil_mode , count_include_pad
32+ )
3733
3834 inputs = [torch .randn (1 , 3 , 32 )]
3935 self .run_test (
@@ -46,7 +42,7 @@ def forward(self, x):
4642 [
4743 (3 , 1 , 0 ),
4844 (3 , 1 , 1 ),
49- ((2 , 2 ), None , (1 , 0 )),
45+ ((2 , 2 ), [] , (1 , 0 )),
5046 ((4 , 3 ), (1 , 1 ), (1 , 1 )),
5147 ((5 , 4 ), (2 , 1 ), (1 , 0 )),
5248 ((7 , 7 ), (1 , 2 ), (0 , 1 )),
@@ -62,9 +58,9 @@ def test_avg_pool2d(
6258 divisor_override = None ,
6359 ):
6460 class TestModule (torch .nn .Module ):
65- def __init__ (self ):
66- super (). __init__ ()
67- self . pool = torch . nn . AvgPool2d (
61+ def forward (self , x ):
62+ return torch . ops . aten . avg_pool2d . default (
63+ x ,
6864 kernel_size ,
6965 stride ,
7066 padding ,
@@ -73,17 +69,14 @@ def __init__(self):
7369 divisor_override ,
7470 )
7571
76- def forward (self , x ):
77- return self .pool (x )
78-
7972 inputs = [torch .randn (1 , 3 , 32 , 32 )]
8073 self .run_test (TestModule (), inputs , use_dynamo_tracer = True )
8174
8275 @parameterized .expand (
8376 [
8477 (3 , 1 , 0 ),
8578 (3 , 1 , 1 ),
86- ((2 , 2 , 3 ), None , (1 , 0 , 1 )),
79+ ((2 , 2 , 3 ), [] , (1 , 0 , 1 )),
8780 ((4 , 3 , 2 ), (1 , 1 , 1 ), (1 , 1 , 0 )),
8881 ((5 , 4 , 3 ), (2 , 1 , 2 ), (1 , 0 , 1 )),
8982 ((7 , 7 , 7 ), (1 , 2 , 1 ), (0 , 1 , 1 )),
@@ -99,9 +92,9 @@ def test_avg_pool3d(
9992 divisor_override = None ,
10093 ):
10194 class TestModule (torch .nn .Module ):
102- def __init__ (self ):
103- super (). __init__ ()
104- self . pool = torch . nn . AvgPool3d (
95+ def forward (self , x ):
96+ return torch . ops . aten . avg_pool3d . default (
97+ x ,
10598 kernel_size ,
10699 stride ,
107100 padding ,
@@ -110,20 +103,17 @@ def __init__(self):
110103 divisor_override ,
111104 )
112105
113- def forward (self , x ):
114- return self .pool (x )
115-
116106 inputs = [torch .randn (1 , 3 , 32 , 32 , 32 )]
117107 self .run_test (TestModule (), inputs , use_dynamo_tracer = True )
118108
119109 @parameterized .expand (
120110 [
121111 (3 , 1 , 0 ),
122- (3 , 1 , 1 ),
123- (2 , None , 0 ),
124- (4 , 1 , 1 ),
125- (5 , 2 , 0 ),
126- (7 , 2 , 1 ),
112+ (( 3 ,), ( 1 ,), ( 1 ,) ),
113+ (( 2 ,), [], ( 0 ,) ),
114+ (( 4 ,), ( 1 ,), ( 1 ,) ),
115+ (( 5 ,), ( 2 ,), ( 0 ,) ),
116+ (( 7 ,), ( 2 ,), ( 1 ,) ),
127117 ]
128118 )
129119 def test_max_pool1d (
@@ -132,18 +122,13 @@ def test_max_pool1d(
132122 stride ,
133123 padding ,
134124 dilation = 1 ,
135- return_indices = False ,
136125 ceil_mode = False ,
137126 ):
138127 class TestModule (torch .nn .Module ):
139- def __init__ (self ):
140- super ().__init__ ()
141- self .pool = torch .nn .MaxPool1d (
142- kernel_size , stride , padding , dilation , return_indices , ceil_mode
143- )
144-
145128 def forward (self , x ):
146- return self .pool (x )
129+ return torch .ops .aten .max_pool1d .default (
130+ x , kernel_size , stride , padding , dilation , ceil_mode
131+ )
147132
148133 inputs = [torch .randn (1 , 3 , 32 )]
149134 self .run_test (
@@ -157,7 +142,7 @@ def forward(self, x):
157142 [
158143 (3 , 1 , 0 ),
159144 (3 , 1 , 1 ),
160- ((2 , 2 ), None , (1 , 0 )),
145+ ((2 , 2 ), [] , (1 , 0 )),
161146 ((4 , 3 ), (1 , 1 ), (1 , 1 )),
162147 ((5 , 4 ), (2 , 1 ), (1 , 0 )),
163148 ((7 , 7 ), (1 , 2 ), (0 , 1 )),
@@ -169,32 +154,27 @@ def test_max_pool2d(
169154 stride ,
170155 padding ,
171156 dilation = 1 ,
172- return_indices = False ,
173157 ceil_mode = False ,
174158 ):
175159 class TestModule (torch .nn .Module ):
176- def __init__ (self ):
177- super ().__init__ ()
178- self .pool = torch .nn .MaxPool2d (
179- kernel_size ,
180- stride ,
181- padding ,
182- dilation ,
183- return_indices ,
184- ceil_mode ,
185- )
186-
187160 def forward (self , x ):
188- return self .pool (x )
161+ return torch .ops .aten .max_pool2d .default (
162+ x , kernel_size , stride , padding , dilation , ceil_mode
163+ )
189164
190165 inputs = [torch .randn (1 , 3 , 32 , 32 )]
191- self .run_test (TestModule (), inputs , use_dynamo_tracer = True , enable_passes = True )
166+ self .run_test (
167+ TestModule (),
168+ inputs ,
169+ use_dynamo_tracer = True ,
170+ enable_passes = True ,
171+ )
192172
193173 @parameterized .expand (
194174 [
195175 (3 , 1 , 0 ),
196176 (3 , 1 , 1 ),
197- ((2 , 2 , 3 ), None , (1 , 0 , 1 )),
177+ ((2 , 2 , 3 ), [] , (1 , 0 , 1 )),
198178 ((4 , 3 , 2 ), (1 , 1 , 1 ), (1 , 1 , 0 )),
199179 ((5 , 4 , 3 ), (2 , 1 , 2 ), (1 , 0 , 1 )),
200180 ((7 , 7 , 7 ), (1 , 2 , 1 ), (0 , 1 , 1 )),
@@ -206,26 +186,21 @@ def test_max_pool3d(
206186 stride ,
207187 padding ,
208188 dilation = 1 ,
209- return_indices = False ,
210189 ceil_mode = False ,
211190 ):
212191 class TestModule (torch .nn .Module ):
213- def __init__ (self ):
214- super ().__init__ ()
215- self .pool = torch .nn .MaxPool3d (
216- kernel_size ,
217- stride ,
218- padding ,
219- dilation ,
220- return_indices ,
221- ceil_mode ,
222- )
223-
224192 def forward (self , x ):
225- return self .pool (x )
193+ return torch .ops .aten .max_pool3d .default (
194+ x , kernel_size , stride , padding , dilation , ceil_mode
195+ )
226196
227197 inputs = [torch .randn (1 , 3 , 32 , 32 , 32 )]
228- self .run_test (TestModule (), inputs , use_dynamo_tracer = True , enable_passes = True )
198+ self .run_test (
199+ TestModule (),
200+ inputs ,
201+ use_dynamo_tracer = True ,
202+ enable_passes = True ,
203+ )
229204
230205
231206if __name__ == "__main__" :
0 commit comments