@@ -205,7 +205,7 @@ def __init__(
205
205
# Forward infra
206
206
self .args_recv_info : dict [int , tuple [InputInfo , ...]] = {}
207
207
self .act_send_info : dict [int , list ] = {}
208
-
208
+ self . grad_recv_indices : dict [ int , list ] = {}
209
209
# Backward infra will created lazily
210
210
self .grad_recv_info : dict = {}
211
211
self .grad_send_info : list | None = None
@@ -481,7 +481,7 @@ def get_fwd_send_ops(self, fwd_chunk_id: int) -> list[dist.P2POp]:
481
481
else dtensor_to_local (
482
482
out ,
483
483
out .process_mesh ,
484
- self . grads_meta [ idx ] .placements ,
484
+ out .placements ,
485
485
)
486
486
),
487
487
peer_global_rank ,
@@ -620,9 +620,6 @@ def forward_maybe_with_nosync(self, *args, **kwargs):
620
620
def backward_maybe_with_nosync (
621
621
self , backward_type , bwd_kwargs : dict , last_backward = False
622
622
) -> tuple [tuple [paddle .Tensor | None , ...], list [dict [str , Any ] | None ]]:
623
- """
624
- PP 与 DP 混用时,在每个batch的最后一个microbatch的反向开始时,此时的一些行为可能会有所差异,此时可能需要注意。
625
- """
626
623
627
624
def perform_backward (
628
625
backward_type ,
@@ -711,9 +708,19 @@ def forward_one_chunk(
711
708
flat_args = _flatten_args (input_args )
712
709
flat_kwargs = _flatten_args (composite_kwargs )
713
710
flatten_input_tensors = flat_args + flat_kwargs
711
+ grad_required_output_tuple = tuple (
712
+ out
713
+ for out in output_tuple
714
+ if isinstance (out , paddle .Tensor ) and not out .stop_gradient
715
+ )
716
+ grad_required_flatten_input_tensors = [
717
+ inp
718
+ for inp in flatten_input_tensors
719
+ if isinstance (inp , paddle .Tensor ) and not inp .stop_gradient
720
+ ]
714
721
self .fwd_cache [fwd_chunk_id ] = (
715
- output_tuple , # stage_output
716
- flatten_input_tensors , # input_values
722
+ grad_required_output_tuple , # stage_output
723
+ grad_required_flatten_input_tensors , # input_values
717
724
)
718
725
719
726
logger .debug (
@@ -1005,8 +1012,16 @@ def requires_grad(x):
1005
1012
flatten_input_tensors = [
1006
1013
x for x in flatten_input_tensors if not x .stop_gradient
1007
1014
]
1015
+ grad_required_outputs = _normalize_model_output_as_tuple (
1016
+ outputs
1017
+ )
1018
+ grad_required_outputs = tuple (
1019
+ out
1020
+ for out in grad_required_outputs
1021
+ if isinstance (out , paddle .Tensor ) and not out .stop_gradient
1022
+ )
1008
1023
self .fwd_cache [0 ] = (
1009
- _normalize_model_output_as_tuple ( outputs ) , # stage_output
1024
+ grad_required_outputs , # stage_output
1010
1025
flatten_input_tensors , # input_values
1011
1026
)
1012
1027
@@ -1100,13 +1115,16 @@ def _prepare_forward_infra(
1100
1115
# Send info during forward for each activation
1101
1116
# only need the rank that is being sent to
1102
1117
self .act_send_info : dict [int , list ] = {}
1103
-
1104
- for idx in range (len (self . get_outputs_meta () )):
1118
+ outputs_meta = self . get_outputs_meta ()
1119
+ for idx in range (len (outputs_meta )):
1105
1120
# We assume we always send to stage + 1
1106
1121
if not self .is_last :
1107
1122
self .act_send_info [idx ] = [self .stage_index + 1 ]
1123
+ if not outputs_meta [idx ].stop_gradient :
1124
+ self .grad_recv_indices [idx ] = [self .stage_index + 1 ]
1108
1125
else :
1109
1126
self .act_send_info [idx ] = []
1127
+ self .grad_recv_indices [idx ] = []
1110
1128
1111
1129
return outputs
1112
1130
@@ -1195,9 +1213,7 @@ def _prepare_backward_infra(
1195
1213
self ._shape_inference_bwd ()
1196
1214
for mb_index in range (num_microbatches ):
1197
1215
# `grad_recv_info` is a mirror of `act_send_info`
1198
- self .grad_recv_info [mb_index ] = self ._create_grad_recv_info (
1199
- self .act_send_info
1200
- )
1216
+ self .grad_recv_info [mb_index ] = self ._create_grad_recv_info ()
1201
1217
1202
1218
# the last stage does not need recv grads from other rank
1203
1219
if not self .is_last :
@@ -1211,7 +1227,6 @@ def _prepare_backward_infra(
1211
1227
1212
1228
def _create_grad_recv_info (
1213
1229
self ,
1214
- act_send_info : dict ,
1215
1230
) -> tuple [_RecvInfo , ...]:
1216
1231
grad_recv_info : tuple [_RecvInfo , ...] = ()
1217
1232
if not self .is_last :
@@ -1224,7 +1239,7 @@ def _create_grad_recv_info(
1224
1239
dst_list [0 ],
1225
1240
_make_tensor_from_meta (self .grads_meta [idx ]),
1226
1241
)
1227
- for idx , dst_list in act_send_info .items ()
1242
+ for idx , dst_list in self . grad_recv_indices .items ()
1228
1243
]
1229
1244
)
1230
1245
return grad_recv_info
0 commit comments