Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
b1b4364
Rename PlainNet --> NetOp
reyoung Jul 26, 2017
ecf23ce
Update Backward
reyoung Jul 26, 2017
b1b13f8
Update Interface
reyoung Jul 26, 2017
00615eb
Refine OpRegistry::AddInput/AddOutput
reyoung Jul 26, 2017
a2dc961
Add fill_zeros_like op
JiayiFeng Jul 26, 2017
e32e306
Develop backward building precess of single op
JiayiFeng Jul 26, 2017
831d4e1
Refining Unittest
reyoung Jul 26, 2017
f77c63b
Merge branch 'feature/backward' of https://github.com/reyoung/Paddle …
JiayiFeng Jul 26, 2017
fa7cbfd
"backward is NetOp"
dzhwinter Jul 26, 2017
0ac79a3
Merge remote-tracking branch 'reyoung/feature/backward' into feature/…
dzhwinter Jul 26, 2017
292f2ab
"split to generic add PR"
dzhwinter Jul 26, 2017
05d9aff
Stash
reyoung Jul 27, 2017
fa6a46a
Merge branch 'feature/backward' of github.com:reyoung/Paddle into fea…
reyoung Jul 27, 2017
03f418c
Fix compile error
JiayiFeng Jul 27, 2017
5297bcb
Merge branch 'feature/backward' of https://github.com/reyoung/Paddle …
JiayiFeng Jul 27, 2017
9475972
Merge branch 'feature/backward' of github.com:reyoung/Paddle into fea…
reyoung Jul 27, 2017
f9fab14
Fix compile error
reyoung Jul 27, 2017
3d18737
Add unittest for part_of_output_are_not_need
reyoung Jul 27, 2017
70bd07a
Fix compile errors of FillZerosLikeOp
JiayiFeng Jul 27, 2017
63636d6
Stash for canpio
reyoung Jul 27, 2017
04db418
Add unitest of Backward.part_of_input_are_not_need
JiayiFeng Jul 27, 2017
28c0281
Stash
reyoung Jul 27, 2017
099bb53
Merge branch 'feature/backward' of github.com:reyoung/Paddle into fea…
reyoung Jul 27, 2017
3dd5fd0
Add unitest of Backward.intermediate_variable_not_need_in_linear_net
JiayiFeng Jul 27, 2017
84198f7
Add unittest
reyoung Jul 27, 2017
4461f3c
Merge branch 'feature/backward' of https://github.com/reyoung/Paddle …
JiayiFeng Jul 27, 2017
b1d8419
rename test
JiayiFeng Jul 27, 2017
d2583bd
InsertOp for NetOp
reyoung Jul 27, 2017
b9f2bb3
"wait add generic"
dzhwinter Jul 27, 2017
5713266
Merge remote-tracking branch 'reyoung/feature/backward' into feature/…
dzhwinter Jul 27, 2017
d4ab70a
Merge branch 'feature/backward' of github.com:reyoung/Paddle into fea…
reyoung Jul 27, 2017
a0669ea
Merge remote-tracking branch 'reyoung/feature/backward' into feature/…
dzhwinter Jul 27, 2017
7088654
"add duplicate"
dzhwinter Jul 27, 2017
404cc05
"reverse travesal"
dzhwinter Jul 27, 2017
65d2678
"add simple net test"
dzhwinter Jul 28, 2017
46d766e
Merge branch 'feature/unittest_for_inputs' into feature/backward
reyoung Jul 28, 2017
e1d1067
Merge branch 'feature/backward' of github.com:reyoung/Paddle into fea…
reyoung Jul 28, 2017
8bf0ca0
Fix unittest error
reyoung Jul 28, 2017
d0b25ac
Fix some unittest error
reyoung Jul 28, 2017
72839a7
fix conflict6
dzhwinter Jul 28, 2017
29d50ad
Refine unit-test
reyoung Jul 28, 2017
74cd9a7
"fix unittest"
dzhwinter Jul 28, 2017
7087a04
"add unittest"
dzhwinter Jul 28, 2017
b2e1c48
Merge remote-tracking branch 'reyoung/feature/backward' into feature/…
dzhwinter Jul 28, 2017
658588a
"format test case"
dzhwinter Jul 28, 2017
d6e0368
Add comment in backward.cc
reyoung Jul 28, 2017
e1cd719
Merge branch 'feature/backward' of https://github.com/reyoung/Paddle …
dzhwinter Jul 28, 2017
71bd439
Addjust Backward.linear_net_intermediate_variable_has_no_grad
JiayiFeng Jul 28, 2017
0da5cce
"fix test case"
dzhwinter Jul 28, 2017
52054af
"fix typo"
dzhwinter Jul 28, 2017
0e337be
Merge branch 'feature/backward' of https://github.com/reyoung/Paddle …
JiayiFeng Jul 28, 2017
1197420
Merge branch 'feature/backward' of https://github.com/reyoung/Paddle …
JiayiFeng Jul 28, 2017
302046a
"fix return net error"
dzhwinter Jul 28, 2017
1de465b
Change some `ASSERT_EQ` to `EXPECT_EQ`
JiayiFeng Jul 28, 2017
dc06eaa
Merge branch 'feature/backward' of https://github.com/reyoung/Paddle …
JiayiFeng Jul 28, 2017
39cd39e
Update test
JiayiFeng Jul 28, 2017
be52868
Fix net_input_of_network_not_need_grad
reyoung Jul 28, 2017
a2e2cd7
Fix bug of TEST Backwar.linear_net_intermediate_variable_has_no_grad
JiayiFeng Jul 28, 2017
2198963
Merge branch 'feature/backward' of https://github.com/reyoung/Paddle …
JiayiFeng Jul 28, 2017
42e2fa5
Fix unittest
reyoung Jul 28, 2017
48812cd
Merge branch 'feature/backward' of github.com:reyoung/Paddle into fea…
reyoung Jul 28, 2017
213fdad
adjust format
JiayiFeng Jul 28, 2017
f5636da
design doc
dzhwinter Jul 30, 2017
bd14660
"add part of design doc"
dzhwinter Jul 31, 2017
ca16c0d
Merge remote-tracking branch 'remotes/reyoung/feature/backward' into …
dzhwinter Jul 31, 2017
bc146e8
Merge branch 'develop' of github.com:baidu/Paddle into feature/backward
reyoung Aug 1, 2017
80baf86
Merge branch 'feature/backward' of github.com:reyoung/Paddle into fea…
reyoung Aug 1, 2017
e2fd2bd
Follow comments and merge develop
reyoung Aug 1, 2017
737ea05
Use static_cast, Fix unittest
reyoung Aug 1, 2017
9cc9907
Merge branch 'develop' of github.com:baidu/Paddle into feature/backward
reyoung Aug 1, 2017
051d6c8
Merge develop
reyoung Aug 1, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 40 additions & 52 deletions paddle/framework/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,88 +31,74 @@ static bool AllInSet(const std::vector<std::string>& names,
return true;
}

static std::vector<size_t> InSetIdx(
const std::vector<std::string>& names, const std::string& suffix,
const std::unordered_set<std::string>& set) {
std::vector<size_t> ret_val;
ret_val.reserve(names.size());
for (size_t i = 0; i < names.size(); ++i) {
if (set.find(names[i] + suffix) != set.end()) {
ret_val.push_back(i);
}
}
return ret_val;
}

static std::shared_ptr<OperatorBase> EmptyOp() {
static std::shared_ptr<OperatorBase> NOP() {
auto net_op = std::make_shared<NetOp>();
net_op->type_ = "@EMPTY_OP@";
net_op->type_ = "@NOP@";
net_op->CompleteAddOp();
return net_op;
}

/**
* @brief Backward an operator, implementation
* @param forwardOp the forward operator
* @param no_grad_names variable names not calculate for gradient. Like X@GRAD
* is not needed.
* @param uniq_id a unique index used inside BackwardImpl, it will be shared
* through recursive invoke.
* @return The backward operator. For simple situation, it is a simple operator.
* For complex situation, it is a NetOp.
*
* See Backward.h for details
*/
static std::shared_ptr<OperatorBase> BackwardImpl(
// Get backward operator from a forward operator, recursively implementation.
//
// no_grad_names the gradient variable names without gradient calculating.
//
// uniq_id is a unique index used inside recursively calling BackwardRecursive.
// use `uid = uniq_id++;` to get the unique index, and pass `uniq_id` through
// recursive calling.
//
// returns The backward operator. For simple situation, it is a simple
// operator. For complex situation, it is a NetOp.
//
// See Backward.h for details
static std::shared_ptr<OperatorBase> BackwardRecursive(
const OperatorBase& forwardOp,
std::unordered_set<std::string>& no_grad_names, size_t& uniq_id);
std::shared_ptr<OperatorBase> BackwardRecursive(
const OperatorBase& forwardOp,
std::unordered_set<std::string>& no_grad_names, size_t& uniq_id) {
/**
* If all input gradients of forwarding operator do not need to calculate,
* just return an EmptyOp. Not return null ptr because EmptyOp does not take
* too much time for calculation, but it is useful for simplifying logic.
*/
// If all input gradients of forwarding operator do not need to calculate,
// just return an NOP. Not return null ptr because NOP does not take
// too much time for calculation, but it is useful for simplifying logic.
if (AllInSet(forwardOp.inputs_, OperatorBase::GRAD_VAR_SUFFIX(),
no_grad_names)) {
return EmptyOp();
return NOP();
}

/**
* All output gradients of forwarding operator do not need to calculate. Then
* all input gradients cannot be computed at all, and we put them into
* `no_grad_names` set. Return an EmptyOp.
*/
// All output gradients of forwarding operator do not need to calculate. Then
// all input gradients cannot be computed at all, and we put them into
// `no_grad_names` set. Return an NOP.
if (AllInSet(forwardOp.outputs_, OperatorBase::GRAD_VAR_SUFFIX(),
no_grad_names)) {
for (auto& name : forwardOp.inputs_) {
/// Mark all input is not need
// Mark all input is not need
no_grad_names.insert(name + OperatorBase::GRAD_VAR_SUFFIX());
}
return EmptyOp();
return NOP();
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我可以理解所有输入都不需要计算梯度时,可标记所有输出都不计算梯度。
但是,没有太想明白,什么情况下,依据输出不需要计算梯度,来标记输入不计算梯度。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在反向遍历的时候,如果使用输出Variable的Operator都没有计算梯度,自己也没办法计算梯度。


//! Returned gradient network
// Returned gradient network
auto net = std::make_shared<NetOp>();

if (forwardOp.IsNetOp()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么不能把获取NetOp的Backward Op写在NetOp对应的一个方法里?
会遇到其他复杂Op(SwitchOp? 或许举得例子不合适)还需要再写一个分支的情况吗?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@qingqing01 我理解这里是Backward的生成逻辑,和NetOp自身的反向不在同一个层面。Gradient Operator是可插拔Unit,Backward是系统core的一部分。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

会遇到其他复杂Op(SwitchOp? 或许举得例子不合适)还需要再写一个分支的情况吗?

或许复杂Op的backward我们能注册到另一个地方。但是这个注册机制感觉非常不统一。。最简单的情况肯定是这么写,在没想到更好的办法的时候,先这么做吧。

/// Because forwardOp is a net op, it can static_cast.
// Because forwardOp is a net op, it can static_cast.
auto& forwardNet = static_cast<const NetOp&>(forwardOp);

//! Map from output gradient variable name to operator's indices in backward
//! net. That operator generates that variable.
// Map from output gradient variable name to operator's indices in backward
// net. That operator generates that variable.
std::unordered_map<std::string, std::vector<size_t>> dup_output_ops;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is dup mean?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

duplicated


size_t local_op_id = 0;
/// reversely travel forwardNet
// reversely travel forwardNet
for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend();
++it, ++local_op_id) {
auto fwd = *it;
auto bwd = BackwardImpl(*fwd, no_grad_names, uniq_id);
auto bwd = BackwardRecursive(*fwd, no_grad_names, uniq_id);
net->AddOp(bwd);
for (auto& out : bwd->outputs_) {
dup_output_ops[out].emplace_back(local_op_id);
}
}
/// Get unique ID for this method.
// Get unique ID for this method.
auto uid = uniq_id++;
// TODO(dzh): more comment
using Pos = std::pair<size_t, std::shared_ptr<OperatorBase>>;
Expand Down Expand Up @@ -145,13 +131,15 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
}

} else {
//! TODO(fjy)
std::shared_ptr<OperatorBase> grad_op = OpRegistry::CreateGradOp(forwardOp);
for (std::string& grad_input : grad_op->inputs_) {
if (no_grad_names.count(grad_input)) {
std::string prefix = grad_input.substr(
0, grad_input.size() - OperatorBase::GRAD_VAR_SUFFIX().size());
grad_input = prefix + OperatorBase::ZERO_VAR_SUFFIX();

// If part of input gradient of that operator is not calculated, fill
// zero variables to that input gradient.
net->AddOp(OpRegistry::CreateOp("fill_zeros_like", {prefix},
{grad_input}, {}));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need more comments for fill_zeros_like op.

}
Expand All @@ -173,8 +161,8 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
return net;
}

//! See header for comments
extern std::shared_ptr<OperatorBase> Backward(
// See header for comments
std::shared_ptr<OperatorBase> Backward(
const OperatorBase& forwardOp,
const std::unordered_set<std::string>& no_grad_vars) {
std::unordered_set<std::string> no_grad_names;
Expand All @@ -184,7 +172,7 @@ extern std::shared_ptr<OperatorBase> Backward(
no_grad_names.insert(name + OperatorBase::GRAD_VAR_SUFFIX());
}
size_t uid = 0;
return BackwardImpl(forwardOp, no_grad_names, uid);
return BackwardRecursive(forwardOp, no_grad_names, uid);
}
} // namespace framework
} // namespace paddle
8 changes: 2 additions & 6 deletions paddle/framework/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,8 @@
namespace paddle {
namespace framework {

/**
* @brief
* @param forwardOp
* @param no_grad_vars ignored input name of forward
* @return
*/
// Create the backward operator from a forward operator.
// TODO(yuyang18): Add more API reference comment.
extern std::shared_ptr<OperatorBase> Backward(
const OperatorBase& forwardOp,
const std::unordered_set<std::string>& no_grad_vars);
Expand Down
1 change: 0 additions & 1 deletion paddle/framework/backward_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ TEST(Backward, simple_op_grad) {

ASSERT_EQ("X" + f::OperatorBase::GRAD_VAR_SUFFIX(),
gop->Output("X" + f::OperatorBase::GRAD_VAR_SUFFIX()));
// LOG(INFO) << gop->Output("X" + "@GRAD");
}

TEST(Backward, simple_op_not_need_grad) {
Expand Down
18 changes: 10 additions & 8 deletions paddle/operators/fill_zeros_like_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,17 @@ namespace operators {

class FillZerosLikeOp : public framework::OperatorWithKernel {
protected:
void InferShape(
const std::vector<const framework::Tensor *> &inputs,
const std::vector<framework::Tensor *> &outputs) const override {
PADDLE_ENFORCE(inputs.size() == 1,
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 1UL,
"Input size of FillZerosLikeOp must be one.");
PADDLE_ENFORCE(outputs.size() == 1, "Output size of AddOp must be one.");
PADDLE_ENFORCE(inputs[0] != nullptr && outputs[0] != nullptr,
"Outputs of FillZerosLikeOp must all be set.");
outputs[0]->Resize(inputs[0]->dims());
PADDLE_ENFORCE(ctx.OutputSize() == 1UL,
"Output size of AddOp must be one.");
PADDLE_ENFORCE(ctx.InputVar(0) != nullptr,
"Input of FillZerosLikeOp must be set.");
PADDLE_ENFORCE(ctx.OutputVar(0) != nullptr,
"Output of FillZerosLikeOp must be set.");
ctx.Output<framework::Tensor>(0)->Resize(
ctx.Input<framework::Tensor>(0)->dims());
}
};

Expand Down
4 changes: 2 additions & 2 deletions paddle/operators/fill_zeros_like_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ namespace operators {
template <typename Place, typename T>
class FillZerosLikeKernel : public framework::OpKernel {
public:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add a Python unit test for this operator

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will add the unit test in other PR. @Canpio @dzhwinter

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@reyoung get it

void Compute(const framework::KernelContext& context) const override {
auto* output = context.Output(0)->GetMutable<framework::Tensor>();
void Compute(const framework::ExecutionContext& context) const override {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

framework::xxx ===> xxx

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed in Next PR

auto* output = context.Output<framework::Tensor>(0);
output->mutable_data<T>(context.GetPlace());
framework::EigenVector<T>::Flatten(*output).setZero();
}
Expand Down
11 changes: 6 additions & 5 deletions paddle/operators/recurrent_network_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -312,13 +312,14 @@ class RecurrentAlgorithmProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
: OpProtoAndCheckerMaker(proto, op_checker) {
const auto& name = RecurrentOp::kArgName;
// inputs and outputs stored in proto
AddInputs(name.inlinks,
"the input that need to be segmented for each step.");
AddInputs(name.boot_memories, "variables to initialize memories.");
AddInput(name.inlinks, "the input that need to be segmented for each step.")
.SetMultiple();
AddInput(name.boot_memories, "variables to initialize memories.")
.SetMultiple();
AddInput(name.step_net, "network shared by all steps.");

AddOutputs(name.outlinks,
"the output that need to concated for all steps.");
AddOutput(name.outlinks, "the output that need to concated for all steps.")
.SetMultiple();
AddOutput(name.step_scopes, "step scopes");

// Attributes stored in AttributeMap
Expand Down