@@ -72,6 +72,9 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
7272 int N, C, H, W, D;
7373 ExtractNCWHD (x_dims, data_layout, &N, &C, &H, &W, &D);
7474
75+ auto *y = ctx.Output <Tensor>(" Y" );
76+ y->mutable_data <T>(ctx.GetPlace ());
77+
7578 // ------------------- cudnn descriptors ---------------------
7679 cudnnTensorDescriptor_t data_desc_;
7780 cudnnTensorDescriptor_t bn_param_desc_;
@@ -93,7 +96,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
9396 mode_ = CUDNN_BATCHNORM_SPATIAL;
9497#endif
9598
96- VLOG (1 ) << " Setting descriptors." ;
99+ VLOG (3 ) << " Setting descriptors." ;
97100 std::vector<int > dims;
98101 std::vector<int > strides;
99102 if (data_layout == DataLayout::kNCHW ) {
@@ -113,11 +116,6 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
113116 const auto *scale = ctx.Input <Tensor>(" Scale" );
114117 const auto *bias = ctx.Input <Tensor>(" Bias" );
115118
116- auto *y = ctx.Output <Tensor>(" Y" );
117-
118- // alloc memory
119- y->mutable_data <T>(ctx.GetPlace ());
120-
121119 auto &dev_ctx = ctx.template device_context <platform::CUDADeviceContext>();
122120
123121 auto handle = dev_ctx.cudnn_handle ();
@@ -162,22 +160,28 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
162160 functor (dev_ctx, saved_mean, static_cast <BatchNormParamType<T>>(0 ));
163161 functor (dev_ctx, saved_variance, static_cast <BatchNormParamType<T>>(0 ));
164162
165- double this_factor = 1 . - momentum;
166-
167- CUDNN_ENFORCE (platform::dynload::cudnnBatchNormalizationForwardTraining (
168- handle, mode_, CudnnDataType<T>::kOne (), CudnnDataType<T>::kZero (),
169- data_desc_, x->template data <T>(), data_desc_,
170- y->template mutable_data <T>(ctx.GetPlace ()), bn_param_desc_,
171- scale->template data <BatchNormParamType<T>>(),
172- bias->template data <BatchNormParamType<T>>(), this_factor,
173- mean_out->template mutable_data <BatchNormParamType<T>>(
174- ctx.GetPlace ()),
175- variance_out->template mutable_data <BatchNormParamType<T>>(
176- ctx.GetPlace ()),
177- epsilon, saved_mean->template mutable_data <BatchNormParamType<T>>(
178- ctx.GetPlace ()),
179- saved_variance->template mutable_data <BatchNormParamType<T>>(
180- ctx.GetPlace ())));
163+ if ((N * H * W * D) == 1 ) {
164+ LOG (WARNING) << " Only 1 element in normalization dimension, "
165+ << " we skip the batch norm calculation, let y = x." ;
166+ framework::TensorCopySync (*x, ctx.GetPlace (), y);
167+ } else {
168+ double this_factor = 1 . - momentum;
169+
170+ CUDNN_ENFORCE (platform::dynload::cudnnBatchNormalizationForwardTraining (
171+ handle, mode_, CudnnDataType<T>::kOne (), CudnnDataType<T>::kZero (),
172+ data_desc_, x->template data <T>(), data_desc_,
173+ y->template mutable_data <T>(ctx.GetPlace ()), bn_param_desc_,
174+ scale->template data <BatchNormParamType<T>>(),
175+ bias->template data <BatchNormParamType<T>>(), this_factor,
176+ mean_out->template mutable_data <BatchNormParamType<T>>(
177+ ctx.GetPlace ()),
178+ variance_out->template mutable_data <BatchNormParamType<T>>(
179+ ctx.GetPlace ()),
180+ epsilon, saved_mean->template mutable_data <BatchNormParamType<T>>(
181+ ctx.GetPlace ()),
182+ saved_variance->template mutable_data <BatchNormParamType<T>>(
183+ ctx.GetPlace ())));
184+ }
181185 }
182186
183187 // clean when exit.
@@ -209,6 +213,25 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
209213 int N, C, H, W, D;
210214 ExtractNCWHD (x_dims, data_layout, &N, &C, &H, &W, &D);
211215
216+ // init output
217+ auto *d_x = ctx.Output <Tensor>(framework::GradVarName (" X" ));
218+ auto *d_scale = ctx.Output <Tensor>(framework::GradVarName (" Scale" ));
219+ auto *d_bias = ctx.Output <Tensor>(framework::GradVarName (" Bias" ));
220+
221+ d_x->mutable_data <T>(ctx.GetPlace ());
222+ d_scale->mutable_data <T>(ctx.GetPlace ());
223+ d_bias->mutable_data <T>(ctx.GetPlace ());
224+
225+ auto &dev_ctx = ctx.template device_context <platform::CUDADeviceContext>();
226+ if ((N * H * W * D) == 1 ) {
227+ framework::TensorCopySync (*d_y, ctx.GetPlace (), d_x);
228+ math::SetConstant<platform::CUDADeviceContext, BatchNormParamType<T>>
229+ functor;
230+ functor (dev_ctx, d_scale, static_cast <BatchNormParamType<T>>(0 ));
231+ functor (dev_ctx, d_bias, static_cast <BatchNormParamType<T>>(0 ));
232+ return ;
233+ }
234+
212235 PADDLE_ENFORCE_EQ (scale->dims ().size (), 1UL );
213236 PADDLE_ENFORCE_EQ (scale->dims ()[0 ], C);
214237
@@ -247,21 +270,11 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
247270 CUDNN_ENFORCE (platform::dynload::cudnnDeriveBNTensorDescriptor (
248271 bn_param_desc_, data_desc_, mode_));
249272
250- // init output
251- auto *d_x = ctx.Output <Tensor>(framework::GradVarName (" X" ));
252- auto *d_scale = ctx.Output <Tensor>(framework::GradVarName (" Scale" ));
253- auto *d_bias = ctx.Output <Tensor>(framework::GradVarName (" Bias" ));
254-
255- d_x->mutable_data <T>(ctx.GetPlace ());
256- d_scale->mutable_data <T>(ctx.GetPlace ());
257- d_bias->mutable_data <T>(ctx.GetPlace ());
258-
259273 const auto *saved_mean = ctx.Input <Tensor>(" SavedMean" );
260274 const auto *saved_var = ctx.Input <Tensor>(" SavedVariance" );
261275 const void *saved_mean_data = saved_mean->template data <T>();
262276 const void *saved_var_data = saved_var->template data <T>();
263277
264- auto &dev_ctx = ctx.template device_context <platform::CUDADeviceContext>();
265278 CUDNN_ENFORCE (platform::dynload::cudnnBatchNormalizationBackward (
266279 dev_ctx.cudnn_handle (), mode_, CudnnDataType<T>::kOne (),
267280 CudnnDataType<T>::kZero (), CudnnDataType<T>::kOne (),
0 commit comments