-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Feature/backward #3068
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/backward #3068
Changes from 1 commit
b1b4364
ecf23ce
b1b13f8
00615eb
a2dc961
e32e306
831d4e1
f77c63b
fa7cbfd
0ac79a3
292f2ab
05d9aff
fa6a46a
03f418c
5297bcb
9475972
f9fab14
3d18737
70bd07a
63636d6
04db418
28c0281
099bb53
3dd5fd0
84198f7
4461f3c
b1d8419
d2583bd
b9f2bb3
5713266
d4ab70a
a0669ea
7088654
404cc05
65d2678
46d766e
e1d1067
8bf0ca0
d0b25ac
72839a7
29d50ad
74cd9a7
7087a04
b2e1c48
658588a
d6e0368
e1cd719
71bd439
0da5cce
52054af
0e337be
1197420
302046a
1de465b
dc06eaa
39cd39e
be52868
a2e2cd7
2198963
42e2fa5
48812cd
213fdad
f5636da
bd14660
ca16c0d
bc146e8
80baf86
e2fd2bd
737ea05
9cc9907
051d6c8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -31,88 +31,74 @@ static bool AllInSet(const std::vector<std::string>& names, | |
| return true; | ||
| } | ||
|
|
||
| static std::vector<size_t> InSetIdx( | ||
| const std::vector<std::string>& names, const std::string& suffix, | ||
| const std::unordered_set<std::string>& set) { | ||
| std::vector<size_t> ret_val; | ||
| ret_val.reserve(names.size()); | ||
| for (size_t i = 0; i < names.size(); ++i) { | ||
| if (set.find(names[i] + suffix) != set.end()) { | ||
| ret_val.push_back(i); | ||
| } | ||
| } | ||
| return ret_val; | ||
| } | ||
|
|
||
| static std::shared_ptr<OperatorBase> EmptyOp() { | ||
| static std::shared_ptr<OperatorBase> NOP() { | ||
| auto net_op = std::make_shared<NetOp>(); | ||
| net_op->type_ = "@EMPTY_OP@"; | ||
| net_op->type_ = "@NOP@"; | ||
| net_op->CompleteAddOp(); | ||
| return net_op; | ||
| } | ||
|
|
||
| /** | ||
| * @brief Backward an operator, implementation | ||
| * @param forwardOp the forward operator | ||
| * @param no_grad_names variable names not calculate for gradient. Like X@GRAD | ||
| * is not needed. | ||
| * @param uniq_id a unique index used inside BackwardImpl, it will be shared | ||
| * through recursive invoke. | ||
| * @return The backward operator. For simple situation, it is a simple operator. | ||
| * For complex situation, it is a NetOp. | ||
| * | ||
| * See Backward.h for details | ||
| */ | ||
| static std::shared_ptr<OperatorBase> BackwardImpl( | ||
| // Get backward operator from a forward operator, recursively implementation. | ||
| // | ||
| // no_grad_names the gradient variable names without gradient calculating. | ||
| // | ||
| // uniq_id is a unique index used inside recursively calling BackwardRecursive. | ||
| // use `uid = uniq_id++;` to get the unique index, and pass `uniq_id` through | ||
| // recursive calling. | ||
| // | ||
| // returns The backward operator. For simple situation, it is a simple | ||
| // operator. For complex situation, it is a NetOp. | ||
| // | ||
| // See Backward.h for details | ||
| static std::shared_ptr<OperatorBase> BackwardRecursive( | ||
| const OperatorBase& forwardOp, | ||
| std::unordered_set<std::string>& no_grad_names, size_t& uniq_id); | ||
| std::shared_ptr<OperatorBase> BackwardRecursive( | ||
| const OperatorBase& forwardOp, | ||
| std::unordered_set<std::string>& no_grad_names, size_t& uniq_id) { | ||
| /** | ||
| * If all input gradients of forwarding operator do not need to calculate, | ||
| * just return an EmptyOp. Not return null ptr because EmptyOp does not take | ||
| * too much time for calculation, but it is useful for simplifying logic. | ||
| */ | ||
| // If all input gradients of forwarding operator do not need to calculate, | ||
| // just return an NOP. Not return null ptr because NOP does not take | ||
| // too much time for calculation, but it is useful for simplifying logic. | ||
| if (AllInSet(forwardOp.inputs_, OperatorBase::GRAD_VAR_SUFFIX(), | ||
| no_grad_names)) { | ||
| return EmptyOp(); | ||
| return NOP(); | ||
| } | ||
|
|
||
| /** | ||
| * All output gradients of forwarding operator do not need to calculate. Then | ||
| * all input gradients cannot be computed at all, and we put them into | ||
| * `no_grad_names` set. Return an EmptyOp. | ||
| */ | ||
| // All output gradients of forwarding operator do not need to calculate. Then | ||
| // all input gradients cannot be computed at all, and we put them into | ||
| // `no_grad_names` set. Return an NOP. | ||
| if (AllInSet(forwardOp.outputs_, OperatorBase::GRAD_VAR_SUFFIX(), | ||
| no_grad_names)) { | ||
| for (auto& name : forwardOp.inputs_) { | ||
| /// Mark all input is not need | ||
| // Mark all input is not need | ||
| no_grad_names.insert(name + OperatorBase::GRAD_VAR_SUFFIX()); | ||
| } | ||
| return EmptyOp(); | ||
| return NOP(); | ||
| } | ||
|
|
||
| //! Returned gradient network | ||
| // Returned gradient network | ||
| auto net = std::make_shared<NetOp>(); | ||
|
|
||
| if (forwardOp.IsNetOp()) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 为什么不能把获取NetOp的Backward Op写在NetOp对应的一个方法里?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @qingqing01 我理解这里是Backward的生成逻辑,和NetOp自身的反向不在同一个层面。Gradient Operator是可插拔Unit,Backward是系统core的一部分。
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
或许复杂Op的backward我们能注册到另一个地方。但是这个注册机制感觉非常不统一。。最简单的情况肯定是这么写,在没想到更好的办法的时候,先这么做吧。 |
||
| /// Because forwardOp is a net op, it can static_cast. | ||
| // Because forwardOp is a net op, it can static_cast. | ||
| auto& forwardNet = static_cast<const NetOp&>(forwardOp); | ||
|
|
||
| //! Map from output gradient variable name to operator's indices in backward | ||
| //! net. That operator generates that variable. | ||
| // Map from output gradient variable name to operator's indices in backward | ||
| // net. That operator generates that variable. | ||
| std::unordered_map<std::string, std::vector<size_t>> dup_output_ops; | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. duplicated |
||
|
|
||
| size_t local_op_id = 0; | ||
| /// reversely travel forwardNet | ||
| // reversely travel forwardNet | ||
| for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend(); | ||
| ++it, ++local_op_id) { | ||
| auto fwd = *it; | ||
| auto bwd = BackwardImpl(*fwd, no_grad_names, uniq_id); | ||
| auto bwd = BackwardRecursive(*fwd, no_grad_names, uniq_id); | ||
| net->AddOp(bwd); | ||
| for (auto& out : bwd->outputs_) { | ||
| dup_output_ops[out].emplace_back(local_op_id); | ||
| } | ||
| } | ||
| /// Get unique ID for this method. | ||
| // Get unique ID for this method. | ||
| auto uid = uniq_id++; | ||
| // TODO(dzh): more comment | ||
| using Pos = std::pair<size_t, std::shared_ptr<OperatorBase>>; | ||
|
|
@@ -145,13 +131,15 @@ static std::shared_ptr<OperatorBase> BackwardImpl( | |
| } | ||
|
|
||
| } else { | ||
| //! TODO(fjy) | ||
| std::shared_ptr<OperatorBase> grad_op = OpRegistry::CreateGradOp(forwardOp); | ||
| for (std::string& grad_input : grad_op->inputs_) { | ||
| if (no_grad_names.count(grad_input)) { | ||
| std::string prefix = grad_input.substr( | ||
| 0, grad_input.size() - OperatorBase::GRAD_VAR_SUFFIX().size()); | ||
| grad_input = prefix + OperatorBase::ZERO_VAR_SUFFIX(); | ||
|
|
||
| // If part of input gradient of that operator is not calculated, fill | ||
| // zero variables to that input gradient. | ||
| net->AddOp(OpRegistry::CreateOp("fill_zeros_like", {prefix}, | ||
| {grad_input}, {})); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. need more comments for |
||
| } | ||
|
|
@@ -173,8 +161,8 @@ static std::shared_ptr<OperatorBase> BackwardImpl( | |
| return net; | ||
| } | ||
|
|
||
| //! See header for comments | ||
| extern std::shared_ptr<OperatorBase> Backward( | ||
| // See header for comments | ||
| std::shared_ptr<OperatorBase> Backward( | ||
| const OperatorBase& forwardOp, | ||
| const std::unordered_set<std::string>& no_grad_vars) { | ||
| std::unordered_set<std::string> no_grad_names; | ||
|
|
@@ -184,7 +172,7 @@ extern std::shared_ptr<OperatorBase> Backward( | |
| no_grad_names.insert(name + OperatorBase::GRAD_VAR_SUFFIX()); | ||
| } | ||
| size_t uid = 0; | ||
| return BackwardImpl(forwardOp, no_grad_names, uid); | ||
| return BackwardRecursive(forwardOp, no_grad_names, uid); | ||
| } | ||
| } // namespace framework | ||
| } // namespace paddle | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,8 +23,8 @@ namespace operators { | |
| template <typename Place, typename T> | ||
| class FillZerosLikeKernel : public framework::OpKernel { | ||
| public: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please add a Python unit test for this operator
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We will add the unit test in other PR. @Canpio @dzhwinter
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @reyoung get it |
||
| void Compute(const framework::KernelContext& context) const override { | ||
| auto* output = context.Output(0)->GetMutable<framework::Tensor>(); | ||
| void Compute(const framework::ExecutionContext& context) const override { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. framework::xxx ===> xxx
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Changed in Next PR |
||
| auto* output = context.Output<framework::Tensor>(0); | ||
| output->mutable_data<T>(context.GetPlace()); | ||
| framework::EigenVector<T>::Flatten(*output).setZero(); | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我可以理解所有输入都不需要计算梯度时,可标记所有输出都不计算梯度。
但是,没有太想明白,什么情况下,依据输出不需要计算梯度,来标记输入不计算梯度。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在反向遍历的时候,如果使用输出Variable的Operator都没有计算梯度,自己也没办法计算梯度。