Skip to content

Commit 2c8fe4e

Browse files
authored
Merge pull request #10143 from typhoonzero/fix_multiGPU_dist_train
Fix multi gpu dist train
2 parents e5f2cb8 + f034152 commit 2c8fe4e

File tree

2 files changed

+46
-4
lines changed

2 files changed

+46
-4
lines changed

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,33 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
7878
}
7979
}
8080

81+
bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op,
82+
OpDesc *send_op) const {
83+
if (send_op == nullptr) {
84+
return false;
85+
}
86+
87+
auto checker = [&](const std::vector<std::string> opvars,
88+
const std::vector<std::string> sendvars) -> bool {
89+
bool is_dist_train_op = false;
90+
for (auto &var : opvars) {
91+
if (var.find(".block") != std::string::npos &&
92+
std::find(sendvars.begin(), sendvars.end(), var) != sendvars.end()) {
93+
is_dist_train_op = true;
94+
break;
95+
}
96+
}
97+
return is_dist_train_op;
98+
};
99+
100+
if (op.Type() == "split") {
101+
return checker(op.OutputArgumentNames(), send_op->InputArgumentNames());
102+
} else if (op.Type() == "concat") {
103+
return checker(op.InputArgumentNames(), send_op->OutputArgumentNames());
104+
}
105+
return false;
106+
}
107+
81108
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
82109
const ProgramDesc &program) const {
83110
auto graph = new SSAGraph();
@@ -89,19 +116,30 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
89116
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>(
90117
places_.size());
91118

119+
// Find "send" op first for split is in front of send.
120+
OpDesc *send_op = nullptr;
121+
for (auto *op : program.Block(0).AllOps()) {
122+
if (op->Type() == "send") {
123+
send_op = op;
124+
break;
125+
}
126+
}
127+
92128
bool is_forwarding = true;
93129
for (auto *op : program.Block(0).AllOps()) {
94130
if (op->Type() == "send") {
95131
// append send op if program is distributed trainer main program.
96132
// always use the first device
97133
CreateSendOp(&result, *op);
134+
} else if (IsDistTrainOp(*op, send_op)) {
135+
CreateComputationalOps(&result, *op, 1);
98136
} else if (IsScaleLossOp(*op)) {
99137
if (!skip_scale_loss_) {
100138
CreateScaleLossGradOp(&result);
101139
}
102140
is_forwarding = false;
103141
} else {
104-
CreateComputationalOps(&result, *op);
142+
CreateComputationalOps(&result, *op, places_.size());
105143
if (!is_forwarding) {
106144
// Currently, we assume that once gradient is generated, it can be
107145
// broadcast, and each gradient is only broadcast once. But there are no
@@ -199,8 +237,9 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const {
199237
}
200238

201239
void MultiDevSSAGraphBuilder::CreateComputationalOps(SSAGraph *result,
202-
const OpDesc &op) const {
203-
for (size_t scope_idx = 0; scope_idx < places_.size(); ++scope_idx) {
240+
const OpDesc &op,
241+
size_t num_places) const {
242+
for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) {
204243
auto p = places_[scope_idx];
205244
auto s = local_scopes_[scope_idx];
206245
result->ops_.emplace_back(new ComputationOpHandle(op, s, p));

paddle/fluid/framework/details/multi_devices_graph_builder.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,10 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
6565

6666
void CreateSendOp(SSAGraph *result, const OpDesc &op) const;
6767

68-
void CreateComputationalOps(SSAGraph *result, const OpDesc &op) const;
68+
bool IsDistTrainOp(const OpDesc &op, OpDesc *send_op) const;
69+
70+
void CreateComputationalOps(SSAGraph *result, const OpDesc &op,
71+
size_t num_places) const;
6972

7073
void CreateScaleLossGradOp(SSAGraph *result) const;
7174

0 commit comments

Comments
 (0)