Skip to content

[Dev] fix(layer_wise): tag MTP-stage word_embeddings as is_embedding_or_output_parameter#5180

Closed
Wohox wants to merge 1 commit into
NVIDIA:devfrom
Wohox:wohox/tag-mtp-word-embeddings-is-embedding-param
Closed

[Dev] fix(layer_wise): tag MTP-stage word_embeddings as is_embedding_or_output_parameter#5180
Wohox wants to merge 1 commit into
NVIDIA:devfrom
Wohox:wohox/tag-mtp-word-embeddings-is-embedding-param

Conversation

@Wohox

@Wohox Wohox commented Jun 5, 2026

Copy link
Copy Markdown
Contributor

Summary

Ports #5034 (merged into main on 2026-05-28) to dev. The dev branch — and
the latest maindev nightly sync (#5029, cut 2026-05-27, one day before
#5034 merged) — still has the pre-fix version.

is_embedding_or_output_parameter (used by decoupled_lr, the Muon
LayerWiseDistributedOptimizer, and FSDP) was only set on the pre_process
stage's word embedding. On an MTP stage (mtp_process=True, pre_process=False)
the duplicated word_embeddings copy was left untagged, so the LayerWise
optimizer treated it as a Muon-managed parameter instead of routing it through
the DistributedOptimizer (the routing wired up by #4771). That duplicates the
embedding's optimizer state and inflates peak memory.

This tags the MTP-stage embedding too, mirroring the pre_process path:

if (self.pre_process or getattr(self, 'mtp_process', False)) and hasattr(self, 'embedding'):
    self.embedding.word_embeddings.weight.is_embedding_or_output_parameter = True

Prerequisites (#4509 DDP-for-LayerWise, #4771 route-non-Muon-through-DistOpt) are
already present on dev; this is the only missing piece.

Test plan

  • DeepSeek-V4 flash proxy (MTP enabled), GB200 4×4 (TP1/PP2/EP8), muon + mxfp8,
    no precision-aware-optimizer, no fp8-param-gather, force-balance. Without this
    tag the run OOMs after iter 1 on a 184 GiB GB200 (peak was already ~178 GiB on
    the branch that includes the fix); with it the run fits and trains, matching
    the pre-sync feature branch that carried fix(layer_wise): tag MTP-stage word_embeddings as is_embedding_or_output_parameter #5034.
  • No numerical change — only the optimizer routing / memory placement of the
    MTP-stage word-embedding parameter is affected.

🤖 Generated with Claude Code

…put_parameter

The is_embedding_or_output_parameter attribute (used by decoupled_lr, the Muon
LayerWise optimizer, and FSDP) was only set on the pre_process stage's word
embedding. On an MTP stage (mtp_process=True, pre_process=False) the duplicated
word_embeddings copy was left untagged, so the LayerWise optimizer treated it as
a Muon-managed parameter instead of routing it through the DistributedOptimizer,
duplicating its optimizer state and inflating peak memory. Tag the MTP-stage
embedding too, mirroring the pre_process path.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@copy-pr-bot

copy-pr-bot Bot commented Jun 5, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@Wohox Wohox closed this Jun 9, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant