Skip to content

Commit 3ca3a20

Browse files
Yang Yang(Tony)jacquesqiao
authored andcommitted
Prune Design Doc (#4732)
* Create prune.md * modification based on comment * remove insertion * rename id to block_id * Update prune.md * formatting
1 parent 831927d commit 3ca3a20

File tree

1 file changed

+63
-0
lines changed

1 file changed

+63
-0
lines changed

doc/design/prune.md

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Prune
2+
3+
## Motivation
4+
5+
We want to support running inference, training and checkpointing in one `ProgramDesc`. We implement
6+
`void Prune(const ProgramDesc* input, ProgramDesc* output)` function, which takes a `ProgramDesc`
7+
and generate a pruned `ProgramDesc`.
8+
9+
## Challenge
10+
11+
Pruning need to support both variables and operators being evaluation targets. Consider the following
12+
different situations.
13+
14+
```python
15+
# Case 1: run foward pass.
16+
cost_np = session.run(target=cost)
17+
# Case 2: run backward passing.
18+
opts_np, _ = session.run(target=[cost, opt])
19+
# Case 3: run checkpointing
20+
_ = session.run(target=checkpoint)
21+
```
22+
23+
## Solution
24+
25+
To support evaluation of operators, we add `is_target` field in the `OpDesc`.
26+
27+
```c++
28+
message OpDesc {
29+
required string type = 3;
30+
repeated Var inputs = 1;
31+
repeated Var outputs = 2;
32+
repeated Attr attrs = 4;
33+
optional bool is_target = 5 [ default = false ];
34+
};
35+
```
36+
37+
To support evaluation of variables, we add [fetch_op](https://github.com/PaddlePaddle/Paddle/pull/4599).
38+
For each variable in the `target`, we insert a `fetch_op` into the `ProgramDesc` with `variable` being
39+
`fetch_op`'s input. Then we also set `fetch_op` is a target.
40+
41+
### Algorithm
42+
43+
If an operator needs to be run, it must fall into one of the following cases:
44+
45+
1. It is the target.
46+
2. It is depended by some other ops, meaning its output is some other op's input.
47+
48+
The first case can be checked by `op_desc.is_traget()` . The second case can be implement as
49+
50+
```c++
51+
bool HasDependentVar(const OpDesc& op_desc, const std::set<string>& dependent_vars) {
52+
for (auto& var : op_desc.outputs()) {
53+
for (auto& argu : var.arguments()) {
54+
if (dependent_vars.count(argu) != 0) {
55+
return true;
56+
}
57+
}
58+
}
59+
return false;
60+
}
61+
```
62+
63+
Then the whole algorithm can be implemented as the following [code](https://github.com/tonyyang-svail/Paddle/blob/prune_impl/paddle/framework/prune.cc).

0 commit comments

Comments
 (0)