@@ -204,7 +204,9 @@ def __init__(
204
204
# Forward infra
205
205
self .args_recv_info : dict [int , tuple [InputInfo , ...]] = {}
206
206
self .act_send_info : dict [int , list ] = {}
207
-
207
+ self ._need_grad_indices : dict [int , list ] = (
208
+ {}
209
+ ) # record the index of output that needs to receive grad from the next stage.
208
210
# Backward infra will created lazily
209
211
self .grad_recv_info : dict = {}
210
212
self .grad_send_info : list | None = None
@@ -480,7 +482,7 @@ def get_fwd_send_ops(self, fwd_chunk_id: int) -> list[dist.P2POp]:
480
482
else dtensor_to_local (
481
483
out ,
482
484
out .process_mesh ,
483
- self . grads_meta [ idx ] .placements ,
485
+ out .placements ,
484
486
)
485
487
),
486
488
peer_global_rank ,
@@ -619,9 +621,6 @@ def forward_maybe_with_nosync(self, *args, **kwargs):
619
621
def backward_maybe_with_nosync (
620
622
self , backward_type , bwd_kwargs : dict , last_backward = False
621
623
) -> tuple [tuple [paddle .Tensor | None , ...], list [dict [str , Any ] | None ]]:
622
- """
623
- PP 与 DP 混用时,在每个batch的最后一个microbatch的反向开始时,此时的一些行为可能会有所差异,此时可能需要注意。
624
- """
625
624
626
625
def perform_backward (
627
626
backward_type ,
@@ -710,9 +709,19 @@ def forward_one_chunk(
710
709
flat_args = _flatten_args (input_args )
711
710
flat_kwargs = _flatten_args (composite_kwargs )
712
711
flatten_input_tensors = flat_args + flat_kwargs
712
+ requires_grad_output_tuple = tuple (
713
+ out
714
+ for out in output_tuple
715
+ if isinstance (out , paddle .Tensor ) and not out .stop_gradient
716
+ )
717
+ flatten_requires_grad_input_tensors = [
718
+ inp
719
+ for inp in flatten_input_tensors
720
+ if isinstance (inp , paddle .Tensor ) and not inp .stop_gradient
721
+ ]
713
722
self .fwd_cache [fwd_chunk_id ] = (
714
- output_tuple , # stage_output
715
- flatten_input_tensors , # input_values
723
+ requires_grad_output_tuple , # stage_output
724
+ flatten_requires_grad_input_tensors , # input_values
716
725
)
717
726
718
727
logger .debug (
@@ -1169,8 +1178,16 @@ def requires_grad(x):
1169
1178
flatten_input_tensors = [
1170
1179
x for x in flatten_input_tensors if not x .stop_gradient
1171
1180
]
1181
+ grad_required_outputs = _normalize_model_output_as_tuple (
1182
+ outputs
1183
+ )
1184
+ grad_required_outputs = tuple (
1185
+ out
1186
+ for out in grad_required_outputs
1187
+ if isinstance (out , paddle .Tensor ) and not out .stop_gradient
1188
+ )
1172
1189
self .fwd_cache [0 ] = (
1173
- _normalize_model_output_as_tuple ( outputs ) , # stage_output
1190
+ grad_required_outputs , # stage_output
1174
1191
flatten_input_tensors , # input_values
1175
1192
)
1176
1193
@@ -1264,13 +1281,16 @@ def _prepare_forward_infra(
1264
1281
# Send info during forward for each activation
1265
1282
# only need the rank that is being sent to
1266
1283
self .act_send_info : dict [int , list ] = {}
1267
-
1268
- for idx in range (len (self . get_outputs_meta () )):
1284
+ outputs_meta = self . get_outputs_meta ()
1285
+ for idx in range (len (outputs_meta )):
1269
1286
# We assume we always send to stage + 1
1270
1287
if not self .is_last :
1271
1288
self .act_send_info [idx ] = [self .stage_index + 1 ]
1289
+ if not outputs_meta [idx ].stop_gradient :
1290
+ self ._need_grad_indices [idx ] = [self .stage_index + 1 ]
1272
1291
else :
1273
1292
self .act_send_info [idx ] = []
1293
+ self ._need_grad_indices [idx ] = []
1274
1294
1275
1295
return outputs
1276
1296
@@ -1359,9 +1379,7 @@ def _prepare_backward_infra(
1359
1379
self ._shape_inference_bwd ()
1360
1380
for mb_index in range (num_microbatches ):
1361
1381
# `grad_recv_info` is a mirror of `act_send_info`
1362
- self .grad_recv_info [mb_index ] = self ._create_grad_recv_info (
1363
- self .act_send_info
1364
- )
1382
+ self .grad_recv_info [mb_index ] = self ._create_grad_recv_info ()
1365
1383
1366
1384
# the last stage does not need recv grads from other rank
1367
1385
if not self .is_last :
@@ -1375,7 +1393,6 @@ def _prepare_backward_infra(
1375
1393
1376
1394
def _create_grad_recv_info (
1377
1395
self ,
1378
- act_send_info : dict ,
1379
1396
) -> tuple [_RecvInfo , ...]:
1380
1397
grad_recv_info : tuple [_RecvInfo , ...] = ()
1381
1398
if not self .is_last :
@@ -1388,7 +1405,7 @@ def _create_grad_recv_info(
1388
1405
dst_list [0 ],
1389
1406
_make_tensor_from_meta (self .grads_meta [idx ]),
1390
1407
)
1391
- for idx , dst_list in act_send_info .items ()
1408
+ for idx , dst_list in self . _need_grad_indices .items ()
1392
1409
]
1393
1410
)
1394
1411
return grad_recv_info
0 commit comments