Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 32 additions & 15 deletions python/paddle/distributed/auto_parallel/pipelining/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,9 @@ def __init__(
# Forward infra
self.args_recv_info: dict[int, tuple[InputInfo, ...]] = {}
self.act_send_info: dict[int, list] = {}

self._need_grad_indices: dict[int, list] = (
{}
) # record the index of output that needs to receive grad from the next stage.
# Backward infra will created lazily
self.grad_recv_info: dict = {}
self.grad_send_info: list | None = None
Expand Down Expand Up @@ -480,7 +482,7 @@ def get_fwd_send_ops(self, fwd_chunk_id: int) -> list[dist.P2POp]:
else dtensor_to_local(
out,
out.process_mesh,
self.grads_meta[idx].placements,
out.placements,
)
),
peer_global_rank,
Expand Down Expand Up @@ -619,9 +621,6 @@ def forward_maybe_with_nosync(self, *args, **kwargs):
def backward_maybe_with_nosync(
self, backward_type, bwd_kwargs: dict, last_backward=False
) -> tuple[tuple[paddle.Tensor | None, ...], list[dict[str, Any] | None]]:
"""
PP 与 DP 混用时,在每个batch的最后一个microbatch的反向开始时,此时的一些行为可能会有所差异,此时可能需要注意。
"""

def perform_backward(
backward_type,
Expand Down Expand Up @@ -710,9 +709,19 @@ def forward_one_chunk(
flat_args = _flatten_args(input_args)
flat_kwargs = _flatten_args(composite_kwargs)
flatten_input_tensors = flat_args + flat_kwargs
grad_required_output_tuple = tuple(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

命名似有问题,grad开头,给人误解为是grad数据。命名为 requires_grad_output_tuple更好

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

out
for out in output_tuple
if isinstance(out, paddle.Tensor) and not out.stop_gradient
)
grad_required_flatten_input_tensors = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

inp
for inp in flatten_input_tensors
if isinstance(inp, paddle.Tensor) and not inp.stop_gradient
]
self.fwd_cache[fwd_chunk_id] = (
output_tuple, # stage_output
flatten_input_tensors, # input_values
grad_required_output_tuple, # stage_output
grad_required_flatten_input_tensors, # input_values
)

logger.debug(
Expand Down Expand Up @@ -1169,8 +1178,16 @@ def requires_grad(x):
flatten_input_tensors = [
x for x in flatten_input_tensors if not x.stop_gradient
]
grad_required_outputs = _normalize_model_output_as_tuple(
outputs
)
grad_required_outputs = tuple(
out
for out in grad_required_outputs
if isinstance(out, paddle.Tensor) and not out.stop_gradient
)
self.fwd_cache[0] = (
_normalize_model_output_as_tuple(outputs), # stage_output
grad_required_outputs, # stage_output
flatten_input_tensors, # input_values
)

Expand Down Expand Up @@ -1264,13 +1281,16 @@ def _prepare_forward_infra(
# Send info during forward for each activation
# only need the rank that is being sent to
self.act_send_info: dict[int, list] = {}

for idx in range(len(self.get_outputs_meta())):
outputs_meta = self.get_outputs_meta()
for idx in range(len(outputs_meta)):
# We assume we always send to stage + 1
if not self.is_last:
self.act_send_info[idx] = [self.stage_index + 1]
if not outputs_meta[idx].stop_gradient:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TensorMeta 没有这个属性

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

self._need_grad_indices[idx] = [self.stage_index + 1]
else:
self.act_send_info[idx] = []
self._need_grad_indices[idx] = []

return outputs

Expand Down Expand Up @@ -1359,9 +1379,7 @@ def _prepare_backward_infra(
self._shape_inference_bwd()
for mb_index in range(num_microbatches):
# `grad_recv_info` is a mirror of `act_send_info`
self.grad_recv_info[mb_index] = self._create_grad_recv_info(
self.act_send_info
)
self.grad_recv_info[mb_index] = self._create_grad_recv_info()

# the last stage does not need recv grads from other rank
if not self.is_last:
Expand All @@ -1375,7 +1393,6 @@ def _prepare_backward_infra(

def _create_grad_recv_info(
self,
act_send_info: dict,
) -> tuple[_RecvInfo, ...]:
grad_recv_info: tuple[_RecvInfo, ...] = ()
if not self.is_last:
Expand All @@ -1388,7 +1405,7 @@ def _create_grad_recv_info(
dst_list[0],
_make_tensor_from_meta(self.grads_meta[idx]),
)
for idx, dst_list in act_send_info.items()
for idx, dst_list in self._need_grad_indices.items()
]
)
return grad_recv_info
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def __init__(self, tensor: paddle.Tensor):
self._local_shape = None
self.dtype = tensor.dtype
self.placements = None if not tensor.is_dist() else tensor.placements
self.stop_gradient = tensor.stop_gradient

def __repr__(self):
return f"TensorMeta(global_shape={self.shape},local_shape={self._local_shape}, dtype={self.dtype}, placements={self.placements})"
Expand Down