Skip to content

Commit 0b1052f

Browse files
committed
Get DEFINE_OPERATOR_CTOR Back to code
1 parent 509d320 commit 0b1052f

19 files changed

+58
-0
lines changed

paddle/framework/backward_test.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ using DeviceContext = platform::DeviceContext;
3030

3131
class EmptyOp : public OperatorBase {
3232
public:
33+
DEFINE_OPERATOR_CTOR(EmptyOp, OperatorBase);
3334
void InferShape(const Scope &scope) const override {}
3435
void Run(const Scope &scope, const DeviceContext &dev_ctx) const override {}
3536
};
@@ -78,6 +79,7 @@ class NoGradOpMaker : public OpProtoAndCheckerMaker {
7879

7980
class FcOp : public operators::NetOp {
8081
public:
82+
DEFINE_OPERATOR_CTOR(FcOp, operators::NetOp)
8183
void Init() override {
8284
AddOp(OpRegistry::CreateOp("mul",
8385
{{"X", {Input("X")}}, {"Y", {Input("W")}}},

paddle/framework/grad_op_builder_test.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ namespace framework {
1010

1111
class NOP : public OperatorBase {
1212
public:
13+
DEFINE_OPERATOR_CTOR(NOP, OperatorBase);
1314
void InferShape(const Scope &scope) const override {}
1415
void Run(const Scope &scope,
1516
const platform::DeviceContext &dev_ctx) const override {}

paddle/framework/op_registry_test.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ namespace paddle {
77
namespace framework {
88
class CosineOp : public OperatorBase {
99
public:
10+
DEFINE_OPERATOR_CTOR(CosineOp, OperatorBase);
1011
void Run(const Scope& scope,
1112
const platform::DeviceContext& dev_ctx) const override {}
1213
void InferShape(const Scope& scope) const override {}
@@ -27,6 +28,7 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
2728

2829
class MyTestOp : public OperatorBase {
2930
public:
31+
DEFINE_OPERATOR_CTOR(MyTestOp, OperatorBase);
3032
void InferShape(const Scope& scope) const override {}
3133
void Run(const Scope& scope,
3234
const platform::DeviceContext& dev_ctx) const override {}

paddle/framework/operator.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,17 @@ class ExecutionContext;
6464
*/
6565
class OperatorBase {
6666
public:
67+
using VarNameMap = std::map<std::string, std::vector<std::string>>;
68+
69+
OperatorBase() = default;
70+
OperatorBase(const std::string& type, const VarNameMap& inputs,
71+
const VarNameMap& outputs, const AttributeMap& attrs)
72+
: type_(type), inputs_(inputs), outputs_(outputs), attrs_(attrs) {}
73+
74+
OperatorBase(const OperatorBase& o) = delete;
75+
OperatorBase& operator=(const OperatorBase& o) = delete;
76+
OperatorBase(OperatorBase&& o) = delete;
77+
6778
virtual ~OperatorBase() {}
6879

6980
template <typename T>
@@ -151,6 +162,15 @@ class OperatorBase {
151162
AttributeMap attrs_;
152163
};
153164

165+
#define DEFINE_OPERATOR_CTOR(Class, ParentClass) \
166+
public: \
167+
Class() : ParentClass() { /* TODO(yi): This constructor is to be removed. */ \
168+
} \
169+
Class(const std::string& type, const VarNameMap& inputs, \
170+
const VarNameMap& outputs, \
171+
const paddle::framework::AttributeMap& attrs) \
172+
: ParentClass(type, inputs, outputs, attrs) {}
173+
154174
class InferShapeContext {
155175
public:
156176
InferShapeContext(const OperatorBase& op, const Scope& scope)
@@ -290,6 +310,8 @@ class OpKernel {
290310

291311
class OperatorWithKernel : public OperatorBase {
292312
public:
313+
DEFINE_OPERATOR_CTOR(OperatorWithKernel, OperatorBase)
314+
293315
struct OpKernelKey {
294316
platform::Place place_;
295317

paddle/framework/operator_test.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ namespace framework {
2222
static int op_run_num = 0;
2323

2424
class OpWithoutKernelTest : public OperatorBase {
25+
DEFINE_OPERATOR_CTOR(OpWithoutKernelTest, framework::OperatorBase)
26+
2527
public:
2628
void Init() override { x = 1; }
2729
void InferShape(const Scope& scope) const override {}
@@ -102,6 +104,7 @@ class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
102104
static int cpu_kernel_run_num = 0;
103105

104106
class OpWithKernelTest : public OperatorWithKernel {
107+
DEFINE_OPERATOR_CTOR(OpWithKernelTest, framework::OperatorWithKernel)
105108
protected:
106109
void InferShape(const framework::InferShapeContext& ctx) const override {}
107110
};

paddle/operators/add_op.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ namespace paddle {
1818
namespace operators {
1919

2020
class AddOp : public framework::OperatorWithKernel {
21+
DEFINE_OPERATOR_CTOR(AddOp, framework::OperatorWithKernel)
22+
2123
protected:
2224
void InferShape(const framework::InferShapeContext &ctx) const override {
2325
PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("X")->dims(),
@@ -43,6 +45,7 @@ The equation is: Out = X + Y
4345
};
4446

4547
class AddOpGrad : public framework::OperatorWithKernel {
48+
DEFINE_OPERATOR_CTOR(AddOpGrad, framework::OperatorWithKernel)
4649
protected:
4750
void InferShape(const framework::InferShapeContext &ctx) const override {}
4851
};

paddle/operators/cross_entropy_op.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ namespace paddle {
1818
namespace operators {
1919

2020
class OnehotCrossEntropyOp : public framework::OperatorWithKernel {
21+
DEFINE_OPERATOR_CTOR(OnehotCrossEntropyOp, framework::OperatorWithKernel)
2122
protected:
2223
void InferShape(const framework::InferShapeContext &ctx) const override {
2324
auto *X = ctx.Input<Tensor>("X");
@@ -31,6 +32,8 @@ class OnehotCrossEntropyOp : public framework::OperatorWithKernel {
3132
};
3233

3334
class OnehotCrossEntropyGradientOp : public framework::OperatorWithKernel {
35+
DEFINE_OPERATOR_CTOR(OnehotCrossEntropyGradientOp,
36+
framework::OperatorWithKernel)
3437
protected:
3538
void InferShape(const framework::InferShapeContext &ctx) const override {
3639
auto X_grad = ctx.Output<Tensor>(framework::GradVarName("X"));

paddle/operators/fill_zeros_like_op.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ namespace paddle {
1818
namespace operators {
1919

2020
class FillZerosLikeOp : public framework::OperatorWithKernel {
21+
DEFINE_OPERATOR_CTOR(FillZerosLikeOp, framework::OperatorWithKernel);
22+
2123
protected:
2224
void InferShape(const framework::InferShapeContext &ctx) const override {
2325
ctx.Output<framework::Tensor>("Dst")->Resize(

paddle/operators/gaussian_random_op.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ class GaussianRandomKernel : public framework::OpKernel {
4343
};
4444

4545
class GaussianRandomOp : public framework::OperatorWithKernel {
46+
DEFINE_OPERATOR_CTOR(GaussianRandomOp, framework::OperatorWithKernel);
47+
4648
protected:
4749
void InferShape(const framework::InferShapeContext& context) const override {
4850
auto* tensor = context.Output<framework::Tensor>(0);

paddle/operators/mean_op.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ namespace paddle {
1818
namespace operators {
1919

2020
class MeanOp : public framework::OperatorWithKernel {
21+
DEFINE_OPERATOR_CTOR(MeanOp, framework::OperatorWithKernel)
2122
protected:
2223
void InferShape(const framework::InferShapeContext &ctx) const override {
2324
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
@@ -37,6 +38,7 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
3738
};
3839

3940
class MeanGradOp : public framework::OperatorWithKernel {
41+
DEFINE_OPERATOR_CTOR(MeanGradOp, framework::OperatorWithKernel)
4042
protected:
4143
void InferShape(const framework::InferShapeContext &ctx) const override {
4244
ctx.Output<Tensor>(framework::GradVarName("X"))

0 commit comments

Comments
 (0)