@@ -23,24 +23,24 @@ class ScatterOp : public framework::OperatorWithKernel {
2323 using framework::OperatorWithKernel::OperatorWithKernel;
2424
2525 void InferShape (framework::InferShapeContext* ctx) const override {
26- PADDLE_ENFORCE (ctx->HasInput (" Ref " ),
27- " Input(Ref ) of ScatterOp should not be null." );
28- PADDLE_ENFORCE (ctx->HasInput (" Index " ),
29- " Input(Index ) of ScatterOp should not be null." );
26+ PADDLE_ENFORCE (ctx->HasInput (" X " ),
27+ " Input(X ) of ScatterOp should not be null." );
28+ PADDLE_ENFORCE (ctx->HasInput (" Ids " ),
29+ " Input(Ids ) of ScatterOp should not be null." );
3030 PADDLE_ENFORCE (ctx->HasInput (" Updates" ),
3131 " Input(Updates) of ScatterOp should not be null." );
3232 PADDLE_ENFORCE (ctx->HasOutput (" Out" ),
3333 " Output(Out) of ScatterOp should not be null." );
3434
3535 auto updates_dims = ctx->GetInputDim (" Updates" );
36- auto ref_dims = ctx->GetInputDim (" Ref " );
37- PADDLE_ENFORCE_EQ (ctx->GetInputDim (" Index " ).size (), 1 ,
38- " Update Index should be 1-D." );
36+ auto ref_dims = ctx->GetInputDim (" X " );
37+ PADDLE_ENFORCE_EQ (ctx->GetInputDim (" Ids " ).size (), 1 ,
38+ " Update Ids should be 1-D." );
3939 PADDLE_ENFORCE_EQ (ref_dims.size (), updates_dims.size (),
40- " Reference and Updates should have the same shape size" );
40+ " Xerence and Updates should have the same shape size" );
4141 PADDLE_ENFORCE_EQ (ctx->GetInputDim (" Updates" )[0 ],
42- ctx->GetInputDim (" Index " )[0 ],
43- " Updates and Index should have same batch-size." );
42+ ctx->GetInputDim (" Ids " )[0 ],
43+ " Updates and Ids should have same batch-size." );
4444 framework::DDim data_dim (updates_dims);
4545 for (int i = 1 ; i < data_dim.size (); ++i) {
4646 PADDLE_ENFORCE_EQ (data_dim[i], updates_dims[i]);
@@ -52,7 +52,7 @@ class ScatterOp : public framework::OperatorWithKernel {
5252 framework::OpKernelType GetExpectedKernelType (
5353 const framework::ExecutionContext& ctx) const override {
5454 return framework::OpKernelType (
55- framework::ToDataType (ctx.Input <Tensor>(" Ref " )->type ()),
55+ framework::ToDataType (ctx.Input <Tensor>(" X " )->type ()),
5656 ctx.device_context ());
5757 }
5858};
@@ -64,14 +64,14 @@ class ScatterGradOp : public framework::OperatorWithKernel {
6464 void InferShape (framework::InferShapeContext* ctx) const override {
6565 ctx->SetOutputDim (framework::GradVarName (" Updates" ),
6666 ctx->GetInputDim (" Updates" ));
67- ctx->SetOutputDim (framework::GradVarName (" Ref " ), ctx->GetInputDim (" Ref " ));
67+ ctx->SetOutputDim (framework::GradVarName (" X " ), ctx->GetInputDim (" X " ));
6868 }
6969
7070 protected:
7171 framework::OpKernelType GetExpectedKernelType (
7272 const framework::ExecutionContext& ctx) const override {
7373 return framework::OpKernelType (
74- framework::ToDataType (ctx.Input <Tensor>(" Ref " )->type ()),
74+ framework::ToDataType (ctx.Input <Tensor>(" X " )->type ()),
7575 ctx.device_context ());
7676 }
7777};
@@ -80,9 +80,8 @@ class ScatterOpMaker : public framework::OpProtoAndCheckerMaker {
8080 public:
8181 ScatterOpMaker (OpProto* proto, OpAttrChecker* op_checker)
8282 : OpProtoAndCheckerMaker(proto, op_checker) {
83- AddInput (" Ref" , " The source input of scatter op" );
84- AddInput (" Index" ,
85- " The index input of scatter op where Ref will be updated" );
83+ AddInput (" X" , " The source input of scatter op" );
84+ AddInput (" Ids" , " The index input of scatter op where X will be updated" );
8685 AddInput (" Updates" , " The updated value of updates op" );
8786 AddOutput (" Out" , " The output of add op" );
8887 AddComment (R"DOC(
@@ -91,8 +90,8 @@ Scatter Operator.
9190This operator obtains output by updating the input on selected indices on the first axis:
9291
9392$$
94- Out = Ref \\
95- Out[Index ] = Ref[Index ] + Updates
93+ Out = X \\
94+ Out[Ids ] = X[Ids ] + Updates
9695$$
9796
9897)DOC" );
0 commit comments