Skip to content

Commit b280613

Browse files
committed
Change Interface to unique_ptr
1 parent 495a80a commit b280613

File tree

4 files changed

+24
-16
lines changed

4 files changed

+24
-16
lines changed

doc/design/register_grad_op.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ The `GradOpDescMaker` will be registered in `OpInfo`, to replace `grad_op_type_`
4242

4343
```cpp
4444
struct OpInfo {
45-
std::function<std::vector<OpDescBind>(const OpDescBind&)> grad_op_maker_;
45+
std::function<std::vector<std::unique_ptr<OpDescBind>>(const OpDescBind&)> grad_op_maker_;
4646
...
4747
};
4848
```
@@ -55,11 +55,11 @@ We propose a base class called `GradOpDescMakerBase` to let operator developers
5555
class GradOpDescMakerBase {
5656
public:
5757
GradOpDescMakerBase(const OpDescBind& );
58-
virtual std::vector<OpDescBind> operator()()const = 0;
58+
virtual std::vector<std::unique_ptr<OpDescBind>> operator()()const = 0;
5959
};
6060
```
6161

62-
We can convert `GradOpDescMakerBase` to `std::function<std::vector<OpDescBind>(const OpDescBind&)>` by
62+
We can convert `GradOpDescMakerBase` to `std::function<std::vector<std::unique_ptr<OpDescBind>>(const OpDescBind&)>` by
6363

6464
```cpp
6565
using GradOpMaker = ...;

paddle/framework/grad_op_desc_maker.h

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class GradOpDescMakerBase {
2424
explicit GradOpDescMakerBase(const OpDescBind& fwd_op) : fwd_op_(fwd_op) {}
2525

2626
virtual ~GradOpDescMakerBase() = default;
27-
virtual std::vector<OpDescBind> operator()() const = 0;
27+
virtual std::vector<std::unique_ptr<OpDescBind>> operator()() const = 0;
2828

2929
protected:
3030
static std::vector<std::string> ToGradNames(
@@ -81,34 +81,38 @@ class SingleGradOpDescMaker : public GradOpDescMakerBase {
8181
public:
8282
using GradOpDescMakerBase::GradOpDescMakerBase;
8383

84-
std::vector<OpDescBind> operator()() const { return {this->Apply()}; }
84+
std::vector<std::unique_ptr<OpDescBind>> operator()() const {
85+
std::vector<std::unique_ptr<OpDescBind>> retv;
86+
retv.emplace_back(this->Apply());
87+
return retv;
88+
}
8589

8690
protected:
87-
virtual OpDescBind Apply() const = 0;
91+
virtual std::unique_ptr<OpDescBind> Apply() const = 0;
8892
};
8993

9094
class DefaultGradOpDescMaker : public SingleGradOpDescMaker {
9195
public:
9296
using SingleGradOpDescMaker::SingleGradOpDescMaker;
9397

9498
protected:
95-
virtual OpDescBind Apply() const {
96-
OpDescBind grad;
97-
grad.SetType(this->GradOpType());
99+
virtual std::unique_ptr<OpDescBind> Apply() const {
100+
auto* grad = new OpDescBind();
101+
grad->SetType(this->GradOpType());
98102

99103
for (auto& input_param : this->InputNames()) {
100-
grad.SetInput(input_param, this->Input(input_param));
101-
grad.SetOutput(GradVarName(input_param), this->InputGrad(input_param));
104+
grad->SetInput(input_param, this->Input(input_param));
105+
grad->SetOutput(GradVarName(input_param), this->InputGrad(input_param));
102106
}
103107

104108
for (auto& output_param : this->OutputNames()) {
105-
grad.SetInput(output_param, this->Output(output_param));
106-
grad.SetInput(GradVarName(output_param), this->OutputGrad(output_param));
109+
grad->SetInput(output_param, this->Output(output_param));
110+
grad->SetInput(GradVarName(output_param), this->OutputGrad(output_param));
107111
}
108112

109-
grad.SetAttrMap(this->Attrs());
113+
grad->SetAttrMap(this->Attrs());
110114

111-
return grad;
115+
return std::unique_ptr<OpDescBind>(grad);
112116
}
113117

114118
virtual std::string GradOpType() const {

paddle/framework/op_info.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ namespace framework {
2828
struct OpInfo {
2929
OpCreator creator_;
3030
std::string grad_op_type_;
31-
std::function<std::vector<OpDescBind>(const OpDescBind&)> grad_op_maker_;
31+
GradOpMakerFN grad_op_maker_;
3232
OpProto* proto_{nullptr};
3333
OpAttrChecker* checker_{nullptr};
3434

paddle/framework/type_defs.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
namespace paddle {
2121
namespace framework {
2222
class OperatorBase;
23+
class OpDescBind;
2324
using VariableNameMap = std::map<std::string, std::vector<std::string>>;
2425

2526
// The order should be as same as framework.proto
@@ -34,5 +35,8 @@ using OpCreator = std::function<OperatorBase*(
3435
const std::string& /*type*/, const VariableNameMap& /*inputs*/,
3536
const VariableNameMap& /*outputs*/, const AttributeMap& /*attrs*/)>;
3637

38+
using GradOpMakerFN =
39+
std::function<std::vector<std::unique_ptr<OpDescBind>>(const OpDescBind&)>;
40+
3741
} // namespace framework
3842
} // namespace paddle

0 commit comments

Comments
 (0)