Skip to content
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
e48de2a
copy base classes
maanug-nv Mar 30, 2026
f561889
copy vocab utils
maanug-nv Mar 30, 2026
b03072d
copy distributed wrapping fns
maanug-nv Mar 30, 2026
ea9f60a
copy mamba cfg and builder
maanug-nv Mar 30, 2026
e7dc746
fix import cycle
maanug-nv Mar 30, 2026
e564315
copy base cfg+builder unit tests
maanug-nv Apr 2, 2026
b08c7e6
copy dist utils unit tests
maanug-nv Apr 2, 2026
ba1d3aa
copy mamba cfg+builder unit tests
maanug-nv Apr 2, 2026
448a7ee
update copyright year
maanug-nv Apr 30, 2026
0a25a81
rename to hybrid
maanug-nv Apr 30, 2026
0a3da47
move to avoid import cycle
maanug-nv Apr 30, 2026
63c0798
match inference spec in build_model with hybrid builders
maanug-nv Apr 30, 2026
eedc90e
add helper to build mamba cfg from args
maanug-nv Apr 30, 2026
19f903c
add model cfg to container in pretrain_hybrid
maanug-nv Apr 30, 2026
3d2c86b
refactor to include torch fsdp config
maanug-nv Apr 30, 2026
a33d292
refactor bucket size assertions
maanug-nv Apr 30, 2026
11b54b3
mirror last 2 commits in dist utils
maanug-nv Apr 30, 2026
21cd884
mirror ddp param layout refactor (#3812) in dist utils
maanug-nv Apr 30, 2026
042b98e
update tests
maanug-nv May 4, 2026
8edc29f
formatting
maanug-nv May 6, 2026
996a9b9
fix import
maanug-nv May 6, 2026
d85c831
re-enable serializable checks
maanug-nv May 6, 2026
025e391
fix headers
maanug-nv May 7, 2026
bfff4cc
update docstring
maanug-nv May 7, 2026
8c8d707
formatting
maanug-nv May 7, 2026
ecdb6ee
handle default spec in builder
maanug-nv May 7, 2026
2ddd5a1
defer test to future PR
maanug-nv May 7, 2026
9b72a53
remove generation config
maanug-nv May 11, 2026
ce22fb1
abstractify
maanug-nv May 11, 2026
0ffb267
update tests
maanug-nv May 11, 2026
a7a8402
docstring cleanup
maanug-nv May 11, 2026
b7c6cc4
Merge branch 'main' into migrate-mamba-builder
maanug-nv May 11, 2026
43bb38b
sync with legacy code removal
maanug-nv May 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ class DistributedDataParallelConfig:
value of max(40000000, 1000000 * dp_size) parameters (larger DP sizes need larger
buckets to ensure collectives do not become latency-bound)."""

num_buckets: Optional[int] = None
"""Number of buckets for data-parallel communication. Should only specify one of
`bucket_size` and `num_buckets`. If `num_buckets` is specified, `bucket_size`
will be determined at runtime."""

pad_buckets_for_high_nccl_busbw: bool = False
"""If true, make sure the bucket size is divisible by a large power of 2 (2^16) to
ensure NCCL collectives have high bus bandwidth at large DP counts, since NCCL
Expand Down Expand Up @@ -225,3 +230,7 @@ def __post_init__(self):
"Only need to explicitly specify param_name patterns for FP32 local accumulation "
"if .main_grads aren't already in FP32"
)

if self.num_buckets is not None:
assert self.bucket_size is None, "Cannot specify both num_buckets and bucket_size"
assert self.num_buckets > 0, "num_buckets must be greater than 0"
163 changes: 162 additions & 1 deletion megatron/training/argument_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@
import ast
import enum
from dataclasses import Field, fields
import warnings
import torch.nn.functional as F
import torch

from megatron.core.transformer import TransformerConfig
from megatron.core.transformer.spec_utils import import_module

from megatron.training.config import (
DistributedInitConfig,
Expand All @@ -23,6 +29,7 @@
StragglerDetectionConfig,
RerunStateMachineConfig, CheckpointConfig, ProfilingConfig
)
from megatron.training.models.hybrid import HybridModelConfig
# TODO: support arg renames

class TypeInferenceError(Exception):
Expand Down Expand Up @@ -261,6 +268,108 @@ def _get_field_docstrings(self, src_cfg_class: type) -> dict[str, str]:
return field_docstrings


def core_transformer_config_from_args(args, config_class=None):
from megatron.core.activations import squared_relu
from megatron.core.fusions.fused_bias_geglu import quick_gelu
from megatron.core.transformer import MLATransformerConfig
from megatron.core.transformer.heterogeneous.heterogeneous_config import (
HeterogeneousTransformerConfig,
)
from megatron.core.quantization.utils import (
kitchen_quantization_recipe_config,
load_quantization_recipe,
)

# Config class.
config_class = config_class or TransformerConfig

if args.multi_latent_attention:
config_class = MLATransformerConfig

if args.heterogeneous_layers_config_path is not None:
assert not args.multi_latent_attention, "Multi latent attention with heterogeneous layers is not supported."
config_class = HeterogeneousTransformerConfig

# Translate args to core transformer configuration
kw_args = {}
for f in dataclasses.fields(config_class):
if hasattr(args, f.name):
kw_args[f.name] = getattr(args, f.name)
kw_args['persist_layer_norm'] = not args.no_persist_layer_norm
kw_args['deallocate_pipeline_outputs'] = True
kw_args['pipeline_dtype'] = args.params_dtype
kw_args['batch_p2p_comm'] = not args.overlap_p2p_comm
kw_args['num_moe_experts'] = args.num_experts
kw_args['rotary_interleaved'] = args.rotary_interleaved
kw_args['num_layers_in_first_pipeline_stage']= args.decoder_first_pipeline_num_layers
kw_args['num_layers_in_last_pipeline_stage']= args.decoder_last_pipeline_num_layers
kw_args['fp8_param'] = args.fp8_param_gather
kw_args['fp4_param'] = args.fp4_param_gather
if args.swiglu:
kw_args['activation_func'] = F.silu
kw_args['gated_linear_unit'] = True
kw_args['bias_activation_fusion'] = args.bias_swiglu_fusion
else:
kw_args['bias_activation_fusion'] = args.bias_gelu_fusion
if args.squared_relu:
assert not args.swiglu
kw_args['activation_func'] = squared_relu
elif args.quick_geglu:
assert not args.swiglu
kw_args['gated_linear_unit'] = True
kw_args['activation_func'] = quick_gelu
if args.init_method_xavier_uniform:
kw_args['init_method'] = torch.nn.init.xavier_uniform_
kw_args['scaled_init_method'] = torch.nn.init.xavier_uniform_
if args.group_query_attention:
kw_args['num_query_groups'] = args.num_query_groups
else:
kw_args['num_query_groups'] = None
kw_args['config_logger_dir'] = args.config_logger_dir
if args.rope_type is None:
# Pop 'rope_type' to let the config class use the default value.
kw_args.pop('rope_type', None)
else:
assert (args.multi_latent_attention or args.rope_type == 'rope'), (
f'Common attention only support rope_type="rope", but got {args.rope_type}.'
)

if len(args.cp_comm_type) == 1:
kw_args['cp_comm_type'] = args.cp_comm_type[0]
if args.hybrid_layer_pattern is not None:
kw_args['is_hybrid_model'] = True
from megatron.core.models.hybrid.hybrid_layer_allocation import Symbols
if Symbols.DS_ATTENTION in args.hybrid_layer_pattern:
kw_args['experimental_attention_variant'] = 'dsa'

kw_args['inference_sampling_seed'] = args.seed

# handle quantization config
# NOTE: Kitchen arguments are only added to the namespace when
# Kitchen library is available.
if hasattr(args, "kitchen_config_file") and args.kitchen_config_file is not None:
kw_args['use_kitchen'] = True
kw_args['quant_recipe'] = load_quantization_recipe(args.kitchen_config_file)
elif hasattr(args, 'kitchen_recipe_number') and args.kitchen_recipe_number is not None:
kw_args['use_kitchen'] = True
kw_args['quant_recipe'] = kitchen_quantization_recipe_config(args.kitchen_recipe_number)

kw_args['moe_latent_size'] = args.moe_latent_size

if args.te_precision_config_file:
assert not 'quant_recipe' in kw_args, "Quantization recipe already configured."
# TODO(kwyss): Prohibit fp8_params or fp4_params with this flexibility
kw_args['quant_recipe'] = load_quantization_recipe(args.te_precision_config_file)

if hasattr(args, "use_kitchen_attention"):
kw_args['use_kitchen_attention'] = args.use_kitchen_attention
if hasattr(args, "kitchen_attention_backend"):
kw_args['kitchen_attention_backend'] = args.kitchen_attention_backend

# Return config.
return config_class(**kw_args)


def _default_config_from_args(cls: type, args: Namespace, return_instance: bool = True) -> Any:
"""Create a config dataclass from the appropriate values in the `args` Namespace.

Expand All @@ -277,10 +386,61 @@ def _default_config_from_args(cls: type, args: Namespace, return_instance: bool
else:
return kwargs

def pretrain_cfg_container_from_args(args: Namespace) -> PretrainConfigContainer:

def hybrid_config_from_args(args: Namespace, config: TransformerConfig | None=None) -> Any:
"""Create a HybridModelConfig from the appropriate values in the `args` Namespace."""

assert args.use_legacy_models is False, "Hybrid model only supported in Mcore!"

kwargs = {}
if config is None:
transformer_cfg = core_transformer_config_from_args(args)
else:
transformer_cfg = config
kwargs["transformer"] = transformer_cfg

if transformer_cfg.transformer_impl == "inference_optimized":
assert (
not transformer_cfg.inference_fuse_tp_communication
), "inference_fuse_tp_communication is not supported for HybridModel"
elif args.spec is not None:
kwargs["hybrid_stack_spec"] = import_module(args.spec)


kwargs["fp16_lm_cross_entropy"] = args.fp16_lm_cross_entropy
kwargs["hybrid_layer_pattern"] = args.hybrid_layer_pattern
kwargs["position_embedding_type"] = args.position_embedding_type
kwargs["rotary_percent"] = args.rotary_percent
kwargs["rotary_base"] = args.rotary_base
kwargs["make_vocab_size_divisible_by"] = args.make_vocab_size_divisible_by

kwargs["seq_len_interpolation_factor"] = args.rotary_seq_len_interpolation_factor
kwargs["seq_length"] = args.max_position_embeddings
kwargs["share_embeddings_and_output_weights"] = not args.untie_embeddings_and_output_weights

if args.padded_vocab_size is not None:
kwargs["vocab_size"] = args.padded_vocab_size
else:
# Megatron-Bridge uses an explicit setting "should_pad_vocab" so that
# when converting model configs from HF, we can set a vocab size and disable padding.
assert args.vocab_size is not None, "Either --padded-vocab-size or --vocab-size must be specified."
kwargs["vocab_size"] = args.vocab_size
kwargs["should_pad_vocab"] = True

return HybridModelConfig(**kwargs)


def pretrain_cfg_container_from_args(args: Namespace, model_cfg=None) -> PretrainConfigContainer:
"""Build a PretrainConfigContainer from the argparse arguments."""
from megatron.training.training import get_megatron_ddp_config, get_megatron_optimizer_config

if model_cfg is None:
msg = """
It is recommended to use a ModelConfig (e.g. megatron.training.models.HybridModelConfig) instead
of a model builder/model provider function pointer.
"""
warnings.warn(msg)

ckpt_kwargs = _default_config_from_args(CheckpointConfig, args, return_instance=False)
ckpt_kwargs["save_optim"] = not args.no_save_optim
ckpt_kwargs["save_rng"] = not args.no_save_rng
Expand All @@ -301,6 +461,7 @@ def pretrain_cfg_container_from_args(args: Namespace) -> PretrainConfigContainer
cfg = PretrainConfigContainer(
train=_default_config_from_args(TrainingConfig, args),
validation=_default_config_from_args(ValidationConfig, args),
model=model_cfg,
optimizer=optim_cfg,
scheduler=_default_config_from_args(SchedulerConfig, args),
ddp=ddp_config,
Expand Down
94 changes: 1 addition & 93 deletions megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
load_quantization_recipe,
)

from megatron.training.argument_utils import ArgumentGroupFactory
from megatron.training.argument_utils import ArgumentGroupFactory, core_transformer_config_from_args

def add_megatron_arguments(parser: argparse.ArgumentParser):
""""Add Megatron-LM arguments to the given parser."""
Expand Down Expand Up @@ -1723,98 +1723,6 @@ def _check_arg_is_not_none(args, arg):
assert getattr(args, arg) is not None, '{} argument is None'.format(arg)


def core_transformer_config_from_args(args, config_class=None):

# Config class.
config_class = config_class or TransformerConfig

if args.multi_latent_attention:
config_class = MLATransformerConfig

if args.heterogeneous_layers_config_path is not None:
assert not args.multi_latent_attention, "Multi latent attention with heterogeneous layers is not supported."
config_class = HeterogeneousTransformerConfig

# Translate args to core transformer configuration
kw_args = {}
for f in dataclasses.fields(config_class):
if hasattr(args, f.name):
kw_args[f.name] = getattr(args, f.name)
kw_args['persist_layer_norm'] = not args.no_persist_layer_norm
kw_args['deallocate_pipeline_outputs'] = True
kw_args['pipeline_dtype'] = args.params_dtype
kw_args['batch_p2p_comm'] = not args.overlap_p2p_comm
kw_args['num_moe_experts'] = args.num_experts
kw_args['rotary_interleaved'] = args.rotary_interleaved
kw_args['num_layers_in_first_pipeline_stage']= args.decoder_first_pipeline_num_layers
kw_args['num_layers_in_last_pipeline_stage']= args.decoder_last_pipeline_num_layers
kw_args['fp8_param'] = args.fp8_param_gather
kw_args['fp4_param'] = args.fp4_param_gather
if args.swiglu:
kw_args['activation_func'] = F.silu
kw_args['gated_linear_unit'] = True
kw_args['bias_activation_fusion'] = args.bias_swiglu_fusion
else:
kw_args['bias_activation_fusion'] = args.bias_gelu_fusion
if args.squared_relu:
assert not args.swiglu
kw_args['activation_func'] = squared_relu
elif args.quick_geglu:
assert not args.swiglu
kw_args['gated_linear_unit'] = True
kw_args['activation_func'] = quick_gelu
if args.init_method_xavier_uniform:
kw_args['init_method'] = torch.nn.init.xavier_uniform_
kw_args['scaled_init_method'] = torch.nn.init.xavier_uniform_
if args.group_query_attention:
kw_args['num_query_groups'] = args.num_query_groups
else:
kw_args['num_query_groups'] = None
kw_args['config_logger_dir'] = args.config_logger_dir
if args.rope_type is None:
# Pop 'rope_type' to let the config class use the default value.
kw_args.pop('rope_type', None)
else:
assert (args.multi_latent_attention or args.rope_type == 'rope'), (
f'Common attention only support rope_type="rope", but got {args.rope_type}.'
)

if len(args.cp_comm_type) == 1:
kw_args['cp_comm_type'] = args.cp_comm_type[0]
if args.hybrid_layer_pattern is not None:
kw_args['is_hybrid_model'] = True
from megatron.core.models.hybrid.hybrid_layer_allocation import Symbols
if Symbols.DS_ATTENTION in args.hybrid_layer_pattern:
kw_args['experimental_attention_variant'] = 'dsa'

kw_args['inference_sampling_seed'] = args.seed

# handle quantization config
# NOTE: Kitchen arguments are only added to the namespace when
# Kitchen library is available.
if hasattr(args, "kitchen_config_file") and args.kitchen_config_file is not None:
kw_args['use_kitchen'] = True
kw_args['quant_recipe'] = load_quantization_recipe(args.kitchen_config_file)
elif hasattr(args, 'kitchen_recipe_number') and args.kitchen_recipe_number is not None:
kw_args['use_kitchen'] = True
kw_args['quant_recipe'] = kitchen_quantization_recipe_config(args.kitchen_recipe_number)

kw_args['moe_latent_size'] = args.moe_latent_size

if args.te_precision_config_file:
assert not 'quant_recipe' in kw_args, "Quantization recipe already configured."
# TODO(kwyss): Prohibit fp8_params or fp4_params with this flexibility
kw_args['quant_recipe'] = load_quantization_recipe(args.te_precision_config_file)

if hasattr(args, "use_kitchen_attention"):
kw_args['use_kitchen_attention'] = args.use_kitchen_attention
if hasattr(args, "kitchen_attention_backend"):
kw_args['kitchen_attention_backend'] = args.kitchen_attention_backend

# Return config.
return config_class(**kw_args)


def _add_transformer_engine_args(parser):
group = parser.add_argument_group(title='Transformer-Engine')

Expand Down
3 changes: 2 additions & 1 deletion megatron/training/config/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from megatron.training.config.utils import sanitize_dataclass_config
from megatron.training.config.instantiate_utils import InstantiationMode, instantiate
from megatron.training.config.yaml_utils import safe_yaml_representers
from megatron.training.models.hybrid import HybridModelConfig

T = TypeVar("T", bound="ConfigContainerBase")

Expand Down Expand Up @@ -217,7 +218,7 @@ class PretrainConfigContainer(ConfigContainerBase):

train: TrainingConfig
validation: ValidationConfig = field(default_factory=ValidationConfig)
# model: GPTModelConfig | MambaModelConfig # TODO (@maanug): add support
model: HybridModelConfig # TODO (@maanug): add support for GPTModelConfig

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: This field is typed as HybridModelConfig with no default, but pretrain_cfg_container_from_args defaults model_cfg=None and five existing callers (pretrain_gpt.py, pretrain_bert.py, pretrain_t5.py, pretrain_vlm.py, train_rl.py) call it without a model_cfg, so None is passed here. This will break type checkers and cause a runtime AttributeError if any downstream code accesses attributes on cfg.model.

Suggested change
model: HybridModelConfig # TODO (@maanug): add support for GPTModelConfig
model: HybridModelConfig | None = None # TODO (@maanug): add support for GPTModelConfig

optimizer: OptimizerConfig
scheduler: SchedulerConfig
# dataset: GPTDatasetConfig # TODO (@maanug): add support
Expand Down
24 changes: 24 additions & 0 deletions megatron/training/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.

from megatron.training.models.base import ModelBuilder, ModelConfig, Serializable, compose_hooks
from megatron.training.models.dist_utils import (
build_virtual_pipeline_stages,
unimodal_build_distributed_models,
)
from megatron.training.models.hybrid import HybridModelBuilder, HybridModelConfig

MambaModelConfig = HybridModelConfig
MambaModelBuilder = HybridModelBuilder

__all__ = [
"ModelBuilder",
"ModelConfig",
"Serializable",
"compose_hooks",
"build_virtual_pipeline_stages",
"unimodal_build_distributed_models",
"HybridModelConfig",
"HybridModelBuilder",
"MambaModelConfig",
"MambaModelBuilder",
]
Loading
Loading