@@ -42,7 +42,7 @@ void RunSerdeTestTensor(platform::Place place) {
4242 int tensor_numel = 4 * 8 * 4 * 2 ;
4343 platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance ();
4444 auto & ctx = *pool.Get (place);
45- float * orig_tensor_data = tensor->mutable_data <float >(place);
45+ tensor->mutable_data <float >(place);
4646 math::set_constant (ctx, tensor, 31.9 );
4747
4848 ::grpc::ByteBuffer msg;
@@ -71,16 +71,9 @@ void RunSerdeTestTensor(platform::Place place) {
7171
7272 const float * tensor_data =
7373 reinterpret_cast <const float *>(varmsg.serialized ().data ());
74- for (int i = 0 ; i < varmsg.serialized ().size (); ++i) {
75- printf (" %02X " , varmsg.serialized ().data ()[i]);
76- }
77- printf (" \n " );
7874 for (int i = 0 ; i < tensor_numel; ++i) {
79- std::cout << " #####tensor data: " << tensor_data[i] << std::endl;
80- EXPECT_EQ (tensor_data[i], orig_tensor_data[i]);
81- std::cout << " test end 1 " << std::endl;
75+ EXPECT_FLOAT_EQ (tensor_data[i], 31.9 );
8276 }
83- std::cout << " tensor data end " << std::endl;
8477
8578 // deserialize zero-copy
8679 framework::Variable var2;
@@ -101,8 +94,7 @@ void RunSerdeTestTensor(platform::Place place) {
10194 EXPECT_EQ (varmsg.lod (0 ).lod_data (0 ), 1 );
10295 EXPECT_EQ (varmsg.lod (0 ).lod_data (1 ), 3 );
10396 EXPECT_EQ (varmsg.lod (0 ).lod_data (2 ), 8 );
104- for (int i = 0 ; i < tensor_numel; ++i)
105- EXPECT_EQ (tensor_data2[i], orig_tensor_data[i]);
97+ for (int i = 0 ; i < tensor_numel; ++i) EXPECT_FLOAT_EQ (tensor_data2[i], 31.9 );
10698}
10799
108100void RunSerdeTestSelectedRows (platform::Place place) {
@@ -114,10 +106,9 @@ void RunSerdeTestSelectedRows(platform::Place place) {
114106 auto * slr = var.GetMutable <framework::SelectedRows>();
115107 auto * tensor = slr->mutable_value ();
116108 auto * rows = slr->mutable_rows ();
117-
118109 tensor->Resize (framework::make_ddim ({2 , 10 }));
110+ tensor->mutable_data <float >(place);
119111 int tensor_numel = 2 * 10 ;
120- float * orig_tensor_data = tensor->mutable_data <float >(place);
121112 math::set_constant (ctx, tensor, 32.7 );
122113 rows->push_back (3 );
123114 rows->push_back (10 );
@@ -144,7 +135,7 @@ void RunSerdeTestSelectedRows(platform::Place place) {
144135 const int64_t * rows_data =
145136 reinterpret_cast <const int64_t *>(varmsg.rows ().data ());
146137 for (int i = 0 ; i < tensor_numel; ++i) {
147- EXPECT_EQ (tensor_data[i], orig_tensor_data[i] );
138+ EXPECT_FLOAT_EQ (tensor_data[i], 32.7 );
148139 }
149140 EXPECT_EQ (rows_data[0 ], 3 );
150141 EXPECT_EQ (rows_data[1 ], 10 );
@@ -168,21 +159,21 @@ void RunSerdeTestSelectedRows(platform::Place place) {
168159 const int64_t * rows_data2 = rows2->data ();
169160
170161 for (int i = 0 ; i < tensor_numel; ++i) {
171- EXPECT_EQ (tensor_data2[i], orig_tensor_data[i] );
162+ EXPECT_FLOAT_EQ (tensor_data2[i], 32.7 );
172163 }
173164 EXPECT_EQ (rows_data2[0 ], 3 );
174165 EXPECT_EQ (rows_data2[1 ], 10 );
175166}
176167
177- // TEST(SelectedRows, CPU) {
178- // platform::CPUPlace place;
179- // RunSerdeTestSelectedRows(place);
180- // }
168+ TEST (SelectedRows, CPU) {
169+ platform::CPUPlace place;
170+ RunSerdeTestSelectedRows (place);
171+ }
181172
182- // TEST(SelectedRows, GPU) {
183- // platform::CUDAPlace place;
184- // RunSerdeTestSelectedRows(place);
185- // }
173+ TEST (SelectedRows, GPU) {
174+ platform::CUDAPlace place;
175+ RunSerdeTestSelectedRows (place);
176+ }
186177
187178TEST (Tensor, CPU) {
188179 platform::CPUPlace place;
0 commit comments