Skip to content

Commit ec1e2fc

Browse files
committed
add cudnn_pool3d unit test
1 parent 7ba3d1e commit ec1e2fc

File tree

4 files changed

+106
-143
lines changed

4 files changed

+106
-143
lines changed

paddle/operators/pool_cudnn_op.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,4 +155,4 @@ REGISTER_OP_GPU_KERNEL(pool2d_cudnn, ops::PoolCudnnOpKernel<float>);
155155
REGISTER_OP_GPU_KERNEL(pool2d_cudnn_grad, ops::PoolCudnnGradOpKernel<float>);
156156

157157
REGISTER_OP_GPU_KERNEL(pool3d_cudnn, ops::PoolCudnnOpKernel<float>);
158-
REGISTER_OP_GPU_KERNEL(pool3d_cudnn_grad, ops::PoolCudnnGradOpKernel<float>);
158+
REGISTER_OP_GPU_KERNEL(pool3d_cudnn_grad, ops::PoolCudnnGradOpKernel<float>);

paddle/platform/cudnn_helper.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ class ScopedTensorDescriptor {
143143
strides[i] = dims[i + 1] * strides[i + 1];
144144
}
145145
// Update tensor descriptor dims setting if groups > 1
146-
// FIXME(typhoonzero): Assume using NCHW order
146+
// FIXME(typhoonzero): Assume using NCHW or NCDHW order
147147
std::vector<int> dims_with_group(dims.begin(), dims.end()); // copy
148148
if (groups > 1) {
149149
dims_with_group[1] = dims_with_group[1] / groups;

python/paddle/v2/framework/tests/test_pool2d_op.py

Lines changed: 26 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
from op_test import OpTest
44

55

6-
def max_pool2D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0):
7-
6+
def max_pool2D_forward_naive(x, ksize, strides, paddings, global_pool=0):
87
N, C, H, W = x.shape
98
if global_pool == 1:
109
ksize = [H, W]
@@ -23,8 +22,7 @@ def max_pool2D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0):
2322
return out
2423

2524

26-
def avg_pool2D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0):
27-
25+
def avg_pool2D_forward_naive(x, ksize, strides, paddings, global_pool=0):
2826
N, C, H, W = x.shape
2927
if global_pool == 1:
3028
ksize = [H, W]
@@ -47,6 +45,7 @@ def avg_pool2D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0):
4745
class TestPool2d_Op(OpTest):
4846
def setUp(self):
4947
self.init_test_case()
48+
self.init_global_pool()
5049
self.init_op_type()
5150
self.init_pool_type()
5251
if self.global_pool:
@@ -75,8 +74,6 @@ def test_check_grad(self):
7574
self.check_grad(set(['X']), 'Out', max_relative_error=0.07)
7675

7776
def init_test_case(self):
78-
self.global_pool = True
79-
self.pool2D_forward_naive = avg_pool2D_forward_naive
8077
self.shape = [2, 3, 5, 5]
8178
self.ksize = [3, 3]
8279
self.strides = [1, 1]
@@ -87,12 +84,14 @@ def init_op_type(self):
8784

8885
def init_pool_type(self):
8986
self.pool_type = "avg"
87+
self.pool2D_forward_naive = avg_pool2D_forward_naive
88+
89+
def init_global_pool(self):
90+
self.global_pool = True
9091

9192

9293
class TestCase1(TestPool2d_Op):
9394
def init_test_case(self):
94-
self.global_pool = False
95-
self.pool2D_forward_naive = avg_pool2D_forward_naive
9695
self.shape = [2, 3, 7, 7]
9796
self.ksize = [3, 3]
9897
self.strides = [1, 1]
@@ -103,12 +102,14 @@ def init_op_type(self):
103102

104103
def init_pool_type(self):
105104
self.pool_type = "avg"
105+
self.pool2D_forward_naive = avg_pool2D_forward_naive
106+
107+
def init_global_pool(self):
108+
self.global_pool = False
106109

107110

108111
class TestCase2(TestPool2d_Op):
109112
def init_test_case(self):
110-
self.global_pool = False
111-
self.pool2D_forward_naive = avg_pool2D_forward_naive
112113
self.shape = [2, 3, 7, 7]
113114
self.ksize = [3, 3]
114115
self.strides = [1, 1]
@@ -119,152 +120,69 @@ def init_op_type(self):
119120

120121
def init_pool_type(self):
121122
self.pool_type = "avg"
123+
self.pool2D_forward_naive = avg_pool2D_forward_naive
122124

125+
def init_global_pool(self):
126+
self.global_pool = False
123127

124-
class TestCase3(TestPool2d_Op):
125-
def init_test_case(self):
126-
self.global_pool = True
127-
self.pool2D_forward_naive = max_pool2D_forward_naive
128-
self.shape = [2, 3, 5, 5]
129-
self.ksize = [3, 3]
130-
self.strides = [1, 1]
131-
self.paddings = [0, 0]
132128

129+
class TestCase3(TestPool2d_Op):
133130
def init_op_type(self):
134131
self.op_type = "pool2d"
135132

136133
def init_pool_type(self):
137134
self.pool_type = "max"
138-
139-
140-
class TestCase4(TestPool2d_Op):
141-
def init_test_case(self):
142-
self.global_pool = False
143135
self.pool2D_forward_naive = max_pool2D_forward_naive
144-
self.shape = [2, 3, 7, 7]
145-
self.ksize = [3, 3]
146-
self.strides = [1, 1]
147-
self.paddings = [0, 0]
148136

137+
138+
class TestCase4(TestCase1):
149139
def init_op_type(self):
150140
self.op_type = "pool2d"
151141

152142
def init_pool_type(self):
153143
self.pool_type = "max"
154-
155-
156-
class TestCase5(TestPool2d_Op):
157-
def init_test_case(self):
158-
self.global_pool = False
159144
self.pool2D_forward_naive = max_pool2D_forward_naive
160-
self.shape = [2, 3, 7, 7]
161-
self.ksize = [3, 3]
162-
self.strides = [1, 1]
163-
self.paddings = [1, 1]
164145

146+
147+
class TestCase5(TestCase2):
165148
def init_op_type(self):
166149
self.op_type = "pool2d"
167150

168151
def init_pool_type(self):
169152
self.pool_type = "max"
153+
self.pool2D_forward_naive = max_pool2D_forward_naive
170154

171155

172156
#--------------------test pool2d_cudnn--------------------
173-
class TestCaseCudnn1(TestPool2d_Op):
174-
def init_test_case(self):
175-
self.global_pool = True
176-
self.pool2D_forward_naive = avg_pool2D_forward_naive
177-
self.shape = [2, 3, 5, 5]
178-
self.ksize = [3, 3]
179-
self.strides = [1, 1]
180-
self.paddings = [0, 0]
181-
157+
class TestCudnnCase1(TestPool2d_Op):
182158
def init_op_type(self):
183159
self.op_type = "pool2d_cudnn"
184160

185-
def init_pool_type(self):
186-
self.pool_type = "avg"
187-
188-
189-
class TestCaseCudnn2(TestPool2d_Op):
190-
def init_test_case(self):
191-
self.global_pool = False
192-
self.pool2D_forward_naive = avg_pool2D_forward_naive
193-
self.shape = [2, 3, 7, 7]
194-
self.ksize = [3, 3]
195-
self.strides = [1, 1]
196-
self.paddings = [0, 0]
197161

162+
class TestCudnnCase2(TestCase1):
198163
def init_op_type(self):
199164
self.op_type = "pool2d_cudnn"
200165

201-
def init_pool_type(self):
202-
self.pool_type = "avg"
203-
204-
205-
class TestCaseCudnn3(TestPool2d_Op):
206-
def init_test_case(self):
207-
self.global_pool = False
208-
self.pool2D_forward_naive = avg_pool2D_forward_naive
209-
self.shape = [2, 3, 7, 7]
210-
self.ksize = [3, 3]
211-
self.strides = [1, 1]
212-
self.paddings = [1, 1]
213166

167+
class TestCudnnCase3(TestCase2):
214168
def init_op_type(self):
215169
self.op_type = "pool2d_cudnn"
216170

217-
def init_pool_type(self):
218-
self.pool_type = "avg"
219-
220-
221-
class TestCaseCudnn4(TestPool2d_Op):
222-
def init_test_case(self):
223-
self.global_pool = True
224-
self.pool2D_forward_naive = max_pool2D_forward_naive
225-
self.shape = [2, 3, 5, 5]
226-
self.ksize = [3, 3]
227-
self.strides = [1, 1]
228-
self.paddings = [0, 0]
229171

172+
class TestCudnnCase4(TestCase3):
230173
def init_op_type(self):
231174
self.op_type = "pool2d_cudnn"
232175

233-
def init_pool_type(self):
234-
self.pool_type = "max"
235-
236-
237-
class TestCaseCudnn5(TestPool2d_Op):
238-
def init_test_case(self):
239-
self.global_pool = False
240-
self.pool2D_forward_naive = max_pool2D_forward_naive
241-
self.shape = [2, 3, 7, 7]
242-
self.ksize = [3, 3]
243-
self.strides = [1, 1]
244-
self.paddings = [0, 0]
245176

177+
class TestCudnnCase5(TestCase4):
246178
def init_op_type(self):
247179
self.op_type = "pool2d_cudnn"
248180

249-
def init_pool_type(self):
250-
self.pool_type = "max"
251-
252-
253-
class TestCaseCudnn6(TestPool2d_Op):
254-
def init_test_case(self):
255-
self.global_pool = False
256-
self.pool2D_forward_naive = max_pool2D_forward_naive
257-
self.shape = [2, 3, 7, 7]
258-
self.ksize = [3, 3]
259-
self.strides = [1, 1]
260-
self.paddings = [1, 1]
261181

182+
class TestCudnnCase6(TestCase5):
262183
def init_op_type(self):
263184
self.op_type = "pool2d_cudnn"
264185

265-
def init_pool_type(self):
266-
self.pool_type = "max"
267-
268186

269187
if __name__ == '__main__':
270188
unittest.main()

0 commit comments

Comments
 (0)