Skip to content

Commit 93daa01

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into feature/insert_reduce_to_parallel_exe
2 parents ed052f1 + ff99d94 commit 93daa01

34 files changed

+498
-118
lines changed

Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ RUN apt-get update && \
3232
automake locales clang-format swig doxygen cmake \
3333
liblapack-dev liblapacke-dev \
3434
clang-3.8 llvm-3.8 libclang-3.8-dev \
35-
net-tools libtool && \
35+
net-tools libtool ccache && \
3636
apt-get clean -y
3737

3838
# Install Go and glide

paddle/cuda/include/hl_base.h

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#ifndef HL_BASE_H_
16-
#define HL_BASE_H_
15+
#pragma once
1716

1817
#include <cstddef>
1918

@@ -207,8 +206,8 @@ typedef struct {
207206

208207
#ifdef __NVCC__
209208

210-
#include "cuda_runtime.h"
211-
#include "hl_cuda.h"
209+
#include "./cuda_runtime.h"
210+
#include "./hl_cuda.h"
212211
#include "paddle/utils/Logging.h"
213212

214213
extern __thread bool g_sync_flag;
@@ -228,6 +227,19 @@ extern __thread cudaStream_t default_stream;
228227
<< "CUDA error: " << hl_get_device_error_string((size_t)err); \
229228
}
230229

231-
#endif /* __NVCC__ */
230+
// __shfl has been deprecated as of CUDA 9.0.
231+
#if CUDA_VERSION < 9000
232+
template <typename T>
233+
__forceinline__ __device__ T
234+
__shfl_sync(unsigned, T val, int src_line, int width) {
235+
return __shfl(val, src_line, width);
236+
}
232237

233-
#endif /* HL_BASE_H_ */
238+
#define CREATE_SHFL_MASK(mask, predicate) mask = 0u;
239+
#else
240+
#define FULL_WARP_MASK 0xFFFFFFFF
241+
#define CREATE_SHFL_MASK(mask, predicate) \
242+
mask = __ballot_sync(FULL_WARP_MASK, (predicate))
243+
#endif
244+
245+
#endif // __NVCC__

paddle/cuda/src/hl_cuda_lstm.cu

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -341,12 +341,15 @@ void hl_lstm_parallel_forward(real *gateValue,
341341
}
342342

343343
__device__ __forceinline__ void transpose_32x32(real a[], const int idx) {
344-
int addr = idx % 32;
344+
const int warp_size = 32;
345+
int addr = idx % warp_size;
346+
unsigned mask = 0u;
347+
CREATE_SHFL_MASK(mask, addr < warp_size);
345348
#pragma unroll
346349
for (int k = 1; k < 32; k++) {
347350
// rSrc[k] = __shfl_sync(rSrc[k], (threadIdx.x + k) % 32, 32);
348-
addr = __shfl_sync(addr, (idx + 1) % 32, 32);
349-
a[k] = __shfl_sync(a[k], addr, 32);
351+
addr = __shfl_sync(mask, addr, (idx + 1) % 32, 32);
352+
a[k] = __shfl_sync(mask, a[k], addr, 32);
350353
}
351354

352355
#pragma unroll
@@ -360,10 +363,11 @@ __device__ __forceinline__ void transpose_32x32(real a[], const int idx) {
360363
}
361364

362365
addr = (32 - idx) % 32;
366+
CREATE_SHFL_MASK(mask, idx % 32 < warp_size);
363367
#pragma unroll
364368
for (int k = 0; k < 32; k++) {
365-
a[k] = __shfl_sync(a[k], addr, 32);
366-
addr = __shfl_sync(addr, (idx + 31) % 32, 32);
369+
a[k] = __shfl_sync(mask, a[k], addr, 32);
370+
addr = __shfl_sync(mask, addr, (idx + 31) % 32, 32);
367371
}
368372
}
369373

paddle/cuda/src/hl_top_k.cu

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,13 +244,16 @@ __device__ __forceinline__ void blockReduce(Pair* shTopK,
244244
if (--beamSize == 0) break;
245245
__syncthreads();
246246

247+
unsigned mask = 0u;
248+
// CREATE_SHFL_MASK(mask, tid < len);
249+
247250
if (tid == maxId[0]) {
248251
if (beam < maxLength) {
249252
shTopK[tid] = topK[beam];
250253
}
251254
}
252255
if (maxId[0] / 32 == warp) {
253-
if (__shfl_sync(beam, (maxId[0]) % 32, 32) == maxLength) break;
256+
if (__shfl_sync(mask, beam, (maxId[0]) % 32, 32) == maxLength) break;
254257
}
255258
}
256259
}

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
3737
const std::string &loss_var_name,
3838
const std::unordered_set<std::string> &params,
3939
const std::vector<Scope *> &local_scopes,
40-
platform::NCCLContextMap *nccl_ctxs, bool skip_scale_loss,
40+
platform::NCCLContextMap *nccl_ctxs, bool use_default_grad_scale,
4141
bool use_nccl_allreduce)
4242
: loss_var_name_(loss_var_name),
4343
places_(places),
@@ -50,7 +50,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
5050
const std::vector<platform::Place> &places,
5151
const std::string &loss_var_name,
5252
const std::unordered_set<std::string> &params,
53-
const std::vector<Scope *> &local_scopes, bool skip_scale_loss,
53+
const std::vector<Scope *> &local_scopes, bool use_default_grad_scale,
5454
bool use_nccl_allreduce)
5555
: loss_var_name_(loss_var_name),
5656
places_(places),
@@ -60,7 +60,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
6060
for (auto &p : params) {
6161
grad_names_.insert(GradVarName(p));
6262
}
63-
skip_scale_loss_ = skip_scale_loss;
63+
use_default_grad_scale_ = use_default_grad_scale;
6464
}
6565

6666
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
@@ -141,8 +141,8 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
141141
} else if (IsDistTrainOp(*op, send_op)) {
142142
CreateComputationalOps(&result, *op, 1);
143143
} else if (IsScaleLossOp(*op)) {
144-
// user can customize loss@grad if skip_scale_loss_
145-
if (!skip_scale_loss_) {
144+
// user can customize loss@grad if not use_default_grad_scale_
145+
if (use_default_grad_scale_) {
146146
CreateScaleLossGradOp(&result);
147147
}
148148
is_forwarding = false;

paddle/fluid/framework/details/multi_devices_graph_builder.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,13 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
3636
const std::unordered_set<std::string> &params,
3737
const std::vector<Scope *> &local_scopes,
3838
platform::NCCLContextMap *nccl_ctxs,
39-
bool skip_scale_loss, bool use_nccl_allreduce);
39+
bool use_default_grad_scale, bool use_nccl_allreduce);
4040
#else
4141
MultiDevSSAGraphBuilder(const std::vector<platform::Place> &places,
4242
const std::string &loss_var_name,
4343
const std::unordered_set<std::string> &params,
4444
const std::vector<Scope *> &local_scopes,
45-
bool skip_scale_loss, bool use_nccl_allreduce);
45+
bool use_default_grad_scale, bool use_nccl_allreduce);
4646
#endif
4747

4848
std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override;
@@ -61,7 +61,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
6161
platform::NCCLContextMap *nccl_ctxs_;
6262
#endif
6363
bool use_nccl_allreduce_;
64-
bool skip_scale_loss_;
64+
bool use_default_grad_scale_;
6565

6666
bool IsScaleLossOp(const OpDesc &op) const;
6767

paddle/fluid/framework/lod_tensor_test.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,11 +255,11 @@ TEST(LoDTensor, RecordIO) {
255255
std::unique_ptr<std::istream> stream_ptr(stream);
256256
recordio::Scanner scanner(std::move(stream_ptr));
257257
auto tensors = ReadFromRecordIO(&scanner, ctx);
258-
ASSERT_EQ(tensors.size(), 2);
258+
ASSERT_EQ(tensors.size(), static_cast<size_t>(2));
259259
assert_tensor_ok(tensors[0]);
260260
assert_tensor_ok(tensors[1]);
261261
tensors = ReadFromRecordIO(&scanner, ctx);
262-
ASSERT_EQ(tensors.size(), 2);
262+
ASSERT_EQ(tensors.size(), static_cast<size_t>(2));
263263
assert_tensor_ok(tensors[0]);
264264
assert_tensor_ok(tensors[1]);
265265
}

paddle/fluid/framework/parallel_executor.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ ParallelExecutor::ParallelExecutor(
5858
const std::unordered_set<std::string> &bcast_vars,
5959
const ProgramDesc &main_program, const std::string &loss_var_name,
6060
Scope *scope, const std::vector<Scope *> &local_scopes, bool allow_op_delay,
61-
bool customize_scale_loss, bool use_nccl_allreduce)
61+
bool use_default_grad_scale, bool use_nccl_allreduce)
6262
: member_(new ParallelExecutorPrivate(places)) {
6363
member_->global_scope_ = scope;
6464

@@ -93,11 +93,11 @@ ParallelExecutor::ParallelExecutor(
9393
#ifdef PADDLE_WITH_CUDA
9494
details::MultiDevSSAGraphBuilder builder(
9595
member_->places_, loss_var_name, params, member_->local_scopes_,
96-
member_->nccl_ctxs_.get(), customize_scale_loss, use_nccl_allreduce);
96+
member_->nccl_ctxs_.get(), use_default_grad_scale, use_nccl_allreduce);
9797
#else
9898
details::MultiDevSSAGraphBuilder builder(
9999
member_->places_, loss_var_name, params, member_->local_scopes_,
100-
customize_scale_loss, use_nccl_allreduce);
100+
use_default_grad_scale, use_nccl_allreduce);
101101
#endif
102102
auto graph = builder.Build(main_program);
103103

paddle/fluid/framework/parallel_executor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class ParallelExecutor {
4040
const ProgramDesc& main_program,
4141
const std::string& loss_var_name, Scope* scope,
4242
const std::vector<Scope*>& local_scopes,
43-
bool allow_op_delay, bool customize_scale_loss,
43+
bool allow_op_delay, bool use_default_grad_scale,
4444
bool use_nccl_allreduce);
4545

4646
~ParallelExecutor();

paddle/fluid/framework/selected_rows.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,11 @@ bool SelectedRows::HasKey(int64_t key) const {
120120
: true;
121121
}
122122

123-
std::vector<int64_t> SelectedRows::Get(std::vector<int64_t> keys,
124-
framework::Tensor* value) const {
123+
std::vector<std::pair<int64_t, int64_t>> SelectedRows::Get(
124+
std::vector<int64_t> keys, framework::Tensor* value) const {
125125
PADDLE_ENFORCE(value->IsInitialized(),
126126
"The value tensor should be initialized.");
127-
std::vector<int64_t> non_keys;
127+
std::vector<std::pair<int64_t, int64_t>> non_keys_pair;
128128
int64_t value_width = value_->numel() / value_->dims()[0];
129129
PADDLE_ENFORCE_EQ(value_width, value->numel() / value->dims()[0],
130130
"output tensor should have the same shape with table "
@@ -133,15 +133,15 @@ std::vector<int64_t> SelectedRows::Get(std::vector<int64_t> keys,
133133
for (size_t i = 0; i < keys.size(); ++i) {
134134
int64_t index = Index(keys[i]);
135135
if (index == -1) {
136-
non_keys.push_back(keys[i]);
136+
non_keys_pair.push_back(std::make_pair(keys[i], static_cast<int64_t>(i)));
137137
} else {
138138
framework::VisitDataType(
139139
framework::ToDataType(value_->type()),
140140
TensorCopyVisitor(value, i * value_width, *value_.get(),
141141
index * value_width, value_width));
142142
}
143143
}
144-
return non_keys;
144+
return non_keys_pair;
145145
}
146146

147147
bool SelectedRows::Set(int64_t key, const framework::Tensor& value) {

0 commit comments

Comments
 (0)