33from op_test import OpTest
44
55
6- def max_pool2D_forward_naive (x , ksize , strides , paddings = [0 , 0 ], global_pool = 0 ):
7-
6+ def max_pool2D_forward_naive (x , ksize , strides , paddings , global_pool = 0 ):
87 N , C , H , W = x .shape
98 if global_pool == 1 :
109 ksize = [H , W ]
@@ -23,8 +22,7 @@ def max_pool2D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0):
2322 return out
2423
2524
26- def avg_pool2D_forward_naive (x , ksize , strides , paddings = [0 , 0 ], global_pool = 0 ):
27-
25+ def avg_pool2D_forward_naive (x , ksize , strides , paddings , global_pool = 0 ):
2826 N , C , H , W = x .shape
2927 if global_pool == 1 :
3028 ksize = [H , W ]
@@ -47,6 +45,7 @@ def avg_pool2D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0):
4745class TestPool2d_Op (OpTest ):
4846 def setUp (self ):
4947 self .init_test_case ()
48+ self .init_global_pool ()
5049 self .init_op_type ()
5150 self .init_pool_type ()
5251 if self .global_pool :
@@ -75,8 +74,6 @@ def test_check_grad(self):
7574 self .check_grad (set (['X' ]), 'Out' , max_relative_error = 0.07 )
7675
7776 def init_test_case (self ):
78- self .global_pool = True
79- self .pool2D_forward_naive = avg_pool2D_forward_naive
8077 self .shape = [2 , 3 , 5 , 5 ]
8178 self .ksize = [3 , 3 ]
8279 self .strides = [1 , 1 ]
@@ -87,12 +84,14 @@ def init_op_type(self):
8784
8885 def init_pool_type (self ):
8986 self .pool_type = "avg"
87+ self .pool2D_forward_naive = avg_pool2D_forward_naive
88+
89+ def init_global_pool (self ):
90+ self .global_pool = True
9091
9192
9293class TestCase1 (TestPool2d_Op ):
9394 def init_test_case (self ):
94- self .global_pool = False
95- self .pool2D_forward_naive = avg_pool2D_forward_naive
9695 self .shape = [2 , 3 , 7 , 7 ]
9796 self .ksize = [3 , 3 ]
9897 self .strides = [1 , 1 ]
@@ -103,12 +102,14 @@ def init_op_type(self):
103102
104103 def init_pool_type (self ):
105104 self .pool_type = "avg"
105+ self .pool2D_forward_naive = avg_pool2D_forward_naive
106+
107+ def init_global_pool (self ):
108+ self .global_pool = False
106109
107110
108111class TestCase2 (TestPool2d_Op ):
109112 def init_test_case (self ):
110- self .global_pool = False
111- self .pool2D_forward_naive = avg_pool2D_forward_naive
112113 self .shape = [2 , 3 , 7 , 7 ]
113114 self .ksize = [3 , 3 ]
114115 self .strides = [1 , 1 ]
@@ -119,152 +120,69 @@ def init_op_type(self):
119120
120121 def init_pool_type (self ):
121122 self .pool_type = "avg"
123+ self .pool2D_forward_naive = avg_pool2D_forward_naive
122124
125+ def init_global_pool (self ):
126+ self .global_pool = False
123127
124- class TestCase3 (TestPool2d_Op ):
125- def init_test_case (self ):
126- self .global_pool = True
127- self .pool2D_forward_naive = max_pool2D_forward_naive
128- self .shape = [2 , 3 , 5 , 5 ]
129- self .ksize = [3 , 3 ]
130- self .strides = [1 , 1 ]
131- self .paddings = [0 , 0 ]
132128
129+ class TestCase3 (TestPool2d_Op ):
133130 def init_op_type (self ):
134131 self .op_type = "pool2d"
135132
136133 def init_pool_type (self ):
137134 self .pool_type = "max"
138-
139-
140- class TestCase4 (TestPool2d_Op ):
141- def init_test_case (self ):
142- self .global_pool = False
143135 self .pool2D_forward_naive = max_pool2D_forward_naive
144- self .shape = [2 , 3 , 7 , 7 ]
145- self .ksize = [3 , 3 ]
146- self .strides = [1 , 1 ]
147- self .paddings = [0 , 0 ]
148136
137+
138+ class TestCase4 (TestCase1 ):
149139 def init_op_type (self ):
150140 self .op_type = "pool2d"
151141
152142 def init_pool_type (self ):
153143 self .pool_type = "max"
154-
155-
156- class TestCase5 (TestPool2d_Op ):
157- def init_test_case (self ):
158- self .global_pool = False
159144 self .pool2D_forward_naive = max_pool2D_forward_naive
160- self .shape = [2 , 3 , 7 , 7 ]
161- self .ksize = [3 , 3 ]
162- self .strides = [1 , 1 ]
163- self .paddings = [1 , 1 ]
164145
146+
147+ class TestCase5 (TestCase2 ):
165148 def init_op_type (self ):
166149 self .op_type = "pool2d"
167150
168151 def init_pool_type (self ):
169152 self .pool_type = "max"
153+ self .pool2D_forward_naive = max_pool2D_forward_naive
170154
171155
172156#--------------------test pool2d_cudnn--------------------
173- class TestCaseCudnn1 (TestPool2d_Op ):
174- def init_test_case (self ):
175- self .global_pool = True
176- self .pool2D_forward_naive = avg_pool2D_forward_naive
177- self .shape = [2 , 3 , 5 , 5 ]
178- self .ksize = [3 , 3 ]
179- self .strides = [1 , 1 ]
180- self .paddings = [0 , 0 ]
181-
157+ class TestCudnnCase1 (TestPool2d_Op ):
182158 def init_op_type (self ):
183159 self .op_type = "pool2d_cudnn"
184160
185- def init_pool_type (self ):
186- self .pool_type = "avg"
187-
188-
189- class TestCaseCudnn2 (TestPool2d_Op ):
190- def init_test_case (self ):
191- self .global_pool = False
192- self .pool2D_forward_naive = avg_pool2D_forward_naive
193- self .shape = [2 , 3 , 7 , 7 ]
194- self .ksize = [3 , 3 ]
195- self .strides = [1 , 1 ]
196- self .paddings = [0 , 0 ]
197161
162+ class TestCudnnCase2 (TestCase1 ):
198163 def init_op_type (self ):
199164 self .op_type = "pool2d_cudnn"
200165
201- def init_pool_type (self ):
202- self .pool_type = "avg"
203-
204-
205- class TestCaseCudnn3 (TestPool2d_Op ):
206- def init_test_case (self ):
207- self .global_pool = False
208- self .pool2D_forward_naive = avg_pool2D_forward_naive
209- self .shape = [2 , 3 , 7 , 7 ]
210- self .ksize = [3 , 3 ]
211- self .strides = [1 , 1 ]
212- self .paddings = [1 , 1 ]
213166
167+ class TestCudnnCase3 (TestCase2 ):
214168 def init_op_type (self ):
215169 self .op_type = "pool2d_cudnn"
216170
217- def init_pool_type (self ):
218- self .pool_type = "avg"
219-
220-
221- class TestCaseCudnn4 (TestPool2d_Op ):
222- def init_test_case (self ):
223- self .global_pool = True
224- self .pool2D_forward_naive = max_pool2D_forward_naive
225- self .shape = [2 , 3 , 5 , 5 ]
226- self .ksize = [3 , 3 ]
227- self .strides = [1 , 1 ]
228- self .paddings = [0 , 0 ]
229171
172+ class TestCudnnCase4 (TestCase3 ):
230173 def init_op_type (self ):
231174 self .op_type = "pool2d_cudnn"
232175
233- def init_pool_type (self ):
234- self .pool_type = "max"
235-
236-
237- class TestCaseCudnn5 (TestPool2d_Op ):
238- def init_test_case (self ):
239- self .global_pool = False
240- self .pool2D_forward_naive = max_pool2D_forward_naive
241- self .shape = [2 , 3 , 7 , 7 ]
242- self .ksize = [3 , 3 ]
243- self .strides = [1 , 1 ]
244- self .paddings = [0 , 0 ]
245176
177+ class TestCudnnCase5 (TestCase4 ):
246178 def init_op_type (self ):
247179 self .op_type = "pool2d_cudnn"
248180
249- def init_pool_type (self ):
250- self .pool_type = "max"
251-
252-
253- class TestCaseCudnn6 (TestPool2d_Op ):
254- def init_test_case (self ):
255- self .global_pool = False
256- self .pool2D_forward_naive = max_pool2D_forward_naive
257- self .shape = [2 , 3 , 7 , 7 ]
258- self .ksize = [3 , 3 ]
259- self .strides = [1 , 1 ]
260- self .paddings = [1 , 1 ]
261181
182+ class TestCudnnCase6 (TestCase5 ):
262183 def init_op_type (self ):
263184 self .op_type = "pool2d_cudnn"
264185
265- def init_pool_type (self ):
266- self .pool_type = "max"
267-
268186
269187if __name__ == '__main__' :
270188 unittest .main ()
0 commit comments