@@ -29,129 +29,127 @@ namespace paddle {
2929namespace operators {
3030namespace detail {
3131
32+ using VarMsg = sendrecv::VariableMessage;
33+
34+ void GetTensorPayload (framework::Variable* var,
35+ const platform::DeviceContext& ctx, VarMsg* request,
36+ void ** payload, size_t * payload_size) {
37+ auto tensor = var->Get <framework::LoDTensor>();
38+ // FIXME(wuyi): data types in send_recv.proto is copied from
39+ // framework.proto
40+ request->set_data_type (
41+ static_cast <VarMsg::Type>(framework::ToDataType (tensor.type ())));
42+ for (auto & dim : framework::vectorize (tensor.dims ())) {
43+ request->add_dims (dim);
44+ }
45+ const framework::LoD lod = tensor.lod ();
46+ if (lod.size () > 0 ) {
47+ request->set_lod_level (lod.size ());
48+ for (auto & each : lod) {
49+ VarMsg::LodData* lod_inner = request->add_lod ();
50+ for (auto & d : each) {
51+ lod_inner->add_lod_data (d);
52+ }
53+ }
54+ }
55+ if (platform::is_gpu_place (ctx.GetPlace ())) {
56+ #ifdef PADDLE_WITH_CUDA
57+ PADDLE_ENFORCE (platform::is_gpu_place (tensor.place ()));
58+ platform::CPUPlace cpu;
59+ auto & gpu_dev_ctx = static_cast <const platform::CUDADeviceContext&>(ctx);
60+ auto copy_size = tensor.numel () * framework::SizeOfType (tensor.type ());
61+ *payload = memory::Alloc (cpu, copy_size);
62+
63+ memory::Copy (cpu, *payload, boost::get<platform::CUDAPlace>(tensor.place ()),
64+ reinterpret_cast <const void *>(tensor.data <void >()), copy_size,
65+ gpu_dev_ctx.stream ());
66+ ctx.Wait ();
67+ #endif
68+ } else {
69+ *payload = tensor.data <void >();
70+ }
71+ *payload_size = tensor.numel () * framework::SizeOfType (tensor.type ());
72+ }
73+
74+ void GetSelectedRowsPayload (framework::Variable* var,
75+ const platform::DeviceContext& ctx, VarMsg* request,
76+ void ** payload, size_t * payload_size) {
77+ auto * slr = var->GetMutable <framework::SelectedRows>();
78+ request->set_data_type (
79+ static_cast <VarMsg::Type>(framework::ToDataType (slr->value ().type ())));
80+ request->set_lod_level (0 );
81+ request->set_slr_height (slr->height ());
82+
83+ for (auto & dim : framework::vectorize (slr->value ().dims ())) {
84+ request->add_dims (dim);
85+ }
86+
87+ auto * tensor = slr->mutable_value ();
88+ if (platform::is_gpu_place (ctx.GetPlace ())) {
89+ #ifdef PADDLE_WITH_CUDA
90+ platform::CPUPlace cpu;
91+ auto & gpu_dev_ctx = static_cast <const platform::CUDADeviceContext&>(ctx);
92+ auto copy_size = tensor->numel () * framework::SizeOfType (tensor->type ());
93+ *payload = memory::Alloc (cpu, copy_size);
94+ memory::Copy (cpu, *payload,
95+ boost::get<platform::CUDAPlace>(tensor->place ()),
96+ reinterpret_cast <const void *>(tensor->data <void >()), copy_size,
97+ gpu_dev_ctx.stream ());
98+ ctx.Wait ();
99+ #endif
100+ } else {
101+ *payload = slr->mutable_value ()->data <void >();
102+ }
103+ *payload_size = tensor->numel () * framework::SizeOfType (tensor->type ());
104+ }
105+
32106void SerializeToByteBuffer (const std::string& name, framework::Variable* var,
33107 const platform::DeviceContext& ctx,
34108 ::grpc::ByteBuffer* msg,
35109 const std::string& out_name) {
36- using VarMsg = sendrecv::VariableMessage;
37- // When using GPU, need to free the copied CPU buffer
38- // when the ByteBuffer destroies
39- // TODO(typhoonzero): add unref here, if we have dependent
40- // parallelism execution, need to know when to free the tensor.
110+ // Default DestroyCallback does nothing, When using GPU
111+ // the CPU buffer need to be freed.
41112 DestroyCallback destroy_callback = [](void * backing) {};
42-
43- auto buffer = std::unique_ptr<char []>(new char [1024 ]);
44- void * buf = buffer.get ();
45-
113+ VarMsg request;
46114 void * payload = nullptr ;
47115 size_t payload_size;
48- ProtoEncodeHelper e (static_cast <char *>(buf), 1024 );
116+
117+ request.set_varname (name);
49118 // Note: normally the profiler is enabled in 1 trainer, hence only
50119 // 1 trainer returns true for ShouldSendProfileState(). It tells PS
51120 // servers the trainer's profiling state so that PS can follow the
52121 // trainer.
53- if (platform::ShouldSendProfileState ()) {
54- e.WriteBool (VarMsg::kProfileFieldNumber , platform::IsProfileEnabled ());
122+ request.set_profile (platform::IsProfileEnabled ());
123+ if (!out_name.empty ()) {
124+ request.set_out_varname (out_name);
55125 }
56- e.WriteString (VarMsg::kVarnameFieldNumber , name);
57126 if (var->IsType <framework::LoDTensor>()) {
58- e.WriteUint64 (VarMsg::kTypeFieldNumber , 0 );
127+ request.set_type (::sendrecv::LOD_TENSOR);
128+ GetTensorPayload (var, ctx, &request, &payload, &payload_size);
59129 } else if (var->IsType <framework::SelectedRows>()) {
60- e.WriteUint64 (VarMsg::kTypeFieldNumber , 1 );
130+ request.set_type (::sendrecv::SELECTED_ROWS);
131+ GetSelectedRowsPayload (var, ctx, &request, &payload, &payload_size);
132+ } else {
133+ PADDLE_THROW (" Serialize does not support type: %s" ,
134+ typeid (var->Type ()).name ());
61135 }
62136
63- if (!out_name.empty ()) {
64- e.WriteString (VarMsg::kOutVarnameFieldNumber , out_name);
137+ if (platform::is_gpu_place (ctx.GetPlace ())) {
138+ // GPU data is copied to CPU buffer when sending,
139+ // free the buffer when possible.
140+ destroy_callback = [](void * backing) {
141+ platform::CPUPlace cpu;
142+ memory::Free (cpu, backing);
143+ };
65144 }
66- switch (framework::ToVarType (var->Type ())) {
67- case framework::proto::VarType_Type_LOD_TENSOR: {
68- auto tensor = var->Get <framework::LoDTensor>();
69- e.WriteUint64 (VarMsg::kDataTypeFieldNumber ,
70- framework::ToDataType (tensor.type ()));
71- for (auto & dim : framework::vectorize (tensor.dims ())) {
72- e.WriteUint64 (VarMsg::kDimsFieldNumber , dim);
73- }
74- auto lod = tensor.lod (); // std::vector<Vector<size_t>>
75- if (lod.size () > 0 ) {
76- e.WriteUint64 (VarMsg::kLodLevelFieldNumber , lod.size ());
77-
78- for (auto & each : lod) {
79- e.WriteVarlengthBeginning (VarMsg::kLodFieldNumber ,
80- 2 + // tag + varintlength of submessage
81- 1 + // kLodDataFieldNumber
82- each.size ());
83- // auto copied from GPU
84- for (auto & d : each) {
85- e.WriteUint64 (VarMsg::LodData::kLodDataFieldNumber , d);
86- }
87- }
88- }
89- if (platform::is_gpu_place (ctx.GetPlace ())) {
90- #ifdef PADDLE_WITH_CUDA
91- PADDLE_ENFORCE (platform::is_gpu_place (tensor.place ()));
92- platform::CPUPlace cpu;
93- auto & gpu_dev_ctx =
94- static_cast <const platform::CUDADeviceContext&>(ctx);
95- auto copy_size = tensor.numel () * framework::SizeOfType (tensor.type ());
96- payload = memory::Alloc (cpu, copy_size);
97-
98- memory::Copy (cpu, payload,
99- boost::get<platform::CUDAPlace>(tensor.place ()),
100- reinterpret_cast <const void *>(tensor.data <void >()),
101- copy_size, gpu_dev_ctx.stream ());
102- ctx.Wait ();
103- destroy_callback = [](void * backing) {
104- platform::CPUPlace cpu;
105- memory::Free (cpu, backing);
106- };
107145
108- #endif
109- } else {
110- payload = tensor.data <void >();
111- }
112- payload_size = tensor.numel () * framework::SizeOfType (tensor.type ());
113- e.WriteVarlengthBeginning (VarMsg::kSerializedFieldNumber , payload_size);
114- } break ;
115- case framework::proto::VarType_Type_SELECTED_ROWS: {
116- // TODO(typhoonzero): selectedrows implement should not use unique_ptr
117- auto * slr = var->GetMutable <framework::SelectedRows>();
118- e.WriteUint64 (VarMsg::kDataTypeFieldNumber ,
119- framework::ToDataType (slr->value ().type ()));
120- for (auto & dim : framework::vectorize (slr->value ().dims ())) {
121- e.WriteUint64 (VarMsg::kDimsFieldNumber , dim);
122- }
123- e.WriteUint64 (VarMsg::kLodLevelFieldNumber , 0 );
124- e.WriteUint64 (VarMsg::kSlrHeightFieldNumber , slr->height ());
125- auto * tensor = slr->mutable_value ();
126- if (platform::is_gpu_place (ctx.GetPlace ())) {
127- #ifdef PADDLE_WITH_CUDA
128- platform::CPUPlace cpu;
129- auto & gpu_dev_ctx =
130- static_cast <const platform::CUDADeviceContext&>(ctx);
131- auto copy_size =
132- tensor->numel () * framework::SizeOfType (tensor->type ());
133- payload = memory::Alloc (cpu, copy_size);
134- memory::Copy (cpu, payload,
135- boost::get<platform::CUDAPlace>(tensor->place ()),
136- reinterpret_cast <const void *>(tensor->data <void >()),
137- copy_size, gpu_dev_ctx.stream ());
138- ctx.Wait ();
139- destroy_callback = [](void * backing) {
140- platform::CPUPlace cpu;
141- memory::Free (cpu, backing);
142- };
143- #endif
144- } else {
145- payload = slr->mutable_value ()->data <void >();
146- }
147- payload_size = tensor->numel () * framework::SizeOfType (tensor->type ());
148- e.WriteVarlengthBeginning (VarMsg::kSerializedFieldNumber , payload_size);
149- } break ;
150- default :
151- PADDLE_THROW (" Serialize does not support type: %s" ,
152- typeid (var->Type ()).name ());
153- break ;
154- }
146+ std::string header;
147+ request.AppendToString (&header);
148+ auto buffer = std::unique_ptr<char []>(new char [1024 ]);
149+ void * buf = buffer.get ();
150+ ProtoEncodeHelper e (static_cast <char *>(buf), 1024 );
151+ e.WriteRawBytes (std::string (header.data (), header.size ()));
152+ e.WriteVarlengthBeginning (VarMsg::kSerializedFieldNumber , payload_size);
155153 // steal reference of tensor data
156154 ::grpc::Slice slices[4 ]; // metadata, tensor, rows meta, rows
157155 int num_slices = 2 ; // only SelectedRows have rows buffer
@@ -162,12 +160,9 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
162160 static_cast <char *>(payload)),
163161 ::grpc::Slice::STEAL_REF);
164162
165- if (framework::ToVarType (var->Type ()) ==
166- framework::proto::VarType_Type_SELECTED_ROWS) {
163+ if (var->IsType <framework::SelectedRows>()) {
167164 auto * slr = var->GetMutable <framework::SelectedRows>();
168-
169165 ProtoEncodeHelper e2 (static_cast <char *>(buf), 128 );
170- // NOTE: rows is of type int64_t
171166 size_t rows_memory_size =
172167 slr->rows ().size () * framework::SizeOfType (typeid (int64_t ));
173168 e2 .WriteVarlengthBeginning (VarMsg::kRowsFieldNumber , rows_memory_size);
@@ -178,10 +173,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
178173 grpc_slice_new_with_user_data (
179174 const_cast <void *>(
180175 reinterpret_cast <const void *>(slr->rows ().data ())),
181- rows_memory_size,
182- [](void * backing) {
183- // TODO(typhoonzero): add unref here, same as above.
184- },
176+ rows_memory_size, [](void * backing) {},
185177 const_cast <char *>(
186178 reinterpret_cast <const char *>(slr->rows ().data ()))),
187179 ::grpc::Slice::STEAL_REF);
0 commit comments