in
|
def reduce_aux_losses_tracker_across_ranks(track_names: Optional[List[str]] = None): |
in reduce_aux_losses_tracker_across_ranks, if pp stage == model block,like num_layers=4, pp=4, all_reduce run error
torch.distributed.all_reduce(
values, group=parallel_state.get_pipeline_model_parallel_group()
)
if num_layers=4, pp=2, run right