[Dev] fix(layer_wise): tag MTP-stage word_embeddings as is_embedding_or_output_parameter#5180
Closed
Wohox wants to merge 1 commit into
Closed
Conversation
…put_parameter The is_embedding_or_output_parameter attribute (used by decoupled_lr, the Muon LayerWise optimizer, and FSDP) was only set on the pre_process stage's word embedding. On an MTP stage (mtp_process=True, pre_process=False) the duplicated word_embeddings copy was left untagged, so the LayerWise optimizer treated it as a Muon-managed parameter instead of routing it through the DistributedOptimizer, duplicating its optimizer state and inflating peak memory. Tag the MTP-stage embedding too, mirroring the pre_process path. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Ports #5034 (merged into
mainon 2026-05-28) todev. Thedevbranch — andthe latest
main→devnightly sync (#5029, cut 2026-05-27, one day before#5034 merged) — still has the pre-fix version.
is_embedding_or_output_parameter(used bydecoupled_lr, the MuonLayerWiseDistributedOptimizer, and FSDP) was only set on thepre_processstage's word embedding. On an MTP stage (
mtp_process=True,pre_process=False)the duplicated
word_embeddingscopy was left untagged, so the LayerWiseoptimizer treated it as a Muon-managed parameter instead of routing it through
the
DistributedOptimizer(the routing wired up by #4771). That duplicates theembedding's optimizer state and inflates peak memory.
This tags the MTP-stage embedding too, mirroring the
pre_processpath:Prerequisites (#4509 DDP-for-LayerWise, #4771 route-non-Muon-through-DistOpt) are
already present on
dev; this is the only missing piece.Test plan
no precision-aware-optimizer, no fp8-param-gather, force-balance. Without this
tag the run OOMs after iter 1 on a 184 GiB GB200 (peak was already ~178 GiB on
the branch that includes the fix); with it the run fits and trains, matching
the pre-sync feature branch that carried fix(layer_wise): tag MTP-stage word_embeddings as is_embedding_or_output_parameter #5034.
MTP-stage word-embedding parameter is affected.
🤖 Generated with Claude Code