Skip to content

Commit 246367d

Browse files
committed
merge with lod_reset branch
2 parents 0caf2f2 + b3f076a commit 246367d

File tree

14 files changed

+361
-60
lines changed

14 files changed

+361
-60
lines changed
File renamed without changes.
File renamed without changes.

paddle/fluid/framework/channel_test.cc

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -871,3 +871,67 @@ TEST(ChannelHolder, ChannelHolderDestroyUnblocksSendersTest) {
871871
ch->Reset<int>(0);
872872
ChannelHolderDestroyUnblockSenders(ch, false);
873873
}
874+
875+
// This tests that closing a channelholder many times.
876+
void ChannelHolderManyTimesClose(ChannelHolder *ch) {
877+
const int num_threads = 15;
878+
std::thread t[num_threads];
879+
bool thread_ended[num_threads];
880+
881+
// Launches threads that try to send data to channel.
882+
for (size_t i = 0; i < num_threads / 3; i++) {
883+
thread_ended[i] = false;
884+
t[i] = std::thread(
885+
[&](bool *ended) {
886+
int data = 10;
887+
ch->Send(&data);
888+
*ended = true;
889+
},
890+
&thread_ended[i]);
891+
}
892+
893+
// Launches threads that try to receive data to channel.
894+
for (size_t i = num_threads / 3; i < 2 * num_threads / 3; i++) {
895+
thread_ended[i] = false;
896+
t[i] = std::thread(
897+
[&](bool *p) {
898+
int data;
899+
if (ch->Receive(&data)) {
900+
EXPECT_EQ(data, 10);
901+
}
902+
*p = true;
903+
},
904+
&thread_ended[i]);
905+
}
906+
907+
// Launches threads that try to close the channel.
908+
for (size_t i = 2 * num_threads / 3; i < num_threads; i++) {
909+
thread_ended[i] = false;
910+
t[i] = std::thread(
911+
[&](bool *p) {
912+
if (!ch->IsClosed()) {
913+
ch->close();
914+
}
915+
*p = true;
916+
},
917+
&thread_ended[i]);
918+
}
919+
920+
std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait
921+
922+
// Verify that all threads are unblocked
923+
for (size_t i = 0; i < num_threads; i++) {
924+
EXPECT_EQ(thread_ended[i], true);
925+
}
926+
EXPECT_TRUE(ch->IsClosed());
927+
// delete the channel
928+
delete ch;
929+
for (size_t i = 0; i < num_threads; i++) t[i].join();
930+
}
931+
932+
TEST(ChannelHolder, ChannelHolderManyTimesCloseTest) {
933+
// Check for Buffered Channel
934+
ChannelHolder *ch = new ChannelHolder();
935+
ch->Reset<int>(10);
936+
ChannelHolderManyTimesClose(ch);
937+
}

paddle/fluid/framework/init.cc

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ namespace paddle {
2626
namespace framework {
2727

2828
std::once_flag gflags_init_flag;
29+
std::once_flag p2p_init_flag;
2930

3031
void InitGflags(std::vector<std::string> &argv) {
3132
std::call_once(gflags_init_flag, [&]() {
@@ -42,6 +43,27 @@ void InitGflags(std::vector<std::string> &argv) {
4243
});
4344
}
4445

46+
void InitP2P(int count) {
47+
#ifdef PADDLE_WITH_CUDA
48+
std::call_once(p2p_init_flag, [&]() {
49+
for (int i = 0; i < count; ++i) {
50+
for (int j = 0; j < count; ++j) {
51+
if (i == j) continue;
52+
int can_acess = -1;
53+
PADDLE_ENFORCE(cudaDeviceCanAccessPeer(&can_acess, i, j),
54+
"Failed to test P2P access.");
55+
if (can_acess != 1) {
56+
LOG(WARNING) << "Cannot enable P2P access from " << i << " to " << j;
57+
} else {
58+
cudaSetDevice(i);
59+
cudaDeviceEnablePeerAccess(j, 0);
60+
}
61+
}
62+
}
63+
});
64+
#endif
65+
}
66+
4567
void InitDevices() {
4668
/*Init all avaiable devices by default */
4769

@@ -63,7 +85,7 @@ void InitDevices() {
6385
for (int i = 0; i < count; ++i) {
6486
places.emplace_back(platform::CUDAPlace(i));
6587
}
66-
88+
InitP2P(count);
6789
platform::DeviceContextPool::Init(places);
6890
}
6991

paddle/fluid/inference/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@ cc_library(paddle_fluid_shared SHARED
1313
SRCS io.cc
1414
DEPS ARCHIVE_START ${GLOB_OP_LIB} ${FLUID_CORE_MODULES} ARCHIVE_END)
1515
set_target_properties(paddle_fluid_shared PROPERTIES OUTPUT_NAME paddle_fluid)
16+
if(NOT APPLE)
17+
# TODO(liuyiqun): Temporarily disable the link flag because it is not support on Mac.
18+
set(LINK_FLAGS "-Wl,--version-script ${CMAKE_CURRENT_SOURCE_DIR}/paddle_fluid.map")
19+
set_target_properties(paddle_fluid_shared PROPERTIES LINK_FLAGS "${LINK_FLAGS}")
20+
endif()
1621

1722
if(WITH_TESTING)
1823
add_subdirectory(tests/book)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
global:
3+
*paddle*;
4+
local:
5+
*;
6+
};

paddle/fluid/operators/lod_reset_op.cc

Lines changed: 81 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,16 @@ class LoDResetOp : public framework::OperatorWithKernel {
2222
using framework::OperatorWithKernel::OperatorWithKernel;
2323

2424
void InferShape(framework::InferShapeContext *ctx) const override {
25-
// input check
2625
PADDLE_ENFORCE(ctx->HasInput("X"),
2726
"Input(X) of LoDResetOp should not be null.");
2827
PADDLE_ENFORCE(ctx->HasOutput("Out"),
2928
"Output(Out) of LoDResetOp should not be null.");
30-
// If target LoD is not set form Input(), then it must be set from Attr().
31-
if (!ctx->HasInput("TargetLoD")) {
29+
30+
if (!ctx->HasInput("Y")) {
3231
auto level0 = ctx->Attrs().Get<std::vector<int>>("target_lod");
33-
PADDLE_ENFORCE(level0.size() > 1,
34-
"Target LoD is not found, should be set to be a valid one "
35-
"through Input() or Attr().");
32+
PADDLE_ENFORCE_GT(level0.size(), 1,
33+
"If Input(Y) not provided, the target lod should be "
34+
"specified by attribute `target_lod`.");
3635
}
3736
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
3837
}
@@ -50,36 +49,77 @@ class LoDResetOpMaker : public framework::OpProtoAndCheckerMaker {
5049
public:
5150
LoDResetOpMaker(OpProto *proto, OpAttrChecker *op_checker)
5251
: OpProtoAndCheckerMaker(proto, op_checker) {
53-
AddInput("X", "(LoDTensor) The input tensor of lod_reset operator.");
54-
AddInput("TargetLoD",
55-
"(Tensor, optional) The target level 0 LoD from Input().")
52+
AddInput("X",
53+
"(Tensor, LoDTensor) Input variable of LoDResetOp which "
54+
"could be a Tensor or LoDTensor, where the data of output "
55+
"variable inherits from.");
56+
AddInput("Y",
57+
"(Tensor, LoDTensor, optional) If provided and Y is LoDTensor, "
58+
"lod of Input(Y) would be considered as the target lod first, "
59+
"otherwise data of Input(Y) would be considered as the "
60+
"target lod.")
5661
.AsDispensable();
57-
AddOutput("Out", "(LoDTensor) The output tensor of lod_reset operator.");
62+
AddOutput("Out",
63+
"(LoDTensor) Output variable of LoDResetOp which should be a "
64+
"LoDTensor.");
5865
AddAttr<std::vector<int>>("target_lod",
5966
"The target level 0 LoD from Attr().")
6067
.SetDefault(std::vector<int>{});
6168
AddComment(R"DOC(LoDReset operator
6269
63-
Reset LoD of Input(X) into a new one specified by Input(TargetLoD) or
64-
Attr(target_lod), or set LoD for Input(X) if it doesn't have one.
65-
Currently the lod_reset operator only supports the reset of level 0 LoD.
66-
At least one of Input(TargetLoD) and Attr(target_lod) must be set,
67-
and if both of them are set, Input(TargetLoD) will be chosen as the
68-
target LoD.
70+
Set LoD of `X` to a new one specified by `Y` or attribute `target_lod`. When `Y`
71+
provided and `Y` is a LoDTensor, `Y.lod` would be considered as target LoD
72+
first, otherwise `Y.data` would be considered as target LoD. If `Y` is not
73+
provided, target LoD should be specified by attribute `target_lod`.
74+
If target LoD is specified by `Y.data` or `target_lod`, only one level LoD
75+
is supported.
76+
77+
Example 1:
78+
79+
Given a 1-level LoDTensor input(X):
80+
X.lod = [[ 0, 2, 5 6 ]]
81+
X.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
82+
X.dims = [6, 1]
83+
84+
attr(target_lod): [0, 4, 6]
85+
86+
then we get a 1-level LoDTensor:
87+
Out.lod = [[ 0, 4, 6 ]]
88+
Out.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
89+
Out.dims = [6, 1]
90+
91+
Example 2:
6992
70-
An example:
71-
Given a float LoDTensor X with shape (6, 1), its transpose form represents
93+
Given a 1-level LoDTensor input(X):
94+
X.lod = [[ 0, 2, 5 6 ]]
95+
X.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
96+
X.dims = [6, 1]
7297
73-
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
98+
input(Y) is a Tensor:
99+
Y.data = [[0, 2, 6]]
100+
Y.dims = [1, 3]
74101
75-
with LoD = [[0, 2, 5, 6]] and the three (transposed) sequences look like
102+
then we get a 1-level LoDTensor:
103+
Out.lod = [[ 0, 2, 6 ]]
104+
Out.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
105+
Out.dims = [6, 1]
76106
77-
[1.0, 2.0], [3.0, 4.0, 5.0], [6.0].
107+
Example 3:
78108
79-
If target LoD = [0, 4, 6], the lod_reset operator will reset the LoD and
80-
the sequences that the LoDTensor Output(Out) contains becomes:
109+
Given a 1-level LoDTensor input(X):
110+
X.lod = [[ 0, 2, 5 6 ]]
111+
X.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
112+
X.dims = [6, 1]
81113
82-
[1.0, 2.0, 3.0, 4.0], [5.0, 6.0].
114+
input(Y) is a 2-level LoDTensor:
115+
Y.lod = [[0, 2, 4], [0, 2, 5, 6]]
116+
Y.data = [[1.1], [2.1], [3.1], [4.1], [5.1], [6.1]]
117+
Y.dims = [6, 1]
118+
119+
then we get a 2-level LoDTensor:
120+
Out.lod = [[0, 2, 4], [0, 2, 5, 6]]
121+
Out.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
122+
Out.dims = [6, 1]
83123
84124
)DOC");
85125
}
@@ -90,10 +130,16 @@ class LoDResetGradOp : public framework::OperatorWithKernel {
90130
using framework::OperatorWithKernel::OperatorWithKernel;
91131

92132
void InferShape(framework::InferShapeContext *ctx) const override {
93-
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) shouldn't be null.");
133+
PADDLE_ENFORCE(ctx->HasInput("X"),
134+
"Input(X) of LoDResetGradOp should not be null.");
94135
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
95-
"Input(Out@GRAD) shouldn't be null.");
96-
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
136+
"Input(Out@Grad) of LoDResetGradOp should not be null.");
137+
138+
auto x_grad_name = framework::GradVarName("X");
139+
if (ctx->HasOutput(x_grad_name)) {
140+
ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X"));
141+
ctx->ShareLoD("X", /*->*/ x_grad_name);
142+
}
97143
}
98144

99145
protected:
@@ -111,9 +157,13 @@ class LoDResetGradOp : public framework::OperatorWithKernel {
111157
namespace ops = paddle::operators;
112158
REGISTER_OP(lod_reset, ops::LoDResetOp, ops::LoDResetOpMaker, lod_reset_grad,
113159
ops::LoDResetGradOp);
114-
REGISTER_OP_CPU_KERNEL(lod_reset,
115-
ops::LoDResetKernel<paddle::platform::CPUPlace, float>,
116-
ops::LoDResetKernel<paddle::platform::CPUPlace, double>);
160+
REGISTER_OP_CPU_KERNEL(
161+
lod_reset, ops::LoDResetKernel<paddle::platform::CPUPlace, float>,
162+
ops::LoDResetKernel<paddle::platform::CPUPlace, double>,
163+
ops::LoDResetKernel<paddle::platform::CPUPlace, int>,
164+
ops::LoDResetKernel<paddle::platform::CPUPlace, int64_t>);
117165
REGISTER_OP_CPU_KERNEL(
118166
lod_reset_grad, ops::LoDResetGradKernel<paddle::platform::CPUPlace, float>,
119-
ops::LoDResetGradKernel<paddle::platform::CPUPlace, double>);
167+
ops::LoDResetGradKernel<paddle::platform::CPUPlace, double>,
168+
ops::LoDResetGradKernel<paddle::platform::CPUPlace, int>,
169+
ops::LoDResetGradKernel<paddle::platform::CPUPlace, int64_t>);

paddle/fluid/operators/lod_reset_op.cu

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,12 @@ namespace ops = paddle::operators;
1818

1919
REGISTER_OP_CUDA_KERNEL(
2020
lod_reset, ops::LoDResetKernel<paddle::platform::CUDADeviceContext, float>,
21-
ops::LoDResetKernel<paddle::platform::CUDADeviceContext, double>);
21+
ops::LoDResetKernel<paddle::platform::CUDADeviceContext, double>,
22+
ops::LoDResetKernel<paddle::platform::CUDADeviceContext, int>,
23+
ops::LoDResetKernel<paddle::platform::CUDADeviceContext, int64_t>);
2224
REGISTER_OP_CUDA_KERNEL(
2325
lod_reset_grad,
2426
ops::LoDResetGradKernel<paddle::platform::CUDADeviceContext, float>,
25-
ops::LoDResetGradKernel<paddle::platform::CUDADeviceContext, double>);
27+
ops::LoDResetGradKernel<paddle::platform::CUDADeviceContext, double>,
28+
ops::LoDResetGradKernel<paddle::platform::CUDADeviceContext, int>,
29+
ops::LoDResetGradKernel<paddle::platform::CUDADeviceContext, int64_t>);

paddle/fluid/operators/lod_reset_op.h

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,35 +26,46 @@ class LoDResetKernel : public framework::OpKernel<T> {
2626
void Compute(const framework::ExecutionContext& ctx) const {
2727
auto* out = ctx.Output<framework::LoDTensor>("Out");
2828
auto* in = ctx.Input<framework::LoDTensor>("X");
29-
auto* lod_t = ctx.Input<framework::Tensor>("TargetLoD");
29+
auto* lod_t = ctx.Input<framework::LoDTensor>("Y");
30+
31+
out->ShareDataWith(*in);
3032

3133
std::vector<int> level0;
3234
if (lod_t) {
33-
auto* lod = lod_t->data<int>();
34-
if (platform::is_gpu_place(ctx.GetPlace())) {
35-
framework::Tensor lod_cpu;
36-
framework::TensorCopy(*lod_t, platform::CPUPlace(),
37-
ctx.device_context(), &lod_cpu);
38-
lod = lod_cpu.data<int>();
35+
if (lod_t->lod().size() > 0) {
36+
auto y_lod = lod_t->lod();
37+
auto last_level = y_lod[y_lod.size() - 1];
38+
PADDLE_ENFORCE_EQ(last_level.back(), in->dims()[0],
39+
"Last value of `Y`'s last level LoD should be equal "
40+
"to the first dimension of `X`");
41+
out->set_lod(y_lod);
42+
return; // early return, since lod already set
43+
} else {
44+
auto* lod = lod_t->data<int>();
45+
if (platform::is_gpu_place(ctx.GetPlace())) {
46+
framework::Tensor lod_cpu;
47+
framework::TensorCopy(*lod_t, platform::CPUPlace(),
48+
ctx.device_context(), &lod_cpu);
49+
lod = lod_cpu.data<int>();
50+
}
51+
level0 = std::vector<int>(lod, lod + lod_t->numel());
3952
}
40-
level0 = std::vector<int>(lod, lod + lod_t->numel());
4153
} else {
4254
level0 = ctx.Attr<std::vector<int>>("target_lod");
4355
}
4456

45-
PADDLE_ENFORCE(level0.size() > 1UL,
46-
"The size of target LoD should be greater than 1.");
47-
PADDLE_ENFORCE(level0[0] == 0,
48-
"Target LoD should be a vector starting from 0.");
49-
PADDLE_ENFORCE(level0.back() == in->dims()[0],
50-
"Target LoD should be a vector end with the "
51-
"first dimension of Input(X).");
57+
PADDLE_ENFORCE_GT(level0.size(), 1UL,
58+
"Size of target LoD should be greater than 1.");
59+
PADDLE_ENFORCE_EQ(level0[0], 0,
60+
"Target LoD should be a vector starting from 0.");
61+
PADDLE_ENFORCE_EQ(level0.back(), in->dims()[0],
62+
"Target LoD should be a vector end with the "
63+
"first dimension of Input(X).");
5264
for (size_t i = 0; i < level0.size() - 1; ++i) {
5365
PADDLE_ENFORCE(level0[i + 1] > level0[i],
5466
"Target LoD should be an ascending vector.");
5567
}
5668

57-
out->ShareDataWith(*in);
5869
// cast level0 to size_t
5970
std::vector<size_t> ulevel0(level0.size(), 0);
6071
std::transform(level0.begin(), level0.end(), ulevel0.begin(),

paddle/fluid/operators/math/concat.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class ConcatFunctor<platform::CPUDeviceContext, T> {
4444
out_cols += t_cols;
4545
input_cols[i] = t_cols;
4646
}
47-
auto& cpu_place = boost::get<platform::CPUPlace>(context.GetPlace());
47+
auto cpu_place = boost::get<platform::CPUPlace>(context.GetPlace());
4848

4949
// computation
5050
for (int k = 0; k < out_rows; ++k) {
@@ -87,7 +87,7 @@ class ConcatGradFunctor<platform::CPUDeviceContext, T> {
8787
input_cols += t_cols;
8888
output_cols[i] = t_cols;
8989
}
90-
auto& cpu_place = boost::get<platform::CPUPlace>(context.GetPlace());
90+
auto cpu_place = boost::get<platform::CPUPlace>(context.GetPlace());
9191

9292
// computation
9393
for (int k = 0; k < input_rows; ++k) {

0 commit comments

Comments
 (0)