@@ -53,7 +53,7 @@ void BroadcastOpHandle::RunImpl() {
5353
5454 Tensor &in_tensor = VariableVisitor::GetMutableTensor (in_var);
5555
56- if (!use_nccl_ || platform::is_cpu_place (in_tensor.place ())) {
56+ if (platform::is_cpu_place (in_tensor.place ())) {
5757 for (auto *out : out_var_handles) {
5858 if (*out == *in_var_handle) {
5959 continue ;
@@ -72,7 +72,7 @@ void BroadcastOpHandle::RunImpl() {
7272 auto dev_ctx = dev_ctxes_.at (out_p);
7373 RunAndRecordEvent (out_p, [in_tensor, out_var, dev_ctx, out_p] {
7474 paddle::framework::TensorCopy (
75- in_tensor, out_p, *( dev_ctx) ,
75+ in_tensor, out_p, *dev_ctx,
7676 &VariableVisitor::GetMutableTensor (out_var));
7777 });
7878 }
@@ -81,22 +81,24 @@ void BroadcastOpHandle::RunImpl() {
8181 PADDLE_ENFORCE (platform::is_gpu_place (in_tensor.place ()));
8282 VarHandle *out_handle;
8383 int root = boost::get<platform::CUDAPlace>(in_tensor.place ()).device ;
84- std::vector<std::function<void ()>> all_reduce_calls ;
84+ std::vector<std::function<void ()>> broadcast_calls ;
8585
8686 for (size_t j = 0 ; j < out_var_handles.size (); ++j) {
87- auto *out = out_var_handles[j];
88- auto *out_var = var_scopes.at (out->scope_idx_ )->FindVar (out->name_ );
87+ VarHandle *out_var_handle = out_var_handles[j];
88+ Variable *out_var = var_scopes.at (out_var_handle->scope_idx_ )
89+ ->FindVar (out_var_handle->name_ );
8990
90- if (*out != *in_var_handle) {
91+ if (*out_var_handle != *in_var_handle) {
9192 PADDLE_ENFORCE_NOT_NULL (out_var);
92- PADDLE_ENFORCE_EQ (out->place_ .which (), in_tensor.place ().which (),
93+ PADDLE_ENFORCE_EQ (out_var_handle->place_ .which (),
94+ in_tensor.place ().which (),
9395 " Places must be all on CPU or all on CUDA." );
9496 VariableVisitor::ShareDimsAndLoD (*in_var, out_var);
9597 VariableVisitor::GetMutableTensor (out_var).mutable_data (
96- out ->place_ , in_tensor.type ());
98+ out_var_handle ->place_ , in_tensor.type ());
9799 }
98100
99- auto out_p = out ->place_ ;
101+ auto out_p = out_var_handle ->place_ ;
100102 int dev_id = boost::get<platform::CUDAPlace>(out_p).device ;
101103
102104 auto &nccl_ctx = nccl_ctxs_->at (dev_id);
@@ -106,15 +108,15 @@ void BroadcastOpHandle::RunImpl() {
106108 void *send_recv_buffer = nullptr ;
107109 if (root == dev_id) {
108110 send_recv_buffer = const_cast <void *>(in_tensor.data <void >());
109- out_handle = out ;
111+ out_handle = out_var_handle ;
110112 } else {
111113 send_recv_buffer =
112114 VariableVisitor::GetMutableTensor (out_var).mutable_data (
113- out ->place_ );
115+ out_var_handle ->place_ );
114116 }
115117
116118 int type = platform::ToNCCLDataType (in_tensor.type ());
117- all_reduce_calls .emplace_back ([=] {
119+ broadcast_calls .emplace_back ([=] {
118120 PADDLE_ENFORCE (platform::dynload::ncclBcast (
119121 send_recv_buffer, in_tensor.numel (),
120122 static_cast <ncclDataType_t>(type), root, comm, stream));
@@ -124,7 +126,7 @@ void BroadcastOpHandle::RunImpl() {
124126 this ->RunAndRecordEvent ([&] {
125127 {
126128 platform::NCCLGroupGuard guard;
127- for (auto &call : all_reduce_calls ) {
129+ for (auto &call : broadcast_calls ) {
128130 call ();
129131 }
130132 }
0 commit comments