Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 29 additions & 29 deletions paddle/fluid/framework/details/multi_devices_graph_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,23 +58,20 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(

void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
const OpDesc &op,
const platform::Place &p,
const size_t &i) const {
size_t place_id) const {
auto p = places_[place_id];
auto *op_handle = result->ops_.back().get();
op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p));

auto var_names = op.InputArgumentNames();

for (auto &each_var_name : var_names) {
VarHandle *var = CreateOrGetLatestVarHandle(result, each_var_name, p, i);
for (auto &each_var_name : op.InputArgumentNames()) {
VarHandle *var =
CreateOrGetLatestVarHandle(result, each_var_name, p, place_id);
op_handle->AddInput(var);
}

var_names = op.OutputArgumentNames();

for (auto &each_var_name : var_names) {
CreateOpOutput(result, op_handle, each_var_name, p, i);
for (auto &each_var_name : op.OutputArgumentNames()) {
CreateOpOutput(result, op_handle, each_var_name, p, place_id);
}
}

Expand All @@ -84,17 +81,18 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op,
return false;
}

auto checker = [&](const std::vector<std::string> opvars,
const std::vector<std::string> sendvars) -> bool {
bool is_dist_train_op = false;
/**
* Check any of opvars contains `.block` and in sendvars
*/
auto checker = [](const std::vector<std::string> &opvars,
const std::vector<std::string> &sendvars) -> bool {
for (auto &var : opvars) {
if (var.find(".block") != std::string::npos &&
std::find(sendvars.begin(), sendvars.end(), var) != sendvars.end()) {
is_dist_train_op = true;
break;
return true;
}
}
return is_dist_train_op;
return false;
};

if (op.Type() == "split") {
Expand All @@ -117,13 +115,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
places_.size());

// Find "send" op first for split is in front of send.
OpDesc *send_op = nullptr;
for (auto *op : program.Block(0).AllOps()) {
if (op->Type() == "send") {
send_op = op;
break;
}
}
OpDesc *send_op = GetSendOpDesc(program);

bool is_forwarding = true;
for (auto *op : program.Block(0).AllOps()) {
Expand All @@ -134,6 +126,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
} else if (IsDistTrainOp(*op, send_op)) {
CreateComputationalOps(&result, *op, 1);
} else if (IsScaleLossOp(*op)) {
// user can customize loss@grad if skip_scale_loss_
if (!skip_scale_loss_) {
CreateScaleLossGradOp(&result);
}
Expand All @@ -142,10 +135,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
CreateComputationalOps(&result, *op, places_.size());
if (!is_forwarding) {
// Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once. But there are no
// other cases, for example, we need to adjust the gradient according to
// the input when we get the gradient, which is not considered at
// present.
// broadcast, and each gradient is only broadcast once.
for (auto &og : op->OutputArgumentNames()) {
if (IsParameterGradientOnce(og, &og_has_been_broadcast)) {
InsertNCCLAllReduceOp(&result, og);
Expand Down Expand Up @@ -175,6 +165,16 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
return std::unique_ptr<SSAGraph>(graph);
}

OpDesc *MultiDevSSAGraphBuilder::GetSendOpDesc(
const ProgramDesc &program) const {
for (auto *op : program.Block(0).AllOps()) {
if (op->Type() == "send") {
return op;
}
}
return nullptr;
}

void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp(
SSAGraph *result, const std::string &og) const {
#ifdef PADDLE_WITH_CUDA
Expand Down Expand Up @@ -243,7 +243,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(SSAGraph *result,
auto p = places_[scope_idx];
auto s = local_scopes_[scope_idx];
result->ops_.emplace_back(new ComputationOpHandle(op, s, p));
CreateOpHandleIOs(result, op, p, scope_idx);
CreateOpHandleIOs(result, op, scope_idx);
}
}

Expand All @@ -255,7 +255,7 @@ void MultiDevSSAGraphBuilder::CreateSendOp(SSAGraph *result,
result->ops_.emplace_back(new SendOpHandle(op, s, p));
// Create inputs for output on original place and no ssa output
// is created for send op.
CreateOpHandleIOs(result, op, p, 0);
CreateOpHandleIOs(result, op, 0);
}

bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const {
Expand Down
11 changes: 10 additions & 1 deletion paddle/fluid/framework/details/multi_devices_graph_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {

private:
void CreateOpHandleIOs(SSAGraph *result, const OpDesc &op,
const platform::Place &p, const size_t &i) const;
size_t place_id) const;

private:
std::string loss_var_name_;
Expand All @@ -65,6 +65,9 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {

void CreateSendOp(SSAGraph *result, const OpDesc &op) const;

/**
* Is this operator as the end-point operator before/after send operator.
*/
bool IsDistTrainOp(const OpDesc &op, OpDesc *send_op) const;

void CreateComputationalOps(SSAGraph *result, const OpDesc &op,
Expand All @@ -77,6 +80,12 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
std::unordered_set<std::string> *og_has_been_broadcast) const;

void InsertNCCLAllReduceOp(SSAGraph *result, const std::string &og) const;

/**
* Get send op in the global block of program.
* nullptr if not found.
*/
OpDesc *GetSendOpDesc(const ProgramDesc &program) const;
};
} // namespace details
} // namespace framework
Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/framework/details/ssa_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,22 @@ namespace paddle {
namespace framework {
namespace details {

// A SSA graph used by parallel executor.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a link to the Wikipedia page to explain what SSA is.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will be added in next PR.

struct SSAGraph {
// all variable in each devices.
// The outside vector is the device vector. Each element of this vector is a
// map from variable name to variables. The variables, who have the same name,
// will have a different version. The offset in the
// `std::vector<std::unique_ptr<VarHandle>>` is the version of varaibles.
std::vector<
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>
vars_;

// aux variables to represent dependency. Useful to resolve data hazard.
std::unordered_set<std::unique_ptr<VarHandleBase>> dep_vars_;

// all operators. NOTE that even we use a vector here, the operators is
// unordered.
std::vector<std::unique_ptr<OpHandleBase>> ops_;
};

Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/framework/details/ssa_graph_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class SSAGraphBuilder {
const platform::Place &place,
size_t place_offset);

// Add an output variable (each_var_name, place, place_offset) to op_handle,
// which belongs to graph
static void CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle,
const std::string &each_var_name,
const platform::Place &place, size_t place_offset);
Expand Down