-
Notifications
You must be signed in to change notification settings - Fork 5.8k
[Auto-parallel] 增强stage对层间传递stop_grad=True的参数的支持 #73459
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
f1c40e1
69c0518
ea80cdb
7511958
10f31c2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -204,7 +204,9 @@ def __init__( | |
# Forward infra | ||
self.args_recv_info: dict[int, tuple[InputInfo, ...]] = {} | ||
self.act_send_info: dict[int, list] = {} | ||
|
||
self._need_grad_indices: dict[int, list] = ( | ||
{} | ||
) # record the index of output that needs to receive grad from the next stage. | ||
# Backward infra will created lazily | ||
self.grad_recv_info: dict = {} | ||
self.grad_send_info: list | None = None | ||
|
@@ -480,7 +482,7 @@ def get_fwd_send_ops(self, fwd_chunk_id: int) -> list[dist.P2POp]: | |
else dtensor_to_local( | ||
out, | ||
out.process_mesh, | ||
self.grads_meta[idx].placements, | ||
out.placements, | ||
) | ||
), | ||
peer_global_rank, | ||
|
@@ -619,9 +621,6 @@ def forward_maybe_with_nosync(self, *args, **kwargs): | |
def backward_maybe_with_nosync( | ||
self, backward_type, bwd_kwargs: dict, last_backward=False | ||
) -> tuple[tuple[paddle.Tensor | None, ...], list[dict[str, Any] | None]]: | ||
""" | ||
PP 与 DP 混用时,在每个batch的最后一个microbatch的反向开始时,此时的一些行为可能会有所差异,此时可能需要注意。 | ||
""" | ||
|
||
def perform_backward( | ||
backward_type, | ||
|
@@ -710,9 +709,19 @@ def forward_one_chunk( | |
flat_args = _flatten_args(input_args) | ||
flat_kwargs = _flatten_args(composite_kwargs) | ||
flatten_input_tensors = flat_args + flat_kwargs | ||
grad_required_output_tuple = tuple( | ||
out | ||
for out in output_tuple | ||
if isinstance(out, paddle.Tensor) and not out.stop_gradient | ||
) | ||
grad_required_flatten_input_tensors = [ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
inp | ||
for inp in flatten_input_tensors | ||
if isinstance(inp, paddle.Tensor) and not inp.stop_gradient | ||
] | ||
self.fwd_cache[fwd_chunk_id] = ( | ||
output_tuple, # stage_output | ||
flatten_input_tensors, # input_values | ||
grad_required_output_tuple, # stage_output | ||
grad_required_flatten_input_tensors, # input_values | ||
) | ||
|
||
logger.debug( | ||
|
@@ -1169,8 +1178,16 @@ def requires_grad(x): | |
flatten_input_tensors = [ | ||
x for x in flatten_input_tensors if not x.stop_gradient | ||
] | ||
grad_required_outputs = _normalize_model_output_as_tuple( | ||
outputs | ||
) | ||
grad_required_outputs = tuple( | ||
out | ||
for out in grad_required_outputs | ||
if isinstance(out, paddle.Tensor) and not out.stop_gradient | ||
) | ||
self.fwd_cache[0] = ( | ||
_normalize_model_output_as_tuple(outputs), # stage_output | ||
grad_required_outputs, # stage_output | ||
flatten_input_tensors, # input_values | ||
) | ||
|
||
|
@@ -1264,13 +1281,16 @@ def _prepare_forward_infra( | |
# Send info during forward for each activation | ||
# only need the rank that is being sent to | ||
self.act_send_info: dict[int, list] = {} | ||
|
||
for idx in range(len(self.get_outputs_meta())): | ||
outputs_meta = self.get_outputs_meta() | ||
for idx in range(len(outputs_meta)): | ||
# We assume we always send to stage + 1 | ||
if not self.is_last: | ||
self.act_send_info[idx] = [self.stage_index + 1] | ||
if not outputs_meta[idx].stop_gradient: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TensorMeta 没有这个属性 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
self._need_grad_indices[idx] = [self.stage_index + 1] | ||
else: | ||
self.act_send_info[idx] = [] | ||
self._need_grad_indices[idx] = [] | ||
|
||
return outputs | ||
|
||
|
@@ -1359,9 +1379,7 @@ def _prepare_backward_infra( | |
self._shape_inference_bwd() | ||
for mb_index in range(num_microbatches): | ||
# `grad_recv_info` is a mirror of `act_send_info` | ||
self.grad_recv_info[mb_index] = self._create_grad_recv_info( | ||
self.act_send_info | ||
) | ||
self.grad_recv_info[mb_index] = self._create_grad_recv_info() | ||
|
||
# the last stage does not need recv grads from other rank | ||
if not self.is_last: | ||
|
@@ -1375,7 +1393,6 @@ def _prepare_backward_infra( | |
|
||
def _create_grad_recv_info( | ||
self, | ||
act_send_info: dict, | ||
) -> tuple[_RecvInfo, ...]: | ||
grad_recv_info: tuple[_RecvInfo, ...] = () | ||
if not self.is_last: | ||
|
@@ -1388,7 +1405,7 @@ def _create_grad_recv_info( | |
dst_list[0], | ||
_make_tensor_from_meta(self.grads_meta[idx]), | ||
) | ||
for idx, dst_list in act_send_info.items() | ||
for idx, dst_list in self._need_grad_indices.items() | ||
] | ||
) | ||
return grad_recv_info |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
命名似有问题,grad开头,给人误解为是grad数据。命名为 requires_grad_output_tuple更好
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done