@@ -78,6 +78,33 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
7878 }
7979}
8080
81+ bool MultiDevSSAGraphBuilder::IsDistTrainOp (const OpDesc &op,
82+ OpDesc *send_op) const {
83+ if (send_op == nullptr ) {
84+ return false ;
85+ }
86+
87+ auto checker = [&](const std::vector<std::string> opvars,
88+ const std::vector<std::string> sendvars) -> bool {
89+ bool is_dist_train_op = false ;
90+ for (auto &var : opvars) {
91+ if (var.find (" .block" ) != std::string::npos &&
92+ std::find (sendvars.begin (), sendvars.end (), var) != sendvars.end ()) {
93+ is_dist_train_op = true ;
94+ break ;
95+ }
96+ }
97+ return is_dist_train_op;
98+ };
99+
100+ if (op.Type () == " split" ) {
101+ return checker (op.OutputArgumentNames (), send_op->InputArgumentNames ());
102+ } else if (op.Type () == " concat" ) {
103+ return checker (op.InputArgumentNames (), send_op->OutputArgumentNames ());
104+ }
105+ return false ;
106+ }
107+
81108std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build (
82109 const ProgramDesc &program) const {
83110 auto graph = new SSAGraph ();
@@ -89,19 +116,30 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
89116 std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>(
90117 places_.size ());
91118
119+ // Find "send" op first for split is in front of send.
120+ OpDesc *send_op = nullptr ;
121+ for (auto *op : program.Block (0 ).AllOps ()) {
122+ if (op->Type () == " send" ) {
123+ send_op = op;
124+ break ;
125+ }
126+ }
127+
92128 bool is_forwarding = true ;
93129 for (auto *op : program.Block (0 ).AllOps ()) {
94130 if (op->Type () == " send" ) {
95131 // append send op if program is distributed trainer main program.
96132 // always use the first device
97133 CreateSendOp (&result, *op);
134+ } else if (IsDistTrainOp (*op, send_op)) {
135+ CreateComputationalOps (&result, *op, 1 );
98136 } else if (IsScaleLossOp (*op)) {
99137 if (!skip_scale_loss_) {
100138 CreateScaleLossGradOp (&result);
101139 }
102140 is_forwarding = false ;
103141 } else {
104- CreateComputationalOps (&result, *op);
142+ CreateComputationalOps (&result, *op, places_. size () );
105143 if (!is_forwarding) {
106144 // Currently, we assume that once gradient is generated, it can be
107145 // broadcast, and each gradient is only broadcast once. But there are no
@@ -199,8 +237,9 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const {
199237}
200238
201239void MultiDevSSAGraphBuilder::CreateComputationalOps (SSAGraph *result,
202- const OpDesc &op) const {
203- for (size_t scope_idx = 0 ; scope_idx < places_.size (); ++scope_idx) {
240+ const OpDesc &op,
241+ size_t num_places) const {
242+ for (size_t scope_idx = 0 ; scope_idx < num_places; ++scope_idx) {
204243 auto p = places_[scope_idx];
205244 auto s = local_scopes_[scope_idx];
206245 result->ops_ .emplace_back (new ComputationOpHandle (op, s, p));
0 commit comments