Skip to content

Commit f1c40e1

Browse files
committed
增强stage对层间传递stop_grad=True的参数的支持
1 parent 497bc37 commit f1c40e1

File tree

1 file changed

+30
-15
lines changed
  • python/paddle/distributed/auto_parallel/pipelining

1 file changed

+30
-15
lines changed

python/paddle/distributed/auto_parallel/pipelining/stage.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def __init__(
205205
# Forward infra
206206
self.args_recv_info: dict[int, tuple[InputInfo, ...]] = {}
207207
self.act_send_info: dict[int, list] = {}
208-
208+
self.grad_recv_indices: dict[int, list] = {}
209209
# Backward infra will created lazily
210210
self.grad_recv_info: dict = {}
211211
self.grad_send_info: list | None = None
@@ -481,7 +481,7 @@ def get_fwd_send_ops(self, fwd_chunk_id: int) -> list[dist.P2POp]:
481481
else dtensor_to_local(
482482
out,
483483
out.process_mesh,
484-
self.grads_meta[idx].placements,
484+
out.placements,
485485
)
486486
),
487487
peer_global_rank,
@@ -620,9 +620,6 @@ def forward_maybe_with_nosync(self, *args, **kwargs):
620620
def backward_maybe_with_nosync(
621621
self, backward_type, bwd_kwargs: dict, last_backward=False
622622
) -> tuple[tuple[paddle.Tensor | None, ...], list[dict[str, Any] | None]]:
623-
"""
624-
PP 与 DP 混用时,在每个batch的最后一个microbatch的反向开始时,此时的一些行为可能会有所差异,此时可能需要注意。
625-
"""
626623

627624
def perform_backward(
628625
backward_type,
@@ -711,9 +708,19 @@ def forward_one_chunk(
711708
flat_args = _flatten_args(input_args)
712709
flat_kwargs = _flatten_args(composite_kwargs)
713710
flatten_input_tensors = flat_args + flat_kwargs
711+
grad_required_output_tuple = tuple(
712+
out
713+
for out in output_tuple
714+
if isinstance(out, paddle.Tensor) and not out.stop_gradient
715+
)
716+
grad_required_flatten_input_tensors = [
717+
inp
718+
for inp in flatten_input_tensors
719+
if isinstance(inp, paddle.Tensor) and not inp.stop_gradient
720+
]
714721
self.fwd_cache[fwd_chunk_id] = (
715-
output_tuple, # stage_output
716-
flatten_input_tensors, # input_values
722+
grad_required_output_tuple, # stage_output
723+
grad_required_flatten_input_tensors, # input_values
717724
)
718725

719726
logger.debug(
@@ -1005,8 +1012,16 @@ def requires_grad(x):
10051012
flatten_input_tensors = [
10061013
x for x in flatten_input_tensors if not x.stop_gradient
10071014
]
1015+
grad_required_outputs = _normalize_model_output_as_tuple(
1016+
outputs
1017+
)
1018+
grad_required_outputs = tuple(
1019+
out
1020+
for out in grad_required_outputs
1021+
if isinstance(out, paddle.Tensor) and not out.stop_gradient
1022+
)
10081023
self.fwd_cache[0] = (
1009-
_normalize_model_output_as_tuple(outputs), # stage_output
1024+
grad_required_outputs, # stage_output
10101025
flatten_input_tensors, # input_values
10111026
)
10121027

@@ -1100,13 +1115,16 @@ def _prepare_forward_infra(
11001115
# Send info during forward for each activation
11011116
# only need the rank that is being sent to
11021117
self.act_send_info: dict[int, list] = {}
1103-
1104-
for idx in range(len(self.get_outputs_meta())):
1118+
outputs_meta = self.get_outputs_meta()
1119+
for idx in range(len(outputs_meta)):
11051120
# We assume we always send to stage + 1
11061121
if not self.is_last:
11071122
self.act_send_info[idx] = [self.stage_index + 1]
1123+
if not outputs_meta[idx].stop_gradient:
1124+
self.grad_recv_indices[idx] = [self.stage_index + 1]
11081125
else:
11091126
self.act_send_info[idx] = []
1127+
self.grad_recv_indices[idx] = []
11101128

11111129
return outputs
11121130

@@ -1195,9 +1213,7 @@ def _prepare_backward_infra(
11951213
self._shape_inference_bwd()
11961214
for mb_index in range(num_microbatches):
11971215
# `grad_recv_info` is a mirror of `act_send_info`
1198-
self.grad_recv_info[mb_index] = self._create_grad_recv_info(
1199-
self.act_send_info
1200-
)
1216+
self.grad_recv_info[mb_index] = self._create_grad_recv_info()
12011217

12021218
# the last stage does not need recv grads from other rank
12031219
if not self.is_last:
@@ -1211,7 +1227,6 @@ def _prepare_backward_infra(
12111227

12121228
def _create_grad_recv_info(
12131229
self,
1214-
act_send_info: dict,
12151230
) -> tuple[_RecvInfo, ...]:
12161231
grad_recv_info: tuple[_RecvInfo, ...] = ()
12171232
if not self.is_last:
@@ -1224,7 +1239,7 @@ def _create_grad_recv_info(
12241239
dst_list[0],
12251240
_make_tensor_from_meta(self.grads_meta[idx]),
12261241
)
1227-
for idx, dst_list in act_send_info.items()
1242+
for idx, dst_list in self.grad_recv_indices.items()
12281243
]
12291244
)
12301245
return grad_recv_info

0 commit comments

Comments
 (0)