Skip to content

[megatron] fix: MTP loss deadlock when using context parallelism#5895

Open
xhx1022 wants to merge 1 commit intoverl-project:mainfrom
xhx1022:mtp_cp
Open

[megatron] fix: MTP loss deadlock when using context parallelism#5895
xhx1022 wants to merge 1 commit intoverl-project:mainfrom
xhx1022:mtp_cp

Conversation

@xhx1022
Copy link
Copy Markdown
Collaborator

@xhx1022 xhx1022 commented Apr 7, 2026

Summary

  • Fix deadlock when MTP is enabled with context parallelism (CP > 1)
  • get_megatron_mtp_loss internally calls reduce_loss_in_tracker() which does all_reduce over DP+CP group (all CP ranks), but the call was gated by is_mp_src_rank_with_outputs() which requires cp_rank==0, so CP rank>0 never participated in the all_reduce → deadlock
  • Changed outer gate from is_mp_src_rank_with_outputs() to mpu.is_pipeline_last_stage(ignore_virtual=True) so all CP/TP ranks participate in the all_reduce, while still only writing metrics on the src rank

The get_megatron_mtp_loss call was gated by is_mp_src_rank_with_outputs()
which requires cp_rank==0, but the internal all_reduce uses DP+CP group
that includes all CP ranks. CP rank>0 never entered the all_reduce,
causing a deadlock. Fix by gating the call with is_pipeline_last_stage()
so all CP ranks participate, and only writing metrics on the src rank.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the MTP loss handling in transformer_impl.py to ensure all ranks in the last pipeline stage participate in the reduction process, which is necessary for correct all-reduce operations across DP and CP groups. The feedback identifies a potential risk where a zero value for n_micro_batch could lead to a ZeroDivisionError or IndexError, and suggests adding a guard condition to prevent this.

@xhx1022 xhx1022 requested a review from PeterSH6 April 7, 2026 08:17
@ArronHZG ArronHZG self-requested a review April 7, 2026 09:13
@ArronHZG
Copy link
Copy Markdown
Collaborator

ArronHZG commented Apr 7, 2026

@arvyanh review

@arvyanh
Copy link
Copy Markdown
Contributor

arvyanh commented Apr 7, 2026

@arvyanh review

looks correct, we had a similar fix locally that I forgot to cherry-pick to github

        # MTP loss is gathered across dp and cp
        if (
            self.model_config.mtp.enable
            and mpu.get_tensor_model_parallel_rank() == 0
            and mpu.get_pipeline_model_parallel_rank() == mpu.get_pipeline_model_parallel_world_size() - 1
        ):
            # add mtp_losses
            metrics = get_megatron_mtp_loss(n_micro_batch)
            if "metrics" not in losses_reduced[0]:
                losses_reduced[0]["metrics"] = {}
            losses_reduced[0]["metrics"].update(metrics)

The difference here (if self.is_mp_src_rank_with_outputs():) is that only CP=0 reports the loss

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants