Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions gpt_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def _get_transformer_layer_spec(use_te, config):
use_kitchen_attention=config.use_kitchen_attention,
kitchen_attention_backend=config.kitchen_attention_backend,
mla_down_proj_fusion=getattr(config, "mla_down_proj_fusion", False),
dense_grouped_gemm=config.dense_grouped_gemm,
)
elif config.transformer_impl == "inference_optimized":
return get_gpt_layer_with_inference_spec(
Expand Down
185 changes: 185 additions & 0 deletions megatron/core/extensions/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2557,8 +2557,193 @@ def as_mlp_submodule(
ffn_hidden_size=ffn_hidden_size,
)

class TEFusedDenseMLP(TEFusedMLP):
Comment thread
sraman-rgb marked this conversation as resolved.
Outdated
"""Dense MLP using GroupedLinear(num_groups=1) to trigger
ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8 fusion on SM100+ with MXFP8 recipe.

Subclass of TEFusedMLP -> does not modify TEFusedMLP or TEGroupedMLP.
The fused kernel fires automatically via the TE op fuser when it detects
the GroupedLinear -> ScaledSwiGLU -> GroupedLinear pattern with MXFP8 recipe.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._norm_seq: Optional[Tuple[te.pytorch.ops.Sequential]] = None
if not is_te_min_version("2.14.0"):
raise RuntimeError(
f"{self.__class__.__name__} requires Transformer Engine >= 2.14.0 "
"(needs pytorch.ops.GroupedLinear and pytorch.ops.ScaledSwiGLU)"
)
if self.config.add_bias_linear:
raise ValueError(
f"{self.__class__.__name__} does not support add_bias_linear=True; "
"the CuTeGEMM fused kernel requires bias-free linear layers."
)
if self.config.activation_func != F.silu or not self.config.gated_linear_unit:
raise ValueError(
f"{self.__class__.__name__} requires SwiGLU activation "
"(activation_func=F.silu, gated_linear_unit=True) "
"for the CuTeGEMM fused kernel, but got "
f"activation_func={self.config.activation_func}, "
f"gated_linear_unit={self.config.gated_linear_unit}."
)

def _make_fused_impl(self) -> te.pytorch.ops.Sequential:
"""Construct fused module with GroupedLinear(num_groups=1) + ScaledSwiGLU."""

fused_impl = te.pytorch.ops.Sequential()

# Tensor parallelism configuration
tp_world_size = get_tensor_model_parallel_world_size()
tp_group = None
if tp_world_size > 1:
tp_group = get_tensor_model_parallel_group()
Comment thread
sraman-rgb marked this conversation as resolved.
Outdated

# RNG state
rng_state_tracker_function = None
if get_cuda_rng_tracker().is_initialized():
rng_state_tracker_function = get_cuda_rng_tracker

# Check submodule types (same as TEFusedMLP)
if not isinstance(self.linear_fc1, te.pytorch.LayerNormLinear):
raise ValueError(
f"{self.__class__.__name__} expects FC1 to be "
"Transformer Engine LayerNormLinear, but found "
f"{self.linear_fc1.__class__.__name__}."
)
if not isinstance(self.linear_fc2, te.pytorch.Linear):
raise ValueError(
f"{self.__class__.__name__} expects FC2 to be "
"Transformer Engine Linear, but found "
f"{self.linear_fc2.__class__.__name__}."
)

# Norm op (same as TEFusedMLP)
norm_type = self.linear_fc1.normalization
norm_shape = self.linear_fc1.weight.size(1)
kwargs = {
"eps": self.linear_fc1.eps,
"device": "meta",
"dtype": self.linear_fc1.layer_norm_weight.dtype,
"zero_centered_gamma": self.linear_fc1.zero_centered_gamma,
}
op = None
if norm_type == "LayerNorm":
op = te.pytorch.ops.LayerNorm(norm_shape, **kwargs)
op.weight = self.linear_fc1.layer_norm_weight
op.bias = self.linear_fc1.layer_norm_bias
elif norm_type == "RMSNorm":
op = te.pytorch.ops.RMSNorm(norm_shape, **kwargs)
op.weight = self.linear_fc1.layer_norm_weight
else:
raise ValueError(f"Unsupported normalization ({norm_type})")
# Store norm in a separate Sequential applied OUTSIDE the MXFP8 autocast
# in forward(). Running norm inside MXFP8 context corrupts the saved rstd
# used in RMSNorm backward, causing gradient amplification up to 10^6.
Comment thread
sraman-rgb marked this conversation as resolved.
# Wrapped in tuple to avoid nn.Module submodule registration (which would
# duplicate the shared norm weight in state_dict/parameters).
norm_seq = te.pytorch.ops.Sequential()
norm_seq.append(op)
self._norm_seq = (norm_seq,)

# GLU interleave size must match ScaledSwiGLU and the CuTe kernel.
_GLU_INTERLEAVE_SIZE = 32
Comment thread
sraman-rgb marked this conversation as resolved.

# FC1: GroupedLinear(num_groups=1) instead of BasicLinear
weight = self.linear_fc1.weight
op = te.pytorch.ops.GroupedLinear(
num_groups=1,
in_features=weight.size(1),
out_features=weight.size(0) * tp_world_size,
device="meta",
dtype=weight.dtype,
bias=False,
rng_state_tracker_function=rng_state_tracker_function,
accumulate_into_main_grad=self.linear_fc1.fuse_wgrad_accumulation,
)
op.weight0 = weight
op._glu_interleave_size = _GLU_INTERLEAVE_SIZE # signals fuser_forward to interleave
fused_impl.append(op)

# ScaledSwiGLU with glu_interleave_size=32
# Required by ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8
fused_impl.append(te.pytorch.ops.ScaledSwiGLU(glu_interleave_size=32))

# FC2: GroupedLinear(num_groups=1) instead of BasicLinear
weight = self.linear_fc2.weight
op = te.pytorch.ops.GroupedLinear(
num_groups=1,
in_features=weight.size(1),
out_features=weight.size(0),
device="meta",
dtype=weight.dtype,
bias=False,
rng_state_tracker_function=rng_state_tracker_function,
accumulate_into_main_grad=self.linear_fc2.fuse_wgrad_accumulation,
)
op.weight0 = weight
# FC2 has no SwiGLU — MXFP8 quantization done on-the-fly in fuser_forward.
# No _mxfp8_weight0 pre-computation to avoid ~28 GB persistent FP8 tensors.
fused_impl.append(op)

if tp_world_size > 1:
if self.linear_fc2.sequence_parallel:
fused_impl.append(te.pytorch.ops.ReduceScatter(tp_group))
else:
fused_impl.append(te.pytorch.ops.AllReduce(tp_group))

self._register_hooks_on_fused_impl(fused_impl)
return fused_impl

def forward(self, hidden_states: torch.Tensor, **kwargs) -> Tuple[Tensor, Optional[Tensor]]:
"""Forward pass using GroupedLinear(num_groups=1) + ScaledSwiGLU."""

orig_shape = hidden_states.shape
hidden_size = hidden_states.size(-1)
hidden_states_2d = hidden_states.view(-1, hidden_size)
total_tokens = hidden_states_2d.size(0)

tokens_per_expert = torch.full(
(1,), total_tokens, dtype=torch.long, device=hidden_states.device
)
scales = torch.ones(
total_tokens, device=hidden_states.device, dtype=hidden_states.dtype
)

# Build fused impl and cache recipe lazily on first forward pass.
# Both are created once and reused — avoids object creation every call.
if not hasattr(self, '_recipe'):
if os.getenv("FP4_RECIPE", "") == "nvfp4":
self._recipe = te.common.recipe.NVFP4BlockScaling()
else:
self._recipe = te.common.recipe.MXFP8BlockScaling()
recipe = self._recipe

if self._fused_impl is None:
with te.pytorch.fp8_autocast(enabled=True, fp8_recipe=recipe):
Comment thread
sraman-rgb marked this conversation as resolved.
Outdated
self._fused_impl = (self._make_fused_impl(),)

# Apply norm in BF16 OUTSIDE the MXFP8 autocast to preserve the rstd
# tensor used by RMSNorm backward (running it inside causes up to 10^6
# gradient amplification, and causes convergence issues).
normed = self._norm_seq[0](hidden_states_2d)

with te.pytorch.fp8_autocast(enabled=True, fp8_recipe=recipe):
Comment thread
sraman-rgb marked this conversation as resolved.
Outdated
out = self._fused_impl[0](normed, tokens_per_expert, scales, tokens_per_expert)

out = out.view(*orig_shape[:-1], out.size(-1))

bias = None
if self.linear_fc2.te_return_bias:
bias = self.linear_fc2.bias
if isinstance(bias, torch.Tensor) and bias.numel() == 0:
bias = None

return out, bias

else:
TEFusedMLP = None # type: ignore[assignment, misc]
TEFusedDenseMLP = None # type: ignore[assignment, misc]


class TEDelayedScaling(te.common.recipe.DelayedScaling):
Expand Down
14 changes: 11 additions & 3 deletions megatron/core/models/gpt/gpt_layer_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@
from megatron.core.utils import is_te_min_version

if HAVE_TE:
from megatron.core.extensions.transformer_engine import TEFusedMLP, TENorm
from megatron.core.extensions.transformer_engine import TEFusedDenseMLP, TEFusedMLP, TENorm
from megatron.core.extensions.transformer_engine_spec_provider import TESpecProvider
else:
TEFusedMLP, TENorm, TESpecProvider = None, None, None
TEFusedDenseMLP, TEFusedMLP, TENorm, TESpecProvider = None, None, None, None

try:
from megatron.core.extensions.kitchen import HAVE_KITCHEN, KitchenSpecProvider
Expand Down Expand Up @@ -185,6 +185,7 @@ def get_gpt_layer_with_transformer_engine_submodules(
use_kitchen_attention: bool = False,
kitchen_attention_backend: str = "sdpa",
mla_down_proj_fusion: bool = False,
dense_grouped_gemm: bool = False,
) -> TransformerLayerSubmodules:
"""Use these submodules to use lower-level Transformer Engine modules (required for fp8
training).
Expand Down Expand Up @@ -233,6 +234,7 @@ def get_gpt_layer_with_transformer_engine_submodules(
moe_grouped_gemm=moe_grouped_gemm,
use_te_op_fuser=use_te_op_fuser,
use_te_activation_func=use_te_activation_func,
dense_grouped_gemm=dense_grouped_gemm,
)

if multi_latent_attention:
Expand Down Expand Up @@ -518,6 +520,7 @@ def get_mlp_module_spec_for_backend(
moe_grouped_gemm: Optional[bool] = False,
use_te_op_fuser: Optional[bool] = False,
use_te_activation_func: bool = False,
dense_grouped_gemm: bool = False,
) -> MlpBuilder:
"""Helper function to get module spec for MLP/MoE"""

Expand All @@ -526,7 +529,12 @@ def get_mlp_module_spec_for_backend(

if num_experts is None:
# Dense MLP w/ or w/o TE modules.
module = not_none(TEFusedMLP).as_mlp_submodule if use_te_op_fuser else MLP.as_mlp_submodule
if dense_grouped_gemm and use_te_op_fuser:
module = not_none(TEFusedDenseMLP).as_mlp_submodule
elif use_te_op_fuser:
module = not_none(TEFusedMLP).as_mlp_submodule
else:
module = MLP.as_mlp_submodule
if backend.fuse_layernorm_and_linear():
linear_fc1 = backend.column_parallel_layer_norm_linear()
assert linear_fc1 is not None
Expand Down
6 changes: 6 additions & 0 deletions megatron/core/transformer/transformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,12 @@ class TransformerConfig(ModelParallelConfig):
If negative, generates bias once per layer and reuses it (abs value is std).
This is an experimental feature for benchmarking purposes."""

dense_grouped_gemm: bool = False
Comment thread
sraman-rgb marked this conversation as resolved.
Outdated
"""Use GroupedLinear(num_groups=1) for dense MLP to trigger the
ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8 fusion on SM100+ with MXFP8 recipe.
Requires ``use_te_op_fuser=True`` and SwiGLU activation.
"""

moe_grouped_gemm: bool = False
"""When there are multiple experts per rank, compress multiple local (potentially small) gemms
in a single kernel launch to improve the utilization and performance by leveraging the Grouped
Expand Down
1 change: 1 addition & 0 deletions tests/unit_tests/models/test_hybrid_moe_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
"deallocate_pipeline_outputs": True,
"defer_embedding_wgrad_compute": False,
"delay_wgrad_compute": False,
"dense_grouped_gemm": False,
"overlap_dispatch_backward_with_experts_wgrad": False,
"deterministic_mode": False,
"disable_bf16_reduced_precision_matmul": False,
Expand Down
87 changes: 87 additions & 0 deletions tests/unit_tests/transformer/test_te_fused_dense_mlp_spec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.

import pytest
import torch.nn.functional as F

from megatron.core.extensions.transformer_engine import (
HAVE_TE,
TEFusedDenseMLP,
TELayerNormColumnParallelLinear,
TERowParallelLinear,
)
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
from megatron.core.transformer.mlp import MLPSubmodules
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import is_te_min_version
from tests.unit_tests.test_utilities import Utils

_SKIP_REASON = "TEFusedDenseMLP requires Transformer Engine >= 2.14.0"
_SKIP = not HAVE_TE or not is_te_min_version("2.14.0")


def _make_submodules():
return MLPSubmodules(linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear)


def _make_config(**overrides):
defaults = dict(
num_layers=1,
hidden_size=64,
num_attention_heads=4,
activation_func=F.silu,
gated_linear_unit=True,
add_bias_linear=False,
use_cpu_initialization=True,
)
defaults.update(overrides)
return TransformerConfig(**defaults)


@pytest.mark.skipif(_SKIP, reason=_SKIP_REASON)
class TestTEFusedDenseMLPSpec:

def setup_method(self, method):
Utils.initialize_model_parallel(1, 1)
model_parallel_cuda_manual_seed(123)

def teardown_method(self, method):
Utils.destroy_model_parallel()

def test_instantiation(self):
config = _make_config()
mlp = TEFusedDenseMLP(config, _make_submodules())
assert isinstance(mlp, TEFusedDenseMLP)

def test_wrong_activation_raises(self):
config = _make_config(activation_func=F.gelu, gated_linear_unit=False)
with pytest.raises(ValueError, match="SwiGLU activation"):
TEFusedDenseMLP(config, _make_submodules())

def test_gated_linear_unit_false_raises(self):
config = _make_config(gated_linear_unit=False)
with pytest.raises(ValueError, match="SwiGLU activation"):
TEFusedDenseMLP(config, _make_submodules())

def test_add_bias_linear_raises(self):
config = _make_config(add_bias_linear=True)
with pytest.raises(ValueError, match="add_bias_linear"):
TEFusedDenseMLP(config, _make_submodules())

def test_norm_seq_not_registered_as_submodule(self):
# _norm_seq must be stored in a tuple (not directly as nn.Module) to avoid
# PyTorch registering it as a submodule, which would duplicate norm weights
# in state_dict/parameters. Verify it starts as None and is never a bare Module.
import torch.nn as nn

config = _make_config()
mlp = TEFusedDenseMLP(config, _make_submodules())
assert mlp._norm_seq is None
assert '_norm_seq' not in dict(mlp.named_children())

# Simulate what _make_fused_impl does and confirm the tuple-wrap holds.
import transformer_engine.pytorch.ops as te_ops

fake_seq = te_ops.Sequential()
mlp._norm_seq = (fake_seq,)
assert not isinstance(mlp._norm_seq, nn.Module)
assert '_norm_seq' not in dict(mlp.named_children())
Loading