Skip to content

Commit 1cafe7b

Browse files
authored
Merge pull request #4703 from Xreki/core_optimize_backward
Simplify backward when inserting a sum operator to accumulate all duplicated variables
2 parents 62da438 + 7454ec0 commit 1cafe7b

File tree

1 file changed

+6
-22
lines changed

1 file changed

+6
-22
lines changed

paddle/framework/backward.cc

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -172,30 +172,14 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
172172
std::to_string(i));
173173
net->ops_[op_offset]->Rename(name, dup_outputs.back());
174174
}
175-
// collect all the offset to append `add` op for each alias
176-
//
177-
// one variable is shared between multiple operators.
178-
// insert add operator one by one, then add it to output
179-
for (size_t output_idx = 0; output_idx < dup_outputs.size() - 1;
180-
++output_idx) {
181-
auto insert_add_x = dup_outputs[output_idx];
182-
auto insert_add_y = dup_outputs[output_idx + 1];
183-
auto insert_add_out = name + "@SHARED@" + std::to_string(output_idx);
184-
// first add op inserted
185-
if (output_idx == dup_outputs.size() - 2) {
186-
insert_add_out = name;
187-
}
188-
if (output_idx != 0) {
189-
insert_add_y = name + "@SHARED@" + std::to_string(output_idx - 1);
190-
}
191-
insert_position.push_back(
192-
{dup_op.back(),
193-
OpRegistry::CreateOp("sum", {{"X", {insert_add_x, insert_add_y}}},
194-
{{"Out", {insert_add_out}}}, {})});
195-
}
175+
// collect all the offset for each alias,
176+
// insert a sum operator to add all aliases to output
177+
insert_position.push_back(
178+
{dup_op.back(), OpRegistry::CreateOp("sum", {{"X", dup_outputs}},
179+
{{"Out", {name}}}, {})});
196180
}
197181

198-
// make sure the inserted `add` ops follow the BFS order.
182+
// make sure the inserted `sum` ops follow the BFS order.
199183
insert_position.sort(
200184
[](const Pos& l, const Pos& r) { return l.first > r.first; });
201185

0 commit comments

Comments
 (0)