Skip to content

Commit 69c0518

Browse files
committed
change the name for grad_recv_indices,fix the vpp hang
1 parent f1c40e1 commit 69c0518

File tree

2 files changed

+13
-11
lines changed

2 files changed

+13
-11
lines changed

python/paddle/distributed/auto_parallel/pipelining/schedules.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1008,10 +1008,8 @@ def _get_1f1b_rank_ops(
10081008
# earliest time step of first backward = [local_stages * group_size + 2 * (group_size - 1 - rank)]
10091009
# warmup_ops = calculated above
10101010
post_warmup_ops = (
1011-
(n_local_stages * pp_group_size + 2 * (pp_group_size - 1 - rank))
1012-
- (warmup_ops + rank)
1013-
- 1
1014-
)
1011+
n_local_stages * pp_group_size + 2 * (pp_group_size - 1 - rank)
1012+
) - (warmup_ops + rank)
10151013

10161014
if enable_zero_bubble:
10171015
post_warmup_ops = pp_group_size - rank - 1
@@ -1076,11 +1074,13 @@ def _get_1f1b_rank_ops(
10761074
)
10771075
)
10781076
weight_op_count += 1
1079-
if op == warmup_ops + fwd_bwd_ops - 1:
1080-
# This is the last step in the 1F1B Phase, the bubbles are symmetrical with respect to the ending phase of the warm_up
1081-
rank_ops.extend([None] * post_warmup_ops)
10821077
# Cooldown phase
10831078
else:
1079+
# During cooldown phase, we need steps to align with 1f1b happening in other ranks
1080+
# TODO: we don't need to always append, after all 1f1b are finished we can stop appending None
1081+
if not enable_zero_bubble:
1082+
rank_ops.append(None)
1083+
10841084
bwd_stage_index = backward_stage_index(op)
10851085
bwd_stage_mb_index[bwd_stage_index] = (
10861086
bwd_mb_index := bwd_stage_mb_index[bwd_stage_index]

python/paddle/distributed/auto_parallel/pipelining/stage.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,9 @@ def __init__(
205205
# Forward infra
206206
self.args_recv_info: dict[int, tuple[InputInfo, ...]] = {}
207207
self.act_send_info: dict[int, list] = {}
208-
self.grad_recv_indices: dict[int, list] = {}
208+
self._need_grad_indices: dict[int, list] = (
209+
{}
210+
) # record the index of output that needs to receive grad from the next stage.
209211
# Backward infra will created lazily
210212
self.grad_recv_info: dict = {}
211213
self.grad_send_info: list | None = None
@@ -1121,10 +1123,10 @@ def _prepare_forward_infra(
11211123
if not self.is_last:
11221124
self.act_send_info[idx] = [self.stage_index + 1]
11231125
if not outputs_meta[idx].stop_gradient:
1124-
self.grad_recv_indices[idx] = [self.stage_index + 1]
1126+
self._need_grad_indices[idx] = [self.stage_index + 1]
11251127
else:
11261128
self.act_send_info[idx] = []
1127-
self.grad_recv_indices[idx] = []
1129+
self._need_grad_indices[idx] = []
11281130

11291131
return outputs
11301132

@@ -1239,7 +1241,7 @@ def _create_grad_recv_info(
12391241
dst_list[0],
12401242
_make_tensor_from_meta(self.grads_meta[idx]),
12411243
)
1242-
for idx, dst_list in self.grad_recv_indices.items()
1244+
for idx, dst_list in self._need_grad_indices.items()
12431245
]
12441246
)
12451247
return grad_recv_info

0 commit comments

Comments
 (0)