@@ -117,13 +117,19 @@ class OpProtoAndCheckerMaker {
117117};
118118
119119class OpRegistry {
120- using OpCreator = std::function<OperatorBase*()>;
121- using VarNameMap = std::map<std::string, std::vector<std::string>>;
120+ using VarNameMap = OperatorBase::VarNameMap;
121+ using OpCreator = std::function<OperatorBase*(
122+ const std::string& /* type*/ , const VarNameMap& /* inputs*/ ,
123+ const VarNameMap& /* outputs*/ , const AttributeMap& /* attrs*/ )>;
122124
123125 public:
124126 template <typename OpType, typename ProtoMakerType>
125127 static void RegisterOp (const std::string& op_type) {
126- op_creators ()[op_type] = [] { return new OpType; };
128+ op_creators ()[op_type] = [](
129+ const std::string& type, const VarNameMap& inputs,
130+ const VarNameMap& outputs, const AttributeMap& attrs) {
131+ return new OpType (type, inputs, outputs, attrs);
132+ };
127133 OpAttrChecker& op_checker = op_checkers ()[op_type];
128134 OpProto& op_proto = OpProtos ()[op_type];
129135 auto maker = ProtoMakerType (&op_proto, &op_checker);
@@ -138,29 +144,25 @@ class OpRegistry {
138144 template <typename GradOpType>
139145 static void RegisterGradOp (const std::string& op_type,
140146 const std::string& grad_op_type) {
141- op_creators ()[grad_op_type] = [] { return new GradOpType; };
147+ op_creators ()[grad_op_type] = [](
148+ const std::string& type, const VarNameMap& inputs,
149+ const VarNameMap& outputs, const AttributeMap& attrs) {
150+ return new GradOpType (type, inputs, outputs, attrs);
151+ };
142152 grad_ops ()[op_type] = grad_op_type;
143153 }
144154
145155 static std::shared_ptr<OperatorBase> CreateOp (const std::string& type,
146156 const VarNameMap& inputs,
147157 const VarNameMap& outputs,
148- const AttributeMap& attrs) {
158+ AttributeMap attrs) {
149159 auto op_create_it = op_creators ().find (type);
150160 PADDLE_ENFORCE (op_create_it != op_creators ().end (),
151161 " Operator %s cannot be found." , type);
162+ op_checkers ().at (type).Check (attrs);
152163
153- auto op = op_create_it->second ();
154- op->type_ = type;
155- op->inputs_ = inputs;
156- op->outputs_ = outputs;
157-
158- op->attrs_ = attrs;
159- op_checkers ().at (type).Check (op->attrs_ );
160-
161- GenerateTempVariableName (op);
164+ auto op = op_create_it->second (type, inputs, outputs, attrs);
162165
163- op->Init ();
164166 return std::shared_ptr<OperatorBase>(op);
165167 }
166168
@@ -195,7 +197,6 @@ class OpRegistry {
195197 PADDLE_ENFORCE (!op.IsNetOp (),
196198 " Use framework::Backward to get backward ops" );
197199 std::shared_ptr<OperatorBase> grad_op (BuildGradOp (&op));
198- grad_op->Init ();
199200 return grad_op;
200201 }
201202
@@ -214,19 +215,6 @@ class OpRegistry {
214215 static std::unordered_map<std::string, OpAttrChecker> op_checkers_;
215216 return op_checkers_;
216217 }
217-
218- static void GenerateTempVariableName (OperatorBase* op) {
219- static std::atomic<size_t > gUniqId (0UL );
220- for (auto & output : op->outputs_ ) {
221- for (auto & output_name : output.second ) {
222- if (output_name == kTempVarName ) {
223- output_name += op->type_ ;
224- output_name += " @" ;
225- output_name += std::to_string (gUniqId .fetch_add (1 ));
226- }
227- }
228- }
229- }
230218};
231219
232220class Registrar {
0 commit comments