@@ -24,7 +24,7 @@ class GradOpDescMakerBase {
2424 explicit GradOpDescMakerBase (const OpDescBind& fwd_op) : fwd_op_(fwd_op) {}
2525
2626 virtual ~GradOpDescMakerBase () = default ;
27- virtual std::vector<OpDescBind> operator ()() const = 0;
27+ virtual std::vector<std::unique_ptr< OpDescBind> > operator ()() const = 0;
2828
2929 protected:
3030 static std::vector<std::string> ToGradNames (
@@ -81,34 +81,38 @@ class SingleGradOpDescMaker : public GradOpDescMakerBase {
8181 public:
8282 using GradOpDescMakerBase::GradOpDescMakerBase;
8383
84- std::vector<OpDescBind> operator ()() const { return {this ->Apply ()}; }
84+ std::vector<std::unique_ptr<OpDescBind>> operator ()() const {
85+ std::vector<std::unique_ptr<OpDescBind>> retv;
86+ retv.emplace_back (this ->Apply ());
87+ return retv;
88+ }
8589
8690 protected:
87- virtual OpDescBind Apply () const = 0;
91+ virtual std::unique_ptr< OpDescBind> Apply () const = 0;
8892};
8993
9094class DefaultGradOpDescMaker : public SingleGradOpDescMaker {
9195 public:
9296 using SingleGradOpDescMaker::SingleGradOpDescMaker;
9397
9498 protected:
95- virtual OpDescBind Apply () const {
96- OpDescBind grad;
97- grad. SetType (this ->GradOpType ());
99+ virtual std::unique_ptr< OpDescBind> Apply () const {
100+ auto * grad = new OpDescBind () ;
101+ grad-> SetType (this ->GradOpType ());
98102
99103 for (auto & input_param : this ->InputNames ()) {
100- grad. SetInput (input_param, this ->Input (input_param));
101- grad. SetOutput (GradVarName (input_param), this ->InputGrad (input_param));
104+ grad-> SetInput (input_param, this ->Input (input_param));
105+ grad-> SetOutput (GradVarName (input_param), this ->InputGrad (input_param));
102106 }
103107
104108 for (auto & output_param : this ->OutputNames ()) {
105- grad. SetInput (output_param, this ->Output (output_param));
106- grad. SetInput (GradVarName (output_param), this ->OutputGrad (output_param));
109+ grad-> SetInput (output_param, this ->Output (output_param));
110+ grad-> SetInput (GradVarName (output_param), this ->OutputGrad (output_param));
107111 }
108112
109- grad. SetAttrMap (this ->Attrs ());
113+ grad-> SetAttrMap (this ->Attrs ());
110114
111- return grad;
115+ return std::unique_ptr<OpDescBind>( grad) ;
112116 }
113117
114118 virtual std::string GradOpType () const {
0 commit comments