Skip to content

Commit 11c3560

Browse files
committed
Remove empty constructor for operator
1 parent 0b1052f commit 11c3560

24 files changed

+158
-116
lines changed

paddle/framework/backward_test.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ using DeviceContext = platform::DeviceContext;
3030

3131
class EmptyOp : public OperatorBase {
3232
public:
33-
DEFINE_OPERATOR_CTOR(EmptyOp, OperatorBase);
33+
using OperatorBase::OperatorBase;
3434
void InferShape(const Scope &scope) const override {}
3535
void Run(const Scope &scope, const DeviceContext &dev_ctx) const override {}
3636
};
@@ -79,8 +79,9 @@ class NoGradOpMaker : public OpProtoAndCheckerMaker {
7979

8080
class FcOp : public operators::NetOp {
8181
public:
82-
DEFINE_OPERATOR_CTOR(FcOp, operators::NetOp)
83-
void Init() override {
82+
FcOp(const std::string &type, const VarNameMap &inputs,
83+
const VarNameMap &outputs, const AttributeMap &attrs)
84+
: NetOp(type, inputs, outputs, attrs) {
8485
AddOp(OpRegistry::CreateOp("mul",
8586
{{"X", {Input("X")}}, {"Y", {Input("W")}}},
8687
{{"Out", {Output("mul_result")}}}, {}));

paddle/framework/grad_op_builder.cc

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,12 @@ class OpRegistry;
2323

2424
enum class OpArgType { IN, OUT };
2525

26-
static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op,
27-
const OpArgType& src_type, const OpArgType& dst_type,
28-
bool is_grad) {
26+
static void TransOpArg(const OperatorBase* src_op,
27+
OperatorBase::VarNameMap* vars,
28+
const OpArgType& src_type, bool is_grad) {
2929
const auto& src_inout =
3030
src_type == OpArgType::IN ? src_op->inputs_ : src_op->outputs_;
31-
auto& dst_inout =
32-
dst_type == OpArgType::IN ? dst_op->inputs_ : dst_op->outputs_;
31+
auto& dst_inout = *vars;
3332

3433
const OpProto& proto = OpProtos().at(src_op->type_);
3534
const auto& src_arg_list =
@@ -47,15 +46,22 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op,
4746
}
4847

4948
OperatorBase* BuildGradOp(const OperatorBase* op) {
50-
std::string grad_op_type = OpRegistry::grad_ops().at(op->type_);
51-
OperatorBase* grad_op = OpRegistry::op_creators().at(grad_op_type)();
52-
grad_op->type_ = grad_op_type;
53-
grad_op->attrs_ = op->attrs_;
54-
TransOpArg(op, grad_op, OpArgType::IN, OpArgType::IN, false); // I
55-
TransOpArg(op, grad_op, OpArgType::OUT, OpArgType::IN, false); // O
56-
TransOpArg(op, grad_op, OpArgType::OUT, OpArgType::IN, true); // OG
57-
TransOpArg(op, grad_op, OpArgType::IN, OpArgType::OUT, true); // IG
58-
return grad_op;
49+
auto gop_type_it = OpRegistry::grad_ops().find(op->type_);
50+
PADDLE_ENFORCE(gop_type_it != OpRegistry::grad_ops().end(),
51+
"Operator %s do not register gradient type", op->type_);
52+
auto& grad_op_type = gop_type_it->second;
53+
OperatorBase::VarNameMap inputs;
54+
OperatorBase::VarNameMap outputs;
55+
TransOpArg(op, &inputs, OpArgType::IN, false); // I
56+
TransOpArg(op, &inputs, OpArgType::OUT, false); // O
57+
TransOpArg(op, &inputs, OpArgType::OUT, true); // OG
58+
TransOpArg(op, &outputs, OpArgType::IN, true); // IG
59+
auto gop_it = OpRegistry::op_creators().find(grad_op_type);
60+
PADDLE_ENFORCE(gop_it != OpRegistry::op_creators().end(),
61+
"Operator %s 's Gradient %s's creator cannot be found",
62+
op->type_, grad_op_type);
63+
64+
return gop_it->second(grad_op_type, inputs, outputs, op->attrs_);
5965
}
6066

6167
} // namespace framework

paddle/framework/grad_op_builder_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ namespace framework {
1010

1111
class NOP : public OperatorBase {
1212
public:
13-
DEFINE_OPERATOR_CTOR(NOP, OperatorBase);
13+
using OperatorBase::OperatorBase;
1414
void InferShape(const Scope &scope) const override {}
1515
void Run(const Scope &scope,
1616
const platform::DeviceContext &dev_ctx) const override {}

paddle/framework/op_registry.h

Lines changed: 17 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -117,13 +117,19 @@ class OpProtoAndCheckerMaker {
117117
};
118118

119119
class OpRegistry {
120-
using OpCreator = std::function<OperatorBase*()>;
121-
using VarNameMap = std::map<std::string, std::vector<std::string>>;
120+
using VarNameMap = OperatorBase::VarNameMap;
121+
using OpCreator = std::function<OperatorBase*(
122+
const std::string& /*type*/, const VarNameMap& /*inputs*/,
123+
const VarNameMap& /*outputs*/, const AttributeMap& /*attrs*/)>;
122124

123125
public:
124126
template <typename OpType, typename ProtoMakerType>
125127
static void RegisterOp(const std::string& op_type) {
126-
op_creators()[op_type] = [] { return new OpType; };
128+
op_creators()[op_type] = [](
129+
const std::string& type, const VarNameMap& inputs,
130+
const VarNameMap& outputs, const AttributeMap& attrs) {
131+
return new OpType(type, inputs, outputs, attrs);
132+
};
127133
OpAttrChecker& op_checker = op_checkers()[op_type];
128134
OpProto& op_proto = OpProtos()[op_type];
129135
auto maker = ProtoMakerType(&op_proto, &op_checker);
@@ -138,29 +144,25 @@ class OpRegistry {
138144
template <typename GradOpType>
139145
static void RegisterGradOp(const std::string& op_type,
140146
const std::string& grad_op_type) {
141-
op_creators()[grad_op_type] = [] { return new GradOpType; };
147+
op_creators()[grad_op_type] = [](
148+
const std::string& type, const VarNameMap& inputs,
149+
const VarNameMap& outputs, const AttributeMap& attrs) {
150+
return new GradOpType(type, inputs, outputs, attrs);
151+
};
142152
grad_ops()[op_type] = grad_op_type;
143153
}
144154

145155
static std::shared_ptr<OperatorBase> CreateOp(const std::string& type,
146156
const VarNameMap& inputs,
147157
const VarNameMap& outputs,
148-
const AttributeMap& attrs) {
158+
AttributeMap attrs) {
149159
auto op_create_it = op_creators().find(type);
150160
PADDLE_ENFORCE(op_create_it != op_creators().end(),
151161
"Operator %s cannot be found.", type);
162+
op_checkers().at(type).Check(attrs);
152163

153-
auto op = op_create_it->second();
154-
op->type_ = type;
155-
op->inputs_ = inputs;
156-
op->outputs_ = outputs;
157-
158-
op->attrs_ = attrs;
159-
op_checkers().at(type).Check(op->attrs_);
160-
161-
GenerateTempVariableName(op);
164+
auto op = op_create_it->second(type, inputs, outputs, attrs);
162165

163-
op->Init();
164166
return std::shared_ptr<OperatorBase>(op);
165167
}
166168

@@ -195,7 +197,6 @@ class OpRegistry {
195197
PADDLE_ENFORCE(!op.IsNetOp(),
196198
"Use framework::Backward to get backward ops");
197199
std::shared_ptr<OperatorBase> grad_op(BuildGradOp(&op));
198-
grad_op->Init();
199200
return grad_op;
200201
}
201202

@@ -214,19 +215,6 @@ class OpRegistry {
214215
static std::unordered_map<std::string, OpAttrChecker> op_checkers_;
215216
return op_checkers_;
216217
}
217-
218-
static void GenerateTempVariableName(OperatorBase* op) {
219-
static std::atomic<size_t> gUniqId(0UL);
220-
for (auto& output : op->outputs_) {
221-
for (auto& output_name : output.second) {
222-
if (output_name == kTempVarName) {
223-
output_name += op->type_;
224-
output_name += "@";
225-
output_name += std::to_string(gUniqId.fetch_add(1));
226-
}
227-
}
228-
}
229-
}
230218
};
231219

232220
class Registrar {

paddle/framework/op_registry_test.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ namespace paddle {
77
namespace framework {
88
class CosineOp : public OperatorBase {
99
public:
10-
DEFINE_OPERATOR_CTOR(CosineOp, OperatorBase);
10+
using OperatorBase::OperatorBase;
1111
void Run(const Scope& scope,
1212
const platform::DeviceContext& dev_ctx) const override {}
1313
void InferShape(const Scope& scope) const override {}
@@ -28,7 +28,7 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
2828

2929
class MyTestOp : public OperatorBase {
3030
public:
31-
DEFINE_OPERATOR_CTOR(MyTestOp, OperatorBase);
31+
using OperatorBase::OperatorBase;
3232
void InferShape(const Scope& scope) const override {}
3333
void Run(const Scope& scope,
3434
const platform::DeviceContext& dev_ctx) const override {}

paddle/framework/operator.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,5 +120,21 @@ void OperatorBase::Rename(const std::string& old_name,
120120
}
121121
}
122122

123+
OperatorBase::OperatorBase(const std::string& type,
124+
const OperatorBase::VarNameMap& inputs,
125+
const OperatorBase::VarNameMap& outputs,
126+
const AttributeMap& attrs)
127+
: type_(type), inputs_(inputs), outputs_(outputs), attrs_(attrs) {
128+
static std::atomic<size_t> gUniqId(0UL);
129+
for (auto& output : outputs_) {
130+
for (auto& output_name : output.second) {
131+
if (output_name == kTempVarName) {
132+
output_name += type_;
133+
output_name += "@";
134+
output_name += std::to_string(gUniqId.fetch_add(1));
135+
}
136+
}
137+
}
138+
}
123139
} // namespace framework
124140
} // namespace paddle

paddle/framework/operator.h

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,8 @@ class OperatorBase {
6666
public:
6767
using VarNameMap = std::map<std::string, std::vector<std::string>>;
6868

69-
OperatorBase() = default;
7069
OperatorBase(const std::string& type, const VarNameMap& inputs,
71-
const VarNameMap& outputs, const AttributeMap& attrs)
72-
: type_(type), inputs_(inputs), outputs_(outputs), attrs_(attrs) {}
70+
const VarNameMap& outputs, const AttributeMap& attrs);
7371

7472
OperatorBase(const OperatorBase& o) = delete;
7573
OperatorBase& operator=(const OperatorBase& o) = delete;
@@ -86,10 +84,6 @@ class OperatorBase {
8684

8785
virtual std::string DebugString() const;
8886

89-
/// Init will be called after CreateOperator, you can put some initialization
90-
/// logic here.
91-
virtual void Init() {}
92-
9387
/// InferShape infer the size of Variables used by this Operator with
9488
/// information inside scope
9589
virtual void InferShape(const Scope& scope) const = 0;
@@ -154,23 +148,14 @@ class OperatorBase {
154148
// I (Inputs)
155149
// O (Outputs)
156150
// OG (Output Gradients)
157-
std::map<std::string, std::vector<std::string>> inputs_;
151+
VarNameMap inputs_;
158152

159153
// NOTE: in case of OpGrad, outputs_ contains
160154
// IG (Inputs Gradients)
161-
std::map<std::string, std::vector<std::string>> outputs_;
155+
VarNameMap outputs_;
162156
AttributeMap attrs_;
163157
};
164158

165-
#define DEFINE_OPERATOR_CTOR(Class, ParentClass) \
166-
public: \
167-
Class() : ParentClass() { /* TODO(yi): This constructor is to be removed. */ \
168-
} \
169-
Class(const std::string& type, const VarNameMap& inputs, \
170-
const VarNameMap& outputs, \
171-
const paddle::framework::AttributeMap& attrs) \
172-
: ParentClass(type, inputs, outputs, attrs) {}
173-
174159
class InferShapeContext {
175160
public:
176161
InferShapeContext(const OperatorBase& op, const Scope& scope)
@@ -310,8 +295,6 @@ class OpKernel {
310295

311296
class OperatorWithKernel : public OperatorBase {
312297
public:
313-
DEFINE_OPERATOR_CTOR(OperatorWithKernel, OperatorBase)
314-
315298
struct OpKernelKey {
316299
platform::Place place_;
317300

@@ -335,6 +318,10 @@ class OperatorWithKernel : public OperatorBase {
335318
using OpKernelMap =
336319
std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>;
337320

321+
OperatorWithKernel(const std::string& type, const VarNameMap& inputs,
322+
const VarNameMap& outputs, const AttributeMap& attrs)
323+
: OperatorBase(type, inputs, outputs, attrs) {}
324+
338325
void InferShape(const Scope& scope) const override {
339326
InferShape(InferShapeContext(*this, scope));
340327
}

paddle/framework/operator_test.cc

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@ namespace framework {
2222
static int op_run_num = 0;
2323

2424
class OpWithoutKernelTest : public OperatorBase {
25-
DEFINE_OPERATOR_CTOR(OpWithoutKernelTest, framework::OperatorBase)
26-
2725
public:
28-
void Init() override { x = 1; }
26+
OpWithoutKernelTest(const std::string& type, const VarNameMap& inputs,
27+
const VarNameMap& outputs, const AttributeMap& attrs)
28+
: OperatorBase(type, inputs, outputs, attrs), x(1) {}
2929
void InferShape(const Scope& scope) const override {}
3030
void Run(const Scope& scope,
3131
const platform::DeviceContext& dev_ctx) const override {
@@ -38,7 +38,7 @@ class OpWithoutKernelTest : public OperatorBase {
3838
}
3939

4040
public:
41-
float x = 0;
41+
int x{0};
4242
};
4343

4444
class OpeWithoutKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
@@ -104,7 +104,9 @@ class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
104104
static int cpu_kernel_run_num = 0;
105105

106106
class OpWithKernelTest : public OperatorWithKernel {
107-
DEFINE_OPERATOR_CTOR(OpWithKernelTest, framework::OperatorWithKernel)
107+
public:
108+
using OperatorWithKernel::OperatorWithKernel;
109+
108110
protected:
109111
void InferShape(const framework::InferShapeContext& ctx) const override {}
110112
};

paddle/operators/add_op.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ namespace paddle {
1818
namespace operators {
1919

2020
class AddOp : public framework::OperatorWithKernel {
21-
DEFINE_OPERATOR_CTOR(AddOp, framework::OperatorWithKernel)
21+
public:
22+
using framework::OperatorWithKernel::OperatorWithKernel;
2223

2324
protected:
2425
void InferShape(const framework::InferShapeContext &ctx) const override {
@@ -45,7 +46,9 @@ The equation is: Out = X + Y
4546
};
4647

4748
class AddOpGrad : public framework::OperatorWithKernel {
48-
DEFINE_OPERATOR_CTOR(AddOpGrad, framework::OperatorWithKernel)
49+
public:
50+
using framework::OperatorWithKernel::OperatorWithKernel;
51+
4952
protected:
5053
void InferShape(const framework::InferShapeContext &ctx) const override {}
5154
};

paddle/operators/cross_entropy_op.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@ namespace paddle {
1818
namespace operators {
1919

2020
class OnehotCrossEntropyOp : public framework::OperatorWithKernel {
21-
DEFINE_OPERATOR_CTOR(OnehotCrossEntropyOp, framework::OperatorWithKernel)
21+
public:
22+
using framework::OperatorWithKernel::OperatorWithKernel;
23+
2224
protected:
2325
void InferShape(const framework::InferShapeContext &ctx) const override {
2426
auto *X = ctx.Input<Tensor>("X");
@@ -32,8 +34,9 @@ class OnehotCrossEntropyOp : public framework::OperatorWithKernel {
3234
};
3335

3436
class OnehotCrossEntropyGradientOp : public framework::OperatorWithKernel {
35-
DEFINE_OPERATOR_CTOR(OnehotCrossEntropyGradientOp,
36-
framework::OperatorWithKernel)
37+
public:
38+
using framework::OperatorWithKernel::OperatorWithKernel;
39+
3740
protected:
3841
void InferShape(const framework::InferShapeContext &ctx) const override {
3942
auto X_grad = ctx.Output<Tensor>(framework::GradVarName("X"));

0 commit comments

Comments
 (0)