@@ -52,7 +52,13 @@ class PoolCudnnOpKernel : public framework::OpKernel<T> {
5252 ScopedTensorDescriptor input_desc;
5353 ScopedTensorDescriptor output_desc;
5454 ScopedPoolingDescriptor pool_desc;
55- DataLayout layout = DataLayout::kNCHW ;
55+ DataLayout layout;
56+
57+ if (strides.size () == 2U ) {
58+ layout = DataLayout::kNCHW ;
59+ } else {
60+ layout = DataLayout::kNCDHW ;
61+ }
5662
5763 cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor <T>(
5864 layout, framework::vectorize2int (input->dims ()));
@@ -112,7 +118,13 @@ class PoolCudnnGradOpKernel : public framework::OpKernel<T> {
112118 ScopedTensorDescriptor input_desc;
113119 ScopedTensorDescriptor output_desc;
114120 ScopedPoolingDescriptor pool_desc;
115- DataLayout layout = DataLayout::kNCHW ;
121+ DataLayout layout;
122+
123+ if (strides.size () == 2U ) {
124+ layout = DataLayout::kNCHW ;
125+ } else {
126+ layout = DataLayout::kNCDHW ;
127+ }
116128
117129 cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor <T>(
118130 layout, framework::vectorize2int (input->dims ()));
@@ -150,5 +162,12 @@ class PoolCudnnGradOpKernel : public framework::OpKernel<T> {
150162
151163namespace ops = paddle::operators;
152164
153- REGISTER_OP_GPU_KERNEL (pool2d_cudnn, ops::PoolCudnnOpKernel<float >);
154- REGISTER_OP_GPU_KERNEL (pool2d_cudnn_grad, ops::PoolCudnnGradOpKernel<float >);
165+ REGISTER_OP_GPU_KERNEL (pool2d_cudnn, ops::PoolCudnnOpKernel<float >,
166+ ops::PoolCudnnOpKernel<double >);
167+ REGISTER_OP_GPU_KERNEL (pool2d_cudnn_grad, ops::PoolCudnnGradOpKernel<float >,
168+ ops::PoolCudnnGradOpKernel<double >);
169+
170+ REGISTER_OP_GPU_KERNEL (pool3d_cudnn, ops::PoolCudnnOpKernel<float >,
171+ ops::PoolCudnnOpKernel<double >);
172+ REGISTER_OP_GPU_KERNEL (pool3d_cudnn_grad, ops::PoolCudnnGradOpKernel<float >,
173+ ops::PoolCudnnGradOpKernel<double >);
0 commit comments