Skip to content

[Bugfix] Restore moe_forward output shape invariant on TRTLLM MXFP4 path#41646

Open
stecasta wants to merge 2 commits intovllm-project:mainfrom
stecasta:fix-gpt-oss-moe-stride
Open

[Bugfix] Restore moe_forward output shape invariant on TRTLLM MXFP4 path#41646
stecasta wants to merge 2 commits intovllm-project:mainfrom
stecasta:fix-gpt-oss-moe-stride

Conversation

@stecasta
Copy link
Copy Markdown
Contributor

@stecasta stecasta commented May 4, 2026

Purpose

Fixes #41645. gpt-oss-{20b,120b} crashes under torch.compile with tensor_parallel_size > 1 on Blackwell because of a fake/real shape mismatch in vllm.moe_forward:

  • The TRT-LLM MXFP4 experts kernel (TrtLlmMxfp4Experts{Monolithic,Modular}.apply) writes output at moe_config.hidden_dim_unpadded (gpt-oss: 2880).
  • The fake _moe_forward_fake returned torch.empty_like(hidden_states) whose last dim is the padded hidden_dim (gpt-oss: 3072 after the kernel-alignment pad in _maybe_pad_hidden_states).

When inductor's assert_size_stride checks the runtime tensor against the traced fake, it fires:

AssertionError: expected size 256==256, stride 2880==3072 at dim=0;
                expected size 2880==3072, stride 1==1 at dim=1
Error in op: torch.ops.vllm.moe_forward.default

The divergence was introduced by #40960 (added kernel-alignment padding without updating the fake).

Approach

vllm/model_executor/layers/fused_moe/runner/moe_runner.py

Add hidden_dim_unpadded: int to the _moe_forward / _moe_forward_shared (and matching fake) signatures. The caller in MoERunner.forward computes the value via _trtllm_mxfp4_unpadded_dim: returns moe_config.hidden_dim_unpadded only when the active backend is TrtLlmMxfp4ExpertsBase, else 0. The fake allocates the narrow shape when the int is positive, else falls through to empty_like(hidden_states).

Computing the discriminator caller-side rather than peeking layer state in the fake is necessary: doing the isinstance check inside _moe_forward_fake specializes the fake per-layer_name and breaks torch.compile subgraph dedup (tests/compile/h100/test_startup.py::test_moe_startup is the canary that catches this). moe_config.hidden_dim_unpadded alone is also insufficient: it encodes the model's logical hidden, not whether the active kernel narrows. Cutlass MXFP4 MXFP8 writes the full padded width and would be mis-classified.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the trtllm_mxfp4_moe.py file by replacing self.hidden_dim_unpadded with self.hidden_dim for the output tensor allocation in the apply method and updating the workspace_shapes calculation to use K. I have no feedback to provide as there were no review comments to evaluate.

@mergify mergify Bot added nvidia bug Something isn't working labels May 4, 2026
@mgoin mgoin added ready ONLY add when PR is ready to merge/full CI is needed gpt-oss Related to GPT-OSS models labels May 4, 2026
@stecasta
Copy link
Copy Markdown
Contributor Author

stecasta commented May 4, 2026

Seems like there is an infra issue for the failing tests @mgoin. Should we retry them?

@simon-mo
Copy link
Copy Markdown
Collaborator

simon-mo commented May 4, 2026

I retried failed tests.

@github-project-automation github-project-automation Bot moved this to Ready in NVIDIA May 4, 2026
@github-project-automation github-project-automation Bot moved this from To Triage to Ready in gpt-oss Issues & Enhancements May 4, 2026
Copy link
Copy Markdown
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@ywang96
Copy link
Copy Markdown
Member

ywang96 commented May 4, 2026

Is this test failure related?

(EngineCore pid=13238) ERROR 05-04 18:58:58 [core.py:1136]   File "/usr/lib/python3.12/concurrent/futures/_base.py", line 401, in __get_result
--
(EngineCore pid=13238) ERROR 05-04 18:58:58 [core.py:1136]     raise self._exception
(EngineCore pid=13238) ERROR 05-04 18:58:58 [core.py:1136]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/executor/multiproc_executor.py", line 94, in _wait_for_response
(EngineCore pid=13238) ERROR 05-04 18:58:58 [core.py:1136]     response = self.aggregate(self.get_response())
(EngineCore pid=13238) ERROR 05-04 18:58:58 [core.py:1136]                               ^^^^^^^^^^^^^^^^^^^
(EngineCore pid=13238) ERROR 05-04 18:58:58 [core.py:1136]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/executor/multiproc_executor.py", line 390, in get_response
(EngineCore pid=13238) ERROR 05-04 18:58:58 [core.py:1136]     raise RuntimeError(
(EngineCore pid=13238) ERROR 05-04 18:58:58 [core.py:1136] RuntimeError: Worker failed with error 'Workspace validation failed:
(EngineCore pid=13238) ERROR 05-04 18:58:58 [core.py:1136]   - token_num (8192) * hidden_dim (3072) exceeds workspace max_token_num (8192) * hidden_dim (2880). This may cause Illegal Memory Access.', please check the stack trace above for the root cause

@zyongye
Copy link
Copy Markdown
Member

zyongye commented May 4, 2026

Is this test failure related?

(EngineCore pid=13238) ERROR 05-04 18:58:58 [core.py:1136]   File "/usr/lib/python3.12/concurrent/futures/_base.py", line 401, in __get_result
--
(EngineCore pid=13238) ERROR 05-04 18:58:58 [core.py:1136]     raise self._exception
(EngineCore pid=13238) ERROR 05-04 18:58:58 [core.py:1136]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/executor/multiproc_executor.py", line 94, in _wait_for_response
(EngineCore pid=13238) ERROR 05-04 18:58:58 [core.py:1136]     response = self.aggregate(self.get_response())
(EngineCore pid=13238) ERROR 05-04 18:58:58 [core.py:1136]                               ^^^^^^^^^^^^^^^^^^^
(EngineCore pid=13238) ERROR 05-04 18:58:58 [core.py:1136]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/executor/multiproc_executor.py", line 390, in get_response
(EngineCore pid=13238) ERROR 05-04 18:58:58 [core.py:1136]     raise RuntimeError(
(EngineCore pid=13238) ERROR 05-04 18:58:58 [core.py:1136] RuntimeError: Worker failed with error 'Workspace validation failed:
(EngineCore pid=13238) ERROR 05-04 18:58:58 [core.py:1136]   - token_num (8192) * hidden_dim (3072) exceeds workspace max_token_num (8192) * hidden_dim (2880). This may cause Illegal Memory Access.', please check the stack trace above for the root cause

I think so

@stecasta
Copy link
Copy Markdown
Contributor Author

stecasta commented May 5, 2026

The b200 tests hit the bugged agent dgxb200-01-4. Can someone re-trigger them?

@stecasta
Copy link
Copy Markdown
Contributor Author

stecasta commented May 5, 2026

I ended up with a different approach: explicit hidden_dim_unpadded: int argument on the four _moe_forward custom-op signatures. The fake allocates at hidden_dim_unpadded; the real op truncates with out[..., :hidden_dim_unpadded].contiguous() only when the backend writes wider. One contract handles all backend widths (TRTLLM unpadded 2880, TRTLLM widened 3072, CUTLASS 128B-aligned 2944). No-op for any model where hidden_dim_unpadded == hidden_dim

@zyongye
Copy link
Copy Markdown
Member

zyongye commented May 5, 2026

I ended up with a different approach: explicit hidden_dim_unpadded: int argument on the four _moe_forward custom-op signatures. The fake allocates at hidden_dim_unpadded; the real op truncates with out[..., :hidden_dim_unpadded].contiguous() only when the backend writes wider. One contract handles all backend widths (TRTLLM unpadded 2880, TRTLLM widened 3072, CUTLASS 128B-aligned 2944). No-op for any model where hidden_dim_unpadded == hidden_dim

I think we can get this parameter from

layer = get_layer_from_name(_resolve_layer_name(layer_name))
hidden_dim_unpadded = layer.config.hidden_dim_unpadded

Also I feel like the .contiguous() at the very end is unnecessary.

@stecasta
Copy link
Copy Markdown
Contributor Author

stecasta commented May 5, 2026

I think we can get this parameter from

layer = get_layer_from_name(_resolve_layer_name(layer_name))
hidden_dim_unpadded = layer.config.hidden_dim_unpadded

I think we would also need to do something like this:

register_opaque_type(LayerName, hoist=True, members={"value": MemberType.USE_REAL})

I chose to add the explicit parameter instead because I wasn't sure if this would affect other consumers, but I'm happy to switch to this approach and then reading the hidden dim like this:

hidden_dim_unpadded = (
        layer.moe_config.hidden_dim_unpadded or layer.moe_config.hidden_dim
    )

@stecasta
Copy link
Copy Markdown
Contributor Author

stecasta commented May 5, 2026

I'm actually thinking it's better to restrict the fix to the TRT-LLM mxfp4 path. Let me work on that.

@stecasta stecasta force-pushed the fix-gpt-oss-moe-stride branch from 54306df to 0927ec7 Compare May 6, 2026 11:31
…t-oss MXFP4 + torch.compile

Fixes vllm-project#41645. The TRT-LLM MXFP4 experts kernel writes output at
moe_config.hidden_dim_unpadded while _moe_forward_fake returned
empty_like(hidden_states) at the kernel-aligned padded width, tripping
inductor's assert_size_stride. Plumb the unpadded dim through the custom
op signature so the fake matches the real op's allocation, gated to
TRT-LLM MXFP4 only (other backends, including Cutlass MXFP4 MXFP8, write
the full padded width).

Signed-off-by: Stefano Castagnetta <scastagnetta@nvidia.com>
@stecasta stecasta force-pushed the fix-gpt-oss-moe-stride branch from 0927ec7 to 3180762 Compare May 6, 2026 14:09
@stecasta
Copy link
Copy Markdown
Contributor Author

stecasta commented May 6, 2026

@zyongye updated:

  • Removed the .contiguous() by restricting the fix to the TRT-LLM MXFP4 path. Other backends keep their existing empty_like shape.
  • Couldn't avoid the new hidden_dim_unpadded: int on the op signature. Alternatives I tried break either subgraph dedup (caught by test_moe_startup: num_compiled_artifacts_saved=33 instead of 3 on Phi-MoE) or attention paths (opaque-type member registration via MemberType.USE_REAL). Reasoning in the PR description. Open to suggestions on a cleaner shape.

CI is now green.

@zyongye zyongye enabled auto-merge (squash) May 6, 2026 21:49
@stecasta
Copy link
Copy Markdown
Contributor Author

stecasta commented May 7, 2026

Failing tests are not related. See #41887

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working gpt-oss Related to GPT-OSS models nvidia ready ONLY add when PR is ready to merge/full CI is needed

Projects

Status: Ready
Status: Ready

Development

Successfully merging this pull request may close these issues.

[Bug]: gpt-oss MoE moe_forward fake-kernel shape mismatch breaks torch.compile + TP > 1 on Blackwell

5 participants