Skip to content

Commit 388eaa7

Browse files
authored
[Auto-parallel] 增强stage对层间传递stop_grad=True的参数的支持 (#73459)
* 增强stage对层间传递stop_grad=True的参数的支持 * change the name for grad_recv_indices,fix the vpp hang * 为TensorMeta添加stop_gradient属性 * Modify variable naming
1 parent 8b88b53 commit 388eaa7

File tree

2 files changed

+33
-15
lines changed

2 files changed

+33
-15
lines changed

python/paddle/distributed/auto_parallel/pipelining/stage.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,9 @@ def __init__(
204204
# Forward infra
205205
self.args_recv_info: dict[int, tuple[InputInfo, ...]] = {}
206206
self.act_send_info: dict[int, list] = {}
207-
207+
self._need_grad_indices: dict[int, list] = (
208+
{}
209+
) # record the index of output that needs to receive grad from the next stage.
208210
# Backward infra will created lazily
209211
self.grad_recv_info: dict = {}
210212
self.grad_send_info: list | None = None
@@ -480,7 +482,7 @@ def get_fwd_send_ops(self, fwd_chunk_id: int) -> list[dist.P2POp]:
480482
else dtensor_to_local(
481483
out,
482484
out.process_mesh,
483-
self.grads_meta[idx].placements,
485+
out.placements,
484486
)
485487
),
486488
peer_global_rank,
@@ -619,9 +621,6 @@ def forward_maybe_with_nosync(self, *args, **kwargs):
619621
def backward_maybe_with_nosync(
620622
self, backward_type, bwd_kwargs: dict, last_backward=False
621623
) -> tuple[tuple[paddle.Tensor | None, ...], list[dict[str, Any] | None]]:
622-
"""
623-
PP 与 DP 混用时,在每个batch的最后一个microbatch的反向开始时,此时的一些行为可能会有所差异,此时可能需要注意。
624-
"""
625624

626625
def perform_backward(
627626
backward_type,
@@ -710,9 +709,19 @@ def forward_one_chunk(
710709
flat_args = _flatten_args(input_args)
711710
flat_kwargs = _flatten_args(composite_kwargs)
712711
flatten_input_tensors = flat_args + flat_kwargs
712+
requires_grad_output_tuple = tuple(
713+
out
714+
for out in output_tuple
715+
if isinstance(out, paddle.Tensor) and not out.stop_gradient
716+
)
717+
flatten_requires_grad_input_tensors = [
718+
inp
719+
for inp in flatten_input_tensors
720+
if isinstance(inp, paddle.Tensor) and not inp.stop_gradient
721+
]
713722
self.fwd_cache[fwd_chunk_id] = (
714-
output_tuple, # stage_output
715-
flatten_input_tensors, # input_values
723+
requires_grad_output_tuple, # stage_output
724+
flatten_requires_grad_input_tensors, # input_values
716725
)
717726

718727
logger.debug(
@@ -1169,8 +1178,16 @@ def requires_grad(x):
11691178
flatten_input_tensors = [
11701179
x for x in flatten_input_tensors if not x.stop_gradient
11711180
]
1181+
grad_required_outputs = _normalize_model_output_as_tuple(
1182+
outputs
1183+
)
1184+
grad_required_outputs = tuple(
1185+
out
1186+
for out in grad_required_outputs
1187+
if isinstance(out, paddle.Tensor) and not out.stop_gradient
1188+
)
11721189
self.fwd_cache[0] = (
1173-
_normalize_model_output_as_tuple(outputs), # stage_output
1190+
grad_required_outputs, # stage_output
11741191
flatten_input_tensors, # input_values
11751192
)
11761193

@@ -1264,13 +1281,16 @@ def _prepare_forward_infra(
12641281
# Send info during forward for each activation
12651282
# only need the rank that is being sent to
12661283
self.act_send_info: dict[int, list] = {}
1267-
1268-
for idx in range(len(self.get_outputs_meta())):
1284+
outputs_meta = self.get_outputs_meta()
1285+
for idx in range(len(outputs_meta)):
12691286
# We assume we always send to stage + 1
12701287
if not self.is_last:
12711288
self.act_send_info[idx] = [self.stage_index + 1]
1289+
if not outputs_meta[idx].stop_gradient:
1290+
self._need_grad_indices[idx] = [self.stage_index + 1]
12721291
else:
12731292
self.act_send_info[idx] = []
1293+
self._need_grad_indices[idx] = []
12741294

12751295
return outputs
12761296

@@ -1359,9 +1379,7 @@ def _prepare_backward_infra(
13591379
self._shape_inference_bwd()
13601380
for mb_index in range(num_microbatches):
13611381
# `grad_recv_info` is a mirror of `act_send_info`
1362-
self.grad_recv_info[mb_index] = self._create_grad_recv_info(
1363-
self.act_send_info
1364-
)
1382+
self.grad_recv_info[mb_index] = self._create_grad_recv_info()
13651383

13661384
# the last stage does not need recv grads from other rank
13671385
if not self.is_last:
@@ -1375,7 +1393,6 @@ def _prepare_backward_infra(
13751393

13761394
def _create_grad_recv_info(
13771395
self,
1378-
act_send_info: dict,
13791396
) -> tuple[_RecvInfo, ...]:
13801397
grad_recv_info: tuple[_RecvInfo, ...] = ()
13811398
if not self.is_last:
@@ -1388,7 +1405,7 @@ def _create_grad_recv_info(
13881405
dst_list[0],
13891406
_make_tensor_from_meta(self.grads_meta[idx]),
13901407
)
1391-
for idx, dst_list in act_send_info.items()
1408+
for idx, dst_list in self._need_grad_indices.items()
13921409
]
13931410
)
13941411
return grad_recv_info

python/paddle/distributed/auto_parallel/pipelining/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def __init__(self, tensor: paddle.Tensor):
123123
self._local_shape = None
124124
self.dtype = tensor.dtype
125125
self.placements = None if not tensor.is_dist() else tensor.placements
126+
self.stop_gradient = tensor.stop_gradient
126127

127128
def __repr__(self):
128129
return f"TensorMeta(global_shape={self.shape},local_shape={self._local_shape}, dtype={self.dtype}, placements={self.placements})"

0 commit comments

Comments
 (0)