@@ -27,20 +27,6 @@ void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service) {
2727 VLOG (4 ) << " RunServer thread end" ;
2828}
2929
30- static void CreateTensorFromMessageType (framework::Variable *var,
31- sendrecv::VarType var_type) {
32- if (var_type == sendrecv::VarType::LOD_TENSOR) {
33- var->GetMutable <framework::LoDTensor>();
34- } else if (var_type == sendrecv::VarType::SELECTED_ROWS) {
35- var->GetMutable <framework::SelectedRows>();
36- } else {
37- PADDLE_THROW (
38- " VariableMessage type %d is not in "
39- " [LoDTensor, SelectedRows]" ,
40- var_type);
41- }
42- }
43-
4430static void ParallelExecuteBlocks (
4531 const std::vector<size_t > ¶llel_blkids, framework::Executor *executor,
4632 const std::vector<std::shared_ptr<framework::ExecutorPrepareContext>>
@@ -62,6 +48,13 @@ static void ParallelExecuteBlocks(
6248 for (size_t i = 0 ; i < fs.size (); ++i) fs[i].wait ();
6349}
6450
51+ static void SavePort (std::shared_ptr<detail::AsyncGRPCServer> rpc_service) {
52+ std::ofstream port_file;
53+ port_file.open (" /tmp/paddle.selected_port" );
54+ port_file << rpc_service->GetSelectedPort ();
55+ port_file.close ();
56+ }
57+
6558ListenAndServOp::ListenAndServOp (const std::string &type,
6659 const framework::VariableNameMap &inputs,
6760 const framework::VariableNameMap &outputs,
@@ -77,59 +70,26 @@ void ListenAndServOp::Stop() {
7770 server_thread_->join ();
7871}
7972
80- void ListenAndServOp::RunImpl (const framework::Scope &scope,
81- const platform::Place &dev_place) const {
82- platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance ();
83- auto &dev_ctx = *pool.Get (dev_place);
84- framework::Scope &recv_scope = scope.NewScope ();
85-
86- if (!rpc_service_) {
87- std::string endpoint = Attr<std::string>(" endpoint" );
88- rpc_service_.reset (new detail::AsyncGRPCServer (endpoint));
89- }
90-
91- auto ins = Inputs (" X" );
73+ void ListenAndServOp::RunSyncLoop (framework::Executor *executor,
74+ framework::ProgramDesc *program,
75+ framework::Scope *recv_scope,
76+ framework::BlockDesc *prefetch_block) const {
9277 auto fan_in = Attr<int >(" Fanin" );
93- auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock );
94- auto *prefetch_block = Attr<framework::BlockDesc *>(kPrefetchBlock );
95- auto *program = optimize_block->Program ();
78+
9679 size_t num_blocks = program->Size ();
9780 PADDLE_ENFORCE_GE (num_blocks, 2 ,
9881 " server program should have at least 2 blocks" );
9982
100- framework::Executor executor (dev_place);
10183 std::vector<int > block_list;
10284 for (size_t blkid = 1 ; blkid < num_blocks; ++blkid) {
103- if (blkid != static_cast <size_t >(prefetch_block->ID ())) {
104- block_list.push_back (blkid);
105- }
85+ block_list.push_back (blkid);
10686 }
107- auto optimize_prepared = executor. Prepare (*program, block_list);
87+ auto optimize_prepared = executor-> Prepare (*program, block_list);
10888 // Insert placeholder for block0 which holds current op itself.
10989 optimize_prepared.insert (
11090 optimize_prepared.begin (),
11191 std::shared_ptr<framework::ExecutorPrepareContext>(nullptr ));
11292
113- rpc_service_->SetScope (&recv_scope);
114- rpc_service_->SetDevCtx (&dev_ctx);
115- // TODO(qiao) set proper fields for table lookup and update
116- rpc_service_->SetExecutor (&executor);
117- VLOG (3 ) << " prefetch block id is " << prefetch_block->ID ();
118- auto prefetch_prepared = executor.Prepare (*program, prefetch_block->ID ());
119- rpc_service_->SetPrefetchBlkdId (prefetch_block->ID ());
120- rpc_service_->SetPrefetchPreparedCtx (prefetch_prepared.get ());
121- prefetch_prepared.release ();
122- rpc_service_->SetProgram (program);
123- // start the server listening after all member initialized.
124- server_thread_.reset (new std::thread (RunServer, rpc_service_));
125- VLOG (3 ) << " wait server thread to become ready..." ;
126- sleep (5 );
127- // Write to a file of server selected port for python use.
128- std::ofstream port_file;
129- port_file.open (" /tmp/paddle.selected_port" );
130- port_file << rpc_service_->GetSelectedPort ();
131- port_file.close ();
132-
13393 bool exit_flag = false ;
13494 // Record received sparse variables, so that
13595 // we could reset those after execute optimize program
@@ -170,7 +130,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
170130 break ;
171131 }
172132
173- // NOTE: if is_gpu_place, CUDA kernels are laugched by multiple threads
133+ // NOTE: if is_gpu_place, CUDA kernels are launched by multiple threads
174134 // and this will still work.
175135
176136 // The optimize blocks which have the same parent ID would run parallel
@@ -182,16 +142,16 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
182142 for (size_t blkid = 2 ; blkid < num_blocks; ++blkid) {
183143 if (blkid != static_cast <size_t >(prefetch_block->ID ())) {
184144 if (program->Block (blkid).Parent () != last_parent_blkid) {
185- ParallelExecuteBlocks (parallel_blkids, & executor, optimize_prepared,
186- program, & recv_scope);
145+ ParallelExecuteBlocks (parallel_blkids, executor, optimize_prepared,
146+ program, recv_scope);
187147 parallel_blkids.clear ();
188148 last_parent_blkid = program->Block (blkid).Parent ();
189149 }
190150 parallel_blkids.push_back (blkid);
191151 }
192152 }
193- ParallelExecuteBlocks (parallel_blkids, & executor, optimize_prepared,
194- program, & recv_scope);
153+ ParallelExecuteBlocks (parallel_blkids, executor, optimize_prepared, program ,
154+ recv_scope);
195155 VLOG (2 ) << " run all blocks spent " << detail::GetTimestamp () - ts << " (ms)" ;
196156
197157 // Reset the received sparse variables, the sum operator would not
@@ -209,6 +169,42 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
209169 } // while(true)
210170}
211171
172+ void ListenAndServOp::RunImpl (const framework::Scope &scope,
173+ const platform::Place &dev_place) const {
174+ platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance ();
175+ auto &dev_ctx = *pool.Get (dev_place);
176+ framework::Scope &recv_scope = scope.NewScope ();
177+
178+ PADDLE_ENFORCE (!rpc_service_);
179+ std::string endpoint = Attr<std::string>(" endpoint" );
180+ rpc_service_.reset (new detail::AsyncGRPCServer (endpoint));
181+
182+ auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock );
183+ auto *prefetch_block = Attr<framework::BlockDesc *>(kPrefetchBlock );
184+ auto *program = optimize_block->Program ();
185+ framework::Executor executor (dev_place);
186+
187+ // prepare rpc_service
188+ rpc_service_->SetScope (&recv_scope);
189+ rpc_service_->SetDevCtx (&dev_ctx);
190+ rpc_service_->SetProgram (program);
191+ rpc_service_->SetExecutor (&executor);
192+
193+ // prepare for prefetch
194+ VLOG (3 ) << " prefetch block id is " << prefetch_block->ID ();
195+ auto prefetch_prepared = executor.Prepare (*program, prefetch_block->ID ());
196+ rpc_service_->SetPrefetchPreparedCtx (prefetch_prepared.get ());
197+ prefetch_prepared.release ();
198+
199+ // start the server listening after all member initialized.
200+ server_thread_.reset (new std::thread (RunServer, rpc_service_));
201+ VLOG (3 ) << " wait server thread to become ready..." ;
202+ sleep (5 );
203+ // Write to a file of server selected port for python use.
204+ SavePort (rpc_service_);
205+ RunSyncLoop (&executor, program, &recv_scope, prefetch_block);
206+ }
207+
212208class ListenAndServOpMaker : public framework ::OpProtoAndCheckerMaker {
213209 public:
214210 ListenAndServOpMaker (OpProto *proto, OpAttrChecker *op_checker)
0 commit comments