Skip to content

Commit ea78be2

Browse files
committed
follow comments
1 parent 7ee07df commit ea78be2

File tree

8 files changed

+116
-63
lines changed

8 files changed

+116
-63
lines changed

paddle/fluid/framework/details/broadcast_op_handle.cc

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ void BroadcastOpHandle::RunImpl() {
5353

5454
Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var);
5555

56-
if (!use_nccl_ || platform::is_cpu_place(in_tensor.place())) {
56+
if (platform::is_cpu_place(in_tensor.place())) {
5757
for (auto *out : out_var_handles) {
5858
if (*out == *in_var_handle) {
5959
continue;
@@ -72,7 +72,7 @@ void BroadcastOpHandle::RunImpl() {
7272
auto dev_ctx = dev_ctxes_.at(out_p);
7373
RunAndRecordEvent(out_p, [in_tensor, out_var, dev_ctx, out_p] {
7474
paddle::framework::TensorCopy(
75-
in_tensor, out_p, *(dev_ctx),
75+
in_tensor, out_p, *dev_ctx,
7676
&VariableVisitor::GetMutableTensor(out_var));
7777
});
7878
}
@@ -81,22 +81,24 @@ void BroadcastOpHandle::RunImpl() {
8181
PADDLE_ENFORCE(platform::is_gpu_place(in_tensor.place()));
8282
VarHandle *out_handle;
8383
int root = boost::get<platform::CUDAPlace>(in_tensor.place()).device;
84-
std::vector<std::function<void()>> all_reduce_calls;
84+
std::vector<std::function<void()>> broadcast_calls;
8585

8686
for (size_t j = 0; j < out_var_handles.size(); ++j) {
87-
auto *out = out_var_handles[j];
88-
auto *out_var = var_scopes.at(out->scope_idx_)->FindVar(out->name_);
87+
VarHandle *out_var_handle = out_var_handles[j];
88+
Variable *out_var = var_scopes.at(out_var_handle->scope_idx_)
89+
->FindVar(out_var_handle->name_);
8990

90-
if (*out != *in_var_handle) {
91+
if (*out_var_handle != *in_var_handle) {
9192
PADDLE_ENFORCE_NOT_NULL(out_var);
92-
PADDLE_ENFORCE_EQ(out->place_.which(), in_tensor.place().which(),
93+
PADDLE_ENFORCE_EQ(out_var_handle->place_.which(),
94+
in_tensor.place().which(),
9395
"Places must be all on CPU or all on CUDA.");
9496
VariableVisitor::ShareDimsAndLoD(*in_var, out_var);
9597
VariableVisitor::GetMutableTensor(out_var).mutable_data(
96-
out->place_, in_tensor.type());
98+
out_var_handle->place_, in_tensor.type());
9799
}
98100

99-
auto out_p = out->place_;
101+
auto out_p = out_var_handle->place_;
100102
int dev_id = boost::get<platform::CUDAPlace>(out_p).device;
101103

102104
auto &nccl_ctx = nccl_ctxs_->at(dev_id);
@@ -106,15 +108,15 @@ void BroadcastOpHandle::RunImpl() {
106108
void *send_recv_buffer = nullptr;
107109
if (root == dev_id) {
108110
send_recv_buffer = const_cast<void *>(in_tensor.data<void>());
109-
out_handle = out;
111+
out_handle = out_var_handle;
110112
} else {
111113
send_recv_buffer =
112114
VariableVisitor::GetMutableTensor(out_var).mutable_data(
113-
out->place_);
115+
out_var_handle->place_);
114116
}
115117

116118
int type = platform::ToNCCLDataType(in_tensor.type());
117-
all_reduce_calls.emplace_back([=] {
119+
broadcast_calls.emplace_back([=] {
118120
PADDLE_ENFORCE(platform::dynload::ncclBcast(
119121
send_recv_buffer, in_tensor.numel(),
120122
static_cast<ncclDataType_t>(type), root, comm, stream));
@@ -124,7 +126,7 @@ void BroadcastOpHandle::RunImpl() {
124126
this->RunAndRecordEvent([&] {
125127
{
126128
platform::NCCLGroupGuard guard;
127-
for (auto &call : all_reduce_calls) {
129+
for (auto &call : broadcast_calls) {
128130
call();
129131
}
130132
}

paddle/fluid/framework/details/broadcast_op_handle.h

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,9 @@ struct BroadcastOpHandle : public OpHandleBase {
3636
public:
3737
#ifdef PADDLE_WITH_CUDA
3838
BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
39-
const std::vector<platform::Place> &places, bool use_nccl,
39+
const std::vector<platform::Place> &places,
4040
const platform::NCCLContextMap *nccl_ctxs)
41-
: local_scopes_(local_scopes),
42-
places_(places),
43-
use_nccl_(use_nccl),
44-
nccl_ctxs_(nccl_ctxs) {
41+
: local_scopes_(local_scopes), places_(places), nccl_ctxs_(nccl_ctxs) {
4542
if (nccl_ctxs_) {
4643
for (auto &p_ctx : nccl_ctxs_->contexts_) {
4744
dev_ctxes_[platform::CUDAPlace(p_ctx.first)] = p_ctx.second.ctx_.get();
@@ -50,8 +47,8 @@ struct BroadcastOpHandle : public OpHandleBase {
5047
}
5148
#else
5249
BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
53-
const std::vector<platform::Place> &places, bool use_nccl)
54-
: local_scopes_(local_scopes), places_(places), use_nccl_(use_nccl) {}
50+
const std::vector<platform::Place> &places)
51+
: local_scopes_(local_scopes), places_(places) {}
5552
#endif
5653

5754
std::string Name() const override;
@@ -65,7 +62,6 @@ struct BroadcastOpHandle : public OpHandleBase {
6562
private:
6663
const std::vector<Scope *> &local_scopes_;
6764
const std::vector<platform::Place> &places_;
68-
bool use_nccl_;
6965
#ifdef PADDLE_WITH_CUDA
7066
const platform::NCCLContextMap *nccl_ctxs_;
7167
#endif

paddle/fluid/framework/details/broadcast_op_handle_test.cc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -96,19 +96,19 @@ struct TestBroadcastOpHandle {
9696
}
9797
param_scopes_[input_scope_idx]->Var("input");
9898

99-
#ifdef PADDLE_WITH_CUDA
100-
op_handle_.reset(new BroadcastOpHandle(local_scopes_, gpu_list_, use_gpu_,
101-
nccl_ctxs_.get()));
102-
#endif
103-
10499
if (use_gpu_) {
105-
#ifndef PADDLE_WITH_CUDA
100+
#ifdef PADDLE_WITH_CUDA
101+
op_handle_.reset(
102+
new BroadcastOpHandle(local_scopes_, gpu_list_, nccl_ctxs_.get()));
103+
#else
106104
PADDLE_THROW("CUDA is not support.");
107105
#endif
108106
} else {
109-
#ifndef PADDLE_WITH_CUDA
107+
#ifdef PADDLE_WITH_CUDA
110108
op_handle_.reset(
111-
new BroadcastOpHandle(local_scopes_, gpu_list_, use_gpu_));
109+
new BroadcastOpHandle(local_scopes_, gpu_list_, nccl_ctxs_.get()));
110+
#else
111+
op_handle_.reset(new BroadcastOpHandle(local_scopes_, gpu_list_));
112112
#endif
113113
}
114114

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -239,10 +239,9 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(SSAGraph *result,
239239
const std::string &p_name,
240240
size_t dev_id) const {
241241
#ifdef PADDLE_WITH_CUDA
242-
auto *op_handle =
243-
new BroadcastOpHandle(local_scopes_, places_, true, nccl_ctxs_);
242+
auto *op_handle = new BroadcastOpHandle(local_scopes_, places_, nccl_ctxs_);
244243
#else
245-
auto *op_handle = new BroadcastOpHandle(local_scopes_, places_, false);
244+
auto *op_handle = new BroadcastOpHandle(local_scopes_, places_);
246245
#endif
247246

248247
result->ops_.emplace_back(op_handle);

paddle/fluid/framework/details/reduce_op_handle_test.cc

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -96,17 +96,18 @@ struct TestReduceOpHandle {
9696
}
9797
param_scopes_[out_scope_idx]->Var("out");
9898

99-
#ifdef PADDLE_WITH_CUDA
100-
op_handle_.reset(
101-
new ReduceOpHandle(local_scopes_, gpu_list_, nccl_ctxs_.get()));
102-
#endif
103-
10499
if (use_gpu_) {
105-
#ifndef PADDLE_WITH_CUDA
100+
#ifdef PADDLE_WITH_CUDA
101+
op_handle_.reset(
102+
new ReduceOpHandle(local_scopes_, gpu_list_, nccl_ctxs_.get()));
103+
#else
106104
PADDLE_THROW("CUDA is not support.");
107105
#endif
108106
} else {
109-
#ifndef PADDLE_WITH_CUDA
107+
#ifdef PADDLE_WITH_CUDA
108+
op_handle_.reset(
109+
new ReduceOpHandle(local_scopes_, gpu_list_, nccl_ctxs_.get()));
110+
#else
110111
op_handle_.reset(new ReduceOpHandle(local_scopes_, gpu_list_));
111112
#endif
112113
}

paddle/fluid/framework/details/var_handle.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,7 @@ struct VarHandle : public VarHandleBase {
6767
o.scope_idx_ == scope_idx_;
6868
}
6969

70-
bool operator!=(const VarHandle& o) const {
71-
return o.generated_op_ != generated_op_ || o.name_ != name_ ||
72-
o.scope_idx_ != scope_idx_;
73-
}
70+
bool operator!=(const VarHandle& o) const { return !this->operator==(o); }
7471
};
7572

7673
// Dummy Variable. It is used to represent dependencies between operators

python/paddle/fluid/parallel_executor.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,17 @@ def __init__(self,
4444
training.
4545
allow_op_delay(bool, default False): Whether to delay and buffer
4646
some operators together for scheduling or not, which may
47-
improve performance in some cases, defalut False.
47+
improve performance in some cases, default False.
4848
share_vars_from(ParallelExecutor, default None): If provied,
4949
it will share variables from the specified ParallelExecutor.
50+
use_nccl_allreduce(bool, default True): Whether to use nccl_allreduce
51+
or not, if set True, the communication between different
52+
devices by nccl allReduce, which doesn't support updating sparse
53+
parameter, if set False, the communication between different
54+
devices by reduce_op and broadcast_op, which will distribute all
55+
the parameter gradients evenly to different device and updates
56+
the parameters, and finally broadcast to other device, this method
57+
support updating sparse parameter. Default True.
5058
5159
Returns:
5260
A ParallelExecutor object.

0 commit comments

Comments
 (0)