@@ -12,20 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212See the License for the specific language governing permissions and
1313limitations under the License. */
1414
15- #include < stdint.h>
1615#include < ostream>
16+ #include < thread>
1717
18- #include " paddle/fluid/framework/executor.h"
19- #include " paddle/fluid/framework/lod_tensor.h"
20- #include " paddle/fluid/framework/op_registry.h"
21- #include " paddle/fluid/framework/threadpool.h"
22- #include " paddle/fluid/operators/detail/grpc_server.h"
18+ #include " paddle/fluid/operators/listen_and_serv_op.h"
2319
2420namespace paddle {
2521namespace operators {
2622
27- constexpr char kOptimizeBlock [] = " OptimizeBlock" ;
28-
2923void RunServer (std::shared_ptr<detail::AsyncGRPCServer> service) {
3024 service->RunSyncUpdate ();
3125 VLOG (4 ) << " RunServer thread end" ;
@@ -66,143 +60,138 @@ static void ParallelExecuteBlocks(
6660 for (size_t i = 0 ; i < fs.size (); ++i) fs[i].wait ();
6761}
6862
69- class ListenAndServOp : public framework ::OperatorBase {
70- public:
71- ListenAndServOp (const std::string &type,
72- const framework::VariableNameMap &inputs,
73- const framework::VariableNameMap &outputs,
74- const framework::AttributeMap &attrs)
75- : OperatorBase(type, inputs, outputs, attrs) {
76- if (!rpc_service_) {
77- std::string endpoint = Attr<std::string>(" endpoint" );
78- rpc_service_.reset (new detail::AsyncGRPCServer (endpoint));
79- server_thread_.reset (new std::thread (RunServer, rpc_service_));
80- }
81- }
63+ ListenAndServOp::ListenAndServOp (const std::string &type,
64+ const framework::VariableNameMap &inputs,
65+ const framework::VariableNameMap &outputs,
66+ const framework::AttributeMap &attrs)
67+ : OperatorBase(type, inputs, outputs, attrs) {}
8268
83- void Stop () override {
84- rpc_service_->Push (LISTEN_TERMINATE_MESSAGE);
85- server_thread_->join ();
69+ int ListenAndServOp::GetSelectedPort () {
70+ return rpc_service_->GetSelectedPort ();
71+ }
72+
73+ void ListenAndServOp::Stop () {
74+ rpc_service_->Push (LISTEN_TERMINATE_MESSAGE);
75+ server_thread_->join ();
76+ }
77+
78+ void ListenAndServOp::RunImpl (const framework::Scope &scope,
79+ const platform::Place &dev_place) const {
80+ platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance ();
81+ auto &dev_ctx = *pool.Get (dev_place);
82+ framework::Scope &recv_scope = scope.NewScope ();
83+
84+ if (!rpc_service_) {
85+ std::string endpoint = Attr<std::string>(" endpoint" );
86+ rpc_service_.reset (new detail::AsyncGRPCServer (endpoint));
8687 }
8788
88- void RunImpl (const framework::Scope &scope,
89- const platform::Place &dev_place) const override {
90- platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance ();
91- auto &dev_ctx = *pool.Get (dev_place);
92- framework::Scope &recv_scope = scope.NewScope ();
93-
94- // FIXME(Yancey1989): initialize rpc server with lazy mode.
95- rpc_service_->SetScope (&recv_scope);
96- rpc_service_->SetDevCtx (&dev_ctx);
97- auto ins = Inputs (" X" );
98- auto fan_in = Attr<int >(" Fanin" );
99-
100- auto *block = Attr<framework::BlockDesc *>(kOptimizeBlock );
101- auto *program = block->Program ();
102- size_t num_blocks = program->Size ();
103- PADDLE_ENFORCE_GE (num_blocks, 2 ,
104- " server program should have at least 2 blocks" );
105-
106- framework::Executor executor (dev_place);
107- std::vector<int > block_list;
108- for (size_t blkid = 1 ; blkid < num_blocks; ++blkid)
109- block_list.push_back (blkid);
110- auto prepared = executor.Prepare (*program, block_list);
111- prepared.insert (
112- prepared.begin (),
113- std::shared_ptr<framework::ExecutorPrepareContext>(nullptr ));
114-
115- // TODO(qiao) set proper fields for table lookup and update
116- rpc_service_->SetExecutor (&executor);
117- rpc_service_->SetPrefetchBlkdId (0 );
118- rpc_service_->SetProgram (program);
119-
120- // TODO(typhoonzero): change this to a while_op for every cluster-batch.
121- bool exit_flag = false ;
122- // Record received sparse variables, so that
123- // we could reset those after execute optimize program
124- std::vector<framework::Variable *> sparse_vars;
125- while (!exit_flag) {
126- // Get from multiple trainers, we don't care about the order in which
127- // the gradients arrives, just add suffix 0~n and merge the gradient.
128- rpc_service_->SetCond (0 );
129- size_t recv_var_cnt = 0 ;
130- int batch_barrier = 0 ;
131- while (batch_barrier != fan_in) {
132- const detail::ReceivedMessage v = rpc_service_->Get ();
133- auto recv_var_name = v.first ;
134- if (recv_var_name == LISTEN_TERMINATE_MESSAGE) {
135- LOG (INFO) << " received terminate message and exit" ;
136- exit_flag = true ;
137- break ;
138- } else if (recv_var_name == BATCH_BARRIER_MESSAGE) {
139- VLOG (3 ) << " recv batch barrier message" ;
140- batch_barrier++;
141- continue ;
142- } else {
143- VLOG (3 ) << " received grad: " << recv_var_name;
144- recv_var_cnt++;
145- auto var = v.second ->GetVar ();
146- if (var == nullptr ) {
147- LOG (ERROR) << " Can not find server side var: " << recv_var_name;
148- PADDLE_THROW (" Can not find server side var" );
149- }
150- if (var->IsType <framework::SelectedRows>()) {
151- sparse_vars.push_back (var);
152- }
153- }
154- }
155- if (exit_flag) {
156- rpc_service_->SetCond (1 );
157- rpc_service_->ShutDown ();
89+ auto ins = Inputs (" X" );
90+ auto fan_in = Attr<int >(" Fanin" );
91+ auto *block = Attr<framework::BlockDesc *>(kOptimizeBlock );
92+ auto *program = block->Program ();
93+ size_t num_blocks = program->Size ();
94+ PADDLE_ENFORCE_GE (num_blocks, 2 ,
95+ " server program should have at least 2 blocks" );
96+
97+ framework::Executor executor (dev_place);
98+ std::vector<int > block_list;
99+ for (size_t blkid = 1 ; blkid < num_blocks; ++blkid) {
100+ block_list.push_back (blkid);
101+ }
102+ auto prepared = executor.Prepare (*program, block_list);
103+ // Insert placeholder for block0 which holds current op itself.
104+ prepared.insert (prepared.begin (),
105+ std::shared_ptr<framework::ExecutorPrepareContext>(nullptr ));
106+
107+ rpc_service_->SetScope (&recv_scope);
108+ rpc_service_->SetDevCtx (&dev_ctx);
109+ // TODO(qiao) set proper fields for table lookup and update
110+ rpc_service_->SetExecutor (&executor);
111+ rpc_service_->SetPrefetchBlkdId (0 );
112+ rpc_service_->SetProgram (program);
113+ // start the server listening after all member initialized.
114+ server_thread_.reset (new std::thread (RunServer, rpc_service_));
115+ // FIXME(typhoonzero): do we need to wait until the server port is ready?
116+ sleep (5 );
117+
118+ // TODO(typhoonzero): change this to a while_op for every cluster-batch.
119+ bool exit_flag = false ;
120+ // Record received sparse variables, so that
121+ // we could reset those after execute optimize program
122+ std::vector<framework::Variable *> sparse_vars;
123+ while (!exit_flag) {
124+ // Get from multiple trainers, we don't care about the order in which
125+ // the gradients arrives, just add suffix 0~n and merge the gradient.
126+ rpc_service_->SetCond (0 );
127+ size_t recv_var_cnt = 0 ;
128+ int batch_barrier = 0 ;
129+ while (batch_barrier != fan_in) {
130+ const detail::ReceivedMessage v = rpc_service_->Get ();
131+ auto recv_var_name = v.first ;
132+ if (recv_var_name == LISTEN_TERMINATE_MESSAGE) {
133+ LOG (INFO) << " received terminate message and exit" ;
134+ exit_flag = true ;
158135 break ;
159- }
160-
161- // NOTE: if is_gpu_place, CUDA kernels are laugched by multiple threads
162- // and this will still work.
163-
164- // The optimize blocks which have the same parent ID would run parallel
165- // TODO(Yancey1989): need to use ParallelExecutor for future
166- int32_t last_parent_blkid = program->Block (1 ).Parent ();
167- std::vector<size_t > parallel_blkids;
168- parallel_blkids.push_back (1 );
169- double ts = detail::GetTimestamp ();
170- for (size_t blkid = 2 ; blkid < num_blocks; ++blkid) {
171- if (program->Block (blkid).Parent () != last_parent_blkid) {
172- for (size_t idx : parallel_blkids) VLOG (3 ) << idx;
173- ParallelExecuteBlocks (parallel_blkids, &executor, prepared, program,
174- &recv_scope);
175- parallel_blkids.clear ();
176- last_parent_blkid = program->Block (blkid).Parent ();
136+ } else if (recv_var_name == BATCH_BARRIER_MESSAGE) {
137+ VLOG (3 ) << " recv batch barrier message" ;
138+ batch_barrier++;
139+ continue ;
140+ } else {
141+ VLOG (3 ) << " received grad: " << recv_var_name;
142+ recv_var_cnt++;
143+ auto var = v.second ->GetVar ();
144+ if (var == nullptr ) {
145+ LOG (ERROR) << " Can not find server side var: " << recv_var_name;
146+ PADDLE_THROW (" Can not find server side var" );
147+ }
148+ if (var->IsType <framework::SelectedRows>()) {
149+ sparse_vars.push_back (var);
177150 }
178- parallel_blkids.push_back (blkid);
179- }
180- ParallelExecuteBlocks (parallel_blkids, &executor, prepared, program,
181- &recv_scope);
182-
183- VLOG (3 ) << " run all blocks spent " << detail::GetTimestamp () - ts
184- << " (ms)" ;
185-
186- // Reset the received sparse variables, the sum operator would not
187- // sum the input sparse variables which rows is empty at the next
188- // mini-batch.
189- // TODO(Yancey1989): move the reset action into an operator, we couldn't
190- // have any hide logic in the operator.
191- for (auto &var : sparse_vars) {
192- var->GetMutable <framework::SelectedRows>()->mutable_rows ()->clear ();
193151 }
152+ }
153+ if (exit_flag) {
194154 rpc_service_->SetCond (1 );
195- // NOTE: does not consider barrier request retry in here, we may use
196- // global barrier id to resolve this.
197- rpc_service_->WaitClientGet (fan_in);
198- sparse_vars.clear ();
199- } // while(true)
200- }
155+ rpc_service_->ShutDown ();
156+ break ;
157+ }
201158
202- protected:
203- std::shared_ptr<detail::AsyncGRPCServer> rpc_service_;
204- std::shared_ptr<std::thread> server_thread_;
205- };
159+ // NOTE: if is_gpu_place, CUDA kernels are laugched by multiple threads
160+ // and this will still work.
161+
162+ // The optimize blocks which have the same parent ID would run parallel
163+ // TODO(Yancey1989): need to use ParallelExecutor for future
164+ int32_t last_parent_blkid = program->Block (1 ).Parent ();
165+ std::vector<size_t > parallel_blkids;
166+ parallel_blkids.push_back (1 );
167+ double ts = detail::GetTimestamp ();
168+ for (size_t blkid = 2 ; blkid < num_blocks; ++blkid) {
169+ if (program->Block (blkid).Parent () != last_parent_blkid) {
170+ ParallelExecuteBlocks (parallel_blkids, &executor, prepared, program,
171+ &recv_scope);
172+ parallel_blkids.clear ();
173+ last_parent_blkid = program->Block (blkid).Parent ();
174+ }
175+ parallel_blkids.push_back (blkid);
176+ }
177+ ParallelExecuteBlocks (parallel_blkids, &executor, prepared, program,
178+ &recv_scope);
179+ VLOG (2 ) << " run all blocks spent " << detail::GetTimestamp () - ts << " (ms)" ;
180+
181+ // Reset the received sparse variables, the sum operator would not
182+ // sum the input sparse variables which rows is empty at the next
183+ // mini-batch.
184+ // TODO(Yancey1989): move the reset action into an operator, we couldn't
185+ // have any hide logic in the operator.
186+ for (auto &var : sparse_vars) {
187+ var->GetMutable <framework::SelectedRows>()->mutable_rows ()->clear ();
188+ }
189+ rpc_service_->SetCond (1 );
190+ // FIXME(typhoonzero): use another condition to sync wait clients get.
191+ rpc_service_->WaitClientGet (fan_in);
192+ sparse_vars.clear ();
193+ } // while(true)
194+ }
206195
207196class ListenAndServOpMaker : public framework ::OpProtoAndCheckerMaker {
208197 public:
0 commit comments