-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Feature/backward #3068
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/backward #3068
Changes from 62 commits
b1b4364
ecf23ce
b1b13f8
00615eb
a2dc961
e32e306
831d4e1
f77c63b
fa7cbfd
0ac79a3
292f2ab
05d9aff
fa6a46a
03f418c
5297bcb
9475972
f9fab14
3d18737
70bd07a
63636d6
04db418
28c0281
099bb53
3dd5fd0
84198f7
4461f3c
b1d8419
d2583bd
b9f2bb3
5713266
d4ab70a
a0669ea
7088654
404cc05
65d2678
46d766e
e1d1067
8bf0ca0
d0b25ac
72839a7
29d50ad
74cd9a7
7087a04
b2e1c48
658588a
d6e0368
e1cd719
71bd439
0da5cce
52054af
0e337be
1197420
302046a
1de465b
dc06eaa
39cd39e
be52868
a2e2cd7
2198963
42e2fa5
48812cd
213fdad
f5636da
bd14660
ca16c0d
bc146e8
80baf86
e2fd2bd
737ea05
9cc9907
051d6c8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,190 @@ | ||
| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|
|
||
| #include "paddle/framework/backward.h" | ||
| #include <list> | ||
| #include "paddle/framework/net.h" | ||
| #include "paddle/framework/op_registry.h" | ||
|
|
||
| namespace paddle { | ||
| namespace framework { | ||
|
|
||
| static bool AllInSet(const std::vector<std::string>& names, | ||
| const std::string& suffix, | ||
| const std::unordered_set<std::string>& set) { | ||
| for (auto& name : names) { | ||
| if (set.find(name + suffix) == set.end()) { | ||
| return false; | ||
| } | ||
| } | ||
| return true; | ||
| } | ||
|
|
||
| static std::vector<size_t> InSetIdx( | ||
| const std::vector<std::string>& names, const std::string& suffix, | ||
| const std::unordered_set<std::string>& set) { | ||
| std::vector<size_t> ret_val; | ||
| ret_val.reserve(names.size()); | ||
| for (size_t i = 0; i < names.size(); ++i) { | ||
| if (set.find(names[i] + suffix) != set.end()) { | ||
| ret_val.push_back(i); | ||
| } | ||
| } | ||
| return ret_val; | ||
| } | ||
|
|
||
| static std::shared_ptr<OperatorBase> EmptyOp() { | ||
|
||
| auto net_op = std::make_shared<NetOp>(); | ||
| net_op->type_ = "@EMPTY_OP@"; | ||
| net_op->CompleteAddOp(); | ||
| return net_op; | ||
| } | ||
|
|
||
| /** | ||
| * @brief Backward an operator, implementation | ||
|
||
| * @param forwardOp the forward operator | ||
| * @param no_grad_names variable names not calculate for gradient. Like X@GRAD | ||
|
||
| * is not needed. | ||
| * @param uniq_id a unique index used inside BackwardImpl, it will be shared | ||
| * through recursive invoke. | ||
|
||
| * @return The backward operator. For simple situation, it is a simple operator. | ||
| * For complex situation, it is a NetOp. | ||
| * | ||
| * See Backward.h for details | ||
| */ | ||
| static std::shared_ptr<OperatorBase> BackwardImpl( | ||
|
||
| const OperatorBase& forwardOp, | ||
| std::unordered_set<std::string>& no_grad_names, size_t& uniq_id) { | ||
| /** | ||
| * If all input gradients of forwarding operator do not need to calculate, | ||
| * just return an EmptyOp. Not return null ptr because EmptyOp does not take | ||
| * too much time for calculation, but it is useful for simplifying logic. | ||
| */ | ||
| if (AllInSet(forwardOp.inputs_, OperatorBase::GRAD_VAR_SUFFIX(), | ||
| no_grad_names)) { | ||
| return EmptyOp(); | ||
| } | ||
|
|
||
| /** | ||
| * All output gradients of forwarding operator do not need to calculate. Then | ||
|
||
| * all input gradients cannot be computed at all, and we put them into | ||
| * `no_grad_names` set. Return an EmptyOp. | ||
| */ | ||
| if (AllInSet(forwardOp.outputs_, OperatorBase::GRAD_VAR_SUFFIX(), | ||
| no_grad_names)) { | ||
| for (auto& name : forwardOp.inputs_) { | ||
| /// Mark all input is not need | ||
| no_grad_names.insert(name + OperatorBase::GRAD_VAR_SUFFIX()); | ||
| } | ||
| return EmptyOp(); | ||
| } | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 我可以理解所有输入都不需要计算梯度时,可标记所有输出都不计算梯度。
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 在反向遍历的时候,如果使用输出Variable的Operator都没有计算梯度,自己也没办法计算梯度。 |
||
|
|
||
| //! Returned gradient network | ||
| auto net = std::make_shared<NetOp>(); | ||
|
|
||
| if (forwardOp.IsNetOp()) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 为什么不能把获取NetOp的Backward Op写在NetOp对应的一个方法里?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @qingqing01 我理解这里是Backward的生成逻辑,和NetOp自身的反向不在同一个层面。Gradient Operator是可插拔Unit,Backward是系统core的一部分。
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
或许复杂Op的backward我们能注册到另一个地方。但是这个注册机制感觉非常不统一。。最简单的情况肯定是这么写,在没想到更好的办法的时候,先这么做吧。 |
||
| /// Because forwardOp is a net op, it can static_cast. | ||
| auto& forwardNet = static_cast<const NetOp&>(forwardOp); | ||
|
|
||
| //! Map from output gradient variable name to operator's indices in backward | ||
| //! net. That operator generates that variable. | ||
| std::unordered_map<std::string, std::vector<size_t>> dup_output_ops; | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. duplicated |
||
|
|
||
| size_t local_op_id = 0; | ||
| /// reversely travel forwardNet | ||
| for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend(); | ||
| ++it, ++local_op_id) { | ||
| auto fwd = *it; | ||
| auto bwd = BackwardImpl(*fwd, no_grad_names, uniq_id); | ||
| net->AddOp(bwd); | ||
| for (auto& out : bwd->outputs_) { | ||
| dup_output_ops[out].emplace_back(local_op_id); | ||
| } | ||
| } | ||
| /// Get unique ID for this method. | ||
| auto uid = uniq_id++; | ||
| // TODO(dzh): more comment | ||
| using Pos = std::pair<size_t, std::shared_ptr<OperatorBase>>; | ||
| std::list<Pos> insert_position; | ||
| for (auto& dup_output_op : dup_output_ops) { | ||
| const std::string& name = dup_output_op.first; | ||
| auto& dup_op = dup_output_op.second; | ||
| if (dup_op.size() == 1) continue; | ||
| std::vector<std::string> dup_outputs; | ||
|
|
||
| for (size_t i = 0; i < dup_op.size(); ++i) { | ||
| auto op_offset = dup_op[i]; | ||
| dup_outputs.push_back(name + "@RENAME@" + std::to_string(uid) + "@" + | ||
| std::to_string(i)); | ||
| net->ops_[op_offset]->Rename(name, dup_outputs.back()); | ||
| } | ||
| insert_position.push_back( | ||
| {dup_op.back(), | ||
| OpRegistry::CreateOp( | ||
| "add", {dup_outputs}, {name}, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个add op现在应该还没实现?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 是的。这个Op的实现不影响Backward算法的实现和单测。 |
||
| {{"input_format", | ||
| std::vector<int>{0, (int)dup_outputs.size()}}})}); | ||
|
||
| } | ||
|
|
||
| insert_position.sort( | ||
| [](const Pos& l, const Pos& r) { return l.first > r.first; }); | ||
|
|
||
| for (auto& pos : insert_position) { | ||
| net->InsertOp(pos.first + 1, pos.second); | ||
| } | ||
|
|
||
| } else { | ||
| //! TODO(fjy) | ||
| std::shared_ptr<OperatorBase> grad_op = OpRegistry::CreateGradOp(forwardOp); | ||
| for (std::string& grad_input : grad_op->inputs_) { | ||
| if (no_grad_names.count(grad_input)) { | ||
| std::string prefix = grad_input.substr( | ||
| 0, grad_input.size() - OperatorBase::GRAD_VAR_SUFFIX().size()); | ||
| grad_input = prefix + OperatorBase::ZERO_VAR_SUFFIX(); | ||
| net->AddOp(OpRegistry::CreateOp("fill_zeros_like", {prefix}, | ||
| {grad_input}, {})); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. need more comments for |
||
| } | ||
| } | ||
|
|
||
| for (std::string& grad_output : grad_op->outputs_) { | ||
| if (no_grad_names.count(grad_output)) { | ||
| grad_output = OperatorBase::EMPTY_VAR_NAME(); | ||
| } | ||
| } | ||
|
|
||
| if (net->ops_.empty()) { // Current no aux op is added to network | ||
| return grad_op; | ||
| } | ||
| net->AddOp(grad_op); | ||
| } | ||
| net->type_ = "@GENERATED_BACKWARD@"; | ||
| net->CompleteAddOp(); | ||
| return net; | ||
| } | ||
|
|
||
| //! See header for comments | ||
| extern std::shared_ptr<OperatorBase> Backward( | ||
| const OperatorBase& forwardOp, | ||
| const std::unordered_set<std::string>& no_grad_vars) { | ||
| std::unordered_set<std::string> no_grad_names; | ||
| no_grad_names.reserve(no_grad_vars.size()); | ||
|
|
||
| for (auto& name : no_grad_vars) { | ||
| no_grad_names.insert(name + OperatorBase::GRAD_VAR_SUFFIX()); | ||
| } | ||
| size_t uid = 0; | ||
| return BackwardImpl(forwardOp, no_grad_names, uid); | ||
| } | ||
| } // namespace framework | ||
| } // namespace paddle | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,31 @@ | ||
| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|
|
||
| #pragma once | ||
| #include <unordered_set> | ||
| #include "operator.h" | ||
| namespace paddle { | ||
| namespace framework { | ||
|
|
||
| /** | ||
| * @brief | ||
| * @param forwardOp | ||
| * @param no_grad_vars ignored input name of forward | ||
| * @return | ||
| */ | ||
| extern std::shared_ptr<OperatorBase> Backward( | ||
| const OperatorBase& forwardOp, | ||
| const std::unordered_set<std::string>& no_grad_vars); | ||
| } // namespace framework | ||
| } // namespace paddle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need to use
static?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because we do not export them to global symbols.