Skip to content

Commit 7140071

Browse files
authored
"exported scatter to python" (#9038)
* "exported scatter to python" * Revert ""exported scatter to python"" This reverts commit 38745a6. * "polish scatter and export to python"
1 parent cf2addd commit 7140071

File tree

5 files changed

+46
-65
lines changed

5 files changed

+46
-65
lines changed

paddle/fluid/operators/scatter_op.cc

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,24 +23,24 @@ class ScatterOp : public framework::OperatorWithKernel {
2323
using framework::OperatorWithKernel::OperatorWithKernel;
2424

2525
void InferShape(framework::InferShapeContext* ctx) const override {
26-
PADDLE_ENFORCE(ctx->HasInput("Ref"),
27-
"Input(Ref) of ScatterOp should not be null.");
28-
PADDLE_ENFORCE(ctx->HasInput("Index"),
29-
"Input(Index) of ScatterOp should not be null.");
26+
PADDLE_ENFORCE(ctx->HasInput("X"),
27+
"Input(X) of ScatterOp should not be null.");
28+
PADDLE_ENFORCE(ctx->HasInput("Ids"),
29+
"Input(Ids) of ScatterOp should not be null.");
3030
PADDLE_ENFORCE(ctx->HasInput("Updates"),
3131
"Input(Updates) of ScatterOp should not be null.");
3232
PADDLE_ENFORCE(ctx->HasOutput("Out"),
3333
"Output(Out) of ScatterOp should not be null.");
3434

3535
auto updates_dims = ctx->GetInputDim("Updates");
36-
auto ref_dims = ctx->GetInputDim("Ref");
37-
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Index").size(), 1,
38-
"Update Index should be 1-D.");
36+
auto ref_dims = ctx->GetInputDim("X");
37+
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Ids").size(), 1,
38+
"Update Ids should be 1-D.");
3939
PADDLE_ENFORCE_EQ(ref_dims.size(), updates_dims.size(),
40-
"Reference and Updates should have the same shape size");
40+
"Xerence and Updates should have the same shape size");
4141
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Updates")[0],
42-
ctx->GetInputDim("Index")[0],
43-
"Updates and Index should have same batch-size.");
42+
ctx->GetInputDim("Ids")[0],
43+
"Updates and Ids should have same batch-size.");
4444
framework::DDim data_dim(updates_dims);
4545
for (int i = 1; i < data_dim.size(); ++i) {
4646
PADDLE_ENFORCE_EQ(data_dim[i], updates_dims[i]);
@@ -52,7 +52,7 @@ class ScatterOp : public framework::OperatorWithKernel {
5252
framework::OpKernelType GetExpectedKernelType(
5353
const framework::ExecutionContext& ctx) const override {
5454
return framework::OpKernelType(
55-
framework::ToDataType(ctx.Input<Tensor>("Ref")->type()),
55+
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
5656
ctx.device_context());
5757
}
5858
};
@@ -64,14 +64,14 @@ class ScatterGradOp : public framework::OperatorWithKernel {
6464
void InferShape(framework::InferShapeContext* ctx) const override {
6565
ctx->SetOutputDim(framework::GradVarName("Updates"),
6666
ctx->GetInputDim("Updates"));
67-
ctx->SetOutputDim(framework::GradVarName("Ref"), ctx->GetInputDim("Ref"));
67+
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
6868
}
6969

7070
protected:
7171
framework::OpKernelType GetExpectedKernelType(
7272
const framework::ExecutionContext& ctx) const override {
7373
return framework::OpKernelType(
74-
framework::ToDataType(ctx.Input<Tensor>("Ref")->type()),
74+
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
7575
ctx.device_context());
7676
}
7777
};
@@ -80,9 +80,8 @@ class ScatterOpMaker : public framework::OpProtoAndCheckerMaker {
8080
public:
8181
ScatterOpMaker(OpProto* proto, OpAttrChecker* op_checker)
8282
: OpProtoAndCheckerMaker(proto, op_checker) {
83-
AddInput("Ref", "The source input of scatter op");
84-
AddInput("Index",
85-
"The index input of scatter op where Ref will be updated");
83+
AddInput("X", "The source input of scatter op");
84+
AddInput("Ids", "The index input of scatter op where X will be updated");
8685
AddInput("Updates", "The updated value of updates op");
8786
AddOutput("Out", "The output of add op");
8887
AddComment(R"DOC(
@@ -91,8 +90,8 @@ Scatter Operator.
9190
This operator obtains output by updating the input on selected indices on the first axis:
9291
9392
$$
94-
Out = Ref \\
95-
Out[Index] = Ref[Index] + Updates
93+
Out = X \\
94+
Out[Ids] = X[Ids] + Updates
9695
$$
9796
9897
)DOC");

paddle/fluid/operators/scatter_op.cu

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@ class ScatterOpCUDAKernel : public framework::OpKernel<T> {
2525
void Compute(const framework::ExecutionContext &ctx) const override {
2626
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
2727
"This kernel only runs on GPU device.");
28-
auto *Ref = ctx.Input<Tensor>("Ref");
29-
auto *Index = ctx.Input<Tensor>("Index");
28+
auto *X = ctx.Input<Tensor>("X");
29+
auto *Ids = ctx.Input<Tensor>("Ids");
3030
auto *Updates = ctx.Input<Tensor>("Updates");
3131
auto *Out = ctx.Output<Tensor>("Out");
3232

33-
Out->ShareDataWith(*Ref);
33+
Out->ShareDataWith(*X);
3434

35-
GPUScatterAssign<T>(ctx.device_context(), *Updates, *Index, Out);
35+
GPUScatterAssign<T>(ctx.device_context(), *Updates, *Ids, Out);
3636
}
3737
};
3838

@@ -42,16 +42,16 @@ class ScatterGradOpCUDAKernel : public framework::OpKernel<T> {
4242
void Compute(const framework::ExecutionContext &ctx) const override {
4343
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
4444
"This kernel only runs on GPU device.");
45-
auto *dRef = ctx.Output<Tensor>(framework::GradVarName("Ref"));
45+
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
4646
auto *dUpdates = ctx.Output<Tensor>(framework::GradVarName("Updates"));
47-
auto *Index = ctx.Input<Tensor>("Index");
47+
auto *Ids = ctx.Input<Tensor>("Ids");
4848
auto *dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));
4949

50-
// In place gradient: dRef = dO
51-
dRef->ShareDataWith(*dOut);
50+
// In place gradient: dX = dO
51+
dX->ShareDataWith(*dOut);
5252
dUpdates->mutable_data<T>(ctx.GetPlace());
53-
// Gradient by Gather: dUpdates = dO[Index]
54-
GPUGather<T>(ctx.device_context(), *dOut, *Index, dUpdates);
53+
// Gradient by Gather: dUpdates = dO[Ids]
54+
GPUGather<T>(ctx.device_context(), *dOut, *Ids, dUpdates);
5555
}
5656
};
5757

paddle/fluid/operators/scatter_op.h

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,15 @@ class ScatterOpKernel : public framework::OpKernel<T> {
2929
void Compute(const framework::ExecutionContext &ctx) const override {
3030
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
3131
"This kernel only runs on CPU.");
32-
auto *Ref = ctx.Input<Tensor>("Ref");
33-
auto *Index = ctx.Input<Tensor>("Index");
32+
auto *X = ctx.Input<Tensor>("X");
33+
auto *Ids = ctx.Input<Tensor>("Ids");
3434
auto *Updates = ctx.Input<Tensor>("Updates");
3535
auto *Out = ctx.Output<Tensor>("Out");
3636

37-
// In place output: Out = Ref, Out[Index] += Updates
38-
Out->ShareDataWith(*Ref);
37+
// In place output: Out = X, Out[Ids] += Updates
38+
Out->ShareDataWith(*X);
3939
// Apply ScatterUpdate: Out[index] += Updates[:]
40-
ScatterAssign<T>(ctx.device_context(), *Updates, *Index, Out);
40+
ScatterAssign<T>(ctx.device_context(), *Updates, *Ids, Out);
4141
}
4242
};
4343

@@ -47,16 +47,16 @@ class ScatterGradientOpKernel : public framework::OpKernel<T> {
4747
void Compute(const framework::ExecutionContext &ctx) const override {
4848
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
4949
"This kernel only runs on CPU.");
50-
auto *dRef = ctx.Output<Tensor>(framework::GradVarName("Ref"));
50+
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
5151
auto *dUpdates = ctx.Output<Tensor>(framework::GradVarName("Updates"));
52-
auto *Index = ctx.Input<Tensor>("Index");
52+
auto *Ids = ctx.Input<Tensor>("Ids");
5353
auto *dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));
5454

55-
// In place gradient: dRef = dO
56-
dRef->ShareDataWith(*dOut);
55+
// In place gradient: dX = dO
56+
dX->ShareDataWith(*dOut);
5757
dUpdates->mutable_data<T>(ctx.GetPlace());
58-
// Gradient by Gather: dUpdates += dO[Index]
59-
CPUGather<T>(ctx.device_context(), *dOut, *Index, dUpdates);
58+
// Gradient by Gather: dUpdates += dO[Ids]
59+
CPUGather<T>(ctx.device_context(), *dOut, *Ids, dUpdates);
6060
}
6161
};
6262

python/paddle/fluid/layers/ops.py

Lines changed: 7 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -45,31 +45,13 @@
4545
]
4646

4747
__all__ = [
48-
'mean',
49-
'mul',
50-
'reshape',
51-
'scale',
52-
'sigmoid_cross_entropy_with_logits',
53-
'elementwise_add',
54-
'elementwise_div',
55-
'elementwise_sub',
56-
'elementwise_mul',
57-
'elementwise_max',
58-
'elementwise_min',
59-
'elementwise_pow',
60-
'clip',
61-
'clip_by_norm',
62-
'softmax',
63-
'sequence_softmax',
64-
'logical_and',
65-
'logical_or',
66-
'logical_xor',
67-
'logical_not',
68-
'uniform_random',
69-
'uniform_random_batch_size_like',
70-
'gaussian_random',
71-
'gaussian_random_batch_size_like',
72-
'cumsum',
48+
'mean', 'mul', 'reshape', 'scale', 'sigmoid_cross_entropy_with_logits',
49+
'elementwise_add', 'elementwise_div', 'elementwise_sub', 'elementwise_mul',
50+
'elementwise_max', 'elementwise_min', 'elementwise_pow', 'clip',
51+
'clip_by_norm', 'softmax', 'sequence_softmax', 'logical_and', 'logical_or',
52+
'logical_xor', 'logical_not', 'uniform_random',
53+
'uniform_random_batch_size_like', 'gaussian_random',
54+
'gaussian_random_batch_size_like', 'cumsum', 'scatter'
7355
] + __activations__
7456

7557
for _OP in set(__all__):

python/paddle/fluid/tests/unittests/test_scatter_op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def setUp(self):
2525
updates_np = np.random.random((2, 3)).astype("float32")
2626
output_np = np.copy(ref_np)
2727
output_np[index_np] = updates_np
28-
self.inputs = {'Ref': ref_np, 'Index': index_np, 'Updates': updates_np}
28+
self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np}
2929
self.outputs = {'Out': output_np}
3030

3131
def test_check_output(self):

0 commit comments

Comments
 (0)