Skip to content

Commit b884bc3

Browse files
authored
Merge pull request #4551 from reyoung/feature/grad_reg_mechanism_cont
Add helper function in GradOpDescMakerBase. Make it easier to use.
2 parents 8bf209f + 703321e commit b884bc3

File tree

7 files changed

+176
-36
lines changed

7 files changed

+176
-36
lines changed

doc/design/register_grad_op.md

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,22 +33,45 @@ The mapping relationship between an operator and its gradient operators is a fun
3333
3434
```cpp
3535
// (OpDesc) --> vector<OpDesc>
36-
using GradOpDescMaker = std::function<std::vector<OpDesc>(const OpDesc&)>;
36+
std::function<std::vector<OpDescBind>(const OpDescBind&)>;
3737
```
3838

39-
The function take a `OpDesc` of the forward operator and return one or many gradient operator descriptions.
39+
The function takes an `OpDescBind` of the forward operator and returns one or many gradient operator descriptions. `OpDescBind` is a C++ wrapper for protobuf message `OpDesc` to manipulate `OpDesc` fast.
4040

4141
The `GradOpDescMaker` will be registered in `OpInfo`, to replace `grad_op_type_` field. The `OpInfo` should be
4242

4343
```cpp
4444
struct OpInfo {
45-
GradOpDescMaker grad_op_maker_;
45+
std::function<std::vector<std::unique_ptr<OpDescBind>>(const OpDescBind&)> grad_op_maker_;
4646
...
4747
};
4848
```
4949
5050
The `grad_op_maker_ ` is `nullptr` if the operator does not have associated gradient operators.
5151
52+
We propose a base class called `GradOpDescMakerBase` to let operator developers generate `Gradient Operators` easily. The public interface of that class is
53+
54+
```cpp
55+
class GradOpDescMakerBase {
56+
public:
57+
GradOpDescMakerBase(const OpDescBind& );
58+
virtual std::vector<std::unique_ptr<OpDescBind>> operator()()const = 0;
59+
};
60+
```
61+
62+
We can convert `GradOpDescMakerBase` to `std::function<std::vector<std::unique_ptr<OpDescBind>>(const OpDescBind&)>` by
63+
64+
```cpp
65+
using GradOpMaker = ...;
66+
std::function<std::vector<OpDescBind>(const OpDescBind&)> func;
67+
func = [] (const OpDescBind& fwd_op) {
68+
GradOpMaker maker(fwd_op);
69+
return maker();
70+
};
71+
```
72+
73+
We can write many helper functions since the `GradOpDescMakerBase` is a class now. The basic helper functions get the variables of `Input`, `Output`, `InputGradient` and `OutputGradient` in the forwarding operator.
74+
5275
We should chagne register macros at the same time. In the current solution, there is no difference between forwarding operators and backward operators. So `REGISTER_OP` just register one operator. If the `REGISTER_OPERATOR ` contains `OpProtoAndCheckerMaker` and `GradOpDescMaker`, we just list them in the same macro. It can be done by a macro contains `__VA_ARGS__`.
5376

5477
The user interface should be

paddle/framework/details/op_registry.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#pragma once
1616

17+
#include "paddle/framework/grad_op_desc_maker.h"
1718
#include "paddle/framework/op_info.h"
1819
#include "paddle/framework/op_proto_maker.h"
1920
#include "paddle/framework/operator.h"
@@ -96,7 +97,10 @@ struct OpInfoFiller<T, kOpProtoAndCheckerMaker> {
9697
template <typename T>
9798
struct OpInfoFiller<T, kGradOpDescMaker> {
9899
void operator()(const char* op_type, OpInfo* info) const {
99-
info->grad_op_maker_ = new T();
100+
info->grad_op_maker_ = [](const OpDescBind& fwd_op) {
101+
T maker(fwd_op);
102+
return maker();
103+
};
100104
}
101105
};
102106
} // namespace details
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include "paddle/framework/op_desc.h"
17+
#include "paddle/framework/operator.h"
18+
19+
namespace paddle {
20+
namespace framework {
21+
22+
class GradOpDescMakerBase {
23+
public:
24+
explicit GradOpDescMakerBase(const OpDescBind& fwd_op) : fwd_op_(fwd_op) {}
25+
26+
virtual ~GradOpDescMakerBase() = default;
27+
virtual std::vector<std::unique_ptr<OpDescBind>> operator()() const = 0;
28+
29+
protected:
30+
static std::vector<std::string> ToGradNames(
31+
const std::vector<std::string>& var_names) {
32+
std::vector<std::string> ret_val;
33+
ret_val.reserve(var_names.size());
34+
std::transform(var_names.begin(), var_names.end(),
35+
std::back_inserter(ret_val), GradVarName);
36+
return ret_val;
37+
}
38+
39+
std::vector<std::string> InputGrad(const std::string& name) const {
40+
return ToGradNames(fwd_op_.Input(name));
41+
}
42+
43+
std::vector<std::string> OutputGrad(const std::string& name) const {
44+
return ToGradNames(fwd_op_.Output(name));
45+
}
46+
47+
std::vector<std::string> InputNames() const {
48+
return this->fwd_op_.InputNames();
49+
}
50+
51+
std::vector<std::string> OutputNames() const {
52+
return this->fwd_op_.OutputNames();
53+
}
54+
55+
std::vector<std::string> Input(const std::string& name) const {
56+
return fwd_op_.Input(name);
57+
}
58+
59+
std::vector<std::string> Output(const std::string& name) const {
60+
return fwd_op_.Output(name);
61+
}
62+
63+
const std::unordered_map<std::string, Attribute>& Attrs() const {
64+
return fwd_op_.GetAttrMap();
65+
}
66+
67+
const Attribute& GetAttr(const std::string& name) const {
68+
auto& map = fwd_op_.GetAttrMap();
69+
auto it = map.find(name);
70+
PADDLE_ENFORCE(it != map.end(), "Cannot find attribute %s", name);
71+
return it->second;
72+
}
73+
74+
std::string ForwardOpType() const { return this->fwd_op_.Type(); }
75+
76+
private:
77+
const OpDescBind& fwd_op_;
78+
};
79+
80+
class SingleGradOpDescMaker : public GradOpDescMakerBase {
81+
public:
82+
using GradOpDescMakerBase::GradOpDescMakerBase;
83+
84+
std::vector<std::unique_ptr<OpDescBind>> operator()() const {
85+
std::vector<std::unique_ptr<OpDescBind>> retv;
86+
retv.emplace_back(this->Apply());
87+
return retv;
88+
}
89+
90+
protected:
91+
virtual std::unique_ptr<OpDescBind> Apply() const = 0;
92+
};
93+
94+
class DefaultGradOpDescMaker : public SingleGradOpDescMaker {
95+
public:
96+
using SingleGradOpDescMaker::SingleGradOpDescMaker;
97+
98+
protected:
99+
virtual std::unique_ptr<OpDescBind> Apply() const {
100+
auto* grad = new OpDescBind();
101+
grad->SetType(this->GradOpType());
102+
103+
for (auto& input_param : this->InputNames()) {
104+
grad->SetInput(input_param, this->Input(input_param));
105+
grad->SetOutput(GradVarName(input_param), this->InputGrad(input_param));
106+
}
107+
108+
for (auto& output_param : this->OutputNames()) {
109+
grad->SetInput(output_param, this->Output(output_param));
110+
grad->SetInput(GradVarName(output_param), this->OutputGrad(output_param));
111+
}
112+
113+
grad->SetAttrMap(this->Attrs());
114+
115+
return std::unique_ptr<OpDescBind>(grad);
116+
}
117+
118+
virtual std::string GradOpType() const {
119+
return this->ForwardOpType() + "_grad";
120+
}
121+
};
122+
123+
} // namespace framework
124+
} // namespace paddle

paddle/framework/op_desc.cc

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,6 @@ const std::vector<std::string> &OpDescBind::Input(
3131
return it->second;
3232
}
3333

34-
std::vector<std::string> OpDescBind::InputNames() const {
35-
std::vector<std::string> retv;
36-
retv.reserve(this->inputs_.size());
37-
for (auto &ipt : this->inputs_) {
38-
retv.push_back(ipt.first);
39-
}
40-
return retv;
41-
}
42-
4334
void OpDescBind::SetInput(const std::string &param_name,
4435
const std::vector<std::string> &args) {
4536
need_update_ = true;
@@ -54,15 +45,6 @@ const std::vector<std::string> &OpDescBind::Output(
5445
return it->second;
5546
}
5647

57-
std::vector<std::string> OpDescBind::OutputNames() const {
58-
std::vector<std::string> retv;
59-
retv.reserve(this->outputs_.size());
60-
for (auto &ipt : this->outputs_) {
61-
retv.push_back(ipt.first);
62-
}
63-
return retv;
64-
}
65-
6648
void OpDescBind::SetOutput(const std::string &param_name,
6749
const std::vector<std::string> &args) {
6850
need_update_ = true;

paddle/framework/op_desc.h

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,11 @@ class OpDescBind {
3535

3636
const std::vector<std::string> &Input(const std::string &name) const;
3737

38-
std::vector<std::string> InputNames() const;
39-
4038
void SetInput(const std::string &param_name,
4139
const std::vector<std::string> &args);
4240

4341
const std::vector<std::string> &Output(const std::string &name) const;
4442

45-
std::vector<std::string> OutputNames() const;
46-
4743
void SetOutput(const std::string &param_name,
4844
const std::vector<std::string> &args);
4945

@@ -61,17 +57,30 @@ class OpDescBind {
6157

6258
void SetBlockAttr(const std::string &name, BlockDescBind &block);
6359

64-
// Only be used in C++
65-
void SetAttrMap(const AttributeMap &attr_map);
66-
6760
Attribute GetAttr(const std::string &name) const;
6861

6962
int GetBlockAttr(const std::string &name) const;
7063

7164
// Only be used in C++
7265
const AttributeMap &GetAttrMap() const;
7366

67+
// Only be used in C++
68+
void SetAttrMap(const AttributeMap &attr_map);
69+
70+
std::vector<std::string> InputNames() const { return MapKeys(inputs_); }
71+
std::vector<std::string> OutputNames() const { return MapKeys(outputs_); }
72+
7473
private:
74+
template <typename MapType>
75+
static std::vector<typename MapType::key_type> MapKeys(const MapType &map) {
76+
std::vector<typename MapType::key_type> ret_val;
77+
ret_val.reserve(map.size());
78+
std::transform(
79+
map.begin(), map.end(), std::back_inserter(ret_val),
80+
[](const typename MapType::value_type &pair) { return pair.first; });
81+
return ret_val;
82+
}
83+
7584
void Sync();
7685

7786
OpDesc op_desc_;

paddle/framework/op_info.h

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,10 @@
2525
namespace paddle {
2626
namespace framework {
2727

28-
class GradOpDescMakerBase {
29-
public:
30-
virtual ~GradOpDescMakerBase() = default;
31-
virtual std::vector<OpDescBind> operator()(const OpDescBind&) const = 0;
32-
};
33-
3428
struct OpInfo {
3529
OpCreator creator_;
3630
std::string grad_op_type_;
37-
GradOpDescMakerBase* grad_op_maker_{nullptr};
31+
GradOpMakerFN grad_op_maker_;
3832
OpProto* proto_{nullptr};
3933
OpAttrChecker* checker_{nullptr};
4034

paddle/framework/type_defs.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
namespace paddle {
2121
namespace framework {
2222
class OperatorBase;
23+
class OpDescBind;
2324
using VariableNameMap = std::map<std::string, std::vector<std::string>>;
2425

2526
// The order should be as same as framework.proto
@@ -34,5 +35,8 @@ using OpCreator = std::function<OperatorBase*(
3435
const std::string& /*type*/, const VariableNameMap& /*inputs*/,
3536
const VariableNameMap& /*outputs*/, const AttributeMap& /*attrs*/)>;
3637

38+
using GradOpMakerFN =
39+
std::function<std::vector<std::unique_ptr<OpDescBind>>(const OpDescBind&)>;
40+
3741
} // namespace framework
3842
} // namespace paddle

0 commit comments

Comments
 (0)