@@ -22,17 +22,16 @@ class LoDResetOp : public framework::OperatorWithKernel {
2222 using framework::OperatorWithKernel::OperatorWithKernel;
2323
2424 void InferShape (framework::InferShapeContext *ctx) const override {
25- // input check
2625 PADDLE_ENFORCE (ctx->HasInput (" X" ),
2726 " Input(X) of LoDResetOp should not be null." );
2827 PADDLE_ENFORCE (ctx->HasOutput (" Out" ),
2928 " Output(Out) of LoDResetOp should not be null." );
30- // If target LoD is not set form Input(), then it must be set from Attr().
31- if (!ctx->HasInput (" TargetLoD " )) {
29+
30+ if (!ctx->HasInput (" Y " )) {
3231 auto level0 = ctx->Attrs ().Get <std::vector<int >>(" target_lod" );
33- PADDLE_ENFORCE (level0.size () > 1 ,
34- " Target LoD is not found, should be set to be a valid one "
35- " through Input() or Attr() ." );
32+ PADDLE_ENFORCE_GT (level0.size (), 1 ,
33+ " If Input(Y) not provided, the target lod should be "
34+ " specified by attribute `target_lod` ." );
3635 }
3736 ctx->SetOutputDim (" Out" , ctx->GetInputDim (" X" ));
3837 }
@@ -50,36 +49,77 @@ class LoDResetOpMaker : public framework::OpProtoAndCheckerMaker {
5049 public:
5150 LoDResetOpMaker (OpProto *proto, OpAttrChecker *op_checker)
5251 : OpProtoAndCheckerMaker(proto, op_checker) {
53- AddInput (" X" , " (LoDTensor) The input tensor of lod_reset operator." );
54- AddInput (" TargetLoD" ,
55- " (Tensor, optional) The target level 0 LoD from Input()." )
52+ AddInput (" X" ,
53+ " (Tensor, LoDTensor) Input variable of LoDResetOp which "
54+ " could be a Tensor or LoDTensor, where the data of output "
55+ " variable inherits from." );
56+ AddInput (" Y" ,
57+ " (Tensor, LoDTensor, optional) If provided and Y is LoDTensor, "
58+ " lod of Input(Y) would be considered as the target lod first, "
59+ " otherwise data of Input(Y) would be considered as the "
60+ " target lod." )
5661 .AsDispensable ();
57- AddOutput (" Out" , " (LoDTensor) The output tensor of lod_reset operator." );
62+ AddOutput (" Out" ,
63+ " (LoDTensor) Output variable of LoDResetOp which should be a "
64+ " LoDTensor." );
5865 AddAttr<std::vector<int >>(" target_lod" ,
5966 " The target level 0 LoD from Attr()." )
6067 .SetDefault (std::vector<int >{});
6168 AddComment (R"DOC( LoDReset operator
6269
63- Reset LoD of Input(X) into a new one specified by Input(TargetLoD) or
64- Attr(target_lod), or set LoD for Input(X) if it doesn't have one.
65- Currently the lod_reset operator only supports the reset of level 0 LoD.
66- At least one of Input(TargetLoD) and Attr(target_lod) must be set,
67- and if both of them are set, Input(TargetLoD) will be chosen as the
68- target LoD.
70+ Set LoD of `X` to a new one specified by `Y` or attribute `target_lod`. When `Y`
71+ provided and `Y` is a LoDTensor, `Y.lod` would be considered as target LoD
72+ first, otherwise `Y.data` would be considered as target LoD. If `Y` is not
73+ provided, target LoD should be specified by attribute `target_lod`.
74+ If target LoD is specified by `Y.data` or `target_lod`, only one level LoD
75+ is supported.
76+
77+ Example 1:
78+
79+ Given a 1-level LoDTensor input(X):
80+ X.lod = [[ 0, 2, 5 6 ]]
81+ X.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
82+ X.dims = [6, 1]
83+
84+ attr(target_lod): [0, 4, 6]
85+
86+ then we get a 1-level LoDTensor:
87+ Out.lod = [[ 0, 4, 6 ]]
88+ Out.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
89+ Out.dims = [6, 1]
90+
91+ Example 2:
6992
70- An example:
71- Given a float LoDTensor X with shape (6, 1), its transpose form represents
93+ Given a 1-level LoDTensor input(X):
94+ X.lod = [[ 0, 2, 5 6 ]]
95+ X.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
96+ X.dims = [6, 1]
7297
73- [1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
98+ input(Y) is a Tensor:
99+ Y.data = [[0, 2, 6]]
100+ Y.dims = [1, 3]
74101
75- with LoD = [[0, 2, 5, 6]] and the three (transposed) sequences look like
102+ then we get a 1-level LoDTensor:
103+ Out.lod = [[ 0, 2, 6 ]]
104+ Out.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
105+ Out.dims = [6, 1]
76106
77- [1.0, 2.0], [3.0, 4.0, 5.0], [6.0].
107+ Example 3:
78108
79- If target LoD = [0, 4, 6], the lod_reset operator will reset the LoD and
80- the sequences that the LoDTensor Output(Out) contains becomes:
109+ Given a 1-level LoDTensor input(X):
110+ X.lod = [[ 0, 2, 5 6 ]]
111+ X.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
112+ X.dims = [6, 1]
81113
82- [1.0, 2.0, 3.0, 4.0], [5.0, 6.0].
114+ input(Y) is a 2-level LoDTensor:
115+ Y.lod = [[0, 2, 4], [0, 2, 5, 6]]
116+ Y.data = [[1.1], [2.1], [3.1], [4.1], [5.1], [6.1]]
117+ Y.dims = [6, 1]
118+
119+ then we get a 2-level LoDTensor:
120+ Out.lod = [[0, 2, 4], [0, 2, 5, 6]]
121+ Out.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
122+ Out.dims = [6, 1]
83123
84124)DOC" );
85125 }
@@ -90,10 +130,16 @@ class LoDResetGradOp : public framework::OperatorWithKernel {
90130 using framework::OperatorWithKernel::OperatorWithKernel;
91131
92132 void InferShape (framework::InferShapeContext *ctx) const override {
93- PADDLE_ENFORCE (ctx->HasInput (" X" ), " Input(X) shouldn't be null." );
133+ PADDLE_ENFORCE (ctx->HasInput (" X" ),
134+ " Input(X) of LoDResetGradOp should not be null." );
94135 PADDLE_ENFORCE (ctx->HasInput (framework::GradVarName (" Out" )),
95- " Input(Out@GRAD) shouldn't be null." );
96- ctx->SetOutputDim (framework::GradVarName (" X" ), ctx->GetInputDim (" X" ));
136+ " Input(Out@Grad) of LoDResetGradOp should not be null." );
137+
138+ auto x_grad_name = framework::GradVarName (" X" );
139+ if (ctx->HasOutput (x_grad_name)) {
140+ ctx->SetOutputDim (x_grad_name, ctx->GetInputDim (" X" ));
141+ ctx->ShareLoD (" X" , /* ->*/ x_grad_name);
142+ }
97143 }
98144
99145 protected:
@@ -111,9 +157,13 @@ class LoDResetGradOp : public framework::OperatorWithKernel {
111157namespace ops = paddle::operators;
112158REGISTER_OP (lod_reset, ops::LoDResetOp, ops::LoDResetOpMaker, lod_reset_grad,
113159 ops::LoDResetGradOp);
114- REGISTER_OP_CPU_KERNEL (lod_reset,
115- ops::LoDResetKernel<paddle::platform::CPUPlace, float >,
116- ops::LoDResetKernel<paddle::platform::CPUPlace, double >);
160+ REGISTER_OP_CPU_KERNEL (
161+ lod_reset, ops::LoDResetKernel<paddle::platform::CPUPlace, float >,
162+ ops::LoDResetKernel<paddle::platform::CPUPlace, double >,
163+ ops::LoDResetKernel<paddle::platform::CPUPlace, int >,
164+ ops::LoDResetKernel<paddle::platform::CPUPlace, int64_t >);
117165REGISTER_OP_CPU_KERNEL (
118166 lod_reset_grad, ops::LoDResetGradKernel<paddle::platform::CPUPlace, float >,
119- ops::LoDResetGradKernel<paddle::platform::CPUPlace, double >);
167+ ops::LoDResetGradKernel<paddle::platform::CPUPlace, double >,
168+ ops::LoDResetGradKernel<paddle::platform::CPUPlace, int >,
169+ ops::LoDResetGradKernel<paddle::platform::CPUPlace, int64_t >);
0 commit comments