[megatron] fix: MTP loss deadlock when using context parallelism#5895
Open
xhx1022 wants to merge 1 commit intoverl-project:mainfrom
Open
[megatron] fix: MTP loss deadlock when using context parallelism#5895xhx1022 wants to merge 1 commit intoverl-project:mainfrom
xhx1022 wants to merge 1 commit intoverl-project:mainfrom
Conversation
The get_megatron_mtp_loss call was gated by is_mp_src_rank_with_outputs() which requires cp_rank==0, but the internal all_reduce uses DP+CP group that includes all CP ranks. CP rank>0 never entered the all_reduce, causing a deadlock. Fix by gating the call with is_pipeline_last_stage() so all CP ranks participate, and only writing metrics on the src rank. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Contributor
There was a problem hiding this comment.
Code Review
This pull request updates the MTP loss handling in transformer_impl.py to ensure all ranks in the last pipeline stage participate in the reduction process, which is necessary for correct all-reduce operations across DP and CP groups. The feedback identifies a potential risk where a zero value for n_micro_batch could lead to a ZeroDivisionError or IndexError, and suggests adding a guard condition to prevent this.
Collaborator
|
@arvyanh review |
Contributor
looks correct, we had a similar fix locally that I forgot to cherry-pick to github # MTP loss is gathered across dp and cp
if (
self.model_config.mtp.enable
and mpu.get_tensor_model_parallel_rank() == 0
and mpu.get_pipeline_model_parallel_rank() == mpu.get_pipeline_model_parallel_world_size() - 1
):
# add mtp_losses
metrics = get_megatron_mtp_loss(n_micro_batch)
if "metrics" not in losses_reduced[0]:
losses_reduced[0]["metrics"] = {}
losses_reduced[0]["metrics"].update(metrics)The difference here ( |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
get_megatron_mtp_lossinternally callsreduce_loss_in_tracker()which doesall_reduceover DP+CP group (all CP ranks), but the call was gated byis_mp_src_rank_with_outputs()which requirescp_rank==0, so CP rank>0 never participated in the all_reduce → deadlockis_mp_src_rank_with_outputs()tompu.is_pipeline_last_stage(ignore_virtual=True)so all CP/TP ranks participate in the all_reduce, while still only writing metrics on the src rank