|
| 1 | +# Prune |
| 2 | + |
| 3 | +## Motivation |
| 4 | + |
| 5 | +We want to support running inference, training and checkpointing in one `ProgramDesc`. We implement |
| 6 | +`void Prune(const ProgramDesc* input, ProgramDesc* output)` function, which takes a `ProgramDesc` |
| 7 | +and generate a pruned `ProgramDesc`. |
| 8 | + |
| 9 | +## Challenge |
| 10 | + |
| 11 | +Pruning need to support both variables and operators being evaluation targets. Consider the following |
| 12 | +different situations. |
| 13 | + |
| 14 | +```python |
| 15 | +# Case 1: run foward pass. |
| 16 | +cost_np = session.run(target=cost) |
| 17 | +# Case 2: run backward passing. |
| 18 | +opts_np, _ = session.run(target=[cost, opt]) |
| 19 | +# Case 3: run checkpointing |
| 20 | +_ = session.run(target=checkpoint) |
| 21 | +``` |
| 22 | + |
| 23 | +## Solution |
| 24 | + |
| 25 | +To support evaluation of operators, we add `is_target` field in the `OpDesc`. |
| 26 | + |
| 27 | +```c++ |
| 28 | +message OpDesc { |
| 29 | + required string type = 3; |
| 30 | + repeated Var inputs = 1; |
| 31 | + repeated Var outputs = 2; |
| 32 | + repeated Attr attrs = 4; |
| 33 | + optional bool is_target = 5 [ default = false ]; |
| 34 | +}; |
| 35 | +``` |
| 36 | + |
| 37 | +To support evaluation of variables, we add [fetch_op](https://github.com/PaddlePaddle/Paddle/pull/4599). |
| 38 | +For each variable in the `target`, we insert a `fetch_op` into the `ProgramDesc` with `variable` being |
| 39 | +`fetch_op`'s input. Then we also set `fetch_op` is a target. |
| 40 | + |
| 41 | +### Algorithm |
| 42 | + |
| 43 | +If an operator needs to be run, it must fall into one of the following cases: |
| 44 | + |
| 45 | +1. It is the target. |
| 46 | +2. It is depended by some other ops, meaning its output is some other op's input. |
| 47 | + |
| 48 | +The first case can be checked by `op_desc.is_traget()` . The second case can be implement as |
| 49 | + |
| 50 | +```c++ |
| 51 | +bool HasDependentVar(const OpDesc& op_desc, const std::set<string>& dependent_vars) { |
| 52 | + for (auto& var : op_desc.outputs()) { |
| 53 | + for (auto& argu : var.arguments()) { |
| 54 | + if (dependent_vars.count(argu) != 0) { |
| 55 | + return true; |
| 56 | + } |
| 57 | + } |
| 58 | + } |
| 59 | + return false; |
| 60 | +} |
| 61 | +``` |
| 62 | +
|
| 63 | +Then the whole algorithm can be implemented as the following [code](https://github.com/tonyyang-svail/Paddle/blob/prune_impl/paddle/framework/prune.cc). |
0 commit comments