Skip to content

Commit e57a40b

Browse files
authored
Merge pull request #7140 from JiayiFeng/dev_refine_backward_no_grad_var_handling
Dev refine backward no grad var handling
2 parents 39502e6 + 33e7520 commit e57a40b

File tree

2 files changed

+23
-6
lines changed

2 files changed

+23
-6
lines changed

doc/design/backward.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,11 @@ See function `_addup_repetitive_outputs_` in `backward.py` for implementation de
106106

107107
In our framework, variables can be marked as *no_gradient*, it means that the gradient of this variable is unnecessary and can be considered as zero in model training. Apparently, when all the outputs of some `grad_op` are marked as *no_gradient*, the `grad_op` itself can be skipped in backward pass.
108108

109-
But these unnecessary gradients still need to be creating and initialized by something, otherwise following `grad_op`s who take these gradients as inputs take the risk of using uninitialized memory. In our code, we employ `fill_zeros_like_op` to initialize them as all zeros.
109+
Another situation is all the gradient inputs of some `grad_op` are marked as *no_gradient*, which means all of them can be considered as zeros. For `grad_op`s are in essence the propagation of gradients, all the outputs are definitely zeros when all gradient inputs are zeros. Therefore the `grad_op` can also be skipped.
110110

111-
This features are implemented in function `_remove_no_grad_branch_`. It checks new created `grad_op`s one-by-one, removes whose outputs are all in `no_grad_set` or inserts `fill_zeros_like_op` when its necessary. We can get the `no_grad_set` from the `_append_backward_ops_` argument `no_grad_dict` or generate it on the fly by scanning all variables' `no_gradient` attribute(True or False).
111+
It should be noted that all these zero gradients still need to be creating and initialized by something, otherwise following `grad_op`s who take these gradients as inputs take the risk of using uninitialized memory. In our code, we employ `fill_zeros_like_op` to initialize them as all zeros.
112+
113+
This features are implemented in function `_remove_no_grad_branch_`. It checks new created `grad_op`s one-by-one, removes who can be skipped and inserts `fill_zeros_like_op` when its necessary. We can get the `no_grad_set` from the `_append_backward_ops_` argument `no_grad_dict` or generate it on the fly by scanning all variables' `no_gradient` attribute(True or False).
112114

113115
### Creating Backward Variables
114116

python/paddle/v2/fluid/backward.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ def _all_in_set_(cands, s):
5757
"""
5858
Test if all elements of 'cands' are in set 's'
5959
"""
60+
if len(cands) == 0:
61+
return False
6062
for c in cands:
6163
if not c in s:
6264
return False
@@ -136,12 +138,23 @@ def _remove_no_grad_branch_(op_descs, no_grad_set):
136138
Remove unnecessary grad ops
137139
A grad op can be removed in two cases:
138140
1. all outputs of the grad op are in 'no_grad_set'
139-
2. (TODO) all grad inputs of the grad op are in 'no_grad_set'
141+
2. all grad inputs of the grad op are in 'no_grad_set'
140142
"""
143+
144+
def _op_can_be_removed_(op_desc, no_grad_set):
145+
out_arg_names = op_desc.output_arg_names()
146+
if len(out_arg_names) == 0 or _all_in_set_(out_arg_names, no_grad_set):
147+
return True
148+
if _all_in_set_(
149+
filter(lambda name: name.find(core.grad_var_suffix()) != -1,
150+
op_desc.input_arg_names()), no_grad_set):
151+
no_grad_set.union(out_arg_names)
152+
return True
153+
return False
154+
141155
# Remove ops whose outputs are all in no_grad_dict
142156
op_descs = filter(
143-
lambda op_desc: not _all_in_set_(op_desc.output_arg_names(), no_grad_set),
144-
op_descs)
157+
lambda op_desc: not _op_can_be_removed_(op_desc, no_grad_set), op_descs)
145158
# Insert fill_zeros_like_op
146159
to_insert = []
147160
for idx, op_desc in enumerate(op_descs):
@@ -284,7 +297,9 @@ def append_backward(loss, parameter_list=None, no_grad_set=None):
284297
block_no_grad_set.add(_append_grad_suffix_(var.name))
285298
no_grad_dict[block.idx] = block_no_grad_set
286299
elif isinstance(no_grad_set, set):
287-
no_grad_dict = {0: no_grad_set}
300+
no_grad_dict = {
301+
0: set([_append_grad_suffix_(name) for name in no_grad_set])
302+
}
288303
else:
289304
raise ValueError("'no_grad_set' should be a set or None.")
290305

0 commit comments

Comments
 (0)